13 Commits

Author SHA1 Message Date
bff92738d5 Allow config path from FOURTH_CONFIG 2021-11-01 16:06:47 +08:00
754a5af794 Add publish CI and run fmt 2021-11-01 15:56:57 +08:00
fc7a3038bd Add unknown protocol error 2021-11-01 15:32:08 +08:00
8a96de9666 Update README and minor refactor 2021-11-01 15:25:12 +08:00
0407f4b40c Add config validation 2021-11-01 13:45:47 +08:00
47be2568ba Add upstream scheme support
Need to implement TCP and UDP upstream support.
2021-10-31 19:21:32 +08:00
5944beb6a2 Combine TCP and KCP tests 2021-10-27 08:36:24 +08:00
4363e3f76a Publish 0.1.3 and update README 2021-10-26 23:58:00 +08:00
ee9d0685b3 Refactor TCP and KCP test 2021-10-26 23:52:07 +08:00
421ad8c979 Fix example config 2021-10-26 23:27:03 +08:00
a88a263d20 Move tokio_kcp to local files 2021-10-26 23:02:05 +08:00
bfce455a7e Add Cargo installation method 2021-10-26 21:40:40 +08:00
55eef8581c Add KCP support 2021-10-26 21:36:12 +08:00
22 changed files with 1748 additions and 192 deletions

39
.github/workflows/publish-binaries.yml vendored Normal file
View File

@ -0,0 +1,39 @@
on:
release:
types: [published]
name: Publish binaries to release
jobs:
publish:
name: Publish for ${{ matrix.os }}
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-latest, macos-latest, windows-latest]
include:
- os: ubuntu-latest
artifact_name: fourth
asset_name: fourth-linux-amd64
- os: macos-latest
artifact_name: fourth
asset_name: fourth-macos-amd64
- os: windows-latest
artifact_name: fourth.exe
asset_name: fourth-windows-amd64.exe
steps:
- uses: hecrj/setup-rust-action@master
with:
rust-version: stable
- uses: actions/checkout@v2
- name: Build
run: cargo build --release --locked
- name: Publish
uses: svenstaro/upload-release-action@v1-release
with:
repo_token: ${{ secrets.PUBLISH_TOKEN }}
file: target/release/${{ matrix.artifact_name }}
asset_name: ${{ matrix.asset_name }}
tag: ${{ github.ref }}

143
Cargo.lock generated
View File

@ -34,6 +34,28 @@ version = "1.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cf1de2fe8c75bc145a2f577add951f8134889b4795d47466a54a5c846d691693" checksum = "cf1de2fe8c75bc145a2f577add951f8134889b4795d47466a54a5c846d691693"
[[package]]
name = "byte_string"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "11aade7a05aa8c3a351cedc44c3fc45806430543382fcc4743a9b757a2a0b4ed"
[[package]]
name = "byteorder"
version = "1.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610"
[[package]]
name = "bytes"
version = "0.4.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "206fdffcfa2df7cbe15601ef46c813fce0965eb3286db6b56c583b814b51c81c"
dependencies = [
"byteorder",
"iovec",
]
[[package]] [[package]]
name = "bytes" name = "bytes"
version = "1.1.0" version = "1.1.0"
@ -69,22 +91,36 @@ checksum = "44533bbbb3bb3c1fa17d9f2e4e38bbbaf8396ba82193c4cb1b6445d711445d36"
dependencies = [ dependencies = [
"atty", "atty",
"humantime", "humantime",
"log", "log 0.4.14",
"regex", "regex",
"termcolor", "termcolor",
] ]
[[package]] [[package]]
name = "fourth" name = "form_urlencoded"
version = "0.1.1" version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5fc25a87fa4fd2094bffb06925852034d90a17f0d1e05197d4956d3555752191"
dependencies = [ dependencies = [
"matches",
"percent-encoding",
]
[[package]]
name = "fourth"
version = "0.1.4"
dependencies = [
"byte_string",
"bytes 1.1.0",
"futures", "futures",
"log", "kcp",
"log 0.4.14",
"pretty_env_logger", "pretty_env_logger",
"serde", "serde",
"serde_yaml", "serde_yaml",
"tls-parser", "tls-parser",
"tokio", "tokio",
"url",
] ]
[[package]] [[package]]
@ -216,6 +252,17 @@ dependencies = [
"quick-error", "quick-error",
] ]
[[package]]
name = "idna"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "418a0a6fab821475f634efe3ccc45c013f742efe03d853e8d3355d5cb850ecf8"
dependencies = [
"matches",
"unicode-bidi",
"unicode-normalization",
]
[[package]] [[package]]
name = "indexmap" name = "indexmap"
version = "1.7.0" version = "1.7.0"
@ -235,6 +282,25 @@ dependencies = [
"cfg-if", "cfg-if",
] ]
[[package]]
name = "iovec"
version = "0.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b2b3ea6ff95e175473f8ffe6a7eb7c00d054240321b84c57051175fe3c1e075e"
dependencies = [
"libc",
]
[[package]]
name = "kcp"
version = "0.4.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "09481e52ffe09d417d8b770217faca77eeb048ab5f337562cede72070fc91b21"
dependencies = [
"bytes 0.4.12",
"log 0.3.9",
]
[[package]] [[package]]
name = "libc" name = "libc"
version = "0.2.103" version = "0.2.103"
@ -256,6 +322,15 @@ dependencies = [
"scopeguard", "scopeguard",
] ]
[[package]]
name = "log"
version = "0.3.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e19e8d5c34a3e0e2223db8e060f9e8264aeeb5c5fc64a4ee9965c062211c024b"
dependencies = [
"log 0.4.14",
]
[[package]] [[package]]
name = "log" name = "log"
version = "0.4.14" version = "0.4.14"
@ -265,6 +340,12 @@ dependencies = [
"cfg-if", "cfg-if",
] ]
[[package]]
name = "matches"
version = "0.1.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a3e378b66a060d48947b590737b30a1be76706c8dd7b8ba0f2fe3989c68a853f"
[[package]] [[package]]
name = "memchr" name = "memchr"
version = "2.4.1" version = "2.4.1"
@ -284,7 +365,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8067b404fe97c70829f082dec8bcf4f71225d7eaea1d8645349cb76fa06205cc" checksum = "8067b404fe97c70829f082dec8bcf4f71225d7eaea1d8645349cb76fa06205cc"
dependencies = [ dependencies = [
"libc", "libc",
"log", "log 0.4.14",
"miow", "miow",
"ntapi", "ntapi",
"winapi", "winapi",
@ -400,6 +481,12 @@ dependencies = [
"winapi", "winapi",
] ]
[[package]]
name = "percent-encoding"
version = "2.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d4fd5641d01c8f18a23da7b6fe29298ff4b55afcccdf78973b24cf3175fee32e"
[[package]] [[package]]
name = "phf" name = "phf"
version = "0.10.0" version = "0.10.0"
@ -463,7 +550,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "926d36b9553851b8b0005f1275891b392ee4d2d833852c417ed025477350fb9d" checksum = "926d36b9553851b8b0005f1275891b392ee4d2d833852c417ed025477350fb9d"
dependencies = [ dependencies = [
"env_logger", "env_logger",
"log", "log 0.4.14",
] ]
[[package]] [[package]]
@ -668,6 +755,21 @@ dependencies = [
"winapi-util", "winapi-util",
] ]
[[package]]
name = "tinyvec"
version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f83b2a3d4d9091d0abd7eba4dc2710b1718583bd4d8992e2190720ea38f391f7"
dependencies = [
"tinyvec_macros",
]
[[package]]
name = "tinyvec_macros"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cda74da7e1a664f795bb1f8a87ec406fb89a02522cf6e50620d016add6dbbf5c"
[[package]] [[package]]
name = "tls-parser" name = "tls-parser"
version = "0.11.0" version = "0.11.0"
@ -689,7 +791,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c2c2416fdedca8443ae44b4527de1ea633af61d8f7169ffa6e72c5b53d24efcc" checksum = "c2c2416fdedca8443ae44b4527de1ea633af61d8f7169ffa6e72c5b53d24efcc"
dependencies = [ dependencies = [
"autocfg", "autocfg",
"bytes", "bytes 1.1.0",
"libc", "libc",
"memchr", "memchr",
"mio", "mio",
@ -713,12 +815,39 @@ dependencies = [
"syn", "syn",
] ]
[[package]]
name = "unicode-bidi"
version = "0.3.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1a01404663e3db436ed2746d9fefef640d868edae3cceb81c3b8d5732fda678f"
[[package]]
name = "unicode-normalization"
version = "0.1.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d54590932941a9e9266f0832deed84ebe1bf2e4c9e4a3554d393d18f5e854bf9"
dependencies = [
"tinyvec",
]
[[package]] [[package]]
name = "unicode-xid" name = "unicode-xid"
version = "0.2.2" version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8ccb82d61f80a663efe1f787a51b16b5a51e3314d6ac365b08639f52387b33f3" checksum = "8ccb82d61f80a663efe1f787a51b16b5a51e3314d6ac365b08639f52387b33f3"
[[package]]
name = "url"
version = "2.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a507c383b2d33b5fc35d1861e77e6b383d158b2da5e14fe51b83dfedf6fd578c"
dependencies = [
"form_urlencoded",
"idna",
"matches",
"percent-encoding",
]
[[package]] [[package]]
name = "version_check" name = "version_check"
version = "0.9.3" version = "0.9.3"

View File

@ -1,6 +1,6 @@
[package] [package]
name = "fourth" name = "fourth"
version = "0.1.1" version = "0.1.4"
edition = "2021" edition = "2021"
authors = ["LI Rui <lr_cn@outlook.com>"] authors = ["LI Rui <lr_cn@outlook.com>"]
license = "Apache-2.0" license = "Apache-2.0"
@ -13,6 +13,8 @@ categories = ["web-programming"]
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
exclude = [".*"]
[dependencies] [dependencies]
log = "0.4" log = "0.4"
pretty_env_logger = "0.4" pretty_env_logger = "0.4"
@ -20,5 +22,10 @@ serde = { version = "1.0", features = ["derive"] }
serde_yaml = "0.8" serde_yaml = "0.8"
futures = "0.3" futures = "0.3"
tls-parser = "0.11" tls-parser = "0.11"
url = "2.2.2"
tokio = { version = "1.0", features = ["full"] } tokio = { version = "1.0", features = ["full"] }
bytes = "1.1"
kcp = "0.4"
byte_string = "1"

View File

@ -4,12 +4,15 @@
[![](https://img.shields.io/crates/v/fourth)](https://crates.io/crates/fourth) [![CI](https://img.shields.io/github/workflow/status/kernelerr/fourth/Rust)](https://github.com/KernelErr/fourth/actions/workflows/rust.yml) [![](https://img.shields.io/crates/v/fourth)](https://crates.io/crates/fourth) [![CI](https://img.shields.io/github/workflow/status/kernelerr/fourth/Rust)](https://github.com/KernelErr/fourth/actions/workflows/rust.yml)
Fourth is a layer 4 proxy implemented by Rust to listen on specific ports and transfer data to remote addresses according to configuration. **Under heavy development, version 0.1 may update frequently**
Fourth is a layer 4 proxy implemented by Rust to listen on specific ports and transfer TCP/KCP data to remote addresses(only TCP) according to configuration.
## Features ## Features
- Listen on specific port and proxy to local or remote port - Listen on specific port and proxy to local or remote port
- SNI-based rule without terminating TLS connection - SNI-based rule without terminating TLS connection
- Allow KCP inbound(warning: untested)
## Installation ## Installation
@ -22,36 +25,45 @@ $ cargo build --release
Binary file will be generated at `target/release/fourth`, or you can use `cargo install --path .` to install. Binary file will be generated at `target/release/fourth`, or you can use `cargo install --path .` to install.
Or you can use Cargo to install Fourth:
```bash
$ cargo install fourth
```
Or you can download binary file form the Release page.
## Configuration ## Configuration
Fourth will read yaml format configuration file from `/etc/fourth/config.yaml`, here is an example: Fourth will read yaml format configuration file from `/etc/fourth/config.yaml`, and you can set custom path to environment variable `FOURTH_CONFIG`, here is an minimal viable example:
```yaml ```yaml
version: 1 version: 1
log: info log: info
servers: servers:
example_server: proxy_server:
listen:
- "0.0.0.0:443"
- "[::]:443"
tls: true # Enable TLS features like SNI
sni:
proxy.example.com: proxy
www.example.com: nginx
default: ban
relay_server:
listen: listen:
- "127.0.0.1:8081" - "127.0.0.1:8081"
default: remote default: remote
upstream: upstream:
nginx: "127.0.0.1:8080" remote: "tcp://www.remote.example.com:8082" # proxy to remote address
proxy: "127.0.0.1:1024"
other: "www.remote.example.com:8082" # proxy to remote address
``` ```
Built-in two upstreams: ban(terminate connection immediately), echo Built-in two upstreams: ban(terminate connection immediately), echo. For detailed configuration, check [this example](./example-config.yaml).
## Performance Benchmark
Tested on 4C2G server:
Use fourth to proxy to Nginx(QPS of direct connection: ~120000): ~70000 req/s (Command: `wrk -t200 -c1000 -d120s --latency http://proxy-server:8081`)
Use fourth to proxy to local iperf3: 8Gbps
## Thanks
- [tokio_kcp](https://github.com/Matrix-Zhang/tokio_kcp)
## License ## License

View File

@ -6,12 +6,15 @@
[English](/README-EN.md) [English](/README-EN.md)
Fourth是一个Rust实现的Layer 4代理用于监听指定端口TCP流量并根据规则转发到指定目标。 **积极开发中0.1版本迭代可能较快**
Fourth是一个Rust实现的Layer 4代理用于监听指定端口TCP/KCP流量并根据规则转发到指定目标目前只支持TCP
## 功能 ## 功能
- 监听指定端口代理到本地或远端指定端口 - 监听指定端口代理到本地或远端指定端口
- 监听指定端口通过TLS ClientHello消息中的SNI进行分流 - 监听指定端口通过TLS ClientHello消息中的SNI进行分流
- 支持KCP入站警告未测试
## 安装方法 ## 安装方法
@ -24,36 +27,41 @@ $ cargo build --release
将在`target/release/fourth`生成二进制文件,您也可以使用`cargo install --path . `来安装二进制文件。 将在`target/release/fourth`生成二进制文件,您也可以使用`cargo install --path . `来安装二进制文件。
或者您也可以使用Cargo直接安装
```bash
$ cargo install fourth
```
或者您也可以直接从Release中下载二进制文件。
## 配置 ## 配置
Fourth使用yaml格式的配置文件默认情况下会读取`/etc/fourth/config.yaml`,如下是一个示例配置 Fourth使用yaml格式的配置文件默认情况下会读取`/etc/fourth/config.yaml`您也可以设置自定义路径到环境变量`FOURTH_CONFIG`如下是一个最小有效配置
```yaml ```yaml
version: 1 version: 1
log: info log: info
servers: servers:
example_server: proxy_server:
listen:
- "0.0.0.0:443"
- "[::]:443"
tls: true # 启动SNI分流将根据TLS请求中的主机名分流
sni:
proxy.example.com: proxy
www.example.com: nginx
default: ban
relay_server:
listen: listen:
- "127.0.0.1:8081" - "127.0.0.1:8081"
default: remote default: remote
upstream: upstream:
nginx: "127.0.0.1:8080" remote: "tcp://www.remote.example.com:8082" # proxy to remote address
proxy: "127.0.0.1:1024"
other: "www.remote.example.com:8082" # 代理到远端地址
``` ```
内置两个的upstreamban立即中断连接、echo返回读到的数据 内置两个的upstreamban立即中断连接、echo返回读到的数据更详细的配置可以参考[示例配置](./example-config.yaml)。
## 性能测试
在4C2G的服务器上测试
使用Fourth代理到Nginx直连QPS 120000: ~70000req/s (测试命令:`wrk -t200 -c1000 -d120s --latency http://proxy-server:8081 `
使用Fourth代理到本地iperf38Gbps
## io_uring? ## io_uring?
@ -61,6 +69,10 @@ upstream:
可能以后会为Linux高内核版本的用户提供可选的io_uring加速。 可能以后会为Linux高内核版本的用户提供可选的io_uring加速。
## 感谢
- [tokio_kcp](https://github.com/Matrix-Zhang/tokio_kcp)
## 协议 ## 协议
Fourth以Apache-2.0协议开源。 Fourth以Apache-2.0协议开源。

View File

@ -15,8 +15,13 @@ servers:
listen: listen:
- "127.0.0.1:8081" - "127.0.0.1:8081"
default: remote default: remote
kcp_server:
protocol: kcp # default TCP
listen:
- "127.0.0.1:8082"
default: echo
upstream: upstream:
nginx: "127.0.0.1:8080" nginx: "tcp://127.0.0.1:8080"
proxy: "127.0.0.1:1024" proxy: "tcp://127.0.0.1:1024"
other: "www.remote.example.com:8082" # proxy to remote address remote: "tcp://www.remote.example.com:8082" # proxy to remote address

View File

@ -1,12 +1,21 @@
use log::debug; use log::{debug, warn};
use serde::Deserialize; use serde::Deserialize;
use std::collections::HashMap; use std::collections::{HashMap, HashSet};
use std::fs::File; use std::fs::File;
use std::io::{Error as IOError, Read}; use std::io::{Error as IOError, Read};
use url::Url;
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct Config { pub struct Config {
pub base: BaseConfig, pub base: ParsedConfig,
}
#[derive(Debug, Default, Deserialize, Clone)]
pub struct ParsedConfig {
pub version: i32,
pub log: Option<String>,
pub servers: HashMap<String, ServerConfig>,
pub upstream: HashMap<String, Upstream>,
} }
#[derive(Debug, Default, Deserialize, Clone)] #[derive(Debug, Default, Deserialize, Clone)]
@ -20,11 +29,26 @@ pub struct BaseConfig {
#[derive(Debug, Default, Deserialize, Clone)] #[derive(Debug, Default, Deserialize, Clone)]
pub struct ServerConfig { pub struct ServerConfig {
pub listen: Vec<String>, pub listen: Vec<String>,
pub protocol: Option<String>,
pub tls: Option<bool>, pub tls: Option<bool>,
pub sni: Option<HashMap<String, String>>, pub sni: Option<HashMap<String, String>>,
pub default: Option<String>, pub default: Option<String>,
} }
#[derive(Debug, Clone, Deserialize)]
pub enum Upstream {
Ban,
Echo,
Custom(CustomUpstream),
}
#[derive(Debug, Clone, Deserialize)]
pub struct CustomUpstream {
pub name: String,
pub addr: String,
pub protocol: String,
}
#[derive(Debug)] #[derive(Debug)]
pub enum ConfigError { pub enum ConfigError {
IO(IOError), IO(IOError),
@ -40,29 +64,146 @@ impl Config {
} }
} }
fn load_config(path: &str) -> Result<BaseConfig, ConfigError> { fn load_config(path: &str) -> Result<ParsedConfig, ConfigError> {
let mut contents = String::new(); let mut contents = String::new();
let mut file = (File::open(path))?; let mut file = (File::open(path))?;
(file.read_to_string(&mut contents))?; (file.read_to_string(&mut contents))?;
let parsed: BaseConfig = serde_yaml::from_str(&contents)?; let base: BaseConfig = serde_yaml::from_str(&contents)?;
if parsed.version != 1 { if base.version != 1 {
return Err(ConfigError::Custom( return Err(ConfigError::Custom(
"Unsupported config version".to_string(), "Unsupported config version".to_string(),
)); ));
} }
let log_level = parsed.log.clone().unwrap_or_else(|| "info".to_string()); let log_level = base.log.clone().unwrap_or_else(|| "info".to_string());
if !log_level.eq("disable") { if !log_level.eq("disable") {
std::env::set_var("FOURTH_LOG", log_level.clone()); std::env::set_var("FOURTH_LOG", log_level.clone());
pretty_env_logger::init_custom_env("FOURTH_LOG"); pretty_env_logger::init_custom_env("FOURTH_LOG");
debug!("Set log level to {}", log_level); debug!("Set log level to {}", log_level);
} }
debug!("Config version {}", parsed.version); debug!("Config version {}", base.version);
Ok(parsed) let mut parsed_upstream: HashMap<String, Upstream> = HashMap::new();
for (name, upstream) in base.upstream.iter() {
let upstream_url = match Url::parse(upstream) {
Ok(url) => url,
Err(_) => {
return Err(ConfigError::Custom(format!(
"Invalid upstream url {}",
upstream
)))
}
};
let upstream_host = match upstream_url.host_str() {
Some(host) => host,
None => {
return Err(ConfigError::Custom(format!(
"Invalid upstream url {}",
upstream
)))
}
};
let upsteam_port = match upstream_url.port_or_known_default() {
Some(port) => port,
None => {
return Err(ConfigError::Custom(format!(
"Invalid upstream url {}",
upstream
)))
}
};
if upstream_url.scheme() != "tcp" {
return Err(ConfigError::Custom(format!(
"Invalid upstream scheme {}",
upstream
)));
}
parsed_upstream.insert(
name.to_string(),
Upstream::Custom(CustomUpstream {
name: name.to_string(),
addr: format!("{}:{}", upstream_host, upsteam_port),
protocol: upstream_url.scheme().to_string(),
}),
);
}
parsed_upstream.insert("ban".to_string(), Upstream::Ban);
parsed_upstream.insert("echo".to_string(), Upstream::Echo);
let parsed = ParsedConfig {
version: base.version,
log: base.log,
servers: base.servers,
upstream: parsed_upstream,
};
verify_config(parsed)
}
fn verify_config(config: ParsedConfig) -> Result<ParsedConfig, ConfigError> {
let mut used_upstreams: HashSet<String> = HashSet::new();
let mut upstream_names: HashSet<String> = HashSet::new();
let mut listen_addresses: HashSet<String> = HashSet::new();
// Check for duplicate upstream names
for (name, _) in config.upstream.iter() {
if upstream_names.contains(name) {
return Err(ConfigError::Custom(format!(
"Duplicate upstream name {}",
name
)));
}
upstream_names.insert(name.to_string());
}
for (_, server) in config.servers.clone() {
// check for duplicate listen addresses
for listen in server.listen {
if listen_addresses.contains(&listen) {
return Err(ConfigError::Custom(format!(
"Duplicate listen address {}",
listen
)));
}
listen_addresses.insert(listen.to_string());
}
if server.tls.unwrap_or_default() && server.sni.is_some() {
for (_, val) in server.sni.unwrap() {
used_upstreams.insert(val.to_string());
}
}
if server.default.is_some() {
used_upstreams.insert(server.default.unwrap().to_string());
}
for key in &used_upstreams {
if !config.upstream.contains_key(key) {
return Err(ConfigError::Custom(format!("Upstream {} not found", key)));
}
}
}
for key in &upstream_names {
if !used_upstreams.contains(key) {
warn!("Upstream {} not used", key);
}
}
Ok(config)
} }
impl From<IOError> for ConfigError { impl From<IOError> for ConfigError {
@ -86,7 +227,7 @@ mod tests {
let config = Config::new("tests/config.yaml").unwrap(); let config = Config::new("tests/config.yaml").unwrap();
assert_eq!(config.base.version, 1); assert_eq!(config.base.version, 1);
assert_eq!(config.base.log.unwrap(), "disable"); assert_eq!(config.base.log.unwrap(), "disable");
assert_eq!(config.base.servers.len(), 2); assert_eq!(config.base.servers.len(), 5);
assert_eq!(config.base.upstream.len(), 2); assert_eq!(config.base.upstream.len(), 3 + 2); // Add ban and echo upstreams
} }
} }

View File

@ -1,13 +1,17 @@
mod config; mod config;
mod plugins;
mod servers; mod servers;
use crate::config::Config; use crate::config::Config;
use crate::servers::Server; use crate::servers::Server;
use std::env;
use log::{debug, error}; use log::{debug, error};
fn main() { fn main() {
let config = match Config::new("/etc/fourth/config.yaml") { let config_path = env::var("FOURTH_CONFIG").unwrap_or_else(|_| "/etc/fourth/config.yaml".to_string());
let config = match Config::new(&config_path) {
Ok(config) => config, Ok(config) => config,
Err(e) => { Err(e) => {
println!("Could not load config: {:?}", e); println!("Could not load config: {:?}", e);

110
src/plugins/kcp/config.rs Normal file
View File

@ -0,0 +1,110 @@
use std::{io::Write, time::Duration};
use kcp::Kcp;
/// Kcp Delay Config
#[derive(Debug, Clone, Copy)]
pub struct KcpNoDelayConfig {
/// Enable nodelay
pub nodelay: bool,
/// Internal update interval (ms)
pub interval: i32,
/// ACK number to enable fast resend
pub resend: i32,
/// Disable congetion control
pub nc: bool,
}
impl Default for KcpNoDelayConfig {
fn default() -> KcpNoDelayConfig {
KcpNoDelayConfig {
nodelay: false,
interval: 100,
resend: 0,
nc: false,
}
}
}
#[allow(unused)]
impl KcpNoDelayConfig {
/// Get a fastest configuration
///
/// 1. Enable NoDelay
/// 2. Set ticking interval to be 10ms
/// 3. Set fast resend to be 2
/// 4. Disable congestion control
pub fn fastest() -> KcpNoDelayConfig {
KcpNoDelayConfig {
nodelay: true,
interval: 10,
resend: 2,
nc: true,
}
}
/// Get a normal configuration
///
/// 1. Disable NoDelay
/// 2. Set ticking interval to be 40ms
/// 3. Disable fast resend
/// 4. Enable congestion control
pub fn normal() -> KcpNoDelayConfig {
KcpNoDelayConfig {
nodelay: false,
interval: 40,
resend: 0,
nc: false,
}
}
}
/// Kcp Config
#[derive(Debug, Clone, Copy)]
pub struct KcpConfig {
/// Max Transmission Unit
pub mtu: usize,
/// nodelay
pub nodelay: KcpNoDelayConfig,
/// Send window size
pub wnd_size: (u16, u16),
/// Session expire duration, default is 90 seconds
pub session_expire: Duration,
/// Flush KCP state immediately after write
pub flush_write: bool,
/// Flush ACKs immediately after input
pub flush_acks_input: bool,
/// Stream mode
pub stream: bool,
}
impl Default for KcpConfig {
fn default() -> KcpConfig {
KcpConfig {
mtu: 1400,
nodelay: KcpNoDelayConfig::normal(),
wnd_size: (256, 256),
session_expire: Duration::from_secs(90),
flush_write: false,
flush_acks_input: false,
stream: true,
}
}
}
impl KcpConfig {
/// Applies config onto `Kcp`
#[doc(hidden)]
pub fn apply_config<W: Write>(&self, k: &mut Kcp<W>) {
k.set_mtu(self.mtu).expect("invalid MTU");
k.set_nodelay(
self.nodelay.nodelay,
self.nodelay.interval,
self.nodelay.resend,
self.nodelay.nc,
);
k.set_wndsize(self.wnd_size.0, self.wnd_size.1);
}
}

128
src/plugins/kcp/listener.rs Normal file
View File

@ -0,0 +1,128 @@
use std::{
io::{self, ErrorKind},
net::SocketAddr,
sync::Arc,
time::Duration,
};
use byte_string::ByteStr;
use kcp::{Error as KcpError, KcpResult};
use log::{debug, error, trace};
use tokio::{
net::{ToSocketAddrs, UdpSocket},
sync::mpsc,
task::JoinHandle,
time,
};
use crate::plugins::kcp::{config::KcpConfig, session::KcpSessionManager, stream::KcpStream};
#[allow(unused)]
pub struct KcpListener {
udp: Arc<UdpSocket>,
accept_rx: mpsc::Receiver<(KcpStream, SocketAddr)>,
task_watcher: JoinHandle<()>,
}
impl Drop for KcpListener {
fn drop(&mut self) {
self.task_watcher.abort();
}
}
impl KcpListener {
pub async fn bind<A: ToSocketAddrs>(config: KcpConfig, addr: A) -> KcpResult<KcpListener> {
let udp = UdpSocket::bind(addr).await?;
let udp = Arc::new(udp);
let server_udp = udp.clone();
let (accept_tx, accept_rx) = mpsc::channel(1024 /* backlogs */);
let task_watcher = tokio::spawn(async move {
let (close_tx, mut close_rx) = mpsc::channel(64);
let mut sessions = KcpSessionManager::new();
let mut packet_buffer = [0u8; 65536];
loop {
tokio::select! {
conv = close_rx.recv() => {
let conv = conv.expect("close_tx closed unexpectly");
sessions.close_conv(conv);
trace!("session conv: {} removed", conv);
}
recv_res = udp.recv_from(&mut packet_buffer) => {
match recv_res {
Err(err) => {
error!("udp.recv_from failed, error: {}", err);
time::sleep(Duration::from_secs(1)).await;
}
Ok((n, peer_addr)) => {
let packet = &mut packet_buffer[..n];
log::trace!("received peer: {}, {:?}", peer_addr, ByteStr::new(packet));
let mut conv = kcp::get_conv(packet);
if conv == 0 {
// Allocate a conv for client.
conv = sessions.alloc_conv();
debug!("allocate {} conv for peer: {}", conv, peer_addr);
kcp::set_conv(packet, conv);
}
let session = match sessions.get_or_create(&config, conv, &udp, peer_addr, &close_tx) {
Ok((s, created)) => {
if created {
// Created a new session, constructed a new accepted client
let stream = KcpStream::with_session(s.clone());
if let Err(..) = accept_tx.try_send((stream, peer_addr)) {
debug!("failed to create accepted stream due to channel failure");
// remove it from session
sessions.close_conv(conv);
continue;
}
}
s
},
Err(err) => {
error!("failed to create session, error: {}, peer: {}, conv: {}", err, peer_addr, conv);
continue;
}
};
// let mut kcp = session.kcp_socket().lock().await;
// if let Err(err) = kcp.input(packet) {
// error!("kcp.input failed, peer: {}, conv: {}, error: {}, packet: {:?}", peer_addr, conv, err, ByteStr::new(packet));
// }
session.input(packet).await;
}
}
}
}
}
});
Ok(KcpListener {
udp: server_udp,
accept_rx,
task_watcher,
})
}
pub async fn accept(&mut self) -> KcpResult<(KcpStream, SocketAddr)> {
match self.accept_rx.recv().await {
Some(s) => Ok(s),
None => Err(KcpError::IoError(io::Error::new(
ErrorKind::Other,
"accept channel closed unexpectly",
))),
}
}
#[allow(unused)]
pub fn local_addr(&self) -> io::Result<SocketAddr> {
self.udp.local_addr()
}
}

14
src/plugins/kcp/mod.rs Normal file
View File

@ -0,0 +1,14 @@
//! Library of KCP on Tokio
pub use self::{
config::{KcpConfig, KcpNoDelayConfig},
listener::KcpListener,
stream::KcpStream,
};
mod config;
mod listener;
mod session;
mod skcp;
mod stream;
mod utils;

256
src/plugins/kcp/session.rs Normal file
View File

@ -0,0 +1,256 @@
use std::{
collections::{hash_map::Entry, HashMap},
net::SocketAddr,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
time::Duration,
};
use byte_string::ByteStr;
use kcp::KcpResult;
use log::{error, trace};
use tokio::{
net::UdpSocket,
sync::{mpsc, Mutex},
time::{self, Instant},
};
use crate::plugins::kcp::{skcp::KcpSocket, KcpConfig};
pub struct KcpSession {
socket: Mutex<KcpSocket>,
closed: AtomicBool,
session_expire: Duration,
session_close_notifier: Option<mpsc::Sender<u32>>,
input_tx: mpsc::Sender<Vec<u8>>,
}
impl KcpSession {
fn new(
socket: KcpSocket,
session_expire: Duration,
session_close_notifier: Option<mpsc::Sender<u32>>,
input_tx: mpsc::Sender<Vec<u8>>,
) -> KcpSession {
KcpSession {
socket: Mutex::new(socket),
closed: AtomicBool::new(false),
session_expire,
session_close_notifier,
input_tx,
}
}
pub fn new_shared(
socket: KcpSocket,
session_expire: Duration,
session_close_notifier: Option<mpsc::Sender<u32>>,
) -> Arc<KcpSession> {
let is_client = session_close_notifier.is_none();
let (input_tx, mut input_rx) = mpsc::channel(64);
let udp_socket = socket.udp_socket().clone();
let session = Arc::new(KcpSession::new(
socket,
session_expire,
session_close_notifier,
input_tx,
));
{
let session = session.clone();
tokio::spawn(async move {
let mut input_buffer = [0u8; 65536];
let update_timer = time::sleep(Duration::from_millis(10));
tokio::pin!(update_timer);
loop {
tokio::select! {
// recv() then input()
// Drives the KCP machine forward
recv_result = udp_socket.recv(&mut input_buffer), if is_client => {
match recv_result {
Err(err) => {
error!("[SESSION] UDP recv failed, error: {}", err);
}
Ok(n) => {
let input_buffer = &input_buffer[..n];
trace!("[SESSION] UDP recv {} bytes, going to input {:?}", n, ByteStr::new(input_buffer));
let mut socket = session.socket.lock().await;
match socket.input(input_buffer) {
Ok(true) => {
trace!("[SESSION] UDP input {} bytes and waked sender/receiver", n);
}
Ok(false) => {}
Err(err) => {
error!("[SESSION] UDP input {} bytes error: {}, input buffer {:?}", n, err, ByteStr::new(input_buffer));
}
}
}
}
}
// bytes received from listener socket
input_opt = input_rx.recv() => {
if let Some(input_buffer) = input_opt {
let mut socket = session.socket.lock().await;
match socket.input(&input_buffer) {
Ok(..) => {
trace!("[SESSION] UDP input {} bytes from channel {:?}", input_buffer.len(), ByteStr::new(&input_buffer));
}
Err(err) => {
error!("[SESSION] UDP input {} bytes from channel failed, error: {}, input buffer {:?}",
input_buffer.len(), err, ByteStr::new(&input_buffer));
}
}
}
}
// Call update() in period
_ = &mut update_timer => {
let mut socket = session.socket.lock().await;
let is_closed = session.closed.load(Ordering::Acquire);
if is_closed && socket.can_close() {
trace!("[SESSION] KCP session closed");
break;
}
// server socket expires
if !is_client {
// If this is a server stream, close it automatically after a period of time
let last_update_time = socket.last_update_time();
let elapsed = last_update_time.elapsed();
if elapsed > session.session_expire {
if elapsed > session.session_expire * 2 {
// Force close. Client may have already gone.
trace!(
"[SESSION] force close inactive session, conv: {}, last_update: {}s ago",
socket.conv(),
elapsed.as_secs()
);
break;
}
if !is_closed {
trace!(
"[SESSION] closing inactive session, conv: {}, last_update: {}s ago",
socket.conv(),
elapsed.as_secs()
);
session.closed.store(true, Ordering::Release);
}
}
}
match socket.update() {
Ok(next_next) => {
update_timer.as_mut().reset(Instant::from_std(next_next));
}
Err(err) => {
error!("[SESSION] KCP update failed, error: {}", err);
update_timer.as_mut().reset(Instant::now() + Duration::from_millis(10));
}
}
}
}
}
{
// Close the socket.
// Wake all pending tasks and let all send/recv return EOF
let mut socket = session.socket.lock().await;
socket.close();
}
if let Some(ref notifier) = session.session_close_notifier {
let socket = session.socket.lock().await;
let _ = notifier.send(socket.conv()).await;
}
});
}
session
}
pub fn kcp_socket(&self) -> &Mutex<KcpSocket> {
&self.socket
}
pub fn close(&self) {
self.closed.store(true, Ordering::Release);
}
pub async fn input(&self, buf: &[u8]) {
self.input_tx
.send(buf.to_owned())
.await
.expect("input channel closed")
}
}
pub struct KcpSessionManager {
sessions: HashMap<u32, Arc<KcpSession>>,
next_free_conv: u32,
}
impl KcpSessionManager {
pub fn new() -> KcpSessionManager {
KcpSessionManager {
sessions: HashMap::new(),
next_free_conv: 0,
}
}
pub fn close_conv(&mut self, conv: u32) {
self.sessions.remove(&conv);
}
pub fn alloc_conv(&mut self) -> u32 {
loop {
let (mut c, _) = self.next_free_conv.overflowing_add(1);
if c == 0 {
let (nc, _) = c.overflowing_add(1);
c = nc;
}
self.next_free_conv = c;
if self.sessions.get(&self.next_free_conv).is_none() {
let conv = self.next_free_conv;
return conv;
}
}
}
pub fn get_or_create(
&mut self,
config: &KcpConfig,
conv: u32,
udp: &Arc<UdpSocket>,
peer_addr: SocketAddr,
session_close_notifier: &mpsc::Sender<u32>,
) -> KcpResult<(Arc<KcpSession>, bool)> {
match self.sessions.entry(conv) {
Entry::Occupied(occ) => Ok((occ.get().clone(), false)),
Entry::Vacant(vac) => {
let socket = KcpSocket::new(config, conv, udp.clone(), peer_addr, config.stream)?;
let session = KcpSession::new_shared(
socket,
config.session_expire,
Some(session_close_notifier.clone()),
);
trace!("created session for conv: {}, peer: {}", conv, peer_addr);
vac.insert(session.clone());
Ok((session, true))
}
}
}
}

288
src/plugins/kcp/skcp.rs Normal file
View File

@ -0,0 +1,288 @@
use std::{
io::{self, ErrorKind, Write},
net::SocketAddr,
sync::Arc,
task::{Context, Poll, Waker},
time::{Duration, Instant},
};
use futures::future;
use kcp::{Error as KcpError, Kcp, KcpResult};
use log::{error, trace};
use tokio::{net::UdpSocket, sync::mpsc};
use crate::plugins::kcp::{utils::now_millis, KcpConfig};
/// Writer for sending packets to the underlying UdpSocket
struct UdpOutput {
socket: Arc<UdpSocket>,
target_addr: SocketAddr,
delay_tx: mpsc::UnboundedSender<Vec<u8>>,
}
impl UdpOutput {
/// Create a new Writer for writing packets to UdpSocket
pub fn new(socket: Arc<UdpSocket>, target_addr: SocketAddr) -> UdpOutput {
let (delay_tx, mut delay_rx) = mpsc::unbounded_channel::<Vec<u8>>();
{
let socket = socket.clone();
tokio::spawn(async move {
while let Some(buf) = delay_rx.recv().await {
if let Err(err) = socket.send_to(&buf, target_addr).await {
error!("[SEND] UDP delayed send failed, error: {}", err);
}
}
});
}
UdpOutput {
socket,
target_addr,
delay_tx,
}
}
}
impl Write for UdpOutput {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
match self.socket.try_send_to(buf, self.target_addr) {
Ok(n) => Ok(n),
Err(ref err) if err.kind() == ErrorKind::WouldBlock => {
// send return EAGAIN
// ignored as packet was lost in transmission
trace!(
"[SEND] UDP send EAGAIN, packet.size: {} bytes, delayed send",
buf.len()
);
self.delay_tx
.send(buf.to_owned())
.expect("channel closed unexpectly");
Ok(buf.len())
}
Err(err) => Err(err),
}
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
pub struct KcpSocket {
kcp: Kcp<UdpOutput>,
last_update: Instant,
socket: Arc<UdpSocket>,
flush_write: bool,
flush_ack_input: bool,
sent_first: bool,
pending_sender: Option<Waker>,
pending_receiver: Option<Waker>,
closed: bool,
}
impl KcpSocket {
pub fn new(
c: &KcpConfig,
conv: u32,
socket: Arc<UdpSocket>,
target_addr: SocketAddr,
stream: bool,
) -> KcpResult<KcpSocket> {
let output = UdpOutput::new(socket.clone(), target_addr);
let mut kcp = if stream {
Kcp::new_stream(conv, output)
} else {
Kcp::new(conv, output)
};
c.apply_config(&mut kcp);
// Ask server to allocate one
if conv == 0 {
kcp.input_conv();
}
kcp.update(now_millis())?;
Ok(KcpSocket {
kcp,
last_update: Instant::now(),
socket,
flush_write: c.flush_write,
flush_ack_input: c.flush_acks_input,
sent_first: false,
pending_sender: None,
pending_receiver: None,
closed: false,
})
}
/// Call every time you got data from transmission
pub fn input(&mut self, buf: &[u8]) -> KcpResult<bool> {
match self.kcp.input(buf) {
Ok(..) => {}
Err(KcpError::ConvInconsistent(expected, actual)) => {
trace!(
"[INPUT] Conv expected={} actual={} ignored",
expected,
actual
);
return Ok(false);
}
Err(err) => return Err(err),
}
self.last_update = Instant::now();
if self.flush_ack_input {
self.kcp.flush_ack()?;
}
Ok(self.try_wake_pending_waker())
}
/// Call if you want to send some data
pub fn poll_send(&mut self, cx: &mut Context<'_>, mut buf: &[u8]) -> Poll<KcpResult<usize>> {
if self.closed {
return Ok(0).into();
}
// If:
// 1. Have sent the first packet (asking for conv)
// 2. Too many pending packets
if self.sent_first
&& (self.kcp.wait_snd() >= self.kcp.snd_wnd() as usize || self.kcp.waiting_conv())
{
trace!(
"[SEND] waitsnd={} sndwnd={} excceeded or waiting conv={}",
self.kcp.wait_snd(),
self.kcp.snd_wnd(),
self.kcp.waiting_conv()
);
self.pending_sender = Some(cx.waker().clone());
return Poll::Pending;
}
if !self.sent_first && self.kcp.waiting_conv() && buf.len() > self.kcp.mss() as usize {
buf = &buf[..self.kcp.mss() as usize];
}
let n = self.kcp.send(buf)?;
self.sent_first = true;
self.last_update = Instant::now();
if self.flush_write {
self.kcp.flush()?;
}
Ok(n).into()
}
/// Call if you want to send some data
#[allow(dead_code)]
pub async fn send(&mut self, buf: &[u8]) -> KcpResult<usize> {
future::poll_fn(|cx| self.poll_send(cx, buf)).await
}
#[allow(dead_code)]
pub fn try_recv(&mut self, buf: &mut [u8]) -> KcpResult<usize> {
if self.closed {
return Ok(0);
}
self.kcp.recv(buf)
}
pub fn poll_recv(&mut self, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<KcpResult<usize>> {
if self.closed {
return Ok(0).into();
}
match self.kcp.recv(buf) {
Ok(n) => Ok(n).into(),
Err(KcpError::RecvQueueEmpty) => {
self.pending_receiver = Some(cx.waker().clone());
Poll::Pending
}
Err(err) => Err(err).into(),
}
}
#[allow(dead_code)]
pub async fn recv(&mut self, buf: &mut [u8]) -> KcpResult<usize> {
future::poll_fn(|cx| self.poll_recv(cx, buf)).await
}
pub fn flush(&mut self) -> KcpResult<()> {
self.kcp.flush()?;
self.last_update = Instant::now();
Ok(())
}
fn try_wake_pending_waker(&mut self) -> bool {
let mut waked = false;
if self.pending_sender.is_some()
&& self.kcp.wait_snd() < self.kcp.snd_wnd() as usize
&& !self.kcp.waiting_conv()
{
let waker = self.pending_sender.take().unwrap();
waker.wake();
waked = true;
}
if self.pending_receiver.is_some() {
if let Ok(peek) = self.kcp.peeksize() {
if peek > 0 {
let waker = self.pending_receiver.take().unwrap();
waker.wake();
waked = true;
}
}
}
waked
}
pub fn update(&mut self) -> KcpResult<Instant> {
let now = now_millis();
self.kcp.update(now)?;
let next = self.kcp.check(now);
self.try_wake_pending_waker();
Ok(Instant::now() + Duration::from_millis(next as u64))
}
pub fn close(&mut self) {
self.closed = true;
if let Some(w) = self.pending_sender.take() {
w.wake();
}
if let Some(w) = self.pending_receiver.take() {
w.wake();
}
}
pub fn udp_socket(&self) -> &Arc<UdpSocket> {
&self.socket
}
pub fn can_close(&self) -> bool {
self.kcp.wait_snd() == 0
}
pub fn conv(&self) -> u32 {
self.kcp.conv()
}
pub fn peek_size(&self) -> KcpResult<usize> {
self.kcp.peeksize()
}
pub fn last_update_time(&self) -> Instant {
self.last_update
}
}

183
src/plugins/kcp/stream.rs Normal file
View File

@ -0,0 +1,183 @@
use std::{
io::{self, ErrorKind},
net::{IpAddr, SocketAddr},
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use futures::{future, ready};
use kcp::{Error as KcpError, KcpResult};
use log::trace;
use tokio::{
io::{AsyncRead, AsyncWrite, ReadBuf},
net::UdpSocket,
};
use crate::plugins::kcp::{config::KcpConfig, session::KcpSession, skcp::KcpSocket};
pub struct KcpStream {
session: Arc<KcpSession>,
recv_buffer: Vec<u8>,
recv_buffer_pos: usize,
recv_buffer_cap: usize,
}
impl Drop for KcpStream {
fn drop(&mut self) {
self.session.close();
}
}
#[allow(unused)]
impl KcpStream {
pub async fn connect(config: &KcpConfig, addr: SocketAddr) -> KcpResult<KcpStream> {
let udp = match addr.ip() {
IpAddr::V4(..) => UdpSocket::bind("0.0.0.0:0").await?,
IpAddr::V6(..) => UdpSocket::bind("[::]:0").await?,
};
let udp = Arc::new(udp);
let socket = KcpSocket::new(config, 0, udp, addr, config.stream)?;
let session = KcpSession::new_shared(socket, config.session_expire, None);
Ok(KcpStream::with_session(session))
}
pub(crate) fn with_session(session: Arc<KcpSession>) -> KcpStream {
KcpStream {
session,
recv_buffer: Vec::new(),
recv_buffer_pos: 0,
recv_buffer_cap: 0,
}
}
pub fn poll_send(&mut self, cx: &mut Context<'_>, buf: &[u8]) -> Poll<KcpResult<usize>> {
// Mutex doesn't have poll_lock, spinning on it.
let socket = self.session.kcp_socket();
let mut kcp = match socket.try_lock() {
Ok(guard) => guard,
Err(..) => {
cx.waker().wake_by_ref();
return Poll::Pending;
}
};
kcp.poll_send(cx, buf)
}
pub async fn send(&mut self, buf: &[u8]) -> KcpResult<usize> {
future::poll_fn(|cx| self.poll_send(cx, buf)).await
}
pub fn poll_recv(&mut self, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<KcpResult<usize>> {
loop {
// Consumes all data in buffer
if self.recv_buffer_pos < self.recv_buffer_cap {
let remaining = self.recv_buffer_cap - self.recv_buffer_pos;
let copy_length = remaining.min(buf.len());
buf.copy_from_slice(
&self.recv_buffer[self.recv_buffer_pos..self.recv_buffer_pos + copy_length],
);
self.recv_buffer_pos += copy_length;
return Ok(copy_length).into();
}
// Mutex doesn't have poll_lock, spinning on it.
let socket = self.session.kcp_socket();
let mut kcp = match socket.try_lock() {
Ok(guard) => guard,
Err(..) => {
cx.waker().wake_by_ref();
return Poll::Pending;
}
};
// Try to read from KCP
// 1. Read directly with user provided `buf`
match ready!(kcp.poll_recv(cx, buf)) {
Ok(n) => {
trace!("[CLIENT] recv directly {} bytes", n);
return Ok(n).into();
}
Err(KcpError::UserBufTooSmall) => {}
Err(err) => return Err(err).into(),
}
// 2. User `buf` too small, read to recv_buffer
let required_size = kcp.peek_size()?;
if self.recv_buffer.len() < required_size {
self.recv_buffer.resize(required_size, 0);
}
match ready!(kcp.poll_recv(cx, &mut self.recv_buffer)) {
Ok(n) => {
trace!("[CLIENT] recv buffered {} bytes", n);
self.recv_buffer_pos = 0;
self.recv_buffer_cap = n;
}
Err(err) => return Err(err).into(),
}
}
}
pub async fn recv(&mut self, buf: &mut [u8]) -> KcpResult<usize> {
future::poll_fn(|cx| self.poll_recv(cx, buf)).await
}
}
impl AsyncRead for KcpStream {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
match ready!(self.poll_recv(cx, buf.initialize_unfilled())) {
Ok(n) => {
buf.advance(n);
Ok(()).into()
}
Err(KcpError::IoError(err)) => Err(err).into(),
Err(err) => Err(io::Error::new(ErrorKind::Other, err)).into(),
}
}
}
impl AsyncWrite for KcpStream {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
match ready!(self.poll_send(cx, buf)) {
Ok(n) => Ok(n).into(),
Err(KcpError::IoError(err)) => Err(err).into(),
Err(err) => Err(io::Error::new(ErrorKind::Other, err)).into(),
}
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
// Mutex doesn't have poll_lock, spinning on it.
let socket = self.session.kcp_socket();
let mut kcp = match socket.try_lock() {
Ok(guard) => guard,
Err(..) => {
cx.waker().wake_by_ref();
return Poll::Pending;
}
};
match kcp.flush() {
Ok(..) => Ok(()).into(),
Err(KcpError::IoError(err)) => Err(err).into(),
Err(err) => Err(io::Error::new(ErrorKind::Other, err)).into(),
}
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Ok(()).into()
}
}

10
src/plugins/kcp/utils.rs Normal file
View File

@ -0,0 +1,10 @@
use std::time::{SystemTime, UNIX_EPOCH};
#[inline]
pub fn now_millis() -> u32 {
let start = SystemTime::now();
let since_the_epoch = start
.duration_since(UNIX_EPOCH)
.expect("time went afterwards");
(since_the_epoch.as_secs() * 1000 + since_the_epoch.subsec_millis() as u64 / 1_000_000) as u32
}

1
src/plugins/mod.rs Normal file
View File

@ -0,0 +1 @@
pub mod kcp;

View File

@ -1,41 +1,39 @@
use futures::future::try_join; use log::{error, info};
use log::{debug, error, info, warn};
use std::collections::{HashMap, HashSet}; use std::collections::{HashMap, HashSet};
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::Arc; use std::sync::Arc;
use tokio::io;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
use tokio::task::JoinHandle; use tokio::task::JoinHandle;
mod tls; mod protocol;
use self::tls::get_sni; use crate::config::{ParsedConfig, Upstream};
use crate::config::BaseConfig; use protocol::{kcp, tcp};
#[derive(Debug)] #[derive(Debug)]
pub struct Server { pub struct Server {
pub proxies: Vec<Arc<Proxy>>, pub proxies: Vec<Arc<Proxy>>,
pub config: BaseConfig, pub config: ParsedConfig,
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct Proxy { pub struct Proxy {
pub name: String, pub name: String,
pub listen: SocketAddr, pub listen: SocketAddr,
pub protocol: String,
pub tls: bool, pub tls: bool,
pub sni: Option<HashMap<String, String>>, pub sni: Option<HashMap<String, String>>,
pub default: String, pub default: String,
pub upstream: HashMap<String, String>, pub upstream: HashMap<String, Upstream>,
} }
impl Server { impl Server {
pub fn new(config: BaseConfig) -> Self { pub fn new(config: ParsedConfig) -> Self {
let mut new_server = Server { let mut new_server = Server {
proxies: Vec::new(), proxies: Vec::new(),
config: config.clone(), config: config.clone(),
}; };
for (name, proxy) in config.servers.iter() { for (name, proxy) in config.servers.iter() {
let protocol = proxy.protocol.clone().unwrap_or_else(|| "tcp".to_string());
let tls = proxy.tls.unwrap_or(false); let tls = proxy.tls.unwrap_or(false);
let sni = proxy.sni.clone(); let sni = proxy.sni.clone();
let default = proxy.default.clone().unwrap_or_else(|| "ban".to_string()); let default = proxy.default.clone().unwrap_or_else(|| "ban".to_string());
@ -48,7 +46,6 @@ impl Server {
upstream_set.insert(key.clone()); upstream_set.insert(key.clone());
} }
for listen in proxy.listen.clone() { for listen in proxy.listen.clone() {
println!("{:?}", listen);
let listen_addr: SocketAddr = match listen.parse() { let listen_addr: SocketAddr = match listen.parse() {
Ok(addr) => addr, Ok(addr) => addr,
Err(_) => { Err(_) => {
@ -56,9 +53,11 @@ impl Server {
continue; continue;
} }
}; };
let proxy = Proxy { let proxy = Proxy {
name: name.clone(), name: name.clone(),
listen: listen_addr, listen: listen_addr,
protocol: protocol.clone(),
tls, tls,
sni: sni.clone(), sni: sni.clone(),
default: default.clone(), default: default.clone(),
@ -77,9 +76,22 @@ impl Server {
let mut handles: Vec<JoinHandle<()>> = Vec::new(); let mut handles: Vec<JoinHandle<()>> = Vec::new();
for config in proxies { for config in proxies {
info!("Starting server {} on {}", config.name, config.listen); info!(
"Starting {} server {} on {}",
config.protocol, config.name, config.listen
);
let handle = tokio::spawn(async move { let handle = tokio::spawn(async move {
let _ = proxy(config).await; match config.protocol.as_ref() {
"tcp" => {
let _ = tcp::proxy(config).await;
}
"kcp" => {
let _ = kcp::proxy(config).await;
}
_ => {
error!("Invalid protocol: {}", config.protocol)
}
}
}); });
handles.push(handle); handles.push(handle);
} }
@ -91,145 +103,89 @@ impl Server {
} }
} }
async fn proxy(config: Arc<Proxy>) -> Result<(), Box<dyn std::error::Error>> {
let listener = TcpListener::bind(config.listen).await?;
let config = config.clone();
loop {
let thread_proxy = config.clone();
match listener.accept().await {
Err(err) => {
error!("Failed to accept connection: {}", err);
return Err(Box::new(err));
}
Ok((stream, _)) => {
tokio::spawn(async move {
match accept(stream, thread_proxy).await {
Ok(_) => {}
Err(err) => {
error!("Relay thread returned an error: {}", err);
}
};
});
}
}
}
}
async fn accept(inbound: TcpStream, proxy: Arc<Proxy>) -> Result<(), Box<dyn std::error::Error>> {
debug!("New connection from {:?}", inbound.peer_addr()?);
let upstream_name = match proxy.tls {
false => proxy.default.clone(),
true => {
let mut hello_buf = [0u8; 1024];
inbound.peek(&mut hello_buf).await?;
let snis = get_sni(&hello_buf);
if snis.is_empty() {
proxy.default.clone()
} else {
match proxy.sni.clone() {
Some(sni_map) => {
let mut upstream = proxy.default.clone();
for sni in snis {
let m = sni_map.get(&sni);
if m.is_some() {
upstream = m.unwrap().clone();
break;
}
}
upstream
}
None => proxy.default.clone(),
}
}
}
};
debug!("Upstream: {}", upstream_name);
let upstream = match proxy.upstream.get(&upstream_name) {
Some(upstream) => upstream,
None => {
warn!(
"No upstream named {:?} on server {:?}",
proxy.default, proxy.name
);
return process(inbound, &proxy.default).await;
}
};
return process(inbound, upstream).await;
}
async fn process(mut inbound: TcpStream, upstream: &str) -> Result<(), Box<dyn std::error::Error>> {
if upstream == "ban" {
let _ = inbound.shutdown();
return Ok(());
} else if upstream == "echo" {
loop {
let mut buf = [0u8; 1];
let b = inbound.read(&mut buf).await?;
if b == 0 {
break;
} else {
inbound.write(&buf).await?;
}
}
return Ok(());
}
let outbound = TcpStream::connect(upstream).await?;
let (mut ri, mut wi) = io::split(inbound);
let (mut ro, mut wo) = io::split(outbound);
let inbound_to_outbound = copy(&mut ri, &mut wo);
let outbound_to_inbound = copy(&mut ro, &mut wi);
let (bytes_tx, bytes_rx) = try_join(inbound_to_outbound, outbound_to_inbound).await?;
debug!("Bytes read: {:?} write: {:?}", bytes_tx, bytes_rx);
Ok(())
}
async fn copy<'a, R, W>(reader: &'a mut R, writer: &'a mut W) -> io::Result<u64>
where
R: AsyncRead + Unpin + ?Sized,
W: AsyncWrite + Unpin + ?Sized,
{
match io::copy(reader, writer).await {
Ok(u64) => {
let _ = writer.shutdown().await;
Ok(u64)
}
Err(_) => Ok(0),
}
}
#[cfg(test)] #[cfg(test)]
mod test { mod tests {
use crate::plugins::kcp::{KcpConfig, KcpStream};
use std::net::SocketAddr;
use std::thread::{self, sleep}; use std::thread::{self, sleep};
use std::time::Duration; use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
use super::*; use super::*;
#[tokio::main]
async fn tcp_mock_server() {
let server_addr: SocketAddr = "127.0.0.1:54599".parse().unwrap();
let listener = TcpListener::bind(server_addr).await.unwrap();
loop {
let (mut stream, _) = listener.accept().await.unwrap();
let mut buf = [0u8; 2];
let mut n = stream.read(&mut buf).await.unwrap();
while n > 0 {
stream.write(b"hello").await.unwrap();
if buf.eq(b"by") {
stream.shutdown().await.unwrap();
break;
}
n = stream.read(&mut buf).await.unwrap();
}
stream.shutdown().await.unwrap();
}
}
#[tokio::test] #[tokio::test]
async fn test_echo_server() { async fn test_proxy() {
use crate::config::Config; use crate::config::Config;
let config = Config::new("tests/config.yaml").unwrap(); let config = Config::new("tests/config.yaml").unwrap();
let mut server = Server::new(config.base); let mut server = Server::new(config.base);
thread::spawn(move || {
tcp_mock_server();
});
sleep(Duration::from_secs(1)); // wait for server to start
thread::spawn(move || { thread::spawn(move || {
let _ = server.run(); let _ = server.run();
}); });
sleep(Duration::from_secs(1)); // wait for server to start sleep(Duration::from_secs(1)); // wait for server to start
// test TCP proxy
let mut conn = TcpStream::connect("127.0.0.1:54500").await.unwrap();
let mut buf = [0u8; 5];
conn.write(b"hi").await.unwrap();
conn.read(&mut buf).await.unwrap();
assert_eq!(&buf, b"hello");
conn.shutdown().await.unwrap();
// test TCP echo
let mut conn = TcpStream::connect("127.0.0.1:54956").await.unwrap(); let mut conn = TcpStream::connect("127.0.0.1:54956").await.unwrap();
let mut buf = [0u8; 1]; let mut buf = [0u8; 1];
for i in 0..=255u8 { for i in 0..=10u8 {
conn.write(&[i]).await.unwrap(); conn.write(&[i]).await.unwrap();
conn.read(&mut buf).await.unwrap(); conn.read(&mut buf).await.unwrap();
assert_eq!(&buf, &[i]); assert_eq!(&buf, &[i]);
} }
conn.shutdown().await.unwrap(); conn.shutdown().await.unwrap();
// test KCP echo
let kcp_config = KcpConfig::default();
let server_addr: SocketAddr = "127.0.0.1:54959".parse().unwrap();
let mut conn = KcpStream::connect(&kcp_config, server_addr).await.unwrap();
let mut buf = [0u8; 1];
for i in 0..=10u8 {
conn.write(&[i]).await.unwrap();
conn.read(&mut buf).await.unwrap();
assert_eq!(&buf, &[i]);
}
conn.shutdown().await.unwrap();
// test KCP proxy and close mock server
let kcp_config = KcpConfig::default();
let server_addr: SocketAddr = "127.0.0.1:54958".parse().unwrap();
let mut conn = KcpStream::connect(&kcp_config, server_addr).await.unwrap();
let mut buf = [0u8; 5];
conn.write(b"by").await.unwrap();
conn.read(&mut buf).await.unwrap();
assert_eq!(&buf, b"hello");
conn.shutdown().await.unwrap();
} }
} }

112
src/servers/protocol/kcp.rs Normal file
View File

@ -0,0 +1,112 @@
use crate::config::Upstream;
use crate::plugins::kcp::{KcpConfig, KcpListener, KcpStream};
use crate::servers::Proxy;
use futures::future::try_join;
use log::{debug, error, warn};
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::io;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tokio::net::TcpStream;
pub async fn proxy(config: Arc<Proxy>) -> Result<(), Box<dyn std::error::Error>> {
let kcp_config = KcpConfig::default();
let mut listener = KcpListener::bind(kcp_config, config.listen).await?;
let config = config.clone();
loop {
let thread_proxy = config.clone();
match listener.accept().await {
Err(err) => {
error!("Failed to accept connection: {}", err);
return Err(Box::new(err));
}
Ok((stream, peer)) => {
tokio::spawn(async move {
match accept(stream, peer, thread_proxy).await {
Ok(_) => {}
Err(err) => {
error!("Relay thread returned an error: {}", err);
}
};
});
}
}
}
}
async fn accept(
inbound: KcpStream,
peer: SocketAddr,
proxy: Arc<Proxy>,
) -> Result<(), Box<dyn std::error::Error>> {
debug!("New connection from {:?}", peer);
let upstream_name = proxy.default.clone();
debug!("Upstream: {}", upstream_name);
let upstream = match proxy.upstream.get(&upstream_name) {
Some(upstream) => upstream,
None => {
warn!(
"No upstream named {:?} on server {:?}",
proxy.default, proxy.name
);
return process(inbound, proxy.upstream.get(&proxy.default).unwrap()).await;
// ToDo: Remove unwrap and check default option
}
};
return process(inbound, upstream).await;
}
async fn process(
mut inbound: KcpStream,
upstream: &Upstream,
) -> Result<(), Box<dyn std::error::Error>> {
match upstream {
Upstream::Ban => {
let _ = inbound.shutdown();
}
Upstream::Echo => {
let (mut ri, mut wi) = io::split(inbound);
let inbound_to_inbound = copy(&mut ri, &mut wi);
let bytes_tx = inbound_to_inbound.await;
debug!("Bytes read: {:?}", bytes_tx);
}
Upstream::Custom(custom) => match custom.protocol.as_ref() {
"tcp" => {
let outbound = TcpStream::connect(custom.addr.clone()).await?;
let (mut ri, mut wi) = io::split(inbound);
let (mut ro, mut wo) = io::split(outbound);
let inbound_to_outbound = copy(&mut ri, &mut wo);
let outbound_to_inbound = copy(&mut ro, &mut wi);
let (bytes_tx, bytes_rx) =
try_join(inbound_to_outbound, outbound_to_inbound).await?;
debug!("Bytes read: {:?} write: {:?}", bytes_tx, bytes_rx);
}
_ => {
error!("Reached unknown protocol: {:?}", custom.protocol);
}
},
};
Ok(())
}
async fn copy<'a, R, W>(reader: &'a mut R, writer: &'a mut W) -> io::Result<u64>
where
R: AsyncRead + Unpin + ?Sized,
W: AsyncWrite + Unpin + ?Sized,
{
match io::copy(reader, writer).await {
Ok(u64) => {
let _ = writer.shutdown().await;
Ok(u64)
}
Err(_) => Ok(0),
}
}

View File

@ -0,0 +1,3 @@
pub mod kcp;
pub mod tcp;
pub mod tls;

131
src/servers/protocol/tcp.rs Normal file
View File

@ -0,0 +1,131 @@
use crate::config::Upstream;
use crate::servers::protocol::tls::get_sni;
use crate::servers::Proxy;
use futures::future::try_join;
use log::{debug, error, warn};
use std::sync::Arc;
use tokio::io;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
pub async fn proxy(config: Arc<Proxy>) -> Result<(), Box<dyn std::error::Error>> {
let listener = TcpListener::bind(config.listen).await?;
let config = config.clone();
loop {
let thread_proxy = config.clone();
match listener.accept().await {
Err(err) => {
error!("Failed to accept connection: {}", err);
return Err(Box::new(err));
}
Ok((stream, _)) => {
tokio::spawn(async move {
match accept(stream, thread_proxy).await {
Ok(_) => {}
Err(err) => {
error!("Relay thread returned an error: {}", err);
}
};
});
}
}
}
}
async fn accept(inbound: TcpStream, proxy: Arc<Proxy>) -> Result<(), Box<dyn std::error::Error>> {
debug!("New connection from {:?}", inbound.peer_addr()?);
let upstream_name = match proxy.tls {
false => proxy.default.clone(),
true => {
let mut hello_buf = [0u8; 1024];
inbound.peek(&mut hello_buf).await?;
let snis = get_sni(&hello_buf);
if snis.is_empty() {
proxy.default.clone()
} else {
match proxy.sni.clone() {
Some(sni_map) => {
let mut upstream = proxy.default.clone();
for sni in snis {
let m = sni_map.get(&sni);
if m.is_some() {
upstream = m.unwrap().clone();
break;
}
}
upstream
}
None => proxy.default.clone(),
}
}
}
};
debug!("Upstream: {}", upstream_name);
let upstream = match proxy.upstream.get(&upstream_name) {
Some(upstream) => upstream,
None => {
warn!(
"No upstream named {:?} on server {:?}",
proxy.default, proxy.name
);
return process(inbound, proxy.upstream.get(&proxy.default).unwrap()).await;
// ToDo: Remove unwrap and check default option
}
};
return process(inbound, upstream).await;
}
async fn process(
mut inbound: TcpStream,
upstream: &Upstream,
) -> Result<(), Box<dyn std::error::Error>> {
match upstream {
Upstream::Ban => {
let _ = inbound.shutdown();
}
Upstream::Echo => {
let (mut ri, mut wi) = io::split(inbound);
let inbound_to_inbound = copy(&mut ri, &mut wi);
let bytes_tx = inbound_to_inbound.await;
debug!("Bytes read: {:?}", bytes_tx);
}
Upstream::Custom(custom) => match custom.protocol.as_ref() {
"tcp" => {
let outbound = TcpStream::connect(custom.addr.clone()).await?;
let (mut ri, mut wi) = io::split(inbound);
let (mut ro, mut wo) = io::split(outbound);
let inbound_to_outbound = copy(&mut ri, &mut wo);
let outbound_to_inbound = copy(&mut ro, &mut wi);
let (bytes_tx, bytes_rx) =
try_join(inbound_to_outbound, outbound_to_inbound).await?;
debug!("Bytes read: {:?} write: {:?}", bytes_tx, bytes_rx);
}
_ => {
error!("Reached unknown protocol: {:?}", custom.protocol);
}
},
};
Ok(())
}
async fn copy<'a, R, W>(reader: &'a mut R, writer: &'a mut W) -> io::Result<u64>
where
R: AsyncRead + Unpin + ?Sized,
W: AsyncWrite + Unpin + ?Sized,
{
match io::copy(reader, writer).await {
Ok(u64) => {
let _ = writer.shutdown().await;
Ok(u64)
}
Err(_) => Ok(0),
}
}

View File

@ -11,11 +11,26 @@ servers:
proxy.test.com: proxy proxy.test.com: proxy
www.test.com: web www.test.com: web
default: ban default: ban
echo_server: tcp_server:
listen:
- "127.0.0.1:54500"
default: tester
tcp_echo_server:
listen: listen:
- "0.0.0.0:54956" - "0.0.0.0:54956"
default: echo default: echo
kcp_server:
protocol: kcp
listen:
- "127.0.0.1:54958"
default: tester
kcp_echo_server:
protocol: kcp
listen:
- "127.0.0.1:54959"
default: echo
upstream: upstream:
web: "127.0.0.1:8080" web: "tcp://127.0.0.1:8080"
proxy: "www.example.com:1024" proxy: "tcp://www.example.com:1024"
tester: "tcp://127.0.0.1:54599"