Compare commits
	
		
			50 Commits
		
	
	
		
			v0.1.1
			...
			aa74585a1a
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| aa74585a1a | |||
| 646afe2c40 | |||
| 77bc8364f2 | |||
| ec9ab1d2bc | |||
| bb81a32349 | |||
| 17b39dc6bc | |||
| 07fccb6b2a | |||
| 3a2367ef28 | |||
| 2116659a14 | |||
| 8404f38182 | |||
| 23296c6436 | |||
| 84f0499ec8 | |||
| ae594135a1 | |||
| 9564fbed6e | |||
| a574163aef | |||
| 2651ec1f4a | |||
| 8dae1126d5 | |||
| da46c5873f | |||
| 086e2b4766 | |||
| 5f0de72b88 | |||
| 40b890bc13 | |||
| 483c058105 | |||
| 6349fc6502 | |||
| cd35859c9b | |||
| 7f399af713 | |||
| fd86162450 | |||
| a6748f30d9 | |||
| 902b2c0d55 | |||
| fb7a7d9cae | |||
| 1c325f45b4 | |||
| 79c931fc38 | |||
| 915e39b684 | |||
| 0c5153bbd6 | |||
| 01784ee3fd | |||
| f4bc441ca8 | |||
| f010f8c76b | |||
| 
						 | 
					8fbc0c370a | ||
| 
						 | 
					bff92738d5 | ||
| 
						 | 
					754a5af794 | ||
| 
						 | 
					fc7a3038bd | ||
| 
						 | 
					8a96de9666 | ||
| 
						 | 
					0407f4b40c | ||
| 
						 | 
					47be2568ba | ||
| 
						 | 
					5944beb6a2 | ||
| 
						 | 
					4363e3f76a | ||
| 
						 | 
					ee9d0685b3 | ||
| 
						 | 
					421ad8c979 | ||
| 
						 | 
					a88a263d20 | ||
| 
						 | 
					bfce455a7e | ||
| 
						 | 
					55eef8581c | 
							
								
								
									
										3
									
								
								.cargo/config.toml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										3
									
								
								.cargo/config.toml
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,3 @@
 | 
			
		||||
[profile.release]
 | 
			
		||||
lto = "thin"
 | 
			
		||||
strip = true
 | 
			
		||||
							
								
								
									
										92
									
								
								.drone.jsonnet
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										92
									
								
								.drone.jsonnet
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,92 @@
 | 
			
		||||
local executableName = 'l4p';
 | 
			
		||||
local build_image = 'img.kie.rs/jjkiers/rust-cross:rust1.71.1-zig';
 | 
			
		||||
 | 
			
		||||
local archs = [
 | 
			
		||||
  { target: 'aarch64-unknown-linux-musl', short: 'arm64-musl' },
 | 
			
		||||
  { target: 'x86_64-pc-windows-gnu', short: 'windows' },
 | 
			
		||||
  { target: 'x86_64-unknown-linux-musl', short: 'amd64-musl' },
 | 
			
		||||
];
 | 
			
		||||
 | 
			
		||||
local getStepName(arch) = 'Build for ' + arch.short;
 | 
			
		||||
 | 
			
		||||
local builtExecutableName(arch) = executableName + if std.length(std.findSubstr(arch.short, 'windows')) > 0 then '.exe' else '';
 | 
			
		||||
local targetExecutableName(arch) = executableName + '-' + arch.target + if std.length(std.findSubstr(arch.short, 'windows')) > 0 then '.exe' else '';
 | 
			
		||||
 | 
			
		||||
local getVolumeName(arch) = 'target-' + arch.target;
 | 
			
		||||
local getLocalVolumes(arch) = [
 | 
			
		||||
  {
 | 
			
		||||
    name: getVolumeName(arch),
 | 
			
		||||
    temp: {},
 | 
			
		||||
  }
 | 
			
		||||
  for arch in archs
 | 
			
		||||
];
 | 
			
		||||
 | 
			
		||||
local add_build_steps() = [
 | 
			
		||||
  {
 | 
			
		||||
    name: getStepName(arch),
 | 
			
		||||
    image: build_image,
 | 
			
		||||
    commands: [
 | 
			
		||||
      'echo Hello World from Jsonnet on ' + arch.target + '!',
 | 
			
		||||
      'cargo zigbuild --release --target ' + arch.target,
 | 
			
		||||
      'cp target/' + arch.target + '/release/' + builtExecutableName(arch) + ' artifacts/' + targetExecutableName(arch),
 | 
			
		||||
      'rm -rf target/' + arch.target + '/release/*',
 | 
			
		||||
    ],
 | 
			
		||||
    depends_on: ['Prepare'],
 | 
			
		||||
    volumes: [{
 | 
			
		||||
      name: getVolumeName(arch),
 | 
			
		||||
      path: '/drone/src/target',
 | 
			
		||||
    }],
 | 
			
		||||
  }
 | 
			
		||||
  for arch in archs
 | 
			
		||||
];
 | 
			
		||||
 | 
			
		||||
{
 | 
			
		||||
  kind: 'pipeline',
 | 
			
		||||
  type: 'docker',
 | 
			
		||||
  name: 'default',
 | 
			
		||||
  platform: {
 | 
			
		||||
    arch: 'amd64',
 | 
			
		||||
  },
 | 
			
		||||
  steps:
 | 
			
		||||
    [{
 | 
			
		||||
      name: 'Prepare',
 | 
			
		||||
      image: build_image,
 | 
			
		||||
      commands: [
 | 
			
		||||
        'mkdir artifacts',
 | 
			
		||||
        'echo Using image: ' + build_image,
 | 
			
		||||
        'cargo --version',
 | 
			
		||||
        'rustc --version',
 | 
			
		||||
      ],
 | 
			
		||||
    }] +
 | 
			
		||||
    add_build_steps() +
 | 
			
		||||
    [
 | 
			
		||||
      {
 | 
			
		||||
        name: 'Show built artifacts',
 | 
			
		||||
        image: build_image,
 | 
			
		||||
        commands: [
 | 
			
		||||
          'ls -lah artifacts',
 | 
			
		||||
        ],
 | 
			
		||||
        depends_on: [getStepName(a) for a in archs],
 | 
			
		||||
      },
 | 
			
		||||
      {
 | 
			
		||||
        name: 'Create release on gitea',
 | 
			
		||||
        image: 'plugins/gitea-release',
 | 
			
		||||
        settings: {
 | 
			
		||||
          api_key: {
 | 
			
		||||
            from_secret: 'gitea_token',
 | 
			
		||||
          },
 | 
			
		||||
          base_url: 'https://code.kiers.eu',
 | 
			
		||||
          files: 'artifacts/*',
 | 
			
		||||
          checksum: 'sha256',
 | 
			
		||||
        },
 | 
			
		||||
        when: {
 | 
			
		||||
          event: ['tag', 'promote'],
 | 
			
		||||
        },
 | 
			
		||||
        depends_on: ['Show built artifacts'],
 | 
			
		||||
      },
 | 
			
		||||
    ],
 | 
			
		||||
 | 
			
		||||
  volumes: getLocalVolumes(archs),
 | 
			
		||||
 | 
			
		||||
  image_pull_secrets: ['docker_private_repo'],
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										24
									
								
								.github/workflows/rust.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										24
									
								
								.github/workflows/rust.yml
									
									
									
									
										vendored
									
									
								
							@@ -1,24 +0,0 @@
 | 
			
		||||
name: Rust
 | 
			
		||||
 | 
			
		||||
on:
 | 
			
		||||
  push:
 | 
			
		||||
    branches: [ main ]
 | 
			
		||||
  pull_request:
 | 
			
		||||
    branches: [ main ]
 | 
			
		||||
 | 
			
		||||
env:
 | 
			
		||||
  CARGO_TERM_COLOR: always
 | 
			
		||||
 | 
			
		||||
jobs:
 | 
			
		||||
  build:
 | 
			
		||||
 | 
			
		||||
    runs-on: ubuntu-latest
 | 
			
		||||
 | 
			
		||||
    steps:
 | 
			
		||||
    - uses: actions/checkout@v2
 | 
			
		||||
    - name: Upgrade Rust
 | 
			
		||||
      run: rustup update
 | 
			
		||||
    - name: Build
 | 
			
		||||
      run: cargo build --verbose
 | 
			
		||||
    - name: Run tests
 | 
			
		||||
      run: cargo test --verbose
 | 
			
		||||
							
								
								
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							@@ -1 +1,2 @@
 | 
			
		||||
/target
 | 
			
		||||
config.yaml
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										779
									
								
								Cargo.lock
									
									
									
										generated
									
									
									
								
							
							
						
						
									
										779
									
								
								Cargo.lock
									
									
									
										generated
									
									
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										36
									
								
								Cargo.toml
									
									
									
									
									
								
							
							
						
						
									
										36
									
								
								Cargo.toml
									
									
									
									
									
								
							@@ -1,24 +1,34 @@
 | 
			
		||||
[package]
 | 
			
		||||
name = "fourth"
 | 
			
		||||
version = "0.1.1"
 | 
			
		||||
name = "layer4-proxy"
 | 
			
		||||
version = "0.1.7"
 | 
			
		||||
edition = "2021"
 | 
			
		||||
authors = ["LI Rui <lr_cn@outlook.com>"]
 | 
			
		||||
authors = ["Jacob Kiers <code@kiers.eu>"]
 | 
			
		||||
license = "Apache-2.0"
 | 
			
		||||
description = "Simple and fast layer 4 proxy in Rust"
 | 
			
		||||
readme = "README.md"
 | 
			
		||||
homepage = "https://github.com/KernelErr/fourth"
 | 
			
		||||
repository = "https://github.com/KernelErr/fourth"
 | 
			
		||||
homepage = "https://code.kiers.eu/jjkiers/layer4-proxy"
 | 
			
		||||
repository = "https://code.kiers.eu/jjkiers/layer4-proxy"
 | 
			
		||||
keywords = ["proxy", "network"]
 | 
			
		||||
categories = ["web-programming"]
 | 
			
		||||
 | 
			
		||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
 | 
			
		||||
 | 
			
		||||
[dependencies]
 | 
			
		||||
log = "0.4"
 | 
			
		||||
pretty_env_logger = "0.4"
 | 
			
		||||
serde = { version = "1.0", features = ["derive"] }
 | 
			
		||||
serde_yaml = "0.8"
 | 
			
		||||
futures = "0.3"
 | 
			
		||||
tls-parser = "0.11"
 | 
			
		||||
exclude = [".*"]
 | 
			
		||||
 | 
			
		||||
tokio = { version = "1.0", features = ["full"] }
 | 
			
		||||
[[bin]]
 | 
			
		||||
name = "l4p"
 | 
			
		||||
path = "src/main.rs"
 | 
			
		||||
 | 
			
		||||
[dependencies]
 | 
			
		||||
async-trait = "0.1.73"
 | 
			
		||||
byte_string = "1"
 | 
			
		||||
bytes = "1.1"
 | 
			
		||||
futures = "0.3"
 | 
			
		||||
log = "0.4"
 | 
			
		||||
pretty_env_logger = "0.5"
 | 
			
		||||
serde = { version = "1.0", features = ["derive"] }
 | 
			
		||||
serde_yaml = "0.9.21"
 | 
			
		||||
time = { version = "0.3.1", features = ["local-offset", "formatting"] }
 | 
			
		||||
tls-parser = "0.11"
 | 
			
		||||
tokio = { version = "1.0", features = ["full"] }
 | 
			
		||||
url = "2.2.2"
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										58
									
								
								README-EN.md
									
									
									
									
									
								
							
							
						
						
									
										58
									
								
								README-EN.md
									
									
									
									
									
								
							@@ -1,58 +0,0 @@
 | 
			
		||||
# Fourth
 | 
			
		||||
 | 
			
		||||
> Hey, now we are on level 4!
 | 
			
		||||
 | 
			
		||||
[](https://crates.io/crates/fourth) [](https://github.com/KernelErr/fourth/actions/workflows/rust.yml)
 | 
			
		||||
 | 
			
		||||
Fourth is a layer 4 proxy implemented by Rust to listen on specific ports and transfer data to remote addresses according to configuration.
 | 
			
		||||
 | 
			
		||||
## Features
 | 
			
		||||
 | 
			
		||||
- Listen on specific port and proxy to local or remote port
 | 
			
		||||
- SNI-based rule without terminating TLS connection
 | 
			
		||||
 | 
			
		||||
## Installation
 | 
			
		||||
 | 
			
		||||
To gain best performance on your computer's architecture, please consider build the source code. First, you may need [Rust tool chain](https://rustup.rs/).
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
$ cd fourth
 | 
			
		||||
$ cargo build --release
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
Binary file will be generated at `target/release/fourth`, or you can use `cargo install --path .` to install.
 | 
			
		||||
 | 
			
		||||
## Configuration
 | 
			
		||||
 | 
			
		||||
Fourth will read yaml format configuration file from `/etc/fourth/config.yaml`, here is an example:
 | 
			
		||||
 | 
			
		||||
```yaml
 | 
			
		||||
version: 1
 | 
			
		||||
log: info
 | 
			
		||||
 | 
			
		||||
servers:
 | 
			
		||||
  example_server:
 | 
			
		||||
    listen:
 | 
			
		||||
      - "0.0.0.0:443"
 | 
			
		||||
      - "[::]:443"
 | 
			
		||||
    tls: true # Enable TLS features like SNI
 | 
			
		||||
    sni:
 | 
			
		||||
      proxy.example.com: proxy
 | 
			
		||||
      www.example.com: nginx
 | 
			
		||||
    default: ban
 | 
			
		||||
  relay_server:
 | 
			
		||||
    listen:
 | 
			
		||||
      - "127.0.0.1:8081"
 | 
			
		||||
    default: remote
 | 
			
		||||
 | 
			
		||||
upstream:
 | 
			
		||||
  nginx: "127.0.0.1:8080"
 | 
			
		||||
  proxy: "127.0.0.1:1024"
 | 
			
		||||
  other: "www.remote.example.com:8082" # proxy to remote address
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
Built-in two upstreams: ban(terminate connection immediately), echo
 | 
			
		||||
 | 
			
		||||
## License
 | 
			
		||||
 | 
			
		||||
Fourth is available under terms of Apache-2.0.
 | 
			
		||||
							
								
								
									
										66
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										66
									
								
								README.md
									
									
									
									
									
								
							@@ -1,66 +1,64 @@
 | 
			
		||||
# Fourth
 | 
			
		||||
# l4p
 | 
			
		||||
 | 
			
		||||
> 这一波在第四层。
 | 
			
		||||
> Hey, now we are on level 4!
 | 
			
		||||
 | 
			
		||||
[](https://crates.io/crates/fourth) [](https://github.com/KernelErr/fourth/actions/workflows/rust.yml)
 | 
			
		||||

 | 
			
		||||
 | 
			
		||||
[English](/README-EN.md)
 | 
			
		||||
`l4p` is a layer 4 proxy implemented by Rust to listen on specific ports and transfer TCP/KCP data to remote addresses(only TCP) according to configuration.
 | 
			
		||||
 | 
			
		||||
Fourth是一个Rust实现的Layer 4代理,用于监听指定端口TCP流量,并根据规则转发到指定目标。
 | 
			
		||||
## Features
 | 
			
		||||
 | 
			
		||||
## 功能
 | 
			
		||||
- Listen on specific port and proxy to local or remote port
 | 
			
		||||
- SNI-based rule without terminating TLS connection
 | 
			
		||||
- DNS-based backend with periodic resolution
 | 
			
		||||
 | 
			
		||||
- 监听指定端口代理到本地或远端指定端口
 | 
			
		||||
- 监听指定端口,通过TLS ClientHello消息中的SNI进行分流
 | 
			
		||||
## Installation
 | 
			
		||||
 | 
			
		||||
## 安装方法
 | 
			
		||||
 | 
			
		||||
为了确保获得您架构下的最佳性能,请考虑自行编译,首选需要确保您拥有[Rust工具链](https://rustup.rs/)。
 | 
			
		||||
To gain best performance on your computer's architecture, please consider build the source code. First, you may need [Rust tool chain](https://rustup.rs/).
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
$ cd fourth
 | 
			
		||||
$ cd l4p
 | 
			
		||||
$ cargo build --release
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
将在`target/release/fourth`生成二进制文件,您也可以使用`cargo install --path . `来安装二进制文件。
 | 
			
		||||
Binary file will be generated at `target/release/l4p`, or you can use `cargo install --path .` to install.
 | 
			
		||||
 | 
			
		||||
## 配置
 | 
			
		||||
Or you can use Cargo to install `l4p`:
 | 
			
		||||
 | 
			
		||||
Fourth使用yaml格式的配置文件,默认情况下会读取`/etc/fourth/config.yaml`,如下是一个示例配置。
 | 
			
		||||
```bash
 | 
			
		||||
$ cargo install l4p
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
Or you can download binary file form the Release page.
 | 
			
		||||
 | 
			
		||||
## Configuration
 | 
			
		||||
 | 
			
		||||
`l4p` will read yaml format configuration file from `/etc/l4p/l4p.yaml`, and you can set custom path to environment variable `L4P_CONFIG`, here is an minimal viable example:
 | 
			
		||||
 | 
			
		||||
```yaml
 | 
			
		||||
version: 1
 | 
			
		||||
log: info
 | 
			
		||||
 | 
			
		||||
servers:
 | 
			
		||||
  example_server:
 | 
			
		||||
    listen:
 | 
			
		||||
      - "0.0.0.0:443"
 | 
			
		||||
      - "[::]:443"
 | 
			
		||||
    tls: true # 启动SNI分流,将根据TLS请求中的主机名分流
 | 
			
		||||
    sni:
 | 
			
		||||
      proxy.example.com: proxy
 | 
			
		||||
      www.example.com: nginx
 | 
			
		||||
    default: ban
 | 
			
		||||
  relay_server:
 | 
			
		||||
  proxy_server:
 | 
			
		||||
    listen:
 | 
			
		||||
      - "127.0.0.1:8081"
 | 
			
		||||
    default: remote
 | 
			
		||||
 | 
			
		||||
upstream:
 | 
			
		||||
  nginx: "127.0.0.1:8080"
 | 
			
		||||
  proxy: "127.0.0.1:1024"
 | 
			
		||||
  other: "www.remote.example.com:8082" # 代理到远端地址
 | 
			
		||||
  remote: "tcp://www.remote.example.com:8082" # proxy to remote address
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
内置两个的upstream:ban(立即中断连接)、echo(返回读到的数据)。
 | 
			
		||||
There are two upstreams built in:
 | 
			
		||||
* Ban, which terminates the connection immediately
 | 
			
		||||
* Echo, which reflects back with the input
 | 
			
		||||
 | 
			
		||||
## io_uring?
 | 
			
		||||
For detailed configuration, check [this example](./config.yaml.example).
 | 
			
		||||
 | 
			
		||||
尽管经过了很多尝试,我们发现目前一些Rust下面的io_uring实现存在问题,我们使用的io_uring库实现尽管在吞吐量上可以做到单线程20Gbps(相比之下Tokio仅有8Gbps),但在QPS上存在性能损失较大的问题。因此在有成熟的io_uring实现之前,我们仍然选择epoll。之后我们会持续关注相关进展。
 | 
			
		||||
## Thanks
 | 
			
		||||
 | 
			
		||||
可能以后会为Linux高内核版本的用户提供可选的io_uring加速。
 | 
			
		||||
- [`l4p`](https://crates.io/crates/`l4p`), of which this is a heavily modified fork.
 | 
			
		||||
 | 
			
		||||
## 协议
 | 
			
		||||
## License
 | 
			
		||||
 | 
			
		||||
Fourth以Apache-2.0协议开源。
 | 
			
		||||
`l4p` is available under terms of Apache-2.0.
 | 
			
		||||
							
								
								
									
										21
									
								
								config.yaml.example
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										21
									
								
								config.yaml.example
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,21 @@
 | 
			
		||||
version: 1
 | 
			
		||||
log: debug
 | 
			
		||||
 | 
			
		||||
servers:
 | 
			
		||||
  first_server:
 | 
			
		||||
    listen:
 | 
			
		||||
      - "0.0.0.0:8443"
 | 
			
		||||
      - "[::]:8443"
 | 
			
		||||
    tls: true # Enable TLS features like SNI filtering
 | 
			
		||||
    sni:
 | 
			
		||||
      api.example.org: example-api
 | 
			
		||||
      www.example.org: proxy
 | 
			
		||||
    default: ban
 | 
			
		||||
 | 
			
		||||
  second-server:
 | 
			
		||||
    listen: [ "127.0.0.1:8080" ]
 | 
			
		||||
    default: echo
 | 
			
		||||
 | 
			
		||||
upstream:
 | 
			
		||||
  proxy: "tcp://new-www.example.org:443" # Connect over IPv4 or IPv6 to new-www.example.org:443
 | 
			
		||||
  example-api: "tcp6://api-v1.example.com:443" # Connect over IPv6 to api-v1.example.com:443
 | 
			
		||||
@@ -15,8 +15,13 @@ servers:
 | 
			
		||||
    listen:
 | 
			
		||||
      - "127.0.0.1:8081"
 | 
			
		||||
    default: remote
 | 
			
		||||
  kcp_server:
 | 
			
		||||
    protocol: kcp # default TCP
 | 
			
		||||
    listen:
 | 
			
		||||
      - "127.0.0.1:8082"
 | 
			
		||||
    default: echo
 | 
			
		||||
 | 
			
		||||
upstream:
 | 
			
		||||
  nginx: "127.0.0.1:8080"
 | 
			
		||||
  proxy: "127.0.0.1:1024"
 | 
			
		||||
  other: "www.remote.example.com:8082" # proxy to remote address
 | 
			
		||||
  nginx: "tcp://127.0.0.1:8080"
 | 
			
		||||
  proxy: "tcp://127.0.0.1:1024"
 | 
			
		||||
  remote: "tcp://www.remote.example.com:8082" # proxy to remote address
 | 
			
		||||
							
								
								
									
										51
									
								
								fourth.service
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										51
									
								
								fourth.service
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,51 @@
 | 
			
		||||
[Unit]
 | 
			
		||||
Description=Fourth - Layer 4 proxy
 | 
			
		||||
 | 
			
		||||
After=network-online.target
 | 
			
		||||
Wants=network-online.target
 | 
			
		||||
 | 
			
		||||
[Install]
 | 
			
		||||
WantedBy=default.target
 | 
			
		||||
 | 
			
		||||
[Service]
 | 
			
		||||
Type=simple
 | 
			
		||||
 | 
			
		||||
# Allow read-only access to the config directory
 | 
			
		||||
ReadOnlyPaths=/etc/fourth
 | 
			
		||||
# Path to the binary
 | 
			
		||||
ExecStart=/usr/local/bin/fourth
 | 
			
		||||
 | 
			
		||||
# Needs CAP_NET_BIND_SERVICE in order to bind to lower ports
 | 
			
		||||
# When using ports above 1024, these should be made empty
 | 
			
		||||
AmbientCapabilities=CAP_NET_BIND_SERVICE
 | 
			
		||||
CapabilityBoundingSet=CAP_NET_BIND_SERVICE
 | 
			
		||||
 | 
			
		||||
# Run as a dynamic user
 | 
			
		||||
DynamicUser=yes
 | 
			
		||||
 | 
			
		||||
# Security
 | 
			
		||||
PrivateTmp=yes
 | 
			
		||||
PrivateDevices=yes
 | 
			
		||||
ProtectSystem=strict
 | 
			
		||||
ProtectHome=yes
 | 
			
		||||
SystemCallFilter=@basic-io @file-system @network-io @system-service
 | 
			
		||||
SystemCallFilter=~@privileged
 | 
			
		||||
SystemCallFilter=~@resources
 | 
			
		||||
NoNewPrivileges=yes
 | 
			
		||||
ProtectProc=invisible
 | 
			
		||||
RemoveIPC=yes
 | 
			
		||||
RestrictAddressFamilies=AF_INET AF_INET6
 | 
			
		||||
RestrictNamespaces=yes
 | 
			
		||||
ProtectHostname=yes
 | 
			
		||||
ProtectClock=yes
 | 
			
		||||
ProtectKernelModules=yes
 | 
			
		||||
ProtectKernelLogs=yes
 | 
			
		||||
ProtectControlGroups=yes
 | 
			
		||||
LockPersonality=yes
 | 
			
		||||
MemoryDenyWriteExecute=yes
 | 
			
		||||
RestrictRealtime=yes
 | 
			
		||||
ProcSubset=pid
 | 
			
		||||
UMask=0077
 | 
			
		||||
SystemCallArchitectures=native
 | 
			
		||||
RestrictSUIDSGID=yes
 | 
			
		||||
ProtectKernelTunables=yes
 | 
			
		||||
@@ -1,92 +0,0 @@
 | 
			
		||||
use log::debug;
 | 
			
		||||
use serde::Deserialize;
 | 
			
		||||
use std::collections::HashMap;
 | 
			
		||||
use std::fs::File;
 | 
			
		||||
use std::io::{Error as IOError, Read};
 | 
			
		||||
 | 
			
		||||
#[derive(Debug, Clone)]
 | 
			
		||||
pub struct Config {
 | 
			
		||||
    pub base: BaseConfig,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[derive(Debug, Default, Deserialize, Clone)]
 | 
			
		||||
pub struct BaseConfig {
 | 
			
		||||
    pub version: i32,
 | 
			
		||||
    pub log: Option<String>,
 | 
			
		||||
    pub servers: HashMap<String, ServerConfig>,
 | 
			
		||||
    pub upstream: HashMap<String, String>,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[derive(Debug, Default, Deserialize, Clone)]
 | 
			
		||||
pub struct ServerConfig {
 | 
			
		||||
    pub listen: Vec<String>,
 | 
			
		||||
    pub tls: Option<bool>,
 | 
			
		||||
    pub sni: Option<HashMap<String, String>>,
 | 
			
		||||
    pub default: Option<String>,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[derive(Debug)]
 | 
			
		||||
pub enum ConfigError {
 | 
			
		||||
    IO(IOError),
 | 
			
		||||
    Yaml(serde_yaml::Error),
 | 
			
		||||
    Custom(String),
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl Config {
 | 
			
		||||
    pub fn new(path: &str) -> Result<Config, ConfigError> {
 | 
			
		||||
        let base = (load_config(path))?;
 | 
			
		||||
 | 
			
		||||
        Ok(Config { base })
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
fn load_config(path: &str) -> Result<BaseConfig, ConfigError> {
 | 
			
		||||
    let mut contents = String::new();
 | 
			
		||||
    let mut file = (File::open(path))?;
 | 
			
		||||
    (file.read_to_string(&mut contents))?;
 | 
			
		||||
 | 
			
		||||
    let parsed: BaseConfig = serde_yaml::from_str(&contents)?;
 | 
			
		||||
 | 
			
		||||
    if parsed.version != 1 {
 | 
			
		||||
        return Err(ConfigError::Custom(
 | 
			
		||||
            "Unsupported config version".to_string(),
 | 
			
		||||
        ));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    let log_level = parsed.log.clone().unwrap_or_else(|| "info".to_string());
 | 
			
		||||
    if !log_level.eq("disable") {
 | 
			
		||||
        std::env::set_var("FOURTH_LOG", log_level.clone());
 | 
			
		||||
        pretty_env_logger::init_custom_env("FOURTH_LOG");
 | 
			
		||||
        debug!("Set log level to {}", log_level);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    debug!("Config version {}", parsed.version);
 | 
			
		||||
 | 
			
		||||
    Ok(parsed)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl From<IOError> for ConfigError {
 | 
			
		||||
    fn from(err: IOError) -> ConfigError {
 | 
			
		||||
        ConfigError::IO(err)
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl From<serde_yaml::Error> for ConfigError {
 | 
			
		||||
    fn from(err: serde_yaml::Error) -> ConfigError {
 | 
			
		||||
        ConfigError::Yaml(err)
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[cfg(test)]
 | 
			
		||||
mod tests {
 | 
			
		||||
    use super::*;
 | 
			
		||||
 | 
			
		||||
    #[test]
 | 
			
		||||
    fn test_load_config() {
 | 
			
		||||
        let config = Config::new("tests/config.yaml").unwrap();
 | 
			
		||||
        assert_eq!(config.base.version, 1);
 | 
			
		||||
        assert_eq!(config.base.log.unwrap(), "disable");
 | 
			
		||||
        assert_eq!(config.base.servers.len(), 2);
 | 
			
		||||
        assert_eq!(config.base.upstream.len(), 2);
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										227
									
								
								src/config/config.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										227
									
								
								src/config/config.rs
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,227 @@
 | 
			
		||||
use crate::upstreams::ProxyToUpstream;
 | 
			
		||||
use crate::upstreams::Upstream;
 | 
			
		||||
use log::{debug, warn};
 | 
			
		||||
use serde::Deserialize;
 | 
			
		||||
use std::collections::{HashMap, HashSet};
 | 
			
		||||
use std::fs::File;
 | 
			
		||||
use std::io::{Error as IOError, Read};
 | 
			
		||||
use url::Url;
 | 
			
		||||
 | 
			
		||||
#[derive(Debug, Clone)]
 | 
			
		||||
pub struct ConfigV1 {
 | 
			
		||||
    pub base: ParsedConfigV1,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[derive(Debug, Default, Deserialize, Clone)]
 | 
			
		||||
pub struct ParsedConfigV1 {
 | 
			
		||||
    pub version: i32,
 | 
			
		||||
    pub log: Option<String>,
 | 
			
		||||
    pub servers: HashMap<String, ServerConfig>,
 | 
			
		||||
    pub upstream: HashMap<String, Upstream>,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[derive(Debug, Default, Deserialize, Clone)]
 | 
			
		||||
pub struct BaseConfig {
 | 
			
		||||
    pub version: i32,
 | 
			
		||||
    pub log: Option<String>,
 | 
			
		||||
    pub servers: HashMap<String, ServerConfig>,
 | 
			
		||||
    pub upstream: HashMap<String, String>,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[derive(Debug, Default, Deserialize, Clone)]
 | 
			
		||||
pub struct ServerConfig {
 | 
			
		||||
    pub listen: Vec<String>,
 | 
			
		||||
    pub protocol: Option<String>,
 | 
			
		||||
    pub tls: Option<bool>,
 | 
			
		||||
    pub sni: Option<HashMap<String, String>>,
 | 
			
		||||
    pub default: Option<String>,
 | 
			
		||||
}
 | 
			
		||||
impl TryInto<ProxyToUpstream> for &str {
 | 
			
		||||
    type Error = ConfigError;
 | 
			
		||||
 | 
			
		||||
    fn try_into(self) -> Result<ProxyToUpstream, Self::Error> {
 | 
			
		||||
        let upstream_url = match Url::parse(self) {
 | 
			
		||||
            Ok(url) => url,
 | 
			
		||||
            Err(_) => {
 | 
			
		||||
                return Err(ConfigError::Custom(format!(
 | 
			
		||||
                    "Invalid upstream url {}",
 | 
			
		||||
                    self
 | 
			
		||||
                )))
 | 
			
		||||
            }
 | 
			
		||||
        };
 | 
			
		||||
 | 
			
		||||
        let upstream_host = match upstream_url.host_str() {
 | 
			
		||||
            Some(host) => host,
 | 
			
		||||
            None => {
 | 
			
		||||
                return Err(ConfigError::Custom(format!(
 | 
			
		||||
                    "Invalid upstream url {}",
 | 
			
		||||
                    self
 | 
			
		||||
                )))
 | 
			
		||||
            }
 | 
			
		||||
        };
 | 
			
		||||
 | 
			
		||||
        let upstream_port = match upstream_url.port_or_known_default() {
 | 
			
		||||
            Some(port) => port,
 | 
			
		||||
            None => {
 | 
			
		||||
                return Err(ConfigError::Custom(format!(
 | 
			
		||||
                    "Invalid upstream url {}",
 | 
			
		||||
                    self
 | 
			
		||||
                )))
 | 
			
		||||
            }
 | 
			
		||||
        };
 | 
			
		||||
 | 
			
		||||
        match upstream_url.scheme() {
 | 
			
		||||
            "tcp" | "tcp4" | "tcp6" => {}
 | 
			
		||||
            _ => {
 | 
			
		||||
                return Err(ConfigError::Custom(format!(
 | 
			
		||||
                    "Invalid upstream scheme {}",
 | 
			
		||||
                    self
 | 
			
		||||
                )))
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        Ok(ProxyToUpstream::new(
 | 
			
		||||
            format!("{}:{}", upstream_host, upstream_port),
 | 
			
		||||
            upstream_url.scheme().to_string(),
 | 
			
		||||
        ))
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[derive(Debug)]
 | 
			
		||||
pub enum ConfigError {
 | 
			
		||||
    IO(IOError),
 | 
			
		||||
    Yaml(serde_yaml::Error),
 | 
			
		||||
    Custom(String),
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl ConfigV1 {
 | 
			
		||||
    pub fn new(path: &str) -> Result<ConfigV1, ConfigError> {
 | 
			
		||||
        let base = load_config(path)?;
 | 
			
		||||
 | 
			
		||||
        Ok(ConfigV1 { base })
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
fn load_config(path: &str) -> Result<ParsedConfigV1, ConfigError> {
 | 
			
		||||
    let mut contents = String::new();
 | 
			
		||||
    let mut file = File::open(path)?;
 | 
			
		||||
    file.read_to_string(&mut contents)?;
 | 
			
		||||
 | 
			
		||||
    let base: BaseConfig = serde_yaml::from_str(&contents)?;
 | 
			
		||||
 | 
			
		||||
    if base.version != 1 {
 | 
			
		||||
        return Err(ConfigError::Custom(
 | 
			
		||||
            "Unsupported config version".to_string(),
 | 
			
		||||
        ));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    let log_level = base.log.clone().unwrap_or_else(|| "info".to_string());
 | 
			
		||||
    if !log_level.eq("disable") {
 | 
			
		||||
        std::env::set_var("FOURTH_LOG", log_level.clone());
 | 
			
		||||
        pretty_env_logger::init_custom_env("FOURTH_LOG");
 | 
			
		||||
        debug!("Set log level to {}", log_level);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    debug!("Config version {}", base.version);
 | 
			
		||||
 | 
			
		||||
    let mut parsed_upstream: HashMap<String, Upstream> = HashMap::new();
 | 
			
		||||
 | 
			
		||||
    parsed_upstream.insert("ban".to_string(), Upstream::Ban);
 | 
			
		||||
    parsed_upstream.insert("echo".to_string(), Upstream::Echo);
 | 
			
		||||
 | 
			
		||||
    for (name, upstream) in base.upstream.iter() {
 | 
			
		||||
        let ups = upstream.as_str().try_into()?;
 | 
			
		||||
        parsed_upstream.insert(name.to_string(), Upstream::Proxy(ups));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    let parsed = ParsedConfigV1 {
 | 
			
		||||
        version: base.version,
 | 
			
		||||
        log: base.log,
 | 
			
		||||
        servers: base.servers,
 | 
			
		||||
        upstream: parsed_upstream,
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
    verify_config(parsed)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
fn verify_config(config: ParsedConfigV1) -> Result<ParsedConfigV1, ConfigError> {
 | 
			
		||||
    let mut used_upstreams: HashSet<String> = HashSet::new();
 | 
			
		||||
    let mut upstream_names: HashSet<String> = HashSet::new();
 | 
			
		||||
    let mut listen_addresses: HashSet<String> = HashSet::new();
 | 
			
		||||
 | 
			
		||||
    // Check for duplicate upstream names
 | 
			
		||||
    for (name, _) in config.upstream.iter() {
 | 
			
		||||
        if upstream_names.contains(name) {
 | 
			
		||||
            return Err(ConfigError::Custom(format!(
 | 
			
		||||
                "Duplicate upstream name {}",
 | 
			
		||||
                name
 | 
			
		||||
            )));
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        upstream_names.insert(name.to_string());
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    for (_, server) in config.servers.clone() {
 | 
			
		||||
        // check for duplicate listen addresses
 | 
			
		||||
        for listen in server.listen {
 | 
			
		||||
            if listen_addresses.contains(&listen) {
 | 
			
		||||
                return Err(ConfigError::Custom(format!(
 | 
			
		||||
                    "Duplicate listen address {}",
 | 
			
		||||
                    listen
 | 
			
		||||
                )));
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            listen_addresses.insert(listen.to_string());
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        if server.tls.unwrap_or_default() && server.sni.is_some() {
 | 
			
		||||
            for (_, val) in server.sni.unwrap() {
 | 
			
		||||
                used_upstreams.insert(val.to_string());
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        if server.default.is_some() {
 | 
			
		||||
            used_upstreams.insert(server.default.unwrap().to_string());
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        for key in &used_upstreams {
 | 
			
		||||
            if !config.upstream.contains_key(key) {
 | 
			
		||||
                return Err(ConfigError::Custom(format!("Upstream {} not found", key)));
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    for key in &upstream_names {
 | 
			
		||||
        if !used_upstreams.contains(key) && !key.eq("echo") && !key.eq("ban") {
 | 
			
		||||
            warn!("Upstream {} not used", key);
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    Ok(config)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl From<IOError> for ConfigError {
 | 
			
		||||
    fn from(err: IOError) -> ConfigError {
 | 
			
		||||
        ConfigError::IO(err)
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl From<serde_yaml::Error> for ConfigError {
 | 
			
		||||
    fn from(err: serde_yaml::Error) -> ConfigError {
 | 
			
		||||
        ConfigError::Yaml(err)
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[cfg(test)]
 | 
			
		||||
mod tests {
 | 
			
		||||
    use super::*;
 | 
			
		||||
 | 
			
		||||
    #[test]
 | 
			
		||||
    fn test_load_config() {
 | 
			
		||||
        let config = ConfigV1::new("tests/config.yaml").unwrap();
 | 
			
		||||
        assert_eq!(config.base.version, 1);
 | 
			
		||||
        assert_eq!(config.base.log.unwrap(), "disable");
 | 
			
		||||
        assert_eq!(config.base.servers.len(), 5);
 | 
			
		||||
        assert_eq!(config.base.upstream.len(), 3 + 2); // Add ban and echo upstreams
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										3
									
								
								src/config/mod.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										3
									
								
								src/config/mod.rs
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,3 @@
 | 
			
		||||
mod config;
 | 
			
		||||
pub(crate) use config::ConfigV1;
 | 
			
		||||
pub(crate) use config::ParsedConfigV1;
 | 
			
		||||
							
								
								
									
										50
									
								
								src/main.rs
									
									
									
									
									
								
							
							
						
						
									
										50
									
								
								src/main.rs
									
									
									
									
									
								
							@@ -1,13 +1,27 @@
 | 
			
		||||
mod config;
 | 
			
		||||
mod plugins;
 | 
			
		||||
mod servers;
 | 
			
		||||
mod upstreams;
 | 
			
		||||
 | 
			
		||||
use crate::config::Config;
 | 
			
		||||
use crate::config::ConfigV1;
 | 
			
		||||
use crate::servers::Server;
 | 
			
		||||
 | 
			
		||||
use log::{debug, error};
 | 
			
		||||
use std::path::PathBuf;
 | 
			
		||||
 | 
			
		||||
fn main() {
 | 
			
		||||
    let config = match Config::new("/etc/fourth/config.yaml") {
 | 
			
		||||
    let config_path = match find_config() {
 | 
			
		||||
        Ok(p) => p,
 | 
			
		||||
        Err(paths) => {
 | 
			
		||||
            println!("Could not find config file. Tried paths:");
 | 
			
		||||
            for p in paths {
 | 
			
		||||
                println!("- {}", p);
 | 
			
		||||
            }
 | 
			
		||||
            std::process::exit(1);
 | 
			
		||||
        }
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
    let config = match ConfigV1::new(&config_path) {
 | 
			
		||||
        Ok(config) => config,
 | 
			
		||||
        Err(e) => {
 | 
			
		||||
            println!("Could not load config: {:?}", e);
 | 
			
		||||
@@ -16,9 +30,35 @@ fn main() {
 | 
			
		||||
    };
 | 
			
		||||
    debug!("{:?}", config);
 | 
			
		||||
 | 
			
		||||
    let mut server = Server::new(config.base);
 | 
			
		||||
    let mut server = Server::new_from_v1_config(config.base);
 | 
			
		||||
    debug!("{:?}", server);
 | 
			
		||||
 | 
			
		||||
    let res = server.run();
 | 
			
		||||
    error!("Server returned an error: {:?}", res);
 | 
			
		||||
    let _ = server.run();
 | 
			
		||||
    error!("Server ended with errors");
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
fn find_config() -> Result<String, Vec<String>> {
 | 
			
		||||
    let possible_paths = ["/etc/l4p", ""];
 | 
			
		||||
    let possible_names = ["l4p.yaml", "config.yaml"];
 | 
			
		||||
 | 
			
		||||
    let mut tried_paths = Vec::<String>::new();
 | 
			
		||||
 | 
			
		||||
    for path in possible_paths
 | 
			
		||||
        .iter()
 | 
			
		||||
        .flat_map(|&path| {
 | 
			
		||||
            possible_names
 | 
			
		||||
                .iter()
 | 
			
		||||
                .map(move |&file| PathBuf::new().join(path).join(file))
 | 
			
		||||
        })
 | 
			
		||||
        .collect::<Vec<PathBuf>>()
 | 
			
		||||
    {
 | 
			
		||||
        let path_str = path.to_string_lossy().to_string();
 | 
			
		||||
        if path.exists() {
 | 
			
		||||
            return Ok(path_str);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        tried_paths.push(path_str);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    Err(tried_paths)
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										110
									
								
								src/plugins/kcp/config.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										110
									
								
								src/plugins/kcp/config.rs
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,110 @@
 | 
			
		||||
use std::{io::Write, time::Duration};
 | 
			
		||||
 | 
			
		||||
use kcp::Kcp;
 | 
			
		||||
 | 
			
		||||
/// Kcp Delay Config
 | 
			
		||||
#[derive(Debug, Clone, Copy)]
 | 
			
		||||
pub struct KcpNoDelayConfig {
 | 
			
		||||
    /// Enable nodelay
 | 
			
		||||
    pub nodelay: bool,
 | 
			
		||||
    /// Internal update interval (ms)
 | 
			
		||||
    pub interval: i32,
 | 
			
		||||
    /// ACK number to enable fast resend
 | 
			
		||||
    pub resend: i32,
 | 
			
		||||
    /// Disable congetion control
 | 
			
		||||
    pub nc: bool,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl Default for KcpNoDelayConfig {
 | 
			
		||||
    fn default() -> KcpNoDelayConfig {
 | 
			
		||||
        KcpNoDelayConfig {
 | 
			
		||||
            nodelay: false,
 | 
			
		||||
            interval: 100,
 | 
			
		||||
            resend: 0,
 | 
			
		||||
            nc: false,
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[allow(unused)]
 | 
			
		||||
impl KcpNoDelayConfig {
 | 
			
		||||
    /// Get a fastest configuration
 | 
			
		||||
    ///
 | 
			
		||||
    /// 1. Enable NoDelay
 | 
			
		||||
    /// 2. Set ticking interval to be 10ms
 | 
			
		||||
    /// 3. Set fast resend to be 2
 | 
			
		||||
    /// 4. Disable congestion control
 | 
			
		||||
    pub fn fastest() -> KcpNoDelayConfig {
 | 
			
		||||
        KcpNoDelayConfig {
 | 
			
		||||
            nodelay: true,
 | 
			
		||||
            interval: 10,
 | 
			
		||||
            resend: 2,
 | 
			
		||||
            nc: true,
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /// Get a normal configuration
 | 
			
		||||
    ///
 | 
			
		||||
    /// 1. Disable NoDelay
 | 
			
		||||
    /// 2. Set ticking interval to be 40ms
 | 
			
		||||
    /// 3. Disable fast resend
 | 
			
		||||
    /// 4. Enable congestion control
 | 
			
		||||
    pub fn normal() -> KcpNoDelayConfig {
 | 
			
		||||
        KcpNoDelayConfig {
 | 
			
		||||
            nodelay: false,
 | 
			
		||||
            interval: 40,
 | 
			
		||||
            resend: 0,
 | 
			
		||||
            nc: false,
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/// Kcp Config
 | 
			
		||||
#[derive(Debug, Clone, Copy)]
 | 
			
		||||
pub struct KcpConfig {
 | 
			
		||||
    /// Max Transmission Unit
 | 
			
		||||
    pub mtu: usize,
 | 
			
		||||
    /// nodelay
 | 
			
		||||
    pub nodelay: KcpNoDelayConfig,
 | 
			
		||||
    /// Send window size
 | 
			
		||||
    pub wnd_size: (u16, u16),
 | 
			
		||||
    /// Session expire duration, default is 90 seconds
 | 
			
		||||
    pub session_expire: Duration,
 | 
			
		||||
    /// Flush KCP state immediately after write
 | 
			
		||||
    pub flush_write: bool,
 | 
			
		||||
    /// Flush ACKs immediately after input
 | 
			
		||||
    pub flush_acks_input: bool,
 | 
			
		||||
    /// Stream mode
 | 
			
		||||
    pub stream: bool,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl Default for KcpConfig {
 | 
			
		||||
    fn default() -> KcpConfig {
 | 
			
		||||
        KcpConfig {
 | 
			
		||||
            mtu: 1400,
 | 
			
		||||
            nodelay: KcpNoDelayConfig::normal(),
 | 
			
		||||
            wnd_size: (256, 256),
 | 
			
		||||
            session_expire: Duration::from_secs(90),
 | 
			
		||||
            flush_write: false,
 | 
			
		||||
            flush_acks_input: false,
 | 
			
		||||
            stream: true,
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl KcpConfig {
 | 
			
		||||
    /// Applies config onto `Kcp`
 | 
			
		||||
    #[doc(hidden)]
 | 
			
		||||
    pub fn apply_config<W: Write>(&self, k: &mut Kcp<W>) {
 | 
			
		||||
        k.set_mtu(self.mtu).expect("invalid MTU");
 | 
			
		||||
 | 
			
		||||
        k.set_nodelay(
 | 
			
		||||
            self.nodelay.nodelay,
 | 
			
		||||
            self.nodelay.interval,
 | 
			
		||||
            self.nodelay.resend,
 | 
			
		||||
            self.nodelay.nc,
 | 
			
		||||
        );
 | 
			
		||||
 | 
			
		||||
        k.set_wndsize(self.wnd_size.0, self.wnd_size.1);
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										128
									
								
								src/plugins/kcp/listener.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										128
									
								
								src/plugins/kcp/listener.rs
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,128 @@
 | 
			
		||||
use std::{
 | 
			
		||||
    io::{self, ErrorKind},
 | 
			
		||||
    net::SocketAddr,
 | 
			
		||||
    sync::Arc,
 | 
			
		||||
    time::Duration,
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
use byte_string::ByteStr;
 | 
			
		||||
use kcp::{Error as KcpError, KcpResult};
 | 
			
		||||
use log::{debug, error, trace};
 | 
			
		||||
use tokio::{
 | 
			
		||||
    net::{ToSocketAddrs, UdpSocket},
 | 
			
		||||
    sync::mpsc,
 | 
			
		||||
    task::JoinHandle,
 | 
			
		||||
    time,
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
use crate::plugins::kcp::{config::KcpConfig, session::KcpSessionManager, stream::KcpStream};
 | 
			
		||||
 | 
			
		||||
#[allow(unused)]
 | 
			
		||||
pub struct KcpListener {
 | 
			
		||||
    udp: Arc<UdpSocket>,
 | 
			
		||||
    accept_rx: mpsc::Receiver<(KcpStream, SocketAddr)>,
 | 
			
		||||
    task_watcher: JoinHandle<()>,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl Drop for KcpListener {
 | 
			
		||||
    fn drop(&mut self) {
 | 
			
		||||
        self.task_watcher.abort();
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl KcpListener {
 | 
			
		||||
    pub async fn bind<A: ToSocketAddrs>(config: KcpConfig, addr: A) -> KcpResult<KcpListener> {
 | 
			
		||||
        let udp = UdpSocket::bind(addr).await?;
 | 
			
		||||
        let udp = Arc::new(udp);
 | 
			
		||||
        let server_udp = udp.clone();
 | 
			
		||||
 | 
			
		||||
        let (accept_tx, accept_rx) = mpsc::channel(1024 /* backlogs */);
 | 
			
		||||
        let task_watcher = tokio::spawn(async move {
 | 
			
		||||
            let (close_tx, mut close_rx) = mpsc::channel(64);
 | 
			
		||||
 | 
			
		||||
            let mut sessions = KcpSessionManager::new();
 | 
			
		||||
            let mut packet_buffer = [0u8; 65536];
 | 
			
		||||
            loop {
 | 
			
		||||
                tokio::select! {
 | 
			
		||||
                    conv = close_rx.recv() => {
 | 
			
		||||
                        let conv = conv.expect("close_tx closed unexpectly");
 | 
			
		||||
                        sessions.close_conv(conv);
 | 
			
		||||
                        trace!("session conv: {} removed", conv);
 | 
			
		||||
                    }
 | 
			
		||||
 | 
			
		||||
                    recv_res = udp.recv_from(&mut packet_buffer) => {
 | 
			
		||||
                        match recv_res {
 | 
			
		||||
                            Err(err) => {
 | 
			
		||||
                                error!("udp.recv_from failed, error: {}", err);
 | 
			
		||||
                                time::sleep(Duration::from_secs(1)).await;
 | 
			
		||||
                            }
 | 
			
		||||
                            Ok((n, peer_addr)) => {
 | 
			
		||||
                                let packet = &mut packet_buffer[..n];
 | 
			
		||||
 | 
			
		||||
                                log::trace!("received peer: {}, {:?}", peer_addr, ByteStr::new(packet));
 | 
			
		||||
 | 
			
		||||
                                let mut conv = kcp::get_conv(packet);
 | 
			
		||||
                                if conv == 0 {
 | 
			
		||||
                                    // Allocate a conv for client.
 | 
			
		||||
                                    conv = sessions.alloc_conv();
 | 
			
		||||
                                    debug!("allocate {} conv for peer: {}", conv, peer_addr);
 | 
			
		||||
 | 
			
		||||
                                    kcp::set_conv(packet, conv);
 | 
			
		||||
                                }
 | 
			
		||||
 | 
			
		||||
                                let session = match sessions.get_or_create(&config, conv, &udp, peer_addr, &close_tx) {
 | 
			
		||||
                                    Ok((s, created)) => {
 | 
			
		||||
                                        if created {
 | 
			
		||||
                                            // Created a new session, constructed a new accepted client
 | 
			
		||||
                                            let stream = KcpStream::with_session(s.clone());
 | 
			
		||||
                                            if let Err(..) = accept_tx.try_send((stream, peer_addr)) {
 | 
			
		||||
                                                debug!("failed to create accepted stream due to channel failure");
 | 
			
		||||
 | 
			
		||||
                                                // remove it from session
 | 
			
		||||
                                                sessions.close_conv(conv);
 | 
			
		||||
                                                continue;
 | 
			
		||||
                                            }
 | 
			
		||||
                                        }
 | 
			
		||||
 | 
			
		||||
                                        s
 | 
			
		||||
                                    },
 | 
			
		||||
                                    Err(err) => {
 | 
			
		||||
                                        error!("failed to create session, error: {}, peer: {}, conv: {}", err, peer_addr, conv);
 | 
			
		||||
                                        continue;
 | 
			
		||||
                                    }
 | 
			
		||||
                                };
 | 
			
		||||
 | 
			
		||||
                                // let mut kcp = session.kcp_socket().lock().await;
 | 
			
		||||
                                // if let Err(err) = kcp.input(packet) {
 | 
			
		||||
                                //     error!("kcp.input failed, peer: {}, conv: {}, error: {}, packet: {:?}", peer_addr, conv, err, ByteStr::new(packet));
 | 
			
		||||
                                // }
 | 
			
		||||
                                session.input(packet).await;
 | 
			
		||||
                            }
 | 
			
		||||
                        }
 | 
			
		||||
                    }
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
        });
 | 
			
		||||
 | 
			
		||||
        Ok(KcpListener {
 | 
			
		||||
            udp: server_udp,
 | 
			
		||||
            accept_rx,
 | 
			
		||||
            task_watcher,
 | 
			
		||||
        })
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub async fn accept(&mut self) -> KcpResult<(KcpStream, SocketAddr)> {
 | 
			
		||||
        match self.accept_rx.recv().await {
 | 
			
		||||
            Some(s) => Ok(s),
 | 
			
		||||
            None => Err(KcpError::IoError(io::Error::new(
 | 
			
		||||
                ErrorKind::Other,
 | 
			
		||||
                "accept channel closed unexpectly",
 | 
			
		||||
            ))),
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    #[allow(unused)]
 | 
			
		||||
    pub fn local_addr(&self) -> io::Result<SocketAddr> {
 | 
			
		||||
        self.udp.local_addr()
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										14
									
								
								src/plugins/kcp/mod.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										14
									
								
								src/plugins/kcp/mod.rs
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,14 @@
 | 
			
		||||
//! Library of KCP on Tokio
 | 
			
		||||
 | 
			
		||||
pub use self::{
 | 
			
		||||
    config::{KcpConfig, KcpNoDelayConfig},
 | 
			
		||||
    listener::KcpListener,
 | 
			
		||||
    stream::KcpStream,
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
mod config;
 | 
			
		||||
mod listener;
 | 
			
		||||
mod session;
 | 
			
		||||
mod skcp;
 | 
			
		||||
mod stream;
 | 
			
		||||
mod utils;
 | 
			
		||||
							
								
								
									
										256
									
								
								src/plugins/kcp/session.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										256
									
								
								src/plugins/kcp/session.rs
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,256 @@
 | 
			
		||||
use std::{
 | 
			
		||||
    collections::{hash_map::Entry, HashMap},
 | 
			
		||||
    net::SocketAddr,
 | 
			
		||||
    sync::{
 | 
			
		||||
        atomic::{AtomicBool, Ordering},
 | 
			
		||||
        Arc,
 | 
			
		||||
    },
 | 
			
		||||
    time::Duration,
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
use byte_string::ByteStr;
 | 
			
		||||
use kcp::KcpResult;
 | 
			
		||||
use log::{error, trace};
 | 
			
		||||
use tokio::{
 | 
			
		||||
    net::UdpSocket,
 | 
			
		||||
    sync::{mpsc, Mutex},
 | 
			
		||||
    time::{self, Instant},
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
use crate::plugins::kcp::{skcp::KcpSocket, KcpConfig};
 | 
			
		||||
 | 
			
		||||
pub struct KcpSession {
 | 
			
		||||
    socket: Mutex<KcpSocket>,
 | 
			
		||||
    closed: AtomicBool,
 | 
			
		||||
    session_expire: Duration,
 | 
			
		||||
    session_close_notifier: Option<mpsc::Sender<u32>>,
 | 
			
		||||
    input_tx: mpsc::Sender<Vec<u8>>,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl KcpSession {
 | 
			
		||||
    fn new(
 | 
			
		||||
        socket: KcpSocket,
 | 
			
		||||
        session_expire: Duration,
 | 
			
		||||
        session_close_notifier: Option<mpsc::Sender<u32>>,
 | 
			
		||||
        input_tx: mpsc::Sender<Vec<u8>>,
 | 
			
		||||
    ) -> KcpSession {
 | 
			
		||||
        KcpSession {
 | 
			
		||||
            socket: Mutex::new(socket),
 | 
			
		||||
            closed: AtomicBool::new(false),
 | 
			
		||||
            session_expire,
 | 
			
		||||
            session_close_notifier,
 | 
			
		||||
            input_tx,
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn new_shared(
 | 
			
		||||
        socket: KcpSocket,
 | 
			
		||||
        session_expire: Duration,
 | 
			
		||||
        session_close_notifier: Option<mpsc::Sender<u32>>,
 | 
			
		||||
    ) -> Arc<KcpSession> {
 | 
			
		||||
        let is_client = session_close_notifier.is_none();
 | 
			
		||||
 | 
			
		||||
        let (input_tx, mut input_rx) = mpsc::channel(64);
 | 
			
		||||
 | 
			
		||||
        let udp_socket = socket.udp_socket().clone();
 | 
			
		||||
 | 
			
		||||
        let session = Arc::new(KcpSession::new(
 | 
			
		||||
            socket,
 | 
			
		||||
            session_expire,
 | 
			
		||||
            session_close_notifier,
 | 
			
		||||
            input_tx,
 | 
			
		||||
        ));
 | 
			
		||||
 | 
			
		||||
        {
 | 
			
		||||
            let session = session.clone();
 | 
			
		||||
            tokio::spawn(async move {
 | 
			
		||||
                let mut input_buffer = [0u8; 65536];
 | 
			
		||||
                let update_timer = time::sleep(Duration::from_millis(10));
 | 
			
		||||
                tokio::pin!(update_timer);
 | 
			
		||||
 | 
			
		||||
                loop {
 | 
			
		||||
                    tokio::select! {
 | 
			
		||||
                        // recv() then input()
 | 
			
		||||
                        // Drives the KCP machine forward
 | 
			
		||||
                        recv_result = udp_socket.recv(&mut input_buffer), if is_client => {
 | 
			
		||||
                            match recv_result {
 | 
			
		||||
                                Err(err) => {
 | 
			
		||||
                                    error!("[SESSION] UDP recv failed, error: {}", err);
 | 
			
		||||
                                }
 | 
			
		||||
                                Ok(n) => {
 | 
			
		||||
                                    let input_buffer = &input_buffer[..n];
 | 
			
		||||
                                    trace!("[SESSION] UDP recv {} bytes, going to input {:?}", n, ByteStr::new(input_buffer));
 | 
			
		||||
 | 
			
		||||
                                    let mut socket = session.socket.lock().await;
 | 
			
		||||
 | 
			
		||||
                                    match socket.input(input_buffer) {
 | 
			
		||||
                                        Ok(true) => {
 | 
			
		||||
                                            trace!("[SESSION] UDP input {} bytes and waked sender/receiver", n);
 | 
			
		||||
                                        }
 | 
			
		||||
                                        Ok(false) => {}
 | 
			
		||||
                                        Err(err) => {
 | 
			
		||||
                                            error!("[SESSION] UDP input {} bytes error: {}, input buffer {:?}", n, err, ByteStr::new(input_buffer));
 | 
			
		||||
                                        }
 | 
			
		||||
                                    }
 | 
			
		||||
                                }
 | 
			
		||||
                            }
 | 
			
		||||
                        }
 | 
			
		||||
 | 
			
		||||
                        // bytes received from listener socket
 | 
			
		||||
                        input_opt = input_rx.recv() => {
 | 
			
		||||
                            if let Some(input_buffer) = input_opt {
 | 
			
		||||
                                let mut socket = session.socket.lock().await;
 | 
			
		||||
                                match socket.input(&input_buffer) {
 | 
			
		||||
                                    Ok(..) => {
 | 
			
		||||
                                        trace!("[SESSION] UDP input {} bytes from channel {:?}", input_buffer.len(), ByteStr::new(&input_buffer));
 | 
			
		||||
                                    }
 | 
			
		||||
                                    Err(err) => {
 | 
			
		||||
                                        error!("[SESSION] UDP input {} bytes from channel failed, error: {}, input buffer {:?}",
 | 
			
		||||
                                               input_buffer.len(), err, ByteStr::new(&input_buffer));
 | 
			
		||||
                                    }
 | 
			
		||||
                                }
 | 
			
		||||
                            }
 | 
			
		||||
                        }
 | 
			
		||||
 | 
			
		||||
                        // Call update() in period
 | 
			
		||||
                        _ = &mut update_timer => {
 | 
			
		||||
                            let mut socket = session.socket.lock().await;
 | 
			
		||||
 | 
			
		||||
                            let is_closed = session.closed.load(Ordering::Acquire);
 | 
			
		||||
                            if is_closed && socket.can_close() {
 | 
			
		||||
                                trace!("[SESSION] KCP session closed");
 | 
			
		||||
                                break;
 | 
			
		||||
                            }
 | 
			
		||||
 | 
			
		||||
                            // server socket expires
 | 
			
		||||
                            if !is_client {
 | 
			
		||||
                                // If this is a server stream, close it automatically after a period of time
 | 
			
		||||
                                let last_update_time = socket.last_update_time();
 | 
			
		||||
                                let elapsed = last_update_time.elapsed();
 | 
			
		||||
 | 
			
		||||
                                if elapsed > session.session_expire {
 | 
			
		||||
                                    if elapsed > session.session_expire * 2 {
 | 
			
		||||
                                        // Force close. Client may have already gone.
 | 
			
		||||
                                        trace!(
 | 
			
		||||
                                            "[SESSION] force close inactive session, conv: {}, last_update: {}s ago",
 | 
			
		||||
                                            socket.conv(),
 | 
			
		||||
                                            elapsed.as_secs()
 | 
			
		||||
                                        );
 | 
			
		||||
                                        break;
 | 
			
		||||
                                    }
 | 
			
		||||
 | 
			
		||||
                                    if !is_closed {
 | 
			
		||||
                                        trace!(
 | 
			
		||||
                                            "[SESSION] closing inactive session, conv: {}, last_update: {}s ago",
 | 
			
		||||
                                            socket.conv(),
 | 
			
		||||
                                            elapsed.as_secs()
 | 
			
		||||
                                        );
 | 
			
		||||
                                        session.closed.store(true, Ordering::Release);
 | 
			
		||||
                                    }
 | 
			
		||||
                                }
 | 
			
		||||
                            }
 | 
			
		||||
 | 
			
		||||
                            match socket.update() {
 | 
			
		||||
                                Ok(next_next) => {
 | 
			
		||||
                                    update_timer.as_mut().reset(Instant::from_std(next_next));
 | 
			
		||||
                                }
 | 
			
		||||
                                Err(err) => {
 | 
			
		||||
                                    error!("[SESSION] KCP update failed, error: {}", err);
 | 
			
		||||
                                    update_timer.as_mut().reset(Instant::now() + Duration::from_millis(10));
 | 
			
		||||
                                }
 | 
			
		||||
                            }
 | 
			
		||||
                        }
 | 
			
		||||
                    }
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
                {
 | 
			
		||||
                    // Close the socket.
 | 
			
		||||
                    // Wake all pending tasks and let all send/recv return EOF
 | 
			
		||||
 | 
			
		||||
                    let mut socket = session.socket.lock().await;
 | 
			
		||||
                    socket.close();
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
                if let Some(ref notifier) = session.session_close_notifier {
 | 
			
		||||
                    let socket = session.socket.lock().await;
 | 
			
		||||
                    let _ = notifier.send(socket.conv()).await;
 | 
			
		||||
                }
 | 
			
		||||
            });
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        session
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn kcp_socket(&self) -> &Mutex<KcpSocket> {
 | 
			
		||||
        &self.socket
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn close(&self) {
 | 
			
		||||
        self.closed.store(true, Ordering::Release);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub async fn input(&self, buf: &[u8]) {
 | 
			
		||||
        self.input_tx
 | 
			
		||||
            .send(buf.to_owned())
 | 
			
		||||
            .await
 | 
			
		||||
            .expect("input channel closed")
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
pub struct KcpSessionManager {
 | 
			
		||||
    sessions: HashMap<u32, Arc<KcpSession>>,
 | 
			
		||||
    next_free_conv: u32,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl KcpSessionManager {
 | 
			
		||||
    pub fn new() -> KcpSessionManager {
 | 
			
		||||
        KcpSessionManager {
 | 
			
		||||
            sessions: HashMap::new(),
 | 
			
		||||
            next_free_conv: 0,
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn close_conv(&mut self, conv: u32) {
 | 
			
		||||
        self.sessions.remove(&conv);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn alloc_conv(&mut self) -> u32 {
 | 
			
		||||
        loop {
 | 
			
		||||
            let (mut c, _) = self.next_free_conv.overflowing_add(1);
 | 
			
		||||
            if c == 0 {
 | 
			
		||||
                let (nc, _) = c.overflowing_add(1);
 | 
			
		||||
                c = nc;
 | 
			
		||||
            }
 | 
			
		||||
            self.next_free_conv = c;
 | 
			
		||||
 | 
			
		||||
            if self.sessions.get(&self.next_free_conv).is_none() {
 | 
			
		||||
                let conv = self.next_free_conv;
 | 
			
		||||
                return conv;
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn get_or_create(
 | 
			
		||||
        &mut self,
 | 
			
		||||
        config: &KcpConfig,
 | 
			
		||||
        conv: u32,
 | 
			
		||||
        udp: &Arc<UdpSocket>,
 | 
			
		||||
        peer_addr: SocketAddr,
 | 
			
		||||
        session_close_notifier: &mpsc::Sender<u32>,
 | 
			
		||||
    ) -> KcpResult<(Arc<KcpSession>, bool)> {
 | 
			
		||||
        match self.sessions.entry(conv) {
 | 
			
		||||
            Entry::Occupied(occ) => Ok((occ.get().clone(), false)),
 | 
			
		||||
            Entry::Vacant(vac) => {
 | 
			
		||||
                let socket = KcpSocket::new(config, conv, udp.clone(), peer_addr, config.stream)?;
 | 
			
		||||
                let session = KcpSession::new_shared(
 | 
			
		||||
                    socket,
 | 
			
		||||
                    config.session_expire,
 | 
			
		||||
                    Some(session_close_notifier.clone()),
 | 
			
		||||
                );
 | 
			
		||||
                trace!("created session for conv: {}, peer: {}", conv, peer_addr);
 | 
			
		||||
                vac.insert(session.clone());
 | 
			
		||||
                Ok((session, true))
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										288
									
								
								src/plugins/kcp/skcp.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										288
									
								
								src/plugins/kcp/skcp.rs
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,288 @@
 | 
			
		||||
use std::{
 | 
			
		||||
    io::{self, ErrorKind, Write},
 | 
			
		||||
    net::SocketAddr,
 | 
			
		||||
    sync::Arc,
 | 
			
		||||
    task::{Context, Poll, Waker},
 | 
			
		||||
    time::{Duration, Instant},
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
use futures::future;
 | 
			
		||||
use kcp::{Error as KcpError, Kcp, KcpResult};
 | 
			
		||||
use log::{error, trace};
 | 
			
		||||
use tokio::{net::UdpSocket, sync::mpsc};
 | 
			
		||||
 | 
			
		||||
use crate::plugins::kcp::{utils::now_millis, KcpConfig};
 | 
			
		||||
 | 
			
		||||
/// Writer for sending packets to the underlying UdpSocket
 | 
			
		||||
struct UdpOutput {
 | 
			
		||||
    socket: Arc<UdpSocket>,
 | 
			
		||||
    target_addr: SocketAddr,
 | 
			
		||||
    delay_tx: mpsc::UnboundedSender<Vec<u8>>,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl UdpOutput {
 | 
			
		||||
    /// Create a new Writer for writing packets to UdpSocket
 | 
			
		||||
    pub fn new(socket: Arc<UdpSocket>, target_addr: SocketAddr) -> UdpOutput {
 | 
			
		||||
        let (delay_tx, mut delay_rx) = mpsc::unbounded_channel::<Vec<u8>>();
 | 
			
		||||
 | 
			
		||||
        {
 | 
			
		||||
            let socket = socket.clone();
 | 
			
		||||
            tokio::spawn(async move {
 | 
			
		||||
                while let Some(buf) = delay_rx.recv().await {
 | 
			
		||||
                    if let Err(err) = socket.send_to(&buf, target_addr).await {
 | 
			
		||||
                        error!("[SEND] UDP delayed send failed, error: {}", err);
 | 
			
		||||
                    }
 | 
			
		||||
                }
 | 
			
		||||
            });
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        UdpOutput {
 | 
			
		||||
            socket,
 | 
			
		||||
            target_addr,
 | 
			
		||||
            delay_tx,
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl Write for UdpOutput {
 | 
			
		||||
    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
 | 
			
		||||
        match self.socket.try_send_to(buf, self.target_addr) {
 | 
			
		||||
            Ok(n) => Ok(n),
 | 
			
		||||
            Err(ref err) if err.kind() == ErrorKind::WouldBlock => {
 | 
			
		||||
                // send return EAGAIN
 | 
			
		||||
                // ignored as packet was lost in transmission
 | 
			
		||||
                trace!(
 | 
			
		||||
                    "[SEND] UDP send EAGAIN, packet.size: {} bytes, delayed send",
 | 
			
		||||
                    buf.len()
 | 
			
		||||
                );
 | 
			
		||||
 | 
			
		||||
                self.delay_tx
 | 
			
		||||
                    .send(buf.to_owned())
 | 
			
		||||
                    .expect("channel closed unexpectly");
 | 
			
		||||
 | 
			
		||||
                Ok(buf.len())
 | 
			
		||||
            }
 | 
			
		||||
            Err(err) => Err(err),
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    fn flush(&mut self) -> io::Result<()> {
 | 
			
		||||
        Ok(())
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
pub struct KcpSocket {
 | 
			
		||||
    kcp: Kcp<UdpOutput>,
 | 
			
		||||
    last_update: Instant,
 | 
			
		||||
    socket: Arc<UdpSocket>,
 | 
			
		||||
    flush_write: bool,
 | 
			
		||||
    flush_ack_input: bool,
 | 
			
		||||
    sent_first: bool,
 | 
			
		||||
    pending_sender: Option<Waker>,
 | 
			
		||||
    pending_receiver: Option<Waker>,
 | 
			
		||||
    closed: bool,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl KcpSocket {
 | 
			
		||||
    pub fn new(
 | 
			
		||||
        c: &KcpConfig,
 | 
			
		||||
        conv: u32,
 | 
			
		||||
        socket: Arc<UdpSocket>,
 | 
			
		||||
        target_addr: SocketAddr,
 | 
			
		||||
        stream: bool,
 | 
			
		||||
    ) -> KcpResult<KcpSocket> {
 | 
			
		||||
        let output = UdpOutput::new(socket.clone(), target_addr);
 | 
			
		||||
        let mut kcp = if stream {
 | 
			
		||||
            Kcp::new_stream(conv, output)
 | 
			
		||||
        } else {
 | 
			
		||||
            Kcp::new(conv, output)
 | 
			
		||||
        };
 | 
			
		||||
        c.apply_config(&mut kcp);
 | 
			
		||||
 | 
			
		||||
        // Ask server to allocate one
 | 
			
		||||
        if conv == 0 {
 | 
			
		||||
            kcp.input_conv();
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        kcp.update(now_millis())?;
 | 
			
		||||
 | 
			
		||||
        Ok(KcpSocket {
 | 
			
		||||
            kcp,
 | 
			
		||||
            last_update: Instant::now(),
 | 
			
		||||
            socket,
 | 
			
		||||
            flush_write: c.flush_write,
 | 
			
		||||
            flush_ack_input: c.flush_acks_input,
 | 
			
		||||
            sent_first: false,
 | 
			
		||||
            pending_sender: None,
 | 
			
		||||
            pending_receiver: None,
 | 
			
		||||
            closed: false,
 | 
			
		||||
        })
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /// Call every time you got data from transmission
 | 
			
		||||
    pub fn input(&mut self, buf: &[u8]) -> KcpResult<bool> {
 | 
			
		||||
        match self.kcp.input(buf) {
 | 
			
		||||
            Ok(..) => {}
 | 
			
		||||
            Err(KcpError::ConvInconsistent(expected, actual)) => {
 | 
			
		||||
                trace!(
 | 
			
		||||
                    "[INPUT] Conv expected={} actual={} ignored",
 | 
			
		||||
                    expected,
 | 
			
		||||
                    actual
 | 
			
		||||
                );
 | 
			
		||||
                return Ok(false);
 | 
			
		||||
            }
 | 
			
		||||
            Err(err) => return Err(err),
 | 
			
		||||
        }
 | 
			
		||||
        self.last_update = Instant::now();
 | 
			
		||||
 | 
			
		||||
        if self.flush_ack_input {
 | 
			
		||||
            self.kcp.flush_ack()?;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        Ok(self.try_wake_pending_waker())
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /// Call if you want to send some data
 | 
			
		||||
    pub fn poll_send(&mut self, cx: &mut Context<'_>, mut buf: &[u8]) -> Poll<KcpResult<usize>> {
 | 
			
		||||
        if self.closed {
 | 
			
		||||
            return Ok(0).into();
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        // If:
 | 
			
		||||
        //     1. Have sent the first packet (asking for conv)
 | 
			
		||||
        //     2. Too many pending packets
 | 
			
		||||
        if self.sent_first
 | 
			
		||||
            && (self.kcp.wait_snd() >= self.kcp.snd_wnd() as usize || self.kcp.waiting_conv())
 | 
			
		||||
        {
 | 
			
		||||
            trace!(
 | 
			
		||||
                "[SEND] waitsnd={} sndwnd={} excceeded or waiting conv={}",
 | 
			
		||||
                self.kcp.wait_snd(),
 | 
			
		||||
                self.kcp.snd_wnd(),
 | 
			
		||||
                self.kcp.waiting_conv()
 | 
			
		||||
            );
 | 
			
		||||
            self.pending_sender = Some(cx.waker().clone());
 | 
			
		||||
            return Poll::Pending;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        if !self.sent_first && self.kcp.waiting_conv() && buf.len() > self.kcp.mss() as usize {
 | 
			
		||||
            buf = &buf[..self.kcp.mss() as usize];
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        let n = self.kcp.send(buf)?;
 | 
			
		||||
        self.sent_first = true;
 | 
			
		||||
        self.last_update = Instant::now();
 | 
			
		||||
 | 
			
		||||
        if self.flush_write {
 | 
			
		||||
            self.kcp.flush()?;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        Ok(n).into()
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /// Call if you want to send some data
 | 
			
		||||
    #[allow(dead_code)]
 | 
			
		||||
    pub async fn send(&mut self, buf: &[u8]) -> KcpResult<usize> {
 | 
			
		||||
        future::poll_fn(|cx| self.poll_send(cx, buf)).await
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    #[allow(dead_code)]
 | 
			
		||||
    pub fn try_recv(&mut self, buf: &mut [u8]) -> KcpResult<usize> {
 | 
			
		||||
        if self.closed {
 | 
			
		||||
            return Ok(0);
 | 
			
		||||
        }
 | 
			
		||||
        self.kcp.recv(buf)
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn poll_recv(&mut self, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<KcpResult<usize>> {
 | 
			
		||||
        if self.closed {
 | 
			
		||||
            return Ok(0).into();
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        match self.kcp.recv(buf) {
 | 
			
		||||
            Ok(n) => Ok(n).into(),
 | 
			
		||||
            Err(KcpError::RecvQueueEmpty) => {
 | 
			
		||||
                self.pending_receiver = Some(cx.waker().clone());
 | 
			
		||||
                Poll::Pending
 | 
			
		||||
            }
 | 
			
		||||
            Err(err) => Err(err).into(),
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    #[allow(dead_code)]
 | 
			
		||||
    pub async fn recv(&mut self, buf: &mut [u8]) -> KcpResult<usize> {
 | 
			
		||||
        future::poll_fn(|cx| self.poll_recv(cx, buf)).await
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn flush(&mut self) -> KcpResult<()> {
 | 
			
		||||
        self.kcp.flush()?;
 | 
			
		||||
        self.last_update = Instant::now();
 | 
			
		||||
        Ok(())
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    fn try_wake_pending_waker(&mut self) -> bool {
 | 
			
		||||
        let mut waked = false;
 | 
			
		||||
 | 
			
		||||
        if self.pending_sender.is_some()
 | 
			
		||||
            && self.kcp.wait_snd() < self.kcp.snd_wnd() as usize
 | 
			
		||||
            && !self.kcp.waiting_conv()
 | 
			
		||||
        {
 | 
			
		||||
            let waker = self.pending_sender.take().unwrap();
 | 
			
		||||
            waker.wake();
 | 
			
		||||
 | 
			
		||||
            waked = true;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        if self.pending_receiver.is_some() {
 | 
			
		||||
            if let Ok(peek) = self.kcp.peeksize() {
 | 
			
		||||
                if peek > 0 {
 | 
			
		||||
                    let waker = self.pending_receiver.take().unwrap();
 | 
			
		||||
                    waker.wake();
 | 
			
		||||
 | 
			
		||||
                    waked = true;
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        waked
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn update(&mut self) -> KcpResult<Instant> {
 | 
			
		||||
        let now = now_millis();
 | 
			
		||||
        self.kcp.update(now)?;
 | 
			
		||||
        let next = self.kcp.check(now);
 | 
			
		||||
 | 
			
		||||
        self.try_wake_pending_waker();
 | 
			
		||||
 | 
			
		||||
        Ok(Instant::now() + Duration::from_millis(next as u64))
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn close(&mut self) {
 | 
			
		||||
        self.closed = true;
 | 
			
		||||
        if let Some(w) = self.pending_sender.take() {
 | 
			
		||||
            w.wake();
 | 
			
		||||
        }
 | 
			
		||||
        if let Some(w) = self.pending_receiver.take() {
 | 
			
		||||
            w.wake();
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn udp_socket(&self) -> &Arc<UdpSocket> {
 | 
			
		||||
        &self.socket
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn can_close(&self) -> bool {
 | 
			
		||||
        self.kcp.wait_snd() == 0
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn conv(&self) -> u32 {
 | 
			
		||||
        self.kcp.conv()
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn peek_size(&self) -> KcpResult<usize> {
 | 
			
		||||
        self.kcp.peeksize()
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn last_update_time(&self) -> Instant {
 | 
			
		||||
        self.last_update
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										183
									
								
								src/plugins/kcp/stream.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										183
									
								
								src/plugins/kcp/stream.rs
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,183 @@
 | 
			
		||||
use std::{
 | 
			
		||||
    io::{self, ErrorKind},
 | 
			
		||||
    net::{IpAddr, SocketAddr},
 | 
			
		||||
    pin::Pin,
 | 
			
		||||
    sync::Arc,
 | 
			
		||||
    task::{Context, Poll},
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
use futures::{future, ready};
 | 
			
		||||
use kcp::{Error as KcpError, KcpResult};
 | 
			
		||||
use log::trace;
 | 
			
		||||
use tokio::{
 | 
			
		||||
    io::{AsyncRead, AsyncWrite, ReadBuf},
 | 
			
		||||
    net::UdpSocket,
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
use crate::plugins::kcp::{config::KcpConfig, session::KcpSession, skcp::KcpSocket};
 | 
			
		||||
 | 
			
		||||
pub struct KcpStream {
 | 
			
		||||
    session: Arc<KcpSession>,
 | 
			
		||||
    recv_buffer: Vec<u8>,
 | 
			
		||||
    recv_buffer_pos: usize,
 | 
			
		||||
    recv_buffer_cap: usize,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl Drop for KcpStream {
 | 
			
		||||
    fn drop(&mut self) {
 | 
			
		||||
        self.session.close();
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[allow(unused)]
 | 
			
		||||
impl KcpStream {
 | 
			
		||||
    pub async fn connect(config: &KcpConfig, addr: SocketAddr) -> KcpResult<KcpStream> {
 | 
			
		||||
        let udp = match addr.ip() {
 | 
			
		||||
            IpAddr::V4(..) => UdpSocket::bind("0.0.0.0:0").await?,
 | 
			
		||||
            IpAddr::V6(..) => UdpSocket::bind("[::]:0").await?,
 | 
			
		||||
        };
 | 
			
		||||
 | 
			
		||||
        let udp = Arc::new(udp);
 | 
			
		||||
        let socket = KcpSocket::new(config, 0, udp, addr, config.stream)?;
 | 
			
		||||
 | 
			
		||||
        let session = KcpSession::new_shared(socket, config.session_expire, None);
 | 
			
		||||
 | 
			
		||||
        Ok(KcpStream::with_session(session))
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub(crate) fn with_session(session: Arc<KcpSession>) -> KcpStream {
 | 
			
		||||
        KcpStream {
 | 
			
		||||
            session,
 | 
			
		||||
            recv_buffer: Vec::new(),
 | 
			
		||||
            recv_buffer_pos: 0,
 | 
			
		||||
            recv_buffer_cap: 0,
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn poll_send(&mut self, cx: &mut Context<'_>, buf: &[u8]) -> Poll<KcpResult<usize>> {
 | 
			
		||||
        // Mutex doesn't have poll_lock, spinning on it.
 | 
			
		||||
        let socket = self.session.kcp_socket();
 | 
			
		||||
        let mut kcp = match socket.try_lock() {
 | 
			
		||||
            Ok(guard) => guard,
 | 
			
		||||
            Err(..) => {
 | 
			
		||||
                cx.waker().wake_by_ref();
 | 
			
		||||
                return Poll::Pending;
 | 
			
		||||
            }
 | 
			
		||||
        };
 | 
			
		||||
 | 
			
		||||
        kcp.poll_send(cx, buf)
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub async fn send(&mut self, buf: &[u8]) -> KcpResult<usize> {
 | 
			
		||||
        future::poll_fn(|cx| self.poll_send(cx, buf)).await
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn poll_recv(&mut self, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<KcpResult<usize>> {
 | 
			
		||||
        loop {
 | 
			
		||||
            // Consumes all data in buffer
 | 
			
		||||
            if self.recv_buffer_pos < self.recv_buffer_cap {
 | 
			
		||||
                let remaining = self.recv_buffer_cap - self.recv_buffer_pos;
 | 
			
		||||
                let copy_length = remaining.min(buf.len());
 | 
			
		||||
 | 
			
		||||
                buf.copy_from_slice(
 | 
			
		||||
                    &self.recv_buffer[self.recv_buffer_pos..self.recv_buffer_pos + copy_length],
 | 
			
		||||
                );
 | 
			
		||||
                self.recv_buffer_pos += copy_length;
 | 
			
		||||
                return Ok(copy_length).into();
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            // Mutex doesn't have poll_lock, spinning on it.
 | 
			
		||||
            let socket = self.session.kcp_socket();
 | 
			
		||||
            let mut kcp = match socket.try_lock() {
 | 
			
		||||
                Ok(guard) => guard,
 | 
			
		||||
                Err(..) => {
 | 
			
		||||
                    cx.waker().wake_by_ref();
 | 
			
		||||
                    return Poll::Pending;
 | 
			
		||||
                }
 | 
			
		||||
            };
 | 
			
		||||
 | 
			
		||||
            // Try to read from KCP
 | 
			
		||||
            // 1. Read directly with user provided `buf`
 | 
			
		||||
            match ready!(kcp.poll_recv(cx, buf)) {
 | 
			
		||||
                Ok(n) => {
 | 
			
		||||
                    trace!("[CLIENT] recv directly {} bytes", n);
 | 
			
		||||
                    return Ok(n).into();
 | 
			
		||||
                }
 | 
			
		||||
                Err(KcpError::UserBufTooSmall) => {}
 | 
			
		||||
                Err(err) => return Err(err).into(),
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            // 2. User `buf` too small, read to recv_buffer
 | 
			
		||||
            let required_size = kcp.peek_size()?;
 | 
			
		||||
            if self.recv_buffer.len() < required_size {
 | 
			
		||||
                self.recv_buffer.resize(required_size, 0);
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            match ready!(kcp.poll_recv(cx, &mut self.recv_buffer)) {
 | 
			
		||||
                Ok(n) => {
 | 
			
		||||
                    trace!("[CLIENT] recv buffered {} bytes", n);
 | 
			
		||||
                    self.recv_buffer_pos = 0;
 | 
			
		||||
                    self.recv_buffer_cap = n;
 | 
			
		||||
                }
 | 
			
		||||
                Err(err) => return Err(err).into(),
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub async fn recv(&mut self, buf: &mut [u8]) -> KcpResult<usize> {
 | 
			
		||||
        future::poll_fn(|cx| self.poll_recv(cx, buf)).await
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl AsyncRead for KcpStream {
 | 
			
		||||
    fn poll_read(
 | 
			
		||||
        mut self: Pin<&mut Self>,
 | 
			
		||||
        cx: &mut Context<'_>,
 | 
			
		||||
        buf: &mut ReadBuf<'_>,
 | 
			
		||||
    ) -> Poll<io::Result<()>> {
 | 
			
		||||
        match ready!(self.poll_recv(cx, buf.initialize_unfilled())) {
 | 
			
		||||
            Ok(n) => {
 | 
			
		||||
                buf.advance(n);
 | 
			
		||||
                Ok(()).into()
 | 
			
		||||
            }
 | 
			
		||||
            Err(KcpError::IoError(err)) => Err(err).into(),
 | 
			
		||||
            Err(err) => Err(io::Error::new(ErrorKind::Other, err)).into(),
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl AsyncWrite for KcpStream {
 | 
			
		||||
    fn poll_write(
 | 
			
		||||
        mut self: Pin<&mut Self>,
 | 
			
		||||
        cx: &mut Context<'_>,
 | 
			
		||||
        buf: &[u8],
 | 
			
		||||
    ) -> Poll<io::Result<usize>> {
 | 
			
		||||
        match ready!(self.poll_send(cx, buf)) {
 | 
			
		||||
            Ok(n) => Ok(n).into(),
 | 
			
		||||
            Err(KcpError::IoError(err)) => Err(err).into(),
 | 
			
		||||
            Err(err) => Err(io::Error::new(ErrorKind::Other, err)).into(),
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
 | 
			
		||||
        // Mutex doesn't have poll_lock, spinning on it.
 | 
			
		||||
        let socket = self.session.kcp_socket();
 | 
			
		||||
        let mut kcp = match socket.try_lock() {
 | 
			
		||||
            Ok(guard) => guard,
 | 
			
		||||
            Err(..) => {
 | 
			
		||||
                cx.waker().wake_by_ref();
 | 
			
		||||
                return Poll::Pending;
 | 
			
		||||
            }
 | 
			
		||||
        };
 | 
			
		||||
 | 
			
		||||
        match kcp.flush() {
 | 
			
		||||
            Ok(..) => Ok(()).into(),
 | 
			
		||||
            Err(KcpError::IoError(err)) => Err(err).into(),
 | 
			
		||||
            Err(err) => Err(io::Error::new(ErrorKind::Other, err)).into(),
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
 | 
			
		||||
        Ok(()).into()
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										10
									
								
								src/plugins/kcp/utils.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										10
									
								
								src/plugins/kcp/utils.rs
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,10 @@
 | 
			
		||||
use std::time::{SystemTime, UNIX_EPOCH};
 | 
			
		||||
 | 
			
		||||
#[inline]
 | 
			
		||||
pub fn now_millis() -> u32 {
 | 
			
		||||
    let start = SystemTime::now();
 | 
			
		||||
    let since_the_epoch = start
 | 
			
		||||
        .duration_since(UNIX_EPOCH)
 | 
			
		||||
        .expect("time went afterwards");
 | 
			
		||||
    (since_the_epoch.as_secs() * 1000 + since_the_epoch.subsec_millis() as u64 / 1_000_000) as u32
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										1
									
								
								src/plugins/mod.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								src/plugins/mod.rs
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1 @@
 | 
			
		||||
//pub mod kcp;
 | 
			
		||||
@@ -1,41 +1,41 @@
 | 
			
		||||
use futures::future::try_join;
 | 
			
		||||
use log::{debug, error, info, warn};
 | 
			
		||||
use log::{error, info};
 | 
			
		||||
use std::collections::{HashMap, HashSet};
 | 
			
		||||
use std::net::SocketAddr;
 | 
			
		||||
use std::sync::Arc;
 | 
			
		||||
use tokio::io;
 | 
			
		||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
 | 
			
		||||
use tokio::net::{TcpListener, TcpStream};
 | 
			
		||||
 | 
			
		||||
use tokio::task::JoinHandle;
 | 
			
		||||
 | 
			
		||||
mod tls;
 | 
			
		||||
use self::tls::get_sni;
 | 
			
		||||
use crate::config::BaseConfig;
 | 
			
		||||
mod protocol;
 | 
			
		||||
pub(crate) mod upstream_address;
 | 
			
		||||
 | 
			
		||||
use crate::config::ParsedConfigV1;
 | 
			
		||||
use crate::upstreams::Upstream;
 | 
			
		||||
use protocol::tcp;
 | 
			
		||||
 | 
			
		||||
#[derive(Debug)]
 | 
			
		||||
pub struct Server {
 | 
			
		||||
pub(crate) struct Server {
 | 
			
		||||
    pub proxies: Vec<Arc<Proxy>>,
 | 
			
		||||
    pub config: BaseConfig,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[derive(Debug, Clone)]
 | 
			
		||||
pub struct Proxy {
 | 
			
		||||
pub(crate) struct Proxy {
 | 
			
		||||
    pub name: String,
 | 
			
		||||
    pub listen: SocketAddr,
 | 
			
		||||
    pub protocol: String,
 | 
			
		||||
    pub tls: bool,
 | 
			
		||||
    pub sni: Option<HashMap<String, String>>,
 | 
			
		||||
    pub default: String,
 | 
			
		||||
    pub upstream: HashMap<String, String>,
 | 
			
		||||
    pub default_action: String,
 | 
			
		||||
    pub upstream: HashMap<String, Upstream>,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl Server {
 | 
			
		||||
    pub fn new(config: BaseConfig) -> Self {
 | 
			
		||||
    pub fn new_from_v1_config(config: ParsedConfigV1) -> Self {
 | 
			
		||||
        let mut new_server = Server {
 | 
			
		||||
            proxies: Vec::new(),
 | 
			
		||||
            config: config.clone(),
 | 
			
		||||
        };
 | 
			
		||||
 | 
			
		||||
        for (name, proxy) in config.servers.iter() {
 | 
			
		||||
            let protocol = proxy.protocol.clone().unwrap_or_else(|| "tcp".to_string());
 | 
			
		||||
            let tls = proxy.tls.unwrap_or(false);
 | 
			
		||||
            let sni = proxy.sni.clone();
 | 
			
		||||
            let default = proxy.default.clone().unwrap_or_else(|| "ban".to_string());
 | 
			
		||||
@@ -48,7 +48,6 @@ impl Server {
 | 
			
		||||
                upstream_set.insert(key.clone());
 | 
			
		||||
            }
 | 
			
		||||
            for listen in proxy.listen.clone() {
 | 
			
		||||
                println!("{:?}", listen);
 | 
			
		||||
                let listen_addr: SocketAddr = match listen.parse() {
 | 
			
		||||
                    Ok(addr) => addr,
 | 
			
		||||
                    Err(_) => {
 | 
			
		||||
@@ -56,12 +55,14 @@ impl Server {
 | 
			
		||||
                        continue;
 | 
			
		||||
                    }
 | 
			
		||||
                };
 | 
			
		||||
 | 
			
		||||
                let proxy = Proxy {
 | 
			
		||||
                    name: name.clone(),
 | 
			
		||||
                    listen: listen_addr,
 | 
			
		||||
                    protocol: protocol.clone(),
 | 
			
		||||
                    tls,
 | 
			
		||||
                    sni: sni.clone(),
 | 
			
		||||
                    default: default.clone(),
 | 
			
		||||
                    default_action: default.clone(),
 | 
			
		||||
                    upstream: upstream.clone(),
 | 
			
		||||
                };
 | 
			
		||||
                new_server.proxies.push(Arc::new(proxy));
 | 
			
		||||
@@ -77,9 +78,28 @@ impl Server {
 | 
			
		||||
        let mut handles: Vec<JoinHandle<()>> = Vec::new();
 | 
			
		||||
 | 
			
		||||
        for config in proxies {
 | 
			
		||||
            info!("Starting server {} on {}", config.name, config.listen);
 | 
			
		||||
            info!(
 | 
			
		||||
                "Starting {} server {} on {}",
 | 
			
		||||
                config.protocol, config.name, config.listen
 | 
			
		||||
            );
 | 
			
		||||
            let handle = tokio::spawn(async move {
 | 
			
		||||
                let _ = proxy(config).await;
 | 
			
		||||
                match config.protocol.as_ref() {
 | 
			
		||||
                    "tcp" | "tcp4" | "tcp6" => {
 | 
			
		||||
                        let res = tcp::proxy(config.clone()).await;
 | 
			
		||||
                        if res.is_err() {
 | 
			
		||||
                            error!("Failed to start {}: {}", config.name, res.err().unwrap());
 | 
			
		||||
                        }
 | 
			
		||||
                    }
 | 
			
		||||
                    // "kcp" => {
 | 
			
		||||
                    //     let res = kcp::proxy(config.clone()).await;
 | 
			
		||||
                    //     if res.is_err() {
 | 
			
		||||
                    //         error!("Failed to start {}: {}", config.name, res.err().unwrap());
 | 
			
		||||
                    //     }
 | 
			
		||||
                    // }
 | 
			
		||||
                    _ => {
 | 
			
		||||
                        error!("Invalid protocol: {}", config.protocol)
 | 
			
		||||
                    }
 | 
			
		||||
                }
 | 
			
		||||
            });
 | 
			
		||||
            handles.push(handle);
 | 
			
		||||
        }
 | 
			
		||||
@@ -91,145 +111,92 @@ impl Server {
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
async fn proxy(config: Arc<Proxy>) -> Result<(), Box<dyn std::error::Error>> {
 | 
			
		||||
    let listener = TcpListener::bind(config.listen).await?;
 | 
			
		||||
    let config = config.clone();
 | 
			
		||||
 | 
			
		||||
    loop {
 | 
			
		||||
        let thread_proxy = config.clone();
 | 
			
		||||
        match listener.accept().await {
 | 
			
		||||
            Err(err) => {
 | 
			
		||||
                error!("Failed to accept connection: {}", err);
 | 
			
		||||
                return Err(Box::new(err));
 | 
			
		||||
            }
 | 
			
		||||
            Ok((stream, _)) => {
 | 
			
		||||
                tokio::spawn(async move {
 | 
			
		||||
                    match accept(stream, thread_proxy).await {
 | 
			
		||||
                        Ok(_) => {}
 | 
			
		||||
                        Err(err) => {
 | 
			
		||||
                            error!("Relay thread returned an error: {}", err);
 | 
			
		||||
                        }
 | 
			
		||||
                    };
 | 
			
		||||
                });
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
async fn accept(inbound: TcpStream, proxy: Arc<Proxy>) -> Result<(), Box<dyn std::error::Error>> {
 | 
			
		||||
    debug!("New connection from {:?}", inbound.peer_addr()?);
 | 
			
		||||
 | 
			
		||||
    let upstream_name = match proxy.tls {
 | 
			
		||||
        false => proxy.default.clone(),
 | 
			
		||||
        true => {
 | 
			
		||||
            let mut hello_buf = [0u8; 1024];
 | 
			
		||||
            inbound.peek(&mut hello_buf).await?;
 | 
			
		||||
            let snis = get_sni(&hello_buf);
 | 
			
		||||
            if snis.is_empty() {
 | 
			
		||||
                proxy.default.clone()
 | 
			
		||||
            } else {
 | 
			
		||||
                match proxy.sni.clone() {
 | 
			
		||||
                    Some(sni_map) => {
 | 
			
		||||
                        let mut upstream = proxy.default.clone();
 | 
			
		||||
                        for sni in snis {
 | 
			
		||||
                            let m = sni_map.get(&sni);
 | 
			
		||||
                            if m.is_some() {
 | 
			
		||||
                                upstream = m.unwrap().clone();
 | 
			
		||||
                                break;
 | 
			
		||||
                            }
 | 
			
		||||
                        }
 | 
			
		||||
                        upstream
 | 
			
		||||
                    }
 | 
			
		||||
                    None => proxy.default.clone(),
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
    debug!("Upstream: {}", upstream_name);
 | 
			
		||||
 | 
			
		||||
    let upstream = match proxy.upstream.get(&upstream_name) {
 | 
			
		||||
        Some(upstream) => upstream,
 | 
			
		||||
        None => {
 | 
			
		||||
            warn!(
 | 
			
		||||
                "No upstream named {:?} on server {:?}",
 | 
			
		||||
                proxy.default, proxy.name
 | 
			
		||||
            );
 | 
			
		||||
            return process(inbound, &proxy.default).await;
 | 
			
		||||
        }
 | 
			
		||||
    };
 | 
			
		||||
    return process(inbound, upstream).await;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
async fn process(mut inbound: TcpStream, upstream: &str) -> Result<(), Box<dyn std::error::Error>> {
 | 
			
		||||
    if upstream == "ban" {
 | 
			
		||||
        let _ = inbound.shutdown();
 | 
			
		||||
        return Ok(());
 | 
			
		||||
    } else if upstream == "echo" {
 | 
			
		||||
        loop {
 | 
			
		||||
            let mut buf = [0u8; 1];
 | 
			
		||||
            let b = inbound.read(&mut buf).await?;
 | 
			
		||||
            if b == 0 {
 | 
			
		||||
                break;
 | 
			
		||||
            } else {
 | 
			
		||||
                inbound.write(&buf).await?;
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
        return Ok(());
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    let outbound = TcpStream::connect(upstream).await?;
 | 
			
		||||
 | 
			
		||||
    let (mut ri, mut wi) = io::split(inbound);
 | 
			
		||||
    let (mut ro, mut wo) = io::split(outbound);
 | 
			
		||||
 | 
			
		||||
    let inbound_to_outbound = copy(&mut ri, &mut wo);
 | 
			
		||||
    let outbound_to_inbound = copy(&mut ro, &mut wi);
 | 
			
		||||
 | 
			
		||||
    let (bytes_tx, bytes_rx) = try_join(inbound_to_outbound, outbound_to_inbound).await?;
 | 
			
		||||
 | 
			
		||||
    debug!("Bytes read: {:?} write: {:?}", bytes_tx, bytes_rx);
 | 
			
		||||
 | 
			
		||||
    Ok(())
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
async fn copy<'a, R, W>(reader: &'a mut R, writer: &'a mut W) -> io::Result<u64>
 | 
			
		||||
where
 | 
			
		||||
    R: AsyncRead + Unpin + ?Sized,
 | 
			
		||||
    W: AsyncWrite + Unpin + ?Sized,
 | 
			
		||||
{
 | 
			
		||||
    match io::copy(reader, writer).await {
 | 
			
		||||
        Ok(u64) => {
 | 
			
		||||
            let _ = writer.shutdown().await;
 | 
			
		||||
            Ok(u64)
 | 
			
		||||
        }
 | 
			
		||||
        Err(_) => Ok(0),
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[cfg(test)]
 | 
			
		||||
mod test {
 | 
			
		||||
mod tests {
 | 
			
		||||
    //use crate::plugins::kcp::{KcpConfig, KcpStream};
 | 
			
		||||
    use std::thread::{self, sleep};
 | 
			
		||||
    use std::time::Duration;
 | 
			
		||||
    use tokio::io::{AsyncReadExt, AsyncWriteExt};
 | 
			
		||||
    use tokio::net::TcpListener;
 | 
			
		||||
 | 
			
		||||
    use super::*;
 | 
			
		||||
 | 
			
		||||
    #[tokio::main]
 | 
			
		||||
    async fn tcp_mock_server() {
 | 
			
		||||
        let server_addr: SocketAddr = "127.0.0.1:54599".parse().unwrap();
 | 
			
		||||
        let listener = TcpListener::bind(server_addr).await.unwrap();
 | 
			
		||||
        loop {
 | 
			
		||||
            let (mut stream, _) = listener.accept().await.unwrap();
 | 
			
		||||
            let mut buf = [0u8; 2];
 | 
			
		||||
            let mut n = stream.read(&mut buf).await.unwrap();
 | 
			
		||||
            while n > 0 {
 | 
			
		||||
                let _ = stream.write(b"hello").await.unwrap();
 | 
			
		||||
                if buf.eq(b"by") {
 | 
			
		||||
                    stream.shutdown().await.unwrap();
 | 
			
		||||
                    break;
 | 
			
		||||
                }
 | 
			
		||||
                n = stream.read(&mut buf).await.unwrap();
 | 
			
		||||
            }
 | 
			
		||||
            stream.shutdown().await.unwrap();
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    #[tokio::test]
 | 
			
		||||
    async fn test_echo_server() {
 | 
			
		||||
        use crate::config::Config;
 | 
			
		||||
        let config = Config::new("tests/config.yaml").unwrap();
 | 
			
		||||
        let mut server = Server::new(config.base);
 | 
			
		||||
    async fn test_proxy() {
 | 
			
		||||
        use crate::config::ConfigV1;
 | 
			
		||||
        let config = ConfigV1::new("tests/config.yaml").unwrap();
 | 
			
		||||
        let mut server = Server::new_from_v1_config(config.base);
 | 
			
		||||
        thread::spawn(move || {
 | 
			
		||||
            tcp_mock_server();
 | 
			
		||||
        });
 | 
			
		||||
        sleep(Duration::from_secs(1)); // wait for server to start
 | 
			
		||||
        thread::spawn(move || {
 | 
			
		||||
            let _ = server.run();
 | 
			
		||||
        });
 | 
			
		||||
        sleep(Duration::from_secs(1)); // wait for server to start
 | 
			
		||||
        let mut conn = TcpStream::connect("127.0.0.1:54956").await.unwrap();
 | 
			
		||||
 | 
			
		||||
        // test TCP proxy
 | 
			
		||||
        let mut conn = tokio::net::TcpStream::connect("127.0.0.1:54500")
 | 
			
		||||
            .await
 | 
			
		||||
            .unwrap();
 | 
			
		||||
        let mut buf = [0u8; 5];
 | 
			
		||||
        let _ = conn.write(b"hi").await.unwrap();
 | 
			
		||||
        let _ = conn.read(&mut buf).await.unwrap();
 | 
			
		||||
        assert_eq!(&buf, b"hello");
 | 
			
		||||
        conn.shutdown().await.unwrap();
 | 
			
		||||
 | 
			
		||||
        // test TCP echo
 | 
			
		||||
        let mut conn = tokio::net::TcpStream::connect("127.0.0.1:54956")
 | 
			
		||||
            .await
 | 
			
		||||
            .unwrap();
 | 
			
		||||
        let mut buf = [0u8; 1];
 | 
			
		||||
        for i in 0..=255u8 {
 | 
			
		||||
            conn.write(&[i]).await.unwrap();
 | 
			
		||||
            conn.read(&mut buf).await.unwrap();
 | 
			
		||||
        for i in 0..=10u8 {
 | 
			
		||||
            let _ = conn.write(&[i]).await.unwrap();
 | 
			
		||||
            let _ = conn.read(&mut buf).await.unwrap();
 | 
			
		||||
            assert_eq!(&buf, &[i]);
 | 
			
		||||
        }
 | 
			
		||||
        conn.shutdown().await.unwrap();
 | 
			
		||||
 | 
			
		||||
        // test KCP echo
 | 
			
		||||
        // let kcp_config = KcpConfig::default();
 | 
			
		||||
        // let server_addr: SocketAddr = "127.0.0.1:54959".parse().unwrap();
 | 
			
		||||
        // let mut conn = KcpStream::connect(&kcp_config, server_addr).await.unwrap();
 | 
			
		||||
        // let mut buf = [0u8; 1];
 | 
			
		||||
        // for i in 0..=10u8 {
 | 
			
		||||
        //     conn.write(&[i]).await.unwrap();
 | 
			
		||||
        //     conn.read(&mut buf).await.unwrap();
 | 
			
		||||
        //     assert_eq!(&buf, &[i]);
 | 
			
		||||
        // }
 | 
			
		||||
        // conn.shutdown().await.unwrap();
 | 
			
		||||
        //
 | 
			
		||||
        // // test KCP proxy and close mock server
 | 
			
		||||
        // let kcp_config = KcpConfig::default();
 | 
			
		||||
        // let server_addr: SocketAddr = "127.0.0.1:54958".parse().unwrap();
 | 
			
		||||
        // let mut conn = KcpStream::connect(&kcp_config, server_addr).await.unwrap();
 | 
			
		||||
        // let mut buf = [0u8; 5];
 | 
			
		||||
        // conn.write(b"by").await.unwrap();
 | 
			
		||||
        // conn.read(&mut buf).await.unwrap();
 | 
			
		||||
        // assert_eq!(&buf, b"hello");
 | 
			
		||||
        // conn.shutdown().await.unwrap();
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										98
									
								
								src/servers/protocol/kcp.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										98
									
								
								src/servers/protocol/kcp.rs
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,98 @@
 | 
			
		||||
use crate::config::Upstream;
 | 
			
		||||
use crate::plugins::kcp::{KcpConfig, KcpListener, KcpStream};
 | 
			
		||||
use crate::servers::{copy, Proxy};
 | 
			
		||||
use futures::future::try_join;
 | 
			
		||||
use log::{debug, error, warn};
 | 
			
		||||
use std::net::SocketAddr;
 | 
			
		||||
use std::sync::Arc;
 | 
			
		||||
use tokio::io;
 | 
			
		||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
 | 
			
		||||
use tokio::net::TcpStream;
 | 
			
		||||
 | 
			
		||||
pub async fn proxy(config: Arc<Proxy>) -> Result<(), Box<dyn std::error::Error>> {
 | 
			
		||||
    let kcp_config = KcpConfig::default();
 | 
			
		||||
    let mut listener = KcpListener::bind(kcp_config, config.listen).await?;
 | 
			
		||||
    let config = config.clone();
 | 
			
		||||
 | 
			
		||||
    loop {
 | 
			
		||||
        let thread_proxy = config.clone();
 | 
			
		||||
        match listener.accept().await {
 | 
			
		||||
            Err(err) => {
 | 
			
		||||
                error!("Failed to accept connection: {}", err);
 | 
			
		||||
                return Err(Box::new(err));
 | 
			
		||||
            }
 | 
			
		||||
            Ok((stream, peer)) => {
 | 
			
		||||
                tokio::spawn(async move {
 | 
			
		||||
                    match accept(stream, peer, thread_proxy).await {
 | 
			
		||||
                        Ok(_) => {}
 | 
			
		||||
                        Err(err) => {
 | 
			
		||||
                            error!("Relay thread returned an error: {}", err);
 | 
			
		||||
                        }
 | 
			
		||||
                    };
 | 
			
		||||
                });
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
async fn accept(
 | 
			
		||||
    inbound: KcpStream,
 | 
			
		||||
    peer: SocketAddr,
 | 
			
		||||
    proxy: Arc<Proxy>,
 | 
			
		||||
) -> Result<(), Box<dyn std::error::Error>> {
 | 
			
		||||
    debug!("New connection from {:?}", peer);
 | 
			
		||||
 | 
			
		||||
    let upstream_name = proxy.default_action.clone();
 | 
			
		||||
 | 
			
		||||
    debug!("Upstream: {}", upstream_name);
 | 
			
		||||
 | 
			
		||||
    let upstream = match proxy.upstream.get(&upstream_name) {
 | 
			
		||||
        Some(upstream) => upstream,
 | 
			
		||||
        None => {
 | 
			
		||||
            warn!(
 | 
			
		||||
                "No upstream named {:?} on server {:?}",
 | 
			
		||||
                proxy.default, proxy.name
 | 
			
		||||
            );
 | 
			
		||||
            return process(inbound, proxy.upstream.get(&proxy.default_action).unwrap()).await;
 | 
			
		||||
            // ToDo: Remove unwrap and check default option
 | 
			
		||||
        }
 | 
			
		||||
    };
 | 
			
		||||
    return process(inbound, upstream).await;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
async fn process(
 | 
			
		||||
    mut inbound: KcpStream,
 | 
			
		||||
    upstream: &Upstream,
 | 
			
		||||
) -> Result<(), Box<dyn std::error::Error>> {
 | 
			
		||||
    match upstream {
 | 
			
		||||
        Upstream::Ban => {
 | 
			
		||||
            let _ = inbound.shutdown();
 | 
			
		||||
        }
 | 
			
		||||
        Upstream::Echo => {
 | 
			
		||||
            let (mut ri, mut wi) = io::split(inbound);
 | 
			
		||||
            let inbound_to_inbound = copy(&mut ri, &mut wi);
 | 
			
		||||
            let bytes_tx = inbound_to_inbound.await;
 | 
			
		||||
            debug!("Bytes read: {:?}", bytes_tx);
 | 
			
		||||
        }
 | 
			
		||||
        Upstream::Custom(custom) => match custom.protocol.as_ref() {
 | 
			
		||||
            "tcp" => {
 | 
			
		||||
                let outbound = TcpStream::connect(custom.addr.clone()).await?;
 | 
			
		||||
 | 
			
		||||
                let (mut ri, mut wi) = io::split(inbound);
 | 
			
		||||
                let (mut ro, mut wo) = io::split(outbound);
 | 
			
		||||
 | 
			
		||||
                let inbound_to_outbound = copy(&mut ri, &mut wo);
 | 
			
		||||
                let outbound_to_inbound = copy(&mut ro, &mut wi);
 | 
			
		||||
 | 
			
		||||
                let (bytes_tx, bytes_rx) =
 | 
			
		||||
                    try_join(inbound_to_outbound, outbound_to_inbound).await?;
 | 
			
		||||
 | 
			
		||||
                debug!("Bytes read: {:?} write: {:?}", bytes_tx, bytes_rx);
 | 
			
		||||
            }
 | 
			
		||||
            _ => {
 | 
			
		||||
                error!("Reached unknown protocol: {:?}", custom.protocol);
 | 
			
		||||
            }
 | 
			
		||||
        },
 | 
			
		||||
    };
 | 
			
		||||
    Ok(())
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										3
									
								
								src/servers/protocol/mod.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										3
									
								
								src/servers/protocol/mod.rs
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,3 @@
 | 
			
		||||
//pub mod kcp;
 | 
			
		||||
pub mod tcp;
 | 
			
		||||
pub mod tls;
 | 
			
		||||
							
								
								
									
										77
									
								
								src/servers/protocol/tcp.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										77
									
								
								src/servers/protocol/tcp.rs
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,77 @@
 | 
			
		||||
use crate::servers::protocol::tls::get_sni;
 | 
			
		||||
use crate::servers::Proxy;
 | 
			
		||||
use log::{debug, error, info, warn};
 | 
			
		||||
use std::error::Error;
 | 
			
		||||
use std::sync::Arc;
 | 
			
		||||
use tokio::net::{TcpListener, TcpStream};
 | 
			
		||||
 | 
			
		||||
pub(crate) async fn proxy(config: Arc<Proxy>) -> Result<(), Box<dyn Error>> {
 | 
			
		||||
    let listener = TcpListener::bind(config.listen).await?;
 | 
			
		||||
    let config = config.clone();
 | 
			
		||||
 | 
			
		||||
    loop {
 | 
			
		||||
        let thread_proxy = config.clone();
 | 
			
		||||
        match listener.accept().await {
 | 
			
		||||
            Err(err) => {
 | 
			
		||||
                error!("Failed to accept connection: {}", err);
 | 
			
		||||
                return Err(Box::new(err));
 | 
			
		||||
            }
 | 
			
		||||
            Ok((stream, _)) => {
 | 
			
		||||
                tokio::spawn(async move {
 | 
			
		||||
                    match accept(stream, thread_proxy).await {
 | 
			
		||||
                        Ok(_) => {}
 | 
			
		||||
                        Err(err) => {
 | 
			
		||||
                            error!("Relay thread returned an error: {}", err);
 | 
			
		||||
                        }
 | 
			
		||||
                    };
 | 
			
		||||
                });
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
async fn accept(inbound: TcpStream, proxy: Arc<Proxy>) -> Result<(), Box<dyn Error>> {
 | 
			
		||||
    info!("New connection from {:?}", inbound.peer_addr()?);
 | 
			
		||||
 | 
			
		||||
    let upstream_name = match proxy.tls {
 | 
			
		||||
        false => proxy.default_action.clone(),
 | 
			
		||||
        true => {
 | 
			
		||||
            let mut hello_buf = [0u8; 1024];
 | 
			
		||||
            inbound.peek(&mut hello_buf).await?;
 | 
			
		||||
            let snis = get_sni(&hello_buf);
 | 
			
		||||
            if snis.is_empty() {
 | 
			
		||||
                proxy.default_action.clone()
 | 
			
		||||
            } else {
 | 
			
		||||
                match proxy.sni.clone() {
 | 
			
		||||
                    Some(sni_map) => {
 | 
			
		||||
                        let mut upstream = proxy.default_action.clone();
 | 
			
		||||
                        for sni in snis {
 | 
			
		||||
                            let m = sni_map.get(&sni);
 | 
			
		||||
                            if m.is_some() {
 | 
			
		||||
                                upstream = m.unwrap().clone();
 | 
			
		||||
                                break;
 | 
			
		||||
                            }
 | 
			
		||||
                        }
 | 
			
		||||
                        upstream
 | 
			
		||||
                    }
 | 
			
		||||
                    None => proxy.default_action.clone(),
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
    debug!("Upstream: {}", upstream_name);
 | 
			
		||||
 | 
			
		||||
    let upstream = match proxy.upstream.get(&upstream_name) {
 | 
			
		||||
        Some(upstream) => upstream,
 | 
			
		||||
        None => {
 | 
			
		||||
            warn!(
 | 
			
		||||
                "No upstream named {:?} on server {:?}",
 | 
			
		||||
                proxy.default_action, proxy.name
 | 
			
		||||
            );
 | 
			
		||||
            proxy.upstream.get(&proxy.default_action).unwrap()
 | 
			
		||||
        }
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
    upstream.process(inbound).await
 | 
			
		||||
}
 | 
			
		||||
@@ -49,6 +49,7 @@ pub fn get_sni(buf: &[u8]) -> Vec<String> {
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    debug!("Found SNIs: {:?}", &snis);
 | 
			
		||||
    snis
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -98,6 +99,6 @@ mod tests {
 | 
			
		||||
            0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
 | 
			
		||||
        ];
 | 
			
		||||
        let sni = get_sni(&BUF);
 | 
			
		||||
        assert!(sni[0] == "www.lirui.tech".to_string());
 | 
			
		||||
        assert!(sni[0] == *"www.lirui.tech");
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										141
									
								
								src/servers/upstream_address.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										141
									
								
								src/servers/upstream_address.rs
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,141 @@
 | 
			
		||||
use log::debug;
 | 
			
		||||
use std::fmt::{Display, Formatter};
 | 
			
		||||
use std::io::Result;
 | 
			
		||||
use std::net::SocketAddr;
 | 
			
		||||
use time::{Duration, Instant, OffsetDateTime};
 | 
			
		||||
 | 
			
		||||
#[derive(Debug, Clone, Default)]
 | 
			
		||||
pub(crate) struct UpstreamAddress {
 | 
			
		||||
    address: String,
 | 
			
		||||
    resolved_addresses: Vec<SocketAddr>,
 | 
			
		||||
    resolved_time: Option<Instant>,
 | 
			
		||||
    ttl: Option<Duration>,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl Display for UpstreamAddress {
 | 
			
		||||
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
 | 
			
		||||
        self.address.fmt(f)
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl UpstreamAddress {
 | 
			
		||||
    pub fn new(address: String) -> Self {
 | 
			
		||||
        UpstreamAddress {
 | 
			
		||||
            address,
 | 
			
		||||
            ..Default::default()
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn is_valid(&self) -> bool {
 | 
			
		||||
        if let Some(resolved) = self.resolved_time {
 | 
			
		||||
            if let Some(ttl) = self.ttl {
 | 
			
		||||
                return resolved.elapsed() < ttl;
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        false
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    fn is_resolved(&self) -> bool {
 | 
			
		||||
        !self.resolved_addresses.is_empty()
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    fn time_remaining(&self) -> Duration {
 | 
			
		||||
        if !self.is_valid() {
 | 
			
		||||
            return Duration::seconds(0);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        self.ttl.unwrap() - self.resolved_time.unwrap().elapsed()
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub async fn resolve(&mut self, mode: ResolutionMode) -> Result<Vec<SocketAddr>> {
 | 
			
		||||
        if self.is_resolved() && self.is_valid() {
 | 
			
		||||
            debug!(
 | 
			
		||||
                "Already got address {:?}, still valid for {:.3}s",
 | 
			
		||||
                &self.resolved_addresses,
 | 
			
		||||
                self.time_remaining().as_seconds_f64()
 | 
			
		||||
            );
 | 
			
		||||
            return Ok(self.resolved_addresses.clone());
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        debug!(
 | 
			
		||||
            "Resolving addresses for {} with mode {:?}",
 | 
			
		||||
            &self.address, &mode
 | 
			
		||||
        );
 | 
			
		||||
 | 
			
		||||
        let lookup_result = tokio::net::lookup_host(&self.address).await;
 | 
			
		||||
 | 
			
		||||
        let resolved_addresses: Vec<SocketAddr> = match lookup_result {
 | 
			
		||||
            Ok(resolved_addresses) => resolved_addresses.into_iter().collect(),
 | 
			
		||||
            Err(e) => {
 | 
			
		||||
                debug!("Failed looking up {}: {}", &self.address, &e);
 | 
			
		||||
                // Protect against DNS flooding. Cache the result for 1 second.
 | 
			
		||||
                self.resolved_time = Some(Instant::now());
 | 
			
		||||
                self.ttl = Some(Duration::seconds(3));
 | 
			
		||||
                return Err(e);
 | 
			
		||||
            }
 | 
			
		||||
        };
 | 
			
		||||
 | 
			
		||||
        debug!("Resolved addresses: {:?}", &resolved_addresses);
 | 
			
		||||
 | 
			
		||||
        let addresses: Vec<SocketAddr> = match mode {
 | 
			
		||||
            ResolutionMode::Ipv4 => resolved_addresses
 | 
			
		||||
                .into_iter()
 | 
			
		||||
                .filter(|a| a.is_ipv4())
 | 
			
		||||
                .collect(),
 | 
			
		||||
 | 
			
		||||
            ResolutionMode::Ipv6 => resolved_addresses
 | 
			
		||||
                .into_iter()
 | 
			
		||||
                .filter(|a| a.is_ipv6())
 | 
			
		||||
                .collect(),
 | 
			
		||||
 | 
			
		||||
            _ => resolved_addresses,
 | 
			
		||||
        };
 | 
			
		||||
 | 
			
		||||
        debug!(
 | 
			
		||||
            "Got {} addresses for {}: {:?}",
 | 
			
		||||
            &mode, &self.address, &addresses
 | 
			
		||||
        );
 | 
			
		||||
        debug!(
 | 
			
		||||
            "Resolved at {}",
 | 
			
		||||
            OffsetDateTime::now_utc()
 | 
			
		||||
                .format(&time::format_description::well_known::Rfc3339)
 | 
			
		||||
                .expect("Format")
 | 
			
		||||
        );
 | 
			
		||||
 | 
			
		||||
        self.resolved_addresses = addresses;
 | 
			
		||||
        self.resolved_time = Some(Instant::now());
 | 
			
		||||
        self.ttl = Some(Duration::minutes(1));
 | 
			
		||||
 | 
			
		||||
        Ok(self.resolved_addresses.clone())
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[derive(Debug, Default, Clone)]
 | 
			
		||||
pub(crate) enum ResolutionMode {
 | 
			
		||||
    #[default]
 | 
			
		||||
    Ipv4AndIpv6,
 | 
			
		||||
    Ipv4,
 | 
			
		||||
    Ipv6,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl From<&str> for ResolutionMode {
 | 
			
		||||
    fn from(value: &str) -> Self {
 | 
			
		||||
        match value {
 | 
			
		||||
            "tcp4" => ResolutionMode::Ipv4,
 | 
			
		||||
            "tcp6" => ResolutionMode::Ipv6,
 | 
			
		||||
            "tcp" => ResolutionMode::Ipv4AndIpv6,
 | 
			
		||||
            _ => panic!("This should never happen. Please check configuration parser."),
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl Display for ResolutionMode {
 | 
			
		||||
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
 | 
			
		||||
        match self {
 | 
			
		||||
            ResolutionMode::Ipv4 => write!(f, "IPv4Only"),
 | 
			
		||||
            ResolutionMode::Ipv6 => write!(f, "IPv6Only"),
 | 
			
		||||
            ResolutionMode::Ipv4AndIpv6 => write!(f, "IPv4 and IPv6"),
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										51
									
								
								src/upstreams/mod.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										51
									
								
								src/upstreams/mod.rs
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,51 @@
 | 
			
		||||
mod proxy_to_upstream;
 | 
			
		||||
 | 
			
		||||
use log::debug;
 | 
			
		||||
use serde::Deserialize;
 | 
			
		||||
use std::error::Error;
 | 
			
		||||
use tokio::io;
 | 
			
		||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
 | 
			
		||||
use tokio::net::TcpStream;
 | 
			
		||||
 | 
			
		||||
pub use crate::upstreams::proxy_to_upstream::ProxyToUpstream;
 | 
			
		||||
 | 
			
		||||
#[derive(Debug, Clone, Deserialize)]
 | 
			
		||||
pub enum Upstream {
 | 
			
		||||
    Ban,
 | 
			
		||||
    Echo,
 | 
			
		||||
    Proxy(ProxyToUpstream),
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl Upstream {
 | 
			
		||||
    pub(crate) async fn process(&self, mut inbound: TcpStream) -> Result<(), Box<dyn Error>> {
 | 
			
		||||
        match self {
 | 
			
		||||
            Upstream::Ban => {
 | 
			
		||||
                inbound.shutdown().await?;
 | 
			
		||||
            }
 | 
			
		||||
            Upstream::Echo => {
 | 
			
		||||
                let (mut ri, mut wi) = io::split(inbound);
 | 
			
		||||
                let inbound_to_inbound = copy(&mut ri, &mut wi);
 | 
			
		||||
                let bytes_tx = inbound_to_inbound.await;
 | 
			
		||||
                debug!("Bytes read: {:?}", bytes_tx);
 | 
			
		||||
            }
 | 
			
		||||
            Upstream::Proxy(config) => {
 | 
			
		||||
                config.proxy(inbound).await?;
 | 
			
		||||
            }
 | 
			
		||||
        };
 | 
			
		||||
        Ok(())
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
async fn copy<'a, R, W>(reader: &'a mut R, writer: &'a mut W) -> io::Result<u64>
 | 
			
		||||
where
 | 
			
		||||
    R: AsyncRead + Unpin + ?Sized,
 | 
			
		||||
    W: AsyncWrite + Unpin + ?Sized,
 | 
			
		||||
{
 | 
			
		||||
    match io::copy(reader, writer).await {
 | 
			
		||||
        Ok(u64) => {
 | 
			
		||||
            let _ = writer.shutdown().await;
 | 
			
		||||
            Ok(u64)
 | 
			
		||||
        }
 | 
			
		||||
        Err(_) => Ok(0),
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										68
									
								
								src/upstreams/proxy_to_upstream.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										68
									
								
								src/upstreams/proxy_to_upstream.rs
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,68 @@
 | 
			
		||||
use crate::servers::upstream_address::UpstreamAddress;
 | 
			
		||||
 | 
			
		||||
use crate::upstreams::copy;
 | 
			
		||||
use futures::future::try_join;
 | 
			
		||||
use log::{debug, error};
 | 
			
		||||
use serde::Deserialize;
 | 
			
		||||
use std::net::SocketAddr;
 | 
			
		||||
use tokio::io;
 | 
			
		||||
use tokio::net::TcpStream;
 | 
			
		||||
use tokio::sync::Mutex;
 | 
			
		||||
 | 
			
		||||
#[derive(Debug, Default)]
 | 
			
		||||
struct Addr(Mutex<UpstreamAddress>);
 | 
			
		||||
 | 
			
		||||
impl Clone for Addr {
 | 
			
		||||
    fn clone(&self) -> Self {
 | 
			
		||||
        tokio::task::block_in_place(|| Self(Mutex::new(self.0.blocking_lock().clone())))
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[derive(Debug, Clone, Deserialize, Default)]
 | 
			
		||||
pub struct ProxyToUpstream {
 | 
			
		||||
    pub addr: String,
 | 
			
		||||
    pub protocol: String,
 | 
			
		||||
    #[serde(skip_deserializing)]
 | 
			
		||||
    addresses: Addr,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl ProxyToUpstream {
 | 
			
		||||
    pub async fn resolve_addresses(&self) -> std::io::Result<Vec<SocketAddr>> {
 | 
			
		||||
        let mut addr = self.addresses.0.lock().await;
 | 
			
		||||
        addr.resolve((*self.protocol).into()).await
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn new(address: String, protocol: String) -> Self {
 | 
			
		||||
        Self {
 | 
			
		||||
            addr: address.clone(),
 | 
			
		||||
            protocol,
 | 
			
		||||
            addresses: Addr(Mutex::new(UpstreamAddress::new(address))),
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub(crate) async fn proxy(&self, inbound: TcpStream) -> Result<(), Box<dyn std::error::Error>> {
 | 
			
		||||
        let outbound = match self.protocol.as_ref() {
 | 
			
		||||
            "tcp4" | "tcp6" | "tcp" => {
 | 
			
		||||
                TcpStream::connect(self.resolve_addresses().await?.as_slice()).await?
 | 
			
		||||
            }
 | 
			
		||||
            _ => {
 | 
			
		||||
                error!("Reached unknown protocol: {:?}", self.protocol);
 | 
			
		||||
                return Err("Reached unknown protocol".into());
 | 
			
		||||
            }
 | 
			
		||||
        };
 | 
			
		||||
 | 
			
		||||
        debug!("Connected to {:?}", outbound.peer_addr().unwrap());
 | 
			
		||||
 | 
			
		||||
        let (mut ri, mut wi) = io::split(inbound);
 | 
			
		||||
        let (mut ro, mut wo) = io::split(outbound);
 | 
			
		||||
 | 
			
		||||
        let inbound_to_outbound = copy(&mut ri, &mut wo);
 | 
			
		||||
        let outbound_to_inbound = copy(&mut ro, &mut wi);
 | 
			
		||||
 | 
			
		||||
        let (bytes_tx, bytes_rx) = try_join(inbound_to_outbound, outbound_to_inbound).await?;
 | 
			
		||||
 | 
			
		||||
        debug!("Bytes read: {:?} write: {:?}", bytes_tx, bytes_rx);
 | 
			
		||||
 | 
			
		||||
        Ok(())
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
@@ -11,11 +11,26 @@ servers:
 | 
			
		||||
      proxy.test.com: proxy
 | 
			
		||||
      www.test.com: web
 | 
			
		||||
    default: ban
 | 
			
		||||
  echo_server:
 | 
			
		||||
  tcp_server:
 | 
			
		||||
    listen:
 | 
			
		||||
      - "127.0.0.1:54500"
 | 
			
		||||
    default: tester
 | 
			
		||||
  tcp_echo_server:
 | 
			
		||||
    listen:
 | 
			
		||||
      - "0.0.0.0:54956"
 | 
			
		||||
    default: echo
 | 
			
		||||
  kcp_server:
 | 
			
		||||
    protocol: kcp
 | 
			
		||||
    listen:
 | 
			
		||||
      - "127.0.0.1:54958"
 | 
			
		||||
    default: tester
 | 
			
		||||
  kcp_echo_server:
 | 
			
		||||
    protocol: kcp
 | 
			
		||||
    listen:
 | 
			
		||||
      - "127.0.0.1:54959"
 | 
			
		||||
    default: echo
 | 
			
		||||
 | 
			
		||||
upstream:
 | 
			
		||||
  web: "127.0.0.1:8080"
 | 
			
		||||
  proxy: "www.example.com:1024"
 | 
			
		||||
  web: "tcp://127.0.0.1:8080"
 | 
			
		||||
  proxy: "tcp://www.example.com:1024"
 | 
			
		||||
  tester: "tcp://127.0.0.1:54599"
 | 
			
		||||
		Reference in New Issue
	
	Block a user