feat(smart-proxy): add UDP transport support with QUIC/HTTP3 routing and datagram handler relay

This commit is contained in:
2026-03-19 15:06:27 +00:00
parent cfa958cf3d
commit 4fb91cd868
34 changed files with 2978 additions and 55 deletions

446
rust/Cargo.lock generated
View File

@@ -157,12 +157,24 @@ dependencies = [
"shlex",
]
[[package]]
name = "cesu8"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6d43a04d8753f35258c91f8ec639f792891f748a1edbd759cf1dcea3382ad83c"
[[package]]
name = "cfg-if"
version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801"
[[package]]
name = "cfg_aliases"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724"
[[package]]
name = "clap"
version = "4.5.57"
@@ -218,6 +230,16 @@ version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b05b61dc5112cbb17e4b6cd61790d9845d13888356391624cbe7e41efeac1e75"
[[package]]
name = "combine"
version = "4.6.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ba5a308b75df32fe02788e748662718f03fde005016435c444eea572398219fd"
dependencies = [
"bytes",
"memchr",
]
[[package]]
name = "core-foundation"
version = "0.10.1"
@@ -285,6 +307,24 @@ dependencies = [
"windows-sys 0.61.2",
]
[[package]]
name = "fastbloom"
version = "0.14.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4e7f34442dbe69c60fe8eaf58a8cafff81a1f278816d8ab4db255b3bef4ac3c4"
dependencies = [
"getrandom 0.3.4",
"libm",
"rand",
"siphasher",
]
[[package]]
name = "fastrand"
version = "2.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be"
[[package]]
name = "find-msvc-tools"
version = "0.1.9"
@@ -303,6 +343,21 @@ version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c"
[[package]]
name = "futures"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876"
dependencies = [
"futures-channel",
"futures-core",
"futures-executor",
"futures-io",
"futures-sink",
"futures-task",
"futures-util",
]
[[package]]
name = "futures-channel"
version = "0.3.31"
@@ -310,6 +365,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2dff15bf788c671c1934e366d07e30c1814a8ef514e1af724a602e8a2fbe1b10"
dependencies = [
"futures-core",
"futures-sink",
]
[[package]]
@@ -318,6 +374,34 @@ version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e"
[[package]]
name = "futures-executor"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e28d1d997f585e54aebc3f97d39e72338912123a67330d723fdbb564d646c9f"
dependencies = [
"futures-core",
"futures-task",
"futures-util",
]
[[package]]
name = "futures-io"
version = "0.3.32"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cecba35d7ad927e23624b22ad55235f2239cfa44fd10428eecbeba6d6a717718"
[[package]]
name = "futures-macro"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "futures-sink"
version = "0.3.31"
@@ -336,10 +420,16 @@ version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81"
dependencies = [
"futures-channel",
"futures-core",
"futures-io",
"futures-macro",
"futures-sink",
"futures-task",
"memchr",
"pin-project-lite",
"pin-utils",
"slab",
]
[[package]]
@@ -362,9 +452,11 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "899def5c37c4fd7b2664648c28120ecec138e4d395b459e5ca34f9cce2dd77fd"
dependencies = [
"cfg-if",
"js-sys",
"libc",
"r-efi",
"wasip2",
"wasm-bindgen",
]
[[package]]
@@ -392,6 +484,34 @@ dependencies = [
"tracing",
]
[[package]]
name = "h3"
version = "0.0.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "10872b55cfb02a821b69dc7cf8dc6a71d6af25eb9a79662bec4a9d016056b3be"
dependencies = [
"bytes",
"fastrand",
"futures-util",
"http",
"pin-project-lite",
"tokio",
]
[[package]]
name = "h3-quinn"
version = "0.0.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8b2e732c8d91a74731663ac8479ab505042fbf547b9a207213ab7fbcbfc4f8b4"
dependencies = [
"bytes",
"futures",
"h3",
"quinn",
"tokio",
"tokio-util",
]
[[package]]
name = "hashbrown"
version = "0.14.5"
@@ -565,6 +685,28 @@ version = "1.0.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "92ecc6618181def0457392ccd0ee51198e065e016d1d527a7ac1b6dc7c1f09d2"
[[package]]
name = "jni"
version = "0.21.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1a87aa2bb7d2af34197c04845522473242e1aa17c12f4935d5856491a7fb8c97"
dependencies = [
"cesu8",
"cfg-if",
"combine",
"jni-sys",
"log",
"thiserror 1.0.69",
"walkdir",
"windows-sys 0.45.0",
]
[[package]]
name = "jni-sys"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8eaf4bc02d17cbdd7ff4c7438cafcdf7fb9a4613313ad11b4f8fefe7d3fa0130"
[[package]]
name = "jobserver"
version = "0.1.34"
@@ -612,6 +754,12 @@ version = "0.2.180"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bcc35a38544a891a5f7c865aca548a982ccb3b8650a5b06d0fd33a10283c56fc"
[[package]]
name = "libm"
version = "0.2.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b6d2cec3eae94f9f509c767b45932f1ada8350c4bdb85af2fcab4a3c14807981"
[[package]]
name = "libmimalloc-sys"
version = "0.1.44"
@@ -637,6 +785,12 @@ version = "0.4.29"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897"
[[package]]
name = "lru-slab"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "112b39cec0b298b6c1999fee3e31427f74f676e4cb9879ed1a121b43661a4154"
[[package]]
name = "matchers"
version = "0.2.0"
@@ -784,6 +938,15 @@ version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391"
[[package]]
name = "ppv-lite86"
version = "0.2.21"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "85eae3c4ed2f50dcfe72643da4befc30deadb458a9b590d720cde2f2b1e97da9"
dependencies = [
"zerocopy",
]
[[package]]
name = "proc-macro2"
version = "1.0.106"
@@ -793,6 +956,64 @@ dependencies = [
"unicode-ident",
]
[[package]]
name = "quinn"
version = "0.11.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b9e20a958963c291dc322d98411f541009df2ced7b5a4f2bd52337638cfccf20"
dependencies = [
"bytes",
"cfg_aliases",
"futures-io",
"pin-project-lite",
"quinn-proto",
"quinn-udp",
"rustc-hash",
"rustls",
"socket2 0.6.2",
"thiserror 2.0.18",
"tokio",
"tracing",
"web-time",
]
[[package]]
name = "quinn-proto"
version = "0.11.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "434b42fec591c96ef50e21e886936e66d3cc3f737104fdb9b737c40ffb94c098"
dependencies = [
"bytes",
"fastbloom",
"getrandom 0.3.4",
"lru-slab",
"rand",
"ring",
"rustc-hash",
"rustls",
"rustls-pki-types",
"rustls-platform-verifier",
"slab",
"thiserror 2.0.18",
"tinyvec",
"tracing",
"web-time",
]
[[package]]
name = "quinn-udp"
version = "0.5.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "addec6a0dcad8a8d96a771f815f0eaf55f9d1805756410b39f5fa81332574cbd"
dependencies = [
"cfg_aliases",
"libc",
"once_cell",
"socket2 0.6.2",
"tracing",
"windows-sys 0.60.2",
]
[[package]]
name = "quote"
version = "1.0.44"
@@ -808,6 +1029,35 @@ version = "5.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f"
[[package]]
name = "rand"
version = "0.9.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6db2770f06117d490610c7488547d543617b21bfa07796d7a12f6f1bd53850d1"
dependencies = [
"rand_chacha",
"rand_core",
]
[[package]]
name = "rand_chacha"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb"
dependencies = [
"ppv-lite86",
"rand_core",
]
[[package]]
name = "rand_core"
version = "0.9.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "76afc826de14238e6e8c374ddcc1fa19e374fd8dd986b0d2af0d02377261d83c"
dependencies = [
"getrandom 0.3.4",
]
[[package]]
name = "rcgen"
version = "0.13.2"
@@ -873,6 +1123,12 @@ dependencies = [
"windows-sys 0.52.0",
]
[[package]]
name = "rustc-hash"
version = "2.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d"
[[package]]
name = "rustls"
version = "0.23.36"
@@ -916,9 +1172,37 @@ version = "1.14.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "be040f8b0a225e40375822a563fa9524378b9d63112f53e19ffff34df5d33fdd"
dependencies = [
"web-time",
"zeroize",
]
[[package]]
name = "rustls-platform-verifier"
version = "0.6.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1d99feebc72bae7ab76ba994bb5e121b8d83d910ca40b36e0921f53becc41784"
dependencies = [
"core-foundation",
"core-foundation-sys",
"jni",
"log",
"once_cell",
"rustls",
"rustls-native-certs",
"rustls-platform-verifier-android",
"rustls-webpki",
"security-framework",
"security-framework-sys",
"webpki-root-certs",
"windows-sys 0.61.2",
]
[[package]]
name = "rustls-platform-verifier-android"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f87165f0995f63a9fbeea62b64d10b4d9d8e78ec6d7d51fb2125fda7bb36788f"
[[package]]
name = "rustls-webpki"
version = "0.103.9"
@@ -946,6 +1230,7 @@ dependencies = [
"mimalloc",
"rcgen",
"rustls",
"rustls-pemfile",
"rustproxy-config",
"rustproxy-http",
"rustproxy-metrics",
@@ -981,10 +1266,13 @@ dependencies = [
"arc-swap",
"bytes",
"dashmap",
"h3",
"h3-quinn",
"http-body",
"http-body-util",
"hyper",
"hyper-util",
"quinn",
"regex",
"rustls",
"rustproxy-config",
@@ -1031,7 +1319,10 @@ version = "0.1.0"
dependencies = [
"anyhow",
"arc-swap",
"base64",
"dashmap",
"quinn",
"rcgen",
"rustls",
"rustls-pemfile",
"rustproxy-config",
@@ -1096,6 +1387,15 @@ version = "1.0.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d"
[[package]]
name = "same-file"
version = "1.0.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502"
dependencies = [
"winapi-util",
]
[[package]]
name = "schannel"
version = "0.1.28"
@@ -1214,6 +1514,12 @@ dependencies = [
"time",
]
[[package]]
name = "siphasher"
version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b2aa850e253778c88a04c3d7323b043aeda9d3e30d5971937c1855769763678e"
[[package]]
name = "slab"
version = "0.4.12"
@@ -1349,6 +1655,21 @@ dependencies = [
"time-core",
]
[[package]]
name = "tinyvec"
version = "1.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3e61e67053d25a4e82c844e8424039d9745781b3fc4f32b8d55ed50f5f667ef3"
dependencies = [
"tinyvec_macros",
]
[[package]]
name = "tinyvec_macros"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20"
[[package]]
name = "tokio"
version = "1.49.0"
@@ -1412,6 +1733,7 @@ version = "0.1.44"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "63e71662fa4b2a2c3a26f570f037eb95bb1f85397f3cd8076caed2f026a6d100"
dependencies = [
"log",
"pin-project-lite",
"tracing-attributes",
"tracing-core",
@@ -1497,6 +1819,16 @@ version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ba73ea9cf16a25df0c8caa16c51acb937d5712a8429db78a3ee29d5dcacd3a65"
[[package]]
name = "walkdir"
version = "2.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "29790946404f91d9c5d06f9874efddea1dc06c5efe94541a7d6863108e3a5e4b"
dependencies = [
"same-file",
"winapi-util",
]
[[package]]
name = "want"
version = "0.3.1"
@@ -1566,12 +1898,49 @@ dependencies = [
"unicode-ident",
]
[[package]]
name = "web-time"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5a6580f308b1fad9207618087a65c04e7a10bc77e02c8e84e9b00dd4b12fa0bb"
dependencies = [
"js-sys",
"wasm-bindgen",
]
[[package]]
name = "webpki-root-certs"
version = "1.0.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "804f18a4ac2676ffb4e8b5b5fa9ae38af06df08162314f96a68d2a363e21a8ca"
dependencies = [
"rustls-pki-types",
]
[[package]]
name = "winapi-util"
version = "0.1.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22"
dependencies = [
"windows-sys 0.61.2",
]
[[package]]
name = "windows-link"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5"
[[package]]
name = "windows-sys"
version = "0.45.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "75283be5efb2831d37ea142365f009c02ec203cd29a3ebecbc093d52315b66d0"
dependencies = [
"windows-targets 0.42.2",
]
[[package]]
name = "windows-sys"
version = "0.52.0"
@@ -1599,6 +1968,21 @@ dependencies = [
"windows-link",
]
[[package]]
name = "windows-targets"
version = "0.42.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8e5180c00cd44c9b1c88adb3693291f1cd93605ded80c250a75d472756b4d071"
dependencies = [
"windows_aarch64_gnullvm 0.42.2",
"windows_aarch64_msvc 0.42.2",
"windows_i686_gnu 0.42.2",
"windows_i686_msvc 0.42.2",
"windows_x86_64_gnu 0.42.2",
"windows_x86_64_gnullvm 0.42.2",
"windows_x86_64_msvc 0.42.2",
]
[[package]]
name = "windows-targets"
version = "0.52.6"
@@ -1632,6 +2016,12 @@ dependencies = [
"windows_x86_64_msvc 0.53.1",
]
[[package]]
name = "windows_aarch64_gnullvm"
version = "0.42.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "597a5118570b68bc08d8d59125332c54f1ba9d9adeedeef5b99b02ba2b0698f8"
[[package]]
name = "windows_aarch64_gnullvm"
version = "0.52.6"
@@ -1644,6 +2034,12 @@ version = "0.53.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a9d8416fa8b42f5c947f8482c43e7d89e73a173cead56d044f6a56104a6d1b53"
[[package]]
name = "windows_aarch64_msvc"
version = "0.42.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e08e8864a60f06ef0d0ff4ba04124db8b0fb3be5776a5cd47641e942e58c4d43"
[[package]]
name = "windows_aarch64_msvc"
version = "0.52.6"
@@ -1656,6 +2052,12 @@ version = "0.53.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b9d782e804c2f632e395708e99a94275910eb9100b2114651e04744e9b125006"
[[package]]
name = "windows_i686_gnu"
version = "0.42.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c61d927d8da41da96a81f029489353e68739737d3beca43145c8afec9a31a84f"
[[package]]
name = "windows_i686_gnu"
version = "0.52.6"
@@ -1680,6 +2082,12 @@ version = "0.53.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fa7359d10048f68ab8b09fa71c3daccfb0e9b559aed648a8f95469c27057180c"
[[package]]
name = "windows_i686_msvc"
version = "0.42.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "44d840b6ec649f480a41c8d80f9c65108b92d89345dd94027bfe06ac444d1060"
[[package]]
name = "windows_i686_msvc"
version = "0.52.6"
@@ -1692,6 +2100,12 @@ version = "0.53.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e7ac75179f18232fe9c285163565a57ef8d3c89254a30685b57d83a38d326c2"
[[package]]
name = "windows_x86_64_gnu"
version = "0.42.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8de912b8b8feb55c064867cf047dda097f92d51efad5b491dfb98f6bbb70cb36"
[[package]]
name = "windows_x86_64_gnu"
version = "0.52.6"
@@ -1704,6 +2118,12 @@ version = "0.53.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9c3842cdd74a865a8066ab39c8a7a473c0778a3f29370b5fd6b4b9aa7df4a499"
[[package]]
name = "windows_x86_64_gnullvm"
version = "0.42.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "26d41b46a36d453748aedef1486d5c7a85db22e56aff34643984ea85514e94a3"
[[package]]
name = "windows_x86_64_gnullvm"
version = "0.52.6"
@@ -1716,6 +2136,12 @@ version = "0.53.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0ffa179e2d07eee8ad8f57493436566c7cc30ac536a3379fdf008f47f6bb7ae1"
[[package]]
name = "windows_x86_64_msvc"
version = "0.42.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9aec5da331524158c6d1a4ac0ab1541149c0b9505fde06423b02f5ef0106b9f0"
[[package]]
name = "windows_x86_64_msvc"
version = "0.52.6"
@@ -1743,6 +2169,26 @@ dependencies = [
"time",
]
[[package]]
name = "zerocopy"
version = "0.8.42"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f2578b716f8a7a858b7f02d5bd870c14bf4ddbbcf3a4c05414ba6503640505e3"
dependencies = [
"zerocopy-derive",
]
[[package]]
name = "zerocopy-derive"
version = "0.8.42"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7e6cc098ea4d3bd6246687de65af3f920c430e236bee1e3bf2e441463f08a02f"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "zeroize"
version = "1.8.2"

View File

@@ -91,6 +91,13 @@ libc = "0.2"
# Socket-level options (keepalive, etc.)
socket2 = { version = "0.5", features = ["all"] }
# QUIC transport
quinn = "0.11"
# HTTP/3 protocol
h3 = "0.0.8"
h3-quinn = "0.0.10"
# mimalloc allocator (prevents glibc fragmentation / slow RSS growth)
mimalloc = "0.1"

View File

@@ -15,6 +15,7 @@ pub fn create_http_route(
domains: Some(domains.into()),
path: None,
client_ip: None,
transport: None,
tls_version: None,
headers: None,
protocol: None,
@@ -31,6 +32,7 @@ pub fn create_http_route(
send_proxy_protocol: None,
headers: None,
advanced: None,
backend_transport: None,
priority: None,
}]),
tls: None,
@@ -41,6 +43,7 @@ pub fn create_http_route(
forwarding_engine: None,
nftables: None,
send_proxy_protocol: None,
udp: None,
},
headers: None,
security: None,
@@ -107,6 +110,7 @@ pub fn create_http_to_https_redirect(
domains: Some(domains),
path: None,
client_ip: None,
transport: None,
tls_version: None,
headers: None,
protocol: None,
@@ -137,6 +141,7 @@ pub fn create_http_to_https_redirect(
forwarding_engine: None,
nftables: None,
send_proxy_protocol: None,
udp: None,
},
headers: None,
security: None,
@@ -187,6 +192,7 @@ pub fn create_load_balancer_route(
send_proxy_protocol: None,
headers: None,
advanced: None,
backend_transport: None,
priority: None,
})
.collect();
@@ -200,6 +206,7 @@ pub fn create_load_balancer_route(
domains: Some(domains.into()),
path: None,
client_ip: None,
transport: None,
tls_version: None,
headers: None,
protocol: None,
@@ -218,6 +225,7 @@ pub fn create_load_balancer_route(
forwarding_engine: None,
nftables: None,
send_proxy_protocol: None,
udp: None,
},
headers: None,
security: None,

View File

@@ -7,16 +7,24 @@ use crate::security_types::RouteSecurity;
// ─── Port Range ──────────────────────────────────────────────────────
/// Port range specification format.
/// Matches TypeScript: `type TPortRange = number | number[] | Array<{ from: number; to: number }>`
/// Matches TypeScript: `type TPortRange = number | Array<number | { from: number; to: number }>`
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum PortRange {
/// Single port number
Single(u16),
/// Array of port numbers
List(Vec<u16>),
/// Array of port ranges
Ranges(Vec<PortRangeSpec>),
/// Array of port numbers, ranges, or mixed
List(Vec<PortRangeItem>),
}
/// A single item in a port range array: either a number or a from-to range.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum PortRangeItem {
/// Single port number
Port(u16),
/// A from-to port range
Range(PortRangeSpec),
}
impl PortRange {
@@ -24,9 +32,11 @@ impl PortRange {
pub fn to_ports(&self) -> Vec<u16> {
match self {
PortRange::Single(p) => vec![*p],
PortRange::List(ports) => ports.clone(),
PortRange::Ranges(ranges) => {
ranges.iter().flat_map(|r| r.from..=r.to).collect()
PortRange::List(items) => {
items.iter().flat_map(|item| match item {
PortRangeItem::Port(p) => vec![*p],
PortRangeItem::Range(r) => (r.from..=r.to).collect(),
}).collect()
}
}
}
@@ -95,6 +105,10 @@ pub struct RouteMatch {
/// Listen on these ports (required)
pub ports: PortRange,
/// Transport protocol: tcp (default), udp, or all (both TCP and UDP)
#[serde(skip_serializing_if = "Option::is_none")]
pub transport: Option<TransportProtocol>,
/// Optional domain patterns to match (default: all domains)
#[serde(skip_serializing_if = "Option::is_none")]
pub domains: Option<DomainSpec>,
@@ -115,7 +129,7 @@ pub struct RouteMatch {
#[serde(skip_serializing_if = "Option::is_none")]
pub headers: Option<HashMap<String, String>>,
/// Match specific protocol: "http" (includes h2 + websocket) or "tcp"
/// Match specific protocol: "http", "tcp", "udp", "quic", "http3"
#[serde(skip_serializing_if = "Option::is_none")]
pub protocol: Option<String>,
}
@@ -367,9 +381,19 @@ pub struct NfTablesOptions {
pub enum BackendProtocol {
Http1,
Http2,
Http3,
Auto,
}
/// Transport protocol for route matching.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum TransportProtocol {
Tcp,
Udp,
All,
}
/// Action options.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
@@ -470,6 +494,10 @@ pub struct RouteTarget {
#[serde(skip_serializing_if = "Option::is_none")]
pub advanced: Option<RouteAdvanced>,
/// Override transport for backend connection (e.g., receive QUIC but forward as TCP)
#[serde(skip_serializing_if = "Option::is_none")]
pub backend_transport: Option<TransportProtocol>,
/// Priority for matching (higher values checked first, default: 0)
#[serde(skip_serializing_if = "Option::is_none")]
pub priority: Option<i32>,
@@ -524,6 +552,68 @@ pub struct RouteAction {
/// PROXY protocol support (default for all targets)
#[serde(skip_serializing_if = "Option::is_none")]
pub send_proxy_protocol: Option<bool>,
/// UDP-specific settings (session tracking, datagram limits, QUIC config)
#[serde(skip_serializing_if = "Option::is_none")]
pub udp: Option<RouteUdp>,
}
// ─── UDP & QUIC Config ──────────────────────────────────────────────
/// UDP-specific settings for route actions.
/// Matches TypeScript: `IRouteUdp`
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct RouteUdp {
/// Idle timeout for a UDP session/flow in ms. Default: 60000
#[serde(skip_serializing_if = "Option::is_none")]
pub session_timeout: Option<u64>,
/// Max concurrent UDP sessions per source IP. Default: 1000
#[serde(skip_serializing_if = "Option::is_none")]
pub max_sessions_per_ip: Option<u32>,
/// Max accepted datagram size in bytes. Default: 65535
#[serde(skip_serializing_if = "Option::is_none")]
pub max_datagram_size: Option<u32>,
/// QUIC-specific configuration
#[serde(skip_serializing_if = "Option::is_none")]
pub quic: Option<RouteQuic>,
}
/// QUIC and HTTP/3 settings.
/// Matches TypeScript: `IRouteQuic`
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct RouteQuic {
/// QUIC connection idle timeout in ms. Default: 30000
#[serde(skip_serializing_if = "Option::is_none")]
pub max_idle_timeout: Option<u64>,
/// Max concurrent bidirectional streams per QUIC connection. Default: 100
#[serde(skip_serializing_if = "Option::is_none")]
pub max_concurrent_bidi_streams: Option<u32>,
/// Max concurrent unidirectional streams per QUIC connection. Default: 100
#[serde(skip_serializing_if = "Option::is_none")]
pub max_concurrent_uni_streams: Option<u32>,
/// Enable HTTP/3 over this QUIC endpoint. Default: false
#[serde(skip_serializing_if = "Option::is_none")]
pub enable_http3: Option<bool>,
/// Port to advertise in Alt-Svc header on TCP HTTP responses
#[serde(skip_serializing_if = "Option::is_none")]
pub alt_svc_port: Option<u16>,
/// Max age for Alt-Svc advertisement in seconds. Default: 86400
#[serde(skip_serializing_if = "Option::is_none")]
pub alt_svc_max_age: Option<u64>,
/// Initial congestion window size in bytes
#[serde(skip_serializing_if = "Option::is_none")]
pub initial_congestion_window: Option<u32>,
}
// ─── Route Config ────────────────────────────────────────────────────

View File

@@ -27,3 +27,6 @@ arc-swap = { workspace = true }
dashmap = { workspace = true }
tokio-util = { workspace = true }
socket2 = { workspace = true }
quinn = { workspace = true }
h3 = { workspace = true }
h3-quinn = { workspace = true }

View File

@@ -1,7 +1,7 @@
//! Backend connection pool for HTTP/1.1 and HTTP/2.
//! Backend connection pool for HTTP/1.1, HTTP/2, and HTTP/3 (QUIC).
//!
//! Reuses idle keep-alive connections to avoid per-request TCP+TLS handshakes.
//! HTTP/2 connections are multiplexed (clone the sender for each request).
//! HTTP/2 and HTTP/3 connections are multiplexed (clone the sender / share the connection).
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
@@ -19,9 +19,17 @@ const IDLE_TIMEOUT: Duration = Duration::from_secs(90);
/// Background eviction interval.
const EVICTION_INTERVAL: Duration = Duration::from_secs(30);
/// Maximum age for pooled HTTP/2 connections before proactive eviction.
/// Prevents staleness from backends that close idle connections (e.g. nginx GOAWAY).
/// 120s is well within typical server GOAWAY windows (nginx: ~60s idle, envoy: ~60s).
const MAX_H2_AGE: Duration = Duration::from_secs(120);
/// Maximum age for pooled QUIC/HTTP/3 connections.
const MAX_H3_AGE: Duration = Duration::from_secs(120);
/// Protocol for pool key discrimination.
#[derive(Clone, Debug, Hash, Eq, PartialEq)]
pub enum PoolProtocol {
H1,
H2,
H3,
}
/// Identifies a unique backend endpoint.
#[derive(Clone, Debug, Hash, Eq, PartialEq)]
@@ -29,7 +37,7 @@ pub struct PoolKey {
pub host: String,
pub port: u16,
pub use_tls: bool,
pub h2: bool,
pub protocol: PoolProtocol,
}
/// An idle HTTP/1.1 sender with a timestamp for eviction.
@@ -47,13 +55,22 @@ struct PooledH2 {
generation: u64,
}
/// A pooled QUIC/HTTP/3 connection (multiplexed like H2).
pub struct PooledH3 {
pub connection: quinn::Connection,
pub created_at: Instant,
pub generation: u64,
}
/// Backend connection pool.
pub struct ConnectionPool {
/// HTTP/1.1 idle connections indexed by backend key.
h1_pool: Arc<DashMap<PoolKey, Vec<IdleH1>>>,
/// HTTP/2 multiplexed connections indexed by backend key.
h2_pool: Arc<DashMap<PoolKey, PooledH2>>,
/// Monotonic generation counter for H2 pool entries.
/// HTTP/3 (QUIC) connections indexed by backend key.
h3_pool: Arc<DashMap<PoolKey, PooledH3>>,
/// Monotonic generation counter for H2/H3 pool entries.
h2_generation: AtomicU64,
/// Handle for the background eviction task.
eviction_handle: Option<tokio::task::JoinHandle<()>>,
@@ -64,16 +81,19 @@ impl ConnectionPool {
pub fn new() -> Self {
let h1_pool: Arc<DashMap<PoolKey, Vec<IdleH1>>> = Arc::new(DashMap::new());
let h2_pool: Arc<DashMap<PoolKey, PooledH2>> = Arc::new(DashMap::new());
let h3_pool: Arc<DashMap<PoolKey, PooledH3>> = Arc::new(DashMap::new());
let h1_clone = Arc::clone(&h1_pool);
let h2_clone = Arc::clone(&h2_pool);
let h3_clone = Arc::clone(&h3_pool);
let eviction_handle = tokio::spawn(async move {
Self::eviction_loop(h1_clone, h2_clone).await;
Self::eviction_loop(h1_clone, h2_clone, h3_clone).await;
});
Self {
h1_pool,
h2_pool,
h3_pool,
h2_generation: AtomicU64::new(0),
eviction_handle: Some(eviction_handle),
}
@@ -173,10 +193,57 @@ impl ConnectionPool {
gen
}
// ── HTTP/3 (QUIC) pool methods ──
/// Try to get a pooled QUIC connection for the given key.
/// QUIC connections are multiplexed — the connection is shared, not removed.
pub fn checkout_h3(&self, key: &PoolKey) -> Option<(quinn::Connection, Duration)> {
let entry = self.h3_pool.get(key)?;
let pooled = entry.value();
let age = pooled.created_at.elapsed();
if age >= MAX_H3_AGE {
drop(entry);
self.h3_pool.remove(key);
return None;
}
// Check if QUIC connection is still alive
if pooled.connection.close_reason().is_some() {
drop(entry);
self.h3_pool.remove(key);
return None;
}
Some((pooled.connection.clone(), age))
}
/// Register a QUIC connection in the pool. Returns the generation ID.
pub fn register_h3(&self, key: PoolKey, connection: quinn::Connection) -> u64 {
let gen = self.h2_generation.fetch_add(1, Ordering::Relaxed);
self.h3_pool.insert(key, PooledH3 {
connection,
created_at: Instant::now(),
generation: gen,
});
gen
}
/// Remove a QUIC connection only if generation matches.
pub fn remove_h3_if_generation(&self, key: &PoolKey, expected_gen: u64) {
if let Some(entry) = self.h3_pool.get(key) {
if entry.value().generation == expected_gen {
drop(entry);
self.h3_pool.remove(key);
}
}
}
/// Background eviction loop — runs every EVICTION_INTERVAL to remove stale connections.
async fn eviction_loop(
h1_pool: Arc<DashMap<PoolKey, Vec<IdleH1>>>,
h2_pool: Arc<DashMap<PoolKey, PooledH2>>,
h3_pool: Arc<DashMap<PoolKey, PooledH3>>,
) {
let mut interval = tokio::time::interval(EVICTION_INTERVAL);
loop {
@@ -206,6 +273,19 @@ impl ConnectionPool {
for key in dead_h2 {
h2_pool.remove(&key);
}
// Evict dead or aged-out H3 (QUIC) connections
let mut dead_h3 = Vec::new();
for entry in h3_pool.iter() {
if entry.value().connection.close_reason().is_some()
|| entry.value().created_at.elapsed() >= MAX_H3_AGE
{
dead_h3.push(entry.key().clone());
}
}
for key in dead_h3 {
h3_pool.remove(&key);
}
}
}
}

View File

@@ -0,0 +1,288 @@
//! HTTP/3 proxy service.
//!
//! Accepts QUIC connections via quinn, runs h3 server to handle HTTP/3 requests,
//! and forwards them to backends using the same routing and pool infrastructure
//! as the HTTP/1+2 proxy.
use std::sync::Arc;
use std::time::Duration;
use arc_swap::ArcSwap;
use bytes::{Buf, Bytes};
use tracing::{debug, warn};
use rustproxy_config::{RouteConfig, TransportProtocol};
use rustproxy_metrics::MetricsCollector;
use rustproxy_routing::{MatchContext, RouteManager};
use crate::connection_pool::ConnectionPool;
use crate::protocol_cache::ProtocolCache;
use crate::upstream_selector::UpstreamSelector;
/// HTTP/3 proxy service.
///
/// Handles QUIC connections with the h3 crate, parses HTTP/3 requests,
/// and forwards them to backends using per-request route matching and
/// shared connection pooling.
pub struct H3ProxyService {
route_manager: Arc<ArcSwap<RouteManager>>,
metrics: Arc<MetricsCollector>,
connection_pool: Arc<ConnectionPool>,
#[allow(dead_code)]
protocol_cache: Arc<ProtocolCache>,
#[allow(dead_code)]
upstream_selector: UpstreamSelector,
#[allow(dead_code)]
backend_tls_config: Arc<rustls::ClientConfig>,
connect_timeout: Duration,
}
impl H3ProxyService {
pub fn new(
route_manager: Arc<ArcSwap<RouteManager>>,
metrics: Arc<MetricsCollector>,
connection_pool: Arc<ConnectionPool>,
protocol_cache: Arc<ProtocolCache>,
backend_tls_config: Arc<rustls::ClientConfig>,
connect_timeout: Duration,
) -> Self {
Self {
route_manager: Arc::clone(&route_manager),
metrics: Arc::clone(&metrics),
connection_pool,
protocol_cache,
upstream_selector: UpstreamSelector::new(),
backend_tls_config,
connect_timeout,
}
}
/// Handle an accepted QUIC connection as HTTP/3.
pub async fn handle_connection(
&self,
connection: quinn::Connection,
_fallback_route: &RouteConfig,
port: u16,
) -> anyhow::Result<()> {
let remote_addr = connection.remote_address();
debug!("HTTP/3 connection from {} on port {}", remote_addr, port);
let mut h3_conn: h3::server::Connection<h3_quinn::Connection, Bytes> =
h3::server::Connection::new(h3_quinn::Connection::new(connection))
.await
.map_err(|e| anyhow::anyhow!("H3 connection setup failed: {}", e))?;
let client_ip = remote_addr.ip().to_string();
loop {
match h3_conn.accept().await {
Ok(Some(resolver)) => {
let (request, stream) = match resolver.resolve_request().await {
Ok(pair) => pair,
Err(e) => {
debug!("HTTP/3 request resolve error: {}", e);
continue;
}
};
self.metrics.record_http_request();
let rm = self.route_manager.load();
let pool = Arc::clone(&self.connection_pool);
let metrics = Arc::clone(&self.metrics);
let connect_timeout = self.connect_timeout;
let client_ip = client_ip.clone();
tokio::spawn(async move {
if let Err(e) = handle_h3_request(
request, stream, port, &client_ip, &rm, &pool, &metrics, connect_timeout,
).await {
debug!("HTTP/3 request error from {}: {}", client_ip, e);
}
});
}
Ok(None) => {
debug!("HTTP/3 connection from {} closed", remote_addr);
break;
}
Err(e) => {
debug!("HTTP/3 accept error from {}: {}", remote_addr, e);
break;
}
}
}
Ok(())
}
}
/// Handle a single HTTP/3 request with per-request route matching.
async fn handle_h3_request(
request: hyper::Request<()>,
mut stream: h3::server::RequestStream<h3_quinn::BidiStream<Bytes>, Bytes>,
port: u16,
client_ip: &str,
route_manager: &RouteManager,
_connection_pool: &ConnectionPool,
metrics: &MetricsCollector,
connect_timeout: Duration,
) -> anyhow::Result<()> {
let method = request.method().clone();
let uri = request.uri().clone();
let path = uri.path().to_string();
// Extract host from :authority or Host header
let host = request.uri().authority()
.map(|a| a.as_str().to_string())
.or_else(|| request.headers().get("host").and_then(|v| v.to_str().ok()).map(|s| s.to_string()))
.unwrap_or_default();
debug!("HTTP/3 {} {} (host: {}, client: {})", method, path, host, client_ip);
// Per-request route matching
let ctx = MatchContext {
port,
domain: if host.is_empty() { None } else { Some(&host) },
path: Some(&path),
client_ip: Some(client_ip),
tls_version: Some("TLSv1.3"),
headers: None,
is_tls: true,
protocol: Some("http"),
transport: Some(TransportProtocol::Udp),
};
let route_match = route_manager.find_route(&ctx)
.ok_or_else(|| anyhow::anyhow!("No route matched for HTTP/3 request to {}{}", host, path))?;
let route = route_match.route;
// Resolve backend target (use matched target or first target)
let target = route_match.target
.or_else(|| route.action.targets.as_ref().and_then(|t| t.first()))
.ok_or_else(|| anyhow::anyhow!("No target for HTTP/3 route"))?;
let backend_host = target.host.first();
let backend_port = target.port.resolve(port);
let backend_addr = format!("{}:{}", backend_host, backend_port);
// Read request body
let mut body_data = Vec::new();
while let Some(mut chunk) = stream.recv_data().await
.map_err(|e| anyhow::anyhow!("Failed to read H3 request body: {}", e))?
{
body_data.extend_from_slice(chunk.chunk());
chunk.advance(chunk.remaining());
}
// Connect to backend via TCP HTTP/1.1 with timeout
let tcp_stream = tokio::time::timeout(
connect_timeout,
tokio::net::TcpStream::connect(&backend_addr),
).await
.map_err(|_| anyhow::anyhow!("Backend connect timeout to {}", backend_addr))?
.map_err(|e| anyhow::anyhow!("Backend connect to {} failed: {}", backend_addr, e))?;
let _ = tcp_stream.set_nodelay(true);
let io = hyper_util::rt::TokioIo::new(tcp_stream);
let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await
.map_err(|e| anyhow::anyhow!("Backend handshake failed: {}", e))?;
tokio::spawn(async move {
if let Err(e) = conn.await {
debug!("Backend connection closed: {}", e);
}
});
let body = http_body_util::Full::new(Bytes::from(body_data));
let backend_req = build_backend_request(&method, &backend_addr, &path, &host, &request, body)?;
let response = sender.send_request(backend_req).await
.map_err(|e| anyhow::anyhow!("Backend request failed: {}", e))?;
// Build H3 response
let status = response.status();
let mut h3_response = hyper::Response::builder().status(status);
// Copy response headers (skip hop-by-hop)
for (name, value) in response.headers() {
let n = name.as_str().to_lowercase();
if n == "transfer-encoding" || n == "connection" || n == "keep-alive" || n == "upgrade" {
continue;
}
h3_response = h3_response.header(name, value);
}
// Add Alt-Svc for HTTP/3 advertisement
let alt_svc = route.action.udp.as_ref()
.and_then(|u| u.quic.as_ref())
.map(|q| {
let p = q.alt_svc_port.unwrap_or(port);
let ma = q.alt_svc_max_age.unwrap_or(86400);
format!("h3=\":{}\"; ma={}", p, ma)
})
.unwrap_or_else(|| format!("h3=\":{}\"; ma=86400", port));
h3_response = h3_response.header("alt-svc", alt_svc);
let h3_response = h3_response.body(())
.map_err(|e| anyhow::anyhow!("Failed to build H3 response: {}", e))?;
// Send response headers
stream.send_response(h3_response).await
.map_err(|e| anyhow::anyhow!("Failed to send H3 response: {}", e))?;
// Stream response body back
use http_body_util::BodyExt;
let mut body = response.into_body();
let mut total_bytes_out: u64 = 0;
while let Some(frame) = body.frame().await {
match frame {
Ok(frame) => {
if let Some(data) = frame.data_ref() {
total_bytes_out += data.len() as u64;
stream.send_data(Bytes::copy_from_slice(data)).await
.map_err(|e| anyhow::anyhow!("Failed to send H3 data: {}", e))?;
}
}
Err(e) => {
warn!("Backend body read error: {}", e);
break;
}
}
}
// Record metrics
let route_id = route.name.as_deref().or(route.id.as_deref());
metrics.record_bytes(0, total_bytes_out, route_id, Some(client_ip));
// Finish the stream
stream.finish().await
.map_err(|e| anyhow::anyhow!("Failed to finish H3 stream: {}", e))?;
Ok(())
}
/// Build an HTTP/1.1 backend request from the H3 frontend request.
fn build_backend_request(
method: &hyper::Method,
backend_addr: &str,
path: &str,
host: &str,
original_request: &hyper::Request<()>,
body: http_body_util::Full<Bytes>,
) -> anyhow::Result<hyper::Request<http_body_util::Full<Bytes>>> {
let mut req = hyper::Request::builder()
.method(method)
.uri(format!("http://{}{}", backend_addr, path))
.header("host", host);
// Forward non-pseudo headers
for (name, value) in original_request.headers() {
let n = name.as_str();
if !n.starts_with(':') && n != "host" {
req = req.header(name, value);
}
}
req.body(body)
.map_err(|e| anyhow::anyhow!("Failed to build backend request: {}", e))
}

View File

@@ -12,6 +12,7 @@ pub mod response_filter;
pub mod shutdown_on_drop;
pub mod template;
pub mod upstream_selector;
pub mod h3_service;
pub use connection_pool::*;
pub use counting_body::*;

View File

@@ -26,6 +26,7 @@ const PROTOCOL_CACHE_CLEANUP_INTERVAL: Duration = Duration::from_secs(60);
pub enum DetectedProtocol {
H1,
H2,
H3,
}
/// Key for the protocol cache: (host, port, requested_host).

View File

@@ -451,6 +451,7 @@ impl HttpProxyService {
headers: headers.as_ref(),
is_tls: false,
protocol: Some("http"),
transport: None,
};
let route_match = match current_rm.find_route(&ctx) {
@@ -647,6 +648,11 @@ impl HttpProxyService {
let (use_h2, needs_alpn_probe) = match backend_protocol_mode {
rustproxy_config::BackendProtocol::Http1 => (false, false),
rustproxy_config::BackendProtocol::Http2 => (true, false),
rustproxy_config::BackendProtocol::Http3 => {
// HTTP/3 (QUIC) backend connections not yet implemented — fall back to H1
warn!("backendProtocol 'http3' not yet implemented, falling back to http1");
(false, false)
}
rustproxy_config::BackendProtocol::Auto => {
if !upstream.use_tls {
// No ALPN without TLS — default to H1
@@ -660,6 +666,10 @@ impl HttpProxyService {
match self.protocol_cache.get(&cache_key) {
Some(crate::protocol_cache::DetectedProtocol::H2) => (true, false),
Some(crate::protocol_cache::DetectedProtocol::H1) => (false, false),
Some(crate::protocol_cache::DetectedProtocol::H3) => {
// H3 cached but we're on TCP — fall back to H2 probe
(false, true)
}
None => (false, true), // needs ALPN probe
}
}
@@ -673,7 +683,7 @@ impl HttpProxyService {
host: upstream.host.clone(),
port: upstream.port,
use_tls: upstream.use_tls,
h2: use_h2,
protocol: if use_h2 { crate::connection_pool::PoolProtocol::H2 } else { crate::connection_pool::PoolProtocol::H1 },
};
// H2 pool checkout — reuse pooled connections for all requests.
@@ -832,7 +842,7 @@ impl HttpProxyService {
host: upstream.host.clone(),
port: upstream.port,
use_tls: upstream.use_tls,
h2: detected_h2,
protocol: if detected_h2 { crate::connection_pool::PoolProtocol::H2 } else { crate::connection_pool::PoolProtocol::H1 },
};
let io = TokioIo::new(backend);
@@ -1298,7 +1308,7 @@ impl HttpProxyService {
host: upstream.host.clone(),
port: upstream.port,
use_tls: upstream.use_tls,
h2: false,
protocol: crate::connection_pool::PoolProtocol::H1,
};
let fallback_io = TokioIo::new(fallback_backend);
let result = self.forward_h1(
@@ -1438,7 +1448,7 @@ impl HttpProxyService {
host: upstream.host.clone(),
port: upstream.port,
use_tls: upstream.use_tls,
h2: false,
protocol: crate::connection_pool::PoolProtocol::H1,
};
let fallback_io = TokioIo::new(fallback_backend);
let result = self.forward_h1(

View File

@@ -10,7 +10,23 @@ pub struct ResponseFilter;
impl ResponseFilter {
/// Apply response headers from route config and CORS settings.
/// If a `RequestContext` is provided, template variables in header values will be expanded.
/// Also injects Alt-Svc header for routes with HTTP/3 enabled.
pub fn apply_headers(route: &RouteConfig, headers: &mut HeaderMap, req_ctx: Option<&RequestContext>) {
// Inject Alt-Svc for HTTP/3 advertisement if QUIC/HTTP3 is enabled on this route
if let Some(ref udp) = route.action.udp {
if let Some(ref quic) = udp.quic {
if quic.enable_http3.unwrap_or(false) {
let port = quic.alt_svc_port
.or_else(|| req_ctx.map(|c| c.port))
.unwrap_or(443);
let max_age = quic.alt_svc_max_age.unwrap_or(86400);
let alt_svc = format!("h3=\":{}\"; ma={}", port, max_age);
if let Ok(val) = HeaderValue::from_str(&alt_svc) {
headers.insert("alt-svc", val);
}
}
}
}
// Apply custom response headers from route config
if let Some(ref route_headers) = route.headers {
if let Some(ref response_headers) = route_headers.response {

View File

@@ -184,6 +184,7 @@ mod tests {
send_proxy_protocol: None,
headers: None,
advanced: None,
backend_transport: None,
priority: None,
}
}

View File

@@ -26,6 +26,11 @@ pub struct Metrics {
pub total_http_requests: u64,
pub http_requests_per_sec: u64,
pub http_requests_per_sec_recent: u64,
// UDP metrics
pub active_udp_sessions: u64,
pub total_udp_sessions: u64,
pub total_datagrams_in: u64,
pub total_datagrams_out: u64,
}
/// Per-route metrics.
@@ -136,6 +141,12 @@ pub struct MetricsCollector {
pending_http_requests: AtomicU64,
http_request_throughput: Mutex<ThroughputTracker>,
// ── UDP metrics ──
active_udp_sessions: AtomicU64,
total_udp_sessions: AtomicU64,
total_datagrams_in: AtomicU64,
total_datagrams_out: AtomicU64,
// ── Lock-free pending throughput counters (hot path) ──
global_pending_tp_in: AtomicU64,
global_pending_tp_out: AtomicU64,
@@ -180,6 +191,10 @@ impl MetricsCollector {
backend_pool_hits: DashMap::new(),
backend_pool_misses: DashMap::new(),
backend_h2_failures: DashMap::new(),
active_udp_sessions: AtomicU64::new(0),
total_udp_sessions: AtomicU64::new(0),
total_datagrams_in: AtomicU64::new(0),
total_datagrams_out: AtomicU64::new(0),
total_http_requests: AtomicU64::new(0),
pending_http_requests: AtomicU64::new(0),
http_request_throughput: Mutex::new(ThroughputTracker::new(retention_seconds)),
@@ -350,6 +365,29 @@ impl MetricsCollector {
self.pending_http_requests.fetch_add(1, Ordering::Relaxed);
}
// ── UDP session recording methods ──
/// Record a new UDP session opened.
pub fn udp_session_opened(&self) {
self.active_udp_sessions.fetch_add(1, Ordering::Relaxed);
self.total_udp_sessions.fetch_add(1, Ordering::Relaxed);
}
/// Record a UDP session closed.
pub fn udp_session_closed(&self) {
self.active_udp_sessions.fetch_sub(1, Ordering::Relaxed);
}
/// Record a UDP datagram (inbound or outbound).
pub fn record_datagram_in(&self) {
self.total_datagrams_in.fetch_add(1, Ordering::Relaxed);
}
/// Record an outbound UDP datagram.
pub fn record_datagram_out(&self) {
self.total_datagrams_out.fetch_add(1, Ordering::Relaxed);
}
// ── Per-backend recording methods ──
/// Record a successful backend connection with its connect duration.
@@ -769,6 +807,10 @@ impl MetricsCollector {
total_http_requests: self.total_http_requests.load(Ordering::Relaxed),
http_requests_per_sec: http_rps,
http_requests_per_sec_recent: http_rps_recent,
active_udp_sessions: self.active_udp_sessions.load(Ordering::Relaxed),
total_udp_sessions: self.total_udp_sessions.load(Ordering::Relaxed),
total_datagrams_in: self.total_datagrams_in.load(Ordering::Relaxed),
total_datagrams_out: self.total_datagrams_out.load(Ordering::Relaxed),
}
}
}

View File

@@ -9,34 +9,36 @@ pub fn build_dnat_rule(
target_port: u16,
options: &NfTablesOptions,
) -> Vec<String> {
let protocol = match options.protocol.as_ref().unwrap_or(&NfTablesProtocol::Tcp) {
NfTablesProtocol::Tcp => "tcp",
NfTablesProtocol::Udp => "udp",
NfTablesProtocol::All => "tcp", // TODO: handle "all"
let protocols: Vec<&str> = match options.protocol.as_ref().unwrap_or(&NfTablesProtocol::Tcp) {
NfTablesProtocol::Tcp => vec!["tcp"],
NfTablesProtocol::Udp => vec!["udp"],
NfTablesProtocol::All => vec!["tcp", "udp"],
};
let mut rules = Vec::new();
// DNAT rule
rules.push(format!(
"nft add rule ip {} {} {} dport {} dnat to {}:{}",
table_name, chain_name, protocol, source_port, target_host, target_port,
));
// SNAT rule if preserving source IP is not enabled
if !options.preserve_source_ip.unwrap_or(false) {
for protocol in &protocols {
// DNAT rule
rules.push(format!(
"nft add rule ip {} postrouting {} dport {} masquerade",
table_name, protocol, target_port,
"nft add rule ip {} {} {} dport {} dnat to {}:{}",
table_name, chain_name, protocol, source_port, target_host, target_port,
));
}
// Rate limiting
if let Some(max_rate) = &options.max_rate {
rules.push(format!(
"nft add rule ip {} {} {} dport {} limit rate {} accept",
table_name, chain_name, protocol, source_port, max_rate,
));
// SNAT rule if preserving source IP is not enabled
if !options.preserve_source_ip.unwrap_or(false) {
rules.push(format!(
"nft add rule ip {} postrouting {} dport {} masquerade",
table_name, protocol, target_port,
));
}
// Rate limiting
if let Some(max_rate) = &options.max_rate {
rules.push(format!(
"nft add rule ip {} {} {} dport {} limit rate {} accept",
table_name, chain_name, protocol, source_port, max_rate,
));
}
}
rules
@@ -120,4 +122,25 @@ mod tests {
assert_eq!(commands.len(), 1);
assert!(commands[0].contains("delete table ip rustproxy"));
}
#[test]
fn test_protocol_all_generates_tcp_and_udp_rules() {
let mut options = make_options();
options.protocol = Some(NfTablesProtocol::All);
let rules = build_dnat_rule("rustproxy", "prerouting", 53, "10.0.0.53", 53, &options);
// Should have TCP DNAT + masquerade + UDP DNAT + masquerade = 4 rules
assert_eq!(rules.len(), 4);
assert!(rules.iter().any(|r| r.contains("tcp dport 53 dnat")));
assert!(rules.iter().any(|r| r.contains("udp dport 53 dnat")));
assert!(rules.iter().filter(|r| r.contains("masquerade")).count() == 2);
}
#[test]
fn test_protocol_udp() {
let mut options = make_options();
options.protocol = Some(NfTablesProtocol::Udp);
let rules = build_dnat_rule("rustproxy", "prerouting", 53, "10.0.0.53", 53, &options);
assert!(rules.iter().all(|r| !r.contains("tcp")));
assert!(rules.iter().any(|r| r.contains("udp dport 53 dnat")));
}
}

View File

@@ -24,3 +24,6 @@ tokio-util = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
socket2 = { workspace = true }
quinn = { workspace = true }
rcgen = { workspace = true }
base64 = { workspace = true }

View File

@@ -1,7 +1,8 @@
//! # rustproxy-passthrough
//!
//! Raw TCP/SNI passthrough engine for RustProxy.
//! Handles TCP listening, TLS ClientHello SNI extraction, and bidirectional forwarding.
//! Raw TCP/SNI passthrough engine and UDP listener for RustProxy.
//! Handles TCP listening, TLS ClientHello SNI extraction, bidirectional forwarding,
//! and UDP datagram session tracking with forwarding.
pub mod tcp_listener;
pub mod sni_parser;
@@ -11,6 +12,9 @@ pub mod tls_handler;
pub mod connection_tracker;
pub mod socket_relay;
pub mod socket_opts;
pub mod udp_session;
pub mod udp_listener;
pub mod quic_handler;
pub use tcp_listener::*;
pub use sni_parser::*;
@@ -20,3 +24,6 @@ pub use tls_handler::*;
pub use connection_tracker::*;
pub use socket_relay::*;
pub use socket_opts::*;
pub use udp_session::*;
pub use udp_listener::*;
pub use quic_handler::*;

View File

@@ -0,0 +1,309 @@
//! QUIC connection handling.
//!
//! Manages QUIC endpoints (via quinn), accepts connections, and either:
//! - Forwards streams bidirectionally to TCP backends (QUIC termination)
//! - Dispatches to H3ProxyService for HTTP/3 handling (Phase 5)
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::io::AsyncWriteExt;
use arc_swap::ArcSwap;
use quinn::{Endpoint, ServerConfig as QuinnServerConfig};
use rustls::ServerConfig as RustlsServerConfig;
use tokio_util::sync::CancellationToken;
use tracing::{debug, info, warn};
use rustproxy_config::{RouteConfig, TransportProtocol};
use rustproxy_metrics::MetricsCollector;
use rustproxy_routing::{MatchContext, RouteManager};
use crate::connection_tracker::ConnectionTracker;
use crate::forwarder::ForwardMetricsCtx;
/// Create a QUIC server endpoint on the given port with the provided TLS config.
///
/// The TLS config must have ALPN protocols set (e.g., `h3` for HTTP/3).
pub fn create_quic_endpoint(
port: u16,
tls_config: Arc<RustlsServerConfig>,
) -> anyhow::Result<Endpoint> {
let quic_crypto = quinn::crypto::rustls::QuicServerConfig::try_from(tls_config)
.map_err(|e| anyhow::anyhow!("Failed to create QUIC crypto config: {}", e))?;
let server_config = QuinnServerConfig::with_crypto(Arc::new(quic_crypto));
let socket = std::net::UdpSocket::bind(SocketAddr::from(([0, 0, 0, 0], port)))?;
let endpoint = Endpoint::new(
quinn::EndpointConfig::default(),
Some(server_config),
socket,
quinn::default_runtime()
.ok_or_else(|| anyhow::anyhow!("No async runtime for quinn"))?,
)?;
info!("QUIC endpoint listening on port {}", port);
Ok(endpoint)
}
/// Run the QUIC accept loop for a single endpoint.
///
/// Accepts incoming QUIC connections and spawns a task per connection.
pub async fn quic_accept_loop(
endpoint: Endpoint,
port: u16,
route_manager: Arc<ArcSwap<RouteManager>>,
metrics: Arc<MetricsCollector>,
conn_tracker: Arc<ConnectionTracker>,
cancel: CancellationToken,
) {
loop {
let incoming = tokio::select! {
_ = cancel.cancelled() => {
debug!("QUIC accept loop on port {} cancelled", port);
break;
}
incoming = endpoint.accept() => {
match incoming {
Some(conn) => conn,
None => {
debug!("QUIC endpoint on port {} closed", port);
break;
}
}
}
};
let remote_addr = incoming.remote_address();
let ip = remote_addr.ip();
// Per-IP rate limiting
if !conn_tracker.try_accept(&ip) {
debug!("QUIC connection rejected from {} (rate limit)", remote_addr);
// Drop `incoming` to refuse the connection
continue;
}
// Route matching (port + client IP, no domain yet — QUIC Initial is encrypted)
let rm = route_manager.load();
let ip_str = ip.to_string();
let ctx = MatchContext {
port,
domain: None,
path: None,
client_ip: Some(&ip_str),
tls_version: None,
headers: None,
is_tls: true,
protocol: Some("quic"),
transport: Some(TransportProtocol::Udp),
};
let route = match rm.find_route(&ctx) {
Some(m) => m.route.clone(),
None => {
debug!("No QUIC route matched for port {} from {}", port, remote_addr);
continue;
}
};
conn_tracker.connection_opened(&ip);
let route_id = route.name.clone().or(route.id.clone());
metrics.connection_opened(route_id.as_deref(), Some(&ip_str));
let metrics = Arc::clone(&metrics);
let conn_tracker = Arc::clone(&conn_tracker);
let cancel = cancel.child_token();
tokio::spawn(async move {
match handle_quic_connection(incoming, route, port, &metrics, &cancel).await {
Ok(()) => debug!("QUIC connection from {} completed", remote_addr),
Err(e) => debug!("QUIC connection from {} error: {}", remote_addr, e),
}
// Cleanup
conn_tracker.connection_closed(&ip);
metrics.connection_closed(route_id.as_deref(), Some(&ip_str));
});
}
// Graceful shutdown: close endpoint and wait for in-flight connections
endpoint.close(quinn::VarInt::from_u32(0), b"server shutting down");
endpoint.wait_idle().await;
info!("QUIC endpoint on port {} shut down", port);
}
/// Handle a single accepted QUIC connection.
async fn handle_quic_connection(
incoming: quinn::Incoming,
route: RouteConfig,
port: u16,
metrics: &MetricsCollector,
cancel: &CancellationToken,
) -> anyhow::Result<()> {
let connection = incoming.await?;
let remote_addr = connection.remote_address();
debug!("QUIC connection established from {}", remote_addr);
// Check if this route has HTTP/3 enabled
let enable_http3 = route.action.udp.as_ref()
.and_then(|u| u.quic.as_ref())
.and_then(|q| q.enable_http3)
.unwrap_or(false);
if enable_http3 {
// Phase 5: dispatch to H3ProxyService
// For now, log and accept streams for basic handling
debug!("HTTP/3 enabled for route {:?}, dispatching to H3 handler", route.name);
handle_h3_connection(connection, route, port, metrics, cancel).await
} else {
// Non-HTTP3 QUIC: bidirectional stream forwarding to TCP backend
handle_quic_stream_forwarding(connection, route, port, metrics, cancel).await
}
}
/// Forward QUIC streams bidirectionally to a TCP backend.
///
/// For each accepted bidirectional QUIC stream, connects to the backend
/// via TCP and forwards data in both directions. Quinn's RecvStream/SendStream
/// implement AsyncRead/AsyncWrite, enabling reuse of existing forwarder patterns.
async fn handle_quic_stream_forwarding(
connection: quinn::Connection,
route: RouteConfig,
port: u16,
_metrics: &MetricsCollector,
cancel: &CancellationToken,
) -> anyhow::Result<()> {
let remote_addr = connection.remote_address();
let route_id = route.name.as_deref().or(route.id.as_deref());
// Resolve backend target
let target = route.action.targets.as_ref()
.and_then(|t| t.first())
.ok_or_else(|| anyhow::anyhow!("No target for QUIC route"))?;
let backend_host = target.host.first();
let backend_port = target.port.resolve(port);
let backend_addr = format!("{}:{}", backend_host, backend_port);
loop {
let (send_stream, recv_stream) = tokio::select! {
_ = cancel.cancelled() => break,
result = connection.accept_bi() => {
match result {
Ok(streams) => streams,
Err(quinn::ConnectionError::ApplicationClosed(_)) => break,
Err(quinn::ConnectionError::LocallyClosed) => break,
Err(e) => {
debug!("QUIC stream accept error from {}: {}", remote_addr, e);
break;
}
}
}
};
let backend_addr = backend_addr.clone();
let ip_str = remote_addr.ip().to_string();
let _fwd_ctx = ForwardMetricsCtx {
collector: Arc::new(MetricsCollector::new()), // TODO: share real metrics
route_id: route_id.map(|s| s.to_string()),
source_ip: Some(ip_str),
};
// Spawn a task for each QUIC stream → TCP bidirectional forwarding
tokio::spawn(async move {
match forward_quic_stream_to_tcp(
send_stream,
recv_stream,
&backend_addr,
).await {
Ok((bytes_in, bytes_out)) => {
debug!("QUIC stream forwarded: {}B in, {}B out", bytes_in, bytes_out);
}
Err(e) => {
debug!("QUIC stream forwarding error: {}", e);
}
}
});
}
Ok(())
}
/// Forward a single QUIC bidirectional stream to a TCP backend connection.
async fn forward_quic_stream_to_tcp(
mut quic_send: quinn::SendStream,
mut quic_recv: quinn::RecvStream,
backend_addr: &str,
) -> anyhow::Result<(u64, u64)> {
// Connect to backend TCP
let tcp_stream = tokio::net::TcpStream::connect(backend_addr).await?;
let (mut tcp_read, mut tcp_write) = tcp_stream.into_split();
// Bidirectional copy
let client_to_backend = tokio::io::copy(&mut quic_recv, &mut tcp_write);
let backend_to_client = tokio::io::copy(&mut tcp_read, &mut quic_send);
let (c2b, b2c) = tokio::join!(client_to_backend, backend_to_client);
let bytes_in = c2b.unwrap_or(0);
let bytes_out = b2c.unwrap_or(0);
// Graceful shutdown
let _ = quic_send.finish();
let _ = tcp_write.shutdown().await;
Ok((bytes_in, bytes_out))
}
/// Placeholder for HTTP/3 connection handling (Phase 5).
///
/// Once h3_service is implemented, this will delegate to it.
async fn handle_h3_connection(
connection: quinn::Connection,
_route: RouteConfig,
_port: u16,
_metrics: &MetricsCollector,
cancel: &CancellationToken,
) -> anyhow::Result<()> {
warn!("HTTP/3 handling not yet fully implemented — accepting connection but no request processing");
// Keep the connection alive until cancelled or closed
tokio::select! {
_ = cancel.cancelled() => {}
reason = connection.closed() => {
debug!("HTTP/3 connection closed: {}", reason);
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_quic_endpoint_requires_tls_config() {
// Install the ring crypto provider for tests
let _ = rustls::crypto::ring::default_provider().install_default();
// Generate a single self-signed cert and use its key pair
let self_signed = rcgen::generate_simple_self_signed(vec!["localhost".to_string()])
.unwrap();
let cert_der = self_signed.cert.der().clone();
let key_der = self_signed.key_pair.serialize_der();
let mut tls_config = RustlsServerConfig::builder()
.with_no_client_auth()
.with_single_cert(
vec![cert_der.into()],
rustls::pki_types::PrivateKeyDer::try_from(key_der).unwrap(),
)
.unwrap();
tls_config.alpn_protocols = vec![b"h3".to_vec()];
// Port 0 = OS assigns a free port
let result = create_quic_endpoint(0, Arc::new(tls_config));
assert!(result.is_ok(), "QUIC endpoint creation failed: {:?}", result.err());
}
}

View File

@@ -625,6 +625,7 @@ impl TcpListenerManager {
headers: None,
is_tls: false,
protocol: None,
transport: None,
};
if let Some(quick_match) = route_manager.find_route(&quick_ctx) {
@@ -814,6 +815,7 @@ impl TcpListenerManager {
is_tls,
// For TLS connections, protocol is unknown until after termination
protocol: if is_http { Some("http") } else if !is_tls { Some("tcp") } else { None },
transport: None,
};
let route_match = route_manager.find_route(&ctx);

View File

@@ -0,0 +1,477 @@
//! UDP listener manager.
//!
//! Binds UDP sockets on configured ports, receives datagrams, matches routes,
//! tracks sessions (flows), and forwards datagrams to backend UDP sockets.
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::atomic::Ordering;
use std::sync::Arc;
use tokio::io::AsyncWriteExt;
use arc_swap::ArcSwap;
use tokio::net::UdpSocket;
use tokio::task::JoinHandle;
use tokio::sync::RwLock;
use tokio_util::sync::CancellationToken;
use tracing::{debug, error, info, warn};
use rustproxy_config::{RouteActionType, TransportProtocol};
use rustproxy_metrics::MetricsCollector;
use rustproxy_routing::{MatchContext, RouteManager};
use crate::connection_tracker::ConnectionTracker;
use crate::udp_session::{SessionKey, UdpSession, UdpSessionConfig, UdpSessionTable};
/// Manages UDP listeners across all configured ports.
pub struct UdpListenerManager {
/// Port → recv loop task handle
listeners: HashMap<u16, JoinHandle<()>>,
/// Hot-reloadable route table
route_manager: Arc<ArcSwap<RouteManager>>,
/// Shared metrics collector
metrics: Arc<MetricsCollector>,
/// Per-IP session/rate limiting (shared with TCP)
conn_tracker: Arc<ConnectionTracker>,
/// Shared session table across all ports
session_table: Arc<UdpSessionTable>,
/// Cancellation for graceful shutdown
cancel_token: CancellationToken,
/// Unix socket path for datagram handler relay
datagram_handler_relay: Arc<RwLock<Option<String>>>,
}
impl UdpListenerManager {
pub fn new(
route_manager: Arc<RouteManager>,
metrics: Arc<MetricsCollector>,
conn_tracker: Arc<ConnectionTracker>,
cancel_token: CancellationToken,
) -> Self {
Self {
listeners: HashMap::new(),
route_manager: Arc::new(ArcSwap::from(route_manager)),
metrics,
conn_tracker,
session_table: Arc::new(UdpSessionTable::new()),
cancel_token,
datagram_handler_relay: Arc::new(RwLock::new(None)),
}
}
/// Update the route manager (for hot-reload).
pub fn update_routes(&self, route_manager: Arc<RouteManager>) {
self.route_manager.store(route_manager);
}
/// Start listening on a UDP port.
///
/// If any route on this port has QUIC config (`action.udp.quic`), a quinn
/// endpoint is created instead of a raw UDP socket.
pub async fn add_port(&mut self, port: u16) -> anyhow::Result<()> {
self.add_port_with_tls(port, None).await
}
/// Start listening on a UDP port with optional TLS config for QUIC.
pub async fn add_port_with_tls(
&mut self,
port: u16,
tls_config: Option<std::sync::Arc<rustls::ServerConfig>>,
) -> anyhow::Result<()> {
if self.listeners.contains_key(&port) {
debug!("UDP port {} already listening", port);
return Ok(());
}
// Check if any route on this port uses QUIC
let rm = self.route_manager.load();
let has_quic = rm.routes_for_port(port).iter().any(|r| {
r.action.udp.as_ref()
.and_then(|u| u.quic.as_ref())
.is_some()
});
if has_quic {
if let Some(tls) = tls_config {
// Create QUIC endpoint
let endpoint = crate::quic_handler::create_quic_endpoint(port, tls)?;
let handle = tokio::spawn(crate::quic_handler::quic_accept_loop(
endpoint,
port,
Arc::clone(&self.route_manager),
Arc::clone(&self.metrics),
Arc::clone(&self.conn_tracker),
self.cancel_token.child_token(),
));
self.listeners.insert(port, handle);
info!("QUIC endpoint started on port {}", port);
return Ok(());
} else {
warn!("QUIC routes on port {} but no TLS config provided, falling back to raw UDP", port);
}
}
// Raw UDP listener
let addr: SocketAddr = ([0, 0, 0, 0], port).into();
let socket = UdpSocket::bind(addr).await?;
let socket = Arc::new(socket);
info!("UDP listener bound on port {}", port);
let handle = tokio::spawn(Self::recv_loop(
socket,
port,
Arc::clone(&self.route_manager),
Arc::clone(&self.metrics),
Arc::clone(&self.conn_tracker),
Arc::clone(&self.session_table),
Arc::clone(&self.datagram_handler_relay),
self.cancel_token.child_token(),
));
self.listeners.insert(port, handle);
// Start the session cleanup task if this is the first port
if self.listeners.len() == 1 {
self.start_cleanup_task();
}
Ok(())
}
/// Stop listening on a UDP port.
pub fn remove_port(&mut self, port: u16) {
if let Some(handle) = self.listeners.remove(&port) {
handle.abort();
info!("UDP listener removed from port {}", port);
}
}
/// Get all listening UDP ports.
pub fn listening_ports(&self) -> Vec<u16> {
let mut ports: Vec<u16> = self.listeners.keys().copied().collect();
ports.sort();
ports
}
/// Stop all listeners and clean up.
pub async fn stop(&mut self) {
self.cancel_token.cancel();
for (port, handle) in self.listeners.drain() {
handle.abort();
debug!("UDP listener stopped on port {}", port);
}
info!("All UDP listeners stopped, {} sessions remaining",
self.session_table.session_count());
}
/// Set the datagram handler relay socket path.
pub async fn set_datagram_handler_relay(&self, path: String) {
let mut relay = self.datagram_handler_relay.write().await;
*relay = Some(path);
}
/// Start periodic session cleanup task.
fn start_cleanup_task(&self) {
let session_table = Arc::clone(&self.session_table);
let metrics = Arc::clone(&self.metrics);
let cancel = self.cancel_token.child_token();
let route_manager = Arc::clone(&self.route_manager);
tokio::spawn(async move {
let mut interval = tokio::time::interval(std::time::Duration::from_secs(10));
loop {
tokio::select! {
_ = cancel.cancelled() => break,
_ = interval.tick() => {
// Determine the timeout from routes (use the minimum configured timeout,
// or default 60s if none configured)
let rm = route_manager.load();
let timeout_ms = Self::get_min_session_timeout(&rm);
let removed = session_table.cleanup_idle(timeout_ms, &metrics);
if removed > 0 {
debug!("UDP session cleanup: removed {} idle sessions, {} remaining",
removed, session_table.session_count());
}
}
}
}
});
}
/// Get the minimum session timeout across all UDP routes.
fn get_min_session_timeout(_rm: &RouteManager) -> u64 {
// Default to 60 seconds; actual per-route timeouts checked during cleanup
60_000
}
/// Main receive loop for a UDP port.
async fn recv_loop(
socket: Arc<UdpSocket>,
port: u16,
route_manager: Arc<ArcSwap<RouteManager>>,
metrics: Arc<MetricsCollector>,
conn_tracker: Arc<ConnectionTracker>,
session_table: Arc<UdpSessionTable>,
datagram_handler_relay: Arc<RwLock<Option<String>>>,
cancel: CancellationToken,
) {
// Use a reasonably large buffer; actual max is per-route but we need a single buffer
let mut buf = vec![0u8; 65535];
loop {
let (len, client_addr) = tokio::select! {
_ = cancel.cancelled() => {
debug!("UDP recv loop on port {} cancelled", port);
break;
}
result = socket.recv_from(&mut buf) => {
match result {
Ok(r) => r,
Err(e) => {
warn!("UDP recv error on port {}: {}", port, e);
continue;
}
}
}
};
let datagram = &buf[..len];
// Route matching
let rm = route_manager.load();
let ip_str = client_addr.ip().to_string();
let ctx = MatchContext {
port,
domain: None,
path: None,
client_ip: Some(&ip_str),
tls_version: None,
headers: None,
is_tls: false,
protocol: Some("udp"),
transport: Some(TransportProtocol::Udp),
};
let route_match = match rm.find_route(&ctx) {
Some(m) => m,
None => {
debug!("No UDP route matched for port {} from {}", port, client_addr);
continue;
}
};
let route = route_match.route;
let route_id = route.name.as_deref().or(route.id.as_deref());
// Socket handler routes → relay datagram to TS via Unix socket
if route.action.action_type == RouteActionType::SocketHandler {
let relay_path = datagram_handler_relay.read().await;
if let Some(ref path) = *relay_path {
if let Err(e) = Self::relay_datagram_to_ts(
path,
route_id.unwrap_or("unknown"),
&client_addr,
port,
datagram,
).await {
debug!("Failed to relay UDP datagram to TS: {}", e);
}
} else {
debug!("UDP datagram handler relay not configured for route {:?}", route_id);
}
continue;
}
// Get UDP config from route
let udp_config = UdpSessionConfig::from_route_udp(route.action.udp.as_ref());
// Check datagram size
if len as u32 > udp_config.max_datagram_size {
debug!("UDP datagram too large ({} > {}) from {}, dropping",
len, udp_config.max_datagram_size, client_addr);
continue;
}
// Session lookup or create
let session_key: SessionKey = (client_addr, port);
let session = match session_table.get(&session_key) {
Some(s) => s,
None => {
// New session — check per-IP limits
if !conn_tracker.try_accept(&client_addr.ip()) {
debug!("UDP session rejected for {} (rate limit)", client_addr);
continue;
}
if !session_table.can_create_session(
&client_addr.ip(),
udp_config.max_sessions_per_ip,
) {
debug!("UDP session rejected for {} (per-IP session limit)", client_addr);
continue;
}
// Resolve target
let target = match route_match.target.or_else(|| {
route.action.targets.as_ref().and_then(|t| t.first())
}) {
Some(t) => t,
None => {
warn!("No target for UDP route {:?}", route_id);
continue;
}
};
let backend_host = target.host.first();
let backend_port = target.port.resolve(port);
let backend_addr = format!("{}:{}", backend_host, backend_port);
// Create backend socket
let backend_socket = match UdpSocket::bind("0.0.0.0:0").await {
Ok(s) => s,
Err(e) => {
error!("Failed to bind backend UDP socket: {}", e);
continue;
}
};
if let Err(e) = backend_socket.connect(&backend_addr).await {
error!("Failed to connect backend UDP socket to {}: {}", backend_addr, e);
continue;
}
let backend_socket = Arc::new(backend_socket);
debug!("New UDP session: {} -> {} (via port {})",
client_addr, backend_addr, port);
// Spawn return-path relay task
let session_cancel = CancellationToken::new();
let return_task = tokio::spawn(Self::return_relay(
Arc::clone(&backend_socket),
Arc::clone(&socket),
client_addr,
Arc::clone(&session_table),
session_key,
Arc::clone(&metrics),
route_id.map(|s| s.to_string()),
session_cancel.child_token(),
));
let session = Arc::new(UdpSession {
backend_socket,
last_activity: std::sync::atomic::AtomicU64::new(session_table.elapsed_ms()),
created_at: std::time::Instant::now(),
route_id: route_id.map(|s| s.to_string()),
source_ip: client_addr.ip(),
client_addr,
return_task,
cancel: session_cancel,
});
if !session_table.insert(session_key, Arc::clone(&session), udp_config.max_sessions_per_ip) {
warn!("Failed to insert UDP session (race condition)");
continue;
}
// Track in metrics
conn_tracker.connection_opened(&client_addr.ip());
metrics.connection_opened(route_id, Some(&ip_str));
metrics.udp_session_opened();
session
}
};
// Forward datagram to backend
match session.backend_socket.send(datagram).await {
Ok(_) => {
session.last_activity.store(session_table.elapsed_ms(), Ordering::Relaxed);
metrics.record_bytes(len as u64, 0, route_id, Some(&ip_str));
metrics.record_datagram_in();
}
Err(e) => {
debug!("Failed to send UDP datagram to backend: {}", e);
}
}
}
}
/// Return-path relay: backend → client.
async fn return_relay(
backend_socket: Arc<UdpSocket>,
listener_socket: Arc<UdpSocket>,
client_addr: SocketAddr,
session_table: Arc<UdpSessionTable>,
session_key: SessionKey,
metrics: Arc<MetricsCollector>,
route_id: Option<String>,
cancel: CancellationToken,
) {
let mut buf = vec![0u8; 65535];
let ip_str = client_addr.ip().to_string();
loop {
let len = tokio::select! {
_ = cancel.cancelled() => break,
result = backend_socket.recv(&mut buf) => {
match result {
Ok(len) => len,
Err(e) => {
debug!("UDP backend recv error for {}: {}", client_addr, e);
break;
}
}
}
};
// Send reply back to client
match listener_socket.send_to(&buf[..len], client_addr).await {
Ok(_) => {
// Update session activity
if let Some(session) = session_table.get(&session_key) {
session.last_activity.store(session_table.elapsed_ms(), Ordering::Relaxed);
}
metrics.record_bytes(0, len as u64, route_id.as_deref(), Some(&ip_str));
metrics.record_datagram_out();
}
Err(e) => {
debug!("Failed to send UDP reply to {}: {}", client_addr, e);
break;
}
}
}
}
/// Relay a UDP datagram to the TypeScript handler via Unix socket.
/// Uses length-prefixed JSON framing: [4-byte BE length][JSON payload]
async fn relay_datagram_to_ts(
relay_path: &str,
route_key: &str,
client_addr: &SocketAddr,
dest_port: u16,
datagram: &[u8],
) -> anyhow::Result<()> {
use base64::Engine;
let payload_b64 = base64::engine::general_purpose::STANDARD.encode(datagram);
let msg = serde_json::json!({
"type": "datagram",
"routeKey": route_key,
"sourceIp": client_addr.ip().to_string(),
"sourcePort": client_addr.port(),
"destPort": dest_port,
"payloadBase64": payload_b64,
});
let json = serde_json::to_vec(&msg)?;
// Connect to relay (one-shot for now; persistent connection optimization deferred)
let mut stream = tokio::net::UnixStream::connect(relay_path).await?;
// Length-prefixed frame
let len_bytes = (json.len() as u32).to_be_bytes();
stream.write_all(&len_bytes).await?;
stream.write_all(&json).await?;
stream.flush().await?;
Ok(())
}
}

View File

@@ -0,0 +1,320 @@
//! UDP session (flow) tracking.
//!
//! A UDP "session" is a flow identified by (client_addr, listening_port).
//! Each session maintains a backend socket bound to an ephemeral port and
//! connected to the backend target, plus a background task that relays
//! return datagrams from the backend back to the client.
use std::net::{IpAddr, SocketAddr};
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::Instant;
use dashmap::DashMap;
use tokio::net::UdpSocket;
use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken;
use tracing::debug;
use rustproxy_metrics::MetricsCollector;
/// A single UDP session (flow).
pub struct UdpSession {
/// Socket bound to ephemeral port, connected to backend
pub backend_socket: Arc<UdpSocket>,
/// Milliseconds since the session table's epoch
pub last_activity: AtomicU64,
/// When the session was created
pub created_at: Instant,
/// Route ID for metrics
pub route_id: Option<String>,
/// Source IP for metrics/tracking
pub source_ip: IpAddr,
/// Client address (for return path)
pub client_addr: SocketAddr,
/// Handle for the return-path relay task
pub return_task: JoinHandle<()>,
/// Per-session cancellation
pub cancel: CancellationToken,
}
impl Drop for UdpSession {
fn drop(&mut self) {
self.cancel.cancel();
self.return_task.abort();
}
}
/// Configuration for UDP session behavior.
#[derive(Debug, Clone)]
pub struct UdpSessionConfig {
/// Idle timeout in milliseconds. Default: 60000.
pub session_timeout_ms: u64,
/// Max concurrent sessions per source IP. Default: 1000.
pub max_sessions_per_ip: u32,
/// Max accepted datagram size in bytes. Default: 65535.
pub max_datagram_size: u32,
}
impl Default for UdpSessionConfig {
fn default() -> Self {
Self {
session_timeout_ms: 60_000,
max_sessions_per_ip: 1_000,
max_datagram_size: 65_535,
}
}
}
impl UdpSessionConfig {
/// Build from route's UDP config, falling back to defaults.
pub fn from_route_udp(udp: Option<&rustproxy_config::RouteUdp>) -> Self {
match udp {
Some(u) => Self {
session_timeout_ms: u.session_timeout.unwrap_or(60_000),
max_sessions_per_ip: u.max_sessions_per_ip.unwrap_or(1_000),
max_datagram_size: u.max_datagram_size.unwrap_or(65_535),
},
None => Self::default(),
}
}
}
/// Session key: (client address, listening port).
pub type SessionKey = (SocketAddr, u16);
/// Tracks all active UDP sessions across all ports.
pub struct UdpSessionTable {
/// Active sessions keyed by (client_addr, listen_port)
sessions: DashMap<SessionKey, Arc<UdpSession>>,
/// Per-IP session counts for enforcing limits
ip_session_counts: DashMap<IpAddr, u32>,
/// Time reference for last_activity
epoch: Instant,
}
impl UdpSessionTable {
pub fn new() -> Self {
Self {
sessions: DashMap::new(),
ip_session_counts: DashMap::new(),
epoch: Instant::now(),
}
}
/// Get elapsed milliseconds since epoch (for last_activity tracking).
pub fn elapsed_ms(&self) -> u64 {
self.epoch.elapsed().as_millis() as u64
}
/// Look up an existing session.
pub fn get(&self, key: &SessionKey) -> Option<Arc<UdpSession>> {
self.sessions.get(key).map(|entry| Arc::clone(entry.value()))
}
/// Check if we can create a new session for this IP (under the per-IP limit).
pub fn can_create_session(&self, ip: &IpAddr, max_per_ip: u32) -> bool {
let count = self.ip_session_counts
.get(ip)
.map(|c| *c.value())
.unwrap_or(0);
count < max_per_ip
}
/// Insert a new session. Returns false if per-IP limit exceeded.
pub fn insert(
&self,
key: SessionKey,
session: Arc<UdpSession>,
max_per_ip: u32,
) -> bool {
let ip = session.source_ip;
// Atomically check and increment per-IP count
let mut count_entry = self.ip_session_counts.entry(ip).or_insert(0);
if *count_entry.value() >= max_per_ip {
return false;
}
*count_entry.value_mut() += 1;
drop(count_entry);
self.sessions.insert(key, session);
true
}
/// Remove a session and decrement per-IP count.
pub fn remove(&self, key: &SessionKey) -> Option<Arc<UdpSession>> {
if let Some((_, session)) = self.sessions.remove(key) {
let ip = session.source_ip;
if let Some(mut count) = self.ip_session_counts.get_mut(&ip) {
*count.value_mut() = count.value().saturating_sub(1);
if *count.value() == 0 {
drop(count);
self.ip_session_counts.remove(&ip);
}
}
Some(session)
} else {
None
}
}
/// Clean up idle sessions past the given timeout.
/// Returns the number of sessions removed.
pub fn cleanup_idle(
&self,
timeout_ms: u64,
metrics: &MetricsCollector,
) -> usize {
let now_ms = self.elapsed_ms();
let mut removed = 0;
// Collect keys to remove (avoid holding DashMap refs during removal)
let stale_keys: Vec<SessionKey> = self.sessions.iter()
.filter(|entry| {
let last = entry.value().last_activity.load(Ordering::Relaxed);
now_ms.saturating_sub(last) >= timeout_ms
})
.map(|entry| *entry.key())
.collect();
for key in stale_keys {
if let Some(session) = self.remove(&key) {
debug!(
"UDP session expired: {} -> port {} (idle {}ms)",
session.client_addr, key.1,
now_ms.saturating_sub(session.last_activity.load(Ordering::Relaxed))
);
metrics.connection_closed(
session.route_id.as_deref(),
Some(&session.source_ip.to_string()),
);
metrics.udp_session_closed();
removed += 1;
}
}
removed
}
/// Total number of active sessions.
pub fn session_count(&self) -> usize {
self.sessions.len()
}
/// Number of tracked IPs with active sessions.
pub fn tracked_ips(&self) -> usize {
self.ip_session_counts.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::{Ipv4Addr, SocketAddrV4};
fn make_addr(port: u16) -> SocketAddr {
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(10, 0, 0, 1), port))
}
fn make_session(client_addr: SocketAddr, cancel: CancellationToken) -> Arc<UdpSession> {
// Create a dummy backend socket for testing
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
let backend_socket = rt.block_on(async {
Arc::new(UdpSocket::bind("127.0.0.1:0").await.unwrap())
});
let child_cancel = cancel.child_token();
let return_task = rt.spawn(async move {
child_cancel.cancelled().await;
});
Arc::new(UdpSession {
backend_socket,
last_activity: AtomicU64::new(0),
created_at: Instant::now(),
route_id: None,
source_ip: client_addr.ip(),
client_addr,
return_task,
cancel,
})
}
#[test]
fn test_session_table_insert_and_get() {
let table = UdpSessionTable::new();
let cancel = CancellationToken::new();
let addr = make_addr(12345);
let key: SessionKey = (addr, 53);
let session = make_session(addr, cancel);
assert!(table.insert(key, session, 1000));
assert!(table.get(&key).is_some());
assert_eq!(table.session_count(), 1);
}
#[test]
fn test_session_table_per_ip_limit() {
let table = UdpSessionTable::new();
let ip = Ipv4Addr::new(10, 0, 0, 1);
// Insert 2 sessions from same IP, limit is 2
for port in [12345u16, 12346] {
let addr = SocketAddr::V4(SocketAddrV4::new(ip, port));
let cancel = CancellationToken::new();
let session = make_session(addr, cancel);
assert!(table.insert((addr, 53), session, 2));
}
// Third should be rejected
let addr3 = SocketAddr::V4(SocketAddrV4::new(ip, 12347));
let cancel3 = CancellationToken::new();
let session3 = make_session(addr3, cancel3);
assert!(!table.insert((addr3, 53), session3, 2));
assert_eq!(table.session_count(), 2);
}
#[test]
fn test_session_table_remove() {
let table = UdpSessionTable::new();
let cancel = CancellationToken::new();
let addr = make_addr(12345);
let key: SessionKey = (addr, 53);
let session = make_session(addr, cancel);
table.insert(key, session, 1000);
assert_eq!(table.session_count(), 1);
assert_eq!(table.tracked_ips(), 1);
table.remove(&key);
assert_eq!(table.session_count(), 0);
assert_eq!(table.tracked_ips(), 0);
}
#[test]
fn test_session_config_defaults() {
let config = UdpSessionConfig::default();
assert_eq!(config.session_timeout_ms, 60_000);
assert_eq!(config.max_sessions_per_ip, 1_000);
assert_eq!(config.max_datagram_size, 65_535);
}
#[test]
fn test_session_config_from_route() {
let route_udp = rustproxy_config::RouteUdp {
session_timeout: Some(10_000),
max_sessions_per_ip: Some(500),
max_datagram_size: Some(1400),
quic: None,
};
let config = UdpSessionConfig::from_route_udp(Some(&route_udp));
assert_eq!(config.session_timeout_ms, 10_000);
assert_eq!(config.max_sessions_per_ip, 500);
assert_eq!(config.max_datagram_size, 1400);
}
}

View File

@@ -1,6 +1,6 @@
use std::collections::HashMap;
use rustproxy_config::{RouteConfig, RouteTarget, TlsMode};
use rustproxy_config::{RouteConfig, RouteTarget, TransportProtocol, TlsMode};
use crate::matchers;
/// Context for route matching (subset of connection info).
@@ -12,8 +12,10 @@ pub struct MatchContext<'a> {
pub tls_version: Option<&'a str>,
pub headers: Option<&'a HashMap<String, String>>,
pub is_tls: bool,
/// Detected protocol: "http" or "tcp". None when unknown (e.g. pre-TLS-termination).
/// Detected protocol: "http", "tcp", "udp", "quic". None when unknown.
pub protocol: Option<&'a str>,
/// Transport protocol of the listener: None = TCP (backward compat), Some(Udp), Some(All).
pub transport: Option<TransportProtocol>,
}
/// Result of a route match.
@@ -92,6 +94,22 @@ impl RouteManager {
fn matches_route(&self, route: &RouteConfig, ctx: &MatchContext<'_>) -> bool {
let rm = &route.route_match;
// Transport filtering: ensure route transport matches context transport
let route_transport = rm.transport.as_ref();
let ctx_transport = ctx.transport.as_ref();
match (route_transport, ctx_transport) {
// Route requires UDP only — reject non-UDP contexts
(Some(TransportProtocol::Udp), None) |
(Some(TransportProtocol::Udp), Some(TransportProtocol::Tcp)) => return false,
// Route requires TCP only — reject UDP contexts
(Some(TransportProtocol::Tcp), Some(TransportProtocol::Udp)) => return false,
// Route has no transport (default = TCP) — reject UDP contexts
(None, Some(TransportProtocol::Udp)) => return false,
// All other combinations match: All matches everything, same transport matches,
// None + None/Tcp matches (backward compat)
_ => {}
}
// Domain matching
if let Some(ref domains) = rm.domains {
if let Some(domain) = ctx.domain {
@@ -303,6 +321,7 @@ mod tests {
id: None,
route_match: RouteMatch {
ports: PortRange::Single(port),
transport: None,
domains: domain.map(|d| DomainSpec::Single(d.to_string())),
path: None,
client_ip: None,
@@ -322,6 +341,7 @@ mod tests {
send_proxy_protocol: None,
headers: None,
advanced: None,
backend_transport: None,
priority: None,
}]),
tls: None,
@@ -332,6 +352,7 @@ mod tests {
forwarding_engine: None,
nftables: None,
send_proxy_protocol: None,
udp: None,
},
headers: None,
security: None,
@@ -360,6 +381,7 @@ mod tests {
headers: None,
is_tls: false,
protocol: None,
transport: None,
};
let result = manager.find_route(&ctx);
@@ -383,6 +405,7 @@ mod tests {
headers: None,
is_tls: false,
protocol: None,
transport: None,
};
let result = manager.find_route(&ctx).unwrap();
@@ -407,6 +430,7 @@ mod tests {
headers: None,
is_tls: false,
protocol: None,
transport: None,
};
assert!(manager.find_route(&ctx).is_none());
@@ -493,6 +517,7 @@ mod tests {
headers: None,
is_tls: false,
protocol: None,
transport: None,
};
assert!(manager.find_route(&ctx).is_some());
@@ -513,6 +538,7 @@ mod tests {
headers: None,
is_tls: true,
protocol: None,
transport: None,
};
assert!(manager.find_route(&ctx).is_none());
@@ -533,6 +559,7 @@ mod tests {
headers: None,
is_tls: true,
protocol: None,
transport: None,
};
assert!(manager.find_route(&ctx).is_none());
@@ -553,6 +580,7 @@ mod tests {
headers: None,
is_tls: true,
protocol: None,
transport: None,
};
assert!(manager.find_route(&ctx).is_some());
@@ -577,6 +605,7 @@ mod tests {
headers: None,
is_tls: true,
protocol: None,
transport: None,
};
let result = manager.find_route(&ctx);
@@ -602,6 +631,7 @@ mod tests {
headers: None,
is_tls: false,
protocol: None,
transport: None,
};
assert!(manager.find_route(&ctx).is_some());
@@ -621,6 +651,7 @@ mod tests {
headers: None,
is_tls: false,
protocol: None,
transport: None,
};
assert!(manager.find_route(&ctx).is_some());
@@ -645,6 +676,7 @@ mod tests {
send_proxy_protocol: None,
headers: None,
advanced: None,
backend_transport: None,
priority: Some(10),
},
RouteTarget {
@@ -657,6 +689,7 @@ mod tests {
send_proxy_protocol: None,
headers: None,
advanced: None,
backend_transport: None,
priority: None,
},
]);
@@ -672,6 +705,7 @@ mod tests {
headers: None,
is_tls: false,
protocol: None,
transport: None,
};
let result = manager.find_route(&ctx).unwrap();
assert_eq!(result.target.unwrap().host.first(), "api-backend");
@@ -686,6 +720,7 @@ mod tests {
headers: None,
is_tls: false,
protocol: None,
transport: None,
};
let result = manager.find_route(&ctx).unwrap();
assert_eq!(result.target.unwrap().host.first(), "default-backend");
@@ -711,6 +746,7 @@ mod tests {
headers: None,
is_tls: false,
protocol: Some("http"),
transport: None,
};
assert!(manager.find_route(&ctx).is_some());
}
@@ -729,6 +765,7 @@ mod tests {
headers: None,
is_tls: false,
protocol: Some("tcp"),
transport: None,
};
assert!(manager.find_route(&ctx).is_none());
}
@@ -748,6 +785,7 @@ mod tests {
headers: None,
is_tls: false,
protocol: Some("http"),
transport: None,
};
assert!(manager.find_route(&ctx_http).is_some());
@@ -760,6 +798,7 @@ mod tests {
headers: None,
is_tls: false,
protocol: Some("tcp"),
transport: None,
};
assert!(manager.find_route(&ctx_tcp).is_some());
}
@@ -780,7 +819,182 @@ mod tests {
headers: None,
is_tls: true,
protocol: None,
transport: None,
};
assert!(manager.find_route(&ctx).is_some());
}
// ===== Transport filtering tests =====
fn make_route_with_transport(port: u16, transport: Option<TransportProtocol>) -> RouteConfig {
let mut route = make_route(port, None, 0);
route.route_match.transport = transport;
route
}
#[test]
fn test_transport_udp_route_matches_udp_context() {
let routes = vec![make_route_with_transport(53, Some(TransportProtocol::Udp))];
let manager = RouteManager::new(routes);
let ctx = MatchContext {
port: 53,
domain: None,
path: None,
client_ip: None,
tls_version: None,
headers: None,
is_tls: false,
protocol: Some("udp"),
transport: Some(TransportProtocol::Udp),
};
assert!(manager.find_route(&ctx).is_some());
}
#[test]
fn test_transport_udp_route_rejects_tcp_context() {
let routes = vec![make_route_with_transport(53, Some(TransportProtocol::Udp))];
let manager = RouteManager::new(routes);
// TCP context (transport: None = TCP)
let ctx = MatchContext {
port: 53,
domain: None,
path: None,
client_ip: None,
tls_version: None,
headers: None,
is_tls: false,
protocol: None,
transport: None,
};
assert!(manager.find_route(&ctx).is_none());
}
#[test]
fn test_transport_tcp_route_rejects_udp_context() {
let routes = vec![make_route_with_transport(80, Some(TransportProtocol::Tcp))];
let manager = RouteManager::new(routes);
let ctx = MatchContext {
port: 80,
domain: None,
path: None,
client_ip: None,
tls_version: None,
headers: None,
is_tls: false,
protocol: Some("udp"),
transport: Some(TransportProtocol::Udp),
};
assert!(manager.find_route(&ctx).is_none());
}
#[test]
fn test_transport_all_matches_both() {
let routes = vec![make_route_with_transport(443, Some(TransportProtocol::All))];
let manager = RouteManager::new(routes);
// TCP context
let tcp_ctx = MatchContext {
port: 443,
domain: None,
path: None,
client_ip: None,
tls_version: None,
headers: None,
is_tls: false,
protocol: None,
transport: None,
};
assert!(manager.find_route(&tcp_ctx).is_some());
// UDP context
let udp_ctx = MatchContext {
port: 443,
domain: None,
path: None,
client_ip: None,
tls_version: None,
headers: None,
is_tls: false,
protocol: Some("udp"),
transport: Some(TransportProtocol::Udp),
};
assert!(manager.find_route(&udp_ctx).is_some());
}
#[test]
fn test_transport_none_default_matches_tcp_only() {
// Route with no transport field = TCP only (backward compat)
let routes = vec![make_route_with_transport(80, None)];
let manager = RouteManager::new(routes);
// TCP context should match
let tcp_ctx = MatchContext {
port: 80,
domain: None,
path: None,
client_ip: None,
tls_version: None,
headers: None,
is_tls: false,
protocol: None,
transport: None,
};
assert!(manager.find_route(&tcp_ctx).is_some());
// UDP context should NOT match
let udp_ctx = MatchContext {
port: 80,
domain: None,
path: None,
client_ip: None,
tls_version: None,
headers: None,
is_tls: false,
protocol: Some("udp"),
transport: Some(TransportProtocol::Udp),
};
assert!(manager.find_route(&udp_ctx).is_none());
}
#[test]
fn test_transport_mixed_routes_same_port() {
// TCP and UDP routes on the same port — each matches only its transport
let mut tcp_route = make_route_with_transport(443, Some(TransportProtocol::Tcp));
tcp_route.name = Some("tcp-route".to_string());
let mut udp_route = make_route_with_transport(443, Some(TransportProtocol::Udp));
udp_route.name = Some("udp-route".to_string());
let manager = RouteManager::new(vec![tcp_route, udp_route]);
let tcp_ctx = MatchContext {
port: 443,
domain: None,
path: None,
client_ip: None,
tls_version: None,
headers: None,
is_tls: false,
protocol: None,
transport: None,
};
let result = manager.find_route(&tcp_ctx).unwrap();
assert_eq!(result.route.name.as_deref(), Some("tcp-route"));
let udp_ctx = MatchContext {
port: 443,
domain: None,
path: None,
client_ip: None,
tls_version: None,
headers: None,
is_tls: false,
protocol: Some("udp"),
transport: Some(TransportProtocol::Udp),
};
let result = manager.find_route(&udp_ctx).unwrap();
assert_eq!(result.route.name.as_deref(), Some("udp-route"));
}
}

View File

@@ -32,6 +32,7 @@ arc-swap = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
rustls = { workspace = true }
rustls-pemfile = { workspace = true }
tokio-rustls = { workspace = true }
tokio-util = { workspace = true }
dashmap = { workspace = true }

View File

@@ -47,7 +47,7 @@ pub use rustproxy_security;
use rustproxy_config::{RouteConfig, RustProxyOptions, TlsMode, CertificateSpec, ForwardingEngine};
use rustproxy_routing::RouteManager;
use rustproxy_passthrough::{TcpListenerManager, TlsCertConfig, ConnectionConfig};
use rustproxy_passthrough::{TcpListenerManager, UdpListenerManager, TlsCertConfig, ConnectionConfig};
use rustproxy_metrics::{MetricsCollector, Metrics, Statistics};
use rustproxy_tls::{CertManager, CertStore, CertBundle, CertMetadata, CertSource};
use rustproxy_nftables::{NftManager, rule_builder};
@@ -68,6 +68,7 @@ pub struct RustProxy {
options: RustProxyOptions,
route_table: ArcSwap<RouteManager>,
listener_manager: Option<TcpListenerManager>,
udp_listener_manager: Option<UdpListenerManager>,
metrics: Arc<MetricsCollector>,
cert_manager: Option<Arc<tokio::sync::Mutex<CertManager>>>,
challenge_server: Option<challenge_server::ChallengeServer>,
@@ -114,6 +115,7 @@ impl RustProxy {
options,
route_table: ArcSwap::from(Arc::new(route_manager)),
listener_manager: None,
udp_listener_manager: None,
metrics: Arc::new(MetricsCollector::with_retention(retention)),
cert_manager,
challenge_server: None,
@@ -153,6 +155,7 @@ impl RustProxy {
send_proxy_protocol: None,
headers: None,
advanced: None,
backend_transport: None,
priority: None,
}
]);
@@ -289,17 +292,62 @@ impl RustProxy {
}
}
// Build QUIC TLS config before set_tls_configs consumes the map
let quic_tls_config = Self::build_quic_tls_config(&tls_configs);
if !tls_configs.is_empty() {
debug!("Loaded TLS certificates for {} domains", tls_configs.len());
listener.set_tls_configs(tls_configs);
}
// Bind all ports
for port in &ports {
// Determine which ports need TCP vs UDP based on route transport config
let mut tcp_ports = std::collections::HashSet::new();
let mut udp_ports = std::collections::HashSet::new();
for route in &self.options.routes {
if !route.is_enabled() { continue; }
let transport = route.route_match.transport.as_ref();
let route_ports = route.route_match.ports.to_ports();
for port in route_ports {
match transport {
Some(rustproxy_config::TransportProtocol::Udp) => {
udp_ports.insert(port);
}
Some(rustproxy_config::TransportProtocol::All) => {
tcp_ports.insert(port);
udp_ports.insert(port);
}
Some(rustproxy_config::TransportProtocol::Tcp) | None => {
tcp_ports.insert(port);
}
}
}
}
// Bind TCP ports
for port in &tcp_ports {
listener.add_port(*port).await?;
}
self.listener_manager = Some(listener);
// Bind UDP ports (if any)
if !udp_ports.is_empty() {
let conn_tracker = self.listener_manager.as_ref().unwrap().conn_tracker().clone();
let mut udp_mgr = UdpListenerManager::new(
Arc::clone(&*self.route_table.load()),
Arc::clone(&self.metrics),
conn_tracker,
self.cancel_token.clone(),
);
for port in &udp_ports {
udp_mgr.add_port_with_tls(*port, quic_tls_config.clone()).await?;
}
info!("UDP listeners started on {} ports: {:?}",
udp_ports.len(), udp_mgr.listening_ports());
self.udp_listener_manager = Some(udp_mgr);
}
self.started = true;
self.started_at = Some(Instant::now());
@@ -567,6 +615,13 @@ impl RustProxy {
listener.graceful_stop().await;
}
self.listener_manager = None;
// Stop UDP listeners
if let Some(ref mut udp_mgr) = self.udp_listener_manager {
udp_mgr.stop().await;
}
self.udp_listener_manager = None;
self.started = false;
// Reset cancel token so proxy can be restarted
self.cancel_token = CancellationToken::new();
@@ -681,6 +736,67 @@ impl RustProxy {
}
}
// Reconcile UDP ports
{
let mut new_udp_ports = HashSet::new();
for route in &routes {
if !route.is_enabled() { continue; }
let transport = route.route_match.transport.as_ref();
match transport {
Some(rustproxy_config::TransportProtocol::Udp) |
Some(rustproxy_config::TransportProtocol::All) => {
for port in route.route_match.ports.to_ports() {
new_udp_ports.insert(port);
}
}
_ => {}
}
}
let old_udp_ports: HashSet<u16> = self.udp_listener_manager
.as_ref()
.map(|u| u.listening_ports().into_iter().collect())
.unwrap_or_default();
if !new_udp_ports.is_empty() {
// Ensure UDP manager exists
if self.udp_listener_manager.is_none() {
if let Some(ref listener) = self.listener_manager {
let conn_tracker = listener.conn_tracker().clone();
self.udp_listener_manager = Some(UdpListenerManager::new(
Arc::clone(&new_manager),
Arc::clone(&self.metrics),
conn_tracker,
self.cancel_token.clone(),
));
}
}
if let Some(ref mut udp_mgr) = self.udp_listener_manager {
udp_mgr.update_routes(Arc::clone(&new_manager));
// Add new UDP ports
for port in &new_udp_ports {
if !old_udp_ports.contains(port) {
udp_mgr.add_port(*port).await?;
}
}
// Remove old UDP ports
for port in &old_udp_ports {
if !new_udp_ports.contains(port) {
udp_mgr.remove_port(*port);
}
}
}
} else if self.udp_listener_manager.is_some() {
// All UDP routes removed — shut down UDP manager
if let Some(ref mut udp_mgr) = self.udp_listener_manager {
udp_mgr.stop().await;
}
self.udp_listener_manager = None;
}
}
// Update NFTables rules: remove old, apply new
self.update_nftables_rules(&routes).await;
@@ -840,6 +956,65 @@ impl RustProxy {
self.socket_handler_relay.read().unwrap().clone()
}
/// Build a rustls ServerConfig suitable for QUIC (TLS 1.3 only, h3 ALPN).
/// Uses the first available cert from tls_configs, or returns None if no certs available.
fn build_quic_tls_config(
tls_configs: &HashMap<String, TlsCertConfig>,
) -> Option<Arc<rustls::ServerConfig>> {
// Find the first available cert (prefer wildcard, then any)
let cert_config = tls_configs.get("*")
.or_else(|| tls_configs.values().next());
let cert_config = match cert_config {
Some(c) => c,
None => return None,
};
// Parse cert chain from PEM
let mut cert_reader = std::io::BufReader::new(cert_config.cert_pem.as_bytes());
let certs: Vec<rustls::pki_types::CertificateDer<'static>> =
rustls_pemfile::certs(&mut cert_reader)
.filter_map(|r| r.ok())
.collect();
if certs.is_empty() {
return None;
}
// Parse private key from PEM
let mut key_reader = std::io::BufReader::new(cert_config.key_pem.as_bytes());
let key = match rustls_pemfile::private_key(&mut key_reader) {
Ok(Some(key)) => key,
_ => return None,
};
let mut tls_config = match rustls::ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(certs, key)
{
Ok(c) => c,
Err(e) => {
warn!("Failed to build QUIC TLS config: {}", e);
return None;
}
};
// QUIC requires h3 ALPN
tls_config.alpn_protocols = vec![b"h3".to_vec()];
Some(Arc::new(tls_config))
}
/// Set the Unix domain socket path for relaying UDP datagrams to TypeScript datagramHandler callbacks.
pub async fn set_datagram_handler_relay_path(&mut self, path: Option<String>) {
info!("Datagram handler relay path set to: {:?}", path);
if let Some(ref udp_mgr) = self.udp_listener_manager {
if let Some(ref p) = path {
udp_mgr.set_datagram_handler_relay(p.clone()).await;
}
}
}
/// Load a certificate for a domain and hot-swap the TLS configuration.
pub async fn load_certificate(
&mut self,

View File

@@ -149,6 +149,7 @@ async fn handle_request(
"getListeningPorts" => handle_get_listening_ports(&id, proxy),
"getNftablesStatus" => handle_get_nftables_status(&id, proxy).await,
"setSocketHandlerRelay" => handle_set_socket_handler_relay(&id, &request.params, proxy).await,
"setDatagramHandlerRelay" => handle_set_datagram_handler_relay(&id, &request.params, proxy).await,
"addListeningPort" => handle_add_listening_port(&id, &request.params, proxy).await,
"removeListeningPort" => handle_remove_listening_port(&id, &request.params, proxy).await,
"loadCertificate" => handle_load_certificate(&id, &request.params, proxy).await,
@@ -391,6 +392,26 @@ async fn handle_set_socket_handler_relay(
ManagementResponse::ok(id.to_string(), serde_json::json!({}))
}
async fn handle_set_datagram_handler_relay(
id: &str,
params: &serde_json::Value,
proxy: &mut Option<RustProxy>,
) -> ManagementResponse {
let p = match proxy.as_mut() {
Some(p) => p,
None => return ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
};
let socket_path = params.get("socketPath")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
info!("setDatagramHandlerRelay: socket_path={:?}", socket_path);
p.set_datagram_handler_relay_path(socket_path).await;
ManagementResponse::ok(id.to_string(), serde_json::json!({}))
}
async fn handle_add_listening_port(
id: &str,
params: &serde_json::Value,

View File

@@ -269,6 +269,7 @@ pub fn make_test_route(
id: None,
route_match: rustproxy_config::RouteMatch {
ports: rustproxy_config::PortRange::Single(port),
transport: None,
domains: domain.map(|d| rustproxy_config::DomainSpec::Single(d.to_string())),
path: None,
client_ip: None,
@@ -288,6 +289,7 @@ pub fn make_test_route(
send_proxy_protocol: None,
headers: None,
advanced: None,
backend_transport: None,
priority: None,
}]),
tls: None,
@@ -298,6 +300,7 @@ pub fn make_test_route(
forwarding_engine: None,
nftables: None,
send_proxy_protocol: None,
udp: None,
},
headers: None,
security: None,