Compare commits
40 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| e806f7257f | |||
| af4908b63f | |||
| 8fa3a51b03 | |||
| 088ef6ab09 | |||
| fdb5ec59bc | |||
| 1ea290a085 | |||
| cb71f32b90 | |||
| 46155ab12c | |||
| 490a310b54 | |||
| 6c5180573a | |||
| 30e5ab308f | |||
| d2a54b3491 | |||
| dc922c97df | |||
| 8d1bae7604 | |||
| 200e86e311 | |||
| a53a2c4ca5 | |||
| 6ee7237357 | |||
| b5b4c608f0 | |||
| af132f40fc | |||
| 781634446a | |||
| e988d935b6 | |||
| 99a026627d | |||
| 572e31587a | |||
| 8587fb997c | |||
| 9ba101c59b | |||
| 1ad3e61c15 | |||
| 3bfa451341 | |||
| 7b3ab7378b | |||
| 527c616cd4 | |||
| b04eb0ab17 | |||
| a55ff20391 | |||
| 3c24bf659b | |||
| 5be93c8d38 | |||
| 788ccea81e | |||
| 47140e5403 | |||
| a6ffa24e36 | |||
| c0e432fd9b | |||
| a3d8a3a388 | |||
| 437d1a3329 | |||
| 746d93663d |
+139
@@ -1,5 +1,144 @@
|
||||
# Changelog
|
||||
|
||||
## 2026-04-26 - 27.9.0 - feat(smart-proxy)
|
||||
add hot-reloadable global ingress security policy across Rust and TypeScript proxy layers
|
||||
|
||||
- adds global securityPolicy config with blocked IP and CIDR support to SmartProxy and RustProxy options
|
||||
- introduces management IPC support to update the security policy at runtime via setSecurityPolicy
|
||||
- enforces the global block list early for TCP, UDP, and QUIC traffic before route selection and backend handling
|
||||
|
||||
## 2026-04-26 - 27.8.2 - fix(rustproxy-metrics)
|
||||
retain inactive per-IP metric buckets briefly to capture final throughput before pruning
|
||||
|
||||
- adds a bounded retention window for closed IP buckets so short-lived transfers are still included in per-IP throughput sampling
|
||||
- prunes expired inactive IP tracking by TTL and hard cap to prevent unbounded metric map growth
|
||||
- updates Rust and throughput tests to expect zero active connections during the temporary retention period
|
||||
|
||||
## 2026-04-26 - 27.8.1 - fix(rustproxy-metrics)
|
||||
preserve high-throughput IPs in metrics snapshots when active-connection rankings are saturated
|
||||
|
||||
- Select snapshot IPs using a blend of active-connection and throughput rankings instead of only active connections
|
||||
- Adds a regression test to ensure a high-bandwidth IP remains included when many other IPs have more active connections
|
||||
|
||||
## 2026-04-14 - 27.8.0 - feat(metrics)
|
||||
add per-domain HTTP request rate metrics
|
||||
|
||||
- Record canonicalized HTTP request rates per domain in the Rust metrics collector and expose per-second and last-minute values in snapshots.
|
||||
- Add TypeScript metrics interfaces and adapter support for requests.byDomain().
|
||||
- Cover HTTP domain rate tracking and ensure TLS passthrough SNI traffic does not affect HTTP request rate metrics.
|
||||
|
||||
## 2026-04-14 - 27.7.4 - fix(rustproxy metrics)
|
||||
use stable route metrics keys across HTTP and passthrough listeners
|
||||
|
||||
- adds a shared RouteConfig::metrics_key helper that prefers route name and falls back to route id
|
||||
- updates HTTP, TCP, UDP, and QUIC metrics labeling to use the shared route metrics key consistently
|
||||
- keeps route cancellation and rate limiter indexing bound to route config ids where required
|
||||
- adds tests covering metrics key selection behavior
|
||||
|
||||
## 2026-04-14 - 27.7.3 - fix(repo)
|
||||
no changes detected
|
||||
|
||||
|
||||
## 2026-04-14 - 27.7.2 - fix(docs)
|
||||
clarify metrics documentation for domain normalization and saturating gauges
|
||||
|
||||
- Document that per-IP domain keys are normalized to lowercase and have trailing dots stripped before counting.
|
||||
- Clarify that the saturating close pattern also applies to connection and UDP active gauges.
|
||||
|
||||
## 2026-04-14 - 27.7.1 - fix(rustproxy-http,rustproxy-metrics)
|
||||
fix domain-scoped request host detection and harden connection metrics cleanup
|
||||
|
||||
- use a shared request host extractor that falls back to URI authority so domain-scoped IP allow lists work for HTTP/2 and HTTP/3 requests without a Host header
|
||||
- add request filter and host extraction tests covering domain-scoped ACL behavior
|
||||
- prevent connection counters from underflowing during close handling and clean up per-IP metrics entries more safely
|
||||
- normalize tracked domain keys in metrics to reduce duplicate entries caused by case or trailing-dot variations
|
||||
|
||||
## 2026-04-13 - 27.7.0 - feat(smart-proxy)
|
||||
add typed Rust config serialization and regex header contract coverage
|
||||
|
||||
- serialize SmartProxy routes and top-level options into explicit Rust-safe types, including header regex literals, UDP field normalization, ACME, defaults, and proxy settings
|
||||
- support JS-style regex header literals with flags in Rust header matching and add cross-contract tests for route preprocessing and config deserialization
|
||||
- improve TypeScript safety for Rust bridge and metrics integration by replacing loose any-based payloads with dedicated Rust type definitions
|
||||
|
||||
## 2026-04-13 - 27.6.0 - feat(metrics)
|
||||
track per-IP domain request metrics across HTTP and TCP passthrough traffic
|
||||
|
||||
- records domain request counts per frontend IP from HTTP Host headers and TCP SNI
|
||||
- exposes per-IP domain maps and top IP-domain request pairs through the TypeScript metrics adapter
|
||||
- bounds per-IP domain tracking and prunes stale entries to limit memory growth
|
||||
- adds metrics system documentation covering architecture, collected data, and known gaps
|
||||
|
||||
## 2026-04-06 - 27.5.0 - feat(security)
|
||||
add domain-scoped IP allow list support across HTTP and passthrough filtering
|
||||
|
||||
- extend route security types to accept IP allow entries scoped to specific domains
|
||||
- apply domain-aware IP checks using Host headers for HTTP and SNI context for QUIC and passthrough connections
|
||||
- preserve compatibility for existing plain allow list entries and add validation and tests for scoped matching
|
||||
|
||||
## 2026-04-04 - 27.4.0 - feat(rustproxy)
|
||||
add HTTP/3 proxy service wiring for QUIC listeners
|
||||
|
||||
- registers H3ProxyService with the UDP listener manager so QUIC connections can serve HTTP/3
|
||||
- keeps proxy IP configuration intact while enabling HTTP/3 handling during listener setup
|
||||
|
||||
## 2026-04-04 - 27.3.1 - fix(metrics)
|
||||
correct frontend and backend protocol connection tracking across h1, h2, h3, and websocket traffic
|
||||
|
||||
- move frontend protocol accounting from per-request to connection lifetime tracking for HTTP/1, HTTP/2, and HTTP/3
|
||||
- add backend protocol guards to connection drivers so active protocol metrics reflect live upstream connections
|
||||
- prevent protocol counter underflow by using atomic saturating decrements in the metrics collector
|
||||
- read backend protocol distribution directly from cached aggregate counters in the Rust metrics adapter
|
||||
|
||||
## 2026-04-04 - 27.3.0 - feat(test)
|
||||
add end-to-end WebSocket proxy test coverage
|
||||
|
||||
- add comprehensive WebSocket e2e tests for upgrade handling, bidirectional messaging, header forwarding, close propagation, and large payloads
|
||||
- add ws and @types/ws as development dependencies to support the new test suite
|
||||
|
||||
## 2026-04-04 - 27.2.0 - feat(metrics)
|
||||
add frontend and backend protocol distribution metrics
|
||||
|
||||
- track active and total frontend protocol counts for h1, h2, h3, websocket, and other traffic
|
||||
- add backend protocol counters with RAII guards to ensure metrics are decremented on all exit paths
|
||||
- expose protocol distribution through the TypeScript metrics interfaces and Rust metrics adapter
|
||||
|
||||
## 2026-03-27 - 27.1.0 - feat(rustproxy-passthrough)
|
||||
add selective connection recycling for route, security, and certificate updates
|
||||
|
||||
- introduce a shared connection registry to track active TCP and QUIC connections by route, source IP, and domain
|
||||
- recycle only affected connections when route actions or security rules change instead of broadly invalidating traffic
|
||||
- gracefully recycle existing connections when TLS certificates change for a domain
|
||||
- apply route-level IP security checks to QUIC connections and share route cancellation state with UDP listeners
|
||||
|
||||
## 2026-03-26 - 27.0.0 - BREAKING CHANGE(smart-proxy)
|
||||
remove route helper APIs and standardize route configuration on plain route objects
|
||||
|
||||
- Removes TypeScript route helper exports and related Rust config helpers in favor of defining routes directly with match and action properties.
|
||||
- Updates documentation and tests to use plain IRouteConfig objects and SocketHandlers imports instead of helper factory functions.
|
||||
- Moves socket handlers to a top-level utils export and keeps direct socket-handler route configuration as the supported pattern.
|
||||
|
||||
## 2026-03-26 - 26.3.0 - feat(nftables)
|
||||
move NFTables forwarding management from the Rust engine to @push.rocks/smartnftables
|
||||
|
||||
- add @push.rocks/smartnftables as a runtime dependency and export it via the plugin layer
|
||||
- remove the internal rustproxy-nftables crate along with Rust-side NFTables rule application and status management
|
||||
- apply and clean up NFTables port-forwarding rules in the TypeScript SmartProxy lifecycle and route update flow
|
||||
- change getNfTablesStatus to return local smartnftables status instead of querying the Rust bridge
|
||||
- update README documentation to describe NFTables support as provided through @push.rocks/smartnftables
|
||||
|
||||
## 2026-03-26 - 26.2.4 - fix(rustproxy-http)
|
||||
improve HTTP/3 connection reuse and clean up stale proxy state
|
||||
|
||||
- Reuse pooled HTTP/3 SendRequest handles to skip repeated SETTINGS handshakes and reduce request overhead on QUIC pool hits
|
||||
- Add periodic cleanup for per-route rate limiters and orphaned backend metrics to prevent unbounded memory growth after traffic or backend errors stop
|
||||
- Enforce HTTP max connection lifetime alongside idle timeouts and apply configured lifetime values from the TCP listener
|
||||
- Reduce HTTP/3 body copying by using owned Bytes paths for request and response streaming, and replace the custom response body adapter with a stream-based implementation
|
||||
- Harden auxiliary proxy components by capping datagram handler buffer growth and removing duplicate RustProxy exit listeners
|
||||
|
||||
## 2026-03-25 - 26.2.3 - fix(repo)
|
||||
no changes to commit
|
||||
|
||||
|
||||
## 2026-03-25 - 26.2.2 - fix(proxy)
|
||||
improve connection cleanup and route validation handling
|
||||
|
||||
|
||||
@@ -7,13 +7,16 @@
|
||||
"npm:@git.zone/tstest@^3.6.0": "3.6.0_typescript@6.0.2",
|
||||
"npm:@push.rocks/smartcrypto@^2.0.4": "2.0.4",
|
||||
"npm:@push.rocks/smartlog@^3.2.1": "3.2.1",
|
||||
"npm:@push.rocks/smartnftables@^1.0.1": "1.0.1",
|
||||
"npm:@push.rocks/smartrust@^1.3.2": "1.3.2",
|
||||
"npm:@push.rocks/smartserve@^2.0.3": "2.0.3",
|
||||
"npm:@tsclass/tsclass@^9.5.0": "9.5.0",
|
||||
"npm:@types/node@^25.5.0": "25.5.0",
|
||||
"npm:@types/ws@^8.18.1": "8.18.1",
|
||||
"npm:minimatch@^10.2.4": "10.2.4",
|
||||
"npm:typescript@^6.0.2": "6.0.2",
|
||||
"npm:why-is-node-running@^3.2.2": "3.2.2"
|
||||
"npm:why-is-node-running@^3.2.2": "3.2.2",
|
||||
"npm:ws@^8.20.0": "8.20.0"
|
||||
},
|
||||
"npm": {
|
||||
"@api.global/typedrequest-interfaces@2.0.2": {
|
||||
@@ -2298,6 +2301,14 @@
|
||||
],
|
||||
"tarball": "https://verdaccio.lossless.digital/@push.rocks/smartnetwork/-/smartnetwork-4.4.0.tgz"
|
||||
},
|
||||
"@push.rocks/smartnftables@1.0.1": {
|
||||
"integrity": "sha512-o822GH4J8dlEBvNLbm+CwU4h6isMUEh03tf2ZnOSWXc5iewRDdKdOCDwI/e+WdnGYWyv7gvH0DHztCmne6rTCg==",
|
||||
"dependencies": [
|
||||
"@push.rocks/smartlog",
|
||||
"@push.rocks/smartpromise"
|
||||
],
|
||||
"tarball": "https://verdaccio.lossless.digital/@push.rocks/smartnftables/-/smartnftables-1.0.1.tgz"
|
||||
},
|
||||
"@push.rocks/smartnpm@2.0.6": {
|
||||
"integrity": "sha512-7anKDOjX6gXWs1IAc+YWz9ZZ8gDsTwaLh+CxRnGHjAawOmK788NrrgVCg2Fb3qojrPnoxecc46F8Ivp1BT7Izw==",
|
||||
"dependencies": [
|
||||
@@ -6729,13 +6740,16 @@
|
||||
"npm:@git.zone/tstest@^3.6.0",
|
||||
"npm:@push.rocks/smartcrypto@^2.0.4",
|
||||
"npm:@push.rocks/smartlog@^3.2.1",
|
||||
"npm:@push.rocks/smartnftables@^1.0.1",
|
||||
"npm:@push.rocks/smartrust@^1.3.2",
|
||||
"npm:@push.rocks/smartserve@^2.0.3",
|
||||
"npm:@tsclass/tsclass@^9.5.0",
|
||||
"npm:@types/node@^25.5.0",
|
||||
"npm:@types/ws@^8.18.1",
|
||||
"npm:minimatch@^10.2.4",
|
||||
"npm:typescript@^6.0.2",
|
||||
"npm:why-is-node-running@^3.2.2"
|
||||
"npm:why-is-node-running@^3.2.2",
|
||||
"npm:ws@^8.20.0"
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
+5
-2
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@push.rocks/smartproxy",
|
||||
"version": "26.2.2",
|
||||
"version": "27.9.0",
|
||||
"private": false,
|
||||
"description": "A powerful proxy package with unified route-based configuration for high traffic management. Features include SSL/TLS support, flexible routing patterns, WebSocket handling, advanced security options, and automatic ACME certificate management.",
|
||||
"main": "dist_ts/index.js",
|
||||
@@ -22,12 +22,15 @@
|
||||
"@git.zone/tstest": "^3.6.0",
|
||||
"@push.rocks/smartserve": "^2.0.3",
|
||||
"@types/node": "^25.5.0",
|
||||
"@types/ws": "^8.18.1",
|
||||
"typescript": "^6.0.2",
|
||||
"why-is-node-running": "^3.2.2"
|
||||
"why-is-node-running": "^3.2.2",
|
||||
"ws": "^8.20.0"
|
||||
},
|
||||
"dependencies": {
|
||||
"@push.rocks/smartcrypto": "^2.0.4",
|
||||
"@push.rocks/smartlog": "^3.2.1",
|
||||
"@push.rocks/smartnftables": "^1.0.1",
|
||||
"@push.rocks/smartrust": "^1.3.2",
|
||||
"@tsclass/tsclass": "^9.5.0",
|
||||
"minimatch": "^10.2.4"
|
||||
|
||||
Generated
+44
-15
@@ -14,6 +14,9 @@ importers:
|
||||
'@push.rocks/smartlog':
|
||||
specifier: ^3.2.1
|
||||
version: 3.2.1
|
||||
'@push.rocks/smartnftables':
|
||||
specifier: ^1.0.1
|
||||
version: 1.0.1
|
||||
'@push.rocks/smartrust':
|
||||
specifier: ^1.3.2
|
||||
version: 1.3.2
|
||||
@@ -42,12 +45,18 @@ importers:
|
||||
'@types/node':
|
||||
specifier: ^25.5.0
|
||||
version: 25.5.0
|
||||
'@types/ws':
|
||||
specifier: ^8.18.1
|
||||
version: 8.18.1
|
||||
typescript:
|
||||
specifier: ^6.0.2
|
||||
version: 6.0.2
|
||||
why-is-node-running:
|
||||
specifier: ^3.2.2
|
||||
version: 3.2.2
|
||||
ws:
|
||||
specifier: ^8.20.0
|
||||
version: 8.20.0
|
||||
|
||||
packages:
|
||||
|
||||
@@ -468,89 +477,105 @@ packages:
|
||||
resolution: {integrity: sha512-excjX8DfsIcJ10x1Kzr4RcWe1edC9PquDRRPx3YVCvQv+U5p7Yin2s32ftzikXojb1PIFc/9Mt28/y+iRklkrw==}
|
||||
cpu: [arm64]
|
||||
os: [linux]
|
||||
libc: [glibc]
|
||||
|
||||
'@img/sharp-libvips-linux-arm@1.2.4':
|
||||
resolution: {integrity: sha512-bFI7xcKFELdiNCVov8e44Ia4u2byA+l3XtsAj+Q8tfCwO6BQ8iDojYdvoPMqsKDkuoOo+X6HZA0s0q11ANMQ8A==}
|
||||
cpu: [arm]
|
||||
os: [linux]
|
||||
libc: [glibc]
|
||||
|
||||
'@img/sharp-libvips-linux-ppc64@1.2.4':
|
||||
resolution: {integrity: sha512-FMuvGijLDYG6lW+b/UvyilUWu5Ayu+3r2d1S8notiGCIyYU/76eig1UfMmkZ7vwgOrzKzlQbFSuQfgm7GYUPpA==}
|
||||
cpu: [ppc64]
|
||||
os: [linux]
|
||||
libc: [glibc]
|
||||
|
||||
'@img/sharp-libvips-linux-riscv64@1.2.4':
|
||||
resolution: {integrity: sha512-oVDbcR4zUC0ce82teubSm+x6ETixtKZBh/qbREIOcI3cULzDyb18Sr/Wcyx7NRQeQzOiHTNbZFF1UwPS2scyGA==}
|
||||
cpu: [riscv64]
|
||||
os: [linux]
|
||||
libc: [glibc]
|
||||
|
||||
'@img/sharp-libvips-linux-s390x@1.2.4':
|
||||
resolution: {integrity: sha512-qmp9VrzgPgMoGZyPvrQHqk02uyjA0/QrTO26Tqk6l4ZV0MPWIW6LTkqOIov+J1yEu7MbFQaDpwdwJKhbJvuRxQ==}
|
||||
cpu: [s390x]
|
||||
os: [linux]
|
||||
libc: [glibc]
|
||||
|
||||
'@img/sharp-libvips-linux-x64@1.2.4':
|
||||
resolution: {integrity: sha512-tJxiiLsmHc9Ax1bz3oaOYBURTXGIRDODBqhveVHonrHJ9/+k89qbLl0bcJns+e4t4rvaNBxaEZsFtSfAdquPrw==}
|
||||
cpu: [x64]
|
||||
os: [linux]
|
||||
libc: [glibc]
|
||||
|
||||
'@img/sharp-libvips-linuxmusl-arm64@1.2.4':
|
||||
resolution: {integrity: sha512-FVQHuwx1IIuNow9QAbYUzJ+En8KcVm9Lk5+uGUQJHaZmMECZmOlix9HnH7n1TRkXMS0pGxIJokIVB9SuqZGGXw==}
|
||||
cpu: [arm64]
|
||||
os: [linux]
|
||||
libc: [musl]
|
||||
|
||||
'@img/sharp-libvips-linuxmusl-x64@1.2.4':
|
||||
resolution: {integrity: sha512-+LpyBk7L44ZIXwz/VYfglaX/okxezESc6UxDSoyo2Ks6Jxc4Y7sGjpgU9s4PMgqgjj1gZCylTieNamqA1MF7Dg==}
|
||||
cpu: [x64]
|
||||
os: [linux]
|
||||
libc: [musl]
|
||||
|
||||
'@img/sharp-linux-arm64@0.34.5':
|
||||
resolution: {integrity: sha512-bKQzaJRY/bkPOXyKx5EVup7qkaojECG6NLYswgktOZjaXecSAeCWiZwwiFf3/Y+O1HrauiE3FVsGxFg8c24rZg==}
|
||||
engines: {node: ^18.17.0 || ^20.3.0 || >=21.0.0}
|
||||
cpu: [arm64]
|
||||
os: [linux]
|
||||
libc: [glibc]
|
||||
|
||||
'@img/sharp-linux-arm@0.34.5':
|
||||
resolution: {integrity: sha512-9dLqsvwtg1uuXBGZKsxem9595+ujv0sJ6Vi8wcTANSFpwV/GONat5eCkzQo/1O6zRIkh0m/8+5BjrRr7jDUSZw==}
|
||||
engines: {node: ^18.17.0 || ^20.3.0 || >=21.0.0}
|
||||
cpu: [arm]
|
||||
os: [linux]
|
||||
libc: [glibc]
|
||||
|
||||
'@img/sharp-linux-ppc64@0.34.5':
|
||||
resolution: {integrity: sha512-7zznwNaqW6YtsfrGGDA6BRkISKAAE1Jo0QdpNYXNMHu2+0dTrPflTLNkpc8l7MUP5M16ZJcUvysVWWrMefZquA==}
|
||||
engines: {node: ^18.17.0 || ^20.3.0 || >=21.0.0}
|
||||
cpu: [ppc64]
|
||||
os: [linux]
|
||||
libc: [glibc]
|
||||
|
||||
'@img/sharp-linux-riscv64@0.34.5':
|
||||
resolution: {integrity: sha512-51gJuLPTKa7piYPaVs8GmByo7/U7/7TZOq+cnXJIHZKavIRHAP77e3N2HEl3dgiqdD/w0yUfiJnII77PuDDFdw==}
|
||||
engines: {node: ^18.17.0 || ^20.3.0 || >=21.0.0}
|
||||
cpu: [riscv64]
|
||||
os: [linux]
|
||||
libc: [glibc]
|
||||
|
||||
'@img/sharp-linux-s390x@0.34.5':
|
||||
resolution: {integrity: sha512-nQtCk0PdKfho3eC5MrbQoigJ2gd1CgddUMkabUj+rBevs8tZ2cULOx46E7oyX+04WGfABgIwmMC0VqieTiR4jg==}
|
||||
engines: {node: ^18.17.0 || ^20.3.0 || >=21.0.0}
|
||||
cpu: [s390x]
|
||||
os: [linux]
|
||||
libc: [glibc]
|
||||
|
||||
'@img/sharp-linux-x64@0.34.5':
|
||||
resolution: {integrity: sha512-MEzd8HPKxVxVenwAa+JRPwEC7QFjoPWuS5NZnBt6B3pu7EG2Ge0id1oLHZpPJdn3OQK+BQDiw9zStiHBTJQQQQ==}
|
||||
engines: {node: ^18.17.0 || ^20.3.0 || >=21.0.0}
|
||||
cpu: [x64]
|
||||
os: [linux]
|
||||
libc: [glibc]
|
||||
|
||||
'@img/sharp-linuxmusl-arm64@0.34.5':
|
||||
resolution: {integrity: sha512-fprJR6GtRsMt6Kyfq44IsChVZeGN97gTD331weR1ex1c1rypDEABN6Tm2xa1wE6lYb5DdEnk03NZPqA7Id21yg==}
|
||||
engines: {node: ^18.17.0 || ^20.3.0 || >=21.0.0}
|
||||
cpu: [arm64]
|
||||
os: [linux]
|
||||
libc: [musl]
|
||||
|
||||
'@img/sharp-linuxmusl-x64@0.34.5':
|
||||
resolution: {integrity: sha512-Jg8wNT1MUzIvhBFxViqrEhWDGzqymo3sV7z7ZsaWbZNDLXRJZoRGrjulp60YYtV4wfY8VIKcWidjojlLcWrd8Q==}
|
||||
engines: {node: ^18.17.0 || ^20.3.0 || >=21.0.0}
|
||||
cpu: [x64]
|
||||
os: [linux]
|
||||
libc: [musl]
|
||||
|
||||
'@img/sharp-wasm32@0.34.5':
|
||||
resolution: {integrity: sha512-OdWTEiVkY2PHwqkbBI8frFxQQFekHaSSkUIJkwzclWZe64O1X4UlUjqqqLaPbUpMOQk6FBu/HtlGXNblIs0huw==}
|
||||
@@ -978,6 +1003,9 @@ packages:
|
||||
'@push.rocks/smartnetwork@4.4.0':
|
||||
resolution: {integrity: sha512-OvFtz41cvQ7lcXwaIOhghNUUlNoMxvwKDctbDvMyuZyEH08SpLjhyv2FuKbKL/mgwA/WxakTbohoC8SW7t+kiw==}
|
||||
|
||||
'@push.rocks/smartnftables@1.0.1':
|
||||
resolution: {integrity: sha512-o822GH4J8dlEBvNLbm+CwU4h6isMUEh03tf2ZnOSWXc5iewRDdKdOCDwI/e+WdnGYWyv7gvH0DHztCmne6rTCg==}
|
||||
|
||||
'@push.rocks/smartnpm@2.0.6':
|
||||
resolution: {integrity: sha512-7anKDOjX6gXWs1IAc+YWz9ZZ8gDsTwaLh+CxRnGHjAawOmK788NrrgVCg2Fb3qojrPnoxecc46F8Ivp1BT7Izw==}
|
||||
|
||||
@@ -1121,36 +1149,42 @@ packages:
|
||||
engines: {node: ^20.19.0 || >=22.12.0}
|
||||
cpu: [arm64]
|
||||
os: [linux]
|
||||
libc: [glibc]
|
||||
|
||||
'@rolldown/binding-linux-arm64-musl@1.0.0-rc.11':
|
||||
resolution: {integrity: sha512-jfndI9tsfm4APzjNt6QdBkYwre5lRPUgHeDHoI7ydKUuJvz3lZeCfMsI56BZj+7BYqiKsJm7cfd/6KYV7ubrBg==}
|
||||
engines: {node: ^20.19.0 || >=22.12.0}
|
||||
cpu: [arm64]
|
||||
os: [linux]
|
||||
libc: [musl]
|
||||
|
||||
'@rolldown/binding-linux-ppc64-gnu@1.0.0-rc.11':
|
||||
resolution: {integrity: sha512-ZlFgw46NOAGMgcdvdYwAGu2Q+SLFA9LzbJLW+iyMOJyhj5wk6P3KEE9Gct4xWwSzFoPI7JCdYmYMzVtlgQ+zfw==}
|
||||
engines: {node: ^20.19.0 || >=22.12.0}
|
||||
cpu: [ppc64]
|
||||
os: [linux]
|
||||
libc: [glibc]
|
||||
|
||||
'@rolldown/binding-linux-s390x-gnu@1.0.0-rc.11':
|
||||
resolution: {integrity: sha512-hIOYmuT6ofM4K04XAZd3OzMySEO4K0/nc9+jmNcxNAxRi6c5UWpqfw3KMFV4MVFWL+jQsSh+bGw2VqmaPMTLyw==}
|
||||
engines: {node: ^20.19.0 || >=22.12.0}
|
||||
cpu: [s390x]
|
||||
os: [linux]
|
||||
libc: [glibc]
|
||||
|
||||
'@rolldown/binding-linux-x64-gnu@1.0.0-rc.11':
|
||||
resolution: {integrity: sha512-qXBQQO9OvkjjQPLdUVr7Nr2t3QTZI7s4KZtfw7HzBgjbmAPSFwSv4rmET9lLSgq3rH/ndA3ngv3Qb8l2njoPNA==}
|
||||
engines: {node: ^20.19.0 || >=22.12.0}
|
||||
cpu: [x64]
|
||||
os: [linux]
|
||||
libc: [glibc]
|
||||
|
||||
'@rolldown/binding-linux-x64-musl@1.0.0-rc.11':
|
||||
resolution: {integrity: sha512-/tpFfoSTzUkH9LPY+cYbqZBDyyX62w5fICq9qzsHLL8uTI6BHip3Q9Uzft0wylk/i8OOwKik8OxW+QAhDmzwmg==}
|
||||
engines: {node: ^20.19.0 || >=22.12.0}
|
||||
cpu: [x64]
|
||||
os: [linux]
|
||||
libc: [musl]
|
||||
|
||||
'@rolldown/binding-openharmony-arm64@1.0.0-rc.11':
|
||||
resolution: {integrity: sha512-mcp3Rio2w72IvdZG0oQ4bM2c2oumtwHfUfKncUM6zGgz0KgPz4YmDPQfnXEiY5t3+KD/i8HG2rOB/LxdmieK2g==}
|
||||
@@ -1192,21 +1226,25 @@ packages:
|
||||
resolution: {integrity: sha512-Z4reus7UxGM4+JuhiIht8KuGP1KgM7nNhOlXUHcQCMswP/Rymj5oJQN3TDWgijFUZs09ULl8t3T+AQAVTd/WvA==}
|
||||
cpu: [arm64]
|
||||
os: [linux]
|
||||
libc: [glibc]
|
||||
|
||||
'@rspack/binding-linux-arm64-musl@1.7.10':
|
||||
resolution: {integrity: sha512-LYaoVmWizG4oQ3g+St3eM5qxsyfH07kLirP7NJcDMgvu3eQ29MeyTZ3ugkgW6LvlmJue7eTQyf6CZlanoF5SSg==}
|
||||
cpu: [arm64]
|
||||
os: [linux]
|
||||
libc: [musl]
|
||||
|
||||
'@rspack/binding-linux-x64-gnu@1.7.10':
|
||||
resolution: {integrity: sha512-aIm2G4Kcm3qxDTNqKarK0oaLY2iXnCmpRQQhAcMlR0aS2LmxL89XzVeRr9GFA1MzGrAsZONWCLkxQvn3WUbm4Q==}
|
||||
cpu: [x64]
|
||||
os: [linux]
|
||||
libc: [glibc]
|
||||
|
||||
'@rspack/binding-linux-x64-musl@1.7.10':
|
||||
resolution: {integrity: sha512-SIHQbAgB9IPH0H3H+i5rN5jo9yA/yTMq8b7XfRkTMvZ7P7MXxJ0dE8EJu3BmCLM19sqnTc2eX+SVfE8ZMDzghA==}
|
||||
cpu: [x64]
|
||||
os: [linux]
|
||||
libc: [musl]
|
||||
|
||||
'@rspack/binding-wasm32-wasi@1.7.10':
|
||||
resolution: {integrity: sha512-J9HDXHD1tj+9FmX4+K3CTkO7dCE2bootlR37YuC2Owc0Lwl1/i2oGT71KHnMqI9faF/hipAaQM5OywkiiuNB7w==}
|
||||
@@ -3272,18 +3310,6 @@ packages:
|
||||
wrappy@1.0.2:
|
||||
resolution: {integrity: sha1-tSQ9jz7BqjXxNkYFvA0QNuMKtp8=}
|
||||
|
||||
ws@8.19.0:
|
||||
resolution: {integrity: sha512-blAT2mjOEIi0ZzruJfIhb3nps74PRWTCz1IjglWEEpQl5XS/UNama6u2/rjFkDDouqr4L67ry+1aGIALViWjDg==}
|
||||
engines: {node: '>=10.0.0'}
|
||||
peerDependencies:
|
||||
bufferutil: ^4.0.1
|
||||
utf-8-validate: '>=5.0.2'
|
||||
peerDependenciesMeta:
|
||||
bufferutil:
|
||||
optional: true
|
||||
utf-8-validate:
|
||||
optional: true
|
||||
|
||||
ws@8.20.0:
|
||||
resolution: {integrity: sha512-sAt8BhgNbzCtgGbt2OxmpuryO63ZoDk/sqaB/znQm94T4fCEsy/yV+7CdC1kJhOU9lboAEU7R3kquuycDoibVA==}
|
||||
engines: {node: '>=10.0.0'}
|
||||
@@ -5130,6 +5156,11 @@ snapshots:
|
||||
transitivePeerDependencies:
|
||||
- supports-color
|
||||
|
||||
'@push.rocks/smartnftables@1.0.1':
|
||||
dependencies:
|
||||
'@push.rocks/smartlog': 3.2.1
|
||||
'@push.rocks/smartpromise': 4.2.3
|
||||
|
||||
'@push.rocks/smartnpm@2.0.6':
|
||||
dependencies:
|
||||
'@push.rocks/consolecolor': 2.0.3
|
||||
@@ -5259,7 +5290,7 @@ snapshots:
|
||||
'@push.rocks/smartenv': 6.0.0
|
||||
'@push.rocks/smartlog': 3.2.1
|
||||
'@push.rocks/smartpath': 6.0.0
|
||||
ws: 8.19.0
|
||||
ws: 8.20.0
|
||||
transitivePeerDependencies:
|
||||
- bufferutil
|
||||
- utf-8-validate
|
||||
@@ -7996,8 +8027,6 @@ snapshots:
|
||||
|
||||
wrappy@1.0.2: {}
|
||||
|
||||
ws@8.19.0: {}
|
||||
|
||||
ws@8.20.0: {}
|
||||
|
||||
xml-parse-from-string@1.0.1: {}
|
||||
|
||||
+42
-20
@@ -462,35 +462,57 @@ For TLS termination modes (`terminate` and `terminate-and-reencrypt`), SmartProx
|
||||
|
||||
**HTTP to HTTPS Redirect**:
|
||||
```typescript
|
||||
import { createHttpToHttpsRedirect } from '@push.rocks/smartproxy';
|
||||
import { SocketHandlers } from '@push.rocks/smartproxy';
|
||||
|
||||
const redirectRoute = createHttpToHttpsRedirect(['example.com', 'www.example.com']);
|
||||
const redirectRoute = {
|
||||
name: 'http-to-https',
|
||||
match: { ports: 80, domains: ['example.com', 'www.example.com'] },
|
||||
action: {
|
||||
type: 'socket-handler' as const,
|
||||
socketHandler: SocketHandlers.httpRedirect('https://{domain}:443{path}', 301)
|
||||
}
|
||||
};
|
||||
```
|
||||
|
||||
**Complete HTTPS Server (with redirect)**:
|
||||
```typescript
|
||||
import { createCompleteHttpsServer } from '@push.rocks/smartproxy';
|
||||
|
||||
const routes = createCompleteHttpsServer(
|
||||
'example.com',
|
||||
{ host: 'localhost', port: 8080 },
|
||||
{ certificate: 'auto' }
|
||||
);
|
||||
const routes = [
|
||||
{
|
||||
name: 'https-server',
|
||||
match: { ports: 443, domains: 'example.com' },
|
||||
action: {
|
||||
type: 'forward' as const,
|
||||
targets: [{ host: 'localhost', port: 8080 }],
|
||||
tls: { mode: 'terminate' as const, certificate: 'auto' as const }
|
||||
}
|
||||
},
|
||||
{
|
||||
name: 'http-redirect',
|
||||
match: { ports: 80, domains: 'example.com' },
|
||||
action: {
|
||||
type: 'socket-handler' as const,
|
||||
socketHandler: SocketHandlers.httpRedirect('https://{domain}:443{path}', 301)
|
||||
}
|
||||
}
|
||||
];
|
||||
```
|
||||
|
||||
**Load Balancer with Health Checks**:
|
||||
```typescript
|
||||
import { createLoadBalancerRoute } from '@push.rocks/smartproxy';
|
||||
|
||||
const lbRoute = createLoadBalancerRoute(
|
||||
'api.example.com',
|
||||
[
|
||||
{ host: 'backend1', port: 8080 },
|
||||
{ host: 'backend2', port: 8080 },
|
||||
{ host: 'backend3', port: 8080 }
|
||||
],
|
||||
{ tls: { mode: 'terminate', certificate: 'auto' } }
|
||||
);
|
||||
const lbRoute = {
|
||||
name: 'load-balancer',
|
||||
match: { ports: 443, domains: 'api.example.com' },
|
||||
action: {
|
||||
type: 'forward' as const,
|
||||
targets: [
|
||||
{ host: 'backend1', port: 8080 },
|
||||
{ host: 'backend2', port: 8080 },
|
||||
{ host: 'backend3', port: 8080 }
|
||||
],
|
||||
tls: { mode: 'terminate' as const, certificate: 'auto' as const },
|
||||
loadBalancing: { algorithm: 'round-robin' as const }
|
||||
}
|
||||
};
|
||||
```
|
||||
|
||||
### Smart SNI Requirement (v22.3+)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# @push.rocks/smartproxy 🚀
|
||||
|
||||
**A high-performance, Rust-powered proxy toolkit for Node.js** — unified route-based configuration for SSL/TLS termination, HTTP/HTTPS reverse proxying, WebSocket support, UDP/QUIC/HTTP3, load balancing, custom protocol handlers, and kernel-level NFTables forwarding.
|
||||
**A high-performance, Rust-powered proxy toolkit for Node.js** — unified route-based configuration for SSL/TLS termination, HTTP/HTTPS reverse proxying, WebSocket support, UDP/QUIC/HTTP3, load balancing, custom protocol handlers, and kernel-level NFTables forwarding via [`@push.rocks/smartnftables`](https://code.foss.global/push.rocks/smartnftables).
|
||||
|
||||
## 📦 Installation
|
||||
|
||||
@@ -44,7 +44,7 @@ Whether you're building microservices, deploying edge infrastructure, proxying U
|
||||
Get up and running in 30 seconds:
|
||||
|
||||
```typescript
|
||||
import { SmartProxy, createCompleteHttpsServer } from '@push.rocks/smartproxy';
|
||||
import { SmartProxy, SocketHandlers } from '@push.rocks/smartproxy';
|
||||
|
||||
// Create a proxy with automatic HTTPS
|
||||
const proxy = new SmartProxy({
|
||||
@@ -53,13 +53,25 @@ const proxy = new SmartProxy({
|
||||
useProduction: true
|
||||
},
|
||||
routes: [
|
||||
// Complete HTTPS setup in one call! ✨
|
||||
...createCompleteHttpsServer('app.example.com', {
|
||||
host: 'localhost',
|
||||
port: 3000
|
||||
}, {
|
||||
certificate: 'auto' // Automatic Let's Encrypt cert 🎩
|
||||
})
|
||||
// HTTPS route with automatic Let's Encrypt cert
|
||||
{
|
||||
name: 'https-app',
|
||||
match: { ports: 443, domains: 'app.example.com' },
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{ host: 'localhost', port: 3000 }],
|
||||
tls: { mode: 'terminate', certificate: 'auto' }
|
||||
}
|
||||
},
|
||||
// HTTP → HTTPS redirect
|
||||
{
|
||||
name: 'http-redirect',
|
||||
match: { ports: 80, domains: 'app.example.com' },
|
||||
action: {
|
||||
type: 'socket-handler',
|
||||
socketHandler: SocketHandlers.httpRedirect('https://{domain}:443{path}', 301)
|
||||
}
|
||||
}
|
||||
]
|
||||
});
|
||||
|
||||
@@ -111,31 +123,38 @@ SmartProxy supports three TLS handling modes:
|
||||
### 🌐 HTTP to HTTPS Redirect
|
||||
|
||||
```typescript
|
||||
import { SmartProxy, createHttpToHttpsRedirect } from '@push.rocks/smartproxy';
|
||||
import { SmartProxy, SocketHandlers } from '@push.rocks/smartproxy';
|
||||
|
||||
const proxy = new SmartProxy({
|
||||
routes: [
|
||||
createHttpToHttpsRedirect(['example.com', '*.example.com'])
|
||||
]
|
||||
routes: [{
|
||||
name: 'http-to-https',
|
||||
match: { ports: 80, domains: ['example.com', '*.example.com'] },
|
||||
action: {
|
||||
type: 'socket-handler',
|
||||
socketHandler: SocketHandlers.httpRedirect('https://{domain}:443{path}', 301)
|
||||
}
|
||||
}]
|
||||
});
|
||||
```
|
||||
|
||||
### ⚖️ Load Balancer with Health Checks
|
||||
|
||||
```typescript
|
||||
import { SmartProxy, createLoadBalancerRoute } from '@push.rocks/smartproxy';
|
||||
import { SmartProxy } from '@push.rocks/smartproxy';
|
||||
|
||||
const proxy = new SmartProxy({
|
||||
routes: [
|
||||
createLoadBalancerRoute(
|
||||
'app.example.com',
|
||||
[
|
||||
routes: [{
|
||||
name: 'load-balancer',
|
||||
match: { ports: 443, domains: 'app.example.com' },
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [
|
||||
{ host: 'server1.internal', port: 8080 },
|
||||
{ host: 'server2.internal', port: 8080 },
|
||||
{ host: 'server3.internal', port: 8080 }
|
||||
],
|
||||
{
|
||||
tls: { mode: 'terminate', certificate: 'auto' },
|
||||
tls: { mode: 'terminate', certificate: 'auto' },
|
||||
loadBalancing: {
|
||||
algorithm: 'round-robin',
|
||||
healthCheck: {
|
||||
path: '/health',
|
||||
@@ -145,57 +164,68 @@ const proxy = new SmartProxy({
|
||||
healthyThreshold: 2
|
||||
}
|
||||
}
|
||||
)
|
||||
]
|
||||
}
|
||||
}]
|
||||
});
|
||||
```
|
||||
|
||||
### 🔌 WebSocket Proxy
|
||||
|
||||
```typescript
|
||||
import { SmartProxy, createWebSocketRoute } from '@push.rocks/smartproxy';
|
||||
import { SmartProxy } from '@push.rocks/smartproxy';
|
||||
|
||||
const proxy = new SmartProxy({
|
||||
routes: [
|
||||
createWebSocketRoute(
|
||||
'ws.example.com',
|
||||
{ host: 'websocket-server', port: 8080 },
|
||||
{
|
||||
path: '/socket',
|
||||
useTls: true,
|
||||
certificate: 'auto',
|
||||
routes: [{
|
||||
name: 'websocket',
|
||||
match: { ports: 443, domains: 'ws.example.com', path: '/socket' },
|
||||
priority: 100,
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{ host: 'websocket-server', port: 8080 }],
|
||||
tls: { mode: 'terminate', certificate: 'auto' },
|
||||
websocket: {
|
||||
enabled: true,
|
||||
pingInterval: 30000,
|
||||
pingTimeout: 10000
|
||||
}
|
||||
)
|
||||
]
|
||||
}
|
||||
}]
|
||||
});
|
||||
```
|
||||
|
||||
### 🚦 API Gateway with Rate Limiting
|
||||
|
||||
```typescript
|
||||
import { SmartProxy, createApiGatewayRoute, addRateLimiting } from '@push.rocks/smartproxy';
|
||||
import { SmartProxy } from '@push.rocks/smartproxy';
|
||||
|
||||
let apiRoute = createApiGatewayRoute(
|
||||
'api.example.com',
|
||||
'/api',
|
||||
{ host: 'api-backend', port: 8080 },
|
||||
{
|
||||
useTls: true,
|
||||
certificate: 'auto',
|
||||
addCorsHeaders: true
|
||||
}
|
||||
);
|
||||
|
||||
// Add rate limiting — 100 requests per minute per IP
|
||||
apiRoute = addRateLimiting(apiRoute, {
|
||||
maxRequests: 100,
|
||||
window: 60,
|
||||
keyBy: 'ip'
|
||||
const proxy = new SmartProxy({
|
||||
routes: [{
|
||||
name: 'api-gateway',
|
||||
match: { ports: 443, domains: 'api.example.com', path: '/api/*' },
|
||||
priority: 100,
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{ host: 'api-backend', port: 8080 }],
|
||||
tls: { mode: 'terminate', certificate: 'auto' }
|
||||
},
|
||||
headers: {
|
||||
response: {
|
||||
'Access-Control-Allow-Origin': '*',
|
||||
'Access-Control-Allow-Methods': 'GET, POST, PUT, DELETE, OPTIONS',
|
||||
'Access-Control-Allow-Headers': 'Content-Type, Authorization',
|
||||
'Access-Control-Max-Age': '86400'
|
||||
}
|
||||
},
|
||||
security: {
|
||||
rateLimit: {
|
||||
enabled: true,
|
||||
maxRequests: 100,
|
||||
window: 60,
|
||||
keyBy: 'ip'
|
||||
}
|
||||
}
|
||||
}]
|
||||
});
|
||||
|
||||
const proxy = new SmartProxy({ routes: [apiRoute] });
|
||||
```
|
||||
|
||||
### 🎮 Custom Protocol Handler (TCP)
|
||||
@@ -203,36 +233,40 @@ const proxy = new SmartProxy({ routes: [apiRoute] });
|
||||
SmartProxy lets you implement any protocol with full socket control. Routes with JavaScript socket handlers are automatically relayed from the Rust engine back to your TypeScript code:
|
||||
|
||||
```typescript
|
||||
import { SmartProxy, createSocketHandlerRoute, SocketHandlers } from '@push.rocks/smartproxy';
|
||||
import { SmartProxy, SocketHandlers } from '@push.rocks/smartproxy';
|
||||
|
||||
// Use pre-built handlers
|
||||
const echoRoute = createSocketHandlerRoute(
|
||||
'echo.example.com',
|
||||
7777,
|
||||
SocketHandlers.echo
|
||||
);
|
||||
const proxy = new SmartProxy({
|
||||
routes: [
|
||||
// Use pre-built handlers
|
||||
{
|
||||
name: 'echo-server',
|
||||
match: { ports: 7777, domains: 'echo.example.com' },
|
||||
action: { type: 'socket-handler', socketHandler: SocketHandlers.echo }
|
||||
},
|
||||
// Or create your own custom protocol
|
||||
{
|
||||
name: 'custom-protocol',
|
||||
match: { ports: 9999, domains: 'custom.example.com' },
|
||||
action: {
|
||||
type: 'socket-handler',
|
||||
socketHandler: async (socket) => {
|
||||
console.log(`New connection on custom protocol`);
|
||||
socket.write('Welcome to my custom protocol!\n');
|
||||
|
||||
// Or create your own custom protocol
|
||||
const customRoute = createSocketHandlerRoute(
|
||||
'custom.example.com',
|
||||
9999,
|
||||
async (socket) => {
|
||||
console.log(`New connection on custom protocol`);
|
||||
socket.write('Welcome to my custom protocol!\n');
|
||||
|
||||
socket.on('data', (data) => {
|
||||
const command = data.toString().trim();
|
||||
switch (command) {
|
||||
case 'PING': socket.write('PONG\n'); break;
|
||||
case 'TIME': socket.write(`${new Date().toISOString()}\n`); break;
|
||||
case 'QUIT': socket.end('Goodbye!\n'); break;
|
||||
default: socket.write(`Unknown: ${command}\n`);
|
||||
socket.on('data', (data) => {
|
||||
const command = data.toString().trim();
|
||||
switch (command) {
|
||||
case 'PING': socket.write('PONG\n'); break;
|
||||
case 'TIME': socket.write(`${new Date().toISOString()}\n`); break;
|
||||
case 'QUIT': socket.end('Goodbye!\n'); break;
|
||||
default: socket.write(`Unknown: ${command}\n`);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
);
|
||||
|
||||
const proxy = new SmartProxy({ routes: [echoRoute, customRoute] });
|
||||
}
|
||||
]
|
||||
});
|
||||
```
|
||||
|
||||
**Pre-built Socket Handlers:**
|
||||
@@ -384,23 +418,26 @@ const dualStackRoute: IRouteConfig = {
|
||||
|
||||
### ⚡ High-Performance NFTables Forwarding
|
||||
|
||||
For ultra-low latency on Linux, use kernel-level forwarding (requires root):
|
||||
For ultra-low latency on Linux, use kernel-level forwarding via [`@push.rocks/smartnftables`](https://code.foss.global/push.rocks/smartnftables) (requires root):
|
||||
|
||||
```typescript
|
||||
import { SmartProxy, createNfTablesTerminateRoute } from '@push.rocks/smartproxy';
|
||||
import { SmartProxy } from '@push.rocks/smartproxy';
|
||||
|
||||
const proxy = new SmartProxy({
|
||||
routes: [
|
||||
createNfTablesTerminateRoute(
|
||||
'fast.example.com',
|
||||
{ host: 'backend', port: 8080 },
|
||||
{
|
||||
ports: 443,
|
||||
certificate: 'auto',
|
||||
routes: [{
|
||||
name: 'nftables-fast',
|
||||
match: { ports: 443, domains: 'fast.example.com' },
|
||||
action: {
|
||||
type: 'forward',
|
||||
forwardingEngine: 'nftables',
|
||||
targets: [{ host: 'backend', port: 8080 }],
|
||||
tls: { mode: 'terminate', certificate: 'auto' },
|
||||
nftables: {
|
||||
protocol: 'tcp',
|
||||
preserveSourceIP: true // Backend sees real client IP
|
||||
}
|
||||
)
|
||||
]
|
||||
}
|
||||
}]
|
||||
});
|
||||
```
|
||||
|
||||
@@ -409,15 +446,18 @@ const proxy = new SmartProxy({
|
||||
Forward encrypted traffic to backends without terminating TLS — the proxy routes based on the SNI hostname alone:
|
||||
|
||||
```typescript
|
||||
import { SmartProxy, createHttpsPassthroughRoute } from '@push.rocks/smartproxy';
|
||||
import { SmartProxy } from '@push.rocks/smartproxy';
|
||||
|
||||
const proxy = new SmartProxy({
|
||||
routes: [
|
||||
createHttpsPassthroughRoute('secure.example.com', {
|
||||
host: 'backend-that-handles-tls',
|
||||
port: 8443
|
||||
})
|
||||
]
|
||||
routes: [{
|
||||
name: 'sni-passthrough',
|
||||
match: { ports: 443, domains: 'secure.example.com' },
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{ host: 'backend-that-handles-tls', port: 8443 }],
|
||||
tls: { mode: 'passthrough' }
|
||||
}
|
||||
}]
|
||||
});
|
||||
```
|
||||
|
||||
@@ -524,15 +564,7 @@ Comprehensive per-route security options:
|
||||
}
|
||||
```
|
||||
|
||||
**Security modifier helpers** let you add security to any existing route:
|
||||
|
||||
```typescript
|
||||
import { addRateLimiting, addBasicAuth, addJwtAuth } from '@push.rocks/smartproxy';
|
||||
|
||||
let route = createHttpsTerminateRoute('api.example.com', { host: 'backend', port: 8080 });
|
||||
route = addRateLimiting(route, { maxRequests: 100, window: 60, keyBy: 'ip' });
|
||||
route = addBasicAuth(route, { users: [{ username: 'admin', password: 'secret' }] });
|
||||
```
|
||||
Security options are configured directly on each route's `security` property — no separate helpers needed.
|
||||
|
||||
### 📊 Runtime Management
|
||||
|
||||
@@ -694,22 +726,26 @@ SmartProxy uses a hybrid **Rust + TypeScript** architecture:
|
||||
│ │ Listener│ │ Reverse │ │ Matcher │ │ Cert Mgr │ │
|
||||
│ │ │ │ Proxy │ │ │ │ │ │
|
||||
│ └─────────┘ └─────────┘ └─────────┘ └──────────┘ │
|
||||
│ ┌─────────┐ ┌─────────┐ ┌─────────┐ ┌──────────┐ │
|
||||
│ │ UDP │ │ Security│ │ Metrics │ │ NFTables │ │
|
||||
│ │ QUIC │ │ Enforce │ │ Collect │ │ Mgr │ │
|
||||
│ │ HTTP/3 │ │ │ │ │ │ │ │
|
||||
│ └─────────┘ └─────────┘ └─────────┘ └──────────┘ │
|
||||
│ ┌─────────┐ ┌─────────┐ ┌─────────┐ │
|
||||
│ │ UDP │ │ Security│ │ Metrics │ │
|
||||
│ │ QUIC │ │ Enforce │ │ Collect │ │
|
||||
│ │ HTTP/3 │ │ │ │ │ │
|
||||
│ └─────────┘ └─────────┘ └─────────┘ │
|
||||
└──────────────────┬──────────────────────────────────┘
|
||||
│ Unix Socket Relay
|
||||
┌──────────────────▼──────────────────────────────────┐
|
||||
│ TypeScript Socket & Datagram Handler Servers │
|
||||
│ (for JS socket handlers, datagram handlers, │
|
||||
│ and dynamic routes) │
|
||||
├─────────────────────────────────────────────────────┤
|
||||
│ @push.rocks/smartnftables (kernel-level NFTables) │
|
||||
│ (DNAT/SNAT, firewall, rate limiting via nft CLI) │
|
||||
└─────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
- **Rust Engine** handles all networking: TCP, UDP, TLS, QUIC, HTTP proxying, connection management, security, and metrics
|
||||
- **TypeScript** provides the npm API, configuration types, route helpers, validation, and handler callbacks
|
||||
- **TypeScript** provides the npm API, configuration types, validation, and handler callbacks
|
||||
- **NFTables** managed by [`@push.rocks/smartnftables`](https://code.foss.global/push.rocks/smartnftables) — kernel-level DNAT/SNAT forwarding, firewall rules, and rate limiting via the `nft` CLI
|
||||
- **IPC** — The TypeScript wrapper uses JSON commands/events over stdin/stdout to communicate with the Rust binary
|
||||
- **Socket/Datagram Relay** — Unix domain socket servers for routes requiring TypeScript-side handling (socket handlers, datagram handlers, dynamic host/port functions)
|
||||
|
||||
@@ -854,47 +890,13 @@ interface IRouteQuic {
|
||||
}
|
||||
```
|
||||
|
||||
## 🛠️ Helper Functions Reference
|
||||
|
||||
All helpers are fully typed and return `IRouteConfig` or `IRouteConfig[]`:
|
||||
## 🛠️ Exports Reference
|
||||
|
||||
```typescript
|
||||
import {
|
||||
// HTTP/HTTPS
|
||||
createHttpRoute, // Plain HTTP route
|
||||
createHttpsTerminateRoute, // HTTPS with TLS termination
|
||||
createHttpsPassthroughRoute, // SNI passthrough (no termination)
|
||||
createHttpToHttpsRedirect, // HTTP → HTTPS redirect
|
||||
createCompleteHttpsServer, // HTTPS + redirect combo (returns IRouteConfig[])
|
||||
|
||||
// Load Balancing
|
||||
createLoadBalancerRoute, // Multi-backend with health checks
|
||||
createSmartLoadBalancer, // Dynamic domain-based backend selection
|
||||
|
||||
// API & WebSocket
|
||||
createApiRoute, // API route with path matching
|
||||
createApiGatewayRoute, // API gateway with CORS
|
||||
createWebSocketRoute, // WebSocket-enabled route
|
||||
|
||||
// Custom Protocols
|
||||
createSocketHandlerRoute, // Custom TCP socket handler
|
||||
SocketHandlers, // Pre-built handlers (echo, proxy, block, etc.)
|
||||
|
||||
// NFTables (Linux, requires root)
|
||||
createNfTablesRoute, // Kernel-level packet forwarding
|
||||
createNfTablesTerminateRoute, // NFTables + TLS termination
|
||||
createCompleteNfTablesHttpsServer, // NFTables HTTPS + redirect combo
|
||||
|
||||
// Dynamic Routing
|
||||
createPortMappingRoute, // Port mapping with context
|
||||
createOffsetPortMappingRoute, // Simple port offset
|
||||
createDynamicRoute, // Dynamic host/port via functions
|
||||
createPortOffset, // Port offset factory
|
||||
|
||||
// Security Modifiers
|
||||
addRateLimiting, // Add rate limiting to any route
|
||||
addBasicAuth, // Add basic auth to any route
|
||||
addJwtAuth, // Add JWT auth to any route
|
||||
// Core
|
||||
SmartProxy, // Main proxy class
|
||||
SocketHandlers, // Pre-built socket handlers (echo, proxy, block, httpRedirect, httpServer, etc.)
|
||||
|
||||
// Route Utilities
|
||||
mergeRouteConfigs, // Deep-merge two route configs
|
||||
@@ -906,7 +908,7 @@ import {
|
||||
} from '@push.rocks/smartproxy';
|
||||
```
|
||||
|
||||
> **Tip:** For UDP datagram handler routes or QUIC/HTTP3 routes, construct `IRouteConfig` objects directly — there are no helper functions for these yet. See the [UDP Datagram Handler](#-udp-datagram-handler) and [QUIC / HTTP3 Forwarding](#-quic--http3-forwarding) examples above.
|
||||
All routes are configured as plain `IRouteConfig` objects with `match` and `action` properties — see the examples throughout this document.
|
||||
|
||||
## 📖 API Documentation
|
||||
|
||||
@@ -938,8 +940,8 @@ class SmartProxy extends EventEmitter {
|
||||
getCertificateStatus(routeName: string): Promise<any>;
|
||||
getEligibleDomainsForCertificates(): string[];
|
||||
|
||||
// NFTables
|
||||
getNfTablesStatus(): Promise<Record<string, any>>;
|
||||
// NFTables (managed by @push.rocks/smartnftables)
|
||||
getNfTablesStatus(): INftStatus | null;
|
||||
|
||||
// Events
|
||||
on(event: 'error', handler: (err: Error) => void): this;
|
||||
@@ -991,11 +993,11 @@ interface ISmartProxyOptions {
|
||||
sendProxyProtocol?: boolean; // Send PROXY protocol to targets
|
||||
|
||||
// Timeouts
|
||||
connectionTimeout?: number; // Backend connection timeout (default: 30s)
|
||||
initialDataTimeout?: number; // Initial data/SNI timeout (default: 120s)
|
||||
socketTimeout?: number; // Socket inactivity timeout (default: 1h)
|
||||
maxConnectionLifetime?: number; // Max connection lifetime (default: 24h)
|
||||
inactivityTimeout?: number; // Inactivity timeout (default: 4h)
|
||||
connectionTimeout?: number; // Backend connection timeout (default: 60s)
|
||||
initialDataTimeout?: number; // Initial data/SNI timeout (default: 60s)
|
||||
socketTimeout?: number; // Socket inactivity timeout (default: 60s)
|
||||
maxConnectionLifetime?: number; // Max connection lifetime (default: 1h)
|
||||
inactivityTimeout?: number; // Inactivity timeout (default: 75s)
|
||||
gracefulShutdownTimeout?: number; // Shutdown grace period (default: 30s)
|
||||
|
||||
// Connection limits
|
||||
@@ -1004,8 +1006,8 @@ interface ISmartProxyOptions {
|
||||
|
||||
// Keep-alive
|
||||
keepAliveTreatment?: 'standard' | 'extended' | 'immortal';
|
||||
keepAliveInactivityMultiplier?: number; // (default: 6)
|
||||
extendedKeepAliveLifetime?: number; // (default: 7 days)
|
||||
keepAliveInactivityMultiplier?: number; // (default: 4)
|
||||
extendedKeepAliveLifetime?: number; // (default: 1h)
|
||||
|
||||
// Metrics
|
||||
metrics?: {
|
||||
@@ -1137,7 +1139,7 @@ SmartProxy searches for the Rust binary in this order:
|
||||
|
||||
## License and Legal Information
|
||||
|
||||
This repository contains open-source code licensed under the MIT License. A copy of the license can be found in the [LICENSE](./LICENSE) file.
|
||||
This repository contains open-source code licensed under the MIT License. A copy of the license can be found in the [license](./license) file.
|
||||
|
||||
**Please note:** The MIT License does not grant permission to use the trade names, trademarks, service marks, or product names of the project, except as required for reasonable and customary use in describing the origin of the work and reproducing the content of the NOTICE file.
|
||||
|
||||
|
||||
@@ -0,0 +1,484 @@
|
||||
# SmartProxy Metrics System
|
||||
|
||||
## Architecture
|
||||
|
||||
Two-tier design separating the data plane from the observation plane:
|
||||
|
||||
**Hot path (per-chunk, lock-free):** All recording in the proxy data plane touches only `AtomicU64` counters. No `Mutex` is ever acquired on the forwarding path. `CountingBody` batches flushes every 64KB to reduce DashMap shard contention.
|
||||
|
||||
**Cold path (1Hz sampling):** A background tokio task drains pending atomics into `ThroughputTracker` circular buffers (Mutex-guarded), producing per-second throughput history. Same task prunes orphaned entries and cleans up rate limiter state.
|
||||
|
||||
**Read path (on-demand):** `snapshot()` reads all atomics and locks ThroughputTrackers to build a serializable `Metrics` struct. TypeScript polls this at 1s intervals via IPC.
|
||||
|
||||
```
|
||||
Data Plane (lock-free) Background (1Hz) Read Path
|
||||
───────────────────── ────────────────── ─────────
|
||||
record_bytes() ──> AtomicU64 ──┐
|
||||
record_http_request() ──> AtomicU64 ──┤
|
||||
connection_opened/closed() ──> AtomicU64 ──┤ sample_all() snapshot()
|
||||
backend_*() ──> DashMap<AtomicU64> ──┤────> drain atomics ──────> Metrics struct
|
||||
protocol_*() ──> AtomicU64 ──┤ feed ThroughputTrackers ──> JSON
|
||||
datagram_*() ──> AtomicU64 ──┘ prune orphans ──> IPC stdout
|
||||
──> TS cache
|
||||
──> IMetrics API
|
||||
```
|
||||
|
||||
### Key Types
|
||||
|
||||
| Type | Crate | Purpose |
|
||||
|---|---|---|
|
||||
| `MetricsCollector` | `rustproxy-metrics` | Central store. All DashMaps, atomics, and throughput trackers |
|
||||
| `ThroughputTracker` | `rustproxy-metrics` | Circular buffer of 1Hz samples. Default 3600 capacity (1 hour) |
|
||||
| `ForwardMetricsCtx` | `rustproxy-passthrough` | Carries `Arc<MetricsCollector>` + route_id + source_ip through TCP forwarding |
|
||||
| `CountingBody` | `rustproxy-http` | Wraps HTTP bodies, batches byte recording per 64KB, flushes on drop |
|
||||
| `ProtocolGuard` | `rustproxy-http` | RAII guard for frontend/backend protocol active/total counters |
|
||||
| `ConnectionGuard` | `rustproxy-passthrough` | RAII guard calling `connection_closed()` on drop |
|
||||
| `RustMetricsAdapter` | TypeScript | Polls Rust via IPC, implements `IMetrics` interface over cached JSON |
|
||||
|
||||
---
|
||||
|
||||
## What's Collected
|
||||
|
||||
### Global Counters
|
||||
|
||||
| Metric | Type | Updated by |
|
||||
|---|---|---|
|
||||
| Active connections | AtomicU64 | `connection_opened/closed` |
|
||||
| Total connections (lifetime) | AtomicU64 | `connection_opened` |
|
||||
| Total bytes in | AtomicU64 | `record_bytes` |
|
||||
| Total bytes out | AtomicU64 | `record_bytes` |
|
||||
| Total HTTP requests | AtomicU64 | `record_http_request` |
|
||||
| Active UDP sessions | AtomicU64 | `udp_session_opened/closed` |
|
||||
| Total UDP sessions | AtomicU64 | `udp_session_opened` |
|
||||
| Total datagrams in | AtomicU64 | `record_datagram_in` |
|
||||
| Total datagrams out | AtomicU64 | `record_datagram_out` |
|
||||
|
||||
### Per-Route Metrics (keyed by route ID string)
|
||||
|
||||
| Metric | Storage |
|
||||
|---|---|
|
||||
| Active connections | `DashMap<String, AtomicU64>` |
|
||||
| Total connections | `DashMap<String, AtomicU64>` |
|
||||
| Bytes in / out | `DashMap<String, AtomicU64>` |
|
||||
| Pending throughput (in, out) | `DashMap<String, (AtomicU64, AtomicU64)>` |
|
||||
| Throughput history | `DashMap<String, Mutex<ThroughputTracker>>` |
|
||||
|
||||
Entries are pruned via `retain_routes()` when routes are removed.
|
||||
|
||||
### Per-IP Metrics (keyed by IP string)
|
||||
|
||||
| Metric | Storage |
|
||||
|---|---|
|
||||
| Active connections | `DashMap<String, AtomicU64>` |
|
||||
| Total connections | `DashMap<String, AtomicU64>` |
|
||||
| Bytes in / out | `DashMap<String, AtomicU64>` |
|
||||
| Pending throughput (in, out) | `DashMap<String, (AtomicU64, AtomicU64)>` |
|
||||
| Throughput history | `DashMap<String, Mutex<ThroughputTracker>>` |
|
||||
| Domain requests | `DashMap<String, DashMap<String, AtomicU64>>` (IP → domain → count) |
|
||||
|
||||
All seven maps for an IP are evicted when its active connection count drops to 0. Safety-net pruning in `sample_all()` catches entries orphaned by races. Snapshots cap at 100 IPs, sorted by active connections descending.
|
||||
|
||||
**Domain request tracking:** Records which domains each frontend IP has requested. Populated from HTTP Host headers (for HTTP/1.1, HTTP/2, HTTP/3 requests) and from SNI (for TLS passthrough connections). Domain keys are normalized to lowercase with any trailing dot stripped so the same hostname does not fragment across multiple counters. Capped at 256 domains per IP (`MAX_DOMAINS_PER_IP`) to prevent subdomain-spray abuse. Inner DashMap uses 2 shards to minimise base memory per IP (~200 bytes). Common case (IP + domain both known) is two DashMap reads + one atomic increment with zero allocation.
|
||||
|
||||
### Per-Backend Metrics (keyed by "host:port")
|
||||
|
||||
| Metric | Storage |
|
||||
|---|---|
|
||||
| Active connections | `DashMap<String, AtomicU64>` |
|
||||
| Total connections | `DashMap<String, AtomicU64>` |
|
||||
| Detected protocol (h1/h2/h3) | `DashMap<String, String>` |
|
||||
| Connect errors | `DashMap<String, AtomicU64>` |
|
||||
| Handshake errors | `DashMap<String, AtomicU64>` |
|
||||
| Request errors | `DashMap<String, AtomicU64>` |
|
||||
| Total connect time (microseconds) | `DashMap<String, AtomicU64>` |
|
||||
| Connect count | `DashMap<String, AtomicU64>` |
|
||||
| Pool hits | `DashMap<String, AtomicU64>` |
|
||||
| Pool misses | `DashMap<String, AtomicU64>` |
|
||||
| H2 failures (fallback to H1) | `DashMap<String, AtomicU64>` |
|
||||
|
||||
All per-backend maps are evicted when active count reaches 0. Pruned via `retain_backends()`.
|
||||
|
||||
### Frontend Protocol Distribution
|
||||
|
||||
Tracked via `ProtocolGuard` RAII guards and `FrontendProtocolTracker`. Five protocol categories, each with active + total counters (AtomicU64):
|
||||
|
||||
| Protocol | Where detected |
|
||||
|---|---|
|
||||
| h1 | `FrontendProtocolTracker` on first HTTP/1.x request |
|
||||
| h2 | `FrontendProtocolTracker` on first HTTP/2 request |
|
||||
| h3 | `ProtocolGuard::frontend("h3")` in H3ProxyService |
|
||||
| ws | `ProtocolGuard::frontend("ws")` on WebSocket upgrade |
|
||||
| other | Fallback (TCP passthrough without HTTP) |
|
||||
|
||||
Uses `fetch_update` for saturating decrements to prevent underflow races. The same saturating-close pattern is also used for connection and UDP active gauges.
|
||||
|
||||
### Backend Protocol Distribution
|
||||
|
||||
Same five categories (h1/h2/h3/ws/other), tracked via `ProtocolGuard::backend()` at connection establishment. Backend h2 failures (fallback to h1) are separately counted.
|
||||
|
||||
### Throughput History
|
||||
|
||||
`ThroughputTracker` is a circular buffer storing `ThroughputSample { timestamp_ms, bytes_in, bytes_out }` at 1Hz.
|
||||
|
||||
- Global tracker: 1 instance, default 3600 capacity
|
||||
- Per-route trackers: 1 per active route
|
||||
- Per-IP trackers: 1 per connected IP (evicted with the IP)
|
||||
- HTTP request tracker: reuses ThroughputTracker with bytes_in = request count, bytes_out = 0
|
||||
|
||||
Query methods:
|
||||
- `instant()` — last 1 second average
|
||||
- `recent()` — last 10 seconds average
|
||||
- `throughput(N)` — last N seconds average
|
||||
- `history(N)` — last N raw samples in chronological order
|
||||
|
||||
Snapshots return 60 samples of global throughput history.
|
||||
|
||||
### Protocol Detection Cache
|
||||
|
||||
Not part of MetricsCollector. Maintained by `HttpProxyService`'s protocol detection system. Injected into the metrics snapshot at read time by `get_metrics()`.
|
||||
|
||||
Each entry records: host, port, domain, detected protocol (h1/h2/h3), H3 port, age, last accessed, last probed, suppression flags, cooldown timers, consecutive failure counts.
|
||||
|
||||
---
|
||||
|
||||
## Instrumentation Points
|
||||
|
||||
### TCP Passthrough (`rustproxy-passthrough`)
|
||||
|
||||
**Connection lifecycle** — `tcp_listener.rs`:
|
||||
- Accept: `conn_tracker.connection_opened(&ip)` (rate limiter) + `ConnectionTrackerGuard` RAII
|
||||
- Route match: `metrics.connection_opened(route_id, source_ip)` + `ConnectionGuard` RAII
|
||||
- Close: Both guards call their respective `_closed()` methods on drop
|
||||
|
||||
**Byte recording** — `forwarder.rs` (`forward_bidirectional_with_timeouts`):
|
||||
- Initial peeked data recorded immediately
|
||||
- Per-chunk in both directions: `record_bytes(n, 0, ...)` / `record_bytes(0, n, ...)`
|
||||
- Same pattern in `forward_bidirectional_split_with_timeouts` (tcp_listener.rs) for TLS-terminated paths
|
||||
|
||||
### HTTP Proxy (`rustproxy-http`)
|
||||
|
||||
**Request counting** — `proxy_service.rs`:
|
||||
- `record_http_request()` called once per request after route matching succeeds
|
||||
|
||||
**Body byte counting** — `counting_body.rs` wrapping:
|
||||
- Request bodies: `CountingBody::new(body, ..., Direction::In)` — counts client-to-upstream bytes
|
||||
- Response bodies: `CountingBody::new(body, ..., Direction::Out)` — counts upstream-to-client bytes
|
||||
- Batched flush every 64KB (`BYTE_FLUSH_THRESHOLD = 65_536`), remainder flushed on drop
|
||||
- Also updates `connection_activity` atomic (idle watchdog) and `active_requests` counter (streaming detection)
|
||||
|
||||
**Backend metrics** — `proxy_service.rs`:
|
||||
- `backend_connection_opened(key, connect_time)` — after TCP/TLS connect succeeds
|
||||
- `backend_connection_closed(key)` — on teardown
|
||||
- `backend_connect_error(key)` — TCP/TLS connect failure or timeout
|
||||
- `backend_handshake_error(key)` — H1/H2 protocol handshake failure
|
||||
- `backend_request_error(key)` — send_request failure
|
||||
- `backend_h2_failure(key)` — H2 attempted, fell back to H1
|
||||
- `backend_pool_hit(key)` / `backend_pool_miss(key)` — connection pool reuse
|
||||
- `set_backend_protocol(key, proto)` — records detected protocol
|
||||
|
||||
**WebSocket** — `proxy_service.rs`:
|
||||
- Does NOT use CountingBody; records bytes directly per-chunk in both directions of the bidirectional copy loop
|
||||
|
||||
### QUIC (`rustproxy-passthrough`)
|
||||
|
||||
**Connection level** — `quic_handler.rs`:
|
||||
- `connection_opened` / `connection_closed` via `QuicConnGuard` RAII
|
||||
- `conn_tracker.connection_opened/closed` for rate limiting
|
||||
|
||||
**Stream level**:
|
||||
- For QUIC-to-TCP stream forwarding: `record_bytes(bytes_in, bytes_out, ...)` called once per stream at completion (post-hoc, not per-chunk)
|
||||
- For HTTP/3: delegates to `HttpProxyService.handle_request()`, so all HTTP proxy metrics apply
|
||||
|
||||
**H3 specifics** — `h3_service.rs`:
|
||||
- `ProtocolGuard::frontend("h3")` tracks the H3 connection
|
||||
- H3 request bodies: `record_bytes(data.len(), 0, ...)` called directly (not CountingBody) since H3 uses `stream.send_data()`
|
||||
- H3 response bodies: wrapped in CountingBody like HTTP/1 and HTTP/2
|
||||
|
||||
### UDP (`rustproxy-passthrough`)
|
||||
|
||||
**Session lifecycle** — `udp_listener.rs` / `udp_session.rs`:
|
||||
- `udp_session_opened()` + `connection_opened(route_id, source_ip)` on new session
|
||||
- `udp_session_closed()` + `connection_closed(route_id, source_ip)` on idle reap or port drain
|
||||
|
||||
**Datagram counting** — `udp_listener.rs`:
|
||||
- Inbound: `record_bytes(len, 0, ...)` + `record_datagram_in()`
|
||||
- Outbound (backend reply): `record_bytes(0, len, ...)` + `record_datagram_out()`
|
||||
|
||||
---
|
||||
|
||||
## Sampling Loop
|
||||
|
||||
`lib.rs` spawns a tokio task at configurable interval (default 1000ms):
|
||||
|
||||
```rust
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = cancel => break,
|
||||
_ = interval.tick() => {
|
||||
metrics.sample_all();
|
||||
conn_tracker.cleanup_stale_timestamps();
|
||||
http_proxy.cleanup_all_rate_limiters();
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
`sample_all()` performs in one pass:
|
||||
1. Drains `global_pending_tp_in/out` into global ThroughputTracker, samples
|
||||
2. Drains per-route pending counters into per-route trackers, samples each
|
||||
3. Samples idle route trackers (no new data) to advance their window
|
||||
4. Drains per-IP pending counters into per-IP trackers, samples each
|
||||
5. Drains `pending_http_requests` into HTTP request throughput tracker
|
||||
6. Prunes orphaned per-IP entries (bytes/throughput maps with no matching ip_connections key)
|
||||
7. Prunes orphaned per-backend entries (error/stats maps with no matching active/total key)
|
||||
|
||||
---
|
||||
|
||||
## Data Flow: Rust to TypeScript
|
||||
|
||||
```
|
||||
MetricsCollector.snapshot()
|
||||
├── reads all AtomicU64 counters
|
||||
├── iterates DashMaps (routes, IPs, backends)
|
||||
├── locks ThroughputTrackers for instant/recent rates + history
|
||||
└── produces Metrics struct
|
||||
|
||||
RustProxy::get_metrics()
|
||||
├── calls snapshot()
|
||||
├── enriches with detectedProtocols from HTTP proxy protocol cache
|
||||
└── returns Metrics
|
||||
|
||||
management.rs "getMetrics" IPC command
|
||||
├── calls get_metrics()
|
||||
├── serde_json::to_value (camelCase)
|
||||
└── writes JSON to stdout
|
||||
|
||||
RustProxyBridge (TypeScript)
|
||||
├── reads JSON from Rust process stdout
|
||||
└── returns parsed object
|
||||
|
||||
RustMetricsAdapter
|
||||
├── setInterval polls bridge.getMetrics() every 1s
|
||||
├── stores raw JSON in this.cache
|
||||
└── IMetrics methods read synchronously from cache
|
||||
|
||||
SmartProxy.getMetrics()
|
||||
└── returns the RustMetricsAdapter instance
|
||||
```
|
||||
|
||||
### IPC JSON Shape (Metrics)
|
||||
|
||||
```json
|
||||
{
|
||||
"activeConnections": 42,
|
||||
"totalConnections": 1000,
|
||||
"bytesIn": 123456789,
|
||||
"bytesOut": 987654321,
|
||||
"throughputInBytesPerSec": 50000,
|
||||
"throughputOutBytesPerSec": 80000,
|
||||
"throughputRecentInBytesPerSec": 45000,
|
||||
"throughputRecentOutBytesPerSec": 75000,
|
||||
"routes": {
|
||||
"<route-id>": {
|
||||
"activeConnections": 5,
|
||||
"totalConnections": 100,
|
||||
"bytesIn": 0, "bytesOut": 0,
|
||||
"throughputInBytesPerSec": 0, "throughputOutBytesPerSec": 0,
|
||||
"throughputRecentInBytesPerSec": 0, "throughputRecentOutBytesPerSec": 0
|
||||
}
|
||||
},
|
||||
"ips": {
|
||||
"<ip>": {
|
||||
"activeConnections": 2, "totalConnections": 10,
|
||||
"bytesIn": 0, "bytesOut": 0,
|
||||
"throughputInBytesPerSec": 0, "throughputOutBytesPerSec": 0,
|
||||
"domainRequests": {
|
||||
"example.com": 4821,
|
||||
"api.example.com": 312
|
||||
}
|
||||
}
|
||||
},
|
||||
"backends": {
|
||||
"<host:port>": {
|
||||
"activeConnections": 3, "totalConnections": 50,
|
||||
"protocol": "h2",
|
||||
"connectErrors": 0, "handshakeErrors": 0, "requestErrors": 0,
|
||||
"totalConnectTimeUs": 150000, "connectCount": 50,
|
||||
"poolHits": 40, "poolMisses": 10, "h2Failures": 1
|
||||
}
|
||||
},
|
||||
"throughputHistory": [
|
||||
{ "timestampMs": 1713000000000, "bytesIn": 50000, "bytesOut": 80000 }
|
||||
],
|
||||
"totalHttpRequests": 5000,
|
||||
"httpRequestsPerSec": 100,
|
||||
"httpRequestsPerSecRecent": 95,
|
||||
"activeUdpSessions": 0, "totalUdpSessions": 5,
|
||||
"totalDatagramsIn": 1000, "totalDatagramsOut": 1000,
|
||||
"frontendProtocols": {
|
||||
"h1Active": 10, "h1Total": 500,
|
||||
"h2Active": 5, "h2Total": 200,
|
||||
"h3Active": 1, "h3Total": 50,
|
||||
"wsActive": 2, "wsTotal": 30,
|
||||
"otherActive": 0, "otherTotal": 0
|
||||
},
|
||||
"backendProtocols": { "...same shape..." },
|
||||
"detectedProtocols": [
|
||||
{
|
||||
"host": "backend", "port": 443, "domain": "example.com",
|
||||
"protocol": "h2", "h3Port": 443,
|
||||
"ageSecs": 120, "lastAccessedSecs": 5, "lastProbedSecs": 120,
|
||||
"h2Suppressed": false, "h3Suppressed": false,
|
||||
"h2CooldownRemainingSecs": null, "h3CooldownRemainingSecs": null,
|
||||
"h2ConsecutiveFailures": null, "h3ConsecutiveFailures": null
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### IPC JSON Shape (Statistics)
|
||||
|
||||
Lightweight administrative summary, fetched on-demand (not polled):
|
||||
|
||||
```json
|
||||
{
|
||||
"activeConnections": 42,
|
||||
"totalConnections": 1000,
|
||||
"routesCount": 5,
|
||||
"listeningPorts": [80, 443, 8443],
|
||||
"uptimeSeconds": 86400
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## TypeScript Consumer API
|
||||
|
||||
`SmartProxy.getMetrics()` returns an `IMetrics` object. All members are synchronous methods reading from the polled cache.
|
||||
|
||||
### connections
|
||||
|
||||
| Method | Return | Source |
|
||||
|---|---|---|
|
||||
| `active()` | `number` | `cache.activeConnections` |
|
||||
| `total()` | `number` | `cache.totalConnections` |
|
||||
| `byRoute()` | `Map<string, number>` | `cache.routes[name].activeConnections` |
|
||||
| `byIP()` | `Map<string, number>` | `cache.ips[ip].activeConnections` |
|
||||
| `topIPs(limit?)` | `Array<{ip, count}>` | `cache.ips` sorted by active desc, default 10 |
|
||||
| `domainRequestsByIP()` | `Map<string, Map<string, number>>` | `cache.ips[ip].domainRequests` |
|
||||
| `topDomainRequests(limit?)` | `Array<{ip, domain, count}>` | Flattened from all IPs, sorted by count desc, default 20 |
|
||||
| `frontendProtocols()` | `IProtocolDistribution` | `cache.frontendProtocols.*` |
|
||||
| `backendProtocols()` | `IProtocolDistribution` | `cache.backendProtocols.*` |
|
||||
|
||||
### throughput
|
||||
|
||||
| Method | Return | Source |
|
||||
|---|---|---|
|
||||
| `instant()` | `{in, out}` | `cache.throughputInBytesPerSec/Out` |
|
||||
| `recent()` | `{in, out}` | `cache.throughputRecentInBytesPerSec/Out` |
|
||||
| `average()` | `{in, out}` | Falls back to `instant()` (not wired to windowed average) |
|
||||
| `custom(seconds)` | `{in, out}` | Falls back to `instant()` (not wired) |
|
||||
| `history(seconds)` | `IThroughputHistoryPoint[]` | `cache.throughputHistory` sliced to last N entries |
|
||||
| `byRoute(windowSeconds?)` | `Map<string, {in, out}>` | `cache.routes[name].throughputIn/OutBytesPerSec` |
|
||||
| `byIP(windowSeconds?)` | `Map<string, {in, out}>` | `cache.ips[ip].throughputIn/OutBytesPerSec` |
|
||||
|
||||
### requests
|
||||
|
||||
| Method | Return | Source |
|
||||
|---|---|---|
|
||||
| `perSecond()` | `number` | `cache.httpRequestsPerSec` |
|
||||
| `perMinute()` | `number` | `cache.httpRequestsPerSecRecent * 60` |
|
||||
| `total()` | `number` | `cache.totalHttpRequests` (fallback: totalConnections) |
|
||||
|
||||
### totals
|
||||
|
||||
| Method | Return | Source |
|
||||
|---|---|---|
|
||||
| `bytesIn()` | `number` | `cache.bytesIn` |
|
||||
| `bytesOut()` | `number` | `cache.bytesOut` |
|
||||
| `connections()` | `number` | `cache.totalConnections` |
|
||||
|
||||
### backends
|
||||
|
||||
| Method | Return | Source |
|
||||
|---|---|---|
|
||||
| `byBackend()` | `Map<string, IBackendMetrics>` | `cache.backends[key].*` with computed `avgConnectTimeMs` and `poolHitRate` |
|
||||
| `protocols()` | `Map<string, string>` | `cache.backends[key].protocol` |
|
||||
| `topByErrors(limit?)` | `IBackendMetrics[]` | Sorted by total errors desc |
|
||||
| `detectedProtocols()` | `IProtocolCacheEntry[]` | `cache.detectedProtocols` passthrough |
|
||||
|
||||
`IBackendMetrics`: `{ protocol, activeConnections, totalConnections, connectErrors, handshakeErrors, requestErrors, avgConnectTimeMs, poolHitRate, h2Failures }`
|
||||
|
||||
### udp
|
||||
|
||||
| Method | Return | Source |
|
||||
|---|---|---|
|
||||
| `activeSessions()` | `number` | `cache.activeUdpSessions` |
|
||||
| `totalSessions()` | `number` | `cache.totalUdpSessions` |
|
||||
| `datagramsIn()` | `number` | `cache.totalDatagramsIn` |
|
||||
| `datagramsOut()` | `number` | `cache.totalDatagramsOut` |
|
||||
|
||||
### percentiles (stub)
|
||||
|
||||
`connectionDuration()` and `bytesTransferred()` always return zeros. Not implemented.
|
||||
|
||||
---
|
||||
|
||||
## Configuration
|
||||
|
||||
```typescript
|
||||
interface IMetricsConfig {
|
||||
enabled: boolean; // default true
|
||||
sampleIntervalMs: number; // default 1000 (1Hz sampling + TS polling)
|
||||
retentionSeconds: number; // default 3600 (ThroughputTracker capacity)
|
||||
enableDetailedTracking: boolean;
|
||||
enablePercentiles: boolean;
|
||||
cacheResultsMs: number;
|
||||
prometheusEnabled: boolean; // not wired
|
||||
prometheusPath: string; // not wired
|
||||
prometheusPrefix: string; // not wired
|
||||
}
|
||||
```
|
||||
|
||||
Rust-side config (`MetricsConfig` in `rustproxy-config`):
|
||||
|
||||
```rust
|
||||
pub struct MetricsConfig {
|
||||
pub enabled: Option<bool>,
|
||||
pub sample_interval_ms: Option<u64>, // default 1000
|
||||
pub retention_seconds: Option<u64>, // default 3600
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Design Decisions
|
||||
|
||||
**Lock-free hot path.** `record_bytes()` is the most frequently called method (per-chunk in TCP, per-64KB in HTTP). It only touches `AtomicU64` with `Relaxed` ordering and short-circuits zero-byte directions to skip DashMap lookups entirely.
|
||||
|
||||
**CountingBody batching.** HTTP body frames are typically 16KB. Flushing to MetricsCollector every 64KB reduces DashMap shard contention by ~4x compared to per-frame recording.
|
||||
|
||||
**RAII guards everywhere.** `ConnectionGuard`, `ConnectionTrackerGuard`, `QuicConnGuard`, `ProtocolGuard`, `FrontendProtocolTracker` all use Drop to guarantee counter cleanup on all exit paths including panics.
|
||||
|
||||
**Saturating decrements.** Protocol counters use `fetch_update` loops instead of `fetch_sub` to prevent underflow to `u64::MAX` from concurrent close races.
|
||||
|
||||
**Bounded memory.** Per-IP entries evicted on last connection close. Per-backend entries evicted on last connection close. Snapshot caps IPs and backends at 100 each. `sample_all()` prunes orphaned entries every second.
|
||||
|
||||
**Two-phase throughput.** Pending bytes accumulate in lock-free atomics. The 1Hz cold path drains them into Mutex-guarded ThroughputTrackers. This means the hot path never contends on a Mutex, while the cold path does minimal work (one drain + one sample per tracker).
|
||||
|
||||
---
|
||||
|
||||
## Known Gaps
|
||||
|
||||
| Gap | Status |
|
||||
|---|---|
|
||||
| `throughput.average()` / `throughput.custom(seconds)` | Fall back to `instant()`. Not wired to Rust windowed queries. |
|
||||
| `percentiles.connectionDuration()` / `percentiles.bytesTransferred()` | Stub returning zeros. No histogram in Rust. |
|
||||
| Prometheus export | Config fields exist but not wired to any exporter. |
|
||||
| `LogDeduplicator` | Implemented in `rustproxy-metrics` but not connected to any call site. |
|
||||
| Rate limit hit counters | Rate-limited requests return 429 but no counter is recorded in MetricsCollector. |
|
||||
| QUIC stream byte counting | Post-hoc (per-stream totals after close), not per-chunk like TCP. |
|
||||
| Throughput history in snapshot | Capped at 60 samples. TS `history(seconds)` cannot return more than 60 points regardless of `retentionSeconds`. |
|
||||
| Per-route total connections / bytes | Available in Rust JSON but `IMetrics.connections.byRoute()` only exposes active connections. |
|
||||
| Per-IP total connections / bytes | Available in Rust JSON but `IMetrics.connections.byIP()` only exposes active connections. |
|
||||
| IPC response typing | `RustProxyBridge` declares `result: any` for both metrics commands. No type-safe response. |
|
||||
Generated
+2
-15
@@ -1238,7 +1238,6 @@ dependencies = [
|
||||
"rustproxy-config",
|
||||
"rustproxy-http",
|
||||
"rustproxy-metrics",
|
||||
"rustproxy-nftables",
|
||||
"rustproxy-passthrough",
|
||||
"rustproxy-routing",
|
||||
"rustproxy-security",
|
||||
@@ -1270,6 +1269,7 @@ dependencies = [
|
||||
"arc-swap",
|
||||
"bytes",
|
||||
"dashmap",
|
||||
"futures",
|
||||
"h3",
|
||||
"h3-quinn",
|
||||
"http-body",
|
||||
@@ -1303,20 +1303,6 @@ dependencies = [
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustproxy-nftables"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"libc",
|
||||
"rustproxy-config",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"thiserror 2.0.18",
|
||||
"tokio",
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustproxy-passthrough"
|
||||
version = "0.1.0"
|
||||
@@ -1333,6 +1319,7 @@ dependencies = [
|
||||
"rustproxy-http",
|
||||
"rustproxy-metrics",
|
||||
"rustproxy-routing",
|
||||
"rustproxy-security",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"socket2 0.5.10",
|
||||
|
||||
@@ -7,7 +7,6 @@ members = [
|
||||
"crates/rustproxy-tls",
|
||||
"crates/rustproxy-passthrough",
|
||||
"crates/rustproxy-http",
|
||||
"crates/rustproxy-nftables",
|
||||
"crates/rustproxy-metrics",
|
||||
"crates/rustproxy-security",
|
||||
]
|
||||
@@ -107,6 +106,5 @@ rustproxy-routing = { path = "crates/rustproxy-routing" }
|
||||
rustproxy-tls = { path = "crates/rustproxy-tls" }
|
||||
rustproxy-passthrough = { path = "crates/rustproxy-passthrough" }
|
||||
rustproxy-http = { path = "crates/rustproxy-http" }
|
||||
rustproxy-nftables = { path = "crates/rustproxy-nftables" }
|
||||
rustproxy-metrics = { path = "crates/rustproxy-metrics" }
|
||||
rustproxy-security = { path = "crates/rustproxy-security" }
|
||||
|
||||
@@ -1,345 +0,0 @@
|
||||
use crate::route_types::*;
|
||||
use crate::tls_types::*;
|
||||
|
||||
/// Create a simple HTTP forwarding route.
|
||||
/// Equivalent to SmartProxy's `createHttpRoute()`.
|
||||
pub fn create_http_route(
|
||||
domains: impl Into<DomainSpec>,
|
||||
target_host: impl Into<String>,
|
||||
target_port: u16,
|
||||
) -> RouteConfig {
|
||||
RouteConfig {
|
||||
id: None,
|
||||
route_match: RouteMatch {
|
||||
ports: PortRange::Single(80),
|
||||
domains: Some(domains.into()),
|
||||
path: None,
|
||||
client_ip: None,
|
||||
transport: None,
|
||||
tls_version: None,
|
||||
headers: None,
|
||||
protocol: None,
|
||||
},
|
||||
action: RouteAction {
|
||||
action_type: RouteActionType::Forward,
|
||||
targets: Some(vec![RouteTarget {
|
||||
target_match: None,
|
||||
host: HostSpec::Single(target_host.into()),
|
||||
port: PortSpec::Fixed(target_port),
|
||||
tls: None,
|
||||
websocket: None,
|
||||
load_balancing: None,
|
||||
send_proxy_protocol: None,
|
||||
headers: None,
|
||||
advanced: None,
|
||||
backend_transport: None,
|
||||
priority: None,
|
||||
}]),
|
||||
tls: None,
|
||||
websocket: None,
|
||||
load_balancing: None,
|
||||
advanced: None,
|
||||
options: None,
|
||||
forwarding_engine: None,
|
||||
nftables: None,
|
||||
send_proxy_protocol: None,
|
||||
udp: None,
|
||||
},
|
||||
headers: None,
|
||||
security: None,
|
||||
name: None,
|
||||
description: None,
|
||||
priority: None,
|
||||
tags: None,
|
||||
enabled: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create an HTTPS termination route.
|
||||
/// Equivalent to SmartProxy's `createHttpsTerminateRoute()`.
|
||||
pub fn create_https_terminate_route(
|
||||
domains: impl Into<DomainSpec>,
|
||||
target_host: impl Into<String>,
|
||||
target_port: u16,
|
||||
) -> RouteConfig {
|
||||
let mut route = create_http_route(domains, target_host, target_port);
|
||||
route.route_match.ports = PortRange::Single(443);
|
||||
route.action.tls = Some(RouteTls {
|
||||
mode: TlsMode::Terminate,
|
||||
certificate: Some(CertificateSpec::Auto("auto".to_string())),
|
||||
acme: None,
|
||||
versions: None,
|
||||
ciphers: None,
|
||||
honor_cipher_order: None,
|
||||
session_timeout: None,
|
||||
});
|
||||
route
|
||||
}
|
||||
|
||||
/// Create a TLS passthrough route.
|
||||
/// Equivalent to SmartProxy's `createHttpsPassthroughRoute()`.
|
||||
pub fn create_https_passthrough_route(
|
||||
domains: impl Into<DomainSpec>,
|
||||
target_host: impl Into<String>,
|
||||
target_port: u16,
|
||||
) -> RouteConfig {
|
||||
let mut route = create_http_route(domains, target_host, target_port);
|
||||
route.route_match.ports = PortRange::Single(443);
|
||||
route.action.tls = Some(RouteTls {
|
||||
mode: TlsMode::Passthrough,
|
||||
certificate: None,
|
||||
acme: None,
|
||||
versions: None,
|
||||
ciphers: None,
|
||||
honor_cipher_order: None,
|
||||
session_timeout: None,
|
||||
});
|
||||
route
|
||||
}
|
||||
|
||||
/// Create an HTTP-to-HTTPS redirect route.
|
||||
/// Equivalent to SmartProxy's `createHttpToHttpsRedirect()`.
|
||||
pub fn create_http_to_https_redirect(
|
||||
domains: impl Into<DomainSpec>,
|
||||
) -> RouteConfig {
|
||||
let domains = domains.into();
|
||||
RouteConfig {
|
||||
id: None,
|
||||
route_match: RouteMatch {
|
||||
ports: PortRange::Single(80),
|
||||
domains: Some(domains),
|
||||
path: None,
|
||||
client_ip: None,
|
||||
transport: None,
|
||||
tls_version: None,
|
||||
headers: None,
|
||||
protocol: None,
|
||||
},
|
||||
action: RouteAction {
|
||||
action_type: RouteActionType::Forward,
|
||||
targets: None,
|
||||
tls: None,
|
||||
websocket: None,
|
||||
load_balancing: None,
|
||||
advanced: Some(RouteAdvanced {
|
||||
timeout: None,
|
||||
headers: None,
|
||||
keep_alive: None,
|
||||
static_files: None,
|
||||
test_response: Some(RouteTestResponse {
|
||||
status: 301,
|
||||
headers: {
|
||||
let mut h = std::collections::HashMap::new();
|
||||
h.insert("Location".to_string(), "https://{domain}{path}".to_string());
|
||||
h
|
||||
},
|
||||
body: String::new(),
|
||||
}),
|
||||
url_rewrite: None,
|
||||
}),
|
||||
options: None,
|
||||
forwarding_engine: None,
|
||||
nftables: None,
|
||||
send_proxy_protocol: None,
|
||||
udp: None,
|
||||
},
|
||||
headers: None,
|
||||
security: None,
|
||||
name: Some("HTTP to HTTPS Redirect".to_string()),
|
||||
description: None,
|
||||
priority: None,
|
||||
tags: None,
|
||||
enabled: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a complete HTTPS server with HTTP redirect.
|
||||
/// Equivalent to SmartProxy's `createCompleteHttpsServer()`.
|
||||
pub fn create_complete_https_server(
|
||||
domain: impl Into<String>,
|
||||
target_host: impl Into<String>,
|
||||
target_port: u16,
|
||||
) -> Vec<RouteConfig> {
|
||||
let domain = domain.into();
|
||||
let target_host = target_host.into();
|
||||
|
||||
vec![
|
||||
create_http_to_https_redirect(DomainSpec::Single(domain.clone())),
|
||||
create_https_terminate_route(
|
||||
DomainSpec::Single(domain),
|
||||
target_host,
|
||||
target_port,
|
||||
),
|
||||
]
|
||||
}
|
||||
|
||||
/// Create a load balancer route.
|
||||
/// Equivalent to SmartProxy's `createLoadBalancerRoute()`.
|
||||
pub fn create_load_balancer_route(
|
||||
domains: impl Into<DomainSpec>,
|
||||
targets: Vec<(String, u16)>,
|
||||
tls: Option<RouteTls>,
|
||||
) -> RouteConfig {
|
||||
let route_targets: Vec<RouteTarget> = targets
|
||||
.into_iter()
|
||||
.map(|(host, port)| RouteTarget {
|
||||
target_match: None,
|
||||
host: HostSpec::Single(host),
|
||||
port: PortSpec::Fixed(port),
|
||||
tls: None,
|
||||
websocket: None,
|
||||
load_balancing: None,
|
||||
send_proxy_protocol: None,
|
||||
headers: None,
|
||||
advanced: None,
|
||||
backend_transport: None,
|
||||
priority: None,
|
||||
})
|
||||
.collect();
|
||||
|
||||
let port = if tls.is_some() { 443 } else { 80 };
|
||||
|
||||
RouteConfig {
|
||||
id: None,
|
||||
route_match: RouteMatch {
|
||||
ports: PortRange::Single(port),
|
||||
domains: Some(domains.into()),
|
||||
path: None,
|
||||
client_ip: None,
|
||||
transport: None,
|
||||
tls_version: None,
|
||||
headers: None,
|
||||
protocol: None,
|
||||
},
|
||||
action: RouteAction {
|
||||
action_type: RouteActionType::Forward,
|
||||
targets: Some(route_targets),
|
||||
tls,
|
||||
websocket: None,
|
||||
load_balancing: Some(RouteLoadBalancing {
|
||||
algorithm: LoadBalancingAlgorithm::RoundRobin,
|
||||
health_check: None,
|
||||
}),
|
||||
advanced: None,
|
||||
options: None,
|
||||
forwarding_engine: None,
|
||||
nftables: None,
|
||||
send_proxy_protocol: None,
|
||||
udp: None,
|
||||
},
|
||||
headers: None,
|
||||
security: None,
|
||||
name: Some("Load Balancer".to_string()),
|
||||
description: None,
|
||||
priority: None,
|
||||
tags: None,
|
||||
enabled: None,
|
||||
}
|
||||
}
|
||||
|
||||
// Convenience conversions for DomainSpec
|
||||
impl From<&str> for DomainSpec {
|
||||
fn from(s: &str) -> Self {
|
||||
DomainSpec::Single(s.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<String> for DomainSpec {
|
||||
fn from(s: String) -> Self {
|
||||
DomainSpec::Single(s)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Vec<String>> for DomainSpec {
|
||||
fn from(v: Vec<String>) -> Self {
|
||||
DomainSpec::List(v)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Vec<&str>> for DomainSpec {
|
||||
fn from(v: Vec<&str>) -> Self {
|
||||
DomainSpec::List(v.into_iter().map(|s| s.to_string()).collect())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::tls_types::TlsMode;
|
||||
|
||||
#[test]
|
||||
fn test_create_http_route() {
|
||||
let route = create_http_route("example.com", "localhost", 8080);
|
||||
assert_eq!(route.route_match.ports.to_ports(), vec![80]);
|
||||
let domains = route.route_match.domains.as_ref().unwrap().to_vec();
|
||||
assert_eq!(domains, vec!["example.com"]);
|
||||
let target = &route.action.targets.as_ref().unwrap()[0];
|
||||
assert_eq!(target.host.first(), "localhost");
|
||||
assert_eq!(target.port.resolve(80), 8080);
|
||||
assert!(route.action.tls.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_https_terminate_route() {
|
||||
let route = create_https_terminate_route("api.example.com", "backend", 3000);
|
||||
assert_eq!(route.route_match.ports.to_ports(), vec![443]);
|
||||
let tls = route.action.tls.as_ref().unwrap();
|
||||
assert_eq!(tls.mode, TlsMode::Terminate);
|
||||
assert!(tls.certificate.as_ref().unwrap().is_auto());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_https_passthrough_route() {
|
||||
let route = create_https_passthrough_route("secure.example.com", "backend", 443);
|
||||
assert_eq!(route.route_match.ports.to_ports(), vec![443]);
|
||||
let tls = route.action.tls.as_ref().unwrap();
|
||||
assert_eq!(tls.mode, TlsMode::Passthrough);
|
||||
assert!(tls.certificate.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_http_to_https_redirect() {
|
||||
let route = create_http_to_https_redirect("example.com");
|
||||
assert_eq!(route.route_match.ports.to_ports(), vec![80]);
|
||||
assert!(route.action.targets.is_none());
|
||||
let test_response = route.action.advanced.as_ref().unwrap().test_response.as_ref().unwrap();
|
||||
assert_eq!(test_response.status, 301);
|
||||
assert!(test_response.headers.contains_key("Location"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_complete_https_server() {
|
||||
let routes = create_complete_https_server("example.com", "backend", 8080);
|
||||
assert_eq!(routes.len(), 2);
|
||||
// First route is HTTP redirect
|
||||
assert_eq!(routes[0].route_match.ports.to_ports(), vec![80]);
|
||||
// Second route is HTTPS terminate
|
||||
assert_eq!(routes[1].route_match.ports.to_ports(), vec![443]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_load_balancer_route() {
|
||||
let targets = vec![
|
||||
("backend1".to_string(), 8080),
|
||||
("backend2".to_string(), 8080),
|
||||
("backend3".to_string(), 8080),
|
||||
];
|
||||
let route = create_load_balancer_route("*.example.com", targets, None);
|
||||
assert_eq!(route.route_match.ports.to_ports(), vec![80]);
|
||||
assert_eq!(route.action.targets.as_ref().unwrap().len(), 3);
|
||||
let lb = route.action.load_balancing.as_ref().unwrap();
|
||||
assert_eq!(lb.algorithm, LoadBalancingAlgorithm::RoundRobin);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_domain_spec_from_str() {
|
||||
let spec: DomainSpec = "example.com".into();
|
||||
assert_eq!(spec.to_vec(), vec!["example.com"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_domain_spec_from_vec() {
|
||||
let spec: DomainSpec = vec!["a.com", "b.com"].into();
|
||||
assert_eq!(spec.to_vec(), vec!["a.com", "b.com"]);
|
||||
}
|
||||
}
|
||||
@@ -3,17 +3,15 @@
|
||||
//! Configuration types for RustProxy, fully compatible with SmartProxy's JSON schema.
|
||||
//! All types use `#[serde(rename_all = "camelCase")]` to match TypeScript field naming.
|
||||
|
||||
pub mod route_types;
|
||||
pub mod proxy_options;
|
||||
pub mod tls_types;
|
||||
pub mod route_types;
|
||||
pub mod security_types;
|
||||
pub mod tls_types;
|
||||
pub mod validation;
|
||||
pub mod helpers;
|
||||
|
||||
// Re-export all primary types
|
||||
pub use route_types::*;
|
||||
pub use proxy_options::*;
|
||||
pub use tls_types::*;
|
||||
pub use route_types::*;
|
||||
pub use security_types::*;
|
||||
pub use tls_types::*;
|
||||
pub use validation::*;
|
||||
pub use helpers::*;
|
||||
|
||||
@@ -97,6 +97,16 @@ pub struct MetricsConfig {
|
||||
pub retention_seconds: Option<u64>,
|
||||
}
|
||||
|
||||
/// Global ingress security policy.
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct SecurityPolicy {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub blocked_ips: Option<Vec<String>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub blocked_cidrs: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
/// RustProxy configuration options.
|
||||
/// Matches TypeScript: `ISmartProxyOptions`
|
||||
///
|
||||
@@ -129,7 +139,6 @@ pub struct RustProxyOptions {
|
||||
pub defaults: Option<DefaultConfig>,
|
||||
|
||||
// ─── Timeout Settings ────────────────────────────────────────────
|
||||
|
||||
/// Timeout for establishing connection to backend (ms), default: 30000
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub connection_timeout: Option<u64>,
|
||||
@@ -159,7 +168,6 @@ pub struct RustProxyOptions {
|
||||
pub graceful_shutdown_timeout: Option<u64>,
|
||||
|
||||
// ─── Socket Optimization ─────────────────────────────────────────
|
||||
|
||||
/// Disable Nagle's algorithm (default: true)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub no_delay: Option<bool>,
|
||||
@@ -177,7 +185,6 @@ pub struct RustProxyOptions {
|
||||
pub max_pending_data_size: Option<u64>,
|
||||
|
||||
// ─── Enhanced Features ───────────────────────────────────────────
|
||||
|
||||
/// Disable inactivity checking entirely
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub disable_inactivity_check: Option<bool>,
|
||||
@@ -199,7 +206,6 @@ pub struct RustProxyOptions {
|
||||
pub enable_randomized_timeouts: Option<bool>,
|
||||
|
||||
// ─── Rate Limiting ───────────────────────────────────────────────
|
||||
|
||||
/// Maximum simultaneous connections from a single IP
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub max_connections_per_ip: Option<u64>,
|
||||
@@ -213,7 +219,6 @@ pub struct RustProxyOptions {
|
||||
pub max_connections: Option<u64>,
|
||||
|
||||
// ─── Keep-Alive Settings ─────────────────────────────────────────
|
||||
|
||||
/// How to treat keep-alive connections
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub keep_alive_treatment: Option<KeepAliveTreatment>,
|
||||
@@ -227,7 +232,6 @@ pub struct RustProxyOptions {
|
||||
pub extended_keep_alive_lifetime: Option<u64>,
|
||||
|
||||
// ─── HttpProxy Integration ───────────────────────────────────────
|
||||
|
||||
/// Array of ports to forward to HttpProxy
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub use_http_proxy: Option<Vec<u16>>,
|
||||
@@ -237,13 +241,15 @@ pub struct RustProxyOptions {
|
||||
pub http_proxy_port: Option<u16>,
|
||||
|
||||
// ─── Metrics ─────────────────────────────────────────────────────
|
||||
|
||||
/// Metrics configuration
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub metrics: Option<MetricsConfig>,
|
||||
|
||||
// ─── ACME ────────────────────────────────────────────────────────
|
||||
/// Global ingress security policy, enforced before route selection.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub security_policy: Option<SecurityPolicy>,
|
||||
|
||||
// ─── ACME ────────────────────────────────────────────────────────
|
||||
/// Global ACME configuration
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub acme: Option<AcmeOptions>,
|
||||
@@ -283,6 +289,7 @@ impl Default for RustProxyOptions {
|
||||
use_http_proxy: None,
|
||||
http_proxy_port: None,
|
||||
metrics: None,
|
||||
security_policy: None,
|
||||
acme: None,
|
||||
}
|
||||
}
|
||||
@@ -318,7 +325,8 @@ impl RustProxyOptions {
|
||||
|
||||
/// Get all unique ports that routes listen on.
|
||||
pub fn all_listening_ports(&self) -> Vec<u16> {
|
||||
let mut ports: Vec<u16> = self.routes
|
||||
let mut ports: Vec<u16> = self
|
||||
.routes
|
||||
.iter()
|
||||
.flat_map(|r| r.listening_ports())
|
||||
.collect();
|
||||
@@ -331,12 +339,73 @@ impl RustProxyOptions {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::helpers::*;
|
||||
use crate::route_types::*;
|
||||
use crate::tls_types::*;
|
||||
|
||||
fn make_route(domain: &str, host: &str, port: u16, listen_port: u16) -> RouteConfig {
|
||||
RouteConfig {
|
||||
id: None,
|
||||
route_match: RouteMatch {
|
||||
ports: PortRange::Single(listen_port),
|
||||
domains: Some(DomainSpec::Single(domain.to_string())),
|
||||
path: None,
|
||||
client_ip: None,
|
||||
transport: None,
|
||||
tls_version: None,
|
||||
headers: None,
|
||||
protocol: None,
|
||||
},
|
||||
action: RouteAction {
|
||||
action_type: RouteActionType::Forward,
|
||||
targets: Some(vec![RouteTarget {
|
||||
target_match: None,
|
||||
host: HostSpec::Single(host.to_string()),
|
||||
port: PortSpec::Fixed(port),
|
||||
tls: None,
|
||||
websocket: None,
|
||||
load_balancing: None,
|
||||
send_proxy_protocol: None,
|
||||
headers: None,
|
||||
advanced: None,
|
||||
backend_transport: None,
|
||||
priority: None,
|
||||
}]),
|
||||
tls: None,
|
||||
websocket: None,
|
||||
load_balancing: None,
|
||||
advanced: None,
|
||||
options: None,
|
||||
send_proxy_protocol: None,
|
||||
udp: None,
|
||||
},
|
||||
headers: None,
|
||||
security: None,
|
||||
name: None,
|
||||
description: None,
|
||||
priority: None,
|
||||
tags: None,
|
||||
enabled: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn make_passthrough_route(domain: &str, host: &str, port: u16) -> RouteConfig {
|
||||
let mut route = make_route(domain, host, port, 443);
|
||||
route.action.tls = Some(RouteTls {
|
||||
mode: TlsMode::Passthrough,
|
||||
certificate: None,
|
||||
acme: None,
|
||||
versions: None,
|
||||
ciphers: None,
|
||||
honor_cipher_order: None,
|
||||
session_timeout: None,
|
||||
});
|
||||
route
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_serde_roundtrip_minimal() {
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![create_http_route("example.com", "localhost", 8080)],
|
||||
routes: vec![make_route("example.com", "localhost", 8080, 80)],
|
||||
..Default::default()
|
||||
};
|
||||
let json = serde_json::to_string(&options).unwrap();
|
||||
@@ -348,8 +417,8 @@ mod tests {
|
||||
fn test_serde_roundtrip_full() {
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![
|
||||
create_http_route("a.com", "backend1", 8080),
|
||||
create_https_passthrough_route("b.com", "backend2", 443),
|
||||
make_route("a.com", "backend1", 8080, 80),
|
||||
make_passthrough_route("b.com", "backend2", 443),
|
||||
],
|
||||
connection_timeout: Some(5000),
|
||||
socket_timeout: Some(60000),
|
||||
@@ -374,6 +443,209 @@ mod tests {
|
||||
assert_eq!(parsed.connection_timeout, Some(5000));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_deserialize_ts_contract_route_shapes() {
|
||||
let value = serde_json::json!({
|
||||
"routes": [{
|
||||
"name": "contract-route",
|
||||
"match": {
|
||||
"ports": [443, { "from": 8443, "to": 8444 }],
|
||||
"domains": ["api.example.com", "*.example.com"],
|
||||
"transport": "udp",
|
||||
"protocol": "http3",
|
||||
"headers": {
|
||||
"content-type": "/^application\\/json$/i"
|
||||
}
|
||||
},
|
||||
"action": {
|
||||
"type": "forward",
|
||||
"targets": [{
|
||||
"match": {
|
||||
"ports": [443],
|
||||
"path": "/api/*",
|
||||
"method": ["GET"],
|
||||
"headers": {
|
||||
"x-env": "/^(prod|stage)$/"
|
||||
}
|
||||
},
|
||||
"host": ["backend-a", "backend-b"],
|
||||
"port": "preserve",
|
||||
"sendProxyProtocol": true,
|
||||
"backendTransport": "tcp"
|
||||
}],
|
||||
"tls": {
|
||||
"mode": "terminate",
|
||||
"certificate": "auto"
|
||||
},
|
||||
"sendProxyProtocol": true,
|
||||
"udp": {
|
||||
"maxSessionsPerIp": 321,
|
||||
"quic": {
|
||||
"enableHttp3": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"security": {
|
||||
"ipAllowList": [{
|
||||
"ip": "10.0.0.0/8",
|
||||
"domains": ["api.example.com"]
|
||||
}]
|
||||
}
|
||||
}],
|
||||
"preserveSourceIp": true,
|
||||
"proxyIps": ["10.0.0.1"],
|
||||
"acceptProxyProtocol": true,
|
||||
"sendProxyProtocol": true,
|
||||
"noDelay": true,
|
||||
"keepAlive": true,
|
||||
"keepAliveInitialDelay": 1500,
|
||||
"maxPendingDataSize": 4096,
|
||||
"disableInactivityCheck": true,
|
||||
"enableKeepAliveProbes": true,
|
||||
"enableDetailedLogging": true,
|
||||
"enableTlsDebugLogging": true,
|
||||
"enableRandomizedTimeouts": true,
|
||||
"connectionTimeout": 5000,
|
||||
"initialDataTimeout": 7000,
|
||||
"socketTimeout": 9000,
|
||||
"inactivityCheckInterval": 1100,
|
||||
"maxConnectionLifetime": 13000,
|
||||
"inactivityTimeout": 15000,
|
||||
"gracefulShutdownTimeout": 17000,
|
||||
"maxConnectionsPerIp": 20,
|
||||
"connectionRateLimitPerMinute": 30,
|
||||
"keepAliveTreatment": "extended",
|
||||
"keepAliveInactivityMultiplier": 2.0,
|
||||
"extendedKeepAliveLifetime": 19000,
|
||||
"metrics": {
|
||||
"enabled": true,
|
||||
"sampleIntervalMs": 250,
|
||||
"retentionSeconds": 60
|
||||
},
|
||||
"acme": {
|
||||
"enabled": true,
|
||||
"email": "ops@example.com",
|
||||
"environment": "staging",
|
||||
"useProduction": false,
|
||||
"skipConfiguredCerts": true,
|
||||
"renewThresholdDays": 14,
|
||||
"renewCheckIntervalHours": 12,
|
||||
"autoRenew": true,
|
||||
"port": 80
|
||||
}
|
||||
});
|
||||
|
||||
let options: RustProxyOptions = serde_json::from_value(value).unwrap();
|
||||
|
||||
assert_eq!(options.routes.len(), 1);
|
||||
assert_eq!(options.preserve_source_ip, Some(true));
|
||||
assert_eq!(options.proxy_ips, Some(vec!["10.0.0.1".to_string()]));
|
||||
assert_eq!(options.accept_proxy_protocol, Some(true));
|
||||
assert_eq!(options.send_proxy_protocol, Some(true));
|
||||
assert_eq!(options.no_delay, Some(true));
|
||||
assert_eq!(options.keep_alive, Some(true));
|
||||
assert_eq!(options.keep_alive_initial_delay, Some(1500));
|
||||
assert_eq!(options.max_pending_data_size, Some(4096));
|
||||
assert_eq!(options.disable_inactivity_check, Some(true));
|
||||
assert_eq!(options.enable_keep_alive_probes, Some(true));
|
||||
assert_eq!(options.enable_detailed_logging, Some(true));
|
||||
assert_eq!(options.enable_tls_debug_logging, Some(true));
|
||||
assert_eq!(options.enable_randomized_timeouts, Some(true));
|
||||
assert_eq!(options.connection_timeout, Some(5000));
|
||||
assert_eq!(options.initial_data_timeout, Some(7000));
|
||||
assert_eq!(options.socket_timeout, Some(9000));
|
||||
assert_eq!(options.inactivity_check_interval, Some(1100));
|
||||
assert_eq!(options.max_connection_lifetime, Some(13000));
|
||||
assert_eq!(options.inactivity_timeout, Some(15000));
|
||||
assert_eq!(options.graceful_shutdown_timeout, Some(17000));
|
||||
assert_eq!(options.max_connections_per_ip, Some(20));
|
||||
assert_eq!(options.connection_rate_limit_per_minute, Some(30));
|
||||
assert_eq!(
|
||||
options.keep_alive_treatment,
|
||||
Some(KeepAliveTreatment::Extended)
|
||||
);
|
||||
assert_eq!(options.keep_alive_inactivity_multiplier, Some(2.0));
|
||||
assert_eq!(options.extended_keep_alive_lifetime, Some(19000));
|
||||
|
||||
let route = &options.routes[0];
|
||||
assert_eq!(route.route_match.transport, Some(TransportProtocol::Udp));
|
||||
assert_eq!(route.route_match.protocol.as_deref(), Some("http3"));
|
||||
assert_eq!(
|
||||
route
|
||||
.route_match
|
||||
.headers
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.get("content-type")
|
||||
.unwrap(),
|
||||
"/^application\\/json$/i"
|
||||
);
|
||||
|
||||
let target = &route.action.targets.as_ref().unwrap()[0];
|
||||
assert!(matches!(target.host, HostSpec::List(_)));
|
||||
assert!(matches!(target.port, PortSpec::Special(ref p) if p == "preserve"));
|
||||
assert_eq!(target.backend_transport, Some(TransportProtocol::Tcp));
|
||||
assert_eq!(target.send_proxy_protocol, Some(true));
|
||||
assert_eq!(
|
||||
target
|
||||
.target_match
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.headers
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.get("x-env")
|
||||
.unwrap(),
|
||||
"/^(prod|stage)$/"
|
||||
);
|
||||
assert_eq!(route.action.send_proxy_protocol, Some(true));
|
||||
assert_eq!(
|
||||
route.action.udp.as_ref().unwrap().max_sessions_per_ip,
|
||||
Some(321)
|
||||
);
|
||||
assert_eq!(
|
||||
route
|
||||
.action
|
||||
.udp
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.quic
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.enable_http3,
|
||||
Some(true)
|
||||
);
|
||||
|
||||
let allow_list = route
|
||||
.security
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.ip_allow_list
|
||||
.as_ref()
|
||||
.unwrap();
|
||||
assert!(matches!(
|
||||
&allow_list[0],
|
||||
crate::security_types::IpAllowEntry::DomainScoped { ip, domains }
|
||||
if ip == "10.0.0.0/8" && domains == &vec!["api.example.com".to_string()]
|
||||
));
|
||||
|
||||
let metrics = options.metrics.as_ref().unwrap();
|
||||
assert_eq!(metrics.enabled, Some(true));
|
||||
assert_eq!(metrics.sample_interval_ms, Some(250));
|
||||
assert_eq!(metrics.retention_seconds, Some(60));
|
||||
|
||||
let acme = options.acme.as_ref().unwrap();
|
||||
assert_eq!(acme.enabled, Some(true));
|
||||
assert_eq!(acme.email.as_deref(), Some("ops@example.com"));
|
||||
assert_eq!(acme.environment, Some(AcmeEnvironment::Staging));
|
||||
assert_eq!(acme.use_production, Some(false));
|
||||
assert_eq!(acme.skip_configured_certs, Some(true));
|
||||
assert_eq!(acme.renew_threshold_days, Some(14));
|
||||
assert_eq!(acme.renew_check_interval_hours, Some(12));
|
||||
assert_eq!(acme.auto_renew, Some(true));
|
||||
assert_eq!(acme.port, Some(80));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_default_timeouts() {
|
||||
let options = RustProxyOptions::default();
|
||||
@@ -402,9 +674,9 @@ mod tests {
|
||||
fn test_all_listening_ports() {
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![
|
||||
create_http_route("a.com", "backend", 8080), // port 80
|
||||
create_https_passthrough_route("b.com", "backend", 443), // port 443
|
||||
create_http_route("c.com", "backend", 9090), // port 80 (duplicate)
|
||||
make_route("a.com", "backend", 8080, 80), // port 80
|
||||
make_passthrough_route("b.com", "backend", 443), // port 443
|
||||
make_route("c.com", "backend", 9090, 80), // port 80 (duplicate)
|
||||
],
|
||||
..Default::default()
|
||||
};
|
||||
@@ -428,9 +700,11 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_deserialize_example_json() {
|
||||
let content = std::fs::read_to_string(
|
||||
concat!(env!("CARGO_MANIFEST_DIR"), "/../../config/example.json")
|
||||
).unwrap();
|
||||
let content = std::fs::read_to_string(concat!(
|
||||
env!("CARGO_MANIFEST_DIR"),
|
||||
"/../../config/example.json"
|
||||
))
|
||||
.unwrap();
|
||||
let options: RustProxyOptions = serde_json::from_str(&content).unwrap();
|
||||
assert_eq!(options.routes.len(), 4);
|
||||
let ports = options.all_listening_ports();
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::tls_types::RouteTls;
|
||||
use crate::security_types::RouteSecurity;
|
||||
use crate::tls_types::RouteTls;
|
||||
|
||||
// ─── Port Range ──────────────────────────────────────────────────────
|
||||
|
||||
@@ -32,12 +32,13 @@ impl PortRange {
|
||||
pub fn to_ports(&self) -> Vec<u16> {
|
||||
match self {
|
||||
PortRange::Single(p) => vec![*p],
|
||||
PortRange::List(items) => {
|
||||
items.iter().flat_map(|item| match item {
|
||||
PortRange::List(items) => items
|
||||
.iter()
|
||||
.flat_map(|item| match item {
|
||||
PortRangeItem::Port(p) => vec![*p],
|
||||
PortRangeItem::Range(r) => (r.from..=r.to).collect(),
|
||||
}).collect()
|
||||
}
|
||||
})
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -60,16 +61,6 @@ pub enum RouteActionType {
|
||||
SocketHandler,
|
||||
}
|
||||
|
||||
// ─── Forwarding Engine ───────────────────────────────────────────────
|
||||
|
||||
/// Forwarding engine specification.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum ForwardingEngine {
|
||||
Node,
|
||||
Nftables,
|
||||
}
|
||||
|
||||
// ─── Route Match ─────────────────────────────────────────────────────
|
||||
|
||||
/// Domain specification: single string or array.
|
||||
@@ -89,8 +80,34 @@ impl DomainSpec {
|
||||
}
|
||||
}
|
||||
|
||||
// Convenience conversions for DomainSpec
|
||||
impl From<&str> for DomainSpec {
|
||||
fn from(s: &str) -> Self {
|
||||
DomainSpec::Single(s.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<String> for DomainSpec {
|
||||
fn from(s: String) -> Self {
|
||||
DomainSpec::Single(s)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Vec<String>> for DomainSpec {
|
||||
fn from(v: Vec<String>) -> Self {
|
||||
DomainSpec::List(v)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Vec<&str>> for DomainSpec {
|
||||
fn from(v: Vec<&str>) -> Self {
|
||||
DomainSpec::List(v.into_iter().map(|s| s.to_string()).collect())
|
||||
}
|
||||
}
|
||||
|
||||
/// Header match value: either exact string or regex pattern.
|
||||
/// In JSON, all values come as strings. Regex patterns are prefixed with `/` and suffixed with `/`.
|
||||
/// In JSON, all values come as strings. Regex patterns use JS-style literal syntax,
|
||||
/// e.g. `/^application\/json$/` or `/^application\/json$/i`.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum HeaderMatchValue {
|
||||
@@ -341,38 +358,6 @@ pub struct RouteAdvanced {
|
||||
pub url_rewrite: Option<RouteUrlRewrite>,
|
||||
}
|
||||
|
||||
// ─── NFTables Options ────────────────────────────────────────────────
|
||||
|
||||
/// NFTables protocol type.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum NfTablesProtocol {
|
||||
Tcp,
|
||||
Udp,
|
||||
All,
|
||||
}
|
||||
|
||||
/// NFTables-specific configuration options.
|
||||
/// Matches TypeScript: `INfTablesOptions`
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct NfTablesOptions {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub preserve_source_ip: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub protocol: Option<NfTablesProtocol>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub max_rate: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub priority: Option<i32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub table_name: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub use_ip_sets: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub use_advanced_nat: Option<bool>,
|
||||
}
|
||||
|
||||
// ─── Backend Protocol ────────────────────────────────────────────────
|
||||
|
||||
/// Backend protocol.
|
||||
@@ -541,14 +526,6 @@ pub struct RouteAction {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub options: Option<ActionOptions>,
|
||||
|
||||
/// Forwarding engine specification
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub forwarding_engine: Option<ForwardingEngine>,
|
||||
|
||||
/// NFTables-specific options
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub nftables: Option<NfTablesOptions>,
|
||||
|
||||
/// PROXY protocol support (default for all targets)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub send_proxy_protocol: Option<bool>,
|
||||
@@ -679,6 +656,11 @@ impl RouteConfig {
|
||||
self.route_match.ports.to_ports()
|
||||
}
|
||||
|
||||
/// Stable key used for frontend route-scoped metrics.
|
||||
pub fn metrics_key(&self) -> Option<&str> {
|
||||
self.name.as_deref().or(self.id.as_deref())
|
||||
}
|
||||
|
||||
/// Get the TLS mode for this route (from action-level or first target).
|
||||
pub fn tls_mode(&self) -> Option<&crate::tls_types::TlsMode> {
|
||||
// Check action-level TLS first
|
||||
@@ -696,3 +678,63 @@ impl RouteConfig {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn test_route(name: Option<&str>, id: Option<&str>) -> RouteConfig {
|
||||
RouteConfig {
|
||||
id: id.map(str::to_string),
|
||||
route_match: RouteMatch {
|
||||
ports: PortRange::Single(443),
|
||||
transport: None,
|
||||
domains: None,
|
||||
path: None,
|
||||
client_ip: None,
|
||||
tls_version: None,
|
||||
headers: None,
|
||||
protocol: None,
|
||||
},
|
||||
action: RouteAction {
|
||||
action_type: RouteActionType::Forward,
|
||||
targets: None,
|
||||
tls: None,
|
||||
websocket: None,
|
||||
load_balancing: None,
|
||||
advanced: None,
|
||||
options: None,
|
||||
send_proxy_protocol: None,
|
||||
udp: None,
|
||||
},
|
||||
headers: None,
|
||||
security: None,
|
||||
name: name.map(str::to_string),
|
||||
description: None,
|
||||
priority: None,
|
||||
tags: None,
|
||||
enabled: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn metrics_key_prefers_name() {
|
||||
let route = test_route(Some("named-route"), Some("route-id"));
|
||||
|
||||
assert_eq!(route.metrics_key(), Some("named-route"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn metrics_key_falls_back_to_id() {
|
||||
let route = test_route(None, Some("route-id"));
|
||||
|
||||
assert_eq!(route.metrics_key(), Some("route-id"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn metrics_key_is_absent_without_name_or_id() {
|
||||
let route = test_route(None, None);
|
||||
|
||||
assert_eq!(route.metrics_key(), None);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -103,14 +103,27 @@ pub struct JwtAuthConfig {
|
||||
pub exclude_paths: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
/// An entry in the IP allow list: either a plain IP/CIDR string
|
||||
/// or a domain-scoped entry that restricts the IP to specific domains.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum IpAllowEntry {
|
||||
/// Plain IP/CIDR — allowed for all domains on this route
|
||||
Plain(String),
|
||||
/// Domain-scoped — allowed only when the requested domain matches
|
||||
DomainScoped { ip: String, domains: Vec<String> },
|
||||
}
|
||||
|
||||
/// Security options for routes.
|
||||
/// Matches TypeScript: `IRouteSecurity`
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct RouteSecurity {
|
||||
/// IP addresses that are allowed to connect
|
||||
/// IP addresses that are allowed to connect.
|
||||
/// Entries can be plain strings (full route access) or objects with
|
||||
/// `{ ip, domains }` to scope access to specific domains.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub ip_allow_list: Option<Vec<String>>,
|
||||
pub ip_allow_list: Option<Vec<IpAllowEntry>>,
|
||||
/// IP addresses that are blocked from connecting
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub ip_block_list: Option<Vec<String>>,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use thiserror::Error;
|
||||
|
||||
use crate::route_types::{RouteConfig, RouteActionType};
|
||||
use crate::route_types::{RouteActionType, RouteConfig};
|
||||
|
||||
/// Validation errors for route configurations.
|
||||
#[derive(Debug, Error)]
|
||||
@@ -30,9 +30,10 @@ pub enum ValidationError {
|
||||
/// Validate a single route configuration.
|
||||
pub fn validate_route(route: &RouteConfig) -> Result<(), Vec<ValidationError>> {
|
||||
let mut errors = Vec::new();
|
||||
let name = route.name.clone().unwrap_or_else(|| {
|
||||
route.id.clone().unwrap_or_else(|| "unnamed".to_string())
|
||||
});
|
||||
let name = route
|
||||
.name
|
||||
.clone()
|
||||
.unwrap_or_else(|| route.id.clone().unwrap_or_else(|| "unnamed".to_string()));
|
||||
|
||||
// Check ports
|
||||
let ports = route.listening_ports();
|
||||
@@ -104,7 +105,49 @@ mod tests {
|
||||
use crate::route_types::*;
|
||||
|
||||
fn make_valid_route() -> RouteConfig {
|
||||
crate::helpers::create_http_route("example.com", "localhost", 8080)
|
||||
RouteConfig {
|
||||
id: None,
|
||||
route_match: RouteMatch {
|
||||
ports: PortRange::Single(80),
|
||||
domains: Some(DomainSpec::Single("example.com".to_string())),
|
||||
path: None,
|
||||
client_ip: None,
|
||||
transport: None,
|
||||
tls_version: None,
|
||||
headers: None,
|
||||
protocol: None,
|
||||
},
|
||||
action: RouteAction {
|
||||
action_type: RouteActionType::Forward,
|
||||
targets: Some(vec![RouteTarget {
|
||||
target_match: None,
|
||||
host: HostSpec::Single("localhost".to_string()),
|
||||
port: PortSpec::Fixed(8080),
|
||||
tls: None,
|
||||
websocket: None,
|
||||
load_balancing: None,
|
||||
send_proxy_protocol: None,
|
||||
headers: None,
|
||||
advanced: None,
|
||||
backend_transport: None,
|
||||
priority: None,
|
||||
}]),
|
||||
tls: None,
|
||||
websocket: None,
|
||||
load_balancing: None,
|
||||
advanced: None,
|
||||
options: None,
|
||||
send_proxy_protocol: None,
|
||||
udp: None,
|
||||
},
|
||||
headers: None,
|
||||
security: None,
|
||||
name: None,
|
||||
description: None,
|
||||
priority: None,
|
||||
tags: None,
|
||||
enabled: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -118,7 +161,9 @@ mod tests {
|
||||
let mut route = make_valid_route();
|
||||
route.action.targets = None;
|
||||
let errors = validate_route(&route).unwrap_err();
|
||||
assert!(errors.iter().any(|e| matches!(e, ValidationError::MissingTargets { .. })));
|
||||
assert!(errors
|
||||
.iter()
|
||||
.any(|e| matches!(e, ValidationError::MissingTargets { .. })));
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -126,7 +171,9 @@ mod tests {
|
||||
let mut route = make_valid_route();
|
||||
route.action.targets = Some(vec![]);
|
||||
let errors = validate_route(&route).unwrap_err();
|
||||
assert!(errors.iter().any(|e| matches!(e, ValidationError::EmptyTargets { .. })));
|
||||
assert!(errors
|
||||
.iter()
|
||||
.any(|e| matches!(e, ValidationError::EmptyTargets { .. })));
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -134,7 +181,9 @@ mod tests {
|
||||
let mut route = make_valid_route();
|
||||
route.route_match.ports = PortRange::Single(0);
|
||||
let errors = validate_route(&route).unwrap_err();
|
||||
assert!(errors.iter().any(|e| matches!(e, ValidationError::InvalidPort { port: 0, .. })));
|
||||
assert!(errors
|
||||
.iter()
|
||||
.any(|e| matches!(e, ValidationError::InvalidPort { port: 0, .. })));
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -144,7 +193,9 @@ mod tests {
|
||||
let mut r2 = make_valid_route();
|
||||
r2.id = Some("route-1".to_string());
|
||||
let errors = validate_routes(&[r1, r2]).unwrap_err();
|
||||
assert!(errors.iter().any(|e| matches!(e, ValidationError::DuplicateId { .. })));
|
||||
assert!(errors
|
||||
.iter()
|
||||
.any(|e| matches!(e, ValidationError::DuplicateId { .. })));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -30,3 +30,4 @@ socket2 = { workspace = true }
|
||||
quinn = { workspace = true }
|
||||
h3 = { workspace = true }
|
||||
h3-quinn = { workspace = true }
|
||||
futures = { version = "0.3", default-features = false, features = ["std"] }
|
||||
|
||||
@@ -3,8 +3,8 @@
|
||||
//! Reuses idle keep-alive connections to avoid per-request TCP+TLS handshakes.
|
||||
//! HTTP/2 and HTTP/3 connections are multiplexed (clone the sender / share the connection).
|
||||
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use bytes::Bytes;
|
||||
@@ -56,7 +56,11 @@ struct PooledH2 {
|
||||
}
|
||||
|
||||
/// A pooled QUIC/HTTP/3 connection (multiplexed like H2).
|
||||
/// Stores the h3 `SendRequest` handle so pool hits skip the h3 SETTINGS handshake.
|
||||
pub struct PooledH3 {
|
||||
/// Multiplexed h3 request handle — clone to open a new stream.
|
||||
pub send_request: h3::client::SendRequest<h3_quinn::OpenStreams, Bytes>,
|
||||
/// Raw QUIC connection — kept for liveness probing (close_reason) only.
|
||||
pub connection: quinn::Connection,
|
||||
pub created_at: Instant,
|
||||
pub generation: u64,
|
||||
@@ -101,13 +105,19 @@ impl ConnectionPool {
|
||||
|
||||
/// Try to check out an idle HTTP/1.1 sender for the given key.
|
||||
/// Returns `None` if no usable idle connection exists.
|
||||
pub fn checkout_h1(&self, key: &PoolKey) -> Option<http1::SendRequest<BoxBody<Bytes, hyper::Error>>> {
|
||||
pub fn checkout_h1(
|
||||
&self,
|
||||
key: &PoolKey,
|
||||
) -> Option<http1::SendRequest<BoxBody<Bytes, hyper::Error>>> {
|
||||
let mut entry = self.h1_pool.get_mut(key)?;
|
||||
let idles = entry.value_mut();
|
||||
|
||||
while let Some(idle) = idles.pop() {
|
||||
// Check if the connection is still alive and ready
|
||||
if idle.idle_since.elapsed() < IDLE_TIMEOUT && idle.sender.is_ready() && !idle.sender.is_closed() {
|
||||
if idle.idle_since.elapsed() < IDLE_TIMEOUT
|
||||
&& idle.sender.is_ready()
|
||||
&& !idle.sender.is_closed()
|
||||
{
|
||||
// H1 pool hit — no logging on hot path
|
||||
return Some(idle.sender);
|
||||
}
|
||||
@@ -124,7 +134,11 @@ impl ConnectionPool {
|
||||
|
||||
/// Return an HTTP/1.1 sender to the pool after the response body has been prepared.
|
||||
/// The caller should NOT call this if the sender is closed or not ready.
|
||||
pub fn checkin_h1(&self, key: PoolKey, sender: http1::SendRequest<BoxBody<Bytes, hyper::Error>>) {
|
||||
pub fn checkin_h1(
|
||||
&self,
|
||||
key: PoolKey,
|
||||
sender: http1::SendRequest<BoxBody<Bytes, hyper::Error>>,
|
||||
) {
|
||||
if sender.is_closed() || !sender.is_ready() {
|
||||
return; // Don't pool broken connections
|
||||
}
|
||||
@@ -141,7 +155,10 @@ impl ConnectionPool {
|
||||
|
||||
/// Try to get a cloned HTTP/2 sender for the given key.
|
||||
/// HTTP/2 senders are Clone-able (multiplexed), so we clone rather than remove.
|
||||
pub fn checkout_h2(&self, key: &PoolKey) -> Option<(http2::SendRequest<BoxBody<Bytes, hyper::Error>>, Duration)> {
|
||||
pub fn checkout_h2(
|
||||
&self,
|
||||
key: &PoolKey,
|
||||
) -> Option<(http2::SendRequest<BoxBody<Bytes, hyper::Error>>, Duration)> {
|
||||
let entry = self.h2_pool.get(key)?;
|
||||
let pooled = entry.value();
|
||||
let age = pooled.created_at.elapsed();
|
||||
@@ -180,16 +197,23 @@ impl ConnectionPool {
|
||||
/// Register an HTTP/2 sender in the pool. Returns the generation ID for this entry.
|
||||
/// The caller should pass this generation to the connection driver so it can use
|
||||
/// `remove_h2_if_generation` instead of `remove_h2` to avoid phantom eviction.
|
||||
pub fn register_h2(&self, key: PoolKey, sender: http2::SendRequest<BoxBody<Bytes, hyper::Error>>) -> u64 {
|
||||
pub fn register_h2(
|
||||
&self,
|
||||
key: PoolKey,
|
||||
sender: http2::SendRequest<BoxBody<Bytes, hyper::Error>>,
|
||||
) -> u64 {
|
||||
let gen = self.h2_generation.fetch_add(1, Ordering::Relaxed);
|
||||
if sender.is_closed() {
|
||||
return gen;
|
||||
}
|
||||
self.h2_pool.insert(key, PooledH2 {
|
||||
sender,
|
||||
created_at: Instant::now(),
|
||||
generation: gen,
|
||||
});
|
||||
self.h2_pool.insert(
|
||||
key,
|
||||
PooledH2 {
|
||||
sender,
|
||||
created_at: Instant::now(),
|
||||
generation: gen,
|
||||
},
|
||||
);
|
||||
gen
|
||||
}
|
||||
|
||||
@@ -197,7 +221,14 @@ impl ConnectionPool {
|
||||
|
||||
/// Try to get a pooled QUIC connection for the given key.
|
||||
/// QUIC connections are multiplexed — the connection is shared, not removed.
|
||||
pub fn checkout_h3(&self, key: &PoolKey) -> Option<(quinn::Connection, Duration)> {
|
||||
pub fn checkout_h3(
|
||||
&self,
|
||||
key: &PoolKey,
|
||||
) -> Option<(
|
||||
h3::client::SendRequest<h3_quinn::OpenStreams, Bytes>,
|
||||
quinn::Connection,
|
||||
Duration,
|
||||
)> {
|
||||
let entry = self.h3_pool.get(key)?;
|
||||
let pooled = entry.value();
|
||||
let age = pooled.created_at.elapsed();
|
||||
@@ -215,17 +246,27 @@ impl ConnectionPool {
|
||||
return None;
|
||||
}
|
||||
|
||||
Some((pooled.connection.clone(), age))
|
||||
Some((pooled.send_request.clone(), pooled.connection.clone(), age))
|
||||
}
|
||||
|
||||
/// Register a QUIC connection in the pool. Returns the generation ID.
|
||||
pub fn register_h3(&self, key: PoolKey, connection: quinn::Connection) -> u64 {
|
||||
/// Register a QUIC connection and its h3 SendRequest handle in the pool.
|
||||
/// Returns the generation ID.
|
||||
pub fn register_h3(
|
||||
&self,
|
||||
key: PoolKey,
|
||||
connection: quinn::Connection,
|
||||
send_request: h3::client::SendRequest<h3_quinn::OpenStreams, Bytes>,
|
||||
) -> u64 {
|
||||
let gen = self.h2_generation.fetch_add(1, Ordering::Relaxed);
|
||||
self.h3_pool.insert(key, PooledH3 {
|
||||
connection,
|
||||
created_at: Instant::now(),
|
||||
generation: gen,
|
||||
});
|
||||
self.h3_pool.insert(
|
||||
key,
|
||||
PooledH3 {
|
||||
send_request,
|
||||
connection,
|
||||
created_at: Instant::now(),
|
||||
generation: gen,
|
||||
},
|
||||
);
|
||||
gen
|
||||
}
|
||||
|
||||
@@ -266,7 +307,9 @@ impl ConnectionPool {
|
||||
// Evict dead or aged-out H2 connections
|
||||
let mut dead_h2 = Vec::new();
|
||||
for entry in h2_pool.iter() {
|
||||
if entry.value().sender.is_closed() || entry.value().created_at.elapsed() >= MAX_H2_AGE {
|
||||
if entry.value().sender.is_closed()
|
||||
|| entry.value().created_at.elapsed() >= MAX_H2_AGE
|
||||
{
|
||||
dead_h2.push(entry.key().clone());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
//! A body wrapper that counts bytes flowing through and reports them to MetricsCollector.
|
||||
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::task::{Context, Poll};
|
||||
|
||||
use bytes::Bytes;
|
||||
@@ -76,7 +76,11 @@ impl<B> CountingBody<B> {
|
||||
/// Set the connection-level activity tracker. When set, each data frame
|
||||
/// updates this timestamp to prevent the idle watchdog from killing the
|
||||
/// connection during active body streaming.
|
||||
pub fn with_connection_activity(mut self, activity: Arc<AtomicU64>, start: std::time::Instant) -> Self {
|
||||
pub fn with_connection_activity(
|
||||
mut self,
|
||||
activity: Arc<AtomicU64>,
|
||||
start: std::time::Instant,
|
||||
) -> Self {
|
||||
self.connection_activity = Some(activity);
|
||||
self.activity_start = Some(start);
|
||||
self
|
||||
@@ -134,7 +138,9 @@ where
|
||||
}
|
||||
// Keep the connection-level idle watchdog alive on every frame
|
||||
// (this is just one atomic store — cheap enough per-frame)
|
||||
if let (Some(activity), Some(start)) = (&this.connection_activity, &this.activity_start) {
|
||||
if let (Some(activity), Some(start)) =
|
||||
(&this.connection_activity, &this.activity_start)
|
||||
{
|
||||
activity.store(start.elapsed().as_millis() as u64, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,14 +11,14 @@ use std::task::{Context, Poll};
|
||||
|
||||
use bytes::{Buf, Bytes};
|
||||
use http_body::Frame;
|
||||
use http_body_util::BodyExt;
|
||||
use http_body_util::combinators::BoxBody;
|
||||
use http_body_util::BodyExt;
|
||||
use tracing::{debug, warn};
|
||||
|
||||
use rustproxy_config::RouteConfig;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
use crate::proxy_service::{ConnActivity, HttpProxyService};
|
||||
use crate::proxy_service::{ConnActivity, HttpProxyService, ProtocolGuard};
|
||||
|
||||
/// HTTP/3 proxy service.
|
||||
///
|
||||
@@ -48,6 +48,10 @@ impl H3ProxyService {
|
||||
let remote_addr = real_client_addr.unwrap_or_else(|| connection.remote_address());
|
||||
debug!("HTTP/3 connection from {} on port {}", remote_addr, port);
|
||||
|
||||
// Track frontend H3 connection for the QUIC connection's lifetime.
|
||||
let _frontend_h3_guard =
|
||||
ProtocolGuard::frontend(Arc::clone(self.http_proxy.metrics()), "h3");
|
||||
|
||||
let mut h3_conn: h3::server::Connection<h3_quinn::Connection, Bytes> =
|
||||
h3::server::builder()
|
||||
.send_grease(false)
|
||||
@@ -89,8 +93,15 @@ impl H3ProxyService {
|
||||
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = handle_h3_request(
|
||||
request, stream, port, remote_addr, &http_proxy, request_cancel,
|
||||
).await {
|
||||
request,
|
||||
stream,
|
||||
port,
|
||||
remote_addr,
|
||||
&http_proxy,
|
||||
request_cancel,
|
||||
)
|
||||
.await
|
||||
{
|
||||
debug!("HTTP/3 request error from {}: {}", remote_addr, e);
|
||||
}
|
||||
});
|
||||
@@ -116,7 +127,7 @@ async fn handle_h3_request(
|
||||
cancel: CancellationToken,
|
||||
) -> anyhow::Result<()> {
|
||||
// Stream request body from H3 client via an mpsc channel.
|
||||
let (body_tx, body_rx) = tokio::sync::mpsc::channel::<Bytes>(4);
|
||||
let (body_tx, body_rx) = tokio::sync::mpsc::channel::<Bytes>(32);
|
||||
|
||||
// Spawn the H3 body reader task with cancellation
|
||||
let body_cancel = cancel.clone();
|
||||
@@ -132,8 +143,7 @@ async fn handle_h3_request(
|
||||
}
|
||||
};
|
||||
let mut chunk = chunk;
|
||||
let data = Bytes::copy_from_slice(chunk.chunk());
|
||||
chunk.advance(chunk.remaining());
|
||||
let data = chunk.copy_to_bytes(chunk.remaining());
|
||||
if body_tx.send(data).await.is_err() {
|
||||
break;
|
||||
}
|
||||
@@ -151,11 +161,14 @@ async fn handle_h3_request(
|
||||
// Delegate to HttpProxyService — same backend path as TCP/HTTP:
|
||||
// route matching, ALPN protocol detection, connection pool, H1/H2/H3 auto.
|
||||
let conn_activity = ConnActivity::new_standalone();
|
||||
let response = http_proxy.handle_request(req, peer_addr, port, cancel, conn_activity).await
|
||||
let response = http_proxy
|
||||
.handle_request(req, peer_addr, port, cancel, conn_activity)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("Backend request failed: {}", e))?;
|
||||
|
||||
// Await the body reader to get the H3 stream back
|
||||
let mut stream = body_reader.await
|
||||
let mut stream = body_reader
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("Body reader task failed: {}", e))?;
|
||||
|
||||
// Send response headers over H3 (skip hop-by-hop headers)
|
||||
@@ -168,10 +181,13 @@ async fn handle_h3_request(
|
||||
}
|
||||
h3_response = h3_response.header(name, value);
|
||||
}
|
||||
let h3_response = h3_response.body(())
|
||||
let h3_response = h3_response
|
||||
.body(())
|
||||
.map_err(|e| anyhow::anyhow!("Failed to build H3 response: {}", e))?;
|
||||
|
||||
stream.send_response(h3_response).await
|
||||
stream
|
||||
.send_response(h3_response)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("Failed to send H3 response: {}", e))?;
|
||||
|
||||
// Stream response body back over H3
|
||||
@@ -179,8 +195,10 @@ async fn handle_h3_request(
|
||||
while let Some(frame) = resp_body.frame().await {
|
||||
match frame {
|
||||
Ok(frame) => {
|
||||
if let Some(data) = frame.data_ref() {
|
||||
stream.send_data(Bytes::copy_from_slice(data)).await
|
||||
if let Ok(data) = frame.into_data() {
|
||||
stream
|
||||
.send_data(data)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("Failed to send H3 data: {}", e))?;
|
||||
}
|
||||
}
|
||||
@@ -192,7 +210,9 @@ async fn handle_h3_request(
|
||||
}
|
||||
|
||||
// Finish the H3 stream (send QUIC FIN)
|
||||
stream.finish().await
|
||||
stream
|
||||
.finish()
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("Failed to finish H3 stream: {}", e))?;
|
||||
|
||||
Ok(())
|
||||
|
||||
@@ -5,14 +5,15 @@
|
||||
|
||||
pub mod connection_pool;
|
||||
pub mod counting_body;
|
||||
pub mod h3_service;
|
||||
pub mod protocol_cache;
|
||||
pub mod proxy_service;
|
||||
pub mod request_filter;
|
||||
mod request_host;
|
||||
pub mod response_filter;
|
||||
pub mod shutdown_on_drop;
|
||||
pub mod template;
|
||||
pub mod upstream_selector;
|
||||
pub mod h3_service;
|
||||
|
||||
pub use connection_pool::*;
|
||||
pub use counting_body::*;
|
||||
|
||||
@@ -144,10 +144,14 @@ impl FailureState {
|
||||
}
|
||||
|
||||
fn all_expired(&self) -> bool {
|
||||
let h2_expired = self.h2.as_ref()
|
||||
let h2_expired = self
|
||||
.h2
|
||||
.as_ref()
|
||||
.map(|r| r.failed_at.elapsed() >= r.cooldown)
|
||||
.unwrap_or(true);
|
||||
let h3_expired = self.h3.as_ref()
|
||||
let h3_expired = self
|
||||
.h3
|
||||
.as_ref()
|
||||
.map(|r| r.failed_at.elapsed() >= r.cooldown)
|
||||
.unwrap_or(true);
|
||||
h2_expired && h3_expired
|
||||
@@ -355,9 +359,13 @@ impl ProtocolCache {
|
||||
|
||||
let record = entry.get_mut(protocol);
|
||||
let (consecutive, new_cooldown) = match record {
|
||||
Some(existing) if existing.failed_at.elapsed() < existing.cooldown.saturating_mul(2) => {
|
||||
Some(existing)
|
||||
if existing.failed_at.elapsed() < existing.cooldown.saturating_mul(2) =>
|
||||
{
|
||||
// Still within the "recent" window — escalate
|
||||
let c = existing.consecutive_failures.saturating_add(1)
|
||||
let c = existing
|
||||
.consecutive_failures
|
||||
.saturating_add(1)
|
||||
.min(PROTOCOL_FAILURE_ESCALATION_CAP);
|
||||
(c, escalate_cooldown(c))
|
||||
}
|
||||
@@ -394,8 +402,13 @@ impl ProtocolCache {
|
||||
if protocol == DetectedProtocol::H1 {
|
||||
return false;
|
||||
}
|
||||
self.failures.get(key)
|
||||
.and_then(|entry| entry.get(protocol).map(|r| r.failed_at.elapsed() < r.cooldown))
|
||||
self.failures
|
||||
.get(key)
|
||||
.and_then(|entry| {
|
||||
entry
|
||||
.get(protocol)
|
||||
.map(|r| r.failed_at.elapsed() < r.cooldown)
|
||||
})
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
@@ -464,19 +477,18 @@ impl ProtocolCache {
|
||||
|
||||
/// Snapshot all non-expired cache entries for metrics/UI display.
|
||||
pub fn snapshot(&self) -> Vec<ProtocolCacheEntry> {
|
||||
self.cache.iter()
|
||||
self.cache
|
||||
.iter()
|
||||
.filter(|entry| entry.value().last_accessed_at.elapsed() < PROTOCOL_CACHE_TTL)
|
||||
.map(|entry| {
|
||||
let key = entry.key();
|
||||
let val = entry.value();
|
||||
let failure_info = self.failures.get(key);
|
||||
|
||||
let (h2_sup, h2_cd, h2_cons) = Self::suppression_info(
|
||||
failure_info.as_deref().and_then(|f| f.h2.as_ref()),
|
||||
);
|
||||
let (h3_sup, h3_cd, h3_cons) = Self::suppression_info(
|
||||
failure_info.as_deref().and_then(|f| f.h3.as_ref()),
|
||||
);
|
||||
let (h2_sup, h2_cd, h2_cons) =
|
||||
Self::suppression_info(failure_info.as_deref().and_then(|f| f.h2.as_ref()));
|
||||
let (h3_sup, h3_cd, h3_cons) =
|
||||
Self::suppression_info(failure_info.as_deref().and_then(|f| f.h3.as_ref()));
|
||||
|
||||
ProtocolCacheEntry {
|
||||
host: key.host.clone(),
|
||||
@@ -507,7 +519,13 @@ impl ProtocolCache {
|
||||
/// Insert a protocol detection result with an optional H3 port.
|
||||
/// Logs protocol transitions when overwriting an existing entry.
|
||||
/// No suppression check — callers must check before calling.
|
||||
fn insert_internal(&self, key: ProtocolCacheKey, protocol: DetectedProtocol, h3_port: Option<u16>, reason: &str) {
|
||||
fn insert_internal(
|
||||
&self,
|
||||
key: ProtocolCacheKey,
|
||||
protocol: DetectedProtocol,
|
||||
h3_port: Option<u16>,
|
||||
reason: &str,
|
||||
) {
|
||||
// Check for existing entry to log protocol transitions
|
||||
if let Some(existing) = self.cache.get(&key) {
|
||||
if existing.protocol != protocol {
|
||||
@@ -522,7 +540,9 @@ impl ProtocolCache {
|
||||
|
||||
// Evict oldest entry if at capacity
|
||||
if self.cache.len() >= PROTOCOL_CACHE_MAX_ENTRIES && !self.cache.contains_key(&key) {
|
||||
let oldest = self.cache.iter()
|
||||
let oldest = self
|
||||
.cache
|
||||
.iter()
|
||||
.min_by_key(|entry| entry.value().last_accessed_at)
|
||||
.map(|entry| entry.key().clone());
|
||||
if let Some(oldest_key) = oldest {
|
||||
@@ -531,13 +551,16 @@ impl ProtocolCache {
|
||||
}
|
||||
|
||||
let now = Instant::now();
|
||||
self.cache.insert(key, CachedEntry {
|
||||
protocol,
|
||||
detected_at: now,
|
||||
last_accessed_at: now,
|
||||
last_probed_at: now,
|
||||
h3_port,
|
||||
});
|
||||
self.cache.insert(
|
||||
key,
|
||||
CachedEntry {
|
||||
protocol,
|
||||
detected_at: now,
|
||||
last_accessed_at: now,
|
||||
last_probed_at: now,
|
||||
h3_port,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
/// Reduce a failure record's remaining cooldown to `target`, if it currently
|
||||
@@ -582,26 +605,34 @@ impl ProtocolCache {
|
||||
interval.tick().await;
|
||||
|
||||
// Clean expired cache entries (sliding TTL based on last_accessed_at)
|
||||
let expired: Vec<ProtocolCacheKey> = cache.iter()
|
||||
let expired: Vec<ProtocolCacheKey> = cache
|
||||
.iter()
|
||||
.filter(|entry| entry.value().last_accessed_at.elapsed() >= PROTOCOL_CACHE_TTL)
|
||||
.map(|entry| entry.key().clone())
|
||||
.collect();
|
||||
|
||||
if !expired.is_empty() {
|
||||
debug!("Protocol cache cleanup: removing {} expired entries", expired.len());
|
||||
debug!(
|
||||
"Protocol cache cleanup: removing {} expired entries",
|
||||
expired.len()
|
||||
);
|
||||
for key in expired {
|
||||
cache.remove(&key);
|
||||
}
|
||||
}
|
||||
|
||||
// Clean fully-expired failure entries
|
||||
let expired_failures: Vec<ProtocolCacheKey> = failures.iter()
|
||||
let expired_failures: Vec<ProtocolCacheKey> = failures
|
||||
.iter()
|
||||
.filter(|entry| entry.value().all_expired())
|
||||
.map(|entry| entry.key().clone())
|
||||
.collect();
|
||||
|
||||
if !expired_failures.is_empty() {
|
||||
debug!("Protocol cache cleanup: removing {} expired failure entries", expired_failures.len());
|
||||
debug!(
|
||||
"Protocol cache cleanup: removing {} expired failure entries",
|
||||
expired_failures.len()
|
||||
);
|
||||
for key in expired_failures {
|
||||
failures.remove(&key);
|
||||
}
|
||||
@@ -609,7 +640,8 @@ impl ProtocolCache {
|
||||
|
||||
// Safety net: cap failures map at 2× max entries
|
||||
if failures.len() > PROTOCOL_CACHE_MAX_ENTRIES * 2 {
|
||||
let oldest: Vec<ProtocolCacheKey> = failures.iter()
|
||||
let oldest: Vec<ProtocolCacheKey> = failures
|
||||
.iter()
|
||||
.filter(|e| e.value().all_expired())
|
||||
.map(|e| e.key().clone())
|
||||
.take(failures.len() - PROTOCOL_CACHE_MAX_ENTRIES)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -4,13 +4,15 @@ use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
|
||||
use bytes::Bytes;
|
||||
use http_body_util::Full;
|
||||
use http_body_util::BodyExt;
|
||||
use hyper::{Request, Response, StatusCode};
|
||||
use http_body_util::combinators::BoxBody;
|
||||
use http_body_util::BodyExt;
|
||||
use http_body_util::Full;
|
||||
use hyper::{Request, Response, StatusCode};
|
||||
|
||||
use rustproxy_config::RouteSecurity;
|
||||
use rustproxy_security::{IpFilter, BasicAuthValidator, JwtValidator, RateLimiter};
|
||||
use rustproxy_security::{BasicAuthValidator, IpFilter, JwtValidator, RateLimiter};
|
||||
|
||||
use crate::request_host::extract_request_host;
|
||||
|
||||
pub struct RequestFilter;
|
||||
|
||||
@@ -35,13 +37,14 @@ impl RequestFilter {
|
||||
let client_ip = peer_addr.ip();
|
||||
let request_path = req.uri().path();
|
||||
|
||||
// IP filter
|
||||
// IP filter (domain-aware: use the same host extraction as route matching)
|
||||
if security.ip_allow_list.is_some() || security.ip_block_list.is_some() {
|
||||
let allow = security.ip_allow_list.as_deref().unwrap_or(&[]);
|
||||
let block = security.ip_block_list.as_deref().unwrap_or(&[]);
|
||||
let filter = IpFilter::new(allow, block);
|
||||
let normalized = IpFilter::normalize_ip(&client_ip);
|
||||
if !filter.is_allowed(&normalized) {
|
||||
let host = extract_request_host(req);
|
||||
if !filter.is_allowed_for_domain(&normalized, host) {
|
||||
return Some(error_response(StatusCode::FORBIDDEN, "Access denied"));
|
||||
}
|
||||
}
|
||||
@@ -55,16 +58,15 @@ impl RequestFilter {
|
||||
!limiter.check(&key)
|
||||
} else {
|
||||
// Create a per-check limiter (less ideal but works for non-shared case)
|
||||
let limiter = RateLimiter::new(
|
||||
rate_limit_config.max_requests,
|
||||
rate_limit_config.window,
|
||||
);
|
||||
let limiter =
|
||||
RateLimiter::new(rate_limit_config.max_requests, rate_limit_config.window);
|
||||
let key = Self::rate_limit_key(rate_limit_config, req, peer_addr);
|
||||
!limiter.check(&key)
|
||||
};
|
||||
|
||||
if should_block {
|
||||
let message = rate_limit_config.error_message
|
||||
let message = rate_limit_config
|
||||
.error_message
|
||||
.as_deref()
|
||||
.unwrap_or("Rate limit exceeded");
|
||||
return Some(error_response(StatusCode::TOO_MANY_REQUESTS, message));
|
||||
@@ -80,36 +82,48 @@ impl RequestFilter {
|
||||
if let Some(ref basic_auth) = security.basic_auth {
|
||||
if basic_auth.enabled {
|
||||
// Check basic auth exclude paths
|
||||
let skip_basic = basic_auth.exclude_paths.as_ref()
|
||||
let skip_basic = basic_auth
|
||||
.exclude_paths
|
||||
.as_ref()
|
||||
.map(|paths| Self::path_matches_any(request_path, paths))
|
||||
.unwrap_or(false);
|
||||
|
||||
if !skip_basic {
|
||||
let users: Vec<(String, String)> = basic_auth.users.iter()
|
||||
let users: Vec<(String, String)> = basic_auth
|
||||
.users
|
||||
.iter()
|
||||
.map(|c| (c.username.clone(), c.password.clone()))
|
||||
.collect();
|
||||
let validator = BasicAuthValidator::new(users, basic_auth.realm.clone());
|
||||
|
||||
let auth_header = req.headers()
|
||||
let auth_header = req
|
||||
.headers()
|
||||
.get("authorization")
|
||||
.and_then(|v| v.to_str().ok());
|
||||
|
||||
match auth_header {
|
||||
Some(header) => {
|
||||
if validator.validate(header).is_none() {
|
||||
return Some(Response::builder()
|
||||
.status(StatusCode::UNAUTHORIZED)
|
||||
.header("WWW-Authenticate", validator.www_authenticate())
|
||||
.body(boxed_body("Invalid credentials"))
|
||||
.unwrap());
|
||||
return Some(
|
||||
Response::builder()
|
||||
.status(StatusCode::UNAUTHORIZED)
|
||||
.header(
|
||||
"WWW-Authenticate",
|
||||
validator.www_authenticate(),
|
||||
)
|
||||
.body(boxed_body("Invalid credentials"))
|
||||
.unwrap(),
|
||||
);
|
||||
}
|
||||
}
|
||||
None => {
|
||||
return Some(Response::builder()
|
||||
.status(StatusCode::UNAUTHORIZED)
|
||||
.header("WWW-Authenticate", validator.www_authenticate())
|
||||
.body(boxed_body("Authentication required"))
|
||||
.unwrap());
|
||||
return Some(
|
||||
Response::builder()
|
||||
.status(StatusCode::UNAUTHORIZED)
|
||||
.header("WWW-Authenticate", validator.www_authenticate())
|
||||
.body(boxed_body("Authentication required"))
|
||||
.unwrap(),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -120,7 +134,9 @@ impl RequestFilter {
|
||||
if let Some(ref jwt_auth) = security.jwt_auth {
|
||||
if jwt_auth.enabled {
|
||||
// Check JWT auth exclude paths
|
||||
let skip_jwt = jwt_auth.exclude_paths.as_ref()
|
||||
let skip_jwt = jwt_auth
|
||||
.exclude_paths
|
||||
.as_ref()
|
||||
.map(|paths| Self::path_matches_any(request_path, paths))
|
||||
.unwrap_or(false);
|
||||
|
||||
@@ -132,18 +148,25 @@ impl RequestFilter {
|
||||
jwt_auth.audience.as_deref(),
|
||||
);
|
||||
|
||||
let auth_header = req.headers()
|
||||
let auth_header = req
|
||||
.headers()
|
||||
.get("authorization")
|
||||
.and_then(|v| v.to_str().ok());
|
||||
|
||||
match auth_header.and_then(JwtValidator::extract_token) {
|
||||
Some(token) => {
|
||||
if validator.validate(token).is_err() {
|
||||
return Some(error_response(StatusCode::UNAUTHORIZED, "Invalid token"));
|
||||
return Some(error_response(
|
||||
StatusCode::UNAUTHORIZED,
|
||||
"Invalid token",
|
||||
));
|
||||
}
|
||||
}
|
||||
None => {
|
||||
return Some(error_response(StatusCode::UNAUTHORIZED, "Bearer token required"));
|
||||
return Some(error_response(
|
||||
StatusCode::UNAUTHORIZED,
|
||||
"Bearer token required",
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -203,14 +226,19 @@ impl RequestFilter {
|
||||
}
|
||||
|
||||
/// Check IP-based security (for use in passthrough / TCP-level connections).
|
||||
/// `domain` is the SNI from the TLS handshake (if available) for domain-scoped filtering.
|
||||
/// Returns true if allowed, false if blocked.
|
||||
pub fn check_ip_security(security: &RouteSecurity, client_ip: &std::net::IpAddr) -> bool {
|
||||
pub fn check_ip_security(
|
||||
security: &RouteSecurity,
|
||||
client_ip: &std::net::IpAddr,
|
||||
domain: Option<&str>,
|
||||
) -> bool {
|
||||
if security.ip_allow_list.is_some() || security.ip_block_list.is_some() {
|
||||
let allow = security.ip_allow_list.as_deref().unwrap_or(&[]);
|
||||
let block = security.ip_block_list.as_deref().unwrap_or(&[]);
|
||||
let filter = IpFilter::new(allow, block);
|
||||
let normalized = IpFilter::normalize_ip(client_ip);
|
||||
filter.is_allowed(&normalized)
|
||||
filter.is_allowed_for_domain(&normalized, domain)
|
||||
} else {
|
||||
true
|
||||
}
|
||||
@@ -233,19 +261,28 @@ impl RequestFilter {
|
||||
return None;
|
||||
}
|
||||
|
||||
let origin = req.headers()
|
||||
let origin = req
|
||||
.headers()
|
||||
.get("origin")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.unwrap_or("*");
|
||||
|
||||
Some(Response::builder()
|
||||
.status(StatusCode::NO_CONTENT)
|
||||
.header("Access-Control-Allow-Origin", origin)
|
||||
.header("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, PATCH, OPTIONS")
|
||||
.header("Access-Control-Allow-Headers", "Content-Type, Authorization, X-Requested-With")
|
||||
.header("Access-Control-Max-Age", "86400")
|
||||
.body(boxed_body(""))
|
||||
.unwrap())
|
||||
Some(
|
||||
Response::builder()
|
||||
.status(StatusCode::NO_CONTENT)
|
||||
.header("Access-Control-Allow-Origin", origin)
|
||||
.header(
|
||||
"Access-Control-Allow-Methods",
|
||||
"GET, POST, PUT, DELETE, PATCH, OPTIONS",
|
||||
)
|
||||
.header(
|
||||
"Access-Control-Allow-Headers",
|
||||
"Content-Type, Authorization, X-Requested-With",
|
||||
)
|
||||
.header("Access-Control-Max-Age", "86400")
|
||||
.body(boxed_body(""))
|
||||
.unwrap(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -260,3 +297,71 @@ fn error_response(status: StatusCode, message: &str) -> Response<BoxBody<Bytes,
|
||||
fn boxed_body(data: &str) -> BoxBody<Bytes, hyper::Error> {
|
||||
BoxBody::new(Full::new(Bytes::from(data.to_string())).map_err(|never| match never {}))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use bytes::Bytes;
|
||||
use http_body_util::Empty;
|
||||
use hyper::{Request, StatusCode, Version};
|
||||
use rustproxy_config::{IpAllowEntry, RouteSecurity};
|
||||
|
||||
use super::RequestFilter;
|
||||
|
||||
fn domain_scoped_security() -> RouteSecurity {
|
||||
RouteSecurity {
|
||||
ip_allow_list: Some(vec![IpAllowEntry::DomainScoped {
|
||||
ip: "10.8.0.2".to_string(),
|
||||
domains: vec!["*.abc.xyz".to_string()],
|
||||
}]),
|
||||
ip_block_list: None,
|
||||
max_connections: None,
|
||||
authentication: None,
|
||||
rate_limit: None,
|
||||
basic_auth: None,
|
||||
jwt_auth: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn peer_addr() -> std::net::SocketAddr {
|
||||
std::net::SocketAddr::from(([10, 8, 0, 2], 4242))
|
||||
}
|
||||
|
||||
fn request(uri: &str, version: Version, host: Option<&str>) -> Request<Empty<Bytes>> {
|
||||
let mut builder = Request::builder().uri(uri).version(version);
|
||||
if let Some(host) = host {
|
||||
builder = builder.header("host", host);
|
||||
}
|
||||
|
||||
builder.body(Empty::<Bytes>::new()).unwrap()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn domain_scoped_acl_allows_uri_authority_without_host_header() {
|
||||
let security = domain_scoped_security();
|
||||
let req = request("https://outline.abc.xyz/", Version::HTTP_2, None);
|
||||
|
||||
assert!(RequestFilter::apply(&security, &req, &peer_addr()).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn domain_scoped_acl_allows_host_header_with_port() {
|
||||
let security = domain_scoped_security();
|
||||
let req = request(
|
||||
"https://unrelated.invalid/",
|
||||
Version::HTTP_11,
|
||||
Some("outline.abc.xyz:443"),
|
||||
);
|
||||
|
||||
assert!(RequestFilter::apply(&security, &req, &peer_addr()).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn domain_scoped_acl_denies_non_matching_uri_authority() {
|
||||
let security = domain_scoped_security();
|
||||
let req = request("https://outline.other.xyz/", Version::HTTP_2, None);
|
||||
|
||||
let response = RequestFilter::apply(&security, &req, &peer_addr())
|
||||
.expect("non-matching domain should be denied");
|
||||
assert_eq!(response.status(), StatusCode::FORBIDDEN);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,43 @@
|
||||
use hyper::Request;
|
||||
|
||||
/// Extract the effective request host for routing and scoped ACL checks.
|
||||
///
|
||||
/// Prefer the explicit `Host` header when present, otherwise fall back to the
|
||||
/// URI authority used by HTTP/2 and HTTP/3 requests.
|
||||
pub(crate) fn extract_request_host<B>(req: &Request<B>) -> Option<&str> {
|
||||
req.headers()
|
||||
.get("host")
|
||||
.and_then(|value| value.to_str().ok())
|
||||
.map(|host| host.split(':').next().unwrap_or(host))
|
||||
.or_else(|| req.uri().host())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use bytes::Bytes;
|
||||
use http_body_util::Empty;
|
||||
use hyper::Request;
|
||||
|
||||
use super::extract_request_host;
|
||||
|
||||
#[test]
|
||||
fn extracts_host_header_before_uri_authority() {
|
||||
let req = Request::builder()
|
||||
.uri("https://uri.abc.xyz/test")
|
||||
.header("host", "header.abc.xyz:443")
|
||||
.body(Empty::<Bytes>::new())
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(extract_request_host(&req), Some("header.abc.xyz"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn falls_back_to_uri_authority_when_host_header_missing() {
|
||||
let req = Request::builder()
|
||||
.uri("https://outline.abc.xyz/test")
|
||||
.body(Empty::<Bytes>::new())
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(extract_request_host(&req), Some("outline.abc.xyz"));
|
||||
}
|
||||
}
|
||||
@@ -3,7 +3,7 @@
|
||||
use hyper::header::{HeaderMap, HeaderName, HeaderValue};
|
||||
use rustproxy_config::RouteConfig;
|
||||
|
||||
use crate::template::{RequestContext, expand_template};
|
||||
use crate::template::{expand_template, RequestContext};
|
||||
|
||||
pub struct ResponseFilter;
|
||||
|
||||
@@ -11,12 +11,17 @@ impl ResponseFilter {
|
||||
/// Apply response headers from route config and CORS settings.
|
||||
/// If a `RequestContext` is provided, template variables in header values will be expanded.
|
||||
/// Also injects Alt-Svc header for routes with HTTP/3 enabled.
|
||||
pub fn apply_headers(route: &RouteConfig, headers: &mut HeaderMap, req_ctx: Option<&RequestContext>) {
|
||||
pub fn apply_headers(
|
||||
route: &RouteConfig,
|
||||
headers: &mut HeaderMap,
|
||||
req_ctx: Option<&RequestContext>,
|
||||
) {
|
||||
// Inject Alt-Svc for HTTP/3 advertisement if QUIC/HTTP3 is enabled on this route
|
||||
if let Some(ref udp) = route.action.udp {
|
||||
if let Some(ref quic) = udp.quic {
|
||||
if quic.enable_http3.unwrap_or(false) {
|
||||
let port = quic.alt_svc_port
|
||||
let port = quic
|
||||
.alt_svc_port
|
||||
.or_else(|| req_ctx.map(|c| c.port))
|
||||
.unwrap_or(443);
|
||||
let max_age = quic.alt_svc_max_age.unwrap_or(86400);
|
||||
@@ -63,10 +68,7 @@ impl ResponseFilter {
|
||||
headers.insert("access-control-allow-origin", val);
|
||||
}
|
||||
} else {
|
||||
headers.insert(
|
||||
"access-control-allow-origin",
|
||||
HeaderValue::from_static("*"),
|
||||
);
|
||||
headers.insert("access-control-allow-origin", HeaderValue::from_static("*"));
|
||||
}
|
||||
|
||||
// Allow-Methods
|
||||
|
||||
@@ -62,17 +62,11 @@ impl<S: AsyncRead + AsyncWrite + Unpin + Send + 'static> AsyncWrite for Shutdown
|
||||
self.inner.as_ref().unwrap().is_write_vectored()
|
||||
}
|
||||
|
||||
fn poll_flush(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
Pin::new(self.get_mut().inner.as_mut().unwrap()).poll_flush(cx)
|
||||
}
|
||||
|
||||
fn poll_shutdown(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
let this = self.get_mut();
|
||||
let result = Pin::new(this.inner.as_mut().unwrap()).poll_shutdown(cx);
|
||||
if result.is_ready() {
|
||||
@@ -93,7 +87,8 @@ impl<S: AsyncRead + AsyncWrite + Unpin + Send + 'static> Drop for ShutdownOnDrop
|
||||
let _ = tokio::time::timeout(
|
||||
std::time::Duration::from_secs(2),
|
||||
tokio::io::AsyncWriteExt::shutdown(&mut stream),
|
||||
).await;
|
||||
)
|
||||
.await;
|
||||
// stream is dropped here — all resources freed
|
||||
});
|
||||
}
|
||||
|
||||
@@ -39,7 +39,8 @@ pub fn expand_headers(
|
||||
headers: &HashMap<String, String>,
|
||||
ctx: &RequestContext,
|
||||
) -> HashMap<String, String> {
|
||||
headers.iter()
|
||||
headers
|
||||
.iter()
|
||||
.map(|(k, v)| (k.clone(), expand_template(v, ctx)))
|
||||
.collect()
|
||||
}
|
||||
@@ -150,7 +151,10 @@ mod tests {
|
||||
let ctx = test_context();
|
||||
let template = "{clientIp}|{domain}|{port}|{path}|{routeName}|{connectionId}";
|
||||
let result = expand_template(template, &ctx);
|
||||
assert_eq!(result, "192.168.1.100|example.com|443|/api/v1/users|api-route|42");
|
||||
assert_eq!(
|
||||
result,
|
||||
"192.168.1.100|example.com|443|/api/v1/users|api-route|42"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -7,7 +7,7 @@ use std::sync::Arc;
|
||||
use std::sync::Mutex;
|
||||
|
||||
use dashmap::DashMap;
|
||||
use rustproxy_config::{RouteTarget, LoadBalancingAlgorithm};
|
||||
use rustproxy_config::{LoadBalancingAlgorithm, RouteTarget};
|
||||
|
||||
/// Upstream selection result.
|
||||
pub struct UpstreamSelection {
|
||||
@@ -51,21 +51,19 @@ impl UpstreamSelector {
|
||||
}
|
||||
|
||||
// Determine load balancing algorithm
|
||||
let algorithm = target.load_balancing.as_ref()
|
||||
let algorithm = target
|
||||
.load_balancing
|
||||
.as_ref()
|
||||
.map(|lb| &lb.algorithm)
|
||||
.unwrap_or(&LoadBalancingAlgorithm::RoundRobin);
|
||||
|
||||
let idx = match algorithm {
|
||||
LoadBalancingAlgorithm::RoundRobin => {
|
||||
self.round_robin_select(&hosts, port)
|
||||
}
|
||||
LoadBalancingAlgorithm::RoundRobin => self.round_robin_select(&hosts, port),
|
||||
LoadBalancingAlgorithm::IpHash => {
|
||||
let hash = Self::ip_hash(client_addr);
|
||||
hash % hosts.len()
|
||||
}
|
||||
LoadBalancingAlgorithm::LeastConnections => {
|
||||
self.least_connections_select(&hosts, port)
|
||||
}
|
||||
LoadBalancingAlgorithm::LeastConnections => self.least_connections_select(&hosts, port),
|
||||
};
|
||||
|
||||
UpstreamSelection {
|
||||
@@ -78,9 +76,7 @@ impl UpstreamSelector {
|
||||
fn round_robin_select(&self, hosts: &[&str], port: u16) -> usize {
|
||||
let key = format!("{}:{}", hosts[0], port);
|
||||
let mut counters = self.round_robin.lock().unwrap();
|
||||
let counter = counters
|
||||
.entry(key)
|
||||
.or_insert_with(|| AtomicUsize::new(0));
|
||||
let counter = counters.entry(key).or_insert_with(|| AtomicUsize::new(0));
|
||||
let idx = counter.fetch_add(1, Ordering::Relaxed);
|
||||
idx % hosts.len()
|
||||
}
|
||||
@@ -91,7 +87,8 @@ impl UpstreamSelector {
|
||||
|
||||
for (i, host) in hosts.iter().enumerate() {
|
||||
let key = format!("{}:{}", host, port);
|
||||
let conns = self.active_connections
|
||||
let conns = self
|
||||
.active_connections
|
||||
.get(&key)
|
||||
.map(|entry| entry.value().load(Ordering::Relaxed))
|
||||
.unwrap_or(0);
|
||||
@@ -228,13 +225,21 @@ mod tests {
|
||||
selector.connection_started("backend:8080");
|
||||
selector.connection_started("backend:8080");
|
||||
assert_eq!(
|
||||
selector.active_connections.get("backend:8080").unwrap().load(Ordering::Relaxed),
|
||||
selector
|
||||
.active_connections
|
||||
.get("backend:8080")
|
||||
.unwrap()
|
||||
.load(Ordering::Relaxed),
|
||||
2
|
||||
);
|
||||
|
||||
selector.connection_ended("backend:8080");
|
||||
assert_eq!(
|
||||
selector.active_connections.get("backend:8080").unwrap().load(Ordering::Relaxed),
|
||||
selector
|
||||
.active_connections
|
||||
.get("backend:8080")
|
||||
.unwrap()
|
||||
.load(Ordering::Relaxed),
|
||||
1
|
||||
);
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -2,10 +2,10 @@
|
||||
//!
|
||||
//! Metrics and throughput tracking for RustProxy.
|
||||
|
||||
pub mod throughput;
|
||||
pub mod collector;
|
||||
pub mod log_dedup;
|
||||
pub mod throughput;
|
||||
|
||||
pub use throughput::*;
|
||||
pub use collector::*;
|
||||
pub use log_dedup::*;
|
||||
pub use throughput::*;
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use dashmap::DashMap;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
use tracing::info;
|
||||
|
||||
@@ -47,13 +47,16 @@ impl LogDeduplicator {
|
||||
let map_key = format!("{}:{}", category, key);
|
||||
let now = Instant::now();
|
||||
|
||||
let entry = self.events.entry(map_key).or_insert_with(|| AggregatedEvent {
|
||||
category: category.to_string(),
|
||||
first_message: message.to_string(),
|
||||
count: AtomicU64::new(0),
|
||||
first_seen: now,
|
||||
last_seen: now,
|
||||
});
|
||||
let entry = self
|
||||
.events
|
||||
.entry(map_key)
|
||||
.or_insert_with(|| AggregatedEvent {
|
||||
category: category.to_string(),
|
||||
first_message: message.to_string(),
|
||||
count: AtomicU64::new(0),
|
||||
first_seen: now,
|
||||
last_seen: now,
|
||||
});
|
||||
|
||||
let count = entry.count.fetch_add(1, Ordering::Relaxed) + 1;
|
||||
|
||||
|
||||
@@ -29,6 +29,113 @@ pub struct ThroughputTracker {
|
||||
created_at: Instant,
|
||||
}
|
||||
|
||||
fn unix_timestamp_seconds() -> u64 {
|
||||
SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs()
|
||||
}
|
||||
|
||||
/// Circular buffer for per-second event counts.
|
||||
///
|
||||
/// Unlike `ThroughputTracker`, events are recorded directly into the current
|
||||
/// second so request counts remain stable even when the collector is sampled
|
||||
/// more frequently than once per second.
|
||||
pub(crate) struct RequestRateTracker {
|
||||
samples: Vec<u64>,
|
||||
write_index: usize,
|
||||
count: usize,
|
||||
capacity: usize,
|
||||
current_second: Option<u64>,
|
||||
current_count: u64,
|
||||
}
|
||||
|
||||
impl RequestRateTracker {
|
||||
pub(crate) fn new(retention_seconds: usize) -> Self {
|
||||
Self {
|
||||
samples: Vec::with_capacity(retention_seconds.max(1)),
|
||||
write_index: 0,
|
||||
count: 0,
|
||||
capacity: retention_seconds.max(1),
|
||||
current_second: None,
|
||||
current_count: 0,
|
||||
}
|
||||
}
|
||||
|
||||
fn push_sample(&mut self, count: u64) {
|
||||
if self.samples.len() < self.capacity {
|
||||
self.samples.push(count);
|
||||
} else {
|
||||
self.samples[self.write_index] = count;
|
||||
}
|
||||
self.write_index = (self.write_index + 1) % self.capacity;
|
||||
self.count = (self.count + 1).min(self.capacity);
|
||||
}
|
||||
|
||||
pub(crate) fn record_event(&mut self) {
|
||||
self.record_events_at(unix_timestamp_seconds(), 1);
|
||||
}
|
||||
|
||||
pub(crate) fn record_events_at(&mut self, now_sec: u64, count: u64) {
|
||||
self.advance_to(now_sec);
|
||||
self.current_count = self.current_count.saturating_add(count);
|
||||
}
|
||||
|
||||
pub(crate) fn advance_to_now(&mut self) {
|
||||
self.advance_to(unix_timestamp_seconds());
|
||||
}
|
||||
|
||||
pub(crate) fn advance_to(&mut self, now_sec: u64) {
|
||||
match self.current_second {
|
||||
Some(current_second) if now_sec > current_second => {
|
||||
self.push_sample(self.current_count);
|
||||
for _ in 1..(now_sec - current_second) {
|
||||
self.push_sample(0);
|
||||
}
|
||||
self.current_second = Some(now_sec);
|
||||
self.current_count = 0;
|
||||
}
|
||||
Some(_) => {}
|
||||
None => {
|
||||
self.current_second = Some(now_sec);
|
||||
self.current_count = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn sum_recent(&self, window_seconds: usize) -> u64 {
|
||||
let window = window_seconds.min(self.count);
|
||||
if window == 0 {
|
||||
return 0;
|
||||
}
|
||||
|
||||
let mut total = 0u64;
|
||||
for i in 0..window {
|
||||
let idx = if self.write_index >= i + 1 {
|
||||
self.write_index - i - 1
|
||||
} else {
|
||||
self.capacity - (i + 1 - self.write_index)
|
||||
};
|
||||
if idx < self.samples.len() {
|
||||
total += self.samples[idx];
|
||||
}
|
||||
}
|
||||
total
|
||||
}
|
||||
|
||||
pub(crate) fn last_second(&self) -> u64 {
|
||||
self.sum_recent(1)
|
||||
}
|
||||
|
||||
pub(crate) fn last_minute(&self) -> u64 {
|
||||
self.sum_recent(60)
|
||||
}
|
||||
|
||||
pub(crate) fn is_idle(&self) -> bool {
|
||||
self.current_count == 0 && self.sum_recent(self.capacity) == 0
|
||||
}
|
||||
}
|
||||
|
||||
impl ThroughputTracker {
|
||||
/// Create a new tracker with the given capacity (seconds of retention).
|
||||
pub fn new(retention_seconds: usize) -> Self {
|
||||
@@ -46,7 +153,8 @@ impl ThroughputTracker {
|
||||
/// Record bytes (called from data flow callbacks).
|
||||
pub fn record_bytes(&self, bytes_in: u64, bytes_out: u64) {
|
||||
self.pending_bytes_in.fetch_add(bytes_in, Ordering::Relaxed);
|
||||
self.pending_bytes_out.fetch_add(bytes_out, Ordering::Relaxed);
|
||||
self.pending_bytes_out
|
||||
.fetch_add(bytes_out, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
/// Take a sample (called at 1Hz).
|
||||
@@ -229,4 +337,41 @@ mod tests {
|
||||
let history = tracker.history(10);
|
||||
assert!(history.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_request_rate_tracker_counts_last_second_and_last_minute() {
|
||||
let mut tracker = RequestRateTracker::new(60);
|
||||
|
||||
tracker.record_events_at(100, 2);
|
||||
tracker.record_events_at(100, 3);
|
||||
tracker.advance_to(101);
|
||||
|
||||
assert_eq!(tracker.last_second(), 5);
|
||||
assert_eq!(tracker.last_minute(), 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_request_rate_tracker_adds_zero_samples_for_gaps() {
|
||||
let mut tracker = RequestRateTracker::new(60);
|
||||
|
||||
tracker.record_events_at(100, 4);
|
||||
tracker.record_events_at(102, 1);
|
||||
tracker.advance_to(103);
|
||||
|
||||
assert_eq!(tracker.last_second(), 1);
|
||||
assert_eq!(tracker.last_minute(), 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_request_rate_tracker_decays_to_zero_over_window() {
|
||||
let mut tracker = RequestRateTracker::new(60);
|
||||
|
||||
tracker.record_events_at(100, 7);
|
||||
tracker.advance_to(101);
|
||||
tracker.advance_to(161);
|
||||
|
||||
assert_eq!(tracker.last_second(), 0);
|
||||
assert_eq!(tracker.last_minute(), 0);
|
||||
assert!(tracker.is_idle());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,17 +0,0 @@
|
||||
[package]
|
||||
name = "rustproxy-nftables"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
authors.workspace = true
|
||||
description = "NFTables kernel-level forwarding for RustProxy"
|
||||
|
||||
[dependencies]
|
||||
rustproxy-config = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
anyhow = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
libc = { workspace = true }
|
||||
@@ -1,10 +0,0 @@
|
||||
//! # rustproxy-nftables
|
||||
//!
|
||||
//! NFTables kernel-level forwarding for RustProxy.
|
||||
//! Generates and manages nft CLI rules for DNAT/SNAT.
|
||||
|
||||
pub mod nft_manager;
|
||||
pub mod rule_builder;
|
||||
|
||||
pub use nft_manager::*;
|
||||
pub use rule_builder::*;
|
||||
@@ -1,238 +0,0 @@
|
||||
use thiserror::Error;
|
||||
use std::collections::HashMap;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum NftError {
|
||||
#[error("nft command failed: {0}")]
|
||||
CommandFailed(String),
|
||||
#[error("IO error: {0}")]
|
||||
Io(#[from] std::io::Error),
|
||||
#[error("Not running as root")]
|
||||
NotRoot,
|
||||
}
|
||||
|
||||
/// Manager for nftables rules.
|
||||
///
|
||||
/// Executes `nft` CLI commands to manage kernel-level packet forwarding.
|
||||
/// Requires root privileges; operations are skipped gracefully if not root.
|
||||
pub struct NftManager {
|
||||
table_name: String,
|
||||
/// Active rules indexed by route ID
|
||||
active_rules: HashMap<String, Vec<String>>,
|
||||
/// Whether the table has been initialized
|
||||
table_initialized: bool,
|
||||
}
|
||||
|
||||
impl NftManager {
|
||||
pub fn new(table_name: Option<String>) -> Self {
|
||||
Self {
|
||||
table_name: table_name.unwrap_or_else(|| "rustproxy".to_string()),
|
||||
active_rules: HashMap::new(),
|
||||
table_initialized: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if we are running as root.
|
||||
fn is_root() -> bool {
|
||||
unsafe { libc::geteuid() == 0 }
|
||||
}
|
||||
|
||||
/// Execute a single nft command via the CLI.
|
||||
async fn exec_nft(command: &str) -> Result<String, NftError> {
|
||||
// The command starts with "nft ", strip it to get the args
|
||||
let args = if command.starts_with("nft ") {
|
||||
&command[4..]
|
||||
} else {
|
||||
command
|
||||
};
|
||||
|
||||
let output = tokio::process::Command::new("nft")
|
||||
.args(args.split_whitespace())
|
||||
.output()
|
||||
.await
|
||||
.map_err(NftError::Io)?;
|
||||
|
||||
if output.status.success() {
|
||||
Ok(String::from_utf8_lossy(&output.stdout).to_string())
|
||||
} else {
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
Err(NftError::CommandFailed(format!(
|
||||
"Command '{}' failed: {}",
|
||||
command, stderr
|
||||
)))
|
||||
}
|
||||
}
|
||||
|
||||
/// Ensure the nftables table and chains are set up.
|
||||
async fn ensure_table(&mut self) -> Result<(), NftError> {
|
||||
if self.table_initialized {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let setup_commands = crate::rule_builder::build_table_setup(&self.table_name);
|
||||
for cmd in &setup_commands {
|
||||
Self::exec_nft(cmd).await?;
|
||||
}
|
||||
|
||||
self.table_initialized = true;
|
||||
info!("NFTables table '{}' initialized", self.table_name);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Apply rules for a route.
|
||||
///
|
||||
/// Executes the nft commands via the CLI. If not running as root,
|
||||
/// the rules are stored locally but not applied to the kernel.
|
||||
pub async fn apply_rules(&mut self, route_id: &str, rules: Vec<String>) -> Result<(), NftError> {
|
||||
if !Self::is_root() {
|
||||
warn!("Not running as root, nftables rules will not be applied to kernel");
|
||||
self.active_rules.insert(route_id.to_string(), rules);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
self.ensure_table().await?;
|
||||
|
||||
for cmd in &rules {
|
||||
Self::exec_nft(cmd).await?;
|
||||
debug!("Applied nft rule: {}", cmd);
|
||||
}
|
||||
|
||||
info!("Applied {} nftables rules for route '{}'", rules.len(), route_id);
|
||||
self.active_rules.insert(route_id.to_string(), rules);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Remove rules for a route.
|
||||
///
|
||||
/// Currently removes the route from tracking. To fully remove specific
|
||||
/// rules would require handle-based tracking; for now, cleanup() removes
|
||||
/// the entire table.
|
||||
pub async fn remove_rules(&mut self, route_id: &str) -> Result<(), NftError> {
|
||||
if let Some(rules) = self.active_rules.remove(route_id) {
|
||||
info!("Removed {} tracked nft rules for route '{}'", rules.len(), route_id);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Clean up all managed rules by deleting the entire nftables table.
|
||||
pub async fn cleanup(&mut self) -> Result<(), NftError> {
|
||||
if !Self::is_root() {
|
||||
warn!("Not running as root, skipping nftables cleanup");
|
||||
self.active_rules.clear();
|
||||
self.table_initialized = false;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if self.table_initialized {
|
||||
let cleanup_commands = crate::rule_builder::build_table_cleanup(&self.table_name);
|
||||
for cmd in &cleanup_commands {
|
||||
match Self::exec_nft(cmd).await {
|
||||
Ok(_) => debug!("Cleanup: {}", cmd),
|
||||
Err(e) => warn!("Cleanup command failed (may be ok): {}", e),
|
||||
}
|
||||
}
|
||||
info!("NFTables table '{}' cleaned up", self.table_name);
|
||||
}
|
||||
|
||||
self.active_rules.clear();
|
||||
self.table_initialized = false;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get the table name.
|
||||
pub fn table_name(&self) -> &str {
|
||||
&self.table_name
|
||||
}
|
||||
|
||||
/// Whether the table has been initialized in the kernel.
|
||||
pub fn is_initialized(&self) -> bool {
|
||||
self.table_initialized
|
||||
}
|
||||
|
||||
/// Get the number of active route rule sets.
|
||||
pub fn active_route_count(&self) -> usize {
|
||||
self.active_rules.len()
|
||||
}
|
||||
|
||||
/// Get the status of all active rules.
|
||||
pub fn status(&self) -> HashMap<String, serde_json::Value> {
|
||||
let mut status = HashMap::new();
|
||||
for (route_id, rules) in &self.active_rules {
|
||||
status.insert(
|
||||
route_id.clone(),
|
||||
serde_json::json!({
|
||||
"ruleCount": rules.len(),
|
||||
"rules": rules,
|
||||
}),
|
||||
);
|
||||
}
|
||||
status
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_new_default_table_name() {
|
||||
let mgr = NftManager::new(None);
|
||||
assert_eq!(mgr.table_name(), "rustproxy");
|
||||
assert!(!mgr.is_initialized());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_new_custom_table_name() {
|
||||
let mgr = NftManager::new(Some("custom".to_string()));
|
||||
assert_eq!(mgr.table_name(), "custom");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_apply_rules_non_root() {
|
||||
let mut mgr = NftManager::new(None);
|
||||
// When not root, rules are stored but not applied to kernel
|
||||
let rules = vec!["nft add rule ip rustproxy prerouting tcp dport 443 dnat to 10.0.0.1:8443".to_string()];
|
||||
mgr.apply_rules("route-1", rules).await.unwrap();
|
||||
assert_eq!(mgr.active_route_count(), 1);
|
||||
|
||||
let status = mgr.status();
|
||||
assert!(status.contains_key("route-1"));
|
||||
assert_eq!(status["route-1"]["ruleCount"], 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_remove_rules() {
|
||||
let mut mgr = NftManager::new(None);
|
||||
let rules = vec!["nft add rule test".to_string()];
|
||||
mgr.apply_rules("route-1", rules).await.unwrap();
|
||||
assert_eq!(mgr.active_route_count(), 1);
|
||||
|
||||
mgr.remove_rules("route-1").await.unwrap();
|
||||
assert_eq!(mgr.active_route_count(), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_cleanup_non_root() {
|
||||
let mut mgr = NftManager::new(None);
|
||||
let rules = vec!["nft add rule test".to_string()];
|
||||
mgr.apply_rules("route-1", rules).await.unwrap();
|
||||
mgr.apply_rules("route-2", vec!["nft add rule test2".to_string()]).await.unwrap();
|
||||
|
||||
mgr.cleanup().await.unwrap();
|
||||
assert_eq!(mgr.active_route_count(), 0);
|
||||
assert!(!mgr.is_initialized());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_status_multiple_routes() {
|
||||
let mut mgr = NftManager::new(None);
|
||||
mgr.apply_rules("web", vec!["rule1".to_string(), "rule2".to_string()]).await.unwrap();
|
||||
mgr.apply_rules("api", vec!["rule3".to_string()]).await.unwrap();
|
||||
|
||||
let status = mgr.status();
|
||||
assert_eq!(status.len(), 2);
|
||||
assert_eq!(status["web"]["ruleCount"], 2);
|
||||
assert_eq!(status["api"]["ruleCount"], 1);
|
||||
}
|
||||
}
|
||||
@@ -1,146 +0,0 @@
|
||||
use rustproxy_config::{NfTablesOptions, NfTablesProtocol};
|
||||
|
||||
/// Build nftables DNAT rule for port forwarding.
|
||||
pub fn build_dnat_rule(
|
||||
table_name: &str,
|
||||
chain_name: &str,
|
||||
source_port: u16,
|
||||
target_host: &str,
|
||||
target_port: u16,
|
||||
options: &NfTablesOptions,
|
||||
) -> Vec<String> {
|
||||
let protocols: Vec<&str> = match options.protocol.as_ref().unwrap_or(&NfTablesProtocol::Tcp) {
|
||||
NfTablesProtocol::Tcp => vec!["tcp"],
|
||||
NfTablesProtocol::Udp => vec!["udp"],
|
||||
NfTablesProtocol::All => vec!["tcp", "udp"],
|
||||
};
|
||||
|
||||
let mut rules = Vec::new();
|
||||
|
||||
for protocol in &protocols {
|
||||
// DNAT rule
|
||||
rules.push(format!(
|
||||
"nft add rule ip {} {} {} dport {} dnat to {}:{}",
|
||||
table_name, chain_name, protocol, source_port, target_host, target_port,
|
||||
));
|
||||
|
||||
// SNAT rule if preserving source IP is not enabled
|
||||
if !options.preserve_source_ip.unwrap_or(false) {
|
||||
rules.push(format!(
|
||||
"nft add rule ip {} postrouting {} dport {} masquerade",
|
||||
table_name, protocol, target_port,
|
||||
));
|
||||
}
|
||||
|
||||
// Rate limiting
|
||||
if let Some(max_rate) = &options.max_rate {
|
||||
rules.push(format!(
|
||||
"nft add rule ip {} {} {} dport {} limit rate {} accept",
|
||||
table_name, chain_name, protocol, source_port, max_rate,
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
rules
|
||||
}
|
||||
|
||||
/// Build the initial table and chain setup commands.
|
||||
pub fn build_table_setup(table_name: &str) -> Vec<String> {
|
||||
vec![
|
||||
format!("nft add table ip {}", table_name),
|
||||
format!("nft add chain ip {} prerouting {{ type nat hook prerouting priority 0 \\; }}", table_name),
|
||||
format!("nft add chain ip {} postrouting {{ type nat hook postrouting priority 100 \\; }}", table_name),
|
||||
]
|
||||
}
|
||||
|
||||
/// Build cleanup commands to remove the table.
|
||||
pub fn build_table_cleanup(table_name: &str) -> Vec<String> {
|
||||
vec![format!("nft delete table ip {}", table_name)]
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn make_options() -> NfTablesOptions {
|
||||
NfTablesOptions {
|
||||
preserve_source_ip: None,
|
||||
protocol: None,
|
||||
max_rate: None,
|
||||
priority: None,
|
||||
table_name: None,
|
||||
use_ip_sets: None,
|
||||
use_advanced_nat: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_basic_dnat_rule() {
|
||||
let options = make_options();
|
||||
let rules = build_dnat_rule("rustproxy", "prerouting", 443, "10.0.0.1", 8443, &options);
|
||||
assert!(rules.len() >= 1);
|
||||
assert!(rules[0].contains("dnat to 10.0.0.1:8443"));
|
||||
assert!(rules[0].contains("dport 443"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_preserve_source_ip() {
|
||||
let mut options = make_options();
|
||||
options.preserve_source_ip = Some(true);
|
||||
let rules = build_dnat_rule("rustproxy", "prerouting", 443, "10.0.0.1", 8443, &options);
|
||||
// When preserving source IP, no masquerade rule
|
||||
assert!(rules.iter().all(|r| !r.contains("masquerade")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_without_preserve_source_ip() {
|
||||
let options = make_options();
|
||||
let rules = build_dnat_rule("rustproxy", "prerouting", 443, "10.0.0.1", 8443, &options);
|
||||
assert!(rules.iter().any(|r| r.contains("masquerade")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rate_limited_rule() {
|
||||
let mut options = make_options();
|
||||
options.max_rate = Some("100/second".to_string());
|
||||
let rules = build_dnat_rule("rustproxy", "prerouting", 80, "10.0.0.1", 8080, &options);
|
||||
assert!(rules.iter().any(|r| r.contains("limit rate 100/second")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_table_setup_commands() {
|
||||
let commands = build_table_setup("rustproxy");
|
||||
assert_eq!(commands.len(), 3);
|
||||
assert!(commands[0].contains("add table ip rustproxy"));
|
||||
assert!(commands[1].contains("prerouting"));
|
||||
assert!(commands[2].contains("postrouting"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_table_cleanup() {
|
||||
let commands = build_table_cleanup("rustproxy");
|
||||
assert_eq!(commands.len(), 1);
|
||||
assert!(commands[0].contains("delete table ip rustproxy"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_protocol_all_generates_tcp_and_udp_rules() {
|
||||
let mut options = make_options();
|
||||
options.protocol = Some(NfTablesProtocol::All);
|
||||
let rules = build_dnat_rule("rustproxy", "prerouting", 53, "10.0.0.53", 53, &options);
|
||||
// Should have TCP DNAT + masquerade + UDP DNAT + masquerade = 4 rules
|
||||
assert_eq!(rules.len(), 4);
|
||||
assert!(rules.iter().any(|r| r.contains("tcp dport 53 dnat")));
|
||||
assert!(rules.iter().any(|r| r.contains("udp dport 53 dnat")));
|
||||
assert!(rules.iter().filter(|r| r.contains("masquerade")).count() == 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_protocol_udp() {
|
||||
let mut options = make_options();
|
||||
options.protocol = Some(NfTablesProtocol::Udp);
|
||||
let rules = build_dnat_rule("rustproxy", "prerouting", 53, "10.0.0.53", 53, &options);
|
||||
assert!(rules.iter().all(|r| !r.contains("tcp")));
|
||||
assert!(rules.iter().any(|r| r.contains("udp dport 53 dnat")));
|
||||
}
|
||||
}
|
||||
@@ -10,6 +10,7 @@ description = "Raw TCP/SNI passthrough engine for RustProxy"
|
||||
rustproxy-config = { workspace = true }
|
||||
rustproxy-routing = { workspace = true }
|
||||
rustproxy-metrics = { workspace = true }
|
||||
rustproxy-security = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
|
||||
@@ -0,0 +1,335 @@
|
||||
//! Shared connection registry for selective connection recycling.
|
||||
//!
|
||||
//! Tracks active connections across both TCP and QUIC with metadata
|
||||
//! (source IP, SNI domain, route ID, cancel token) so that connections
|
||||
//! can be selectively recycled when certificates, security rules, or
|
||||
//! route targets change.
|
||||
|
||||
use std::collections::HashSet;
|
||||
use std::net::IpAddr;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::sync::Arc;
|
||||
|
||||
use dashmap::DashMap;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::info;
|
||||
|
||||
use rustproxy_config::RouteSecurity;
|
||||
use rustproxy_http::request_filter::RequestFilter;
|
||||
use rustproxy_routing::matchers::domain_matches;
|
||||
|
||||
/// Metadata about an active connection.
|
||||
pub struct ConnectionEntry {
|
||||
/// Per-connection cancel token (child of per-route token).
|
||||
pub cancel: CancellationToken,
|
||||
/// Client source IP.
|
||||
pub source_ip: IpAddr,
|
||||
/// SNI domain from TLS handshake (None for non-TLS connections).
|
||||
pub domain: Option<String>,
|
||||
/// Route ID this connection was matched to (None if route has no ID).
|
||||
pub route_id: Option<String>,
|
||||
}
|
||||
|
||||
/// Transport-agnostic registry of active connections.
|
||||
///
|
||||
/// Used by both `TcpListenerManager` and `UdpListenerManager` to track
|
||||
/// connections and enable selective recycling on config changes.
|
||||
pub struct ConnectionRegistry {
|
||||
connections: DashMap<u64, ConnectionEntry>,
|
||||
next_id: AtomicU64,
|
||||
}
|
||||
|
||||
impl ConnectionRegistry {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
connections: DashMap::new(),
|
||||
next_id: AtomicU64::new(1),
|
||||
}
|
||||
}
|
||||
|
||||
/// Register a connection and return its ID + RAII guard.
|
||||
///
|
||||
/// The guard automatically removes the connection from the registry on drop.
|
||||
pub fn register(self: &Arc<Self>, entry: ConnectionEntry) -> (u64, ConnectionRegistryGuard) {
|
||||
let id = self.next_id.fetch_add(1, Ordering::Relaxed);
|
||||
self.connections.insert(id, entry);
|
||||
let guard = ConnectionRegistryGuard {
|
||||
registry: Arc::clone(self),
|
||||
conn_id: id,
|
||||
};
|
||||
(id, guard)
|
||||
}
|
||||
|
||||
/// Number of tracked connections (for metrics/debugging).
|
||||
pub fn len(&self) -> usize {
|
||||
self.connections.len()
|
||||
}
|
||||
|
||||
/// Recycle connections whose SNI domain matches a renewed certificate domain.
|
||||
///
|
||||
/// Uses bidirectional domain matching so that:
|
||||
/// - Cert `*.example.com` recycles connections for `sub.example.com`
|
||||
/// - Cert `sub.example.com` recycles connections on routes with `*.example.com`
|
||||
pub fn recycle_for_cert_change(&self, cert_domain: &str) {
|
||||
let mut recycled = 0u64;
|
||||
self.connections.retain(|_, entry| {
|
||||
let matches = entry
|
||||
.domain
|
||||
.as_deref()
|
||||
.map(|d| domain_matches(cert_domain, d) || domain_matches(d, cert_domain))
|
||||
.unwrap_or(false);
|
||||
if matches {
|
||||
entry.cancel.cancel();
|
||||
recycled += 1;
|
||||
false
|
||||
} else {
|
||||
true
|
||||
}
|
||||
});
|
||||
if recycled > 0 {
|
||||
info!(
|
||||
"Recycled {} connection(s) for cert change on domain '{}'",
|
||||
recycled, cert_domain
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Recycle connections on a route whose security config changed.
|
||||
///
|
||||
/// Re-evaluates each connection's source IP against the new security rules.
|
||||
/// Only connections from now-blocked IPs are terminated; allowed IPs are undisturbed.
|
||||
pub fn recycle_for_security_change(&self, route_id: &str, new_security: &RouteSecurity) {
|
||||
let mut recycled = 0u64;
|
||||
self.connections.retain(|_, entry| {
|
||||
if entry.route_id.as_deref() == Some(route_id) {
|
||||
if !RequestFilter::check_ip_security(
|
||||
new_security,
|
||||
&entry.source_ip,
|
||||
entry.domain.as_deref(),
|
||||
) {
|
||||
info!(
|
||||
"Terminating connection from {} — IP now blocked on route '{}'",
|
||||
entry.source_ip, route_id
|
||||
);
|
||||
entry.cancel.cancel();
|
||||
recycled += 1;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
true
|
||||
});
|
||||
if recycled > 0 {
|
||||
info!(
|
||||
"Recycled {} connection(s) for security change on route '{}'",
|
||||
recycled, route_id
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Recycle all connections on a route (e.g., when targets changed).
|
||||
pub fn recycle_for_route_change(&self, route_id: &str) {
|
||||
let mut recycled = 0u64;
|
||||
self.connections.retain(|_, entry| {
|
||||
if entry.route_id.as_deref() == Some(route_id) {
|
||||
entry.cancel.cancel();
|
||||
recycled += 1;
|
||||
false
|
||||
} else {
|
||||
true
|
||||
}
|
||||
});
|
||||
if recycled > 0 {
|
||||
info!(
|
||||
"Recycled {} connection(s) for config change on route '{}'",
|
||||
recycled, route_id
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Remove connections on routes that no longer exist.
|
||||
///
|
||||
/// This complements per-route CancellationToken cancellation —
|
||||
/// the token cascade handles graceful shutdown, this cleans up the registry.
|
||||
pub fn cleanup_removed_routes(&self, active_route_ids: &HashSet<String>) {
|
||||
self.connections.retain(|_, entry| {
|
||||
match &entry.route_id {
|
||||
Some(id) => active_route_ids.contains(id),
|
||||
None => true, // keep connections without a route ID
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/// RAII guard that removes a connection from the registry on drop.
|
||||
pub struct ConnectionRegistryGuard {
|
||||
registry: Arc<ConnectionRegistry>,
|
||||
conn_id: u64,
|
||||
}
|
||||
|
||||
impl Drop for ConnectionRegistryGuard {
|
||||
fn drop(&mut self) {
|
||||
self.registry.connections.remove(&self.conn_id);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn make_registry() -> Arc<ConnectionRegistry> {
|
||||
Arc::new(ConnectionRegistry::new())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_register_and_guard_cleanup() {
|
||||
let reg = make_registry();
|
||||
let token = CancellationToken::new();
|
||||
let entry = ConnectionEntry {
|
||||
cancel: token.clone(),
|
||||
source_ip: "10.0.0.1".parse().unwrap(),
|
||||
domain: Some("example.com".to_string()),
|
||||
route_id: Some("route-1".to_string()),
|
||||
};
|
||||
let (id, guard) = reg.register(entry);
|
||||
assert_eq!(reg.len(), 1);
|
||||
assert!(reg.connections.contains_key(&id));
|
||||
|
||||
drop(guard);
|
||||
assert_eq!(reg.len(), 0);
|
||||
assert!(!token.is_cancelled());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_recycle_for_cert_change_exact() {
|
||||
let reg = make_registry();
|
||||
let t1 = CancellationToken::new();
|
||||
let t2 = CancellationToken::new();
|
||||
let (_, _g1) = reg.register(ConnectionEntry {
|
||||
cancel: t1.clone(),
|
||||
source_ip: "10.0.0.1".parse().unwrap(),
|
||||
domain: Some("api.example.com".to_string()),
|
||||
route_id: Some("r1".to_string()),
|
||||
});
|
||||
let (_, _g2) = reg.register(ConnectionEntry {
|
||||
cancel: t2.clone(),
|
||||
source_ip: "10.0.0.2".parse().unwrap(),
|
||||
domain: Some("other.com".to_string()),
|
||||
route_id: Some("r2".to_string()),
|
||||
});
|
||||
|
||||
reg.recycle_for_cert_change("api.example.com");
|
||||
assert!(t1.is_cancelled());
|
||||
assert!(!t2.is_cancelled());
|
||||
// Registry retains unmatched entry (guard still alive keeps it too,
|
||||
// but the retain removed the matched one before guard could)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_recycle_for_cert_change_wildcard() {
|
||||
let reg = make_registry();
|
||||
let t1 = CancellationToken::new();
|
||||
let t2 = CancellationToken::new();
|
||||
let (_, _g1) = reg.register(ConnectionEntry {
|
||||
cancel: t1.clone(),
|
||||
source_ip: "10.0.0.1".parse().unwrap(),
|
||||
domain: Some("sub.example.com".to_string()),
|
||||
route_id: Some("r1".to_string()),
|
||||
});
|
||||
let (_, _g2) = reg.register(ConnectionEntry {
|
||||
cancel: t2.clone(),
|
||||
source_ip: "10.0.0.2".parse().unwrap(),
|
||||
domain: Some("other.com".to_string()),
|
||||
route_id: Some("r2".to_string()),
|
||||
});
|
||||
|
||||
// Wildcard cert should match subdomain
|
||||
reg.recycle_for_cert_change("*.example.com");
|
||||
assert!(t1.is_cancelled());
|
||||
assert!(!t2.is_cancelled());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_recycle_for_security_change() {
|
||||
let reg = make_registry();
|
||||
let t1 = CancellationToken::new();
|
||||
let t2 = CancellationToken::new();
|
||||
let (_, _g1) = reg.register(ConnectionEntry {
|
||||
cancel: t1.clone(),
|
||||
source_ip: "10.0.0.1".parse().unwrap(),
|
||||
domain: None,
|
||||
route_id: Some("r1".to_string()),
|
||||
});
|
||||
let (_, _g2) = reg.register(ConnectionEntry {
|
||||
cancel: t2.clone(),
|
||||
source_ip: "10.0.0.2".parse().unwrap(),
|
||||
domain: None,
|
||||
route_id: Some("r1".to_string()),
|
||||
});
|
||||
|
||||
// Block 10.0.0.1, allow 10.0.0.2
|
||||
let security = RouteSecurity {
|
||||
ip_allow_list: None,
|
||||
ip_block_list: Some(vec!["10.0.0.1".to_string()]),
|
||||
max_connections: None,
|
||||
authentication: None,
|
||||
rate_limit: None,
|
||||
basic_auth: None,
|
||||
jwt_auth: None,
|
||||
};
|
||||
|
||||
reg.recycle_for_security_change("r1", &security);
|
||||
assert!(t1.is_cancelled());
|
||||
assert!(!t2.is_cancelled());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_recycle_for_route_change() {
|
||||
let reg = make_registry();
|
||||
let t1 = CancellationToken::new();
|
||||
let t2 = CancellationToken::new();
|
||||
let (_, _g1) = reg.register(ConnectionEntry {
|
||||
cancel: t1.clone(),
|
||||
source_ip: "10.0.0.1".parse().unwrap(),
|
||||
domain: None,
|
||||
route_id: Some("r1".to_string()),
|
||||
});
|
||||
let (_, _g2) = reg.register(ConnectionEntry {
|
||||
cancel: t2.clone(),
|
||||
source_ip: "10.0.0.2".parse().unwrap(),
|
||||
domain: None,
|
||||
route_id: Some("r2".to_string()),
|
||||
});
|
||||
|
||||
reg.recycle_for_route_change("r1");
|
||||
assert!(t1.is_cancelled());
|
||||
assert!(!t2.is_cancelled());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cleanup_removed_routes() {
|
||||
let reg = make_registry();
|
||||
let t1 = CancellationToken::new();
|
||||
let t2 = CancellationToken::new();
|
||||
let (_, _g1) = reg.register(ConnectionEntry {
|
||||
cancel: t1.clone(),
|
||||
source_ip: "10.0.0.1".parse().unwrap(),
|
||||
domain: None,
|
||||
route_id: Some("active".to_string()),
|
||||
});
|
||||
let (_, _g2) = reg.register(ConnectionEntry {
|
||||
cancel: t2.clone(),
|
||||
source_ip: "10.0.0.2".parse().unwrap(),
|
||||
domain: None,
|
||||
route_id: Some("removed".to_string()),
|
||||
});
|
||||
|
||||
let mut active = HashSet::new();
|
||||
active.insert("active".to_string());
|
||||
reg.cleanup_removed_routes(&active);
|
||||
|
||||
// "removed" route entry was cleaned from registry
|
||||
// (but guard is still alive so len may differ — the retain already removed it)
|
||||
assert!(!t1.is_cancelled()); // not cancelled by cleanup, only by token cascade
|
||||
assert!(!t2.is_cancelled()); // cleanup doesn't cancel, just removes from registry
|
||||
}
|
||||
}
|
||||
@@ -31,7 +31,8 @@ impl ConnectionTracker {
|
||||
pub fn try_accept(&self, ip: &IpAddr) -> bool {
|
||||
// Check per-IP connection limit
|
||||
if let Some(max) = self.max_per_ip {
|
||||
let count = self.active
|
||||
let count = self
|
||||
.active
|
||||
.get(ip)
|
||||
.map(|c| c.value().load(Ordering::Relaxed))
|
||||
.unwrap_or(0);
|
||||
@@ -48,7 +49,10 @@ impl ConnectionTracker {
|
||||
let timestamps = entry.value_mut();
|
||||
|
||||
// Remove timestamps older than 1 minute
|
||||
while timestamps.front().is_some_and(|t| now.duration_since(*t) >= one_minute) {
|
||||
while timestamps
|
||||
.front()
|
||||
.is_some_and(|t| now.duration_since(*t) >= one_minute)
|
||||
{
|
||||
timestamps.pop_front();
|
||||
}
|
||||
|
||||
@@ -111,7 +115,6 @@ impl ConnectionTracker {
|
||||
pub fn tracked_ips(&self) -> usize {
|
||||
self.active.len()
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::sync::Arc;
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::net::TcpStream;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use tracing::debug;
|
||||
|
||||
use rustproxy_metrics::MetricsCollector;
|
||||
@@ -87,7 +87,12 @@ pub async fn forward_bidirectional_with_timeouts(
|
||||
if let Some(data) = initial_data {
|
||||
backend.write_all(data).await?;
|
||||
if let Some(ref ctx) = metrics {
|
||||
ctx.collector.record_bytes(data.len() as u64, 0, ctx.route_id.as_deref(), ctx.source_ip.as_deref());
|
||||
ctx.collector.record_bytes(
|
||||
data.len() as u64,
|
||||
0,
|
||||
ctx.route_id.as_deref(),
|
||||
ctx.source_ip.as_deref(),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -123,14 +128,17 @@ pub async fn forward_bidirectional_with_timeouts(
|
||||
total += n as u64;
|
||||
la1.store(start.elapsed().as_millis() as u64, Ordering::Relaxed);
|
||||
if let Some(ref ctx) = metrics_c2b {
|
||||
ctx.collector.record_bytes(n as u64, 0, ctx.route_id.as_deref(), ctx.source_ip.as_deref());
|
||||
ctx.collector.record_bytes(
|
||||
n as u64,
|
||||
0,
|
||||
ctx.route_id.as_deref(),
|
||||
ctx.source_ip.as_deref(),
|
||||
);
|
||||
}
|
||||
}
|
||||
// Graceful shutdown with timeout (sends TCP FIN / TLS close_notify)
|
||||
let _ = tokio::time::timeout(
|
||||
std::time::Duration::from_secs(2),
|
||||
backend_write.shutdown(),
|
||||
).await;
|
||||
let _ =
|
||||
tokio::time::timeout(std::time::Duration::from_secs(2), backend_write.shutdown()).await;
|
||||
total
|
||||
});
|
||||
|
||||
@@ -154,14 +162,17 @@ pub async fn forward_bidirectional_with_timeouts(
|
||||
total += n as u64;
|
||||
la2.store(start.elapsed().as_millis() as u64, Ordering::Relaxed);
|
||||
if let Some(ref ctx) = metrics_b2c {
|
||||
ctx.collector.record_bytes(0, n as u64, ctx.route_id.as_deref(), ctx.source_ip.as_deref());
|
||||
ctx.collector.record_bytes(
|
||||
0,
|
||||
n as u64,
|
||||
ctx.route_id.as_deref(),
|
||||
ctx.source_ip.as_deref(),
|
||||
);
|
||||
}
|
||||
}
|
||||
// Graceful shutdown with timeout (sends TCP FIN / TLS close_notify)
|
||||
let _ = tokio::time::timeout(
|
||||
std::time::Duration::from_secs(2),
|
||||
client_write.shutdown(),
|
||||
).await;
|
||||
let _ =
|
||||
tokio::time::timeout(std::time::Duration::from_secs(2), client_write.shutdown()).await;
|
||||
total
|
||||
});
|
||||
|
||||
|
||||
@@ -4,26 +4,26 @@
|
||||
//! Handles TCP listening, TLS ClientHello SNI extraction, bidirectional forwarding,
|
||||
//! and UDP datagram session tracking with forwarding.
|
||||
|
||||
pub mod tcp_listener;
|
||||
pub mod sni_parser;
|
||||
pub mod connection_registry;
|
||||
pub mod connection_tracker;
|
||||
pub mod forwarder;
|
||||
pub mod proxy_protocol;
|
||||
pub mod tls_handler;
|
||||
pub mod connection_tracker;
|
||||
pub mod socket_relay;
|
||||
pub mod socket_opts;
|
||||
pub mod udp_session;
|
||||
pub mod udp_listener;
|
||||
pub mod quic_handler;
|
||||
pub mod sni_parser;
|
||||
pub mod socket_opts;
|
||||
pub mod tcp_listener;
|
||||
pub mod tls_handler;
|
||||
pub mod udp_listener;
|
||||
pub mod udp_session;
|
||||
|
||||
pub use tcp_listener::*;
|
||||
pub use sni_parser::*;
|
||||
pub use connection_registry::*;
|
||||
pub use connection_tracker::*;
|
||||
pub use forwarder::*;
|
||||
pub use proxy_protocol::*;
|
||||
pub use tls_handler::*;
|
||||
pub use connection_tracker::*;
|
||||
pub use socket_relay::*;
|
||||
pub use socket_opts::*;
|
||||
pub use udp_session::*;
|
||||
pub use udp_listener::*;
|
||||
pub use quic_handler::*;
|
||||
pub use sni_parser::*;
|
||||
pub use socket_opts::*;
|
||||
pub use tcp_listener::*;
|
||||
pub use tls_handler::*;
|
||||
pub use udp_listener::*;
|
||||
pub use udp_session::*;
|
||||
|
||||
@@ -54,8 +54,8 @@ pub fn parse_v1(data: &[u8]) -> Result<(ProxyProtocolHeader, usize), ProxyProtoc
|
||||
.position(|w| w == b"\r\n")
|
||||
.ok_or(ProxyProtocolError::InvalidHeader)?;
|
||||
|
||||
let line = std::str::from_utf8(&data[..line_end])
|
||||
.map_err(|_| ProxyProtocolError::InvalidHeader)?;
|
||||
let line =
|
||||
std::str::from_utf8(&data[..line_end]).map_err(|_| ProxyProtocolError::InvalidHeader)?;
|
||||
|
||||
if !line.starts_with("PROXY ") {
|
||||
return Err(ProxyProtocolError::InvalidHeader);
|
||||
@@ -148,7 +148,10 @@ pub fn parse_v2(data: &[u8]) -> Result<(ProxyProtocolHeader, usize), ProxyProtoc
|
||||
let command = data[12] & 0x0F;
|
||||
// 0x0 = LOCAL, 0x1 = PROXY
|
||||
if command > 1 {
|
||||
return Err(ProxyProtocolError::Parse(format!("Unknown command: {}", command)));
|
||||
return Err(ProxyProtocolError::Parse(format!(
|
||||
"Unknown command: {}",
|
||||
command
|
||||
)));
|
||||
}
|
||||
|
||||
// Address family (high nibble) + transport (low nibble) of byte 13
|
||||
@@ -182,7 +185,9 @@ pub fn parse_v2(data: &[u8]) -> Result<(ProxyProtocolHeader, usize), ProxyProtoc
|
||||
// AF_INET (0x1) + STREAM (0x1) = TCP4
|
||||
(0x1, 0x1) => {
|
||||
if addr_len < 12 {
|
||||
return Err(ProxyProtocolError::Parse("IPv4 address block too short".to_string()));
|
||||
return Err(ProxyProtocolError::Parse(
|
||||
"IPv4 address block too short".to_string(),
|
||||
));
|
||||
}
|
||||
let src_ip = Ipv4Addr::new(addr_block[0], addr_block[1], addr_block[2], addr_block[3]);
|
||||
let dst_ip = Ipv4Addr::new(addr_block[4], addr_block[5], addr_block[6], addr_block[7]);
|
||||
@@ -200,7 +205,9 @@ pub fn parse_v2(data: &[u8]) -> Result<(ProxyProtocolHeader, usize), ProxyProtoc
|
||||
// AF_INET (0x1) + DGRAM (0x2) = UDP4
|
||||
(0x1, 0x2) => {
|
||||
if addr_len < 12 {
|
||||
return Err(ProxyProtocolError::Parse("IPv4 address block too short".to_string()));
|
||||
return Err(ProxyProtocolError::Parse(
|
||||
"IPv4 address block too short".to_string(),
|
||||
));
|
||||
}
|
||||
let src_ip = Ipv4Addr::new(addr_block[0], addr_block[1], addr_block[2], addr_block[3]);
|
||||
let dst_ip = Ipv4Addr::new(addr_block[4], addr_block[5], addr_block[6], addr_block[7]);
|
||||
@@ -218,7 +225,9 @@ pub fn parse_v2(data: &[u8]) -> Result<(ProxyProtocolHeader, usize), ProxyProtoc
|
||||
// AF_INET6 (0x2) + STREAM (0x1) = TCP6
|
||||
(0x2, 0x1) => {
|
||||
if addr_len < 36 {
|
||||
return Err(ProxyProtocolError::Parse("IPv6 address block too short".to_string()));
|
||||
return Err(ProxyProtocolError::Parse(
|
||||
"IPv6 address block too short".to_string(),
|
||||
));
|
||||
}
|
||||
let src_ip = Ipv6Addr::from(<[u8; 16]>::try_from(&addr_block[0..16]).unwrap());
|
||||
let dst_ip = Ipv6Addr::from(<[u8; 16]>::try_from(&addr_block[16..32]).unwrap());
|
||||
@@ -236,7 +245,9 @@ pub fn parse_v2(data: &[u8]) -> Result<(ProxyProtocolHeader, usize), ProxyProtoc
|
||||
// AF_INET6 (0x2) + DGRAM (0x2) = UDP6
|
||||
(0x2, 0x2) => {
|
||||
if addr_len < 36 {
|
||||
return Err(ProxyProtocolError::Parse("IPv6 address block too short".to_string()));
|
||||
return Err(ProxyProtocolError::Parse(
|
||||
"IPv6 address block too short".to_string(),
|
||||
));
|
||||
}
|
||||
let src_ip = Ipv6Addr::from(<[u8; 16]>::try_from(&addr_block[0..16]).unwrap());
|
||||
let dst_ip = Ipv6Addr::from(<[u8; 16]>::try_from(&addr_block[16..32]).unwrap());
|
||||
@@ -268,11 +279,7 @@ pub fn parse_v2(data: &[u8]) -> Result<(ProxyProtocolHeader, usize), ProxyProtoc
|
||||
}
|
||||
|
||||
/// Generate a PROXY protocol v2 binary header.
|
||||
pub fn generate_v2(
|
||||
source: &SocketAddr,
|
||||
dest: &SocketAddr,
|
||||
transport: ProxyV2Transport,
|
||||
) -> Vec<u8> {
|
||||
pub fn generate_v2(source: &SocketAddr, dest: &SocketAddr, transport: ProxyV2Transport) -> Vec<u8> {
|
||||
let transport_nibble: u8 = match transport {
|
||||
ProxyV2Transport::Stream => 0x1,
|
||||
ProxyV2Transport::Datagram => 0x2,
|
||||
@@ -462,7 +469,10 @@ mod tests {
|
||||
header.push(0x11);
|
||||
header.extend_from_slice(&12u16.to_be_bytes());
|
||||
header.extend_from_slice(&[0u8; 12]);
|
||||
assert!(matches!(parse_v2(&header), Err(ProxyProtocolError::UnsupportedVersion)));
|
||||
assert!(matches!(
|
||||
parse_v2(&header),
|
||||
Err(ProxyProtocolError::UnsupportedVersion)
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -26,9 +26,11 @@ use tracing::{debug, info, warn};
|
||||
use rustproxy_config::{RouteConfig, TransportProtocol};
|
||||
use rustproxy_metrics::MetricsCollector;
|
||||
use rustproxy_routing::{MatchContext, RouteManager};
|
||||
use rustproxy_security::IpBlockList;
|
||||
|
||||
use rustproxy_http::h3_service::H3ProxyService;
|
||||
|
||||
use crate::connection_registry::{ConnectionEntry, ConnectionRegistry};
|
||||
use crate::connection_tracker::ConnectionTracker;
|
||||
|
||||
/// Create a QUIC server endpoint on the given port with the provided TLS config.
|
||||
@@ -47,8 +49,7 @@ pub fn create_quic_endpoint(
|
||||
quinn::EndpointConfig::default(),
|
||||
Some(server_config),
|
||||
socket,
|
||||
quinn::default_runtime()
|
||||
.ok_or_else(|| anyhow::anyhow!("No async runtime for quinn"))?,
|
||||
quinn::default_runtime().ok_or_else(|| anyhow::anyhow!("No async runtime for quinn"))?,
|
||||
)?;
|
||||
|
||||
info!("QUIC endpoint listening on port {}", port);
|
||||
@@ -96,6 +97,7 @@ pub fn create_quic_endpoint_with_proxy_relay(
|
||||
port: u16,
|
||||
tls_config: Arc<RustlsServerConfig>,
|
||||
proxy_ips: Arc<Vec<IpAddr>>,
|
||||
security_policy: Arc<ArcSwap<IpBlockList>>,
|
||||
cancel: CancellationToken,
|
||||
) -> anyhow::Result<QuicProxyRelay> {
|
||||
// Bind external socket on the real port
|
||||
@@ -118,8 +120,7 @@ pub fn create_quic_endpoint_with_proxy_relay(
|
||||
quinn::EndpointConfig::default(),
|
||||
Some(server_config),
|
||||
internal_socket,
|
||||
quinn::default_runtime()
|
||||
.ok_or_else(|| anyhow::anyhow!("No async runtime for quinn"))?,
|
||||
quinn::default_runtime().ok_or_else(|| anyhow::anyhow!("No async runtime for quinn"))?,
|
||||
)?;
|
||||
|
||||
let real_client_map = Arc::new(DashMap::new());
|
||||
@@ -128,12 +129,20 @@ pub fn create_quic_endpoint_with_proxy_relay(
|
||||
external_socket,
|
||||
quinn_internal_addr,
|
||||
proxy_ips,
|
||||
security_policy,
|
||||
Arc::clone(&real_client_map),
|
||||
cancel,
|
||||
));
|
||||
|
||||
info!("QUIC endpoint with PROXY relay on port {} (quinn internal: {})", port, quinn_internal_addr);
|
||||
Ok(QuicProxyRelay { endpoint, relay_task, real_client_map })
|
||||
info!(
|
||||
"QUIC endpoint with PROXY relay on port {} (quinn internal: {})",
|
||||
port, quinn_internal_addr
|
||||
);
|
||||
Ok(QuicProxyRelay {
|
||||
endpoint,
|
||||
relay_task,
|
||||
real_client_map,
|
||||
})
|
||||
}
|
||||
|
||||
/// Main relay loop: reads datagrams from the external socket, filters PROXY v2
|
||||
@@ -143,6 +152,7 @@ async fn quic_proxy_relay_loop(
|
||||
external_socket: Arc<UdpSocket>,
|
||||
quinn_internal_addr: SocketAddr,
|
||||
proxy_ips: Arc<Vec<IpAddr>>,
|
||||
security_policy: Arc<ArcSwap<IpBlockList>>,
|
||||
real_client_map: Arc<DashMap<SocketAddr, SocketAddr>>,
|
||||
cancel: CancellationToken,
|
||||
) {
|
||||
@@ -183,26 +193,43 @@ async fn quic_proxy_relay_loop(
|
||||
if crate::proxy_protocol::is_proxy_protocol_v2(datagram) {
|
||||
match crate::proxy_protocol::parse_v2(datagram) {
|
||||
Ok((header, _consumed)) => {
|
||||
debug!("QUIC PROXY v2 from {}: real client {}", src_addr, header.source_addr);
|
||||
debug!(
|
||||
"QUIC PROXY v2 from {}: real client {}",
|
||||
src_addr, header.source_addr
|
||||
);
|
||||
proxy_addr_map.insert(src_addr, header.source_addr);
|
||||
continue; // consume the PROXY v2 datagram
|
||||
}
|
||||
Err(e) => {
|
||||
debug!("QUIC proxy relay: failed to parse PROXY v2 from {}: {}", src_addr, e);
|
||||
debug!(
|
||||
"QUIC proxy relay: failed to parse PROXY v2 from {}: {}",
|
||||
src_addr, e
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Determine real client address
|
||||
let real_client = proxy_addr_map.get(&src_addr)
|
||||
let real_client = proxy_addr_map
|
||||
.get(&src_addr)
|
||||
.map(|r| *r)
|
||||
.unwrap_or(src_addr);
|
||||
|
||||
let block_list = security_policy.load();
|
||||
if !block_list.is_empty() && block_list.is_blocked(&real_client.ip()) {
|
||||
debug!(
|
||||
"QUIC datagram from {} blocked by global security policy",
|
||||
real_client
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Get or create relay session for this external source
|
||||
let session = match relay_sessions.get(&src_addr) {
|
||||
Some(s) => {
|
||||
s.last_activity.store(epoch.elapsed().as_millis() as u64, Ordering::Relaxed);
|
||||
s.last_activity
|
||||
.store(epoch.elapsed().as_millis() as u64, Ordering::Relaxed);
|
||||
Arc::clone(s.value())
|
||||
}
|
||||
None => {
|
||||
@@ -215,7 +242,10 @@ async fn quic_proxy_relay_loop(
|
||||
}
|
||||
};
|
||||
if let Err(e) = relay_socket.connect(quinn_internal_addr).await {
|
||||
warn!("QUIC relay: failed to connect relay socket to {}: {}", quinn_internal_addr, e);
|
||||
warn!(
|
||||
"QUIC relay: failed to connect relay socket to {}: {}",
|
||||
quinn_internal_addr, e
|
||||
);
|
||||
continue;
|
||||
}
|
||||
let relay_local_addr = match relay_socket.local_addr() {
|
||||
@@ -247,8 +277,10 @@ async fn quic_proxy_relay_loop(
|
||||
});
|
||||
|
||||
relay_sessions.insert(src_addr, Arc::clone(&session));
|
||||
debug!("QUIC relay: new session for {} (relay {}), real client {}",
|
||||
src_addr, relay_local_addr, real_client);
|
||||
debug!(
|
||||
"QUIC relay: new session for {} (relay {}), real client {}",
|
||||
src_addr, relay_local_addr, real_client
|
||||
);
|
||||
|
||||
session
|
||||
}
|
||||
@@ -263,9 +295,11 @@ async fn quic_proxy_relay_loop(
|
||||
if last_cleanup.elapsed() >= cleanup_interval {
|
||||
last_cleanup = Instant::now();
|
||||
let now_ms = epoch.elapsed().as_millis() as u64;
|
||||
let stale_keys: Vec<SocketAddr> = relay_sessions.iter()
|
||||
let stale_keys: Vec<SocketAddr> = relay_sessions
|
||||
.iter()
|
||||
.filter(|entry| {
|
||||
let age = now_ms.saturating_sub(entry.value().last_activity.load(Ordering::Relaxed));
|
||||
let age =
|
||||
now_ms.saturating_sub(entry.value().last_activity.load(Ordering::Relaxed));
|
||||
age > session_timeout_ms
|
||||
})
|
||||
.map(|entry| *entry.key())
|
||||
@@ -286,13 +320,17 @@ async fn quic_proxy_relay_loop(
|
||||
|
||||
// Also clean orphaned proxy_addr_map entries (PROXY header received
|
||||
// but no relay session was ever created, e.g. client never sent data)
|
||||
let orphaned: Vec<SocketAddr> = proxy_addr_map.iter()
|
||||
let orphaned: Vec<SocketAddr> = proxy_addr_map
|
||||
.iter()
|
||||
.filter(|entry| relay_sessions.get(entry.key()).is_none())
|
||||
.map(|entry| *entry.key())
|
||||
.collect();
|
||||
for key in orphaned {
|
||||
proxy_addr_map.remove(&key);
|
||||
debug!("QUIC relay: cleaned up orphaned proxy_addr_map entry for {}", key);
|
||||
debug!(
|
||||
"QUIC relay: cleaned up orphaned proxy_addr_map entry for {}",
|
||||
key
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -327,8 +365,14 @@ async fn relay_return_path(
|
||||
}
|
||||
};
|
||||
|
||||
if let Err(e) = external_socket.send_to(&buf[..len], external_src_addr).await {
|
||||
debug!("QUIC relay return send error to {}: {}", external_src_addr, e);
|
||||
if let Err(e) = external_socket
|
||||
.send_to(&buf[..len], external_src_addr)
|
||||
.await
|
||||
{
|
||||
debug!(
|
||||
"QUIC relay return send error to {}: {}",
|
||||
external_src_addr, e
|
||||
);
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -350,6 +394,9 @@ pub async fn quic_accept_loop(
|
||||
cancel: CancellationToken,
|
||||
h3_service: Option<Arc<H3ProxyService>>,
|
||||
real_client_map: Option<Arc<DashMap<SocketAddr, SocketAddr>>>,
|
||||
route_cancels: Arc<DashMap<String, CancellationToken>>,
|
||||
connection_registry: Arc<ConnectionRegistry>,
|
||||
security_policy: Arc<ArcSwap<IpBlockList>>,
|
||||
) {
|
||||
loop {
|
||||
let incoming = tokio::select! {
|
||||
@@ -371,11 +418,21 @@ pub async fn quic_accept_loop(
|
||||
let remote_addr = incoming.remote_address();
|
||||
|
||||
// Resolve real client IP from PROXY protocol map if available
|
||||
let real_addr = real_client_map.as_ref()
|
||||
let real_addr = real_client_map
|
||||
.as_ref()
|
||||
.and_then(|map| map.get(&remote_addr).map(|r| *r))
|
||||
.unwrap_or(remote_addr);
|
||||
let ip = real_addr.ip();
|
||||
|
||||
let block_list = security_policy.load();
|
||||
if !block_list.is_empty() && block_list.is_blocked(&ip) {
|
||||
debug!(
|
||||
"QUIC connection from {} blocked by global security policy",
|
||||
real_addr
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Per-IP rate limiting
|
||||
if !conn_tracker.try_accept(&ip) {
|
||||
debug!("QUIC connection rejected from {} (rate limit)", real_addr);
|
||||
@@ -406,17 +463,56 @@ pub async fn quic_accept_loop(
|
||||
}
|
||||
};
|
||||
|
||||
// Check route-level IP security for QUIC (domain from SNI context)
|
||||
if let Some(ref security) = route.security {
|
||||
if !rustproxy_http::request_filter::RequestFilter::check_ip_security(
|
||||
security, &ip, ctx.domain,
|
||||
) {
|
||||
debug!(
|
||||
"QUIC connection from {} blocked by route security",
|
||||
real_addr
|
||||
);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
conn_tracker.connection_opened(&ip);
|
||||
let route_id = route.name.clone().or(route.id.clone());
|
||||
let route_id = route.metrics_key().map(str::to_string);
|
||||
metrics.connection_opened(route_id.as_deref(), Some(&ip_str));
|
||||
|
||||
// Resolve per-route cancel token (child of global cancel)
|
||||
let route_cancel = match route_id.as_deref() {
|
||||
Some(id) => route_cancels
|
||||
.entry(id.to_string())
|
||||
.or_insert_with(|| cancel.child_token())
|
||||
.clone(),
|
||||
None => cancel.child_token(),
|
||||
};
|
||||
// Per-connection child token for selective recycling
|
||||
let conn_cancel = route_cancel.child_token();
|
||||
|
||||
// Register in connection registry
|
||||
let registry = Arc::clone(&connection_registry);
|
||||
let reg_entry = ConnectionEntry {
|
||||
cancel: conn_cancel.clone(),
|
||||
source_ip: ip,
|
||||
domain: None, // QUIC Initial is encrypted, domain comes later via H3 :authority
|
||||
route_id: route_id.clone(),
|
||||
};
|
||||
|
||||
let metrics = Arc::clone(&metrics);
|
||||
let conn_tracker = Arc::clone(&conn_tracker);
|
||||
let cancel = cancel.child_token();
|
||||
let h3_svc = h3_service.clone();
|
||||
let real_client_addr = if real_addr != remote_addr { Some(real_addr) } else { None };
|
||||
let real_client_addr = if real_addr != remote_addr {
|
||||
Some(real_addr)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
tokio::spawn(async move {
|
||||
// Register in connection registry (RAII guard removes on drop)
|
||||
let (_conn_id, _registry_guard) = registry.register(reg_entry);
|
||||
|
||||
// RAII guard: ensures metrics/tracker cleanup even on panic
|
||||
struct QuicConnGuard {
|
||||
tracker: Arc<ConnectionTracker>,
|
||||
@@ -428,7 +524,8 @@ pub async fn quic_accept_loop(
|
||||
impl Drop for QuicConnGuard {
|
||||
fn drop(&mut self) {
|
||||
self.tracker.connection_closed(&self.ip);
|
||||
self.metrics.connection_closed(self.route_id.as_deref(), Some(&self.ip_str));
|
||||
self.metrics
|
||||
.connection_closed(self.route_id.as_deref(), Some(&self.ip_str));
|
||||
}
|
||||
}
|
||||
let _guard = QuicConnGuard {
|
||||
@@ -439,7 +536,17 @@ pub async fn quic_accept_loop(
|
||||
route_id,
|
||||
};
|
||||
|
||||
match handle_quic_connection(incoming, route, port, Arc::clone(&metrics), &cancel, h3_svc, real_client_addr).await {
|
||||
match handle_quic_connection(
|
||||
incoming,
|
||||
route,
|
||||
port,
|
||||
Arc::clone(&metrics),
|
||||
&conn_cancel,
|
||||
h3_svc,
|
||||
real_client_addr,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(()) => debug!("QUIC connection from {} completed", real_addr),
|
||||
Err(e) => debug!("QUIC connection from {} error: {}", real_addr, e),
|
||||
}
|
||||
@@ -467,17 +574,28 @@ async fn handle_quic_connection(
|
||||
debug!("QUIC connection established from {}", effective_addr);
|
||||
|
||||
// Check if this route has HTTP/3 enabled
|
||||
let enable_http3 = route.action.udp.as_ref()
|
||||
let enable_http3 = route
|
||||
.action
|
||||
.udp
|
||||
.as_ref()
|
||||
.and_then(|u| u.quic.as_ref())
|
||||
.and_then(|q| q.enable_http3)
|
||||
.unwrap_or(false);
|
||||
|
||||
if enable_http3 {
|
||||
if let Some(ref h3_svc) = h3_service {
|
||||
debug!("HTTP/3 enabled for route {:?}, dispatching to H3ProxyService", route.name);
|
||||
h3_svc.handle_connection(connection, &route, port, real_client_addr, cancel).await
|
||||
debug!(
|
||||
"HTTP/3 enabled for route {:?}, dispatching to H3ProxyService",
|
||||
route.name
|
||||
);
|
||||
h3_svc
|
||||
.handle_connection(connection, &route, port, real_client_addr, cancel)
|
||||
.await
|
||||
} else {
|
||||
warn!("HTTP/3 enabled for route {:?} but H3ProxyService not initialized", route.name);
|
||||
warn!(
|
||||
"HTTP/3 enabled for route {:?} but H3ProxyService not initialized",
|
||||
route.name
|
||||
);
|
||||
// Keep connection alive until cancelled
|
||||
tokio::select! {
|
||||
_ = cancel.cancelled() => {}
|
||||
@@ -489,7 +607,8 @@ async fn handle_quic_connection(
|
||||
}
|
||||
} else {
|
||||
// Non-HTTP3 QUIC: bidirectional stream forwarding to TCP backend
|
||||
handle_quic_stream_forwarding(connection, route, port, metrics, cancel, real_client_addr).await
|
||||
handle_quic_stream_forwarding(connection, route, port, metrics, cancel, real_client_addr)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
@@ -507,11 +626,14 @@ async fn handle_quic_stream_forwarding(
|
||||
real_client_addr: Option<SocketAddr>,
|
||||
) -> anyhow::Result<()> {
|
||||
let effective_addr = real_client_addr.unwrap_or_else(|| connection.remote_address());
|
||||
let route_id = route.name.as_deref().or(route.id.as_deref());
|
||||
let route_id = route.metrics_key();
|
||||
let metrics_arc = metrics;
|
||||
|
||||
// Resolve backend target
|
||||
let target = route.action.targets.as_ref()
|
||||
let target = route
|
||||
.action
|
||||
.targets
|
||||
.as_ref()
|
||||
.and_then(|t| t.first())
|
||||
.ok_or_else(|| anyhow::anyhow!("No target for QUIC route"))?;
|
||||
let backend_host = target.host.first();
|
||||
@@ -542,19 +664,20 @@ async fn handle_quic_stream_forwarding(
|
||||
|
||||
// Spawn a task for each QUIC stream → TCP bidirectional forwarding
|
||||
tokio::spawn(async move {
|
||||
match forward_quic_stream_to_tcp(
|
||||
send_stream,
|
||||
recv_stream,
|
||||
&backend_addr,
|
||||
stream_cancel,
|
||||
).await {
|
||||
match forward_quic_stream_to_tcp(send_stream, recv_stream, &backend_addr, stream_cancel)
|
||||
.await
|
||||
{
|
||||
Ok((bytes_in, bytes_out)) => {
|
||||
stream_metrics.record_bytes(
|
||||
bytes_in, bytes_out,
|
||||
bytes_in,
|
||||
bytes_out,
|
||||
stream_route_id.as_deref(),
|
||||
Some(&ip_str),
|
||||
);
|
||||
debug!("QUIC stream forwarded: {}B in, {}B out", bytes_in, bytes_out);
|
||||
debug!(
|
||||
"QUIC stream forwarded: {}B in, {}B out",
|
||||
bytes_in, bytes_out
|
||||
);
|
||||
}
|
||||
Err(e) => {
|
||||
debug!("QUIC stream forwarding error: {}", e);
|
||||
@@ -606,10 +729,7 @@ async fn forward_quic_stream_to_tcp(
|
||||
total += n as u64;
|
||||
la1.store(start.elapsed().as_millis() as u64, Ordering::Relaxed);
|
||||
}
|
||||
let _ = tokio::time::timeout(
|
||||
std::time::Duration::from_secs(2),
|
||||
tcp_write.shutdown(),
|
||||
).await;
|
||||
let _ = tokio::time::timeout(std::time::Duration::from_secs(2), tcp_write.shutdown()).await;
|
||||
total
|
||||
});
|
||||
|
||||
@@ -687,8 +807,8 @@ mod tests {
|
||||
let _ = rustls::crypto::ring::default_provider().install_default();
|
||||
|
||||
// Generate a single self-signed cert and use its key pair
|
||||
let self_signed = rcgen::generate_simple_self_signed(vec!["localhost".to_string()])
|
||||
.unwrap();
|
||||
let self_signed =
|
||||
rcgen::generate_simple_self_signed(vec!["localhost".to_string()]).unwrap();
|
||||
let cert_der = self_signed.cert.der().clone();
|
||||
let key_der = self_signed.key_pair.serialize_der();
|
||||
|
||||
@@ -703,6 +823,10 @@ mod tests {
|
||||
|
||||
// Port 0 = OS assigns a free port
|
||||
let result = create_quic_endpoint(0, Arc::new(tls_config));
|
||||
assert!(result.is_ok(), "QUIC endpoint creation failed: {:?}", result.err());
|
||||
assert!(
|
||||
result.is_ok(),
|
||||
"QUIC endpoint creation failed: {:?}",
|
||||
result.err()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -47,9 +47,8 @@ pub fn extract_sni(data: &[u8]) -> SniResult {
|
||||
}
|
||||
|
||||
// Handshake length (3 bytes) - informational, we parse incrementally
|
||||
let _handshake_len = ((data[6] as usize) << 16)
|
||||
| ((data[7] as usize) << 8)
|
||||
| (data[8] as usize);
|
||||
let _handshake_len =
|
||||
((data[6] as usize) << 16) | ((data[7] as usize) << 8) | (data[8] as usize);
|
||||
|
||||
let hello = &data[9..];
|
||||
|
||||
@@ -170,7 +169,10 @@ pub fn extract_http_path(data: &[u8]) -> Option<String> {
|
||||
pub fn extract_http_host(data: &[u8]) -> Option<String> {
|
||||
let text = std::str::from_utf8(data).ok()?;
|
||||
for line in text.split("\r\n") {
|
||||
if let Some(value) = line.strip_prefix("Host: ").or_else(|| line.strip_prefix("host: ")) {
|
||||
if let Some(value) = line
|
||||
.strip_prefix("Host: ")
|
||||
.or_else(|| line.strip_prefix("host: "))
|
||||
{
|
||||
// Strip port if present
|
||||
let host = value.split(':').next().unwrap_or(value).trim();
|
||||
if !host.is_empty() {
|
||||
@@ -196,7 +198,7 @@ pub fn is_http(data: &[u8]) -> bool {
|
||||
b"PATC",
|
||||
b"OPTI",
|
||||
b"CONN",
|
||||
b"PRI ", // HTTP/2 connection preface
|
||||
b"PRI ", // HTTP/2 connection preface
|
||||
];
|
||||
starts.iter().any(|s| data.starts_with(s))
|
||||
}
|
||||
@@ -213,7 +215,10 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_too_short() {
|
||||
assert!(matches!(extract_sni(&[0x16, 0x03]), SniResult::NeedMoreData));
|
||||
assert!(matches!(
|
||||
extract_sni(&[0x16, 0x03]),
|
||||
SniResult::NeedMoreData
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -263,7 +268,8 @@ mod tests {
|
||||
// Extension: type=0x0000 (SNI), length, data
|
||||
let sni_extension = {
|
||||
let mut e = Vec::new();
|
||||
e.push(0x00); e.push(0x00); // SNI type
|
||||
e.push(0x00);
|
||||
e.push(0x00); // SNI type
|
||||
e.push(((sni_ext_data.len() >> 8) & 0xFF) as u8);
|
||||
e.push((sni_ext_data.len() & 0xFF) as u8);
|
||||
e.extend_from_slice(&sni_ext_data);
|
||||
@@ -283,16 +289,20 @@ mod tests {
|
||||
let hello_body = {
|
||||
let mut h = Vec::new();
|
||||
// Client version TLS 1.2
|
||||
h.push(0x03); h.push(0x03);
|
||||
h.push(0x03);
|
||||
h.push(0x03);
|
||||
// Random (32 bytes)
|
||||
h.extend_from_slice(&[0u8; 32]);
|
||||
// Session ID length = 0
|
||||
h.push(0x00);
|
||||
// Cipher suites: length=2, one suite
|
||||
h.push(0x00); h.push(0x02);
|
||||
h.push(0x00); h.push(0x2F); // TLS_RSA_WITH_AES_128_CBC_SHA
|
||||
// Compression methods: length=1, null
|
||||
h.push(0x01); h.push(0x00);
|
||||
h.push(0x00);
|
||||
h.push(0x02);
|
||||
h.push(0x00);
|
||||
h.push(0x2F); // TLS_RSA_WITH_AES_128_CBC_SHA
|
||||
// Compression methods: length=1, null
|
||||
h.push(0x01);
|
||||
h.push(0x00);
|
||||
// Extensions
|
||||
h.extend_from_slice(&extensions);
|
||||
h
|
||||
@@ -302,7 +312,7 @@ mod tests {
|
||||
let handshake = {
|
||||
let mut hs = Vec::new();
|
||||
hs.push(0x01); // ClientHello
|
||||
// 3-byte length
|
||||
// 3-byte length
|
||||
hs.push(((hello_body.len() >> 16) & 0xFF) as u8);
|
||||
hs.push(((hello_body.len() >> 8) & 0xFF) as u8);
|
||||
hs.push((hello_body.len() & 0xFF) as u8);
|
||||
@@ -313,7 +323,8 @@ mod tests {
|
||||
// TLS record: type=0x16, version TLS 1.0, length
|
||||
let mut record = Vec::new();
|
||||
record.push(0x16); // Handshake
|
||||
record.push(0x03); record.push(0x01); // TLS 1.0
|
||||
record.push(0x03);
|
||||
record.push(0x01); // TLS 1.0
|
||||
record.push(((handshake.len() >> 8) & 0xFF) as u8);
|
||||
record.push((handshake.len() & 0xFF) as u8);
|
||||
record.extend_from_slice(&handshake);
|
||||
|
||||
@@ -1,126 +1,4 @@
|
||||
//! Socket handler relay for connecting client connections to a TypeScript handler
|
||||
//! via a Unix domain socket.
|
||||
//! Socket handler relay module.
|
||||
//!
|
||||
//! Protocol: Send a JSON metadata line terminated by `\n`, then bidirectional relay.
|
||||
|
||||
use tokio::net::UnixStream;
|
||||
use tokio::io::{AsyncWriteExt, AsyncReadExt};
|
||||
use tokio::net::TcpStream;
|
||||
use serde::Serialize;
|
||||
use tracing::debug;
|
||||
|
||||
#[derive(Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
struct RelayMetadata {
|
||||
connection_id: u64,
|
||||
remote_ip: String,
|
||||
remote_port: u16,
|
||||
local_port: u16,
|
||||
sni: Option<String>,
|
||||
route_name: String,
|
||||
initial_data_base64: Option<String>,
|
||||
}
|
||||
|
||||
/// Relay a client connection to a TypeScript handler via Unix domain socket.
|
||||
///
|
||||
/// Protocol: Send a JSON metadata line terminated by `\n`, then bidirectional relay.
|
||||
pub async fn relay_to_handler(
|
||||
client: TcpStream,
|
||||
relay_socket_path: &str,
|
||||
connection_id: u64,
|
||||
remote_ip: String,
|
||||
remote_port: u16,
|
||||
local_port: u16,
|
||||
sni: Option<String>,
|
||||
route_name: String,
|
||||
initial_data: Option<&[u8]>,
|
||||
) -> std::io::Result<()> {
|
||||
debug!(
|
||||
"Relaying connection {} to handler socket {}",
|
||||
connection_id, relay_socket_path
|
||||
);
|
||||
|
||||
// Connect to TypeScript handler Unix socket
|
||||
let mut handler = UnixStream::connect(relay_socket_path).await?;
|
||||
|
||||
// Build and send metadata header
|
||||
let initial_data_base64 = initial_data.map(base64_encode);
|
||||
|
||||
let metadata = RelayMetadata {
|
||||
connection_id,
|
||||
remote_ip,
|
||||
remote_port,
|
||||
local_port,
|
||||
sni,
|
||||
route_name,
|
||||
initial_data_base64,
|
||||
};
|
||||
|
||||
let metadata_json = serde_json::to_string(&metadata)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
|
||||
|
||||
handler.write_all(metadata_json.as_bytes()).await?;
|
||||
handler.write_all(b"\n").await?;
|
||||
|
||||
// Bidirectional relay between client and handler
|
||||
let (mut client_read, mut client_write) = client.into_split();
|
||||
let (mut handler_read, mut handler_write) = handler.into_split();
|
||||
|
||||
let c2h = tokio::spawn(async move {
|
||||
let mut buf = vec![0u8; 65536];
|
||||
loop {
|
||||
let n = match client_read.read(&mut buf).await {
|
||||
Ok(0) | Err(_) => break,
|
||||
Ok(n) => n,
|
||||
};
|
||||
if handler_write.write_all(&buf[..n]).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
let _ = handler_write.shutdown().await;
|
||||
});
|
||||
|
||||
let h2c = tokio::spawn(async move {
|
||||
let mut buf = vec![0u8; 65536];
|
||||
loop {
|
||||
let n = match handler_read.read(&mut buf).await {
|
||||
Ok(0) | Err(_) => break,
|
||||
Ok(n) => n,
|
||||
};
|
||||
if client_write.write_all(&buf[..n]).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
let _ = client_write.shutdown().await;
|
||||
});
|
||||
|
||||
let _ = tokio::join!(c2h, h2c);
|
||||
|
||||
debug!("Relay connection {} completed", connection_id);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Simple base64 encoding without external dependency.
|
||||
fn base64_encode(data: &[u8]) -> String {
|
||||
const CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
|
||||
let mut result = String::new();
|
||||
for chunk in data.chunks(3) {
|
||||
let b0 = chunk[0] as u32;
|
||||
let b1 = if chunk.len() > 1 { chunk[1] as u32 } else { 0 };
|
||||
let b2 = if chunk.len() > 2 { chunk[2] as u32 } else { 0 };
|
||||
let n = (b0 << 16) | (b1 << 8) | b2;
|
||||
result.push(CHARS[((n >> 18) & 0x3F) as usize] as char);
|
||||
result.push(CHARS[((n >> 12) & 0x3F) as usize] as char);
|
||||
if chunk.len() > 1 {
|
||||
result.push(CHARS[((n >> 6) & 0x3F) as usize] as char);
|
||||
} else {
|
||||
result.push('=');
|
||||
}
|
||||
if chunk.len() > 2 {
|
||||
result.push(CHARS[(n & 0x3F) as usize] as char);
|
||||
} else {
|
||||
result.push('=');
|
||||
}
|
||||
}
|
||||
result
|
||||
}
|
||||
//! Note: The actual relay logic lives in `tcp_listener::relay_to_socket_handler()`
|
||||
//! which has proper timeouts, cancellation, and metrics integration.
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -7,7 +7,7 @@ use rustls::server::ResolvesServerCert;
|
||||
use rustls::sign::CertifiedKey;
|
||||
use rustls::ServerConfig;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio_rustls::{TlsAcceptor, TlsConnector, server::TlsStream as ServerTlsStream};
|
||||
use tokio_rustls::{server::TlsStream as ServerTlsStream, TlsAcceptor, TlsConnector};
|
||||
use tracing::{debug, info};
|
||||
|
||||
use crate::tcp_listener::TlsCertConfig;
|
||||
@@ -29,7 +29,9 @@ pub struct CertResolver {
|
||||
impl CertResolver {
|
||||
/// Build a resolver from PEM-encoded cert/key configs.
|
||||
/// Parses all PEM data upfront so connections only do a cheap HashMap lookup.
|
||||
pub fn new(configs: &HashMap<String, TlsCertConfig>) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
|
||||
pub fn new(
|
||||
configs: &HashMap<String, TlsCertConfig>,
|
||||
) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
|
||||
ensure_crypto_provider();
|
||||
let provider = rustls::crypto::ring::default_provider();
|
||||
let mut certs = HashMap::new();
|
||||
@@ -38,8 +40,10 @@ impl CertResolver {
|
||||
for (domain, cfg) in configs {
|
||||
let cert_chain = load_certs(&cfg.cert_pem)?;
|
||||
let key = load_private_key(&cfg.key_pem)?;
|
||||
let ck = Arc::new(CertifiedKey::from_der(cert_chain, key, &provider)
|
||||
.map_err(|e| format!("CertifiedKey for {}: {}", domain, e))?);
|
||||
let ck = Arc::new(
|
||||
CertifiedKey::from_der(cert_chain, key, &provider)
|
||||
.map_err(|e| format!("CertifiedKey for {}: {}", domain, e))?,
|
||||
);
|
||||
if domain == "*" {
|
||||
fallback = Some(Arc::clone(&ck));
|
||||
}
|
||||
@@ -78,7 +82,9 @@ impl ResolvesServerCert for CertResolver {
|
||||
|
||||
/// Build a shared TLS acceptor with SNI resolution, session cache, and session tickets.
|
||||
/// The returned acceptor can be reused across all connections (cheap Arc clone).
|
||||
pub fn build_shared_tls_acceptor(resolver: CertResolver) -> Result<TlsAcceptor, Box<dyn std::error::Error + Send + Sync>> {
|
||||
pub fn build_shared_tls_acceptor(
|
||||
resolver: CertResolver,
|
||||
) -> Result<TlsAcceptor, Box<dyn std::error::Error + Send + Sync>> {
|
||||
ensure_crypto_provider();
|
||||
let mut config = ServerConfig::builder()
|
||||
.with_no_client_auth()
|
||||
@@ -90,22 +96,30 @@ pub fn build_shared_tls_acceptor(resolver: CertResolver) -> Result<TlsAcceptor,
|
||||
// Shared session cache — enables session ID resumption across connections
|
||||
config.session_storage = rustls::server::ServerSessionMemoryCache::new(4096);
|
||||
// Session ticket resumption (12-hour lifetime, Chacha20Poly1305 encrypted)
|
||||
config.ticketer = rustls::crypto::ring::Ticketer::new()
|
||||
.map_err(|e| format!("Ticketer: {}", e))?;
|
||||
config.ticketer =
|
||||
rustls::crypto::ring::Ticketer::new().map_err(|e| format!("Ticketer: {}", e))?;
|
||||
|
||||
info!("Built shared TLS config with session cache (4096), ticket support, and ALPN h2+http/1.1");
|
||||
info!(
|
||||
"Built shared TLS config with session cache (4096), ticket support, and ALPN h2+http/1.1"
|
||||
);
|
||||
Ok(TlsAcceptor::from(Arc::new(config)))
|
||||
}
|
||||
|
||||
/// Build a TLS acceptor from PEM-encoded cert and key data.
|
||||
/// Advertises both h2 and http/1.1 via ALPN (for client-facing connections).
|
||||
pub fn build_tls_acceptor(cert_pem: &str, key_pem: &str) -> Result<TlsAcceptor, Box<dyn std::error::Error + Send + Sync>> {
|
||||
pub fn build_tls_acceptor(
|
||||
cert_pem: &str,
|
||||
key_pem: &str,
|
||||
) -> Result<TlsAcceptor, Box<dyn std::error::Error + Send + Sync>> {
|
||||
build_tls_acceptor_with_config(cert_pem, key_pem, None)
|
||||
}
|
||||
|
||||
/// Build a TLS acceptor for backend servers that only speak HTTP/1.1.
|
||||
/// Does NOT advertise h2 in ALPN, preventing false h2 auto-detection.
|
||||
pub fn build_tls_acceptor_h1_only(cert_pem: &str, key_pem: &str) -> Result<TlsAcceptor, Box<dyn std::error::Error + Send + Sync>> {
|
||||
pub fn build_tls_acceptor_h1_only(
|
||||
cert_pem: &str,
|
||||
key_pem: &str,
|
||||
) -> Result<TlsAcceptor, Box<dyn std::error::Error + Send + Sync>> {
|
||||
ensure_crypto_provider();
|
||||
let certs = load_certs(cert_pem)?;
|
||||
let key = load_private_key(key_pem)?;
|
||||
@@ -130,9 +144,7 @@ pub fn build_tls_acceptor_with_config(
|
||||
// Apply TLS version restrictions
|
||||
let versions = resolve_tls_versions(route_tls.versions.as_deref());
|
||||
let builder = ServerConfig::builder_with_protocol_versions(&versions);
|
||||
builder
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(certs, key)?
|
||||
builder.with_no_client_auth().with_single_cert(certs, key)?
|
||||
} else {
|
||||
ServerConfig::builder()
|
||||
.with_no_client_auth()
|
||||
@@ -156,7 +168,9 @@ pub fn build_tls_acceptor_with_config(
|
||||
}
|
||||
|
||||
/// Resolve TLS version strings to rustls SupportedProtocolVersion.
|
||||
fn resolve_tls_versions(versions: Option<&[String]>) -> Vec<&'static rustls::SupportedProtocolVersion> {
|
||||
fn resolve_tls_versions(
|
||||
versions: Option<&[String]>,
|
||||
) -> Vec<&'static rustls::SupportedProtocolVersion> {
|
||||
let versions = match versions {
|
||||
Some(v) if !v.is_empty() => v,
|
||||
_ => return vec![&rustls::version::TLS12, &rustls::version::TLS13],
|
||||
@@ -207,15 +221,17 @@ pub async fn accept_tls(
|
||||
static SHARED_CLIENT_CONFIG: OnceLock<Arc<rustls::ClientConfig>> = OnceLock::new();
|
||||
|
||||
pub fn shared_backend_tls_config() -> Arc<rustls::ClientConfig> {
|
||||
SHARED_CLIENT_CONFIG.get_or_init(|| {
|
||||
ensure_crypto_provider();
|
||||
let config = rustls::ClientConfig::builder()
|
||||
.dangerous()
|
||||
.with_custom_certificate_verifier(Arc::new(InsecureVerifier))
|
||||
.with_no_client_auth();
|
||||
info!("Built shared backend TLS client config with session resumption");
|
||||
Arc::new(config)
|
||||
}).clone()
|
||||
SHARED_CLIENT_CONFIG
|
||||
.get_or_init(|| {
|
||||
ensure_crypto_provider();
|
||||
let config = rustls::ClientConfig::builder()
|
||||
.dangerous()
|
||||
.with_custom_certificate_verifier(Arc::new(InsecureVerifier))
|
||||
.with_no_client_auth();
|
||||
info!("Built shared backend TLS client config with session resumption");
|
||||
Arc::new(config)
|
||||
})
|
||||
.clone()
|
||||
}
|
||||
|
||||
/// Get or create a shared backend TLS `ClientConfig` with ALPN `h2` + `http/1.1`.
|
||||
@@ -225,16 +241,20 @@ pub fn shared_backend_tls_config() -> Arc<rustls::ClientConfig> {
|
||||
static SHARED_CLIENT_CONFIG_ALPN: OnceLock<Arc<rustls::ClientConfig>> = OnceLock::new();
|
||||
|
||||
pub fn shared_backend_tls_config_alpn() -> Arc<rustls::ClientConfig> {
|
||||
SHARED_CLIENT_CONFIG_ALPN.get_or_init(|| {
|
||||
ensure_crypto_provider();
|
||||
let mut config = rustls::ClientConfig::builder()
|
||||
.dangerous()
|
||||
.with_custom_certificate_verifier(Arc::new(InsecureVerifier))
|
||||
.with_no_client_auth();
|
||||
config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
|
||||
info!("Built shared backend TLS client config with ALPN h2+http/1.1 for auto-detection");
|
||||
Arc::new(config)
|
||||
}).clone()
|
||||
SHARED_CLIENT_CONFIG_ALPN
|
||||
.get_or_init(|| {
|
||||
ensure_crypto_provider();
|
||||
let mut config = rustls::ClientConfig::builder()
|
||||
.dangerous()
|
||||
.with_custom_certificate_verifier(Arc::new(InsecureVerifier))
|
||||
.with_no_client_auth();
|
||||
config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
|
||||
info!(
|
||||
"Built shared backend TLS client config with ALPN h2+http/1.1 for auto-detection"
|
||||
);
|
||||
Arc::new(config)
|
||||
})
|
||||
.clone()
|
||||
}
|
||||
|
||||
/// Connect to a backend with TLS (for terminate-and-reencrypt mode).
|
||||
@@ -249,7 +269,8 @@ pub async fn connect_tls(
|
||||
let stream = TcpStream::connect(format!("{}:{}", host, port)).await?;
|
||||
stream.set_nodelay(true)?;
|
||||
// Apply keepalive with 60s default (tls_handler doesn't have ConnectionConfig access)
|
||||
if let Err(e) = crate::socket_opts::apply_keepalive(&stream, std::time::Duration::from_secs(60)) {
|
||||
if let Err(e) = crate::socket_opts::apply_keepalive(&stream, std::time::Duration::from_secs(60))
|
||||
{
|
||||
debug!("Failed to set keepalive on backend TLS socket: {}", e);
|
||||
}
|
||||
|
||||
@@ -260,10 +281,12 @@ pub async fn connect_tls(
|
||||
}
|
||||
|
||||
/// Load certificates from PEM string.
|
||||
fn load_certs(pem: &str) -> Result<Vec<CertificateDer<'static>>, Box<dyn std::error::Error + Send + Sync>> {
|
||||
fn load_certs(
|
||||
pem: &str,
|
||||
) -> Result<Vec<CertificateDer<'static>>, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let mut reader = BufReader::new(pem.as_bytes());
|
||||
let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut reader)
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
let certs: Vec<CertificateDer<'static>> =
|
||||
rustls_pemfile::certs(&mut reader).collect::<Result<Vec<_>, _>>()?;
|
||||
if certs.is_empty() {
|
||||
return Err("No certificates found in PEM data".into());
|
||||
}
|
||||
@@ -271,11 +294,13 @@ fn load_certs(pem: &str) -> Result<Vec<CertificateDer<'static>>, Box<dyn std::er
|
||||
}
|
||||
|
||||
/// Load private key from PEM string.
|
||||
fn load_private_key(pem: &str) -> Result<PrivateKeyDer<'static>, Box<dyn std::error::Error + Send + Sync>> {
|
||||
fn load_private_key(
|
||||
pem: &str,
|
||||
) -> Result<PrivateKeyDer<'static>, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let mut reader = BufReader::new(pem.as_bytes());
|
||||
// Try PKCS8 first, then RSA, then EC
|
||||
let key = rustls_pemfile::private_key(&mut reader)?
|
||||
.ok_or("No private key found in PEM data")?;
|
||||
let key =
|
||||
rustls_pemfile::private_key(&mut reader)?.ok_or("No private key found in PEM data")?;
|
||||
Ok(key)
|
||||
}
|
||||
|
||||
|
||||
@@ -17,17 +17,20 @@ use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
|
||||
use arc_swap::ArcSwap;
|
||||
use tokio::net::UdpSocket;
|
||||
use tokio::task::JoinHandle;
|
||||
use tokio::sync::{Mutex, RwLock};
|
||||
use tokio::task::JoinHandle;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
use rustproxy_config::{RouteActionType, TransportProtocol};
|
||||
use rustproxy_metrics::MetricsCollector;
|
||||
use rustproxy_routing::{MatchContext, RouteManager};
|
||||
use rustproxy_security::IpBlockList;
|
||||
|
||||
use rustproxy_http::h3_service::H3ProxyService;
|
||||
|
||||
use crate::connection_registry::ConnectionRegistry;
|
||||
|
||||
use crate::connection_tracker::ConnectionTracker;
|
||||
use crate::udp_session::{SessionKey, UdpSession, UdpSessionConfig, UdpSessionTable};
|
||||
|
||||
@@ -56,6 +59,12 @@ pub struct UdpListenerManager {
|
||||
/// Trusted proxy IPs that may send PROXY protocol v2 headers.
|
||||
/// When non-empty, PROXY v2 detection is enabled on both raw UDP and QUIC paths.
|
||||
proxy_ips: Arc<Vec<IpAddr>>,
|
||||
/// Per-route cancellation tokens (shared with TcpListenerManager).
|
||||
route_cancels: Arc<DashMap<String, CancellationToken>>,
|
||||
/// Shared connection registry for selective recycling.
|
||||
connection_registry: Arc<ConnectionRegistry>,
|
||||
/// Global ingress block policy, hot-reloadable without restarting listeners.
|
||||
security_policy: Arc<ArcSwap<IpBlockList>>,
|
||||
}
|
||||
|
||||
impl Drop for UdpListenerManager {
|
||||
@@ -76,6 +85,8 @@ impl UdpListenerManager {
|
||||
metrics: Arc<MetricsCollector>,
|
||||
conn_tracker: Arc<ConnectionTracker>,
|
||||
cancel_token: CancellationToken,
|
||||
route_cancels: Arc<DashMap<String, CancellationToken>>,
|
||||
connection_registry: Arc<ConnectionRegistry>,
|
||||
) -> Self {
|
||||
Self {
|
||||
listeners: HashMap::new(),
|
||||
@@ -89,17 +100,28 @@ impl UdpListenerManager {
|
||||
relay_reader_cancel: None,
|
||||
h3_service: None,
|
||||
proxy_ips: Arc::new(Vec::new()),
|
||||
route_cancels,
|
||||
connection_registry,
|
||||
security_policy: Arc::new(ArcSwap::from(Arc::new(IpBlockList::empty()))),
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the trusted proxy IPs for PROXY protocol v2 detection.
|
||||
pub fn set_proxy_ips(&mut self, ips: Vec<IpAddr>) {
|
||||
if !ips.is_empty() {
|
||||
info!("UDP/QUIC PROXY protocol v2 enabled for {} trusted IPs", ips.len());
|
||||
info!(
|
||||
"UDP/QUIC PROXY protocol v2 enabled for {} trusted IPs",
|
||||
ips.len()
|
||||
);
|
||||
}
|
||||
self.proxy_ips = Arc::new(ips);
|
||||
}
|
||||
|
||||
/// Set the shared global ingress security policy.
|
||||
pub fn set_security_policy(&mut self, policy: Arc<ArcSwap<IpBlockList>>) {
|
||||
self.security_policy = policy;
|
||||
}
|
||||
|
||||
/// Set the H3 proxy service for HTTP/3 request handling.
|
||||
pub fn set_h3_service(&mut self, svc: Arc<H3ProxyService>) {
|
||||
self.h3_service = Some(svc);
|
||||
@@ -132,7 +154,9 @@ impl UdpListenerManager {
|
||||
// Check if any route on this port uses QUIC
|
||||
let rm = self.route_manager.load();
|
||||
let has_quic = rm.routes_for_port(port).iter().any(|r| {
|
||||
r.action.udp.as_ref()
|
||||
r.action
|
||||
.udp
|
||||
.as_ref()
|
||||
.and_then(|u| u.quic.as_ref())
|
||||
.is_some()
|
||||
});
|
||||
@@ -152,8 +176,12 @@ impl UdpListenerManager {
|
||||
self.cancel_token.child_token(),
|
||||
self.h3_service.clone(),
|
||||
None,
|
||||
Arc::clone(&self.route_cancels),
|
||||
Arc::clone(&self.connection_registry),
|
||||
Arc::clone(&self.security_policy),
|
||||
));
|
||||
self.listeners.insert(port, (handle, Some(endpoint_for_updates)));
|
||||
self.listeners
|
||||
.insert(port, (handle, Some(endpoint_for_updates)));
|
||||
info!("QUIC endpoint started on port {}", port);
|
||||
} else {
|
||||
// Proxy relay path: we own external socket, quinn on localhost
|
||||
@@ -161,6 +189,7 @@ impl UdpListenerManager {
|
||||
port,
|
||||
tls,
|
||||
Arc::clone(&self.proxy_ips),
|
||||
Arc::clone(&self.security_policy),
|
||||
self.cancel_token.child_token(),
|
||||
)?;
|
||||
let endpoint_for_updates = relay.endpoint.clone();
|
||||
@@ -173,13 +202,20 @@ impl UdpListenerManager {
|
||||
self.cancel_token.child_token(),
|
||||
self.h3_service.clone(),
|
||||
Some(relay.real_client_map),
|
||||
Arc::clone(&self.route_cancels),
|
||||
Arc::clone(&self.connection_registry),
|
||||
Arc::clone(&self.security_policy),
|
||||
));
|
||||
self.listeners.insert(port, (handle, Some(endpoint_for_updates)));
|
||||
self.listeners
|
||||
.insert(port, (handle, Some(endpoint_for_updates)));
|
||||
info!("QUIC endpoint with PROXY relay started on port {}", port);
|
||||
}
|
||||
return Ok(());
|
||||
} else {
|
||||
warn!("QUIC routes on port {} but no TLS config provided, falling back to raw UDP", port);
|
||||
warn!(
|
||||
"QUIC routes on port {} but no TLS config provided, falling back to raw UDP",
|
||||
port
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -200,6 +236,7 @@ impl UdpListenerManager {
|
||||
Arc::clone(&self.relay_writer),
|
||||
self.cancel_token.child_token(),
|
||||
Arc::clone(&self.proxy_ips),
|
||||
Arc::clone(&self.security_policy),
|
||||
));
|
||||
|
||||
self.listeners.insert(port, (handle, None));
|
||||
@@ -240,8 +277,10 @@ impl UdpListenerManager {
|
||||
}
|
||||
debug!("UDP listener stopped on port {}", port);
|
||||
}
|
||||
info!("All UDP listeners stopped, {} sessions remaining",
|
||||
self.session_table.session_count());
|
||||
info!(
|
||||
"All UDP listeners stopped, {} sessions remaining",
|
||||
self.session_table.session_count()
|
||||
);
|
||||
}
|
||||
|
||||
/// Update TLS config on all active QUIC endpoints (cert refresh).
|
||||
@@ -274,11 +313,15 @@ impl UdpListenerManager {
|
||||
pub async fn upgrade_raw_to_quic(&mut self, tls_config: Arc<rustls::ServerConfig>) {
|
||||
// Find ports that are raw UDP fallback (endpoint=None) but have QUIC routes
|
||||
let rm = self.route_manager.load();
|
||||
let upgrade_ports: Vec<u16> = self.listeners.iter()
|
||||
let upgrade_ports: Vec<u16> = self
|
||||
.listeners
|
||||
.iter()
|
||||
.filter(|(_, (_, endpoint))| endpoint.is_none())
|
||||
.filter(|(port, _)| {
|
||||
rm.routes_for_port(**port).iter().any(|r| {
|
||||
r.action.udp.as_ref()
|
||||
r.action
|
||||
.udp
|
||||
.as_ref()
|
||||
.and_then(|u| u.quic.as_ref())
|
||||
.is_some()
|
||||
})
|
||||
@@ -287,17 +330,23 @@ impl UdpListenerManager {
|
||||
.collect();
|
||||
|
||||
for port in upgrade_ports {
|
||||
info!("Upgrading raw UDP listener on port {} to QUIC endpoint", port);
|
||||
info!(
|
||||
"Upgrading raw UDP listener on port {} to QUIC endpoint",
|
||||
port
|
||||
);
|
||||
|
||||
// Stop the raw UDP listener task and drain sessions to release the socket
|
||||
if let Some((handle, _)) = self.listeners.remove(&port) {
|
||||
handle.abort();
|
||||
}
|
||||
let drained = self.session_table.drain_port(
|
||||
port, &self.metrics, &self.conn_tracker,
|
||||
);
|
||||
let drained = self
|
||||
.session_table
|
||||
.drain_port(port, &self.metrics, &self.conn_tracker);
|
||||
if drained > 0 {
|
||||
debug!("Drained {} UDP sessions on port {} for QUIC upgrade", drained, port);
|
||||
debug!(
|
||||
"Drained {} UDP sessions on port {} for QUIC upgrade",
|
||||
drained, port
|
||||
);
|
||||
}
|
||||
|
||||
// Brief yield to let aborted tasks drop their socket references
|
||||
@@ -312,11 +361,17 @@ impl UdpListenerManager {
|
||||
|
||||
match create_result {
|
||||
Ok(()) => {
|
||||
info!("QUIC endpoint started on port {} (upgraded from raw UDP)", port);
|
||||
info!(
|
||||
"QUIC endpoint started on port {} (upgraded from raw UDP)",
|
||||
port
|
||||
);
|
||||
}
|
||||
Err(e) => {
|
||||
// Port may still be held — retry once after a brief delay
|
||||
warn!("QUIC endpoint creation failed on port {}, retrying: {}", port, e);
|
||||
warn!(
|
||||
"QUIC endpoint creation failed on port {}, retrying: {}",
|
||||
port, e
|
||||
);
|
||||
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
|
||||
|
||||
let retry_result = if self.proxy_ips.is_empty() {
|
||||
@@ -327,11 +382,17 @@ impl UdpListenerManager {
|
||||
|
||||
match retry_result {
|
||||
Ok(()) => {
|
||||
info!("QUIC endpoint started on port {} (upgraded from raw UDP, retry)", port);
|
||||
info!(
|
||||
"QUIC endpoint started on port {} (upgraded from raw UDP, retry)",
|
||||
port
|
||||
);
|
||||
}
|
||||
Err(e2) => {
|
||||
error!("Failed to upgrade port {} to QUIC after retry: {}. \
|
||||
Rebinding as raw UDP.", port, e2);
|
||||
error!(
|
||||
"Failed to upgrade port {} to QUIC after retry: {}. \
|
||||
Rebinding as raw UDP.",
|
||||
port, e2
|
||||
);
|
||||
// Fallback: rebind as raw UDP so the port isn't dead
|
||||
if let Ok(()) = self.rebind_raw_udp(port).await {
|
||||
warn!("Port {} rebound as raw UDP (QUIC upgrade failed)", port);
|
||||
@@ -344,7 +405,11 @@ impl UdpListenerManager {
|
||||
}
|
||||
|
||||
/// Create a direct QUIC endpoint (quinn owns the socket).
|
||||
fn create_quic_direct(&mut self, port: u16, tls_config: Arc<rustls::ServerConfig>) -> anyhow::Result<()> {
|
||||
fn create_quic_direct(
|
||||
&mut self,
|
||||
port: u16,
|
||||
tls_config: Arc<rustls::ServerConfig>,
|
||||
) -> anyhow::Result<()> {
|
||||
let endpoint = crate::quic_handler::create_quic_endpoint(port, tls_config)?;
|
||||
let endpoint_for_updates = endpoint.clone();
|
||||
let handle = tokio::spawn(crate::quic_handler::quic_accept_loop(
|
||||
@@ -356,17 +421,26 @@ impl UdpListenerManager {
|
||||
self.cancel_token.child_token(),
|
||||
self.h3_service.clone(),
|
||||
None,
|
||||
Arc::clone(&self.route_cancels),
|
||||
Arc::clone(&self.connection_registry),
|
||||
Arc::clone(&self.security_policy),
|
||||
));
|
||||
self.listeners.insert(port, (handle, Some(endpoint_for_updates)));
|
||||
self.listeners
|
||||
.insert(port, (handle, Some(endpoint_for_updates)));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Create a QUIC endpoint with PROXY protocol relay.
|
||||
fn create_quic_with_relay(&mut self, port: u16, tls_config: Arc<rustls::ServerConfig>) -> anyhow::Result<()> {
|
||||
fn create_quic_with_relay(
|
||||
&mut self,
|
||||
port: u16,
|
||||
tls_config: Arc<rustls::ServerConfig>,
|
||||
) -> anyhow::Result<()> {
|
||||
let relay = crate::quic_handler::create_quic_endpoint_with_proxy_relay(
|
||||
port,
|
||||
tls_config,
|
||||
Arc::clone(&self.proxy_ips),
|
||||
Arc::clone(&self.security_policy),
|
||||
self.cancel_token.child_token(),
|
||||
)?;
|
||||
let endpoint_for_updates = relay.endpoint.clone();
|
||||
@@ -379,8 +453,12 @@ impl UdpListenerManager {
|
||||
self.cancel_token.child_token(),
|
||||
self.h3_service.clone(),
|
||||
Some(relay.real_client_map),
|
||||
Arc::clone(&self.route_cancels),
|
||||
Arc::clone(&self.connection_registry),
|
||||
Arc::clone(&self.security_policy),
|
||||
));
|
||||
self.listeners.insert(port, (handle, Some(endpoint_for_updates)));
|
||||
self.listeners
|
||||
.insert(port, (handle, Some(endpoint_for_updates)));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -401,6 +479,7 @@ impl UdpListenerManager {
|
||||
Arc::clone(&self.relay_writer),
|
||||
self.cancel_token.child_token(),
|
||||
Arc::clone(&self.proxy_ips),
|
||||
Arc::clone(&self.security_policy),
|
||||
));
|
||||
|
||||
self.listeners.insert(port, (handle, None));
|
||||
@@ -440,7 +519,10 @@ impl UdpListenerManager {
|
||||
info!("Datagram handler relay connected to {}", path);
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Failed to connect datagram handler relay to {}: {}", path, e);
|
||||
error!(
|
||||
"Failed to connect datagram handler relay to {}: {}",
|
||||
path, e
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -496,6 +578,7 @@ impl UdpListenerManager {
|
||||
relay_writer: Arc<Mutex<Option<tokio::net::unix::OwnedWriteHalf>>>,
|
||||
cancel: CancellationToken,
|
||||
proxy_ips: Arc<Vec<IpAddr>>,
|
||||
security_policy: Arc<ArcSwap<IpBlockList>>,
|
||||
) {
|
||||
// Use a reasonably large buffer; actual max is per-route but we need a single buffer
|
||||
let mut buf = vec![0u8; 65535];
|
||||
@@ -510,9 +593,11 @@ impl UdpListenerManager {
|
||||
|
||||
loop {
|
||||
// Periodic cleanup: remove proxy_addr_map entries with no active session
|
||||
if !proxy_addr_map.is_empty() && last_proxy_cleanup.elapsed() >= proxy_cleanup_interval {
|
||||
if !proxy_addr_map.is_empty() && last_proxy_cleanup.elapsed() >= proxy_cleanup_interval
|
||||
{
|
||||
last_proxy_cleanup = tokio::time::Instant::now();
|
||||
let stale: Vec<SocketAddr> = proxy_addr_map.iter()
|
||||
let stale: Vec<SocketAddr> = proxy_addr_map
|
||||
.iter()
|
||||
.filter(|entry| {
|
||||
let key: SessionKey = (*entry.key(), port);
|
||||
session_table.get(&key).is_none()
|
||||
@@ -520,7 +605,11 @@ impl UdpListenerManager {
|
||||
.map(|entry| *entry.key())
|
||||
.collect();
|
||||
if !stale.is_empty() {
|
||||
debug!("UDP proxy_addr_map cleanup: removing {} stale entries on port {}", stale.len(), port);
|
||||
debug!(
|
||||
"UDP proxy_addr_map cleanup: removing {} stale entries on port {}",
|
||||
stale.len(),
|
||||
port
|
||||
);
|
||||
for addr in stale {
|
||||
proxy_addr_map.remove(&addr);
|
||||
}
|
||||
@@ -546,34 +635,50 @@ impl UdpListenerManager {
|
||||
let datagram = &buf[..len];
|
||||
|
||||
// PROXY protocol v2 detection for datagrams from trusted proxy IPs
|
||||
let effective_client_ip = if !proxy_ips.is_empty() && proxy_ips.contains(&client_addr.ip()) {
|
||||
let session_key: SessionKey = (client_addr, port);
|
||||
if session_table.get(&session_key).is_none() && !proxy_addr_map.contains_key(&client_addr) {
|
||||
// No session and no prior PROXY header — check for PROXY v2
|
||||
if crate::proxy_protocol::is_proxy_protocol_v2(datagram) {
|
||||
match crate::proxy_protocol::parse_v2(datagram) {
|
||||
Ok((header, _consumed)) => {
|
||||
debug!("UDP PROXY v2 from {}: real client {}", client_addr, header.source_addr);
|
||||
proxy_addr_map.insert(client_addr, header.source_addr);
|
||||
continue; // discard the PROXY v2 datagram
|
||||
}
|
||||
Err(e) => {
|
||||
debug!("UDP PROXY v2 parse error from {}: {}", client_addr, e);
|
||||
client_addr.ip()
|
||||
let effective_client_ip =
|
||||
if !proxy_ips.is_empty() && proxy_ips.contains(&client_addr.ip()) {
|
||||
let session_key: SessionKey = (client_addr, port);
|
||||
if session_table.get(&session_key).is_none()
|
||||
&& !proxy_addr_map.contains_key(&client_addr)
|
||||
{
|
||||
// No session and no prior PROXY header — check for PROXY v2
|
||||
if crate::proxy_protocol::is_proxy_protocol_v2(datagram) {
|
||||
match crate::proxy_protocol::parse_v2(datagram) {
|
||||
Ok((header, _consumed)) => {
|
||||
debug!(
|
||||
"UDP PROXY v2 from {}: real client {}",
|
||||
client_addr, header.source_addr
|
||||
);
|
||||
proxy_addr_map.insert(client_addr, header.source_addr);
|
||||
continue; // discard the PROXY v2 datagram
|
||||
}
|
||||
Err(e) => {
|
||||
debug!("UDP PROXY v2 parse error from {}: {}", client_addr, e);
|
||||
client_addr.ip()
|
||||
}
|
||||
}
|
||||
} else {
|
||||
client_addr.ip()
|
||||
}
|
||||
} else {
|
||||
client_addr.ip()
|
||||
// Use real client IP if we've previously seen a PROXY v2 header
|
||||
proxy_addr_map
|
||||
.get(&client_addr)
|
||||
.map(|r| r.ip())
|
||||
.unwrap_or_else(|| client_addr.ip())
|
||||
}
|
||||
} else {
|
||||
// Use real client IP if we've previously seen a PROXY v2 header
|
||||
proxy_addr_map.get(&client_addr)
|
||||
.map(|r| r.ip())
|
||||
.unwrap_or_else(|| client_addr.ip())
|
||||
}
|
||||
} else {
|
||||
client_addr.ip()
|
||||
};
|
||||
client_addr.ip()
|
||||
};
|
||||
|
||||
let block_list = security_policy.load();
|
||||
if !block_list.is_empty() && block_list.is_blocked(&effective_client_ip) {
|
||||
debug!(
|
||||
"UDP datagram from {} blocked by global security policy",
|
||||
effective_client_ip
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Route matching — use effective (real) client IP
|
||||
let rm = route_manager.load();
|
||||
@@ -593,13 +698,16 @@ impl UdpListenerManager {
|
||||
let route_match = match rm.find_route(&ctx) {
|
||||
Some(m) => m,
|
||||
None => {
|
||||
debug!("No UDP route matched for port {} from {}", port, client_addr);
|
||||
debug!(
|
||||
"No UDP route matched for port {} from {}",
|
||||
port, client_addr
|
||||
);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let route = route_match.route;
|
||||
let route_id = route.name.as_deref().or(route.id.as_deref());
|
||||
let route_id = route.metrics_key();
|
||||
|
||||
// Socket handler routes → relay datagram to TS via persistent Unix socket
|
||||
if route.action.action_type == RouteActionType::SocketHandler {
|
||||
@@ -609,7 +717,9 @@ impl UdpListenerManager {
|
||||
&client_addr,
|
||||
port,
|
||||
datagram,
|
||||
).await {
|
||||
)
|
||||
.await
|
||||
{
|
||||
debug!("Failed to relay UDP datagram to TS: {}", e);
|
||||
}
|
||||
continue;
|
||||
@@ -620,8 +730,10 @@ impl UdpListenerManager {
|
||||
|
||||
// Check datagram size
|
||||
if len as u32 > udp_config.max_datagram_size {
|
||||
debug!("UDP datagram too large ({} > {}) from {}, dropping",
|
||||
len, udp_config.max_datagram_size, client_addr);
|
||||
debug!(
|
||||
"UDP datagram too large ({} > {}) from {}, dropping",
|
||||
len, udp_config.max_datagram_size, client_addr
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -633,21 +745,27 @@ impl UdpListenerManager {
|
||||
None => {
|
||||
// New session — check per-IP limits using the real client IP
|
||||
if !conn_tracker.try_accept(&effective_client_ip) {
|
||||
debug!("UDP session rejected for {} (rate limit)", effective_client_ip);
|
||||
debug!(
|
||||
"UDP session rejected for {} (rate limit)",
|
||||
effective_client_ip
|
||||
);
|
||||
continue;
|
||||
}
|
||||
if !session_table.can_create_session(
|
||||
&effective_client_ip,
|
||||
udp_config.max_sessions_per_ip,
|
||||
) {
|
||||
debug!("UDP session rejected for {} (per-IP session limit)", effective_client_ip);
|
||||
if !session_table
|
||||
.can_create_session(&effective_client_ip, udp_config.max_sessions_per_ip)
|
||||
{
|
||||
debug!(
|
||||
"UDP session rejected for {} (per-IP session limit)",
|
||||
effective_client_ip
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Resolve target
|
||||
let target = match route_match.target.or_else(|| {
|
||||
route.action.targets.as_ref().and_then(|t| t.first())
|
||||
}) {
|
||||
let target = match route_match
|
||||
.target
|
||||
.or_else(|| route.action.targets.as_ref().and_then(|t| t.first()))
|
||||
{
|
||||
Some(t) => t,
|
||||
None => {
|
||||
warn!("No target for UDP route {:?}", route_id);
|
||||
@@ -668,13 +786,18 @@ impl UdpListenerManager {
|
||||
}
|
||||
};
|
||||
if let Err(e) = backend_socket.connect(&backend_addr).await {
|
||||
error!("Failed to connect backend UDP socket to {}: {}", backend_addr, e);
|
||||
error!(
|
||||
"Failed to connect backend UDP socket to {}: {}",
|
||||
backend_addr, e
|
||||
);
|
||||
continue;
|
||||
}
|
||||
let backend_socket = Arc::new(backend_socket);
|
||||
|
||||
debug!("New UDP session: {} -> {} (via port {}, real client {})",
|
||||
client_addr, backend_addr, port, effective_client_ip);
|
||||
debug!(
|
||||
"New UDP session: {} -> {} (via port {}, real client {})",
|
||||
client_addr, backend_addr, port, effective_client_ip
|
||||
);
|
||||
|
||||
// Spawn return-path relay task
|
||||
let session_cancel = CancellationToken::new();
|
||||
@@ -691,7 +814,9 @@ impl UdpListenerManager {
|
||||
|
||||
let session = Arc::new(UdpSession {
|
||||
backend_socket,
|
||||
last_activity: std::sync::atomic::AtomicU64::new(session_table.elapsed_ms()),
|
||||
last_activity: std::sync::atomic::AtomicU64::new(
|
||||
session_table.elapsed_ms(),
|
||||
),
|
||||
created_at: std::time::Instant::now(),
|
||||
route_id: route_id.map(|s| s.to_string()),
|
||||
source_ip: effective_client_ip,
|
||||
@@ -700,7 +825,11 @@ impl UdpListenerManager {
|
||||
cancel: session_cancel,
|
||||
});
|
||||
|
||||
if !session_table.insert(session_key, Arc::clone(&session), udp_config.max_sessions_per_ip) {
|
||||
if !session_table.insert(
|
||||
session_key,
|
||||
Arc::clone(&session),
|
||||
udp_config.max_sessions_per_ip,
|
||||
) {
|
||||
warn!("Failed to insert UDP session (race condition)");
|
||||
continue;
|
||||
}
|
||||
@@ -717,7 +846,9 @@ impl UdpListenerManager {
|
||||
// Forward datagram to backend
|
||||
match session.backend_socket.send(datagram).await {
|
||||
Ok(_) => {
|
||||
session.last_activity.store(session_table.elapsed_ms(), Ordering::Relaxed);
|
||||
session
|
||||
.last_activity
|
||||
.store(session_table.elapsed_ms(), Ordering::Relaxed);
|
||||
metrics.record_bytes(len as u64, 0, route_id, Some(&ip_str));
|
||||
metrics.record_datagram_in();
|
||||
}
|
||||
@@ -761,7 +892,9 @@ impl UdpListenerManager {
|
||||
Ok(_) => {
|
||||
// Update session activity
|
||||
if let Some(session) = session_table.get(&session_key) {
|
||||
session.last_activity.store(session_table.elapsed_ms(), Ordering::Relaxed);
|
||||
session
|
||||
.last_activity
|
||||
.store(session_table.elapsed_ms(), Ordering::Relaxed);
|
||||
}
|
||||
metrics.record_bytes(0, len as u64, route_id.as_deref(), Some(&ip_str));
|
||||
metrics.record_datagram_out();
|
||||
@@ -796,7 +929,8 @@ impl UdpListenerManager {
|
||||
let json = serde_json::to_vec(&msg)?;
|
||||
|
||||
let mut guard = writer.lock().await;
|
||||
let stream = guard.as_mut()
|
||||
let stream = guard
|
||||
.as_mut()
|
||||
.ok_or_else(|| anyhow::anyhow!("Datagram relay not connected"))?;
|
||||
|
||||
// Length-prefixed frame
|
||||
@@ -861,9 +995,15 @@ impl UdpListenerManager {
|
||||
}
|
||||
|
||||
let source_ip = reply.get("sourceIp").and_then(|v| v.as_str()).unwrap_or("");
|
||||
let source_port = reply.get("sourcePort").and_then(|v| v.as_u64()).unwrap_or(0) as u16;
|
||||
let source_port = reply
|
||||
.get("sourcePort")
|
||||
.and_then(|v| v.as_u64())
|
||||
.unwrap_or(0) as u16;
|
||||
let dest_port = reply.get("destPort").and_then(|v| v.as_u64()).unwrap_or(0) as u16;
|
||||
let payload_b64 = reply.get("payloadBase64").and_then(|v| v.as_str()).unwrap_or("");
|
||||
let payload_b64 = reply
|
||||
.get("payloadBase64")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("");
|
||||
|
||||
let payload = match base64::engine::general_purpose::STANDARD.decode(payload_b64) {
|
||||
Ok(p) => p,
|
||||
|
||||
@@ -111,12 +111,15 @@ impl UdpSessionTable {
|
||||
|
||||
/// Look up an existing session.
|
||||
pub fn get(&self, key: &SessionKey) -> Option<Arc<UdpSession>> {
|
||||
self.sessions.get(key).map(|entry| Arc::clone(entry.value()))
|
||||
self.sessions
|
||||
.get(key)
|
||||
.map(|entry| Arc::clone(entry.value()))
|
||||
}
|
||||
|
||||
/// Check if we can create a new session for this IP (under the per-IP limit).
|
||||
pub fn can_create_session(&self, ip: &IpAddr, max_per_ip: u32) -> bool {
|
||||
let count = self.ip_session_counts
|
||||
let count = self
|
||||
.ip_session_counts
|
||||
.get(ip)
|
||||
.map(|c| *c.value())
|
||||
.unwrap_or(0);
|
||||
@@ -124,12 +127,7 @@ impl UdpSessionTable {
|
||||
}
|
||||
|
||||
/// Insert a new session. Returns false if per-IP limit exceeded.
|
||||
pub fn insert(
|
||||
&self,
|
||||
key: SessionKey,
|
||||
session: Arc<UdpSession>,
|
||||
max_per_ip: u32,
|
||||
) -> bool {
|
||||
pub fn insert(&self, key: SessionKey, session: Arc<UdpSession>, max_per_ip: u32) -> bool {
|
||||
let ip = session.source_ip;
|
||||
|
||||
// Atomically check and increment per-IP count
|
||||
@@ -173,7 +171,9 @@ impl UdpSessionTable {
|
||||
let mut removed = 0;
|
||||
|
||||
// Collect keys to remove (avoid holding DashMap refs during removal)
|
||||
let stale_keys: Vec<SessionKey> = self.sessions.iter()
|
||||
let stale_keys: Vec<SessionKey> = self
|
||||
.sessions
|
||||
.iter()
|
||||
.filter(|entry| {
|
||||
let last = entry.value().last_activity.load(Ordering::Relaxed);
|
||||
now_ms.saturating_sub(last) >= timeout_ms
|
||||
@@ -185,7 +185,8 @@ impl UdpSessionTable {
|
||||
if let Some(session) = self.remove(&key) {
|
||||
debug!(
|
||||
"UDP session expired: {} -> port {} (idle {}ms)",
|
||||
session.client_addr, key.1,
|
||||
session.client_addr,
|
||||
key.1,
|
||||
now_ms.saturating_sub(session.last_activity.load(Ordering::Relaxed))
|
||||
);
|
||||
conn_tracker.connection_closed(&session.source_ip);
|
||||
@@ -210,7 +211,9 @@ impl UdpSessionTable {
|
||||
metrics: &MetricsCollector,
|
||||
conn_tracker: &ConnectionTracker,
|
||||
) -> usize {
|
||||
let keys: Vec<SessionKey> = self.sessions.iter()
|
||||
let keys: Vec<SessionKey> = self
|
||||
.sessions
|
||||
.iter()
|
||||
.filter(|entry| entry.key().1 == port)
|
||||
.map(|entry| *entry.key())
|
||||
.collect();
|
||||
@@ -257,9 +260,8 @@ mod tests {
|
||||
.enable_all()
|
||||
.build()
|
||||
.unwrap();
|
||||
let backend_socket = rt.block_on(async {
|
||||
Arc::new(UdpSocket::bind("127.0.0.1:0").await.unwrap())
|
||||
});
|
||||
let backend_socket =
|
||||
rt.block_on(async { Arc::new(UdpSocket::bind("127.0.0.1:0").await.unwrap()) });
|
||||
|
||||
let child_cancel = cancel.child_token();
|
||||
let return_task = rt.spawn(async move {
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
//! Route matching engine for RustProxy.
|
||||
//! Provides domain/path/IP/header matchers and a port-indexed RouteManager.
|
||||
|
||||
pub mod route_manager;
|
||||
pub mod matchers;
|
||||
pub mod route_manager;
|
||||
|
||||
pub use route_manager::*;
|
||||
|
||||
@@ -20,7 +20,7 @@ pub fn domain_matches(pattern: &str, domain: &str) -> bool {
|
||||
// Wildcard patterns
|
||||
if pattern.starts_with("*.") || pattern.starts_with("*.") {
|
||||
let suffix = &pattern[2..]; // e.g., "example.com"
|
||||
// Match exact parent or any single-level subdomain
|
||||
// Match exact parent or any single-level subdomain
|
||||
if domain.eq_ignore_ascii_case(suffix) {
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -1,5 +1,42 @@
|
||||
use std::collections::HashMap;
|
||||
use regex::Regex;
|
||||
use std::collections::HashMap;
|
||||
|
||||
fn compile_regex_pattern(pattern: &str) -> Option<Regex> {
|
||||
if !pattern.starts_with('/') {
|
||||
return None;
|
||||
}
|
||||
|
||||
let last_slash = pattern.rfind('/')?;
|
||||
if last_slash == 0 {
|
||||
return None;
|
||||
}
|
||||
|
||||
let regex_body = &pattern[1..last_slash];
|
||||
let flags = &pattern[last_slash + 1..];
|
||||
|
||||
let mut inline_flags = String::new();
|
||||
for flag in flags.chars() {
|
||||
match flag {
|
||||
'i' | 'm' | 's' | 'u' => {
|
||||
if !inline_flags.contains(flag) {
|
||||
inline_flags.push(flag);
|
||||
}
|
||||
}
|
||||
'g' => {
|
||||
// Global has no effect for single header matching.
|
||||
}
|
||||
_ => return None,
|
||||
}
|
||||
}
|
||||
|
||||
let compiled = if inline_flags.is_empty() {
|
||||
regex_body.to_string()
|
||||
} else {
|
||||
format!("(?{}){}", inline_flags, regex_body)
|
||||
};
|
||||
|
||||
Regex::new(&compiled).ok()
|
||||
}
|
||||
|
||||
/// Match HTTP headers against a set of patterns.
|
||||
///
|
||||
@@ -24,16 +61,15 @@ pub fn headers_match(
|
||||
None => return false, // Required header not present
|
||||
};
|
||||
|
||||
// Check if pattern is a regex (surrounded by /)
|
||||
if pattern.starts_with('/') && pattern.ends_with('/') && pattern.len() > 2 {
|
||||
let regex_str = &pattern[1..pattern.len() - 1];
|
||||
match Regex::new(regex_str) {
|
||||
Ok(re) => {
|
||||
// Check if pattern is a regex literal (/pattern/ or /pattern/flags)
|
||||
if pattern.starts_with('/') && pattern.len() > 2 {
|
||||
match compile_regex_pattern(pattern) {
|
||||
Some(re) => {
|
||||
if !re.is_match(header_value) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
None => {
|
||||
// Invalid regex, fall back to exact match
|
||||
if header_value != pattern {
|
||||
return false;
|
||||
@@ -85,6 +121,24 @@ mod tests {
|
||||
assert!(headers_match(&patterns, &headers));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_regex_header_match_with_flags() {
|
||||
let patterns: HashMap<String, String> = {
|
||||
let mut m = HashMap::new();
|
||||
m.insert(
|
||||
"Content-Type".to_string(),
|
||||
"/^application\\/json$/i".to_string(),
|
||||
);
|
||||
m
|
||||
};
|
||||
let headers: HashMap<String, String> = {
|
||||
let mut m = HashMap::new();
|
||||
m.insert("content-type".to_string(), "Application/JSON".to_string());
|
||||
m
|
||||
};
|
||||
assert!(headers_match(&patterns, &headers));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_missing_header() {
|
||||
let patterns: HashMap<String, String> = {
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use ipnet::IpNet;
|
||||
use std::net::IpAddr;
|
||||
use std::str::FromStr;
|
||||
use ipnet::IpNet;
|
||||
|
||||
/// Match an IP address against a pattern.
|
||||
///
|
||||
@@ -85,7 +85,10 @@ fn wildcard_to_cidr(pattern: &str) -> Option<String> {
|
||||
}
|
||||
}
|
||||
|
||||
Some(format!("{}.{}.{}.{}/{}", octets[0], octets[1], octets[2], octets[3], prefix_len))
|
||||
Some(format!(
|
||||
"{}.{}.{}.{}/{}",
|
||||
octets[0], octets[1], octets[2], octets[3], prefix_len
|
||||
))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
pub mod domain;
|
||||
pub mod path;
|
||||
pub mod ip;
|
||||
pub mod header;
|
||||
pub mod ip;
|
||||
pub mod path;
|
||||
|
||||
pub use domain::*;
|
||||
pub use path::*;
|
||||
pub use ip::*;
|
||||
pub use header::*;
|
||||
pub use ip::*;
|
||||
pub use path::*;
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use rustproxy_config::{RouteConfig, RouteTarget, TransportProtocol, TlsMode};
|
||||
use crate::matchers;
|
||||
use rustproxy_config::{RouteConfig, RouteTarget, TlsMode, TransportProtocol};
|
||||
|
||||
/// Context for route matching (subset of connection info).
|
||||
pub struct MatchContext<'a> {
|
||||
@@ -42,19 +42,14 @@ impl RouteManager {
|
||||
};
|
||||
|
||||
// Filter enabled routes and sort by priority
|
||||
let mut enabled_routes: Vec<RouteConfig> = routes
|
||||
.into_iter()
|
||||
.filter(|r| r.is_enabled())
|
||||
.collect();
|
||||
let mut enabled_routes: Vec<RouteConfig> =
|
||||
routes.into_iter().filter(|r| r.is_enabled()).collect();
|
||||
enabled_routes.sort_by(|a, b| b.effective_priority().cmp(&a.effective_priority()));
|
||||
|
||||
// Build port index
|
||||
for (idx, route) in enabled_routes.iter().enumerate() {
|
||||
for port in route.listening_ports() {
|
||||
manager.port_index
|
||||
.entry(port)
|
||||
.or_default()
|
||||
.push(idx);
|
||||
manager.port_index.entry(port).or_default().push(idx);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -66,7 +61,9 @@ impl RouteManager {
|
||||
/// Used to skip expensive header HashMap construction when no route needs it.
|
||||
pub fn any_route_has_headers(&self, port: u16) -> bool {
|
||||
if let Some(indices) = self.port_index.get(&port) {
|
||||
indices.iter().any(|&idx| self.routes[idx].route_match.headers.is_some())
|
||||
indices
|
||||
.iter()
|
||||
.any(|&idx| self.routes[idx].route_match.headers.is_some())
|
||||
} else {
|
||||
false
|
||||
}
|
||||
@@ -99,8 +96,8 @@ impl RouteManager {
|
||||
let ctx_transport = ctx.transport.as_ref();
|
||||
match (route_transport, ctx_transport) {
|
||||
// Route requires UDP only — reject non-UDP contexts
|
||||
(Some(TransportProtocol::Udp), None) |
|
||||
(Some(TransportProtocol::Udp), Some(TransportProtocol::Tcp)) => return false,
|
||||
(Some(TransportProtocol::Udp), None)
|
||||
| (Some(TransportProtocol::Udp), Some(TransportProtocol::Tcp)) => return false,
|
||||
// Route requires TCP only — reject UDP contexts
|
||||
(Some(TransportProtocol::Tcp), Some(TransportProtocol::Udp)) => return false,
|
||||
// Route has no transport (default = TCP) — reject UDP contexts
|
||||
@@ -196,7 +193,11 @@ impl RouteManager {
|
||||
}
|
||||
|
||||
/// Find the best matching target within a route.
|
||||
fn find_target<'a>(&self, route: &'a RouteConfig, ctx: &MatchContext<'_>) -> Option<&'a RouteTarget> {
|
||||
fn find_target<'a>(
|
||||
&self,
|
||||
route: &'a RouteConfig,
|
||||
ctx: &MatchContext<'_>,
|
||||
) -> Option<&'a RouteTarget> {
|
||||
let targets = route.action.targets.as_ref()?;
|
||||
|
||||
if targets.len() == 1 && targets[0].target_match.is_none() {
|
||||
@@ -223,17 +224,11 @@ impl RouteManager {
|
||||
}
|
||||
|
||||
// Fall back to first target without match criteria
|
||||
best.or_else(|| {
|
||||
targets.iter().find(|t| t.target_match.is_none())
|
||||
})
|
||||
best.or_else(|| targets.iter().find(|t| t.target_match.is_none()))
|
||||
}
|
||||
|
||||
/// Check if a target match criteria matches the context.
|
||||
fn matches_target(
|
||||
&self,
|
||||
tm: &rustproxy_config::TargetMatch,
|
||||
ctx: &MatchContext<'_>,
|
||||
) -> bool {
|
||||
fn matches_target(&self, tm: &rustproxy_config::TargetMatch, ctx: &MatchContext<'_>) -> bool {
|
||||
// Port matching
|
||||
if let Some(ref ports) = tm.ports {
|
||||
if !ports.contains(&ctx.port) {
|
||||
@@ -281,6 +276,11 @@ impl RouteManager {
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
/// Get all enabled routes.
|
||||
pub fn routes(&self) -> &[RouteConfig] {
|
||||
&self.routes
|
||||
}
|
||||
|
||||
/// Get the total number of enabled routes.
|
||||
pub fn route_count(&self) -> usize {
|
||||
self.routes.len()
|
||||
@@ -293,9 +293,7 @@ impl RouteManager {
|
||||
// If multiple passthrough routes on same port, SNI is needed
|
||||
let passthrough_routes: Vec<_> = routes
|
||||
.iter()
|
||||
.filter(|r| {
|
||||
r.tls_mode() == Some(&TlsMode::Passthrough)
|
||||
})
|
||||
.filter(|r| r.tls_mode() == Some(&TlsMode::Passthrough))
|
||||
.collect();
|
||||
|
||||
if passthrough_routes.len() > 1 {
|
||||
@@ -355,8 +353,6 @@ mod tests {
|
||||
load_balancing: None,
|
||||
advanced: None,
|
||||
options: None,
|
||||
forwarding_engine: None,
|
||||
nftables: None,
|
||||
send_proxy_protocol: None,
|
||||
udp: None,
|
||||
},
|
||||
@@ -416,7 +412,11 @@ mod tests {
|
||||
|
||||
let result = manager.find_route(&ctx).unwrap();
|
||||
// Should match the higher-priority specific route
|
||||
assert!(result.route.route_match.domains.as_ref()
|
||||
assert!(result
|
||||
.route
|
||||
.route_match
|
||||
.domains
|
||||
.as_ref()
|
||||
.map(|d| d.to_vec())
|
||||
.unwrap()
|
||||
.contains(&"api.example.com"));
|
||||
@@ -616,8 +616,14 @@ mod tests {
|
||||
|
||||
let result = manager.find_route(&ctx);
|
||||
assert!(result.is_some());
|
||||
let matched_domains = result.unwrap().route.route_match.domains.as_ref()
|
||||
.map(|d| d.to_vec()).unwrap();
|
||||
let matched_domains = result
|
||||
.unwrap()
|
||||
.route
|
||||
.route_match
|
||||
.domains
|
||||
.as_ref()
|
||||
.map(|d| d.to_vec())
|
||||
.unwrap();
|
||||
assert!(matched_domains.contains(&"*"));
|
||||
}
|
||||
|
||||
@@ -732,7 +738,11 @@ mod tests {
|
||||
assert_eq!(result.target.unwrap().host.first(), "default-backend");
|
||||
}
|
||||
|
||||
fn make_route_with_protocol(port: u16, domain: Option<&str>, protocol: Option<&str>) -> RouteConfig {
|
||||
fn make_route_with_protocol(
|
||||
port: u16,
|
||||
domain: Option<&str>,
|
||||
protocol: Option<&str>,
|
||||
) -> RouteConfig {
|
||||
let mut route = make_route(port, domain, 0);
|
||||
route.route_match.protocol = protocol.map(|s| s.to_string());
|
||||
route
|
||||
@@ -1026,8 +1036,10 @@ mod tests {
|
||||
transport: Some(TransportProtocol::Udp),
|
||||
};
|
||||
|
||||
assert!(manager.find_route(&ctx).is_some(),
|
||||
"QUIC (UDP) with is_tls=true and domain=None should match domain-restricted routes");
|
||||
assert!(
|
||||
manager.find_route(&ctx).is_some(),
|
||||
"QUIC (UDP) with is_tls=true and domain=None should match domain-restricted routes"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -1048,7 +1060,9 @@ mod tests {
|
||||
transport: None, // TCP (default)
|
||||
};
|
||||
|
||||
assert!(manager.find_route(&ctx).is_none(),
|
||||
"TCP TLS without SNI should NOT match domain-restricted routes");
|
||||
assert!(
|
||||
manager.find_route(&ctx).is_none(),
|
||||
"TCP TLS without SNI should NOT match domain-restricted routes"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use base64::Engine;
|
||||
use base64::engine::general_purpose::STANDARD as BASE64;
|
||||
use base64::Engine;
|
||||
|
||||
/// Basic auth validator.
|
||||
pub struct BasicAuthValidator {
|
||||
|
||||
@@ -2,14 +2,26 @@ use ipnet::IpNet;
|
||||
use std::net::IpAddr;
|
||||
use std::str::FromStr;
|
||||
|
||||
use rustproxy_config::IpAllowEntry;
|
||||
|
||||
/// IP filter supporting CIDR ranges, wildcards, and exact matches.
|
||||
/// Supports domain-scoped allow entries that restrict an IP to specific domains.
|
||||
pub struct IpFilter {
|
||||
/// Plain allow entries — IP allowed for any domain on the route
|
||||
allow_list: Vec<IpPattern>,
|
||||
/// Domain-scoped allow entries — IP allowed only for matching domains
|
||||
domain_scoped: Vec<DomainScopedEntry>,
|
||||
block_list: Vec<IpPattern>,
|
||||
}
|
||||
|
||||
/// A domain-scoped allow entry: IP + list of allowed domain patterns.
|
||||
struct DomainScopedEntry {
|
||||
pattern: IpPattern,
|
||||
domains: Vec<String>,
|
||||
}
|
||||
|
||||
/// Represents an IP pattern for matching.
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
enum IpPattern {
|
||||
/// Exact IP match
|
||||
Exact(IpAddr),
|
||||
@@ -19,6 +31,37 @@ enum IpPattern {
|
||||
Wildcard,
|
||||
}
|
||||
|
||||
/// Compiled block list for early ingress filtering.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct IpBlockList {
|
||||
block_list: Vec<IpPattern>,
|
||||
}
|
||||
|
||||
impl IpBlockList {
|
||||
pub fn new(block_list: &[String]) -> Self {
|
||||
Self {
|
||||
block_list: block_list.iter().map(|s| IpPattern::parse(s)).collect(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn empty() -> Self {
|
||||
Self {
|
||||
block_list: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.block_list.is_empty()
|
||||
}
|
||||
|
||||
pub fn is_blocked(&self, ip: &IpAddr) -> bool {
|
||||
let normalized = IpFilter::normalize_ip(ip);
|
||||
self.block_list
|
||||
.iter()
|
||||
.any(|pattern| pattern.matches(&normalized))
|
||||
}
|
||||
}
|
||||
|
||||
impl IpPattern {
|
||||
fn parse(s: &str) -> Self {
|
||||
let s = s.trim();
|
||||
@@ -31,10 +74,6 @@ impl IpPattern {
|
||||
if let Ok(addr) = IpAddr::from_str(s) {
|
||||
return IpPattern::Exact(addr);
|
||||
}
|
||||
// Try as CIDR by appending default prefix
|
||||
if let Ok(addr) = IpAddr::from_str(s) {
|
||||
return IpPattern::Exact(addr);
|
||||
}
|
||||
// Fallback: treat as exact, will never match an invalid string
|
||||
IpPattern::Exact(IpAddr::from_str("0.0.0.0").unwrap())
|
||||
}
|
||||
@@ -48,19 +87,55 @@ impl IpPattern {
|
||||
}
|
||||
}
|
||||
|
||||
/// Simple domain pattern matching (exact, `*`, or `*.suffix`).
|
||||
fn domain_matches_pattern(pattern: &str, domain: &str) -> bool {
|
||||
let p = pattern.trim();
|
||||
let d = domain.trim();
|
||||
if p == "*" {
|
||||
return true;
|
||||
}
|
||||
if p.eq_ignore_ascii_case(d) {
|
||||
return true;
|
||||
}
|
||||
if p.starts_with("*.") {
|
||||
let suffix = &p[1..]; // e.g., ".abc.xyz"
|
||||
d.len() > suffix.len() && d[d.len() - suffix.len()..].eq_ignore_ascii_case(suffix)
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
impl IpFilter {
|
||||
/// Create a new IP filter from allow and block lists.
|
||||
pub fn new(allow_list: &[String], block_list: &[String]) -> Self {
|
||||
/// Create a new IP filter from allow entries and a block list.
|
||||
pub fn new(allow_entries: &[IpAllowEntry], block_list: &[String]) -> Self {
|
||||
let mut allow_list = Vec::new();
|
||||
let mut domain_scoped = Vec::new();
|
||||
|
||||
for entry in allow_entries {
|
||||
match entry {
|
||||
IpAllowEntry::Plain(ip) => {
|
||||
allow_list.push(IpPattern::parse(ip));
|
||||
}
|
||||
IpAllowEntry::DomainScoped { ip, domains } => {
|
||||
domain_scoped.push(DomainScopedEntry {
|
||||
pattern: IpPattern::parse(ip),
|
||||
domains: domains.clone(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Self {
|
||||
allow_list: allow_list.iter().map(|s| IpPattern::parse(s)).collect(),
|
||||
allow_list,
|
||||
domain_scoped,
|
||||
block_list: block_list.iter().map(|s| IpPattern::parse(s)).collect(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if an IP is allowed.
|
||||
/// If allow_list is non-empty, IP must match at least one entry.
|
||||
/// If block_list is non-empty, IP must NOT match any entry.
|
||||
pub fn is_allowed(&self, ip: &IpAddr) -> bool {
|
||||
/// Check if an IP is allowed, considering domain-scoped entries.
|
||||
/// If `domain` is Some, domain-scoped entries are evaluated against it.
|
||||
/// If `domain` is None, only plain allow entries are considered.
|
||||
pub fn is_allowed_for_domain(&self, ip: &IpAddr, domain: Option<&str>) -> bool {
|
||||
// Check block list first
|
||||
if !self.block_list.is_empty() {
|
||||
for pattern in &self.block_list {
|
||||
@@ -70,14 +145,40 @@ impl IpFilter {
|
||||
}
|
||||
}
|
||||
|
||||
// If allow list is non-empty, must match at least one
|
||||
if !self.allow_list.is_empty() {
|
||||
return self.allow_list.iter().any(|p| p.matches(ip));
|
||||
// If there are any allow entries (plain or domain-scoped), IP must match
|
||||
let has_any_allow = !self.allow_list.is_empty() || !self.domain_scoped.is_empty();
|
||||
if has_any_allow {
|
||||
// Check plain allow list — grants access to entire route
|
||||
if self.allow_list.iter().any(|p| p.matches(ip)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Check domain-scoped entries — grants access only if domain matches
|
||||
if let Some(req_domain) = domain {
|
||||
for entry in &self.domain_scoped {
|
||||
if entry.pattern.matches(ip) {
|
||||
if entry
|
||||
.domains
|
||||
.iter()
|
||||
.any(|d| domain_matches_pattern(d, req_domain))
|
||||
{
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
/// Check if an IP is allowed (backwards-compat wrapper, no domain context).
|
||||
pub fn is_allowed(&self, ip: &IpAddr) -> bool {
|
||||
self.is_allowed_for_domain(ip, None)
|
||||
}
|
||||
|
||||
/// Normalize IPv4-mapped IPv6 addresses (::ffff:x.x.x.x -> x.x.x.x)
|
||||
pub fn normalize_ip(ip: &IpAddr) -> IpAddr {
|
||||
match ip {
|
||||
@@ -97,19 +198,28 @@ impl IpFilter {
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn plain(s: &str) -> IpAllowEntry {
|
||||
IpAllowEntry::Plain(s.to_string())
|
||||
}
|
||||
|
||||
fn scoped(ip: &str, domains: &[&str]) -> IpAllowEntry {
|
||||
IpAllowEntry::DomainScoped {
|
||||
ip: ip.to_string(),
|
||||
domains: domains.iter().map(|s| s.to_string()).collect(),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_lists_allow_all() {
|
||||
let filter = IpFilter::new(&[], &[]);
|
||||
let ip: IpAddr = "192.168.1.1".parse().unwrap();
|
||||
assert!(filter.is_allowed(&ip));
|
||||
assert!(filter.is_allowed_for_domain(&ip, Some("example.com")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_allow_list_exact() {
|
||||
let filter = IpFilter::new(
|
||||
&["10.0.0.1".to_string()],
|
||||
&[],
|
||||
);
|
||||
fn test_plain_allow_list_exact() {
|
||||
let filter = IpFilter::new(&[plain("10.0.0.1")], &[]);
|
||||
let allowed: IpAddr = "10.0.0.1".parse().unwrap();
|
||||
let denied: IpAddr = "10.0.0.2".parse().unwrap();
|
||||
assert!(filter.is_allowed(&allowed));
|
||||
@@ -117,11 +227,8 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_allow_list_cidr() {
|
||||
let filter = IpFilter::new(
|
||||
&["10.0.0.0/8".to_string()],
|
||||
&[],
|
||||
);
|
||||
fn test_plain_allow_list_cidr() {
|
||||
let filter = IpFilter::new(&[plain("10.0.0.0/8")], &[]);
|
||||
let allowed: IpAddr = "10.255.255.255".parse().unwrap();
|
||||
let denied: IpAddr = "192.168.1.1".parse().unwrap();
|
||||
assert!(filter.is_allowed(&allowed));
|
||||
@@ -130,10 +237,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_block_list() {
|
||||
let filter = IpFilter::new(
|
||||
&[],
|
||||
&["192.168.1.100".to_string()],
|
||||
);
|
||||
let filter = IpFilter::new(&[], &["192.168.1.100".to_string()]);
|
||||
let blocked: IpAddr = "192.168.1.100".parse().unwrap();
|
||||
let allowed: IpAddr = "192.168.1.101".parse().unwrap();
|
||||
assert!(!filter.is_allowed(&blocked));
|
||||
@@ -142,10 +246,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_block_trumps_allow() {
|
||||
let filter = IpFilter::new(
|
||||
&["10.0.0.0/8".to_string()],
|
||||
&["10.0.0.5".to_string()],
|
||||
);
|
||||
let filter = IpFilter::new(&[plain("10.0.0.0/8")], &["10.0.0.5".to_string()]);
|
||||
let blocked: IpAddr = "10.0.0.5".parse().unwrap();
|
||||
let allowed: IpAddr = "10.0.0.6".parse().unwrap();
|
||||
assert!(!filter.is_allowed(&blocked));
|
||||
@@ -154,20 +255,14 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_wildcard_allow() {
|
||||
let filter = IpFilter::new(
|
||||
&["*".to_string()],
|
||||
&[],
|
||||
);
|
||||
let filter = IpFilter::new(&[plain("*")], &[]);
|
||||
let ip: IpAddr = "1.2.3.4".parse().unwrap();
|
||||
assert!(filter.is_allowed(&ip));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_wildcard_block() {
|
||||
let filter = IpFilter::new(
|
||||
&[],
|
||||
&["*".to_string()],
|
||||
);
|
||||
let filter = IpFilter::new(&[], &["*".to_string()]);
|
||||
let ip: IpAddr = "1.2.3.4".parse().unwrap();
|
||||
assert!(!filter.is_allowed(&ip));
|
||||
}
|
||||
@@ -186,4 +281,85 @@ mod tests {
|
||||
let normalized = IpFilter::normalize_ip(&ip);
|
||||
assert_eq!(normalized, ip);
|
||||
}
|
||||
|
||||
// Domain-scoped tests
|
||||
|
||||
#[test]
|
||||
fn test_domain_scoped_allows_matching_domain() {
|
||||
let filter = IpFilter::new(&[scoped("10.8.0.2", &["outline.abc.xyz"])], &[]);
|
||||
let ip: IpAddr = "10.8.0.2".parse().unwrap();
|
||||
assert!(filter.is_allowed_for_domain(&ip, Some("outline.abc.xyz")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_domain_scoped_denies_non_matching_domain() {
|
||||
let filter = IpFilter::new(&[scoped("10.8.0.2", &["outline.abc.xyz"])], &[]);
|
||||
let ip: IpAddr = "10.8.0.2".parse().unwrap();
|
||||
assert!(!filter.is_allowed_for_domain(&ip, Some("app.abc.xyz")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_domain_scoped_denies_without_domain() {
|
||||
let filter = IpFilter::new(&[scoped("10.8.0.2", &["outline.abc.xyz"])], &[]);
|
||||
let ip: IpAddr = "10.8.0.2".parse().unwrap();
|
||||
// Without domain context, domain-scoped entries cannot match
|
||||
assert!(!filter.is_allowed_for_domain(&ip, None));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_domain_scoped_wildcard_domain() {
|
||||
let filter = IpFilter::new(&[scoped("10.8.0.2", &["*.abc.xyz"])], &[]);
|
||||
let ip: IpAddr = "10.8.0.2".parse().unwrap();
|
||||
assert!(filter.is_allowed_for_domain(&ip, Some("outline.abc.xyz")));
|
||||
assert!(filter.is_allowed_for_domain(&ip, Some("app.abc.xyz")));
|
||||
assert!(!filter.is_allowed_for_domain(&ip, Some("other.com")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_plain_and_domain_scoped_coexist() {
|
||||
let filter = IpFilter::new(
|
||||
&[
|
||||
plain("1.2.3.4"), // full route access
|
||||
scoped("10.8.0.2", &["outline.abc.xyz"]), // scoped access
|
||||
],
|
||||
&[],
|
||||
);
|
||||
|
||||
let admin: IpAddr = "1.2.3.4".parse().unwrap();
|
||||
let vpn: IpAddr = "10.8.0.2".parse().unwrap();
|
||||
let other: IpAddr = "9.9.9.9".parse().unwrap();
|
||||
|
||||
// Admin IP has full access
|
||||
assert!(filter.is_allowed_for_domain(&admin, Some("anything.abc.xyz")));
|
||||
assert!(filter.is_allowed_for_domain(&admin, Some("outline.abc.xyz")));
|
||||
|
||||
// VPN IP only has scoped access
|
||||
assert!(filter.is_allowed_for_domain(&vpn, Some("outline.abc.xyz")));
|
||||
assert!(!filter.is_allowed_for_domain(&vpn, Some("app.abc.xyz")));
|
||||
|
||||
// Unknown IP denied
|
||||
assert!(!filter.is_allowed_for_domain(&other, Some("outline.abc.xyz")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_block_trumps_domain_scoped() {
|
||||
let filter = IpFilter::new(
|
||||
&[scoped("10.8.0.2", &["outline.abc.xyz"])],
|
||||
&["10.8.0.2".to_string()],
|
||||
);
|
||||
let ip: IpAddr = "10.8.0.2".parse().unwrap();
|
||||
assert!(!filter.is_allowed_for_domain(&ip, Some("outline.abc.xyz")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_domain_matches_pattern_fn() {
|
||||
assert!(domain_matches_pattern("example.com", "example.com"));
|
||||
assert!(domain_matches_pattern("*.abc.xyz", "outline.abc.xyz"));
|
||||
assert!(domain_matches_pattern("*.abc.xyz", "app.abc.xyz"));
|
||||
assert!(!domain_matches_pattern("*.abc.xyz", "abc.xyz")); // suffix only, not exact parent
|
||||
assert!(domain_matches_pattern("*", "anything.com"));
|
||||
assert!(!domain_matches_pattern("outline.abc.xyz", "app.abc.xyz"));
|
||||
// Case insensitive
|
||||
assert!(domain_matches_pattern("*.ABC.XYZ", "outline.abc.xyz"));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use jsonwebtoken::{decode, DecodingKey, Validation, Algorithm};
|
||||
use jsonwebtoken::{decode, Algorithm, DecodingKey, Validation};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// JWT claims (minimal structure).
|
||||
@@ -160,10 +160,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_extract_token_bearer() {
|
||||
assert_eq!(
|
||||
JwtValidator::extract_token("Bearer abc123"),
|
||||
Some("abc123")
|
||||
);
|
||||
assert_eq!(JwtValidator::extract_token("Bearer abc123"), Some("abc123"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -2,12 +2,12 @@
|
||||
//!
|
||||
//! IP filtering, rate limiting, and authentication for RustProxy.
|
||||
|
||||
pub mod ip_filter;
|
||||
pub mod rate_limiter;
|
||||
pub mod basic_auth;
|
||||
pub mod ip_filter;
|
||||
pub mod jwt_auth;
|
||||
pub mod rate_limiter;
|
||||
|
||||
pub use ip_filter::*;
|
||||
pub use rate_limiter::*;
|
||||
pub use basic_auth::*;
|
||||
pub use ip_filter::*;
|
||||
pub use jwt_auth::*;
|
||||
pub use rate_limiter::*;
|
||||
|
||||
@@ -79,7 +79,7 @@ mod tests {
|
||||
assert!(limiter.check("client-a"));
|
||||
assert!(limiter.check("client-a"));
|
||||
assert!(!limiter.check("client-a")); // blocked
|
||||
// Different key should still be allowed
|
||||
// Different key should still be allowed
|
||||
assert!(limiter.check("client-b"));
|
||||
assert!(limiter.check("client-b"));
|
||||
}
|
||||
|
||||
@@ -4,8 +4,7 @@
|
||||
//! Account credentials are ephemeral — the consumer owns all persistence.
|
||||
|
||||
use instant_acme::{
|
||||
Account, NewAccount, NewOrder, Identifier, ChallengeType, OrderStatus,
|
||||
AccountCredentials,
|
||||
Account, AccountCredentials, ChallengeType, Identifier, NewAccount, NewOrder, OrderStatus,
|
||||
};
|
||||
use rcgen::{CertificateParams, KeyPair};
|
||||
use thiserror::Error;
|
||||
@@ -89,7 +88,11 @@ impl AcmeClient {
|
||||
F: FnOnce(PendingChallenge) -> Fut,
|
||||
Fut: std::future::Future<Output = Result<(), AcmeError>>,
|
||||
{
|
||||
info!("Starting ACME provisioning for {} via {}", domain, self.directory_url());
|
||||
info!(
|
||||
"Starting ACME provisioning for {} via {}",
|
||||
domain,
|
||||
self.directory_url()
|
||||
);
|
||||
|
||||
// 1. Get or create ACME account
|
||||
let account = self.get_or_create_account().await?;
|
||||
@@ -170,14 +173,14 @@ impl AcmeClient {
|
||||
debug!("Order ready, finalizing...");
|
||||
|
||||
// 6. Generate CSR and finalize
|
||||
let key_pair = KeyPair::generate().map_err(|e| {
|
||||
AcmeError::FinalizationFailed(format!("Key generation failed: {}", e))
|
||||
})?;
|
||||
let key_pair = KeyPair::generate()
|
||||
.map_err(|e| AcmeError::FinalizationFailed(format!("Key generation failed: {}", e)))?;
|
||||
|
||||
let mut params = CertificateParams::new(vec![domain.to_string()]).map_err(|e| {
|
||||
AcmeError::FinalizationFailed(format!("CSR params failed: {}", e))
|
||||
})?;
|
||||
params.distinguished_name.push(rcgen::DnType::CommonName, domain);
|
||||
let mut params = CertificateParams::new(vec![domain.to_string()])
|
||||
.map_err(|e| AcmeError::FinalizationFailed(format!("CSR params failed: {}", e)))?;
|
||||
params
|
||||
.distinguished_name
|
||||
.push(rcgen::DnType::CommonName, domain);
|
||||
|
||||
let csr = params.serialize_request(&key_pair).map_err(|e| {
|
||||
AcmeError::FinalizationFailed(format!("CSR serialization failed: {}", e))
|
||||
@@ -219,9 +222,7 @@ impl AcmeClient {
|
||||
.certificate()
|
||||
.await
|
||||
.map_err(|e| AcmeError::FinalizationFailed(e.to_string()))?
|
||||
.ok_or_else(|| {
|
||||
AcmeError::FinalizationFailed("No certificate returned".to_string())
|
||||
})?;
|
||||
.ok_or_else(|| AcmeError::FinalizationFailed("No certificate returned".to_string()))?;
|
||||
|
||||
let private_key_pem = key_pair.serialize_pem();
|
||||
|
||||
|
||||
@@ -2,8 +2,8 @@ use std::time::{SystemTime, UNIX_EPOCH};
|
||||
use thiserror::Error;
|
||||
use tracing::info;
|
||||
|
||||
use crate::cert_store::{CertStore, CertBundle, CertMetadata, CertSource};
|
||||
use crate::acme::AcmeClient;
|
||||
use crate::cert_store::{CertBundle, CertMetadata, CertSource, CertStore};
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum CertManagerError {
|
||||
@@ -45,17 +45,13 @@ impl CertManager {
|
||||
/// Create an ACME client using this manager's configuration.
|
||||
/// Returns None if no ACME email is configured.
|
||||
pub fn acme_client(&self) -> Option<AcmeClient> {
|
||||
self.acme_email.as_ref().map(|email| {
|
||||
AcmeClient::new(email.clone(), self.use_production)
|
||||
})
|
||||
self.acme_email
|
||||
.as_ref()
|
||||
.map(|email| AcmeClient::new(email.clone(), self.use_production))
|
||||
}
|
||||
|
||||
/// Load a static certificate into the store (infallible — pure cache insert).
|
||||
pub fn load_static(
|
||||
&mut self,
|
||||
domain: String,
|
||||
bundle: CertBundle,
|
||||
) {
|
||||
pub fn load_static(&mut self, domain: String, bundle: CertBundle) {
|
||||
self.store.store(domain, bundle);
|
||||
}
|
||||
|
||||
@@ -108,23 +104,25 @@ impl CertManager {
|
||||
F: FnOnce(String, String) -> Fut,
|
||||
Fut: std::future::Future<Output = ()>,
|
||||
{
|
||||
let acme_client = self.acme_client()
|
||||
.ok_or(CertManagerError::NoEmail)?;
|
||||
let acme_client = self.acme_client().ok_or(CertManagerError::NoEmail)?;
|
||||
|
||||
info!("Renewing certificate for {}", domain);
|
||||
|
||||
let domain_owned = domain.to_string();
|
||||
let result = acme_client.provision(&domain_owned, |pending| {
|
||||
let token = pending.token.clone();
|
||||
let key_auth = pending.key_authorization.clone();
|
||||
async move {
|
||||
challenge_setup(token, key_auth).await;
|
||||
Ok(())
|
||||
}
|
||||
}).await.map_err(|e| CertManagerError::AcmeFailure {
|
||||
domain: domain.to_string(),
|
||||
message: e.to_string(),
|
||||
})?;
|
||||
let result = acme_client
|
||||
.provision(&domain_owned, |pending| {
|
||||
let token = pending.token.clone();
|
||||
let key_auth = pending.key_authorization.clone();
|
||||
async move {
|
||||
challenge_setup(token, key_auth).await;
|
||||
Ok(())
|
||||
}
|
||||
})
|
||||
.await
|
||||
.map_err(|e| CertManagerError::AcmeFailure {
|
||||
domain: domain.to_string(),
|
||||
message: e.to_string(),
|
||||
})?;
|
||||
|
||||
let (cert_pem, key_pem) = result;
|
||||
let now = SystemTime::now()
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use std::collections::HashMap;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Certificate metadata stored alongside certs.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
@@ -90,8 +90,10 @@ mod tests {
|
||||
|
||||
fn make_test_bundle(domain: &str) -> CertBundle {
|
||||
CertBundle {
|
||||
key_pem: "-----BEGIN PRIVATE KEY-----\ntest-key\n-----END PRIVATE KEY-----\n".to_string(),
|
||||
cert_pem: "-----BEGIN CERTIFICATE-----\ntest-cert\n-----END CERTIFICATE-----\n".to_string(),
|
||||
key_pem: "-----BEGIN PRIVATE KEY-----\ntest-key\n-----END PRIVATE KEY-----\n"
|
||||
.to_string(),
|
||||
cert_pem: "-----BEGIN CERTIFICATE-----\ntest-cert\n-----END CERTIFICATE-----\n"
|
||||
.to_string(),
|
||||
ca_pem: None,
|
||||
metadata: CertMetadata {
|
||||
domain: domain.to_string(),
|
||||
@@ -122,7 +124,8 @@ mod tests {
|
||||
let mut store = CertStore::new();
|
||||
|
||||
let mut bundle = make_test_bundle("secure.com");
|
||||
bundle.ca_pem = Some("-----BEGIN CERTIFICATE-----\nca-cert\n-----END CERTIFICATE-----\n".to_string());
|
||||
bundle.ca_pem =
|
||||
Some("-----BEGIN CERTIFICATE-----\nca-cert\n-----END CERTIFICATE-----\n".to_string());
|
||||
store.store("secure.com".to_string(), bundle);
|
||||
|
||||
let loaded = store.get("secure.com").unwrap();
|
||||
@@ -147,7 +150,10 @@ mod tests {
|
||||
fn test_remove_cert() {
|
||||
let mut store = CertStore::new();
|
||||
|
||||
store.store("remove-me.com".to_string(), make_test_bundle("remove-me.com"));
|
||||
store.store(
|
||||
"remove-me.com".to_string(),
|
||||
make_test_bundle("remove-me.com"),
|
||||
);
|
||||
assert!(store.has("remove-me.com"));
|
||||
|
||||
let removed = store.remove("remove-me.com");
|
||||
@@ -165,7 +171,10 @@ mod tests {
|
||||
fn test_wildcard_domain() {
|
||||
let mut store = CertStore::new();
|
||||
|
||||
store.store("*.example.com".to_string(), make_test_bundle("*.example.com"));
|
||||
store.store(
|
||||
"*.example.com".to_string(),
|
||||
make_test_bundle("*.example.com"),
|
||||
);
|
||||
assert!(store.has("*.example.com"));
|
||||
|
||||
let loaded = store.get("*.example.com").unwrap();
|
||||
|
||||
@@ -3,11 +3,11 @@
|
||||
//! TLS certificate management for RustProxy.
|
||||
//! Handles ACME (Let's Encrypt), static certificates, and dynamic SNI resolution.
|
||||
|
||||
pub mod cert_store;
|
||||
pub mod cert_manager;
|
||||
pub mod acme;
|
||||
pub mod cert_manager;
|
||||
pub mod cert_store;
|
||||
pub mod sni_resolver;
|
||||
|
||||
pub use cert_store::*;
|
||||
pub use cert_manager::*;
|
||||
pub use cert_store::*;
|
||||
pub use sni_resolver::*;
|
||||
|
||||
@@ -20,7 +20,6 @@ rustproxy-routing = { workspace = true }
|
||||
rustproxy-tls = { workspace = true }
|
||||
rustproxy-passthrough = { workspace = true }
|
||||
rustproxy-http = { workspace = true }
|
||||
rustproxy-nftables = { workspace = true }
|
||||
rustproxy-metrics = { workspace = true }
|
||||
rustproxy-security = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
|
||||
@@ -13,7 +13,7 @@ use hyper::{Request, Response, StatusCode};
|
||||
use hyper_util::rt::TokioIo;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{debug, info, error};
|
||||
use tracing::{debug, error, info};
|
||||
|
||||
/// ACME HTTP-01 challenge server.
|
||||
pub struct ChallengeServer {
|
||||
@@ -47,7 +47,10 @@ impl ChallengeServer {
|
||||
}
|
||||
|
||||
/// Start the challenge server on the given port.
|
||||
pub async fn start(&mut self, port: u16) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
pub async fn start(
|
||||
&mut self,
|
||||
port: u16,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
let addr = format!("0.0.0.0:{}", port);
|
||||
let listener = TcpListener::bind(&addr).await?;
|
||||
info!("ACME challenge server listening on port {}", port);
|
||||
@@ -101,10 +104,7 @@ impl ChallengeServer {
|
||||
pub async fn stop(&mut self) {
|
||||
self.cancel.cancel();
|
||||
if let Some(handle) = self.handle.take() {
|
||||
let _ = tokio::time::timeout(
|
||||
std::time::Duration::from_secs(5),
|
||||
handle,
|
||||
).await;
|
||||
let _ = tokio::time::timeout(std::time::Duration::from_secs(5), handle).await;
|
||||
}
|
||||
self.challenges.clear();
|
||||
self.cancel = CancellationToken::new();
|
||||
@@ -154,10 +154,14 @@ mod tests {
|
||||
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
|
||||
|
||||
// Fetch the challenge
|
||||
let client = tokio::net::TcpStream::connect("127.0.0.1:19900").await.unwrap();
|
||||
let client = tokio::net::TcpStream::connect("127.0.0.1:19900")
|
||||
.await
|
||||
.unwrap();
|
||||
let io = TokioIo::new(client);
|
||||
let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await.unwrap();
|
||||
tokio::spawn(async move { let _ = conn.await; });
|
||||
tokio::spawn(async move {
|
||||
let _ = conn.await;
|
||||
});
|
||||
|
||||
let req = Request::get("/.well-known/acme-challenge/test-token")
|
||||
.body(Full::new(Bytes::new()))
|
||||
|
||||
+399
-250
File diff suppressed because it is too large
Load Diff
@@ -1,12 +1,12 @@
|
||||
#[global_allocator]
|
||||
static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc;
|
||||
|
||||
use anyhow::Result;
|
||||
use clap::Parser;
|
||||
use tracing_subscriber::EnvFilter;
|
||||
use anyhow::Result;
|
||||
|
||||
use rustproxy::RustProxy;
|
||||
use rustproxy::management;
|
||||
use rustproxy::RustProxy;
|
||||
use rustproxy_config::RustProxyOptions;
|
||||
|
||||
/// RustProxy - High-performance multi-protocol proxy
|
||||
@@ -43,8 +43,7 @@ async fn main() -> Result<()> {
|
||||
tracing_subscriber::fmt()
|
||||
.with_writer(std::io::stderr)
|
||||
.with_env_filter(
|
||||
EnvFilter::try_from_default_env()
|
||||
.unwrap_or_else(|_| EnvFilter::new(&cli.log_level))
|
||||
EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new(&cli.log_level)),
|
||||
)
|
||||
.init();
|
||||
|
||||
@@ -60,11 +59,7 @@ async fn main() -> Result<()> {
|
||||
let options = RustProxyOptions::from_file(&cli.config)
|
||||
.map_err(|e| anyhow::anyhow!("Failed to load config '{}': {}", cli.config, e))?;
|
||||
|
||||
tracing::info!(
|
||||
"Loaded {} routes from {}",
|
||||
options.routes.len(),
|
||||
cli.config
|
||||
);
|
||||
tracing::info!("Loaded {} routes from {}", options.routes.len(), cli.config);
|
||||
|
||||
// Validate-only mode
|
||||
if cli.validate {
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use anyhow::Result;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::io::{AsyncBufReadExt, BufReader};
|
||||
use tracing::{info, error};
|
||||
use tracing::{error, info};
|
||||
|
||||
use crate::RustProxy;
|
||||
use rustproxy_config::RustProxyOptions;
|
||||
@@ -141,15 +141,19 @@ async fn handle_request(
|
||||
"start" => handle_start(&id, &request.params, proxy).await,
|
||||
"stop" => handle_stop(&id, proxy).await,
|
||||
"updateRoutes" => handle_update_routes(&id, &request.params, proxy).await,
|
||||
"setSecurityPolicy" => handle_set_security_policy(&id, &request.params, proxy),
|
||||
"getMetrics" => handle_get_metrics(&id, proxy),
|
||||
"getStatistics" => handle_get_statistics(&id, proxy),
|
||||
"provisionCertificate" => handle_provision_certificate(&id, &request.params, proxy).await,
|
||||
"renewCertificate" => handle_renew_certificate(&id, &request.params, proxy).await,
|
||||
"getCertificateStatus" => handle_get_certificate_status(&id, &request.params, proxy).await,
|
||||
"getListeningPorts" => handle_get_listening_ports(&id, proxy),
|
||||
"getNftablesStatus" => handle_get_nftables_status(&id, proxy).await,
|
||||
"setSocketHandlerRelay" => handle_set_socket_handler_relay(&id, &request.params, proxy).await,
|
||||
"setDatagramHandlerRelay" => handle_set_datagram_handler_relay(&id, &request.params, proxy).await,
|
||||
"setSocketHandlerRelay" => {
|
||||
handle_set_socket_handler_relay(&id, &request.params, proxy).await
|
||||
}
|
||||
"setDatagramHandlerRelay" => {
|
||||
handle_set_datagram_handler_relay(&id, &request.params, proxy).await
|
||||
}
|
||||
"addListeningPort" => handle_add_listening_port(&id, &request.params, proxy).await,
|
||||
"removeListeningPort" => handle_remove_listening_port(&id, &request.params, proxy).await,
|
||||
"loadCertificate" => handle_load_certificate(&id, &request.params, proxy).await,
|
||||
@@ -168,7 +172,12 @@ async fn handle_start(
|
||||
|
||||
let config = match params.get("config") {
|
||||
Some(config) => config,
|
||||
None => return ManagementResponse::err(id.to_string(), "Missing 'config' parameter".to_string()),
|
||||
None => {
|
||||
return ManagementResponse::err(
|
||||
id.to_string(),
|
||||
"Missing 'config' parameter".to_string(),
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
let options: RustProxyOptions = match serde_json::from_value(config.clone()) {
|
||||
@@ -177,38 +186,31 @@ async fn handle_start(
|
||||
};
|
||||
|
||||
match RustProxy::new(options) {
|
||||
Ok(mut p) => {
|
||||
match p.start().await {
|
||||
Ok(()) => {
|
||||
send_event("started", serde_json::json!({}));
|
||||
*proxy = Some(p);
|
||||
ManagementResponse::ok(id.to_string(), serde_json::json!({}))
|
||||
}
|
||||
Err(e) => {
|
||||
send_event("error", serde_json::json!({"message": format!("{}", e)}));
|
||||
ManagementResponse::err(id.to_string(), format!("Failed to start: {}", e))
|
||||
}
|
||||
Ok(mut p) => match p.start().await {
|
||||
Ok(()) => {
|
||||
send_event("started", serde_json::json!({}));
|
||||
*proxy = Some(p);
|
||||
ManagementResponse::ok(id.to_string(), serde_json::json!({}))
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
send_event("error", serde_json::json!({"message": format!("{}", e)}));
|
||||
ManagementResponse::err(id.to_string(), format!("Failed to start: {}", e))
|
||||
}
|
||||
},
|
||||
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to create proxy: {}", e)),
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_stop(
|
||||
id: &str,
|
||||
proxy: &mut Option<RustProxy>,
|
||||
) -> ManagementResponse {
|
||||
async fn handle_stop(id: &str, proxy: &mut Option<RustProxy>) -> ManagementResponse {
|
||||
match proxy.as_mut() {
|
||||
Some(p) => {
|
||||
match p.stop().await {
|
||||
Ok(()) => {
|
||||
*proxy = None;
|
||||
send_event("stopped", serde_json::json!({}));
|
||||
ManagementResponse::ok(id.to_string(), serde_json::json!({}))
|
||||
}
|
||||
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to stop: {}", e)),
|
||||
Some(p) => match p.stop().await {
|
||||
Ok(()) => {
|
||||
*proxy = None;
|
||||
send_event("stopped", serde_json::json!({}));
|
||||
ManagementResponse::ok(id.to_string(), serde_json::json!({}))
|
||||
}
|
||||
}
|
||||
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to stop: {}", e)),
|
||||
},
|
||||
None => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
|
||||
}
|
||||
}
|
||||
@@ -225,7 +227,12 @@ async fn handle_update_routes(
|
||||
|
||||
let routes = match params.get("routes") {
|
||||
Some(routes) => routes,
|
||||
None => return ManagementResponse::err(id.to_string(), "Missing 'routes' parameter".to_string()),
|
||||
None => {
|
||||
return ManagementResponse::err(
|
||||
id.to_string(),
|
||||
"Missing 'routes' parameter".to_string(),
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
let routes: Vec<rustproxy_config::RouteConfig> = match serde_json::from_value(routes.clone()) {
|
||||
@@ -235,36 +242,72 @@ async fn handle_update_routes(
|
||||
|
||||
match p.update_routes(routes).await {
|
||||
Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to update routes: {}", e)),
|
||||
Err(e) => {
|
||||
ManagementResponse::err(id.to_string(), format!("Failed to update routes: {}", e))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_get_metrics(
|
||||
fn handle_set_security_policy(
|
||||
id: &str,
|
||||
proxy: &Option<RustProxy>,
|
||||
params: &serde_json::Value,
|
||||
proxy: &mut Option<RustProxy>,
|
||||
) -> ManagementResponse {
|
||||
let p = match proxy.as_mut() {
|
||||
Some(p) => p,
|
||||
None => return ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
|
||||
};
|
||||
|
||||
let policy = match params.get("policy") {
|
||||
Some(policy) => policy,
|
||||
None => {
|
||||
return ManagementResponse::err(
|
||||
id.to_string(),
|
||||
"Missing 'policy' parameter".to_string(),
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
let policy: rustproxy_config::SecurityPolicy = match serde_json::from_value(policy.clone()) {
|
||||
Ok(policy) => policy,
|
||||
Err(e) => {
|
||||
return ManagementResponse::err(
|
||||
id.to_string(),
|
||||
format!("Invalid security policy: {}", e),
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
p.set_security_policy(policy);
|
||||
ManagementResponse::ok(id.to_string(), serde_json::json!({}))
|
||||
}
|
||||
|
||||
fn handle_get_metrics(id: &str, proxy: &Option<RustProxy>) -> ManagementResponse {
|
||||
match proxy.as_ref() {
|
||||
Some(p) => {
|
||||
let metrics = p.get_metrics();
|
||||
match serde_json::to_value(&metrics) {
|
||||
Ok(v) => ManagementResponse::ok(id.to_string(), v),
|
||||
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to serialize metrics: {}", e)),
|
||||
Err(e) => ManagementResponse::err(
|
||||
id.to_string(),
|
||||
format!("Failed to serialize metrics: {}", e),
|
||||
),
|
||||
}
|
||||
}
|
||||
None => ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_get_statistics(
|
||||
id: &str,
|
||||
proxy: &Option<RustProxy>,
|
||||
) -> ManagementResponse {
|
||||
fn handle_get_statistics(id: &str, proxy: &Option<RustProxy>) -> ManagementResponse {
|
||||
match proxy.as_ref() {
|
||||
Some(p) => {
|
||||
let stats = p.get_statistics();
|
||||
match serde_json::to_value(&stats) {
|
||||
Ok(v) => ManagementResponse::ok(id.to_string(), v),
|
||||
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to serialize statistics: {}", e)),
|
||||
Err(e) => ManagementResponse::err(
|
||||
id.to_string(),
|
||||
format!("Failed to serialize statistics: {}", e),
|
||||
),
|
||||
}
|
||||
}
|
||||
None => ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
|
||||
@@ -283,12 +326,20 @@ async fn handle_provision_certificate(
|
||||
|
||||
let route_name = match params.get("routeName").and_then(|v| v.as_str()) {
|
||||
Some(name) => name.to_string(),
|
||||
None => return ManagementResponse::err(id.to_string(), "Missing 'routeName' parameter".to_string()),
|
||||
None => {
|
||||
return ManagementResponse::err(
|
||||
id.to_string(),
|
||||
"Missing 'routeName' parameter".to_string(),
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
match p.provision_certificate(&route_name).await {
|
||||
Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to provision certificate: {}", e)),
|
||||
Err(e) => ManagementResponse::err(
|
||||
id.to_string(),
|
||||
format!("Failed to provision certificate: {}", e),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -304,12 +355,20 @@ async fn handle_renew_certificate(
|
||||
|
||||
let route_name = match params.get("routeName").and_then(|v| v.as_str()) {
|
||||
Some(name) => name.to_string(),
|
||||
None => return ManagementResponse::err(id.to_string(), "Missing 'routeName' parameter".to_string()),
|
||||
None => {
|
||||
return ManagementResponse::err(
|
||||
id.to_string(),
|
||||
"Missing 'routeName' parameter".to_string(),
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
match p.renew_certificate(&route_name).await {
|
||||
Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to renew certificate: {}", e)),
|
||||
Err(e) => ManagementResponse::err(
|
||||
id.to_string(),
|
||||
format!("Failed to renew certificate: {}", e),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -325,24 +384,29 @@ async fn handle_get_certificate_status(
|
||||
|
||||
let route_name = match params.get("routeName").and_then(|v| v.as_str()) {
|
||||
Some(name) => name,
|
||||
None => return ManagementResponse::err(id.to_string(), "Missing 'routeName' parameter".to_string()),
|
||||
None => {
|
||||
return ManagementResponse::err(
|
||||
id.to_string(),
|
||||
"Missing 'routeName' parameter".to_string(),
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
match p.get_certificate_status(route_name).await {
|
||||
Some(status) => ManagementResponse::ok(id.to_string(), serde_json::json!({
|
||||
"domain": status.domain,
|
||||
"source": status.source,
|
||||
"expiresAt": status.expires_at,
|
||||
"isValid": status.is_valid,
|
||||
})),
|
||||
Some(status) => ManagementResponse::ok(
|
||||
id.to_string(),
|
||||
serde_json::json!({
|
||||
"domain": status.domain,
|
||||
"source": status.source,
|
||||
"expiresAt": status.expires_at,
|
||||
"isValid": status.is_valid,
|
||||
}),
|
||||
),
|
||||
None => ManagementResponse::ok(id.to_string(), serde_json::Value::Null),
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_get_listening_ports(
|
||||
id: &str,
|
||||
proxy: &Option<RustProxy>,
|
||||
) -> ManagementResponse {
|
||||
fn handle_get_listening_ports(id: &str, proxy: &Option<RustProxy>) -> ManagementResponse {
|
||||
match proxy.as_ref() {
|
||||
Some(p) => {
|
||||
let ports = p.get_listening_ports();
|
||||
@@ -352,26 +416,6 @@ fn handle_get_listening_ports(
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_get_nftables_status(
|
||||
id: &str,
|
||||
proxy: &Option<RustProxy>,
|
||||
) -> ManagementResponse {
|
||||
match proxy.as_ref() {
|
||||
Some(p) => {
|
||||
match p.get_nftables_status().await {
|
||||
Ok(status) => {
|
||||
match serde_json::to_value(&status) {
|
||||
Ok(v) => ManagementResponse::ok(id.to_string(), v),
|
||||
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to serialize: {}", e)),
|
||||
}
|
||||
}
|
||||
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to get status: {}", e)),
|
||||
}
|
||||
}
|
||||
None => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_set_socket_handler_relay(
|
||||
id: &str,
|
||||
params: &serde_json::Value,
|
||||
@@ -382,7 +426,8 @@ async fn handle_set_socket_handler_relay(
|
||||
None => return ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
|
||||
};
|
||||
|
||||
let socket_path = params.get("socketPath")
|
||||
let socket_path = params
|
||||
.get("socketPath")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|s| s.to_string());
|
||||
|
||||
@@ -402,7 +447,8 @@ async fn handle_set_datagram_handler_relay(
|
||||
None => return ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
|
||||
};
|
||||
|
||||
let socket_path = params.get("socketPath")
|
||||
let socket_path = params
|
||||
.get("socketPath")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|s| s.to_string());
|
||||
|
||||
@@ -424,12 +470,17 @@ async fn handle_add_listening_port(
|
||||
|
||||
let port = match params.get("port").and_then(|v| v.as_u64()) {
|
||||
Some(port) => port as u16,
|
||||
None => return ManagementResponse::err(id.to_string(), "Missing 'port' parameter".to_string()),
|
||||
None => {
|
||||
return ManagementResponse::err(id.to_string(), "Missing 'port' parameter".to_string())
|
||||
}
|
||||
};
|
||||
|
||||
match p.add_listening_port(port).await {
|
||||
Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to add port {}: {}", port, e)),
|
||||
Err(e) => ManagementResponse::err(
|
||||
id.to_string(),
|
||||
format!("Failed to add port {}: {}", port, e),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -445,12 +496,17 @@ async fn handle_remove_listening_port(
|
||||
|
||||
let port = match params.get("port").and_then(|v| v.as_u64()) {
|
||||
Some(port) => port as u16,
|
||||
None => return ManagementResponse::err(id.to_string(), "Missing 'port' parameter".to_string()),
|
||||
None => {
|
||||
return ManagementResponse::err(id.to_string(), "Missing 'port' parameter".to_string())
|
||||
}
|
||||
};
|
||||
|
||||
match p.remove_listening_port(port).await {
|
||||
Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to remove port {}: {}", port, e)),
|
||||
Err(e) => ManagementResponse::err(
|
||||
id.to_string(),
|
||||
format!("Failed to remove port {}: {}", port, e),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -466,26 +522,41 @@ async fn handle_load_certificate(
|
||||
|
||||
let domain = match params.get("domain").and_then(|v| v.as_str()) {
|
||||
Some(d) => d.to_string(),
|
||||
None => return ManagementResponse::err(id.to_string(), "Missing 'domain' parameter".to_string()),
|
||||
None => {
|
||||
return ManagementResponse::err(
|
||||
id.to_string(),
|
||||
"Missing 'domain' parameter".to_string(),
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
let cert = match params.get("cert").and_then(|v| v.as_str()) {
|
||||
Some(c) => c.to_string(),
|
||||
None => return ManagementResponse::err(id.to_string(), "Missing 'cert' parameter".to_string()),
|
||||
None => {
|
||||
return ManagementResponse::err(id.to_string(), "Missing 'cert' parameter".to_string())
|
||||
}
|
||||
};
|
||||
|
||||
let key = match params.get("key").and_then(|v| v.as_str()) {
|
||||
Some(k) => k.to_string(),
|
||||
None => return ManagementResponse::err(id.to_string(), "Missing 'key' parameter".to_string()),
|
||||
None => {
|
||||
return ManagementResponse::err(id.to_string(), "Missing 'key' parameter".to_string())
|
||||
}
|
||||
};
|
||||
|
||||
let ca = params.get("ca").and_then(|v| v.as_str()).map(|s| s.to_string());
|
||||
let ca = params
|
||||
.get("ca")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|s| s.to_string());
|
||||
|
||||
info!("loadCertificate: domain={}", domain);
|
||||
|
||||
// Load cert into cert manager and hot-swap TLS config
|
||||
match p.load_certificate(&domain, cert, key, ca).await {
|
||||
Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to load certificate for {}: {}", domain, e)),
|
||||
Err(e) => ManagementResponse::err(
|
||||
id.to_string(),
|
||||
format!("Failed to load certificate for {}: {}", domain, e),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -136,7 +136,8 @@ pub async fn start_http_echo_backend(port: u16, backend_name: &str) -> JoinHandl
|
||||
let path = parts.get(1).copied().unwrap_or("/");
|
||||
|
||||
// Extract Host header
|
||||
let host = req_str.lines()
|
||||
let host = req_str
|
||||
.lines()
|
||||
.find(|l| l.to_lowercase().starts_with("host:"))
|
||||
.map(|l| l[5..].trim())
|
||||
.unwrap_or("unknown");
|
||||
@@ -297,8 +298,6 @@ pub fn make_test_route(
|
||||
load_balancing: None,
|
||||
advanced: None,
|
||||
options: None,
|
||||
forwarding_engine: None,
|
||||
nftables: None,
|
||||
send_proxy_protocol: None,
|
||||
udp: None,
|
||||
},
|
||||
@@ -338,7 +337,8 @@ pub async fn start_ws_echo_backend(port: u16) -> JoinHandle<()> {
|
||||
let req_str = String::from_utf8_lossy(&buf[..n]);
|
||||
|
||||
// Extract Sec-WebSocket-Key for proper handshake
|
||||
let ws_key = req_str.lines()
|
||||
let ws_key = req_str
|
||||
.lines()
|
||||
.find(|l| l.to_lowercase().starts_with("sec-websocket-key:"))
|
||||
.map(|l| l.split(':').nth(1).unwrap_or("").trim().to_string())
|
||||
.unwrap_or_default();
|
||||
@@ -380,7 +380,9 @@ pub fn generate_self_signed_cert(domain: &str) -> (String, String) {
|
||||
use rcgen::{CertificateParams, KeyPair};
|
||||
|
||||
let mut params = CertificateParams::new(vec![domain.to_string()]).unwrap();
|
||||
params.distinguished_name.push(rcgen::DnType::CommonName, domain);
|
||||
params
|
||||
.distinguished_name
|
||||
.push(rcgen::DnType::CommonName, domain);
|
||||
|
||||
let key_pair = KeyPair::generate().unwrap();
|
||||
let cert = params.self_signed(&key_pair).unwrap();
|
||||
@@ -460,11 +462,7 @@ pub fn make_tls_terminate_route(
|
||||
|
||||
/// Start a TLS WebSocket echo backend: accepts TLS, performs WS handshake, then echoes data.
|
||||
/// Combines TLS acceptance (like `start_tls_http_backend`) with WebSocket echo (like `start_ws_echo_backend`).
|
||||
pub async fn start_tls_ws_echo_backend(
|
||||
port: u16,
|
||||
cert_pem: &str,
|
||||
key_pem: &str,
|
||||
) -> JoinHandle<()> {
|
||||
pub async fn start_tls_ws_echo_backend(port: u16, cert_pem: &str, key_pem: &str) -> JoinHandle<()> {
|
||||
use std::sync::Arc;
|
||||
|
||||
let acceptor = rustproxy_passthrough::build_tls_acceptor(cert_pem, key_pem)
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
mod common;
|
||||
|
||||
use bytes::Buf;
|
||||
use common::*;
|
||||
use rustproxy::RustProxy;
|
||||
use rustproxy_config::{RustProxyOptions, TransportProtocol, RouteUdp, RouteQuic};
|
||||
use bytes::Buf;
|
||||
use rustproxy_config::{RouteQuic, RouteUdp, RustProxyOptions, TransportProtocol};
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Build a route that listens on UDP with HTTP/3 enabled and TLS terminate.
|
||||
@@ -14,7 +14,14 @@ fn make_h3_route(
|
||||
cert_pem: &str,
|
||||
key_pem: &str,
|
||||
) -> rustproxy_config::RouteConfig {
|
||||
let mut route = make_tls_terminate_route(port, "localhost", target_host, target_port, cert_pem, key_pem);
|
||||
let mut route = make_tls_terminate_route(
|
||||
port,
|
||||
"localhost",
|
||||
target_host,
|
||||
target_port,
|
||||
cert_pem,
|
||||
key_pem,
|
||||
);
|
||||
route.route_match.transport = Some(TransportProtocol::All);
|
||||
// Keep domain="localhost" from make_tls_terminate_route — needed for TLS cert extraction
|
||||
route.action.udp = Some(RouteUdp {
|
||||
@@ -89,11 +96,9 @@ async fn test_h3_response_stream_finishes() {
|
||||
.await
|
||||
.expect("QUIC handshake failed");
|
||||
|
||||
let (mut driver, mut send_request) = h3::client::new(
|
||||
h3_quinn::Connection::new(connection),
|
||||
)
|
||||
.await
|
||||
.expect("H3 connection setup failed");
|
||||
let (mut driver, mut send_request) = h3::client::new(h3_quinn::Connection::new(connection))
|
||||
.await
|
||||
.expect("H3 connection setup failed");
|
||||
|
||||
// Drive the H3 connection in background
|
||||
tokio::spawn(async move {
|
||||
@@ -108,33 +113,46 @@ async fn test_h3_response_stream_finishes() {
|
||||
.body(())
|
||||
.unwrap();
|
||||
|
||||
let mut stream = send_request.send_request(req).await
|
||||
let mut stream = send_request
|
||||
.send_request(req)
|
||||
.await
|
||||
.expect("Failed to send H3 request");
|
||||
stream.finish().await
|
||||
stream
|
||||
.finish()
|
||||
.await
|
||||
.expect("Failed to finish sending H3 request body");
|
||||
|
||||
// 6. Read response headers
|
||||
let resp = stream.recv_response().await
|
||||
let resp = stream
|
||||
.recv_response()
|
||||
.await
|
||||
.expect("Failed to receive H3 response");
|
||||
assert_eq!(resp.status(), http::StatusCode::OK,
|
||||
"Expected 200 OK, got {}", resp.status());
|
||||
assert_eq!(
|
||||
resp.status(),
|
||||
http::StatusCode::OK,
|
||||
"Expected 200 OK, got {}",
|
||||
resp.status()
|
||||
);
|
||||
|
||||
// 7. Read body and verify stream ends (FIN received)
|
||||
// This is the critical assertion: recv_data() must return None (stream ended)
|
||||
// within the timeout, NOT hang forever waiting for a FIN that never arrives.
|
||||
let result = with_timeout(async {
|
||||
let mut total = 0usize;
|
||||
while let Some(chunk) = stream.recv_data().await.expect("H3 data receive error") {
|
||||
total += chunk.remaining();
|
||||
}
|
||||
// recv_data() returned None => stream ended (FIN received)
|
||||
total
|
||||
}, 10)
|
||||
let result = with_timeout(
|
||||
async {
|
||||
let mut total = 0usize;
|
||||
while let Some(chunk) = stream.recv_data().await.expect("H3 data receive error") {
|
||||
total += chunk.remaining();
|
||||
}
|
||||
// recv_data() returned None => stream ended (FIN received)
|
||||
total
|
||||
},
|
||||
10,
|
||||
)
|
||||
.await;
|
||||
|
||||
let bytes_received = result.expect(
|
||||
"TIMEOUT: H3 stream never ended (FIN not received by client). \
|
||||
The proxy sent all response data but failed to send the QUIC stream FIN."
|
||||
The proxy sent all response data but failed to send the QUIC stream FIN.",
|
||||
);
|
||||
assert_eq!(
|
||||
bytes_received,
|
||||
|
||||
@@ -43,17 +43,32 @@ async fn test_http_forward_basic() {
|
||||
proxy.start().await.unwrap();
|
||||
assert!(wait_for_port(proxy_port, 2000).await);
|
||||
|
||||
let result = with_timeout(async {
|
||||
let response = send_http_request(proxy_port, "anyhost.com", "GET", "/hello").await;
|
||||
let body = extract_body(&response);
|
||||
body.to_string()
|
||||
}, 10)
|
||||
let result = with_timeout(
|
||||
async {
|
||||
let response = send_http_request(proxy_port, "anyhost.com", "GET", "/hello").await;
|
||||
let body = extract_body(&response);
|
||||
body.to_string()
|
||||
},
|
||||
10,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(result.contains(r#""method":"GET"#), "Expected GET method, got: {}", result);
|
||||
assert!(result.contains(r#""path":"/hello"#), "Expected /hello path, got: {}", result);
|
||||
assert!(result.contains(r#""backend":"main"#), "Expected main backend, got: {}", result);
|
||||
assert!(
|
||||
result.contains(r#""method":"GET"#),
|
||||
"Expected GET method, got: {}",
|
||||
result
|
||||
);
|
||||
assert!(
|
||||
result.contains(r#""path":"/hello"#),
|
||||
"Expected /hello path, got: {}",
|
||||
result
|
||||
);
|
||||
assert!(
|
||||
result.contains(r#""backend":"main"#),
|
||||
"Expected main backend, got: {}",
|
||||
result
|
||||
);
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
}
|
||||
@@ -69,8 +84,18 @@ async fn test_http_forward_host_routing() {
|
||||
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![
|
||||
make_test_route(proxy_port, Some("alpha.example.com"), "127.0.0.1", backend1_port),
|
||||
make_test_route(proxy_port, Some("beta.example.com"), "127.0.0.1", backend2_port),
|
||||
make_test_route(
|
||||
proxy_port,
|
||||
Some("alpha.example.com"),
|
||||
"127.0.0.1",
|
||||
backend1_port,
|
||||
),
|
||||
make_test_route(
|
||||
proxy_port,
|
||||
Some("beta.example.com"),
|
||||
"127.0.0.1",
|
||||
backend2_port,
|
||||
),
|
||||
],
|
||||
..Default::default()
|
||||
};
|
||||
@@ -80,24 +105,38 @@ async fn test_http_forward_host_routing() {
|
||||
assert!(wait_for_port(proxy_port, 2000).await);
|
||||
|
||||
// Test alpha domain
|
||||
let alpha_result = with_timeout(async {
|
||||
let response = send_http_request(proxy_port, "alpha.example.com", "GET", "/").await;
|
||||
extract_body(&response).to_string()
|
||||
}, 10)
|
||||
let alpha_result = with_timeout(
|
||||
async {
|
||||
let response = send_http_request(proxy_port, "alpha.example.com", "GET", "/").await;
|
||||
extract_body(&response).to_string()
|
||||
},
|
||||
10,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(alpha_result.contains(r#""backend":"alpha"#), "Expected alpha backend, got: {}", alpha_result);
|
||||
assert!(
|
||||
alpha_result.contains(r#""backend":"alpha"#),
|
||||
"Expected alpha backend, got: {}",
|
||||
alpha_result
|
||||
);
|
||||
|
||||
// Test beta domain
|
||||
let beta_result = with_timeout(async {
|
||||
let response = send_http_request(proxy_port, "beta.example.com", "GET", "/").await;
|
||||
extract_body(&response).to_string()
|
||||
}, 10)
|
||||
let beta_result = with_timeout(
|
||||
async {
|
||||
let response = send_http_request(proxy_port, "beta.example.com", "GET", "/").await;
|
||||
extract_body(&response).to_string()
|
||||
},
|
||||
10,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(beta_result.contains(r#""backend":"beta"#), "Expected beta backend, got: {}", beta_result);
|
||||
assert!(
|
||||
beta_result.contains(r#""backend":"beta"#),
|
||||
"Expected beta backend, got: {}",
|
||||
beta_result
|
||||
);
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
}
|
||||
@@ -127,24 +166,38 @@ async fn test_http_forward_path_routing() {
|
||||
assert!(wait_for_port(proxy_port, 2000).await);
|
||||
|
||||
// Test API path
|
||||
let api_result = with_timeout(async {
|
||||
let response = send_http_request(proxy_port, "any.com", "GET", "/api/users").await;
|
||||
extract_body(&response).to_string()
|
||||
}, 10)
|
||||
let api_result = with_timeout(
|
||||
async {
|
||||
let response = send_http_request(proxy_port, "any.com", "GET", "/api/users").await;
|
||||
extract_body(&response).to_string()
|
||||
},
|
||||
10,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(api_result.contains(r#""backend":"api"#), "Expected api backend, got: {}", api_result);
|
||||
assert!(
|
||||
api_result.contains(r#""backend":"api"#),
|
||||
"Expected api backend, got: {}",
|
||||
api_result
|
||||
);
|
||||
|
||||
// Test web path (no /api prefix)
|
||||
let web_result = with_timeout(async {
|
||||
let response = send_http_request(proxy_port, "any.com", "GET", "/index.html").await;
|
||||
extract_body(&response).to_string()
|
||||
}, 10)
|
||||
let web_result = with_timeout(
|
||||
async {
|
||||
let response = send_http_request(proxy_port, "any.com", "GET", "/index.html").await;
|
||||
extract_body(&response).to_string()
|
||||
},
|
||||
10,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(web_result.contains(r#""backend":"web"#), "Expected web backend, got: {}", web_result);
|
||||
assert!(
|
||||
web_result.contains(r#""backend":"web"#),
|
||||
"Expected web backend, got: {}",
|
||||
web_result
|
||||
);
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
}
|
||||
@@ -184,9 +237,18 @@ async fn test_http_forward_cors_preflight() {
|
||||
.unwrap();
|
||||
|
||||
// Should get 204 No Content with CORS headers
|
||||
assert!(result.contains("204"), "Expected 204 status, got: {}", result);
|
||||
assert!(result.to_lowercase().contains("access-control-allow-origin"),
|
||||
"Expected CORS header, got: {}", result);
|
||||
assert!(
|
||||
result.contains("204"),
|
||||
"Expected 204 status, got: {}",
|
||||
result
|
||||
);
|
||||
assert!(
|
||||
result
|
||||
.to_lowercase()
|
||||
.contains("access-control-allow-origin"),
|
||||
"Expected CORS header, got: {}",
|
||||
result
|
||||
);
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
}
|
||||
@@ -208,15 +270,22 @@ async fn test_http_forward_backend_error() {
|
||||
proxy.start().await.unwrap();
|
||||
assert!(wait_for_port(proxy_port, 2000).await);
|
||||
|
||||
let result = with_timeout(async {
|
||||
let response = send_http_request(proxy_port, "example.com", "GET", "/fail").await;
|
||||
response
|
||||
}, 10)
|
||||
let result = with_timeout(
|
||||
async {
|
||||
let response = send_http_request(proxy_port, "example.com", "GET", "/fail").await;
|
||||
response
|
||||
},
|
||||
10,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Proxy should relay the 500 from backend
|
||||
assert!(result.contains("500"), "Expected 500 status, got: {}", result);
|
||||
assert!(
|
||||
result.contains("500"),
|
||||
"Expected 500 status, got: {}",
|
||||
result
|
||||
);
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
}
|
||||
@@ -227,7 +296,12 @@ async fn test_http_forward_no_route_matched() {
|
||||
|
||||
// Create a route only for a specific domain
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![make_test_route(proxy_port, Some("known.example.com"), "127.0.0.1", 9999)],
|
||||
routes: vec![make_test_route(
|
||||
proxy_port,
|
||||
Some("known.example.com"),
|
||||
"127.0.0.1",
|
||||
9999,
|
||||
)],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
@@ -235,15 +309,22 @@ async fn test_http_forward_no_route_matched() {
|
||||
proxy.start().await.unwrap();
|
||||
assert!(wait_for_port(proxy_port, 2000).await);
|
||||
|
||||
let result = with_timeout(async {
|
||||
let response = send_http_request(proxy_port, "unknown.example.com", "GET", "/").await;
|
||||
response
|
||||
}, 10)
|
||||
let result = with_timeout(
|
||||
async {
|
||||
let response = send_http_request(proxy_port, "unknown.example.com", "GET", "/").await;
|
||||
response
|
||||
},
|
||||
10,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Should get 502 Bad Gateway (no route matched)
|
||||
assert!(result.contains("502"), "Expected 502 status, got: {}", result);
|
||||
assert!(
|
||||
result.contains("502"),
|
||||
"Expected 502 status, got: {}",
|
||||
result
|
||||
);
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
}
|
||||
@@ -262,15 +343,22 @@ async fn test_http_forward_backend_unavailable() {
|
||||
proxy.start().await.unwrap();
|
||||
assert!(wait_for_port(proxy_port, 2000).await);
|
||||
|
||||
let result = with_timeout(async {
|
||||
let response = send_http_request(proxy_port, "example.com", "GET", "/").await;
|
||||
response
|
||||
}, 10)
|
||||
let result = with_timeout(
|
||||
async {
|
||||
let response = send_http_request(proxy_port, "example.com", "GET", "/").await;
|
||||
response
|
||||
},
|
||||
10,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Should get 502 Bad Gateway (backend unavailable)
|
||||
assert!(result.contains("502"), "Expected 502 status, got: {}", result);
|
||||
assert!(
|
||||
result.contains("502"),
|
||||
"Expected 502 status, got: {}",
|
||||
result
|
||||
);
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
}
|
||||
@@ -286,7 +374,12 @@ async fn test_https_terminate_http_forward() {
|
||||
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![make_tls_terminate_route(
|
||||
proxy_port, domain, "127.0.0.1", backend_port, &cert_pem, &key_pem,
|
||||
proxy_port,
|
||||
domain,
|
||||
"127.0.0.1",
|
||||
backend_port,
|
||||
&cert_pem,
|
||||
&key_pem,
|
||||
)],
|
||||
..Default::default()
|
||||
};
|
||||
@@ -295,38 +388,53 @@ async fn test_https_terminate_http_forward() {
|
||||
proxy.start().await.unwrap();
|
||||
assert!(wait_for_port(proxy_port, 2000).await);
|
||||
|
||||
let result = with_timeout(async {
|
||||
let _ = rustls::crypto::ring::default_provider().install_default();
|
||||
let tls_config = rustls::ClientConfig::builder()
|
||||
.dangerous()
|
||||
.with_custom_certificate_verifier(std::sync::Arc::new(InsecureVerifier))
|
||||
.with_no_client_auth();
|
||||
let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(tls_config));
|
||||
let result = with_timeout(
|
||||
async {
|
||||
let _ = rustls::crypto::ring::default_provider().install_default();
|
||||
let tls_config = rustls::ClientConfig::builder()
|
||||
.dangerous()
|
||||
.with_custom_certificate_verifier(std::sync::Arc::new(InsecureVerifier))
|
||||
.with_no_client_auth();
|
||||
let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(tls_config));
|
||||
|
||||
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
let server_name = rustls::pki_types::ServerName::try_from(domain.to_string()).unwrap();
|
||||
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
|
||||
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
let server_name = rustls::pki_types::ServerName::try_from(domain.to_string()).unwrap();
|
||||
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
|
||||
|
||||
// Send HTTP request through TLS
|
||||
let request = format!(
|
||||
"GET /api/data HTTP/1.1\r\nHost: {}\r\nConnection: close\r\n\r\n",
|
||||
domain
|
||||
);
|
||||
tls_stream.write_all(request.as_bytes()).await.unwrap();
|
||||
// Send HTTP request through TLS
|
||||
let request = format!(
|
||||
"GET /api/data HTTP/1.1\r\nHost: {}\r\nConnection: close\r\n\r\n",
|
||||
domain
|
||||
);
|
||||
tls_stream.write_all(request.as_bytes()).await.unwrap();
|
||||
|
||||
let mut response = Vec::new();
|
||||
tls_stream.read_to_end(&mut response).await.unwrap();
|
||||
String::from_utf8_lossy(&response).to_string()
|
||||
}, 10)
|
||||
let mut response = Vec::new();
|
||||
tls_stream.read_to_end(&mut response).await.unwrap();
|
||||
String::from_utf8_lossy(&response).to_string()
|
||||
},
|
||||
10,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let body = extract_body(&result);
|
||||
assert!(body.contains(r#""method":"GET"#), "Expected GET, got: {}", body);
|
||||
assert!(body.contains(r#""path":"/api/data"#), "Expected /api/data, got: {}", body);
|
||||
assert!(body.contains(r#""backend":"tls-backend"#), "Expected tls-backend, got: {}", body);
|
||||
assert!(
|
||||
body.contains(r#""method":"GET"#),
|
||||
"Expected GET, got: {}",
|
||||
body
|
||||
);
|
||||
assert!(
|
||||
body.contains(r#""path":"/api/data"#),
|
||||
"Expected /api/data, got: {}",
|
||||
body
|
||||
);
|
||||
assert!(
|
||||
body.contains(r#""backend":"tls-backend"#),
|
||||
"Expected tls-backend, got: {}",
|
||||
body
|
||||
);
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
}
|
||||
@@ -347,59 +455,68 @@ async fn test_websocket_through_proxy() {
|
||||
proxy.start().await.unwrap();
|
||||
assert!(wait_for_port(proxy_port, 2000).await);
|
||||
|
||||
let result = with_timeout(async {
|
||||
let mut stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
let result = with_timeout(
|
||||
async {
|
||||
let mut stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Send WebSocket upgrade request
|
||||
let request = format!(
|
||||
"GET /ws HTTP/1.1\r\n\
|
||||
// Send WebSocket upgrade request
|
||||
let request = format!(
|
||||
"GET /ws HTTP/1.1\r\n\
|
||||
Host: example.com\r\n\
|
||||
Upgrade: websocket\r\n\
|
||||
Connection: Upgrade\r\n\
|
||||
Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n\
|
||||
Sec-WebSocket-Version: 13\r\n\
|
||||
\r\n"
|
||||
);
|
||||
stream.write_all(request.as_bytes()).await.unwrap();
|
||||
);
|
||||
stream.write_all(request.as_bytes()).await.unwrap();
|
||||
|
||||
// Read the 101 response
|
||||
let mut response_buf = Vec::with_capacity(4096);
|
||||
let mut temp = [0u8; 1];
|
||||
loop {
|
||||
let n = stream.read(&mut temp).await.unwrap();
|
||||
if n == 0 { break; }
|
||||
response_buf.push(temp[0]);
|
||||
if response_buf.len() >= 4 {
|
||||
let len = response_buf.len();
|
||||
if response_buf[len-4..] == *b"\r\n\r\n" {
|
||||
// Read the 101 response
|
||||
let mut response_buf = Vec::with_capacity(4096);
|
||||
let mut temp = [0u8; 1];
|
||||
loop {
|
||||
let n = stream.read(&mut temp).await.unwrap();
|
||||
if n == 0 {
|
||||
break;
|
||||
}
|
||||
response_buf.push(temp[0]);
|
||||
if response_buf.len() >= 4 {
|
||||
let len = response_buf.len();
|
||||
if response_buf[len - 4..] == *b"\r\n\r\n" {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let response_str = String::from_utf8_lossy(&response_buf).to_string();
|
||||
assert!(response_str.contains("101"), "Expected 101 Switching Protocols, got: {}", response_str);
|
||||
assert!(
|
||||
response_str.to_lowercase().contains("upgrade: websocket"),
|
||||
"Expected Upgrade header, got: {}",
|
||||
response_str
|
||||
);
|
||||
let response_str = String::from_utf8_lossy(&response_buf).to_string();
|
||||
assert!(
|
||||
response_str.contains("101"),
|
||||
"Expected 101 Switching Protocols, got: {}",
|
||||
response_str
|
||||
);
|
||||
assert!(
|
||||
response_str.to_lowercase().contains("upgrade: websocket"),
|
||||
"Expected Upgrade header, got: {}",
|
||||
response_str
|
||||
);
|
||||
|
||||
// After upgrade, send data and verify echo
|
||||
let test_data = b"Hello WebSocket!";
|
||||
stream.write_all(test_data).await.unwrap();
|
||||
// After upgrade, send data and verify echo
|
||||
let test_data = b"Hello WebSocket!";
|
||||
stream.write_all(test_data).await.unwrap();
|
||||
|
||||
// Read echoed data
|
||||
let mut echo_buf = vec![0u8; 256];
|
||||
let n = stream.read(&mut echo_buf).await.unwrap();
|
||||
let echoed = &echo_buf[..n];
|
||||
// Read echoed data
|
||||
let mut echo_buf = vec![0u8; 256];
|
||||
let n = stream.read(&mut echo_buf).await.unwrap();
|
||||
let echoed = &echo_buf[..n];
|
||||
|
||||
assert_eq!(echoed, test_data, "Expected echo of sent data");
|
||||
assert_eq!(echoed, test_data, "Expected echo of sent data");
|
||||
|
||||
"ok".to_string()
|
||||
}, 10)
|
||||
"ok".to_string()
|
||||
},
|
||||
10,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@@ -431,12 +548,22 @@ async fn test_terminate_and_reencrypt_http_routing() {
|
||||
|
||||
// Create terminate-and-reencrypt routes
|
||||
let mut route1 = make_tls_terminate_route(
|
||||
proxy_port, "alpha.example.com", "127.0.0.1", backend1_port, &cert1, &key1,
|
||||
proxy_port,
|
||||
"alpha.example.com",
|
||||
"127.0.0.1",
|
||||
backend1_port,
|
||||
&cert1,
|
||||
&key1,
|
||||
);
|
||||
route1.action.tls.as_mut().unwrap().mode = rustproxy_config::TlsMode::TerminateAndReencrypt;
|
||||
|
||||
let mut route2 = make_tls_terminate_route(
|
||||
proxy_port, "beta.example.com", "127.0.0.1", backend2_port, &cert2, &key2,
|
||||
proxy_port,
|
||||
"beta.example.com",
|
||||
"127.0.0.1",
|
||||
backend2_port,
|
||||
&cert2,
|
||||
&key2,
|
||||
);
|
||||
route2.action.tls.as_mut().unwrap().mode = rustproxy_config::TlsMode::TerminateAndReencrypt;
|
||||
|
||||
@@ -450,27 +577,32 @@ async fn test_terminate_and_reencrypt_http_routing() {
|
||||
assert!(wait_for_port(proxy_port, 2000).await);
|
||||
|
||||
// Test alpha domain - HTTP request through TLS terminate-and-reencrypt
|
||||
let alpha_result = with_timeout(async {
|
||||
let _ = rustls::crypto::ring::default_provider().install_default();
|
||||
let tls_config = rustls::ClientConfig::builder()
|
||||
.dangerous()
|
||||
.with_custom_certificate_verifier(std::sync::Arc::new(InsecureVerifier))
|
||||
.with_no_client_auth();
|
||||
let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(tls_config));
|
||||
let alpha_result = with_timeout(
|
||||
async {
|
||||
let _ = rustls::crypto::ring::default_provider().install_default();
|
||||
let tls_config = rustls::ClientConfig::builder()
|
||||
.dangerous()
|
||||
.with_custom_certificate_verifier(std::sync::Arc::new(InsecureVerifier))
|
||||
.with_no_client_auth();
|
||||
let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(tls_config));
|
||||
|
||||
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
let server_name = rustls::pki_types::ServerName::try_from("alpha.example.com".to_string()).unwrap();
|
||||
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
|
||||
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
let server_name =
|
||||
rustls::pki_types::ServerName::try_from("alpha.example.com".to_string()).unwrap();
|
||||
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
|
||||
|
||||
let request = "GET /api/data HTTP/1.1\r\nHost: alpha.example.com\r\nConnection: close\r\n\r\n";
|
||||
tls_stream.write_all(request.as_bytes()).await.unwrap();
|
||||
let request =
|
||||
"GET /api/data HTTP/1.1\r\nHost: alpha.example.com\r\nConnection: close\r\n\r\n";
|
||||
tls_stream.write_all(request.as_bytes()).await.unwrap();
|
||||
|
||||
let mut response = Vec::new();
|
||||
tls_stream.read_to_end(&mut response).await.unwrap();
|
||||
String::from_utf8_lossy(&response).to_string()
|
||||
}, 10)
|
||||
let mut response = Vec::new();
|
||||
tls_stream.read_to_end(&mut response).await.unwrap();
|
||||
String::from_utf8_lossy(&response).to_string()
|
||||
},
|
||||
10,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@@ -498,27 +630,32 @@ async fn test_terminate_and_reencrypt_http_routing() {
|
||||
);
|
||||
|
||||
// Test beta domain - different host goes to different backend
|
||||
let beta_result = with_timeout(async {
|
||||
let _ = rustls::crypto::ring::default_provider().install_default();
|
||||
let tls_config = rustls::ClientConfig::builder()
|
||||
.dangerous()
|
||||
.with_custom_certificate_verifier(std::sync::Arc::new(InsecureVerifier))
|
||||
.with_no_client_auth();
|
||||
let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(tls_config));
|
||||
let beta_result = with_timeout(
|
||||
async {
|
||||
let _ = rustls::crypto::ring::default_provider().install_default();
|
||||
let tls_config = rustls::ClientConfig::builder()
|
||||
.dangerous()
|
||||
.with_custom_certificate_verifier(std::sync::Arc::new(InsecureVerifier))
|
||||
.with_no_client_auth();
|
||||
let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(tls_config));
|
||||
|
||||
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
let server_name = rustls::pki_types::ServerName::try_from("beta.example.com".to_string()).unwrap();
|
||||
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
|
||||
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
let server_name =
|
||||
rustls::pki_types::ServerName::try_from("beta.example.com".to_string()).unwrap();
|
||||
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
|
||||
|
||||
let request = "GET /other HTTP/1.1\r\nHost: beta.example.com\r\nConnection: close\r\n\r\n";
|
||||
tls_stream.write_all(request.as_bytes()).await.unwrap();
|
||||
let request =
|
||||
"GET /other HTTP/1.1\r\nHost: beta.example.com\r\nConnection: close\r\n\r\n";
|
||||
tls_stream.write_all(request.as_bytes()).await.unwrap();
|
||||
|
||||
let mut response = Vec::new();
|
||||
tls_stream.read_to_end(&mut response).await.unwrap();
|
||||
String::from_utf8_lossy(&response).to_string()
|
||||
}, 10)
|
||||
let mut response = Vec::new();
|
||||
tls_stream.read_to_end(&mut response).await.unwrap();
|
||||
String::from_utf8_lossy(&response).to_string()
|
||||
},
|
||||
10,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@@ -589,14 +726,12 @@ async fn test_terminate_and_reencrypt_websocket() {
|
||||
.dangerous()
|
||||
.with_custom_certificate_verifier(std::sync::Arc::new(InsecureVerifier))
|
||||
.with_no_client_auth();
|
||||
let connector =
|
||||
tokio_rustls::TlsConnector::from(std::sync::Arc::new(tls_config));
|
||||
let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(tls_config));
|
||||
|
||||
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
let server_name =
|
||||
rustls::pki_types::ServerName::try_from(domain.to_string()).unwrap();
|
||||
let server_name = rustls::pki_types::ServerName::try_from(domain.to_string()).unwrap();
|
||||
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
|
||||
|
||||
// Send WebSocket upgrade request through TLS
|
||||
@@ -685,10 +820,13 @@ async fn test_protocol_field_in_route_config() {
|
||||
assert!(wait_for_port(proxy_port, 2000).await);
|
||||
|
||||
// HTTP request should match the route and get proxied
|
||||
let result = with_timeout(async {
|
||||
let response = send_http_request(proxy_port, "example.com", "GET", "/test").await;
|
||||
extract_body(&response).to_string()
|
||||
}, 10)
|
||||
let result = with_timeout(
|
||||
async {
|
||||
let response = send_http_request(proxy_port, "example.com", "GET", "/test").await;
|
||||
extract_body(&response).to_string()
|
||||
},
|
||||
10,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
|
||||
@@ -20,13 +20,19 @@ async fn test_start_and_stop() {
|
||||
assert!(!wait_for_port(port, 200).await);
|
||||
|
||||
proxy.start().await.unwrap();
|
||||
assert!(wait_for_port(port, 2000).await, "Port should be listening after start");
|
||||
assert!(
|
||||
wait_for_port(port, 2000).await,
|
||||
"Port should be listening after start"
|
||||
);
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
|
||||
// Give the OS a moment to release the port
|
||||
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
|
||||
assert!(!wait_for_port(port, 200).await, "Port should not be listening after stop");
|
||||
assert!(
|
||||
!wait_for_port(port, 200).await,
|
||||
"Port should not be listening after stop"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -54,7 +60,12 @@ async fn test_update_routes_hot_reload() {
|
||||
let port = next_port();
|
||||
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![make_test_route(port, Some("old.example.com"), "127.0.0.1", 8080)],
|
||||
routes: vec![make_test_route(
|
||||
port,
|
||||
Some("old.example.com"),
|
||||
"127.0.0.1",
|
||||
8080,
|
||||
)],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
@@ -62,9 +73,12 @@ async fn test_update_routes_hot_reload() {
|
||||
proxy.start().await.unwrap();
|
||||
|
||||
// Update routes atomically
|
||||
let new_routes = vec![
|
||||
make_test_route(port, Some("new.example.com"), "127.0.0.1", 9090),
|
||||
];
|
||||
let new_routes = vec![make_test_route(
|
||||
port,
|
||||
Some("new.example.com"),
|
||||
"127.0.0.1",
|
||||
9090,
|
||||
)];
|
||||
let result = proxy.update_routes(new_routes).await;
|
||||
assert!(result.is_ok());
|
||||
|
||||
@@ -87,15 +101,24 @@ async fn test_add_remove_listening_port() {
|
||||
|
||||
// Add a new port
|
||||
proxy.add_listening_port(port2).await.unwrap();
|
||||
assert!(wait_for_port(port2, 2000).await, "New port should be listening");
|
||||
assert!(
|
||||
wait_for_port(port2, 2000).await,
|
||||
"New port should be listening"
|
||||
);
|
||||
|
||||
// Remove the port
|
||||
proxy.remove_listening_port(port2).await.unwrap();
|
||||
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
|
||||
assert!(!wait_for_port(port2, 200).await, "Removed port should not be listening");
|
||||
assert!(
|
||||
!wait_for_port(port2, 200).await,
|
||||
"Removed port should not be listening"
|
||||
);
|
||||
|
||||
// Original port should still be listening
|
||||
assert!(wait_for_port(port1, 200).await, "Original port should still be listening");
|
||||
assert!(
|
||||
wait_for_port(port1, 200).await,
|
||||
"Original port should still be listening"
|
||||
);
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
}
|
||||
@@ -168,7 +191,11 @@ async fn test_metrics_track_connections() {
|
||||
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
|
||||
|
||||
let stats = proxy.get_statistics();
|
||||
assert!(stats.total_connections > 0, "Expected total_connections > 0, got {}", stats.total_connections);
|
||||
assert!(
|
||||
stats.total_connections > 0,
|
||||
"Expected total_connections > 0, got {}",
|
||||
stats.total_connections
|
||||
);
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
}
|
||||
@@ -205,8 +232,11 @@ async fn test_metrics_track_bytes() {
|
||||
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
|
||||
|
||||
let stats = proxy.get_statistics();
|
||||
assert!(stats.total_connections > 0,
|
||||
"Expected some connections tracked, got {}", stats.total_connections);
|
||||
assert!(
|
||||
stats.total_connections > 0,
|
||||
"Expected some connections tracked, got {}",
|
||||
stats.total_connections
|
||||
);
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
}
|
||||
@@ -228,23 +258,38 @@ async fn test_hot_reload_port_changes() {
|
||||
let mut proxy = RustProxy::new(options).unwrap();
|
||||
proxy.start().await.unwrap();
|
||||
assert!(wait_for_port(port1, 2000).await);
|
||||
assert!(!wait_for_port(port2, 200).await, "port2 should not be listening yet");
|
||||
assert!(
|
||||
!wait_for_port(port2, 200).await,
|
||||
"port2 should not be listening yet"
|
||||
);
|
||||
|
||||
// Update routes to use port2 instead
|
||||
let new_routes = vec![
|
||||
make_test_route(port2, None, "127.0.0.1", backend_port),
|
||||
];
|
||||
let new_routes = vec![make_test_route(port2, None, "127.0.0.1", backend_port)];
|
||||
proxy.update_routes(new_routes).await.unwrap();
|
||||
|
||||
// Port2 should now be listening, port1 should be closed
|
||||
assert!(wait_for_port(port2, 2000).await, "port2 should be listening after reload");
|
||||
assert!(
|
||||
wait_for_port(port2, 2000).await,
|
||||
"port2 should be listening after reload"
|
||||
);
|
||||
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
|
||||
assert!(!wait_for_port(port1, 200).await, "port1 should be closed after reload");
|
||||
assert!(
|
||||
!wait_for_port(port1, 200).await,
|
||||
"port1 should be closed after reload"
|
||||
);
|
||||
|
||||
// Verify port2 works
|
||||
let ports = proxy.get_listening_ports();
|
||||
assert!(ports.contains(&port2), "Expected port2 in listening ports: {:?}", ports);
|
||||
assert!(!ports.contains(&port1), "port1 should not be in listening ports: {:?}", ports);
|
||||
assert!(
|
||||
ports.contains(&port2),
|
||||
"Expected port2 in listening ports: {:?}",
|
||||
ports
|
||||
);
|
||||
assert!(
|
||||
!ports.contains(&port1),
|
||||
"port1 should not be in listening ports: {:?}",
|
||||
ports
|
||||
);
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
}
|
||||
|
||||
@@ -24,19 +24,25 @@ async fn test_tcp_forward_echo() {
|
||||
proxy.start().await.unwrap();
|
||||
|
||||
// Wait for proxy to be ready
|
||||
assert!(wait_for_port(proxy_port, 2000).await, "Proxy port not ready");
|
||||
assert!(
|
||||
wait_for_port(proxy_port, 2000).await,
|
||||
"Proxy port not ready"
|
||||
);
|
||||
|
||||
// Connect and send data
|
||||
let result = with_timeout(async {
|
||||
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
stream.write_all(b"hello world").await.unwrap();
|
||||
let result = with_timeout(
|
||||
async {
|
||||
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
stream.write_all(b"hello world").await.unwrap();
|
||||
|
||||
let mut buf = vec![0u8; 1024];
|
||||
let n = stream.read(&mut buf).await.unwrap();
|
||||
String::from_utf8_lossy(&buf[..n]).to_string()
|
||||
}, 5)
|
||||
let mut buf = vec![0u8; 1024];
|
||||
let n = stream.read(&mut buf).await.unwrap();
|
||||
String::from_utf8_lossy(&buf[..n]).to_string()
|
||||
},
|
||||
5,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@@ -61,21 +67,24 @@ async fn test_tcp_forward_large_payload() {
|
||||
proxy.start().await.unwrap();
|
||||
assert!(wait_for_port(proxy_port, 2000).await);
|
||||
|
||||
let result = with_timeout(async {
|
||||
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
let result = with_timeout(
|
||||
async {
|
||||
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Send 1MB of data
|
||||
let data = vec![b'A'; 1_000_000];
|
||||
stream.write_all(&data).await.unwrap();
|
||||
stream.shutdown().await.unwrap();
|
||||
// Send 1MB of data
|
||||
let data = vec![b'A'; 1_000_000];
|
||||
stream.write_all(&data).await.unwrap();
|
||||
stream.shutdown().await.unwrap();
|
||||
|
||||
// Read all back
|
||||
let mut received = Vec::new();
|
||||
stream.read_to_end(&mut received).await.unwrap();
|
||||
received.len()
|
||||
}, 10)
|
||||
// Read all back
|
||||
let mut received = Vec::new();
|
||||
stream.read_to_end(&mut received).await.unwrap();
|
||||
received.len()
|
||||
},
|
||||
10,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@@ -100,29 +109,32 @@ async fn test_tcp_forward_multiple_connections() {
|
||||
proxy.start().await.unwrap();
|
||||
assert!(wait_for_port(proxy_port, 2000).await);
|
||||
|
||||
let result = with_timeout(async {
|
||||
let mut handles = Vec::new();
|
||||
for i in 0..10 {
|
||||
let port = proxy_port;
|
||||
handles.push(tokio::spawn(async move {
|
||||
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", port))
|
||||
.await
|
||||
.unwrap();
|
||||
let msg = format!("connection-{}", i);
|
||||
stream.write_all(msg.as_bytes()).await.unwrap();
|
||||
let result = with_timeout(
|
||||
async {
|
||||
let mut handles = Vec::new();
|
||||
for i in 0..10 {
|
||||
let port = proxy_port;
|
||||
handles.push(tokio::spawn(async move {
|
||||
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", port))
|
||||
.await
|
||||
.unwrap();
|
||||
let msg = format!("connection-{}", i);
|
||||
stream.write_all(msg.as_bytes()).await.unwrap();
|
||||
|
||||
let mut buf = vec![0u8; 1024];
|
||||
let n = stream.read(&mut buf).await.unwrap();
|
||||
String::from_utf8_lossy(&buf[..n]).to_string()
|
||||
}));
|
||||
}
|
||||
let mut buf = vec![0u8; 1024];
|
||||
let n = stream.read(&mut buf).await.unwrap();
|
||||
String::from_utf8_lossy(&buf[..n]).to_string()
|
||||
}));
|
||||
}
|
||||
|
||||
let mut results = Vec::new();
|
||||
for handle in handles {
|
||||
results.push(handle.await.unwrap());
|
||||
}
|
||||
results
|
||||
}, 10)
|
||||
let mut results = Vec::new();
|
||||
for handle in handles {
|
||||
results.push(handle.await.unwrap());
|
||||
}
|
||||
results
|
||||
},
|
||||
10,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@@ -149,14 +161,20 @@ async fn test_tcp_forward_backend_unreachable() {
|
||||
assert!(wait_for_port(proxy_port, 2000).await);
|
||||
|
||||
// Connection should complete (proxy accepts it) but data should not flow
|
||||
let result = with_timeout(async {
|
||||
let stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port)).await;
|
||||
stream.is_ok()
|
||||
}, 5)
|
||||
let result = with_timeout(
|
||||
async {
|
||||
let stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port)).await;
|
||||
stream.is_ok()
|
||||
},
|
||||
5,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(result, "Should be able to connect to proxy even if backend is down");
|
||||
assert!(
|
||||
result,
|
||||
"Should be able to connect to proxy even if backend is down"
|
||||
);
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
}
|
||||
@@ -178,16 +196,19 @@ async fn test_tcp_forward_bidirectional() {
|
||||
proxy.start().await.unwrap();
|
||||
assert!(wait_for_port(proxy_port, 2000).await);
|
||||
|
||||
let result = with_timeout(async {
|
||||
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
stream.write_all(b"test data").await.unwrap();
|
||||
let result = with_timeout(
|
||||
async {
|
||||
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
stream.write_all(b"test data").await.unwrap();
|
||||
|
||||
let mut buf = vec![0u8; 1024];
|
||||
let n = stream.read(&mut buf).await.unwrap();
|
||||
String::from_utf8_lossy(&buf[..n]).to_string()
|
||||
}, 5)
|
||||
let mut buf = vec![0u8; 1024];
|
||||
let n = stream.read(&mut buf).await.unwrap();
|
||||
String::from_utf8_lossy(&buf[..n]).to_string()
|
||||
},
|
||||
5,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
|
||||
@@ -65,8 +65,18 @@ async fn test_tls_passthrough_sni_routing() {
|
||||
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![
|
||||
make_tls_passthrough_route(proxy_port, Some("one.example.com"), "127.0.0.1", backend1_port),
|
||||
make_tls_passthrough_route(proxy_port, Some("two.example.com"), "127.0.0.1", backend2_port),
|
||||
make_tls_passthrough_route(
|
||||
proxy_port,
|
||||
Some("one.example.com"),
|
||||
"127.0.0.1",
|
||||
backend1_port,
|
||||
),
|
||||
make_tls_passthrough_route(
|
||||
proxy_port,
|
||||
Some("two.example.com"),
|
||||
"127.0.0.1",
|
||||
backend2_port,
|
||||
),
|
||||
],
|
||||
..Default::default()
|
||||
};
|
||||
@@ -76,39 +86,53 @@ async fn test_tls_passthrough_sni_routing() {
|
||||
assert!(wait_for_port(proxy_port, 2000).await);
|
||||
|
||||
// Send a fake ClientHello with SNI "one.example.com"
|
||||
let result = with_timeout(async {
|
||||
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
let hello = build_client_hello("one.example.com");
|
||||
stream.write_all(&hello).await.unwrap();
|
||||
let result = with_timeout(
|
||||
async {
|
||||
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
let hello = build_client_hello("one.example.com");
|
||||
stream.write_all(&hello).await.unwrap();
|
||||
|
||||
let mut buf = vec![0u8; 4096];
|
||||
let n = stream.read(&mut buf).await.unwrap();
|
||||
String::from_utf8_lossy(&buf[..n]).to_string()
|
||||
}, 5)
|
||||
let mut buf = vec![0u8; 4096];
|
||||
let n = stream.read(&mut buf).await.unwrap();
|
||||
String::from_utf8_lossy(&buf[..n]).to_string()
|
||||
},
|
||||
5,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Backend1 should have received the ClientHello and prefixed its response
|
||||
assert!(result.starts_with("BACKEND1:"), "Expected BACKEND1 prefix, got: {}", result);
|
||||
assert!(
|
||||
result.starts_with("BACKEND1:"),
|
||||
"Expected BACKEND1 prefix, got: {}",
|
||||
result
|
||||
);
|
||||
|
||||
// Now test routing to backend2
|
||||
let result2 = with_timeout(async {
|
||||
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
let hello = build_client_hello("two.example.com");
|
||||
stream.write_all(&hello).await.unwrap();
|
||||
let result2 = with_timeout(
|
||||
async {
|
||||
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
let hello = build_client_hello("two.example.com");
|
||||
stream.write_all(&hello).await.unwrap();
|
||||
|
||||
let mut buf = vec![0u8; 4096];
|
||||
let n = stream.read(&mut buf).await.unwrap();
|
||||
String::from_utf8_lossy(&buf[..n]).to_string()
|
||||
}, 5)
|
||||
let mut buf = vec![0u8; 4096];
|
||||
let n = stream.read(&mut buf).await.unwrap();
|
||||
String::from_utf8_lossy(&buf[..n]).to_string()
|
||||
},
|
||||
5,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(result2.starts_with("BACKEND2:"), "Expected BACKEND2 prefix, got: {}", result2);
|
||||
assert!(
|
||||
result2.starts_with("BACKEND2:"),
|
||||
"Expected BACKEND2 prefix, got: {}",
|
||||
result2
|
||||
);
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
}
|
||||
@@ -121,9 +145,12 @@ async fn test_tls_passthrough_unknown_sni() {
|
||||
let _backend = start_echo_server(backend_port).await;
|
||||
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![
|
||||
make_tls_passthrough_route(proxy_port, Some("known.example.com"), "127.0.0.1", backend_port),
|
||||
],
|
||||
routes: vec![make_tls_passthrough_route(
|
||||
proxy_port,
|
||||
Some("known.example.com"),
|
||||
"127.0.0.1",
|
||||
backend_port,
|
||||
)],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
@@ -132,21 +159,24 @@ async fn test_tls_passthrough_unknown_sni() {
|
||||
assert!(wait_for_port(proxy_port, 2000).await);
|
||||
|
||||
// Send ClientHello with unknown SNI - should get no response (connection dropped)
|
||||
let result = with_timeout(async {
|
||||
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
let hello = build_client_hello("unknown.example.com");
|
||||
stream.write_all(&hello).await.unwrap();
|
||||
let result = with_timeout(
|
||||
async {
|
||||
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
let hello = build_client_hello("unknown.example.com");
|
||||
stream.write_all(&hello).await.unwrap();
|
||||
|
||||
let mut buf = vec![0u8; 4096];
|
||||
// Should either get 0 bytes (closed) or an error
|
||||
match stream.read(&mut buf).await {
|
||||
Ok(0) => true, // Connection closed = no route matched
|
||||
Ok(_) => false, // Got data = route shouldn't have matched
|
||||
Err(_) => true, // Error = connection dropped
|
||||
}
|
||||
}, 5)
|
||||
let mut buf = vec![0u8; 4096];
|
||||
// Should either get 0 bytes (closed) or an error
|
||||
match stream.read(&mut buf).await {
|
||||
Ok(0) => true, // Connection closed = no route matched
|
||||
Ok(_) => false, // Got data = route shouldn't have matched
|
||||
Err(_) => true, // Error = connection dropped
|
||||
}
|
||||
},
|
||||
5,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@@ -163,9 +193,12 @@ async fn test_tls_passthrough_wildcard_domain() {
|
||||
let _backend = start_prefix_echo_server(backend_port, "WILDCARD:").await;
|
||||
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![
|
||||
make_tls_passthrough_route(proxy_port, Some("*.example.com"), "127.0.0.1", backend_port),
|
||||
],
|
||||
routes: vec![make_tls_passthrough_route(
|
||||
proxy_port,
|
||||
Some("*.example.com"),
|
||||
"127.0.0.1",
|
||||
backend_port,
|
||||
)],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
@@ -174,21 +207,28 @@ async fn test_tls_passthrough_wildcard_domain() {
|
||||
assert!(wait_for_port(proxy_port, 2000).await);
|
||||
|
||||
// Should match any subdomain of example.com
|
||||
let result = with_timeout(async {
|
||||
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
let hello = build_client_hello("anything.example.com");
|
||||
stream.write_all(&hello).await.unwrap();
|
||||
let result = with_timeout(
|
||||
async {
|
||||
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
let hello = build_client_hello("anything.example.com");
|
||||
stream.write_all(&hello).await.unwrap();
|
||||
|
||||
let mut buf = vec![0u8; 4096];
|
||||
let n = stream.read(&mut buf).await.unwrap();
|
||||
String::from_utf8_lossy(&buf[..n]).to_string()
|
||||
}, 5)
|
||||
let mut buf = vec![0u8; 4096];
|
||||
let n = stream.read(&mut buf).await.unwrap();
|
||||
String::from_utf8_lossy(&buf[..n]).to_string()
|
||||
},
|
||||
5,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(result.starts_with("WILDCARD:"), "Expected WILDCARD prefix, got: {}", result);
|
||||
assert!(
|
||||
result.starts_with("WILDCARD:"),
|
||||
"Expected WILDCARD prefix, got: {}",
|
||||
result
|
||||
);
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
}
|
||||
@@ -222,24 +262,29 @@ async fn test_tls_passthrough_multiple_domains() {
|
||||
("beta.example.com", "B2:"),
|
||||
("gamma.example.com", "B3:"),
|
||||
] {
|
||||
let result = with_timeout(async {
|
||||
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
let hello = build_client_hello(domain);
|
||||
stream.write_all(&hello).await.unwrap();
|
||||
let result = with_timeout(
|
||||
async {
|
||||
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
let hello = build_client_hello(domain);
|
||||
stream.write_all(&hello).await.unwrap();
|
||||
|
||||
let mut buf = vec![0u8; 4096];
|
||||
let n = stream.read(&mut buf).await.unwrap();
|
||||
String::from_utf8_lossy(&buf[..n]).to_string()
|
||||
}, 5)
|
||||
let mut buf = vec![0u8; 4096];
|
||||
let n = stream.read(&mut buf).await.unwrap();
|
||||
String::from_utf8_lossy(&buf[..n]).to_string()
|
||||
},
|
||||
5,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(
|
||||
result.starts_with(expected_prefix),
|
||||
"Domain {} should route to {}, got: {}",
|
||||
domain, expected_prefix, result
|
||||
domain,
|
||||
expected_prefix,
|
||||
result
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@@ -74,7 +74,12 @@ async fn test_tls_terminate_basic() {
|
||||
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![make_tls_terminate_route(
|
||||
proxy_port, domain, "127.0.0.1", backend_port, &cert_pem, &key_pem,
|
||||
proxy_port,
|
||||
domain,
|
||||
"127.0.0.1",
|
||||
backend_port,
|
||||
&cert_pem,
|
||||
&key_pem,
|
||||
)],
|
||||
..Default::default()
|
||||
};
|
||||
@@ -84,23 +89,26 @@ async fn test_tls_terminate_basic() {
|
||||
assert!(wait_for_port(proxy_port, 2000).await);
|
||||
|
||||
// Connect with TLS client
|
||||
let result = with_timeout(async {
|
||||
let tls_config = make_insecure_tls_client_config();
|
||||
let connector = tokio_rustls::TlsConnector::from(tls_config);
|
||||
let result = with_timeout(
|
||||
async {
|
||||
let tls_config = make_insecure_tls_client_config();
|
||||
let connector = tokio_rustls::TlsConnector::from(tls_config);
|
||||
|
||||
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let server_name = rustls::pki_types::ServerName::try_from(domain.to_string()).unwrap();
|
||||
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
|
||||
let server_name = rustls::pki_types::ServerName::try_from(domain.to_string()).unwrap();
|
||||
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
|
||||
|
||||
tls_stream.write_all(b"hello TLS").await.unwrap();
|
||||
tls_stream.write_all(b"hello TLS").await.unwrap();
|
||||
|
||||
let mut buf = vec![0u8; 1024];
|
||||
let n = tls_stream.read(&mut buf).await.unwrap();
|
||||
String::from_utf8_lossy(&buf[..n]).to_string()
|
||||
}, 10)
|
||||
let mut buf = vec![0u8; 1024];
|
||||
let n = tls_stream.read(&mut buf).await.unwrap();
|
||||
String::from_utf8_lossy(&buf[..n]).to_string()
|
||||
},
|
||||
10,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@@ -125,7 +133,12 @@ async fn test_tls_terminate_and_reencrypt() {
|
||||
|
||||
// Create terminate-and-reencrypt route
|
||||
let mut route = make_tls_terminate_route(
|
||||
proxy_port, domain, "127.0.0.1", backend_port, &proxy_cert, &proxy_key,
|
||||
proxy_port,
|
||||
domain,
|
||||
"127.0.0.1",
|
||||
backend_port,
|
||||
&proxy_cert,
|
||||
&proxy_key,
|
||||
);
|
||||
route.action.tls.as_mut().unwrap().mode = rustproxy_config::TlsMode::TerminateAndReencrypt;
|
||||
|
||||
@@ -138,23 +151,26 @@ async fn test_tls_terminate_and_reencrypt() {
|
||||
proxy.start().await.unwrap();
|
||||
assert!(wait_for_port(proxy_port, 2000).await);
|
||||
|
||||
let result = with_timeout(async {
|
||||
let tls_config = make_insecure_tls_client_config();
|
||||
let connector = tokio_rustls::TlsConnector::from(tls_config);
|
||||
let result = with_timeout(
|
||||
async {
|
||||
let tls_config = make_insecure_tls_client_config();
|
||||
let connector = tokio_rustls::TlsConnector::from(tls_config);
|
||||
|
||||
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let server_name = rustls::pki_types::ServerName::try_from(domain.to_string()).unwrap();
|
||||
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
|
||||
let server_name = rustls::pki_types::ServerName::try_from(domain.to_string()).unwrap();
|
||||
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
|
||||
|
||||
tls_stream.write_all(b"hello reencrypt").await.unwrap();
|
||||
tls_stream.write_all(b"hello reencrypt").await.unwrap();
|
||||
|
||||
let mut buf = vec![0u8; 1024];
|
||||
let n = tls_stream.read(&mut buf).await.unwrap();
|
||||
String::from_utf8_lossy(&buf[..n]).to_string()
|
||||
}, 10)
|
||||
let mut buf = vec![0u8; 1024];
|
||||
let n = tls_stream.read(&mut buf).await.unwrap();
|
||||
String::from_utf8_lossy(&buf[..n]).to_string()
|
||||
},
|
||||
10,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@@ -177,8 +193,22 @@ async fn test_tls_terminate_sni_cert_selection() {
|
||||
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![
|
||||
make_tls_terminate_route(proxy_port, "alpha.example.com", "127.0.0.1", backend1_port, &cert1, &key1),
|
||||
make_tls_terminate_route(proxy_port, "beta.example.com", "127.0.0.1", backend2_port, &cert2, &key2),
|
||||
make_tls_terminate_route(
|
||||
proxy_port,
|
||||
"alpha.example.com",
|
||||
"127.0.0.1",
|
||||
backend1_port,
|
||||
&cert1,
|
||||
&key1,
|
||||
),
|
||||
make_tls_terminate_route(
|
||||
proxy_port,
|
||||
"beta.example.com",
|
||||
"127.0.0.1",
|
||||
backend2_port,
|
||||
&cert2,
|
||||
&key2,
|
||||
),
|
||||
],
|
||||
..Default::default()
|
||||
};
|
||||
@@ -188,27 +218,35 @@ async fn test_tls_terminate_sni_cert_selection() {
|
||||
assert!(wait_for_port(proxy_port, 2000).await);
|
||||
|
||||
// Test alpha domain
|
||||
let result = with_timeout(async {
|
||||
let tls_config = make_insecure_tls_client_config();
|
||||
let connector = tokio_rustls::TlsConnector::from(tls_config);
|
||||
let result = with_timeout(
|
||||
async {
|
||||
let tls_config = make_insecure_tls_client_config();
|
||||
let connector = tokio_rustls::TlsConnector::from(tls_config);
|
||||
|
||||
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let server_name = rustls::pki_types::ServerName::try_from("alpha.example.com".to_string()).unwrap();
|
||||
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
|
||||
let server_name =
|
||||
rustls::pki_types::ServerName::try_from("alpha.example.com".to_string()).unwrap();
|
||||
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
|
||||
|
||||
tls_stream.write_all(b"test").await.unwrap();
|
||||
tls_stream.write_all(b"test").await.unwrap();
|
||||
|
||||
let mut buf = vec![0u8; 1024];
|
||||
let n = tls_stream.read(&mut buf).await.unwrap();
|
||||
String::from_utf8_lossy(&buf[..n]).to_string()
|
||||
}, 10)
|
||||
let mut buf = vec![0u8; 1024];
|
||||
let n = tls_stream.read(&mut buf).await.unwrap();
|
||||
String::from_utf8_lossy(&buf[..n]).to_string()
|
||||
},
|
||||
10,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(result.starts_with("ALPHA:"), "Expected ALPHA prefix, got: {}", result);
|
||||
assert!(
|
||||
result.starts_with("ALPHA:"),
|
||||
"Expected ALPHA prefix, got: {}",
|
||||
result
|
||||
);
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
}
|
||||
@@ -224,7 +262,12 @@ async fn test_tls_terminate_large_payload() {
|
||||
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![make_tls_terminate_route(
|
||||
proxy_port, domain, "127.0.0.1", backend_port, &cert_pem, &key_pem,
|
||||
proxy_port,
|
||||
domain,
|
||||
"127.0.0.1",
|
||||
backend_port,
|
||||
&cert_pem,
|
||||
&key_pem,
|
||||
)],
|
||||
..Default::default()
|
||||
};
|
||||
@@ -233,26 +276,29 @@ async fn test_tls_terminate_large_payload() {
|
||||
proxy.start().await.unwrap();
|
||||
assert!(wait_for_port(proxy_port, 2000).await);
|
||||
|
||||
let result = with_timeout(async {
|
||||
let tls_config = make_insecure_tls_client_config();
|
||||
let connector = tokio_rustls::TlsConnector::from(tls_config);
|
||||
let result = with_timeout(
|
||||
async {
|
||||
let tls_config = make_insecure_tls_client_config();
|
||||
let connector = tokio_rustls::TlsConnector::from(tls_config);
|
||||
|
||||
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let server_name = rustls::pki_types::ServerName::try_from(domain.to_string()).unwrap();
|
||||
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
|
||||
let server_name = rustls::pki_types::ServerName::try_from(domain.to_string()).unwrap();
|
||||
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
|
||||
|
||||
// Send 1MB of data
|
||||
let data = vec![b'X'; 1_000_000];
|
||||
tls_stream.write_all(&data).await.unwrap();
|
||||
tls_stream.shutdown().await.unwrap();
|
||||
// Send 1MB of data
|
||||
let data = vec![b'X'; 1_000_000];
|
||||
tls_stream.write_all(&data).await.unwrap();
|
||||
tls_stream.shutdown().await.unwrap();
|
||||
|
||||
let mut received = Vec::new();
|
||||
tls_stream.read_to_end(&mut received).await.unwrap();
|
||||
received.len()
|
||||
}, 15)
|
||||
let mut received = Vec::new();
|
||||
tls_stream.read_to_end(&mut received).await.unwrap();
|
||||
received.len()
|
||||
},
|
||||
15,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@@ -272,7 +318,12 @@ async fn test_tls_terminate_concurrent() {
|
||||
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![make_tls_terminate_route(
|
||||
proxy_port, domain, "127.0.0.1", backend_port, &cert_pem, &key_pem,
|
||||
proxy_port,
|
||||
domain,
|
||||
"127.0.0.1",
|
||||
backend_port,
|
||||
&cert_pem,
|
||||
&key_pem,
|
||||
)],
|
||||
..Default::default()
|
||||
};
|
||||
@@ -281,37 +332,40 @@ async fn test_tls_terminate_concurrent() {
|
||||
proxy.start().await.unwrap();
|
||||
assert!(wait_for_port(proxy_port, 2000).await);
|
||||
|
||||
let result = with_timeout(async {
|
||||
let mut handles = Vec::new();
|
||||
for i in 0..10 {
|
||||
let port = proxy_port;
|
||||
let dom = domain.to_string();
|
||||
handles.push(tokio::spawn(async move {
|
||||
let tls_config = make_insecure_tls_client_config();
|
||||
let connector = tokio_rustls::TlsConnector::from(tls_config);
|
||||
let result = with_timeout(
|
||||
async {
|
||||
let mut handles = Vec::new();
|
||||
for i in 0..10 {
|
||||
let port = proxy_port;
|
||||
let dom = domain.to_string();
|
||||
handles.push(tokio::spawn(async move {
|
||||
let tls_config = make_insecure_tls_client_config();
|
||||
let connector = tokio_rustls::TlsConnector::from(tls_config);
|
||||
|
||||
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", port))
|
||||
.await
|
||||
.unwrap();
|
||||
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", port))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let server_name = rustls::pki_types::ServerName::try_from(dom).unwrap();
|
||||
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
|
||||
let server_name = rustls::pki_types::ServerName::try_from(dom).unwrap();
|
||||
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
|
||||
|
||||
let msg = format!("conn-{}", i);
|
||||
tls_stream.write_all(msg.as_bytes()).await.unwrap();
|
||||
let msg = format!("conn-{}", i);
|
||||
tls_stream.write_all(msg.as_bytes()).await.unwrap();
|
||||
|
||||
let mut buf = vec![0u8; 1024];
|
||||
let n = tls_stream.read(&mut buf).await.unwrap();
|
||||
String::from_utf8_lossy(&buf[..n]).to_string()
|
||||
}));
|
||||
}
|
||||
let mut buf = vec![0u8; 1024];
|
||||
let n = tls_stream.read(&mut buf).await.unwrap();
|
||||
String::from_utf8_lossy(&buf[..n]).to_string()
|
||||
}));
|
||||
}
|
||||
|
||||
let mut results = Vec::new();
|
||||
for handle in handles {
|
||||
results.push(handle.await.unwrap());
|
||||
}
|
||||
results
|
||||
}, 15)
|
||||
let mut results = Vec::new();
|
||||
for handle in handles {
|
||||
results.push(handle.await.unwrap());
|
||||
}
|
||||
results
|
||||
},
|
||||
15,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
|
||||
+33
-36
@@ -1,11 +1,5 @@
|
||||
import { tap, expect } from '@git.zone/tstest/tapbundle';
|
||||
|
||||
import {
|
||||
createHttpsTerminateRoute,
|
||||
createCompleteHttpsServer,
|
||||
createHttpRoute,
|
||||
} from '../ts/proxies/smart-proxy/utils/route-helpers.js';
|
||||
|
||||
import {
|
||||
mergeRouteConfigs,
|
||||
cloneRoute,
|
||||
@@ -19,8 +13,11 @@ import {
|
||||
|
||||
import type { IRouteConfig } from '../ts/proxies/smart-proxy/models/route-types.js';
|
||||
|
||||
tap.test('route creation - createHttpsTerminateRoute produces correct structure', async () => {
|
||||
const route = createHttpsTerminateRoute('secure.example.com', { host: '127.0.0.1', port: 8443 });
|
||||
tap.test('route creation - HTTPS terminate route has correct structure', async () => {
|
||||
const route: IRouteConfig = {
|
||||
match: { ports: 443, domains: 'secure.example.com' },
|
||||
action: { type: 'forward', targets: [{ host: '127.0.0.1', port: 8443 }], tls: { mode: 'terminate', certificate: 'auto' } }
|
||||
};
|
||||
expect(route).toHaveProperty('match');
|
||||
expect(route).toHaveProperty('action');
|
||||
expect(route.action.type).toEqual('forward');
|
||||
@@ -29,20 +26,10 @@ tap.test('route creation - createHttpsTerminateRoute produces correct structure'
|
||||
expect(route.match.domains).toEqual('secure.example.com');
|
||||
});
|
||||
|
||||
tap.test('route creation - createCompleteHttpsServer returns redirect and main route', async () => {
|
||||
const routes = createCompleteHttpsServer('app.example.com', { host: '127.0.0.1', port: 3000 });
|
||||
expect(routes).toBeArray();
|
||||
expect(routes.length).toBeGreaterThanOrEqual(2);
|
||||
// Should have an HTTP→HTTPS redirect and an HTTPS route
|
||||
const hasRedirect = routes.some((r) => r.action.type === 'forward' && r.action.redirect !== undefined);
|
||||
const hasHttps = routes.some((r) => r.action.tls?.mode === 'terminate');
|
||||
expect(hasRedirect || hasHttps).toBeTrue();
|
||||
});
|
||||
|
||||
tap.test('route validation - validateRoutes on a set of routes', async () => {
|
||||
const routes: IRouteConfig[] = [
|
||||
createHttpRoute('a.com', { host: '127.0.0.1', port: 3000 }),
|
||||
createHttpRoute('b.com', { host: '127.0.0.1', port: 4000 }),
|
||||
{ match: { ports: 80, domains: 'a.com' }, action: { type: 'forward', targets: [{ host: '127.0.0.1', port: 3000 }] } },
|
||||
{ match: { ports: 80, domains: 'b.com' }, action: { type: 'forward', targets: [{ host: '127.0.0.1', port: 4000 }] } },
|
||||
];
|
||||
const result = validateRoutes(routes);
|
||||
expect(result.valid).toBeTrue();
|
||||
@@ -51,7 +38,7 @@ tap.test('route validation - validateRoutes on a set of routes', async () => {
|
||||
|
||||
tap.test('route validation - validateRoutes catches invalid route in set', async () => {
|
||||
const routes: any[] = [
|
||||
createHttpRoute('valid.com', { host: '127.0.0.1', port: 3000 }),
|
||||
{ match: { ports: 80, domains: 'valid.com' }, action: { type: 'forward', targets: [{ host: '127.0.0.1', port: 3000 }] } },
|
||||
{ match: { ports: 80 } }, // missing action
|
||||
];
|
||||
const result = validateRoutes(routes);
|
||||
@@ -60,23 +47,30 @@ tap.test('route validation - validateRoutes catches invalid route in set', async
|
||||
});
|
||||
|
||||
tap.test('path matching - routeMatchesPath with exact path', async () => {
|
||||
const route = createHttpRoute('example.com', { host: '127.0.0.1', port: 3000 });
|
||||
route.match.path = '/api';
|
||||
const route: IRouteConfig = {
|
||||
match: { ports: 80, domains: 'example.com', path: '/api' },
|
||||
action: { type: 'forward', targets: [{ host: '127.0.0.1', port: 3000 }] }
|
||||
};
|
||||
expect(routeMatchesPath(route, '/api')).toBeTrue();
|
||||
expect(routeMatchesPath(route, '/other')).toBeFalse();
|
||||
});
|
||||
|
||||
tap.test('path matching - route without path matches everything', async () => {
|
||||
const route = createHttpRoute('example.com', { host: '127.0.0.1', port: 3000 });
|
||||
// No path set, should match any path
|
||||
const route: IRouteConfig = {
|
||||
match: { ports: 80, domains: 'example.com' },
|
||||
action: { type: 'forward', targets: [{ host: '127.0.0.1', port: 3000 }] }
|
||||
};
|
||||
expect(routeMatchesPath(route, '/anything')).toBeTrue();
|
||||
expect(routeMatchesPath(route, '/')).toBeTrue();
|
||||
});
|
||||
|
||||
tap.test('route merging - mergeRouteConfigs combines routes', async () => {
|
||||
const base = createHttpRoute('example.com', { host: '127.0.0.1', port: 3000 });
|
||||
base.priority = 10;
|
||||
base.name = 'base-route';
|
||||
const base: IRouteConfig = {
|
||||
match: { ports: 80, domains: 'example.com' },
|
||||
action: { type: 'forward', targets: [{ host: '127.0.0.1', port: 3000 }] },
|
||||
priority: 10,
|
||||
name: 'base-route'
|
||||
};
|
||||
|
||||
const merged = mergeRouteConfigs(base, {
|
||||
priority: 50,
|
||||
@@ -85,14 +79,16 @@ tap.test('route merging - mergeRouteConfigs combines routes', async () => {
|
||||
|
||||
expect(merged.priority).toEqual(50);
|
||||
expect(merged.name).toEqual('merged-route');
|
||||
// Original route fields should be preserved
|
||||
expect(merged.match.domains).toEqual('example.com');
|
||||
expect(merged.action.targets![0].host).toEqual('127.0.0.1');
|
||||
});
|
||||
|
||||
tap.test('route merging - mergeRouteConfigs does not mutate original', async () => {
|
||||
const base = createHttpRoute('example.com', { host: '127.0.0.1', port: 3000 });
|
||||
base.name = 'original';
|
||||
const base: IRouteConfig = {
|
||||
match: { ports: 80, domains: 'example.com' },
|
||||
action: { type: 'forward', targets: [{ host: '127.0.0.1', port: 3000 }] },
|
||||
name: 'original'
|
||||
};
|
||||
|
||||
const merged = mergeRouteConfigs(base, { name: 'changed' });
|
||||
expect(base.name).toEqual('original');
|
||||
@@ -100,20 +96,21 @@ tap.test('route merging - mergeRouteConfigs does not mutate original', async ()
|
||||
});
|
||||
|
||||
tap.test('route cloning - cloneRoute produces independent copy', async () => {
|
||||
const original = createHttpRoute('example.com', { host: '127.0.0.1', port: 3000 });
|
||||
original.priority = 42;
|
||||
original.name = 'original-route';
|
||||
const original: IRouteConfig = {
|
||||
match: { ports: 80, domains: 'example.com' },
|
||||
action: { type: 'forward', targets: [{ host: '127.0.0.1', port: 3000 }] },
|
||||
priority: 42,
|
||||
name: 'original-route'
|
||||
};
|
||||
|
||||
const cloned = cloneRoute(original);
|
||||
|
||||
// Should be equal in value
|
||||
expect(cloned.match.domains).toEqual('example.com');
|
||||
expect(cloned.priority).toEqual(42);
|
||||
expect(cloned.name).toEqual('original-route');
|
||||
expect(cloned.action.targets![0].host).toEqual('127.0.0.1');
|
||||
expect(cloned.action.targets![0].port).toEqual(3000);
|
||||
|
||||
// Should be independent - modifying clone shouldn't affect original
|
||||
cloned.name = 'cloned-route';
|
||||
cloned.priority = 99;
|
||||
expect(original.name).toEqual('original-route');
|
||||
|
||||
+38
-34
@@ -1,11 +1,5 @@
|
||||
import { tap, expect } from '@git.zone/tstest/tapbundle';
|
||||
|
||||
import {
|
||||
createHttpRoute,
|
||||
createHttpsTerminateRoute,
|
||||
createLoadBalancerRoute,
|
||||
} from '../ts/proxies/smart-proxy/utils/route-helpers.js';
|
||||
|
||||
import {
|
||||
findMatchingRoutes,
|
||||
findBestMatchingRoute,
|
||||
@@ -22,24 +16,11 @@ import {
|
||||
|
||||
import type { IRouteConfig } from '../ts/proxies/smart-proxy/models/route-types.js';
|
||||
|
||||
tap.test('route creation - createHttpRoute produces correct structure', async () => {
|
||||
const route = createHttpRoute('example.com', { host: '127.0.0.1', port: 3000 });
|
||||
expect(route).toHaveProperty('match');
|
||||
expect(route).toHaveProperty('action');
|
||||
expect(route.match.domains).toEqual('example.com');
|
||||
expect(route.action.type).toEqual('forward');
|
||||
expect(route.action.targets).toBeArray();
|
||||
expect(route.action.targets![0].host).toEqual('127.0.0.1');
|
||||
expect(route.action.targets![0].port).toEqual(3000);
|
||||
});
|
||||
|
||||
tap.test('route creation - createHttpRoute with array of domains', async () => {
|
||||
const route = createHttpRoute(['a.com', 'b.com'], { host: 'localhost', port: 8080 });
|
||||
expect(route.match.domains).toEqual(['a.com', 'b.com']);
|
||||
});
|
||||
|
||||
tap.test('route validation - validateRouteConfig accepts valid route', async () => {
|
||||
const route = createHttpRoute('valid.example.com', { host: '10.0.0.1', port: 8080 });
|
||||
const route: IRouteConfig = {
|
||||
match: { ports: 80, domains: 'valid.example.com' },
|
||||
action: { type: 'forward', targets: [{ host: '10.0.0.1', port: 8080 }] }
|
||||
};
|
||||
const result = validateRouteConfig(route);
|
||||
expect(result.valid).toBeTrue();
|
||||
expect(result.errors).toHaveLength(0);
|
||||
@@ -67,30 +48,44 @@ tap.test('route validation - isValidPort checks correctly', async () => {
|
||||
});
|
||||
|
||||
tap.test('domain matching - exact domain', async () => {
|
||||
const route = createHttpRoute('example.com', { host: '127.0.0.1', port: 3000 });
|
||||
const route: IRouteConfig = {
|
||||
match: { ports: 80, domains: 'example.com' },
|
||||
action: { type: 'forward', targets: [{ host: '127.0.0.1', port: 3000 }] }
|
||||
};
|
||||
expect(routeMatchesDomain(route, 'example.com')).toBeTrue();
|
||||
expect(routeMatchesDomain(route, 'other.com')).toBeFalse();
|
||||
});
|
||||
|
||||
tap.test('domain matching - wildcard domain', async () => {
|
||||
const route = createHttpRoute('*.example.com', { host: '127.0.0.1', port: 3000 });
|
||||
const route: IRouteConfig = {
|
||||
match: { ports: 80, domains: '*.example.com' },
|
||||
action: { type: 'forward', targets: [{ host: '127.0.0.1', port: 3000 }] }
|
||||
};
|
||||
expect(routeMatchesDomain(route, 'sub.example.com')).toBeTrue();
|
||||
expect(routeMatchesDomain(route, 'example.com')).toBeFalse();
|
||||
});
|
||||
|
||||
tap.test('port matching - single port', async () => {
|
||||
const route = createHttpRoute('example.com', { host: '127.0.0.1', port: 3000 });
|
||||
// createHttpRoute defaults to port 80
|
||||
const route: IRouteConfig = {
|
||||
match: { ports: 80, domains: 'example.com' },
|
||||
action: { type: 'forward', targets: [{ host: '127.0.0.1', port: 3000 }] }
|
||||
};
|
||||
expect(routeMatchesPort(route, 80)).toBeTrue();
|
||||
expect(routeMatchesPort(route, 443)).toBeFalse();
|
||||
});
|
||||
|
||||
tap.test('route finding - findBestMatchingRoute selects by priority', async () => {
|
||||
const lowPriority = createHttpRoute('example.com', { host: '127.0.0.1', port: 3000 });
|
||||
lowPriority.priority = 10;
|
||||
const lowPriority: IRouteConfig = {
|
||||
match: { ports: 80, domains: 'example.com' },
|
||||
action: { type: 'forward', targets: [{ host: '127.0.0.1', port: 3000 }] },
|
||||
priority: 10
|
||||
};
|
||||
|
||||
const highPriority = createHttpRoute('example.com', { host: '127.0.0.1', port: 4000 });
|
||||
highPriority.priority = 100;
|
||||
const highPriority: IRouteConfig = {
|
||||
match: { ports: 80, domains: 'example.com' },
|
||||
action: { type: 'forward', targets: [{ host: '127.0.0.1', port: 4000 }] },
|
||||
priority: 100
|
||||
};
|
||||
|
||||
const routes: IRouteConfig[] = [lowPriority, highPriority];
|
||||
const best = findBestMatchingRoute(routes, { domain: 'example.com', port: 80 });
|
||||
@@ -100,9 +95,18 @@ tap.test('route finding - findBestMatchingRoute selects by priority', async () =
|
||||
});
|
||||
|
||||
tap.test('route finding - findMatchingRoutes returns all matches', async () => {
|
||||
const route1 = createHttpRoute('example.com', { host: '127.0.0.1', port: 3000 });
|
||||
const route2 = createHttpRoute('example.com', { host: '127.0.0.1', port: 4000 });
|
||||
const route3 = createHttpRoute('other.com', { host: '127.0.0.1', port: 5000 });
|
||||
const route1: IRouteConfig = {
|
||||
match: { ports: 80, domains: 'example.com' },
|
||||
action: { type: 'forward', targets: [{ host: '127.0.0.1', port: 3000 }] }
|
||||
};
|
||||
const route2: IRouteConfig = {
|
||||
match: { ports: 80, domains: 'example.com' },
|
||||
action: { type: 'forward', targets: [{ host: '127.0.0.1', port: 4000 }] }
|
||||
};
|
||||
const route3: IRouteConfig = {
|
||||
match: { ports: 80, domains: 'other.com' },
|
||||
action: { type: 'forward', targets: [{ host: '127.0.0.1', port: 5000 }] }
|
||||
};
|
||||
|
||||
const matches = findMatchingRoutes([route1, route2, route3], { domain: 'example.com', port: 80 });
|
||||
expect(matches).toHaveLength(2);
|
||||
|
||||
@@ -0,0 +1,191 @@
|
||||
import { expect, tap } from '@git.zone/tstest/tapbundle';
|
||||
import { SmartProxy } from '../ts/index.js';
|
||||
import * as http from 'http';
|
||||
import * as net from 'net';
|
||||
import * as tls from 'tls';
|
||||
import * as fs from 'fs';
|
||||
import * as path from 'path';
|
||||
import { fileURLToPath } from 'url';
|
||||
import { assertPortsFree, findFreePorts } from './helpers/port-allocator.js';
|
||||
|
||||
const __filename = fileURLToPath(import.meta.url);
|
||||
const __dirname = path.dirname(__filename);
|
||||
|
||||
const CERT_PEM = fs.readFileSync(path.join(__dirname, '..', 'assets', 'certs', 'cert.pem'), 'utf8');
|
||||
const KEY_PEM = fs.readFileSync(path.join(__dirname, '..', 'assets', 'certs', 'key.pem'), 'utf8');
|
||||
|
||||
let httpBackendPort: number;
|
||||
let tlsBackendPort: number;
|
||||
let httpProxyPort: number;
|
||||
let tlsProxyPort: number;
|
||||
|
||||
let httpBackend: http.Server;
|
||||
let tlsBackend: tls.Server;
|
||||
let proxy: SmartProxy;
|
||||
|
||||
async function pollMetrics(proxyToPoll: SmartProxy): Promise<void> {
|
||||
await (proxyToPoll as any).metricsAdapter.poll();
|
||||
}
|
||||
|
||||
async function waitForCondition(
|
||||
callback: () => Promise<boolean>,
|
||||
timeoutMs: number = 5000,
|
||||
stepMs: number = 100,
|
||||
): Promise<void> {
|
||||
const deadline = Date.now() + timeoutMs;
|
||||
while (Date.now() < deadline) {
|
||||
if (await callback()) {
|
||||
return;
|
||||
}
|
||||
await new Promise((resolve) => setTimeout(resolve, stepMs));
|
||||
}
|
||||
throw new Error(`Condition not met within ${timeoutMs}ms`);
|
||||
}
|
||||
|
||||
function hasIpDomainRequest(domain: string): boolean {
|
||||
const byIp = proxy.getMetrics().connections.domainRequestsByIP();
|
||||
for (const domainMap of byIp.values()) {
|
||||
if (domainMap.has(domain)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
tap.test('setup - backend servers for HTTP domain rate metrics', async () => {
|
||||
[httpBackendPort, tlsBackendPort, httpProxyPort, tlsProxyPort] = await findFreePorts(4);
|
||||
|
||||
httpBackend = http.createServer((req, res) => {
|
||||
let body = '';
|
||||
req.on('data', (chunk) => {
|
||||
body += chunk;
|
||||
});
|
||||
req.on('end', () => {
|
||||
res.writeHead(200, { 'Content-Type': 'text/plain' });
|
||||
res.end(`ok:${body}`);
|
||||
});
|
||||
});
|
||||
await new Promise<void>((resolve) => {
|
||||
httpBackend.listen(httpBackendPort, () => resolve());
|
||||
});
|
||||
|
||||
tlsBackend = tls.createServer({ cert: CERT_PEM, key: KEY_PEM }, (socket) => {
|
||||
socket.on('data', (data) => {
|
||||
socket.write(data);
|
||||
});
|
||||
socket.on('error', () => {});
|
||||
});
|
||||
await new Promise<void>((resolve) => {
|
||||
tlsBackend.listen(tlsBackendPort, () => resolve());
|
||||
});
|
||||
});
|
||||
|
||||
tap.test('setup - start proxy with HTTP and TLS passthrough routes', async () => {
|
||||
proxy = new SmartProxy({
|
||||
routes: [
|
||||
{
|
||||
id: 'http-domain-rates',
|
||||
name: 'http-domain-rates',
|
||||
match: { ports: httpProxyPort, domains: 'example.com' },
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{ host: 'localhost', port: httpBackendPort }],
|
||||
},
|
||||
},
|
||||
{
|
||||
id: 'tls-passthrough-domain-rates',
|
||||
name: 'tls-passthrough-domain-rates',
|
||||
match: { ports: tlsProxyPort, domains: 'passthrough.example.com' },
|
||||
action: {
|
||||
type: 'forward',
|
||||
tls: { mode: 'passthrough' },
|
||||
targets: [{ host: 'localhost', port: tlsBackendPort }],
|
||||
},
|
||||
},
|
||||
],
|
||||
metrics: { enabled: true, sampleIntervalMs: 100, retentionSeconds: 60 },
|
||||
});
|
||||
|
||||
await proxy.start();
|
||||
await new Promise((resolve) => setTimeout(resolve, 300));
|
||||
});
|
||||
|
||||
tap.test('HTTP requests populate per-domain HTTP request rates', async () => {
|
||||
for (let i = 0; i < 3; i++) {
|
||||
await new Promise<void>((resolve, reject) => {
|
||||
const body = `payload-${i}`;
|
||||
const req = http.request(
|
||||
{
|
||||
hostname: 'localhost',
|
||||
port: httpProxyPort,
|
||||
path: '/echo',
|
||||
method: 'POST',
|
||||
headers: {
|
||||
Host: 'Example.COM',
|
||||
'Content-Type': 'text/plain',
|
||||
'Content-Length': String(body.length),
|
||||
},
|
||||
},
|
||||
(res) => {
|
||||
res.resume();
|
||||
res.on('end', () => resolve());
|
||||
},
|
||||
);
|
||||
req.on('error', reject);
|
||||
req.end(body);
|
||||
});
|
||||
}
|
||||
|
||||
await waitForCondition(async () => {
|
||||
await pollMetrics(proxy);
|
||||
const domainMetrics = proxy.getMetrics().requests.byDomain().get('example.com');
|
||||
return (domainMetrics?.lastMinute ?? 0) >= 3 && (domainMetrics?.perSecond ?? 0) > 0;
|
||||
});
|
||||
|
||||
const exampleMetrics = proxy.getMetrics().requests.byDomain().get('example.com');
|
||||
expect(exampleMetrics).toBeTruthy();
|
||||
expect(exampleMetrics?.lastMinute).toEqual(3);
|
||||
expect(exampleMetrics?.perSecond).toBeGreaterThan(0);
|
||||
});
|
||||
|
||||
tap.test('TLS passthrough SNI does not inflate HTTP domain request rates', async () => {
|
||||
const tlsClient = tls.connect({
|
||||
host: 'localhost',
|
||||
port: tlsProxyPort,
|
||||
servername: 'passthrough.example.com',
|
||||
rejectUnauthorized: false,
|
||||
});
|
||||
|
||||
await new Promise<void>((resolve, reject) => {
|
||||
tlsClient.once('secureConnect', () => resolve());
|
||||
tlsClient.once('error', reject);
|
||||
});
|
||||
|
||||
const echoPromise = new Promise<void>((resolve, reject) => {
|
||||
tlsClient.once('data', () => resolve());
|
||||
tlsClient.once('error', reject);
|
||||
});
|
||||
tlsClient.write(Buffer.from('hello over tls passthrough'));
|
||||
await echoPromise;
|
||||
|
||||
await waitForCondition(async () => {
|
||||
await pollMetrics(proxy);
|
||||
return hasIpDomainRequest('passthrough.example.com');
|
||||
});
|
||||
|
||||
const requestRates = proxy.getMetrics().requests.byDomain();
|
||||
expect(requestRates.has('passthrough.example.com')).toBeFalse();
|
||||
expect(requestRates.get('example.com')?.lastMinute).toEqual(3);
|
||||
expect(hasIpDomainRequest('passthrough.example.com')).toBeTrue();
|
||||
|
||||
tlsClient.destroy();
|
||||
});
|
||||
|
||||
tap.test('cleanup - stop proxy and close backend servers', async () => {
|
||||
await proxy.stop();
|
||||
await new Promise<void>((resolve) => httpBackend.close(() => resolve()));
|
||||
await new Promise<void>((resolve) => tlsBackend.close(() => resolve()));
|
||||
await assertPortsFree([httpBackendPort, tlsBackendPort, httpProxyPort, tlsProxyPort]);
|
||||
});
|
||||
|
||||
export default tap.start()
|
||||
@@ -2,146 +2,101 @@ import * as path from 'path';
|
||||
import { tap, expect } from '@git.zone/tstest/tapbundle';
|
||||
|
||||
import { SmartProxy } from '../ts/proxies/smart-proxy/index.js';
|
||||
import {
|
||||
createHttpRoute,
|
||||
createHttpsTerminateRoute,
|
||||
createHttpsPassthroughRoute,
|
||||
createHttpToHttpsRedirect,
|
||||
createCompleteHttpsServer,
|
||||
createLoadBalancerRoute,
|
||||
createApiRoute,
|
||||
createWebSocketRoute
|
||||
} from '../ts/proxies/smart-proxy/utils/route-helpers.js';
|
||||
import { SocketHandlers } from '../ts/proxies/smart-proxy/utils/socket-handlers.js';
|
||||
import type { IRouteConfig } from '../ts/proxies/smart-proxy/models/route-types.js';
|
||||
|
||||
// Test to demonstrate various route configurations using the new helpers
|
||||
tap.test('Route-based configuration examples', async (tools) => {
|
||||
// Example 1: HTTP-only configuration
|
||||
const httpOnlyRoute = createHttpRoute(
|
||||
'http.example.com',
|
||||
{
|
||||
host: 'localhost',
|
||||
port: 3000
|
||||
},
|
||||
{
|
||||
name: 'Basic HTTP Route'
|
||||
}
|
||||
);
|
||||
const httpOnlyRoute: IRouteConfig = {
|
||||
match: { ports: 80, domains: 'http.example.com' },
|
||||
action: { type: 'forward', targets: [{ host: 'localhost', port: 3000 }] },
|
||||
name: 'Basic HTTP Route'
|
||||
};
|
||||
|
||||
console.log('HTTP-only route created successfully:', httpOnlyRoute.name);
|
||||
expect(httpOnlyRoute.action.type).toEqual('forward');
|
||||
expect(httpOnlyRoute.match.domains).toEqual('http.example.com');
|
||||
|
||||
// Example 2: HTTPS Passthrough (SNI) configuration
|
||||
const httpsPassthroughRoute = createHttpsPassthroughRoute(
|
||||
'pass.example.com',
|
||||
{
|
||||
host: ['10.0.0.1', '10.0.0.2'], // Round-robin target IPs
|
||||
port: 443
|
||||
},
|
||||
{
|
||||
name: 'HTTPS Passthrough Route'
|
||||
}
|
||||
);
|
||||
const httpsPassthroughRoute: IRouteConfig = {
|
||||
match: { ports: 443, domains: 'pass.example.com' },
|
||||
action: { type: 'forward', targets: [{ host: '10.0.0.1', port: 443 }, { host: '10.0.0.2', port: 443 }], tls: { mode: 'passthrough' } },
|
||||
name: 'HTTPS Passthrough Route'
|
||||
};
|
||||
|
||||
expect(httpsPassthroughRoute).toBeTruthy();
|
||||
expect(httpsPassthroughRoute.action.tls?.mode).toEqual('passthrough');
|
||||
expect(Array.isArray(httpsPassthroughRoute.action.targets)).toBeTrue();
|
||||
|
||||
// Example 3: HTTPS Termination to HTTP Backend
|
||||
const terminateToHttpRoute = createHttpsTerminateRoute(
|
||||
'secure.example.com',
|
||||
{
|
||||
host: 'localhost',
|
||||
port: 8080
|
||||
},
|
||||
{
|
||||
certificate: 'auto',
|
||||
name: 'HTTPS Termination to HTTP Backend'
|
||||
}
|
||||
);
|
||||
const terminateToHttpRoute: IRouteConfig = {
|
||||
match: { ports: 443, domains: 'secure.example.com' },
|
||||
action: { type: 'forward', targets: [{ host: 'localhost', port: 8080 }], tls: { mode: 'terminate', certificate: 'auto' } },
|
||||
name: 'HTTPS Termination to HTTP Backend'
|
||||
};
|
||||
|
||||
// Create the HTTP to HTTPS redirect for this domain
|
||||
const httpToHttpsRedirect = createHttpToHttpsRedirect(
|
||||
'secure.example.com',
|
||||
443,
|
||||
{
|
||||
name: 'HTTP to HTTPS Redirect for secure.example.com'
|
||||
}
|
||||
);
|
||||
const httpToHttpsRedirect: IRouteConfig = {
|
||||
match: { ports: 80, domains: 'secure.example.com' },
|
||||
action: { type: 'socket-handler', socketHandler: SocketHandlers.httpRedirect('https://{domain}:443{path}', 301) },
|
||||
name: 'HTTP to HTTPS Redirect for secure.example.com'
|
||||
};
|
||||
|
||||
expect(terminateToHttpRoute).toBeTruthy();
|
||||
expect(terminateToHttpRoute.action.tls?.mode).toEqual('terminate');
|
||||
expect(httpToHttpsRedirect.action.type).toEqual('socket-handler');
|
||||
|
||||
// Example 4: Load Balancer with HTTPS
|
||||
const loadBalancerRoute = createLoadBalancerRoute(
|
||||
'proxy.example.com',
|
||||
['internal-api-1.local', 'internal-api-2.local'],
|
||||
8443,
|
||||
{
|
||||
tls: {
|
||||
mode: 'terminate-and-reencrypt',
|
||||
certificate: 'auto'
|
||||
},
|
||||
name: 'Load Balanced HTTPS Route'
|
||||
}
|
||||
);
|
||||
const loadBalancerRoute: IRouteConfig = {
|
||||
match: { ports: 443, domains: 'proxy.example.com' },
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [
|
||||
{ host: 'internal-api-1.local', port: 8443 },
|
||||
{ host: 'internal-api-2.local', port: 8443 }
|
||||
],
|
||||
tls: { mode: 'terminate-and-reencrypt', certificate: 'auto' }
|
||||
},
|
||||
name: 'Load Balanced HTTPS Route'
|
||||
};
|
||||
|
||||
expect(loadBalancerRoute).toBeTruthy();
|
||||
expect(loadBalancerRoute.action.tls?.mode).toEqual('terminate-and-reencrypt');
|
||||
expect(Array.isArray(loadBalancerRoute.action.targets)).toBeTrue();
|
||||
|
||||
// Example 5: API Route
|
||||
const apiRoute = createApiRoute(
|
||||
'api.example.com',
|
||||
'/api',
|
||||
{ host: 'localhost', port: 8081 },
|
||||
{
|
||||
name: 'API Route',
|
||||
useTls: true,
|
||||
addCorsHeaders: true
|
||||
}
|
||||
);
|
||||
const apiRoute: IRouteConfig = {
|
||||
match: { ports: 443, domains: 'api.example.com', path: '/api' },
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{ host: 'localhost', port: 8081 }],
|
||||
tls: { mode: 'terminate', certificate: 'auto' }
|
||||
},
|
||||
name: 'API Route'
|
||||
};
|
||||
|
||||
expect(apiRoute.action.type).toEqual('forward');
|
||||
expect(apiRoute.match.path).toBeTruthy();
|
||||
|
||||
// Example 6: Complete HTTPS Server with HTTP Redirect
|
||||
const httpsServerRoutes = createCompleteHttpsServer(
|
||||
'complete.example.com',
|
||||
{
|
||||
host: 'localhost',
|
||||
port: 8080
|
||||
const httpsRoute: IRouteConfig = {
|
||||
match: { ports: 443, domains: 'complete.example.com' },
|
||||
action: { type: 'forward', targets: [{ host: 'localhost', port: 8080 }], tls: { mode: 'terminate', certificate: 'auto' } },
|
||||
name: 'Complete HTTPS Server'
|
||||
};
|
||||
|
||||
const httpsRedirectRoute: IRouteConfig = {
|
||||
match: { ports: 80, domains: 'complete.example.com' },
|
||||
action: { type: 'socket-handler', socketHandler: SocketHandlers.httpRedirect('https://{domain}:443{path}', 301) },
|
||||
name: 'Complete HTTPS Server - Redirect'
|
||||
};
|
||||
|
||||
const webSocketRoute: IRouteConfig = {
|
||||
match: { ports: 443, domains: 'ws.example.com', path: '/ws' },
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{ host: 'localhost', port: 8082 }],
|
||||
tls: { mode: 'terminate', certificate: 'auto' },
|
||||
websocket: { enabled: true }
|
||||
},
|
||||
{
|
||||
certificate: 'auto',
|
||||
name: 'Complete HTTPS Server'
|
||||
}
|
||||
);
|
||||
|
||||
expect(Array.isArray(httpsServerRoutes)).toBeTrue();
|
||||
expect(httpsServerRoutes.length).toEqual(2); // HTTPS route and HTTP redirect
|
||||
expect(httpsServerRoutes[0].action.tls?.mode).toEqual('terminate');
|
||||
expect(httpsServerRoutes[1].action.type).toEqual('socket-handler');
|
||||
|
||||
// Example 7: Static File Server - removed (use nginx/apache behind proxy)
|
||||
|
||||
// Example 8: WebSocket Route
|
||||
const webSocketRoute = createWebSocketRoute(
|
||||
'ws.example.com',
|
||||
'/ws',
|
||||
{ host: 'localhost', port: 8082 },
|
||||
{
|
||||
useTls: true,
|
||||
name: 'WebSocket Route'
|
||||
}
|
||||
);
|
||||
name: 'WebSocket Route'
|
||||
};
|
||||
|
||||
expect(webSocketRoute.action.type).toEqual('forward');
|
||||
expect(webSocketRoute.action.websocket?.enabled).toBeTrue();
|
||||
|
||||
// Create a SmartProxy instance with all routes
|
||||
const allRoutes: IRouteConfig[] = [
|
||||
httpOnlyRoute,
|
||||
httpsPassthroughRoute,
|
||||
@@ -149,19 +104,17 @@ tap.test('Route-based configuration examples', async (tools) => {
|
||||
httpToHttpsRedirect,
|
||||
loadBalancerRoute,
|
||||
apiRoute,
|
||||
...httpsServerRoutes,
|
||||
httpsRoute,
|
||||
httpsRedirectRoute,
|
||||
webSocketRoute
|
||||
];
|
||||
|
||||
// We're not actually starting the SmartProxy in this test,
|
||||
// just verifying that the configuration is valid
|
||||
const smartProxy = new SmartProxy({
|
||||
routes: allRoutes
|
||||
});
|
||||
|
||||
// Just verify that all routes are configured correctly
|
||||
console.log(`Created ${allRoutes.length} example routes`);
|
||||
expect(allRoutes.length).toEqual(9); // One less without static file route
|
||||
expect(allRoutes.length).toEqual(9);
|
||||
});
|
||||
|
||||
export default tap.start();
|
||||
export default tap.start();
|
||||
|
||||
+18
-48
@@ -1,27 +1,8 @@
|
||||
import { tap, expect } from '@git.zone/tstest/tapbundle';
|
||||
import * as plugins from '../ts/plugins.js';
|
||||
|
||||
// Import route-based helpers
|
||||
import {
|
||||
createHttpRoute,
|
||||
createHttpsTerminateRoute,
|
||||
createHttpsPassthroughRoute,
|
||||
createHttpToHttpsRedirect,
|
||||
createCompleteHttpsServer
|
||||
} from '../ts/proxies/smart-proxy/utils/route-helpers.js';
|
||||
import type { IRouteConfig } from '../ts/proxies/smart-proxy/models/route-types.js';
|
||||
|
||||
// Create helper functions for backward compatibility
|
||||
const helpers = {
|
||||
httpOnly: (domains: string | string[], target: any) => createHttpRoute(domains, target),
|
||||
tlsTerminateToHttp: (domains: string | string[], target: any) =>
|
||||
createHttpsTerminateRoute(domains, target),
|
||||
tlsTerminateToHttps: (domains: string | string[], target: any) =>
|
||||
createHttpsTerminateRoute(domains, target, { reencrypt: true }),
|
||||
httpsPassthrough: (domains: string | string[], target: any) =>
|
||||
createHttpsPassthroughRoute(domains, target)
|
||||
};
|
||||
|
||||
// Route-based utility functions for testing
|
||||
function findRouteForDomain(routes: any[], domain: string): any {
|
||||
return routes.find(route => {
|
||||
const domains = Array.isArray(route.match.domains)
|
||||
@@ -31,55 +12,44 @@ function findRouteForDomain(routes: any[], domain: string): any {
|
||||
});
|
||||
}
|
||||
|
||||
// Replace the old test with route-based tests
|
||||
tap.test('Route Helpers - Create HTTP routes', async () => {
|
||||
const route = helpers.httpOnly('example.com', { host: 'localhost', port: 3000 });
|
||||
const route: IRouteConfig = {
|
||||
match: { ports: 80, domains: 'example.com' },
|
||||
action: { type: 'forward', targets: [{ host: 'localhost', port: 3000 }] }
|
||||
};
|
||||
expect(route.action.type).toEqual('forward');
|
||||
expect(route.match.domains).toEqual('example.com');
|
||||
expect(route.action.targets?.[0]).toEqual({ host: 'localhost', port: 3000 });
|
||||
});
|
||||
|
||||
tap.test('Route Helpers - Create HTTPS terminate to HTTP routes', async () => {
|
||||
const route = helpers.tlsTerminateToHttp('secure.example.com', { host: 'localhost', port: 3000 });
|
||||
const route: IRouteConfig = {
|
||||
match: { ports: 443, domains: 'secure.example.com' },
|
||||
action: { type: 'forward', targets: [{ host: 'localhost', port: 3000 }], tls: { mode: 'terminate', certificate: 'auto' } }
|
||||
};
|
||||
expect(route.action.type).toEqual('forward');
|
||||
expect(route.match.domains).toEqual('secure.example.com');
|
||||
expect(route.action.tls?.mode).toEqual('terminate');
|
||||
});
|
||||
|
||||
tap.test('Route Helpers - Create HTTPS passthrough routes', async () => {
|
||||
const route = helpers.httpsPassthrough('passthrough.example.com', { host: 'backend', port: 443 });
|
||||
const route: IRouteConfig = {
|
||||
match: { ports: 443, domains: 'passthrough.example.com' },
|
||||
action: { type: 'forward', targets: [{ host: 'backend', port: 443 }], tls: { mode: 'passthrough' } }
|
||||
};
|
||||
expect(route.action.type).toEqual('forward');
|
||||
expect(route.match.domains).toEqual('passthrough.example.com');
|
||||
expect(route.action.tls?.mode).toEqual('passthrough');
|
||||
});
|
||||
|
||||
tap.test('Route Helpers - Create HTTPS to HTTPS routes', async () => {
|
||||
const route = helpers.tlsTerminateToHttps('reencrypt.example.com', { host: 'backend', port: 443 });
|
||||
const route: IRouteConfig = {
|
||||
match: { ports: 443, domains: 'reencrypt.example.com' },
|
||||
action: { type: 'forward', targets: [{ host: 'backend', port: 443 }], tls: { mode: 'terminate-and-reencrypt', certificate: 'auto' } }
|
||||
};
|
||||
expect(route.action.type).toEqual('forward');
|
||||
expect(route.match.domains).toEqual('reencrypt.example.com');
|
||||
expect(route.action.tls?.mode).toEqual('terminate-and-reencrypt');
|
||||
});
|
||||
|
||||
tap.test('Route Helpers - Create complete HTTPS server with redirect', async () => {
|
||||
const routes = createCompleteHttpsServer(
|
||||
'full.example.com',
|
||||
{ host: 'localhost', port: 3000 },
|
||||
{ certificate: 'auto' }
|
||||
);
|
||||
|
||||
expect(routes.length).toEqual(2);
|
||||
|
||||
// Check HTTP to HTTPS redirect - find route by port
|
||||
const redirectRoute = routes.find(r => r.match.ports === 80);
|
||||
expect(redirectRoute.action.type).toEqual('socket-handler');
|
||||
expect(redirectRoute.action.socketHandler).toBeDefined();
|
||||
expect(redirectRoute.match.ports).toEqual(80);
|
||||
|
||||
// Check HTTPS route
|
||||
const httpsRoute = routes.find(r => r.action.type === 'forward');
|
||||
expect(httpsRoute.match.ports).toEqual(443);
|
||||
expect(httpsRoute.action.tls?.mode).toEqual('terminate');
|
||||
});
|
||||
|
||||
// Export test runner
|
||||
export default tap.start();
|
||||
export default tap.start();
|
||||
|
||||
@@ -83,6 +83,9 @@ tap.test('should verify new metrics API structure', async () => {
|
||||
expect(metrics.throughput).toHaveProperty('history');
|
||||
expect(metrics.throughput).toHaveProperty('byRoute');
|
||||
expect(metrics.throughput).toHaveProperty('byIP');
|
||||
|
||||
// Check request methods
|
||||
expect(metrics.requests).toHaveProperty('byDomain');
|
||||
});
|
||||
|
||||
tap.test('should track active connections', async (tools) => {
|
||||
@@ -273,4 +276,4 @@ tap.test('should clean up resources', async () => {
|
||||
await assertPortsFree([echoServerPort, proxyPort]);
|
||||
});
|
||||
|
||||
export default tap.start();
|
||||
export default tap.start();
|
||||
|
||||
@@ -1,15 +1,14 @@
|
||||
import { SmartProxy } from '../ts/proxies/smart-proxy/index.js';
|
||||
import { createNfTablesRoute, createNfTablesTerminateRoute } from '../ts/proxies/smart-proxy/utils/route-helpers.js';
|
||||
import { expect, tap } from '@git.zone/tstest/tapbundle';
|
||||
import * as child_process from 'child_process';
|
||||
import { promisify } from 'util';
|
||||
|
||||
import type { IRouteConfig } from '../ts/proxies/smart-proxy/models/route-types.js';
|
||||
|
||||
const exec = promisify(child_process.exec);
|
||||
|
||||
// Check if we have root privileges to run NFTables tests
|
||||
async function checkRootPrivileges(): Promise<boolean> {
|
||||
try {
|
||||
// Check if we're running as root
|
||||
const { stdout } = await exec('id -u');
|
||||
return stdout.trim() === '0';
|
||||
} catch (err) {
|
||||
@@ -17,7 +16,6 @@ async function checkRootPrivileges(): Promise<boolean> {
|
||||
}
|
||||
}
|
||||
|
||||
// Check if tests should run
|
||||
const isRoot = await checkRootPrivileges();
|
||||
|
||||
if (!isRoot) {
|
||||
@@ -29,68 +27,70 @@ if (!isRoot) {
|
||||
console.log('');
|
||||
}
|
||||
|
||||
// Define the test with proper skip condition
|
||||
const testFn = isRoot ? tap.test : tap.skip.test;
|
||||
|
||||
testFn('NFTables integration tests', async () => {
|
||||
|
||||
|
||||
console.log('Running NFTables tests with root privileges');
|
||||
|
||||
// Create test routes
|
||||
const routes = [
|
||||
createNfTablesRoute('tcp-forward', {
|
||||
host: 'localhost',
|
||||
port: 8080
|
||||
}, {
|
||||
ports: 9080,
|
||||
protocol: 'tcp'
|
||||
}),
|
||||
|
||||
createNfTablesRoute('udp-forward', {
|
||||
host: 'localhost',
|
||||
port: 5353
|
||||
}, {
|
||||
ports: 5354,
|
||||
protocol: 'udp'
|
||||
}),
|
||||
|
||||
createNfTablesRoute('port-range', {
|
||||
host: 'localhost',
|
||||
port: 8080
|
||||
}, {
|
||||
ports: [{ from: 9000, to: 9100 }],
|
||||
protocol: 'tcp'
|
||||
})
|
||||
|
||||
const routes: IRouteConfig[] = [
|
||||
{
|
||||
match: { ports: 9080 },
|
||||
action: {
|
||||
type: 'forward',
|
||||
forwardingEngine: 'nftables',
|
||||
targets: [{ host: 'localhost', port: 8080 }],
|
||||
nftables: { protocol: 'tcp' }
|
||||
},
|
||||
name: 'tcp-forward'
|
||||
},
|
||||
|
||||
{
|
||||
match: { ports: 5354 },
|
||||
action: {
|
||||
type: 'forward',
|
||||
forwardingEngine: 'nftables',
|
||||
targets: [{ host: 'localhost', port: 5353 }],
|
||||
nftables: { protocol: 'udp' }
|
||||
},
|
||||
name: 'udp-forward'
|
||||
},
|
||||
|
||||
{
|
||||
match: { ports: [{ from: 9000, to: 9100 }] },
|
||||
action: {
|
||||
type: 'forward',
|
||||
forwardingEngine: 'nftables',
|
||||
targets: [{ host: 'localhost', port: 8080 }],
|
||||
nftables: { protocol: 'tcp' }
|
||||
},
|
||||
name: 'port-range'
|
||||
}
|
||||
];
|
||||
|
||||
|
||||
const smartProxy = new SmartProxy({
|
||||
enableDetailedLogging: true,
|
||||
routes
|
||||
});
|
||||
|
||||
// Start the proxy
|
||||
|
||||
await smartProxy.start();
|
||||
console.log('SmartProxy started with NFTables routes');
|
||||
|
||||
// Get NFTables status
|
||||
|
||||
const status = await smartProxy.getNfTablesStatus();
|
||||
console.log('NFTables status:', JSON.stringify(status, null, 2));
|
||||
|
||||
// Verify all routes are provisioned
|
||||
|
||||
expect(Object.keys(status).length).toEqual(routes.length);
|
||||
|
||||
|
||||
for (const routeStatus of Object.values(status)) {
|
||||
expect(routeStatus.active).toBeTrue();
|
||||
expect(routeStatus.ruleCount.total).toBeGreaterThan(0);
|
||||
}
|
||||
|
||||
// Stop the proxy
|
||||
|
||||
await smartProxy.stop();
|
||||
console.log('SmartProxy stopped');
|
||||
|
||||
// Verify all rules are cleaned up
|
||||
|
||||
const finalStatus = await smartProxy.getNfTablesStatus();
|
||||
expect(Object.keys(finalStatus).length).toEqual(0);
|
||||
});
|
||||
|
||||
export default tap.start();
|
||||
export default tap.start();
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import { SmartProxy } from '../ts/proxies/smart-proxy/index.js';
|
||||
import { createNfTablesRoute, createNfTablesTerminateRoute } from '../ts/proxies/smart-proxy/utils/route-helpers.js';
|
||||
import { expect, tap } from '@git.zone/tstest/tapbundle';
|
||||
import * as net from 'net';
|
||||
import * as http from 'http';
|
||||
@@ -10,13 +9,13 @@ import { fileURLToPath } from 'url';
|
||||
import * as child_process from 'child_process';
|
||||
import { promisify } from 'util';
|
||||
|
||||
import type { IRouteConfig } from '../ts/proxies/smart-proxy/models/route-types.js';
|
||||
|
||||
const exec = promisify(child_process.exec);
|
||||
|
||||
// Get __dirname equivalent for ES modules
|
||||
const __filename = fileURLToPath(import.meta.url);
|
||||
const __dirname = path.dirname(__filename);
|
||||
|
||||
// Check if we have root privileges
|
||||
async function checkRootPrivileges(): Promise<boolean> {
|
||||
try {
|
||||
const { stdout } = await exec('id -u');
|
||||
@@ -26,7 +25,6 @@ async function checkRootPrivileges(): Promise<boolean> {
|
||||
}
|
||||
}
|
||||
|
||||
// Check if tests should run
|
||||
const runTests = await checkRootPrivileges();
|
||||
|
||||
if (!runTests) {
|
||||
@@ -36,10 +34,8 @@ if (!runTests) {
|
||||
console.log('Skipping NFTables integration tests');
|
||||
console.log('========================================');
|
||||
console.log('');
|
||||
// Skip tests when not running as root - tests are marked with tap.skip.test
|
||||
}
|
||||
|
||||
// Test server and client utilities
|
||||
let testTcpServer: net.Server;
|
||||
let testHttpServer: http.Server;
|
||||
let testHttpsServer: https.Server;
|
||||
@@ -53,10 +49,8 @@ const PROXY_HTTP_PORT = 5001;
|
||||
const PROXY_HTTPS_PORT = 5002;
|
||||
const TEST_DATA = 'Hello through NFTables!';
|
||||
|
||||
// Helper to create test certificates
|
||||
async function createTestCertificates() {
|
||||
try {
|
||||
// Import the certificate helper
|
||||
const certsModule = await import('./helpers/certificates.js');
|
||||
const certificates = certsModule.loadTestCertificates();
|
||||
return {
|
||||
@@ -65,7 +59,6 @@ async function createTestCertificates() {
|
||||
};
|
||||
} catch (err) {
|
||||
console.error('Failed to load test certificates:', err);
|
||||
// Use dummy certificates for testing
|
||||
return {
|
||||
cert: fs.readFileSync(path.join(__dirname, '..', 'assets', 'certs', 'cert.pem'), 'utf8'),
|
||||
key: fs.readFileSync(path.join(__dirname, '..', 'assets', 'certs', 'key.pem'), 'utf8')
|
||||
@@ -75,111 +68,112 @@ async function createTestCertificates() {
|
||||
|
||||
tap.skip.test('setup NFTables integration test environment', async () => {
|
||||
console.log('Running NFTables integration tests with root privileges');
|
||||
|
||||
// Create a basic TCP test server
|
||||
|
||||
testTcpServer = net.createServer((socket) => {
|
||||
socket.on('data', (data) => {
|
||||
socket.write(`Server says: ${data.toString()}`);
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
await new Promise<void>((resolve) => {
|
||||
testTcpServer.listen(TEST_TCP_PORT, () => {
|
||||
console.log(`TCP test server listening on port ${TEST_TCP_PORT}`);
|
||||
resolve();
|
||||
});
|
||||
});
|
||||
|
||||
// Create an HTTP test server
|
||||
|
||||
testHttpServer = http.createServer((req, res) => {
|
||||
res.writeHead(200, { 'Content-Type': 'text/plain' });
|
||||
res.end(`HTTP Server says: ${TEST_DATA}`);
|
||||
});
|
||||
|
||||
|
||||
await new Promise<void>((resolve) => {
|
||||
testHttpServer.listen(TEST_HTTP_PORT, () => {
|
||||
console.log(`HTTP test server listening on port ${TEST_HTTP_PORT}`);
|
||||
resolve();
|
||||
});
|
||||
});
|
||||
|
||||
// Create an HTTPS test server
|
||||
|
||||
const certs = await createTestCertificates();
|
||||
testHttpsServer = https.createServer({ key: certs.key, cert: certs.cert }, (req, res) => {
|
||||
res.writeHead(200, { 'Content-Type': 'text/plain' });
|
||||
res.end(`HTTPS Server says: ${TEST_DATA}`);
|
||||
});
|
||||
|
||||
|
||||
await new Promise<void>((resolve) => {
|
||||
testHttpsServer.listen(TEST_HTTPS_PORT, () => {
|
||||
console.log(`HTTPS test server listening on port ${TEST_HTTPS_PORT}`);
|
||||
resolve();
|
||||
});
|
||||
});
|
||||
|
||||
// Create SmartProxy with various NFTables routes
|
||||
|
||||
smartProxy = new SmartProxy({
|
||||
enableDetailedLogging: true,
|
||||
routes: [
|
||||
// TCP forwarding route
|
||||
createNfTablesRoute('tcp-nftables', {
|
||||
host: 'localhost',
|
||||
port: TEST_TCP_PORT
|
||||
}, {
|
||||
ports: PROXY_TCP_PORT,
|
||||
protocol: 'tcp'
|
||||
}),
|
||||
|
||||
// HTTP forwarding route
|
||||
createNfTablesRoute('http-nftables', {
|
||||
host: 'localhost',
|
||||
port: TEST_HTTP_PORT
|
||||
}, {
|
||||
ports: PROXY_HTTP_PORT,
|
||||
protocol: 'tcp'
|
||||
}),
|
||||
|
||||
// HTTPS termination route
|
||||
createNfTablesTerminateRoute('https-nftables.example.com', {
|
||||
host: 'localhost',
|
||||
port: TEST_HTTPS_PORT
|
||||
}, {
|
||||
ports: PROXY_HTTPS_PORT,
|
||||
protocol: 'tcp',
|
||||
certificate: certs
|
||||
}),
|
||||
|
||||
// Route with IP allow list
|
||||
createNfTablesRoute('secure-tcp', {
|
||||
host: 'localhost',
|
||||
port: TEST_TCP_PORT
|
||||
}, {
|
||||
ports: 5003,
|
||||
protocol: 'tcp',
|
||||
ipAllowList: ['127.0.0.1', '::1']
|
||||
}),
|
||||
|
||||
// Route with QoS settings
|
||||
createNfTablesRoute('qos-tcp', {
|
||||
host: 'localhost',
|
||||
port: TEST_TCP_PORT
|
||||
}, {
|
||||
ports: 5004,
|
||||
protocol: 'tcp',
|
||||
maxRate: '10mbps',
|
||||
priority: 1
|
||||
})
|
||||
{
|
||||
match: { ports: PROXY_TCP_PORT },
|
||||
action: {
|
||||
type: 'forward',
|
||||
forwardingEngine: 'nftables',
|
||||
targets: [{ host: 'localhost', port: TEST_TCP_PORT }],
|
||||
nftables: { protocol: 'tcp' }
|
||||
},
|
||||
name: 'tcp-nftables'
|
||||
},
|
||||
|
||||
{
|
||||
match: { ports: PROXY_HTTP_PORT },
|
||||
action: {
|
||||
type: 'forward',
|
||||
forwardingEngine: 'nftables',
|
||||
targets: [{ host: 'localhost', port: TEST_HTTP_PORT }],
|
||||
nftables: { protocol: 'tcp' }
|
||||
},
|
||||
name: 'http-nftables'
|
||||
},
|
||||
|
||||
{
|
||||
match: { ports: PROXY_HTTPS_PORT, domains: 'https-nftables.example.com' },
|
||||
action: {
|
||||
type: 'forward',
|
||||
forwardingEngine: 'nftables',
|
||||
targets: [{ host: 'localhost', port: TEST_HTTPS_PORT }],
|
||||
tls: { mode: 'terminate', certificate: 'auto' },
|
||||
nftables: { protocol: 'tcp' }
|
||||
},
|
||||
name: 'https-nftables'
|
||||
},
|
||||
|
||||
{
|
||||
match: { ports: 5003 },
|
||||
action: {
|
||||
type: 'forward',
|
||||
forwardingEngine: 'nftables',
|
||||
targets: [{ host: 'localhost', port: TEST_TCP_PORT }],
|
||||
nftables: { protocol: 'tcp', ipAllowList: ['127.0.0.1', '::1'] }
|
||||
},
|
||||
name: 'secure-tcp'
|
||||
},
|
||||
|
||||
{
|
||||
match: { ports: 5004 },
|
||||
action: {
|
||||
type: 'forward',
|
||||
forwardingEngine: 'nftables',
|
||||
targets: [{ host: 'localhost', port: TEST_TCP_PORT }],
|
||||
nftables: { protocol: 'tcp', maxRate: '10mbps', priority: 1 }
|
||||
},
|
||||
name: 'qos-tcp'
|
||||
}
|
||||
]
|
||||
});
|
||||
|
||||
|
||||
console.log('SmartProxy created, now starting...');
|
||||
|
||||
// Start the proxy
|
||||
|
||||
try {
|
||||
await smartProxy.start();
|
||||
console.log('SmartProxy started successfully');
|
||||
|
||||
// Verify proxy is listening on expected ports
|
||||
|
||||
const listeningPorts = smartProxy.getListeningPorts();
|
||||
console.log(`SmartProxy is listening on ports: ${listeningPorts.join(', ')}`);
|
||||
} catch (err) {
|
||||
@@ -190,8 +184,7 @@ tap.skip.test('setup NFTables integration test environment', async () => {
|
||||
|
||||
tap.skip.test('should forward TCP connections through NFTables', async () => {
|
||||
console.log(`Attempting to connect to proxy TCP port ${PROXY_TCP_PORT}...`);
|
||||
|
||||
// First verify our test server is running
|
||||
|
||||
try {
|
||||
const testClient = new net.Socket();
|
||||
await new Promise<void>((resolve, reject) => {
|
||||
@@ -205,40 +198,39 @@ tap.skip.test('should forward TCP connections through NFTables', async () => {
|
||||
} catch (err) {
|
||||
console.error(`Test server on port ${TEST_TCP_PORT} is not accessible: ${err}`);
|
||||
}
|
||||
|
||||
// Connect to the proxy port
|
||||
|
||||
const client = new net.Socket();
|
||||
|
||||
|
||||
const response = await new Promise<string>((resolve, reject) => {
|
||||
let responseData = '';
|
||||
const timeout = setTimeout(() => {
|
||||
client.destroy();
|
||||
reject(new Error(`Connection timeout after 5 seconds to proxy port ${PROXY_TCP_PORT}`));
|
||||
}, 5000);
|
||||
|
||||
|
||||
client.connect(PROXY_TCP_PORT, 'localhost', () => {
|
||||
console.log(`Connected to proxy port ${PROXY_TCP_PORT}, sending data...`);
|
||||
client.write(TEST_DATA);
|
||||
});
|
||||
|
||||
|
||||
client.on('data', (data) => {
|
||||
console.log(`Received data from proxy: ${data.toString()}`);
|
||||
responseData += data.toString();
|
||||
client.end();
|
||||
});
|
||||
|
||||
|
||||
client.on('end', () => {
|
||||
clearTimeout(timeout);
|
||||
resolve(responseData);
|
||||
});
|
||||
|
||||
|
||||
client.on('error', (err) => {
|
||||
clearTimeout(timeout);
|
||||
console.error(`Connection error on proxy port ${PROXY_TCP_PORT}: ${err.message}`);
|
||||
reject(err);
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
expect(response).toEqual(`Server says: ${TEST_DATA}`);
|
||||
});
|
||||
|
||||
@@ -254,21 +246,20 @@ tap.skip.test('should forward HTTP connections through NFTables', async () => {
|
||||
});
|
||||
}).on('error', reject);
|
||||
});
|
||||
|
||||
|
||||
expect(response).toEqual(`HTTP Server says: ${TEST_DATA}`);
|
||||
});
|
||||
|
||||
tap.skip.test('should handle HTTPS termination with NFTables', async () => {
|
||||
// Skip this test if running without proper certificates
|
||||
const response = await new Promise<string>((resolve, reject) => {
|
||||
const options = {
|
||||
hostname: 'localhost',
|
||||
port: PROXY_HTTPS_PORT,
|
||||
path: '/',
|
||||
method: 'GET',
|
||||
rejectUnauthorized: false // For self-signed cert
|
||||
rejectUnauthorized: false
|
||||
};
|
||||
|
||||
|
||||
https.get(options, (res) => {
|
||||
let data = '';
|
||||
res.on('data', (chunk) => {
|
||||
@@ -279,43 +270,40 @@ tap.skip.test('should handle HTTPS termination with NFTables', async () => {
|
||||
});
|
||||
}).on('error', reject);
|
||||
});
|
||||
|
||||
|
||||
expect(response).toEqual(`HTTPS Server says: ${TEST_DATA}`);
|
||||
});
|
||||
|
||||
tap.skip.test('should respect IP allow lists in NFTables', async () => {
|
||||
// This test should pass since we're connecting from localhost
|
||||
const client = new net.Socket();
|
||||
|
||||
|
||||
const connected = await new Promise<boolean>((resolve) => {
|
||||
const timeout = setTimeout(() => {
|
||||
client.destroy();
|
||||
resolve(false);
|
||||
}, 2000);
|
||||
|
||||
|
||||
client.connect(5003, 'localhost', () => {
|
||||
clearTimeout(timeout);
|
||||
client.end();
|
||||
resolve(true);
|
||||
});
|
||||
|
||||
|
||||
client.on('error', () => {
|
||||
clearTimeout(timeout);
|
||||
resolve(false);
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
expect(connected).toBeTrue();
|
||||
});
|
||||
|
||||
tap.skip.test('should get NFTables status', async () => {
|
||||
const status = await smartProxy.getNfTablesStatus();
|
||||
|
||||
// Check that we have status for our routes
|
||||
|
||||
const statusKeys = Object.keys(status);
|
||||
expect(statusKeys.length).toBeGreaterThan(0);
|
||||
|
||||
// Check status structure for one of the routes
|
||||
|
||||
const firstStatus = status[statusKeys[0]];
|
||||
expect(firstStatus).toHaveProperty('active');
|
||||
expect(firstStatus).toHaveProperty('ruleCount');
|
||||
@@ -324,21 +312,20 @@ tap.skip.test('should get NFTables status', async () => {
|
||||
});
|
||||
|
||||
tap.skip.test('cleanup NFTables integration test environment', async () => {
|
||||
// Stop the proxy and test servers
|
||||
await smartProxy.stop();
|
||||
|
||||
|
||||
await new Promise<void>((resolve) => {
|
||||
testTcpServer.close(() => {
|
||||
resolve();
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
await new Promise<void>((resolve) => {
|
||||
testHttpServer.close(() => {
|
||||
resolve();
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
await new Promise<void>((resolve) => {
|
||||
testHttpsServer.close(() => {
|
||||
resolve();
|
||||
@@ -346,4 +333,4 @@ tap.skip.test('cleanup NFTables integration test environment', async () => {
|
||||
});
|
||||
});
|
||||
|
||||
export default tap.start();
|
||||
export default tap.start();
|
||||
|
||||
+75
-111
@@ -1,30 +1,20 @@
|
||||
import { expect, tap } from '@git.zone/tstest/tapbundle';
|
||||
import * as net from 'net';
|
||||
import { SmartProxy } from '../ts/proxies/smart-proxy/index.js';
|
||||
import {
|
||||
createPortMappingRoute,
|
||||
createOffsetPortMappingRoute,
|
||||
createDynamicRoute,
|
||||
createSmartLoadBalancer,
|
||||
createPortOffset
|
||||
} from '../ts/proxies/smart-proxy/utils/route-helpers.js';
|
||||
import type { IRouteConfig, IRouteContext } from '../ts/proxies/smart-proxy/models/route-types.js';
|
||||
import { findFreePorts, assertPortsFree } from './helpers/port-allocator.js';
|
||||
|
||||
// Test server and client utilities
|
||||
let testServers: Array<{ server: net.Server; port: number }> = [];
|
||||
let smartProxy: SmartProxy;
|
||||
|
||||
let TEST_PORTS: number[]; // 3 test server ports
|
||||
let PROXY_PORTS: number[]; // 6 proxy ports
|
||||
let TEST_PORTS: number[];
|
||||
let PROXY_PORTS: number[];
|
||||
const TEST_DATA = 'Hello through dynamic port mapper!';
|
||||
|
||||
// Cleanup function to close all servers and proxies
|
||||
function cleanup() {
|
||||
console.log('Starting cleanup...');
|
||||
const promises = [];
|
||||
|
||||
// Close test servers
|
||||
|
||||
for (const { server, port } of testServers) {
|
||||
promises.push(new Promise<void>(resolve => {
|
||||
console.log(`Closing test server on port ${port}`);
|
||||
@@ -34,31 +24,28 @@ function cleanup() {
|
||||
});
|
||||
}));
|
||||
}
|
||||
|
||||
// Stop SmartProxy
|
||||
|
||||
if (smartProxy) {
|
||||
console.log('Stopping SmartProxy...');
|
||||
promises.push(smartProxy.stop().then(() => {
|
||||
console.log('SmartProxy stopped');
|
||||
}));
|
||||
}
|
||||
|
||||
|
||||
return Promise.all(promises);
|
||||
}
|
||||
|
||||
// Helper: Creates a test TCP server that listens on a given port
|
||||
function createTestServer(port: number): Promise<net.Server> {
|
||||
return new Promise((resolve) => {
|
||||
const server = net.createServer((socket) => {
|
||||
socket.on('data', (data) => {
|
||||
// Echo the received data back with a server identifier
|
||||
socket.write(`Server ${port} says: ${data.toString()}`);
|
||||
});
|
||||
socket.on('error', (error) => {
|
||||
console.error(`[Test Server] Socket error on port ${port}:`, error);
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
server.listen(port, () => {
|
||||
console.log(`[Test Server] Listening on port ${port}`);
|
||||
testServers.push({ server, port });
|
||||
@@ -67,32 +54,31 @@ function createTestServer(port: number): Promise<net.Server> {
|
||||
});
|
||||
}
|
||||
|
||||
// Helper: Creates a test client connection with timeout
|
||||
function createTestClient(port: number, data: string): Promise<string> {
|
||||
return new Promise((resolve, reject) => {
|
||||
const client = new net.Socket();
|
||||
let response = '';
|
||||
|
||||
|
||||
const timeout = setTimeout(() => {
|
||||
client.destroy();
|
||||
reject(new Error(`Client connection timeout to port ${port}`));
|
||||
}, 5000);
|
||||
|
||||
|
||||
client.connect(port, 'localhost', () => {
|
||||
console.log(`[Test Client] Connected to server on port ${port}`);
|
||||
client.write(data);
|
||||
});
|
||||
|
||||
|
||||
client.on('data', (chunk) => {
|
||||
response += chunk.toString();
|
||||
client.end();
|
||||
});
|
||||
|
||||
|
||||
client.on('end', () => {
|
||||
clearTimeout(timeout);
|
||||
resolve(response);
|
||||
});
|
||||
|
||||
|
||||
client.on('error', (error) => {
|
||||
clearTimeout(timeout);
|
||||
reject(error);
|
||||
@@ -100,123 +86,108 @@ function createTestClient(port: number, data: string): Promise<string> {
|
||||
});
|
||||
}
|
||||
|
||||
// Set up test environment
|
||||
tap.test('setup port mapping test environment', async () => {
|
||||
const allPorts = await findFreePorts(9);
|
||||
TEST_PORTS = allPorts.slice(0, 3);
|
||||
PROXY_PORTS = allPorts.slice(3, 9);
|
||||
|
||||
// Create multiple test servers on different ports
|
||||
await Promise.all([
|
||||
createTestServer(TEST_PORTS[0]),
|
||||
createTestServer(TEST_PORTS[1]),
|
||||
createTestServer(TEST_PORTS[2]),
|
||||
]);
|
||||
|
||||
// Compute dynamic offset between proxy and test ports
|
||||
const portOffset = TEST_PORTS[1] - PROXY_PORTS[1];
|
||||
|
||||
// Create a SmartProxy with dynamic port mapping routes
|
||||
smartProxy = new SmartProxy({
|
||||
routes: [
|
||||
// Simple function that returns the same port (identity mapping)
|
||||
createPortMappingRoute({
|
||||
sourcePortRange: PROXY_PORTS[0],
|
||||
targetHost: 'localhost',
|
||||
portMapper: (context) => TEST_PORTS[0],
|
||||
name: 'Identity Port Mapping'
|
||||
}),
|
||||
|
||||
// Offset port mapping using dynamic offset
|
||||
createOffsetPortMappingRoute({
|
||||
ports: PROXY_PORTS[1],
|
||||
targetHost: 'localhost',
|
||||
offset: portOffset,
|
||||
name: `Offset Port Mapping (${portOffset})`
|
||||
}),
|
||||
|
||||
// Dynamic route with conditional port mapping
|
||||
createDynamicRoute({
|
||||
ports: [PROXY_PORTS[2], PROXY_PORTS[3]],
|
||||
targetHost: (context) => {
|
||||
// Dynamic host selection based on port
|
||||
return context.port === PROXY_PORTS[2] ? 'localhost' : '127.0.0.1';
|
||||
{
|
||||
match: { ports: PROXY_PORTS[0] },
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{
|
||||
host: 'localhost',
|
||||
port: (context: IRouteContext) => TEST_PORTS[0]
|
||||
}]
|
||||
},
|
||||
portMapper: (context) => {
|
||||
// Port mapping logic based on incoming port
|
||||
if (context.port === PROXY_PORTS[2]) {
|
||||
return TEST_PORTS[0];
|
||||
} else {
|
||||
return TEST_PORTS[2];
|
||||
}
|
||||
name: 'Identity Port Mapping'
|
||||
},
|
||||
|
||||
{
|
||||
match: { ports: PROXY_PORTS[1] },
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{
|
||||
host: 'localhost',
|
||||
port: (context: IRouteContext) => context.port + portOffset
|
||||
}]
|
||||
},
|
||||
name: `Offset Port Mapping (${portOffset})`
|
||||
},
|
||||
|
||||
{
|
||||
match: { ports: [PROXY_PORTS[2], PROXY_PORTS[3]] },
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{
|
||||
host: (context: IRouteContext) => {
|
||||
return context.port === PROXY_PORTS[2] ? 'localhost' : '127.0.0.1';
|
||||
},
|
||||
port: (context: IRouteContext) => {
|
||||
if (context.port === PROXY_PORTS[2]) {
|
||||
return TEST_PORTS[0];
|
||||
} else {
|
||||
return TEST_PORTS[2];
|
||||
}
|
||||
}
|
||||
}]
|
||||
},
|
||||
name: 'Dynamic Host and Port Mapping'
|
||||
}),
|
||||
},
|
||||
|
||||
// Smart load balancer for domain-based routing
|
||||
createSmartLoadBalancer({
|
||||
ports: PROXY_PORTS[4],
|
||||
domainTargets: {
|
||||
'test1.example.com': 'localhost',
|
||||
'test2.example.com': '127.0.0.1'
|
||||
{
|
||||
match: { ports: PROXY_PORTS[4] },
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{
|
||||
host: (context: IRouteContext) => {
|
||||
if (context.domain === 'test1.example.com') return 'localhost';
|
||||
if (context.domain === 'test2.example.com') return '127.0.0.1';
|
||||
return 'localhost';
|
||||
},
|
||||
port: (context: IRouteContext) => {
|
||||
if (context.domain === 'test1.example.com') {
|
||||
return TEST_PORTS[0];
|
||||
} else {
|
||||
return TEST_PORTS[1];
|
||||
}
|
||||
}
|
||||
}]
|
||||
},
|
||||
portMapper: (context) => {
|
||||
// Use different backend ports based on domain
|
||||
if (context.domain === 'test1.example.com') {
|
||||
return TEST_PORTS[0];
|
||||
} else {
|
||||
return TEST_PORTS[1];
|
||||
}
|
||||
},
|
||||
defaultTarget: 'localhost',
|
||||
name: 'Smart Domain Load Balancer'
|
||||
})
|
||||
}
|
||||
]
|
||||
});
|
||||
|
||||
// Start the SmartProxy
|
||||
await smartProxy.start();
|
||||
});
|
||||
|
||||
// Test 1: Simple identity port mapping
|
||||
tap.test('should map port using identity function', async () => {
|
||||
const response = await createTestClient(PROXY_PORTS[0], TEST_DATA);
|
||||
expect(response).toEqual(`Server ${TEST_PORTS[0]} says: ${TEST_DATA}`);
|
||||
});
|
||||
|
||||
// Test 2: Offset port mapping
|
||||
tap.test('should map port using offset function', async () => {
|
||||
const response = await createTestClient(PROXY_PORTS[1], TEST_DATA);
|
||||
expect(response).toEqual(`Server ${TEST_PORTS[1]} says: ${TEST_DATA}`);
|
||||
});
|
||||
|
||||
// Test 3: Dynamic port and host mapping (conditional logic)
|
||||
tap.test('should map port using dynamic logic', async () => {
|
||||
const response = await createTestClient(PROXY_PORTS[2], TEST_DATA);
|
||||
expect(response).toEqual(`Server ${TEST_PORTS[0]} says: ${TEST_DATA}`);
|
||||
});
|
||||
|
||||
// Test 4: Test reuse of createPortOffset helper
|
||||
tap.test('should use createPortOffset helper for port mapping', async () => {
|
||||
// Test the createPortOffset helper with dynamic offset
|
||||
const portOffset = TEST_PORTS[1] - PROXY_PORTS[1];
|
||||
const offsetFn = createPortOffset(portOffset);
|
||||
const context = {
|
||||
port: PROXY_PORTS[1],
|
||||
clientIp: '127.0.0.1',
|
||||
serverIp: '127.0.0.1',
|
||||
isTls: false,
|
||||
timestamp: Date.now(),
|
||||
connectionId: 'test-connection'
|
||||
} as IRouteContext;
|
||||
|
||||
const mappedPort = offsetFn(context);
|
||||
expect(mappedPort).toEqual(TEST_PORTS[1]);
|
||||
});
|
||||
|
||||
// Test 5: Test error handling for invalid port mapping functions
|
||||
tap.test('should handle errors in port mapping functions', async () => {
|
||||
// Create a route with a function that throws an error
|
||||
const errorRoute: IRouteConfig = {
|
||||
match: {
|
||||
ports: PROXY_PORTS[5]
|
||||
@@ -232,34 +203,27 @@ tap.test('should handle errors in port mapping functions', async () => {
|
||||
},
|
||||
name: 'Error Route'
|
||||
};
|
||||
|
||||
// Add the route to SmartProxy
|
||||
|
||||
await smartProxy.updateRoutes([...smartProxy.settings.routes, errorRoute]);
|
||||
|
||||
// The connection should fail or timeout
|
||||
|
||||
try {
|
||||
await createTestClient(PROXY_PORTS[5], TEST_DATA);
|
||||
// Connection should not succeed
|
||||
expect(false).toBeTrue();
|
||||
} catch (error) {
|
||||
// Connection failed as expected
|
||||
expect(true).toBeTrue();
|
||||
}
|
||||
});
|
||||
|
||||
// Cleanup
|
||||
tap.test('cleanup port mapping test environment', async () => {
|
||||
// Add timeout to prevent hanging if SmartProxy shutdown has issues
|
||||
const cleanupPromise = cleanup();
|
||||
const timeoutPromise = new Promise((_, reject) =>
|
||||
const timeoutPromise = new Promise((_, reject) =>
|
||||
setTimeout(() => reject(new Error('Cleanup timeout after 5 seconds')), 5000)
|
||||
);
|
||||
|
||||
|
||||
try {
|
||||
await Promise.race([cleanupPromise, timeoutPromise]);
|
||||
} catch (error) {
|
||||
console.error('Cleanup error:', error);
|
||||
// Force cleanup even if there's an error
|
||||
testServers = [];
|
||||
smartProxy = null as any;
|
||||
}
|
||||
@@ -267,4 +231,4 @@ tap.test('cleanup port mapping test environment', async () => {
|
||||
await assertPortsFree([...TEST_PORTS, ...PROXY_PORTS]);
|
||||
});
|
||||
|
||||
export default tap.start();
|
||||
export default tap.start();
|
||||
|
||||
+266
-223
@@ -6,7 +6,7 @@ import { expect, tap } from '@git.zone/tstest/tapbundle';
|
||||
// Import from core modules
|
||||
import { SmartProxy } from '../ts/proxies/smart-proxy/index.js';
|
||||
|
||||
// Import route utilities and helpers
|
||||
// Import route utilities
|
||||
import {
|
||||
findMatchingRoutes,
|
||||
findBestMatchingRoute,
|
||||
@@ -28,16 +28,7 @@ import {
|
||||
assertValidRoute
|
||||
} from '../ts/proxies/smart-proxy/utils/route-validator.js';
|
||||
|
||||
import {
|
||||
createHttpRoute,
|
||||
createHttpsTerminateRoute,
|
||||
createHttpsPassthroughRoute,
|
||||
createHttpToHttpsRedirect,
|
||||
createCompleteHttpsServer,
|
||||
createLoadBalancerRoute,
|
||||
createApiRoute,
|
||||
createWebSocketRoute
|
||||
} from '../ts/proxies/smart-proxy/utils/route-helpers.js';
|
||||
import { SocketHandlers } from '../ts/proxies/smart-proxy/utils/socket-handlers.js';
|
||||
|
||||
// Import test helpers
|
||||
import { loadTestCertificates } from './helpers/certificates.js';
|
||||
@@ -47,12 +38,12 @@ import type { IRouteConfig } from '../ts/proxies/smart-proxy/models/route-types.
|
||||
// --------------------------------- Route Creation Tests ---------------------------------
|
||||
|
||||
tap.test('Routes: Should create basic HTTP route', async () => {
|
||||
// Create a simple HTTP route
|
||||
const httpRoute = createHttpRoute('example.com', { host: 'localhost', port: 3000 }, {
|
||||
const httpRoute: IRouteConfig = {
|
||||
match: { ports: 80, domains: 'example.com' },
|
||||
action: { type: 'forward', targets: [{ host: 'localhost', port: 3000 }] },
|
||||
name: 'Basic HTTP Route'
|
||||
});
|
||||
};
|
||||
|
||||
// Validate the route configuration
|
||||
expect(httpRoute.match.ports).toEqual(80);
|
||||
expect(httpRoute.match.domains).toEqual('example.com');
|
||||
expect(httpRoute.action.type).toEqual('forward');
|
||||
@@ -62,14 +53,17 @@ tap.test('Routes: Should create basic HTTP route', async () => {
|
||||
});
|
||||
|
||||
tap.test('Routes: Should create HTTPS route with TLS termination', async () => {
|
||||
// Create an HTTPS route with TLS termination
|
||||
const httpsRoute = createHttpsTerminateRoute('secure.example.com', { host: 'localhost', port: 8080 }, {
|
||||
certificate: 'auto',
|
||||
const httpsRoute: IRouteConfig = {
|
||||
match: { ports: 443, domains: 'secure.example.com' },
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{ host: 'localhost', port: 8080 }],
|
||||
tls: { mode: 'terminate', certificate: 'auto' }
|
||||
},
|
||||
name: 'HTTPS Route'
|
||||
});
|
||||
};
|
||||
|
||||
// Validate the route configuration
|
||||
expect(httpsRoute.match.ports).toEqual(443); // Default HTTPS port
|
||||
expect(httpsRoute.match.ports).toEqual(443);
|
||||
expect(httpsRoute.match.domains).toEqual('secure.example.com');
|
||||
expect(httpsRoute.action.type).toEqual('forward');
|
||||
expect(httpsRoute.action.tls?.mode).toEqual('terminate');
|
||||
@@ -80,10 +74,15 @@ tap.test('Routes: Should create HTTPS route with TLS termination', async () => {
|
||||
});
|
||||
|
||||
tap.test('Routes: Should create HTTP to HTTPS redirect', async () => {
|
||||
// Create an HTTP to HTTPS redirect
|
||||
const redirectRoute = createHttpToHttpsRedirect('example.com', 443);
|
||||
const redirectRoute: IRouteConfig = {
|
||||
match: { ports: 80, domains: 'example.com' },
|
||||
action: {
|
||||
type: 'socket-handler',
|
||||
socketHandler: SocketHandlers.httpRedirect('https://{domain}:443{path}', 301)
|
||||
},
|
||||
name: 'HTTP to HTTPS Redirect for example.com'
|
||||
};
|
||||
|
||||
// Validate the route configuration
|
||||
expect(redirectRoute.match.ports).toEqual(80);
|
||||
expect(redirectRoute.match.domains).toEqual('example.com');
|
||||
expect(redirectRoute.action.type).toEqual('socket-handler');
|
||||
@@ -91,22 +90,34 @@ tap.test('Routes: Should create HTTP to HTTPS redirect', async () => {
|
||||
});
|
||||
|
||||
tap.test('Routes: Should create complete HTTPS server with redirects', async () => {
|
||||
// Create a complete HTTPS server setup
|
||||
const routes = createCompleteHttpsServer('example.com', { host: 'localhost', port: 8080 }, {
|
||||
certificate: 'auto'
|
||||
});
|
||||
const routes: IRouteConfig[] = [
|
||||
{
|
||||
match: { ports: 443, domains: 'example.com' },
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{ host: 'localhost', port: 8080 }],
|
||||
tls: { mode: 'terminate', certificate: 'auto' }
|
||||
},
|
||||
name: 'HTTPS Terminate Route for example.com'
|
||||
},
|
||||
{
|
||||
match: { ports: 80, domains: 'example.com' },
|
||||
action: {
|
||||
type: 'socket-handler',
|
||||
socketHandler: SocketHandlers.httpRedirect('https://{domain}:443{path}', 301)
|
||||
},
|
||||
name: 'HTTP to HTTPS Redirect for example.com'
|
||||
}
|
||||
];
|
||||
|
||||
// Validate that we got two routes (HTTPS route and HTTP redirect)
|
||||
expect(routes.length).toEqual(2);
|
||||
|
||||
// Validate HTTPS route
|
||||
const httpsRoute = routes[0];
|
||||
expect(httpsRoute.match.ports).toEqual(443);
|
||||
expect(httpsRoute.match.domains).toEqual('example.com');
|
||||
expect(httpsRoute.action.type).toEqual('forward');
|
||||
expect(httpsRoute.action.tls?.mode).toEqual('terminate');
|
||||
|
||||
// Validate HTTP redirect route
|
||||
const redirectRoute = routes[1];
|
||||
expect(redirectRoute.match.ports).toEqual(80);
|
||||
expect(redirectRoute.action.type).toEqual('socket-handler');
|
||||
@@ -114,21 +125,17 @@ tap.test('Routes: Should create complete HTTPS server with redirects', async ()
|
||||
});
|
||||
|
||||
tap.test('Routes: Should create load balancer route', async () => {
|
||||
// Create a load balancer route
|
||||
const lbRoute = createLoadBalancerRoute(
|
||||
'app.example.com',
|
||||
['10.0.0.1', '10.0.0.2', '10.0.0.3'],
|
||||
8080,
|
||||
{
|
||||
tls: {
|
||||
mode: 'terminate',
|
||||
certificate: 'auto'
|
||||
},
|
||||
name: 'Load Balanced Route'
|
||||
}
|
||||
);
|
||||
const lbRoute: IRouteConfig = {
|
||||
match: { ports: 443, domains: 'app.example.com' },
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{ host: ['10.0.0.1', '10.0.0.2', '10.0.0.3'], port: 8080 }],
|
||||
tls: { mode: 'terminate', certificate: 'auto' },
|
||||
loadBalancing: { algorithm: 'round-robin' }
|
||||
},
|
||||
name: 'Load Balanced Route'
|
||||
};
|
||||
|
||||
// Validate the route configuration
|
||||
expect(lbRoute.match.domains).toEqual('app.example.com');
|
||||
expect(lbRoute.action.type).toEqual('forward');
|
||||
expect(Array.isArray(lbRoute.action.targets?.[0]?.host)).toBeTrue();
|
||||
@@ -139,23 +146,32 @@ tap.test('Routes: Should create load balancer route', async () => {
|
||||
});
|
||||
|
||||
tap.test('Routes: Should create API route with CORS', async () => {
|
||||
// Create an API route with CORS headers
|
||||
const apiRoute = createApiRoute('api.example.com', '/v1', { host: 'localhost', port: 3000 }, {
|
||||
useTls: true,
|
||||
certificate: 'auto',
|
||||
addCorsHeaders: true,
|
||||
const apiRoute: IRouteConfig = {
|
||||
match: { ports: 443, domains: 'api.example.com', path: '/v1/*' },
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{ host: 'localhost', port: 3000 }],
|
||||
tls: { mode: 'terminate', certificate: 'auto' }
|
||||
},
|
||||
headers: {
|
||||
response: {
|
||||
'Access-Control-Allow-Origin': '*',
|
||||
'Access-Control-Allow-Methods': 'GET, POST, PUT, DELETE, OPTIONS',
|
||||
'Access-Control-Allow-Headers': 'Content-Type, Authorization',
|
||||
'Access-Control-Max-Age': '86400'
|
||||
}
|
||||
},
|
||||
priority: 100,
|
||||
name: 'API Route'
|
||||
});
|
||||
};
|
||||
|
||||
// Validate the route configuration
|
||||
expect(apiRoute.match.domains).toEqual('api.example.com');
|
||||
expect(apiRoute.match.path).toEqual('/v1/*');
|
||||
expect(apiRoute.action.type).toEqual('forward');
|
||||
expect(apiRoute.action.tls?.mode).toEqual('terminate');
|
||||
expect(apiRoute.action.targets?.[0]?.host).toEqual('localhost');
|
||||
expect(apiRoute.action.targets?.[0]?.port).toEqual(3000);
|
||||
|
||||
// Check CORS headers
|
||||
|
||||
expect(apiRoute.headers).toBeDefined();
|
||||
if (apiRoute.headers?.response) {
|
||||
expect(apiRoute.headers.response['Access-Control-Allow-Origin']).toEqual('*');
|
||||
@@ -164,23 +180,25 @@ tap.test('Routes: Should create API route with CORS', async () => {
|
||||
});
|
||||
|
||||
tap.test('Routes: Should create WebSocket route', async () => {
|
||||
// Create a WebSocket route
|
||||
const wsRoute = createWebSocketRoute('ws.example.com', '/socket', { host: 'localhost', port: 5000 }, {
|
||||
useTls: true,
|
||||
certificate: 'auto',
|
||||
pingInterval: 15000,
|
||||
const wsRoute: IRouteConfig = {
|
||||
match: { ports: 443, domains: 'ws.example.com', path: '/socket' },
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{ host: 'localhost', port: 5000 }],
|
||||
tls: { mode: 'terminate', certificate: 'auto' },
|
||||
websocket: { enabled: true, pingInterval: 15000 }
|
||||
},
|
||||
priority: 100,
|
||||
name: 'WebSocket Route'
|
||||
});
|
||||
};
|
||||
|
||||
// Validate the route configuration
|
||||
expect(wsRoute.match.domains).toEqual('ws.example.com');
|
||||
expect(wsRoute.match.path).toEqual('/socket');
|
||||
expect(wsRoute.action.type).toEqual('forward');
|
||||
expect(wsRoute.action.tls?.mode).toEqual('terminate');
|
||||
expect(wsRoute.action.targets?.[0]?.host).toEqual('localhost');
|
||||
expect(wsRoute.action.targets?.[0]?.port).toEqual(5000);
|
||||
|
||||
// Check WebSocket configuration
|
||||
|
||||
expect(wsRoute.action.websocket).toBeDefined();
|
||||
if (wsRoute.action.websocket) {
|
||||
expect(wsRoute.action.websocket.enabled).toBeTrue();
|
||||
@@ -191,22 +209,27 @@ tap.test('Routes: Should create WebSocket route', async () => {
|
||||
// Static file serving has been removed - should be handled by external servers
|
||||
|
||||
tap.test('SmartProxy: Should create instance with route-based config', async () => {
|
||||
// Create TLS certificates for testing
|
||||
const certs = loadTestCertificates();
|
||||
|
||||
// Create a SmartProxy instance with route-based configuration
|
||||
const proxy = new SmartProxy({
|
||||
routes: [
|
||||
createHttpRoute('example.com', { host: 'localhost', port: 3000 }, {
|
||||
{
|
||||
match: { ports: 80, domains: 'example.com' },
|
||||
action: { type: 'forward', targets: [{ host: 'localhost', port: 3000 }] },
|
||||
name: 'HTTP Route'
|
||||
}),
|
||||
createHttpsTerminateRoute('secure.example.com', { host: 'localhost', port: 8443 }, {
|
||||
certificate: {
|
||||
key: certs.privateKey,
|
||||
cert: certs.publicKey
|
||||
},
|
||||
{
|
||||
match: { ports: 443, domains: 'secure.example.com' },
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{ host: 'localhost', port: 8443 }],
|
||||
tls: {
|
||||
mode: 'terminate',
|
||||
certificate: { key: certs.privateKey, cert: certs.publicKey }
|
||||
}
|
||||
},
|
||||
name: 'HTTPS Route'
|
||||
})
|
||||
}
|
||||
],
|
||||
defaults: {
|
||||
target: {
|
||||
@@ -218,13 +241,11 @@ tap.test('SmartProxy: Should create instance with route-based config', async ()
|
||||
maxConnections: 100
|
||||
}
|
||||
},
|
||||
// Additional settings
|
||||
initialDataTimeout: 10000,
|
||||
inactivityTimeout: 300000,
|
||||
enableDetailedLogging: true
|
||||
});
|
||||
|
||||
// Simply verify the instance was created successfully
|
||||
expect(typeof proxy).toEqual('object');
|
||||
expect(typeof proxy.start).toEqual('function');
|
||||
expect(typeof proxy.stop).toEqual('function');
|
||||
@@ -233,94 +254,109 @@ tap.test('SmartProxy: Should create instance with route-based config', async ()
|
||||
// --------------------------------- Edge Case Tests ---------------------------------
|
||||
|
||||
tap.test('Edge Case - Empty Routes Array', async () => {
|
||||
// Attempting to find routes in an empty array
|
||||
const emptyRoutes: IRouteConfig[] = [];
|
||||
const matches = findMatchingRoutes(emptyRoutes, { domain: 'example.com', port: 80 });
|
||||
|
||||
|
||||
expect(matches).toBeInstanceOf(Array);
|
||||
expect(matches.length).toEqual(0);
|
||||
|
||||
|
||||
const bestMatch = findBestMatchingRoute(emptyRoutes, { domain: 'example.com', port: 80 });
|
||||
expect(bestMatch).toBeUndefined();
|
||||
});
|
||||
|
||||
tap.test('Edge Case - Multiple Matching Routes with Same Priority', async () => {
|
||||
// Create multiple routes with identical priority but different targets
|
||||
const route1 = createHttpRoute('example.com', { host: 'server1', port: 3000 });
|
||||
const route2 = createHttpRoute('example.com', { host: 'server2', port: 3000 });
|
||||
const route3 = createHttpRoute('example.com', { host: 'server3', port: 3000 });
|
||||
|
||||
// Set all to the same priority
|
||||
const route1: IRouteConfig = {
|
||||
match: { ports: 80, domains: 'example.com' },
|
||||
action: { type: 'forward', targets: [{ host: 'server1', port: 3000 }] },
|
||||
name: 'HTTP Route for example.com'
|
||||
};
|
||||
const route2: IRouteConfig = {
|
||||
match: { ports: 80, domains: 'example.com' },
|
||||
action: { type: 'forward', targets: [{ host: 'server2', port: 3000 }] },
|
||||
name: 'HTTP Route for example.com'
|
||||
};
|
||||
const route3: IRouteConfig = {
|
||||
match: { ports: 80, domains: 'example.com' },
|
||||
action: { type: 'forward', targets: [{ host: 'server3', port: 3000 }] },
|
||||
name: 'HTTP Route for example.com'
|
||||
};
|
||||
|
||||
route1.priority = 100;
|
||||
route2.priority = 100;
|
||||
route3.priority = 100;
|
||||
|
||||
|
||||
const routes = [route1, route2, route3];
|
||||
|
||||
// Find matching routes
|
||||
|
||||
const matches = findMatchingRoutes(routes, { domain: 'example.com', port: 80 });
|
||||
|
||||
// Should find all three routes
|
||||
|
||||
expect(matches.length).toEqual(3);
|
||||
|
||||
// First match could be any of the routes since they have the same priority
|
||||
// But the implementation should be consistent (likely keep the original order)
|
||||
|
||||
const bestMatch = findBestMatchingRoute(routes, { domain: 'example.com', port: 80 });
|
||||
expect(bestMatch).not.toBeUndefined();
|
||||
});
|
||||
|
||||
tap.test('Edge Case - Wildcard Domains and Path Matching', async () => {
|
||||
// Create routes with wildcard domains and path patterns
|
||||
const wildcardApiRoute = createApiRoute('*.example.com', '/api', { host: 'api-server', port: 3000 }, {
|
||||
useTls: true,
|
||||
certificate: 'auto'
|
||||
});
|
||||
|
||||
const exactApiRoute = createApiRoute('api.example.com', '/api', { host: 'specific-api-server', port: 3001 }, {
|
||||
useTls: true,
|
||||
certificate: 'auto',
|
||||
priority: 200 // Higher priority
|
||||
});
|
||||
|
||||
const wildcardApiRoute: IRouteConfig = {
|
||||
match: { ports: 443, domains: '*.example.com', path: '/api/*' },
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{ host: 'api-server', port: 3000 }],
|
||||
tls: { mode: 'terminate', certificate: 'auto' }
|
||||
},
|
||||
priority: 100,
|
||||
name: 'API Route for *.example.com'
|
||||
};
|
||||
|
||||
const exactApiRoute: IRouteConfig = {
|
||||
match: { ports: 443, domains: 'api.example.com', path: '/api/*' },
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{ host: 'specific-api-server', port: 3001 }],
|
||||
tls: { mode: 'terminate', certificate: 'auto' }
|
||||
},
|
||||
priority: 200,
|
||||
name: 'API Route for api.example.com'
|
||||
};
|
||||
|
||||
const routes = [wildcardApiRoute, exactApiRoute];
|
||||
|
||||
// Test with a specific subdomain that matches both routes
|
||||
|
||||
const matches = findMatchingRoutes(routes, { domain: 'api.example.com', path: '/api/users', port: 443 });
|
||||
|
||||
// Should match both routes
|
||||
|
||||
expect(matches.length).toEqual(2);
|
||||
|
||||
// The exact domain match should have higher priority
|
||||
|
||||
const bestMatch = findBestMatchingRoute(routes, { domain: 'api.example.com', path: '/api/users', port: 443 });
|
||||
expect(bestMatch).not.toBeUndefined();
|
||||
if (bestMatch) {
|
||||
expect(bestMatch.action.targets[0].port).toEqual(3001); // Should match the exact domain route
|
||||
expect(bestMatch.action.targets[0].port).toEqual(3001);
|
||||
}
|
||||
|
||||
// Test with a different subdomain - should only match the wildcard route
|
||||
|
||||
const otherMatches = findMatchingRoutes(routes, { domain: 'other.example.com', path: '/api/products', port: 443 });
|
||||
expect(otherMatches.length).toEqual(1);
|
||||
expect(otherMatches[0].action.targets[0].port).toEqual(3000); // Should match the wildcard domain route
|
||||
expect(otherMatches[0].action.targets[0].port).toEqual(3000);
|
||||
});
|
||||
|
||||
tap.test('Edge Case - Disabled Routes', async () => {
|
||||
// Create enabled and disabled routes
|
||||
const enabledRoute = createHttpRoute('example.com', { host: 'server1', port: 3000 });
|
||||
const disabledRoute = createHttpRoute('example.com', { host: 'server2', port: 3001 });
|
||||
const enabledRoute: IRouteConfig = {
|
||||
match: { ports: 80, domains: 'example.com' },
|
||||
action: { type: 'forward', targets: [{ host: 'server1', port: 3000 }] },
|
||||
name: 'HTTP Route for example.com'
|
||||
};
|
||||
const disabledRoute: IRouteConfig = {
|
||||
match: { ports: 80, domains: 'example.com' },
|
||||
action: { type: 'forward', targets: [{ host: 'server2', port: 3001 }] },
|
||||
name: 'HTTP Route for example.com'
|
||||
};
|
||||
disabledRoute.enabled = false;
|
||||
|
||||
|
||||
const routes = [enabledRoute, disabledRoute];
|
||||
|
||||
// Find matching routes
|
||||
|
||||
const matches = findMatchingRoutes(routes, { domain: 'example.com', port: 80 });
|
||||
|
||||
// Should only find the enabled route
|
||||
|
||||
expect(matches.length).toEqual(1);
|
||||
expect(matches[0].action.targets[0].port).toEqual(3000);
|
||||
});
|
||||
|
||||
tap.test('Edge Case - Complex Path and Headers Matching', async () => {
|
||||
// Create route with complex path and headers matching
|
||||
const complexRoute: IRouteConfig = {
|
||||
match: {
|
||||
domains: 'api.example.com',
|
||||
@@ -344,22 +380,20 @@ tap.test('Edge Case - Complex Path and Headers Matching', async () => {
|
||||
},
|
||||
name: 'Complex API Route'
|
||||
};
|
||||
|
||||
// Test with matching criteria
|
||||
|
||||
const matchingPath = routeMatchesPath(complexRoute, '/api/v2/users');
|
||||
expect(matchingPath).toBeTrue();
|
||||
|
||||
|
||||
const matchingHeaders = routeMatchesHeaders(complexRoute, {
|
||||
'Content-Type': 'application/json',
|
||||
'X-API-Key': 'valid-key',
|
||||
'Accept': 'application/json'
|
||||
});
|
||||
expect(matchingHeaders).toBeTrue();
|
||||
|
||||
// Test with non-matching criteria
|
||||
|
||||
const nonMatchingPath = routeMatchesPath(complexRoute, '/api/v1/users');
|
||||
expect(nonMatchingPath).toBeFalse();
|
||||
|
||||
|
||||
const nonMatchingHeaders = routeMatchesHeaders(complexRoute, {
|
||||
'Content-Type': 'application/json',
|
||||
'X-API-Key': 'invalid-key'
|
||||
@@ -368,7 +402,6 @@ tap.test('Edge Case - Complex Path and Headers Matching', async () => {
|
||||
});
|
||||
|
||||
tap.test('Edge Case - Port Range Matching', async () => {
|
||||
// Create route with port range matching
|
||||
const portRangeRoute: IRouteConfig = {
|
||||
match: {
|
||||
domains: 'example.com',
|
||||
@@ -383,17 +416,14 @@ tap.test('Edge Case - Port Range Matching', async () => {
|
||||
},
|
||||
name: 'Port Range Route'
|
||||
};
|
||||
|
||||
// Test with ports in the range
|
||||
expect(routeMatchesPort(portRangeRoute, 8000)).toBeTrue(); // Lower bound
|
||||
expect(routeMatchesPort(portRangeRoute, 8500)).toBeTrue(); // Middle
|
||||
expect(routeMatchesPort(portRangeRoute, 9000)).toBeTrue(); // Upper bound
|
||||
|
||||
// Test with ports outside the range
|
||||
expect(routeMatchesPort(portRangeRoute, 7999)).toBeFalse(); // Just below
|
||||
expect(routeMatchesPort(portRangeRoute, 9001)).toBeFalse(); // Just above
|
||||
|
||||
// Test with multiple port ranges
|
||||
|
||||
expect(routeMatchesPort(portRangeRoute, 8000)).toBeTrue();
|
||||
expect(routeMatchesPort(portRangeRoute, 8500)).toBeTrue();
|
||||
expect(routeMatchesPort(portRangeRoute, 9000)).toBeTrue();
|
||||
|
||||
expect(routeMatchesPort(portRangeRoute, 7999)).toBeFalse();
|
||||
expect(routeMatchesPort(portRangeRoute, 9001)).toBeFalse();
|
||||
|
||||
const multiRangeRoute: IRouteConfig = {
|
||||
match: {
|
||||
domains: 'example.com',
|
||||
@@ -411,7 +441,7 @@ tap.test('Edge Case - Port Range Matching', async () => {
|
||||
},
|
||||
name: 'Multi Range Route'
|
||||
};
|
||||
|
||||
|
||||
expect(routeMatchesPort(multiRangeRoute, 85)).toBeTrue();
|
||||
expect(routeMatchesPort(multiRangeRoute, 8500)).toBeTrue();
|
||||
expect(routeMatchesPort(multiRangeRoute, 100)).toBeFalse();
|
||||
@@ -420,55 +450,56 @@ tap.test('Edge Case - Port Range Matching', async () => {
|
||||
// --------------------------------- Wildcard Domain Tests ---------------------------------
|
||||
|
||||
tap.test('Wildcard Domain Handling', async () => {
|
||||
// Create routes with different wildcard patterns
|
||||
const simpleDomainRoute = createHttpRoute('example.com', { host: 'server1', port: 3000 });
|
||||
const wildcardSubdomainRoute = createHttpRoute('*.example.com', { host: 'server2', port: 3001 });
|
||||
const specificSubdomainRoute = createHttpRoute('api.example.com', { host: 'server3', port: 3002 });
|
||||
const simpleDomainRoute: IRouteConfig = {
|
||||
match: { ports: 80, domains: 'example.com' },
|
||||
action: { type: 'forward', targets: [{ host: 'server1', port: 3000 }] },
|
||||
name: 'HTTP Route for example.com'
|
||||
};
|
||||
const wildcardSubdomainRoute: IRouteConfig = {
|
||||
match: { ports: 80, domains: '*.example.com' },
|
||||
action: { type: 'forward', targets: [{ host: 'server2', port: 3001 }] },
|
||||
name: 'HTTP Route for *.example.com'
|
||||
};
|
||||
const specificSubdomainRoute: IRouteConfig = {
|
||||
match: { ports: 80, domains: 'api.example.com' },
|
||||
action: { type: 'forward', targets: [{ host: 'server3', port: 3002 }] },
|
||||
name: 'HTTP Route for api.example.com'
|
||||
};
|
||||
|
||||
// Set explicit priorities to ensure deterministic matching
|
||||
specificSubdomainRoute.priority = 200; // Highest priority for specific domain
|
||||
wildcardSubdomainRoute.priority = 100; // Medium priority for wildcard
|
||||
simpleDomainRoute.priority = 50; // Lowest priority for generic domain
|
||||
specificSubdomainRoute.priority = 200;
|
||||
wildcardSubdomainRoute.priority = 100;
|
||||
simpleDomainRoute.priority = 50;
|
||||
|
||||
const routes = [simpleDomainRoute, wildcardSubdomainRoute, specificSubdomainRoute];
|
||||
|
||||
// Test exact domain match
|
||||
expect(routeMatchesDomain(simpleDomainRoute, 'example.com')).toBeTrue();
|
||||
expect(routeMatchesDomain(simpleDomainRoute, 'sub.example.com')).toBeFalse();
|
||||
|
||||
// Test wildcard subdomain match
|
||||
expect(routeMatchesDomain(wildcardSubdomainRoute, 'any.example.com')).toBeTrue();
|
||||
expect(routeMatchesDomain(wildcardSubdomainRoute, 'nested.sub.example.com')).toBeTrue();
|
||||
expect(routeMatchesDomain(wildcardSubdomainRoute, 'example.com')).toBeFalse();
|
||||
|
||||
// Test specific subdomain match
|
||||
expect(routeMatchesDomain(specificSubdomainRoute, 'api.example.com')).toBeTrue();
|
||||
expect(routeMatchesDomain(specificSubdomainRoute, 'other.example.com')).toBeFalse();
|
||||
expect(routeMatchesDomain(specificSubdomainRoute, 'sub.api.example.com')).toBeFalse();
|
||||
|
||||
// Test finding best match when multiple domains match
|
||||
const specificSubdomainRequest = { domain: 'api.example.com', port: 80 };
|
||||
const bestSpecificMatch = findBestMatchingRoute(routes, specificSubdomainRequest);
|
||||
expect(bestSpecificMatch).not.toBeUndefined();
|
||||
if (bestSpecificMatch) {
|
||||
// Find which route was matched
|
||||
const matchedPort = bestSpecificMatch.action.targets[0].port;
|
||||
console.log(`Matched route with port: ${matchedPort}`);
|
||||
|
||||
// Verify it's the specific subdomain route (with highest priority)
|
||||
expect(bestSpecificMatch.priority).toEqual(200);
|
||||
}
|
||||
|
||||
// Test with a subdomain that matches wildcard but not specific
|
||||
const otherSubdomainRequest = { domain: 'other.example.com', port: 80 };
|
||||
const bestWildcardMatch = findBestMatchingRoute(routes, otherSubdomainRequest);
|
||||
expect(bestWildcardMatch).not.toBeUndefined();
|
||||
if (bestWildcardMatch) {
|
||||
// Find which route was matched
|
||||
const matchedPort = bestWildcardMatch.action.targets[0].port;
|
||||
console.log(`Matched route with port: ${matchedPort}`);
|
||||
|
||||
// Verify it's the wildcard subdomain route (with medium priority)
|
||||
expect(bestWildcardMatch.priority).toEqual(100);
|
||||
}
|
||||
});
|
||||
@@ -476,56 +507,83 @@ tap.test('Wildcard Domain Handling', async () => {
|
||||
// --------------------------------- Integration Tests ---------------------------------
|
||||
|
||||
tap.test('Route Integration - Combining Multiple Route Types', async () => {
|
||||
// Create a comprehensive set of routes for a full application
|
||||
const routes: IRouteConfig[] = [
|
||||
// Main website with HTTPS and HTTP redirect
|
||||
...createCompleteHttpsServer('example.com', { host: 'web-server', port: 8080 }, {
|
||||
certificate: 'auto'
|
||||
}),
|
||||
|
||||
// API endpoints
|
||||
createApiRoute('api.example.com', '/v1', { host: 'api-server', port: 3000 }, {
|
||||
useTls: true,
|
||||
certificate: 'auto',
|
||||
addCorsHeaders: true
|
||||
}),
|
||||
|
||||
// WebSocket for real-time updates
|
||||
createWebSocketRoute('ws.example.com', '/live', { host: 'websocket-server', port: 5000 }, {
|
||||
useTls: true,
|
||||
certificate: 'auto'
|
||||
}),
|
||||
|
||||
|
||||
// Legacy system with passthrough
|
||||
createHttpsPassthroughRoute('legacy.example.com', { host: 'legacy-server', port: 443 })
|
||||
{
|
||||
match: { ports: 443, domains: 'example.com' },
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{ host: 'web-server', port: 8080 }],
|
||||
tls: { mode: 'terminate', certificate: 'auto' }
|
||||
},
|
||||
name: 'HTTPS Terminate Route for example.com'
|
||||
},
|
||||
{
|
||||
match: { ports: 80, domains: 'example.com' },
|
||||
action: {
|
||||
type: 'socket-handler',
|
||||
socketHandler: SocketHandlers.httpRedirect('https://{domain}:443{path}', 301)
|
||||
},
|
||||
name: 'HTTP to HTTPS Redirect for example.com'
|
||||
},
|
||||
{
|
||||
match: { ports: 443, domains: 'api.example.com', path: '/v1/*' },
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{ host: 'api-server', port: 3000 }],
|
||||
tls: { mode: 'terminate', certificate: 'auto' }
|
||||
},
|
||||
headers: {
|
||||
response: {
|
||||
'Access-Control-Allow-Origin': '*',
|
||||
'Access-Control-Allow-Methods': 'GET, POST, PUT, DELETE, OPTIONS',
|
||||
'Access-Control-Allow-Headers': 'Content-Type, Authorization',
|
||||
'Access-Control-Max-Age': '86400'
|
||||
}
|
||||
},
|
||||
priority: 100,
|
||||
name: 'API Route for api.example.com'
|
||||
},
|
||||
{
|
||||
match: { ports: 443, domains: 'ws.example.com', path: '/live' },
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{ host: 'websocket-server', port: 5000 }],
|
||||
tls: { mode: 'terminate', certificate: 'auto' },
|
||||
websocket: { enabled: true }
|
||||
},
|
||||
priority: 100,
|
||||
name: 'WebSocket Route for ws.example.com'
|
||||
},
|
||||
{
|
||||
match: { ports: 443, domains: 'legacy.example.com' },
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{ host: 'legacy-server', port: 443 }],
|
||||
tls: { mode: 'passthrough' }
|
||||
},
|
||||
name: 'HTTPS Passthrough Route for legacy.example.com'
|
||||
}
|
||||
];
|
||||
|
||||
// Validate all routes
|
||||
|
||||
const validationResult = validateRoutes(routes);
|
||||
expect(validationResult.valid).toBeTrue();
|
||||
expect(validationResult.errors.length).toEqual(0);
|
||||
|
||||
// Test route matching for different endpoints
|
||||
|
||||
// Web server (HTTPS)
|
||||
|
||||
const webServerMatch = findBestMatchingRoute(routes, { domain: 'example.com', port: 443 });
|
||||
expect(webServerMatch).not.toBeUndefined();
|
||||
if (webServerMatch) {
|
||||
expect(webServerMatch.action.type).toEqual('forward');
|
||||
expect(webServerMatch.action.targets[0].host).toEqual('web-server');
|
||||
}
|
||||
|
||||
// Web server (HTTP redirect via socket handler)
|
||||
|
||||
const webRedirectMatch = findBestMatchingRoute(routes, { domain: 'example.com', port: 80 });
|
||||
expect(webRedirectMatch).not.toBeUndefined();
|
||||
if (webRedirectMatch) {
|
||||
expect(webRedirectMatch.action.type).toEqual('socket-handler');
|
||||
}
|
||||
|
||||
// API server
|
||||
const apiMatch = findBestMatchingRoute(routes, {
|
||||
domain: 'api.example.com',
|
||||
|
||||
const apiMatch = findBestMatchingRoute(routes, {
|
||||
domain: 'api.example.com',
|
||||
port: 443,
|
||||
path: '/v1/users'
|
||||
});
|
||||
@@ -534,10 +592,9 @@ tap.test('Route Integration - Combining Multiple Route Types', async () => {
|
||||
expect(apiMatch.action.type).toEqual('forward');
|
||||
expect(apiMatch.action.targets[0].host).toEqual('api-server');
|
||||
}
|
||||
|
||||
// WebSocket server
|
||||
const wsMatch = findBestMatchingRoute(routes, {
|
||||
domain: 'ws.example.com',
|
||||
|
||||
const wsMatch = findBestMatchingRoute(routes, {
|
||||
domain: 'ws.example.com',
|
||||
port: 443,
|
||||
path: '/live'
|
||||
});
|
||||
@@ -547,12 +604,9 @@ tap.test('Route Integration - Combining Multiple Route Types', async () => {
|
||||
expect(wsMatch.action.targets[0].host).toEqual('websocket-server');
|
||||
expect(wsMatch.action.websocket?.enabled).toBeTrue();
|
||||
}
|
||||
|
||||
// Static assets route was removed - static file serving should be handled externally
|
||||
|
||||
// Legacy system
|
||||
const legacyMatch = findBestMatchingRoute(routes, {
|
||||
domain: 'legacy.example.com',
|
||||
|
||||
const legacyMatch = findBestMatchingRoute(routes, {
|
||||
domain: 'legacy.example.com',
|
||||
port: 443
|
||||
});
|
||||
expect(legacyMatch).not.toBeUndefined();
|
||||
@@ -565,7 +619,6 @@ tap.test('Route Integration - Combining Multiple Route Types', async () => {
|
||||
// --------------------------------- Protocol Match Field Tests ---------------------------------
|
||||
|
||||
tap.test('Routes: Should accept protocol field on route match', async () => {
|
||||
// Create a route with protocol: 'http'
|
||||
const httpOnlyRoute: IRouteConfig = {
|
||||
match: {
|
||||
ports: 443,
|
||||
@@ -583,16 +636,13 @@ tap.test('Routes: Should accept protocol field on route match', async () => {
|
||||
name: 'HTTP-only Route',
|
||||
};
|
||||
|
||||
// Validate the route - protocol field should not cause errors
|
||||
const validation = validateRouteConfig(httpOnlyRoute);
|
||||
expect(validation.valid).toBeTrue();
|
||||
|
||||
// Verify the protocol field is preserved
|
||||
expect(httpOnlyRoute.match.protocol).toEqual('http');
|
||||
});
|
||||
|
||||
tap.test('Routes: Should accept protocol tcp on route match', async () => {
|
||||
// Create a route with protocol: 'tcp'
|
||||
const tcpOnlyRoute: IRouteConfig = {
|
||||
match: {
|
||||
ports: 443,
|
||||
@@ -616,28 +666,26 @@ tap.test('Routes: Should accept protocol tcp on route match', async () => {
|
||||
});
|
||||
|
||||
tap.test('Routes: Protocol field should work with terminate-and-reencrypt', async () => {
|
||||
// Create a terminate-and-reencrypt route that only accepts HTTP
|
||||
const reencryptRoute = createHttpsTerminateRoute(
|
||||
'secure.example.com',
|
||||
{ host: 'backend', port: 443 },
|
||||
{ reencrypt: true, certificate: 'auto', name: 'Reencrypt HTTP Route' }
|
||||
);
|
||||
const reencryptRoute: IRouteConfig = {
|
||||
match: { ports: 443, domains: 'secure.example.com' },
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{ host: 'backend', port: 443 }],
|
||||
tls: { mode: 'terminate-and-reencrypt', certificate: 'auto' }
|
||||
},
|
||||
name: 'Reencrypt HTTP Route'
|
||||
};
|
||||
|
||||
// Set protocol restriction to http
|
||||
reencryptRoute.match.protocol = 'http';
|
||||
|
||||
// Validate the route
|
||||
const validation = validateRouteConfig(reencryptRoute);
|
||||
expect(validation.valid).toBeTrue();
|
||||
|
||||
// Verify TLS mode
|
||||
expect(reencryptRoute.action.tls?.mode).toEqual('terminate-and-reencrypt');
|
||||
// Verify protocol field is preserved
|
||||
expect(reencryptRoute.match.protocol).toEqual('http');
|
||||
});
|
||||
|
||||
tap.test('Routes: Protocol field should not affect domain/port matching', async () => {
|
||||
// Routes with and without protocol field should both match the same domain/port
|
||||
const routeWithProtocol: IRouteConfig = {
|
||||
match: {
|
||||
ports: 443,
|
||||
@@ -669,11 +717,9 @@ tap.test('Routes: Protocol field should not affect domain/port matching', async
|
||||
|
||||
const routes = [routeWithProtocol, routeWithoutProtocol];
|
||||
|
||||
// Both routes should match the domain/port (protocol is a hint for Rust-side matching)
|
||||
const matches = findMatchingRoutes(routes, { domain: 'example.com', port: 443 });
|
||||
expect(matches.length).toEqual(2);
|
||||
|
||||
// The one with higher priority should be first
|
||||
const best = findBestMatchingRoute(routes, { domain: 'example.com', port: 443 });
|
||||
expect(best).not.toBeUndefined();
|
||||
expect(best!.name).toEqual('With Protocol');
|
||||
@@ -696,11 +742,9 @@ tap.test('Routes: Protocol field preserved through route cloning', async () => {
|
||||
|
||||
const cloned = cloneRoute(original);
|
||||
|
||||
// Verify protocol is preserved in clone
|
||||
expect(cloned.match.protocol).toEqual('http');
|
||||
expect(cloned.action.tls?.mode).toEqual('terminate-and-reencrypt');
|
||||
|
||||
// Modify clone should not affect original
|
||||
cloned.match.protocol = 'tcp';
|
||||
expect(original.match.protocol).toEqual('http');
|
||||
});
|
||||
@@ -720,10 +764,9 @@ tap.test('Routes: Protocol field preserved through route merging', async () => {
|
||||
name: 'Merge Base',
|
||||
};
|
||||
|
||||
// Merge with override that changes name but not protocol
|
||||
const merged = mergeRouteConfigs(base, { name: 'Merged Route' });
|
||||
expect(merged.match.protocol).toEqual('http');
|
||||
expect(merged.name).toEqual('Merged Route');
|
||||
});
|
||||
|
||||
export default tap.start();
|
||||
export default tap.start();
|
||||
|
||||
+213
-468
@@ -1,21 +1,7 @@
|
||||
import { tap, expect } from '@git.zone/tstest/tapbundle';
|
||||
import * as plugins from '../ts/plugins.js';
|
||||
|
||||
// Import from individual modules to avoid naming conflicts
|
||||
import {
|
||||
// Route helpers
|
||||
createHttpRoute,
|
||||
createHttpsTerminateRoute,
|
||||
createApiRoute,
|
||||
createWebSocketRoute,
|
||||
createHttpToHttpsRedirect,
|
||||
createHttpsPassthroughRoute,
|
||||
createCompleteHttpsServer,
|
||||
createLoadBalancerRoute
|
||||
} from '../ts/proxies/smart-proxy/utils/route-helpers.js';
|
||||
|
||||
import {
|
||||
// Route validators
|
||||
validateRouteConfig,
|
||||
validateRoutes,
|
||||
isValidDomain,
|
||||
@@ -27,7 +13,6 @@ import {
|
||||
} from '../ts/proxies/smart-proxy/utils/route-validator.js';
|
||||
|
||||
import {
|
||||
// Route utilities
|
||||
mergeRouteConfigs,
|
||||
findMatchingRoutes,
|
||||
findBestMatchingRoute,
|
||||
@@ -39,16 +24,6 @@ import {
|
||||
cloneRoute
|
||||
} from '../ts/proxies/smart-proxy/utils/route-utils.js';
|
||||
|
||||
import {
|
||||
// Route patterns
|
||||
createApiGatewayRoute,
|
||||
createWebSocketRoute as createWebSocketPattern,
|
||||
createLoadBalancerRoute as createLbPattern,
|
||||
addRateLimiting,
|
||||
addBasicAuth,
|
||||
addJwtAuth
|
||||
} from '../ts/proxies/smart-proxy/utils/route-helpers.js';
|
||||
|
||||
import type {
|
||||
IRouteConfig,
|
||||
IRouteMatch,
|
||||
@@ -84,7 +59,7 @@ tap.test('Route Validation - isValidPort', async () => {
|
||||
expect(isValidPort(443)).toBeTrue();
|
||||
expect(isValidPort(8080)).toBeTrue();
|
||||
expect(isValidPort([80, 443])).toBeTrue();
|
||||
|
||||
|
||||
// Invalid ports
|
||||
expect(isValidPort(0)).toBeFalse();
|
||||
expect(isValidPort(65536)).toBeFalse();
|
||||
@@ -101,7 +76,7 @@ tap.test('Route Validation - validateRouteMatch', async () => {
|
||||
const validResult = validateRouteMatch(validMatch);
|
||||
expect(validResult.valid).toBeTrue();
|
||||
expect(validResult.errors.length).toEqual(0);
|
||||
|
||||
|
||||
// Invalid match configuration (invalid domain)
|
||||
const invalidMatch: IRouteMatch = {
|
||||
ports: 80,
|
||||
@@ -111,7 +86,7 @@ tap.test('Route Validation - validateRouteMatch', async () => {
|
||||
expect(invalidResult.valid).toBeFalse();
|
||||
expect(invalidResult.errors.length).toBeGreaterThan(0);
|
||||
expect(invalidResult.errors[0]).toInclude('Invalid domain');
|
||||
|
||||
|
||||
// Invalid match configuration (invalid port)
|
||||
const invalidPortMatch: IRouteMatch = {
|
||||
ports: 0,
|
||||
@@ -121,7 +96,7 @@ tap.test('Route Validation - validateRouteMatch', async () => {
|
||||
expect(invalidPortResult.valid).toBeFalse();
|
||||
expect(invalidPortResult.errors.length).toBeGreaterThan(0);
|
||||
expect(invalidPortResult.errors[0]).toInclude('Invalid port');
|
||||
|
||||
|
||||
// Test path validation
|
||||
const invalidPathMatch: IRouteMatch = {
|
||||
ports: 80,
|
||||
@@ -146,7 +121,7 @@ tap.test('Route Validation - validateRouteAction', async () => {
|
||||
const validForwardResult = validateRouteAction(validForwardAction);
|
||||
expect(validForwardResult.valid).toBeTrue();
|
||||
expect(validForwardResult.errors.length).toEqual(0);
|
||||
|
||||
|
||||
// Valid socket-handler action
|
||||
const validSocketAction: IRouteAction = {
|
||||
type: 'socket-handler',
|
||||
@@ -157,7 +132,7 @@ tap.test('Route Validation - validateRouteAction', async () => {
|
||||
const validSocketResult = validateRouteAction(validSocketAction);
|
||||
expect(validSocketResult.valid).toBeTrue();
|
||||
expect(validSocketResult.errors.length).toEqual(0);
|
||||
|
||||
|
||||
// Invalid action (missing targets)
|
||||
const invalidAction: IRouteAction = {
|
||||
type: 'forward'
|
||||
@@ -166,7 +141,7 @@ tap.test('Route Validation - validateRouteAction', async () => {
|
||||
expect(invalidResult.valid).toBeFalse();
|
||||
expect(invalidResult.errors.length).toBeGreaterThan(0);
|
||||
expect(invalidResult.errors[0]).toInclude('Targets array is required');
|
||||
|
||||
|
||||
// Invalid action (missing socket handler)
|
||||
const invalidSocketAction: IRouteAction = {
|
||||
type: 'socket-handler'
|
||||
@@ -179,11 +154,15 @@ tap.test('Route Validation - validateRouteAction', async () => {
|
||||
|
||||
tap.test('Route Validation - validateRouteConfig', async () => {
|
||||
// Valid route config
|
||||
const validRoute = createHttpRoute('example.com', { host: 'localhost', port: 3000 });
|
||||
const validRoute: IRouteConfig = {
|
||||
match: { ports: 80, domains: 'example.com' },
|
||||
action: { type: 'forward', targets: [{ host: 'localhost', port: 3000 }] },
|
||||
name: 'HTTP Route for example.com',
|
||||
};
|
||||
const validResult = validateRouteConfig(validRoute);
|
||||
expect(validResult.valid).toBeTrue();
|
||||
expect(validResult.errors.length).toEqual(0);
|
||||
|
||||
|
||||
// Invalid route config (missing targets)
|
||||
const invalidRoute: IRouteConfig = {
|
||||
match: {
|
||||
@@ -203,7 +182,11 @@ tap.test('Route Validation - validateRouteConfig', async () => {
|
||||
tap.test('Route Validation - validateRoutes', async () => {
|
||||
// Create valid and invalid routes
|
||||
const routes = [
|
||||
createHttpRoute('example.com', { host: 'localhost', port: 3000 }),
|
||||
{
|
||||
match: { ports: 80, domains: 'example.com' },
|
||||
action: { type: 'forward', targets: [{ host: 'localhost', port: 3000 }] },
|
||||
name: 'HTTP Route for example.com',
|
||||
} as IRouteConfig,
|
||||
{
|
||||
match: {
|
||||
domains: 'invalid..domain',
|
||||
@@ -217,9 +200,13 @@ tap.test('Route Validation - validateRoutes', async () => {
|
||||
}
|
||||
}
|
||||
} as IRouteConfig,
|
||||
createHttpsTerminateRoute('secure.example.com', { host: 'localhost', port: 3001 })
|
||||
{
|
||||
match: { ports: 443, domains: 'secure.example.com' },
|
||||
action: { type: 'forward', targets: [{ host: 'localhost', port: 3001 }], tls: { mode: 'terminate', certificate: 'auto' } },
|
||||
name: 'HTTPS Terminate Route for secure.example.com',
|
||||
} as IRouteConfig
|
||||
];
|
||||
|
||||
|
||||
const result = validateRoutes(routes);
|
||||
expect(result.valid).toBeFalse();
|
||||
expect(result.errors.length).toEqual(1);
|
||||
@@ -230,13 +217,13 @@ tap.test('Route Validation - validateRoutes', async () => {
|
||||
|
||||
tap.test('Route Validation - hasRequiredPropertiesForAction', async () => {
|
||||
// Forward action
|
||||
const forwardRoute = createHttpRoute('example.com', { host: 'localhost', port: 3000 });
|
||||
const forwardRoute: IRouteConfig = {
|
||||
match: { ports: 80, domains: 'example.com' },
|
||||
action: { type: 'forward', targets: [{ host: 'localhost', port: 3000 }] },
|
||||
name: 'HTTP Route for example.com',
|
||||
};
|
||||
expect(hasRequiredPropertiesForAction(forwardRoute, 'forward')).toBeTrue();
|
||||
|
||||
// Socket handler action (redirect functionality)
|
||||
const redirectRoute = createHttpToHttpsRedirect('example.com');
|
||||
expect(hasRequiredPropertiesForAction(redirectRoute, 'socket-handler')).toBeTrue();
|
||||
|
||||
|
||||
// Socket handler action
|
||||
const socketRoute: IRouteConfig = {
|
||||
match: {
|
||||
@@ -252,7 +239,7 @@ tap.test('Route Validation - hasRequiredPropertiesForAction', async () => {
|
||||
name: 'Socket Handler Route'
|
||||
};
|
||||
expect(hasRequiredPropertiesForAction(socketRoute, 'socket-handler')).toBeTrue();
|
||||
|
||||
|
||||
// Missing required properties
|
||||
const invalidForwardRoute: IRouteConfig = {
|
||||
match: {
|
||||
@@ -269,9 +256,13 @@ tap.test('Route Validation - hasRequiredPropertiesForAction', async () => {
|
||||
|
||||
tap.test('Route Validation - assertValidRoute', async () => {
|
||||
// Valid route
|
||||
const validRoute = createHttpRoute('example.com', { host: 'localhost', port: 3000 });
|
||||
const validRoute: IRouteConfig = {
|
||||
match: { ports: 80, domains: 'example.com' },
|
||||
action: { type: 'forward', targets: [{ host: 'localhost', port: 3000 }] },
|
||||
name: 'HTTP Route for example.com',
|
||||
};
|
||||
expect(() => assertValidRoute(validRoute)).not.toThrow();
|
||||
|
||||
|
||||
// Invalid route
|
||||
const invalidRoute: IRouteConfig = {
|
||||
match: {
|
||||
@@ -290,8 +281,12 @@ tap.test('Route Validation - assertValidRoute', async () => {
|
||||
|
||||
tap.test('Route Utilities - mergeRouteConfigs', async () => {
|
||||
// Base route
|
||||
const baseRoute = createHttpRoute('example.com', { host: 'localhost', port: 3000 });
|
||||
|
||||
const baseRoute: IRouteConfig = {
|
||||
match: { ports: 80, domains: 'example.com' },
|
||||
action: { type: 'forward', targets: [{ host: 'localhost', port: 3000 }] },
|
||||
name: 'HTTP Route for example.com',
|
||||
};
|
||||
|
||||
// Override with different name and port
|
||||
const overrideRoute: Partial<IRouteConfig> = {
|
||||
name: 'Merged Route',
|
||||
@@ -299,16 +294,16 @@ tap.test('Route Utilities - mergeRouteConfigs', async () => {
|
||||
ports: 8080
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
// Merge configs
|
||||
const mergedRoute = mergeRouteConfigs(baseRoute, overrideRoute);
|
||||
|
||||
|
||||
// Check merged properties
|
||||
expect(mergedRoute.name).toEqual('Merged Route');
|
||||
expect(mergedRoute.match.ports).toEqual(8080);
|
||||
expect(mergedRoute.match.domains).toEqual('example.com');
|
||||
expect(mergedRoute.action.type).toEqual('forward');
|
||||
|
||||
|
||||
// Test merging action properties
|
||||
const actionOverride: Partial<IRouteConfig> = {
|
||||
action: {
|
||||
@@ -319,11 +314,11 @@ tap.test('Route Utilities - mergeRouteConfigs', async () => {
|
||||
}]
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
const actionMergedRoute = mergeRouteConfigs(baseRoute, actionOverride);
|
||||
expect(actionMergedRoute.action.targets?.[0]?.host).toEqual('new-host.local');
|
||||
expect(actionMergedRoute.action.targets?.[0]?.port).toEqual(5000);
|
||||
|
||||
|
||||
// Test replacing action with socket handler
|
||||
const typeChangeOverride: Partial<IRouteConfig> = {
|
||||
action: {
|
||||
@@ -336,7 +331,7 @@ tap.test('Route Utilities - mergeRouteConfigs', async () => {
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
const typeChangedRoute = mergeRouteConfigs(baseRoute, typeChangeOverride);
|
||||
expect(typeChangedRoute.action.type).toEqual('socket-handler');
|
||||
expect(typeChangedRoute.action.socketHandler).toBeDefined();
|
||||
@@ -345,37 +340,53 @@ tap.test('Route Utilities - mergeRouteConfigs', async () => {
|
||||
|
||||
tap.test('Route Matching - routeMatchesDomain', async () => {
|
||||
// Create route with wildcard domain
|
||||
const wildcardRoute = createHttpRoute('*.example.com', { host: 'localhost', port: 3000 });
|
||||
|
||||
const wildcardRoute: IRouteConfig = {
|
||||
match: { ports: 80, domains: '*.example.com' },
|
||||
action: { type: 'forward', targets: [{ host: 'localhost', port: 3000 }] },
|
||||
name: 'HTTP Route for *.example.com',
|
||||
};
|
||||
|
||||
// Create route with exact domain
|
||||
const exactRoute = createHttpRoute('example.com', { host: 'localhost', port: 3000 });
|
||||
|
||||
const exactRoute: IRouteConfig = {
|
||||
match: { ports: 80, domains: 'example.com' },
|
||||
action: { type: 'forward', targets: [{ host: 'localhost', port: 3000 }] },
|
||||
name: 'HTTP Route for example.com',
|
||||
};
|
||||
|
||||
// Create route with multiple domains
|
||||
const multiDomainRoute = createHttpRoute(['example.com', 'example.org'], { host: 'localhost', port: 3000 });
|
||||
|
||||
const multiDomainRoute: IRouteConfig = {
|
||||
match: { ports: 80, domains: ['example.com', 'example.org'] },
|
||||
action: { type: 'forward', targets: [{ host: 'localhost', port: 3000 }] },
|
||||
name: 'HTTP Route for example.com,example.org',
|
||||
};
|
||||
|
||||
// Test wildcard domain matching
|
||||
expect(routeMatchesDomain(wildcardRoute, 'sub.example.com')).toBeTrue();
|
||||
expect(routeMatchesDomain(wildcardRoute, 'another.example.com')).toBeTrue();
|
||||
expect(routeMatchesDomain(wildcardRoute, 'example.com')).toBeFalse();
|
||||
expect(routeMatchesDomain(wildcardRoute, 'example.org')).toBeFalse();
|
||||
|
||||
|
||||
// Test exact domain matching
|
||||
expect(routeMatchesDomain(exactRoute, 'example.com')).toBeTrue();
|
||||
expect(routeMatchesDomain(exactRoute, 'sub.example.com')).toBeFalse();
|
||||
|
||||
|
||||
// Test multiple domains matching
|
||||
expect(routeMatchesDomain(multiDomainRoute, 'example.com')).toBeTrue();
|
||||
expect(routeMatchesDomain(multiDomainRoute, 'example.org')).toBeTrue();
|
||||
expect(routeMatchesDomain(multiDomainRoute, 'example.net')).toBeFalse();
|
||||
|
||||
|
||||
// Test case insensitivity
|
||||
expect(routeMatchesDomain(exactRoute, 'Example.Com')).toBeTrue();
|
||||
});
|
||||
|
||||
tap.test('Route Matching - routeMatchesPort', async () => {
|
||||
// Create routes with different port configurations
|
||||
const singlePortRoute = createHttpRoute('example.com', { host: 'localhost', port: 3000 });
|
||||
|
||||
const singlePortRoute: IRouteConfig = {
|
||||
match: { ports: 80, domains: 'example.com' },
|
||||
action: { type: 'forward', targets: [{ host: 'localhost', port: 3000 }] },
|
||||
name: 'HTTP Route for example.com',
|
||||
};
|
||||
|
||||
const multiPortRoute: IRouteConfig = {
|
||||
match: {
|
||||
domains: 'example.com',
|
||||
@@ -389,7 +400,7 @@ tap.test('Route Matching - routeMatchesPort', async () => {
|
||||
}]
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
const portRangeRoute: IRouteConfig = {
|
||||
match: {
|
||||
domains: 'example.com',
|
||||
@@ -403,16 +414,16 @@ tap.test('Route Matching - routeMatchesPort', async () => {
|
||||
}]
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
// Test single port matching
|
||||
expect(routeMatchesPort(singlePortRoute, 80)).toBeTrue();
|
||||
expect(routeMatchesPort(singlePortRoute, 443)).toBeFalse();
|
||||
|
||||
|
||||
// Test multi-port matching
|
||||
expect(routeMatchesPort(multiPortRoute, 80)).toBeTrue();
|
||||
expect(routeMatchesPort(multiPortRoute, 8080)).toBeTrue();
|
||||
expect(routeMatchesPort(multiPortRoute, 3000)).toBeFalse();
|
||||
|
||||
|
||||
// Test port range matching
|
||||
expect(routeMatchesPort(portRangeRoute, 8000)).toBeTrue();
|
||||
expect(routeMatchesPort(portRangeRoute, 8500)).toBeTrue();
|
||||
@@ -437,11 +448,11 @@ tap.test('Route Matching - routeMatchesPath', async () => {
|
||||
}]
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
// Test prefix matching with wildcard (not trailing slash)
|
||||
const prefixPathRoute: IRouteConfig = {
|
||||
match: {
|
||||
domains: 'example.com',
|
||||
domains: 'example.com',
|
||||
ports: 80,
|
||||
path: '/api/*'
|
||||
},
|
||||
@@ -453,7 +464,7 @@ tap.test('Route Matching - routeMatchesPath', async () => {
|
||||
}]
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
const wildcardPathRoute: IRouteConfig = {
|
||||
match: {
|
||||
domains: 'example.com',
|
||||
@@ -468,17 +479,17 @@ tap.test('Route Matching - routeMatchesPath', async () => {
|
||||
}]
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
// Test exact path matching
|
||||
expect(routeMatchesPath(exactPathRoute, '/api')).toBeTrue();
|
||||
expect(routeMatchesPath(exactPathRoute, '/api/users')).toBeFalse();
|
||||
expect(routeMatchesPath(exactPathRoute, '/app')).toBeFalse();
|
||||
|
||||
|
||||
// Test prefix path matching with wildcard
|
||||
expect(routeMatchesPath(prefixPathRoute, '/api/')).toBeFalse(); // Wildcard requires content after /api/
|
||||
expect(routeMatchesPath(prefixPathRoute, '/api/users')).toBeTrue();
|
||||
expect(routeMatchesPath(prefixPathRoute, '/app/')).toBeFalse();
|
||||
|
||||
|
||||
// Test wildcard path matching
|
||||
expect(routeMatchesPath(wildcardPathRoute, '/api/users')).toBeTrue();
|
||||
expect(routeMatchesPath(wildcardPathRoute, '/api/products')).toBeTrue();
|
||||
@@ -504,30 +515,59 @@ tap.test('Route Matching - routeMatchesHeaders', async () => {
|
||||
}]
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
// Test header matching
|
||||
expect(routeMatchesHeaders(headerRoute, {
|
||||
'Content-Type': 'application/json',
|
||||
'X-Custom-Header': 'value'
|
||||
})).toBeTrue();
|
||||
|
||||
|
||||
expect(routeMatchesHeaders(headerRoute, {
|
||||
'Content-Type': 'application/json',
|
||||
'X-Custom-Header': 'value',
|
||||
'Extra-Header': 'something'
|
||||
})).toBeTrue();
|
||||
|
||||
|
||||
expect(routeMatchesHeaders(headerRoute, {
|
||||
'Content-Type': 'application/json'
|
||||
})).toBeFalse();
|
||||
|
||||
|
||||
expect(routeMatchesHeaders(headerRoute, {
|
||||
'Content-Type': 'text/html',
|
||||
'X-Custom-Header': 'value'
|
||||
})).toBeFalse();
|
||||
|
||||
|
||||
const regexHeaderRoute: IRouteConfig = {
|
||||
match: {
|
||||
domains: 'example.com',
|
||||
ports: 80,
|
||||
headers: {
|
||||
'Content-Type': /^application\/(json|problem\+json)$/i,
|
||||
}
|
||||
},
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{
|
||||
host: 'localhost',
|
||||
port: 3000
|
||||
}]
|
||||
}
|
||||
};
|
||||
|
||||
expect(routeMatchesHeaders(regexHeaderRoute, {
|
||||
'Content-Type': 'Application/Problem+Json',
|
||||
})).toBeTrue();
|
||||
|
||||
expect(routeMatchesHeaders(regexHeaderRoute, {
|
||||
'Content-Type': 'text/html',
|
||||
})).toBeFalse();
|
||||
|
||||
// Route without header matching should match any headers
|
||||
const noHeaderRoute = createHttpRoute('example.com', { host: 'localhost', port: 3000 });
|
||||
const noHeaderRoute: IRouteConfig = {
|
||||
match: { ports: 80, domains: 'example.com' },
|
||||
action: { type: 'forward', targets: [{ host: 'localhost', port: 3000 }] },
|
||||
name: 'HTTP Route for example.com',
|
||||
};
|
||||
expect(routeMatchesHeaders(noHeaderRoute, {
|
||||
'Content-Type': 'application/json'
|
||||
})).toBeTrue();
|
||||
@@ -536,78 +576,118 @@ tap.test('Route Matching - routeMatchesHeaders', async () => {
|
||||
tap.test('Route Finding - findMatchingRoutes', async () => {
|
||||
// Create multiple routes
|
||||
const routes: IRouteConfig[] = [
|
||||
createHttpRoute('example.com', { host: 'localhost', port: 3000 }),
|
||||
createHttpsTerminateRoute('secure.example.com', { host: 'localhost', port: 3001 }),
|
||||
createApiRoute('api.example.com', '/v1', { host: 'localhost', port: 3002 }),
|
||||
createWebSocketRoute('ws.example.com', '/socket', { host: 'localhost', port: 3003 })
|
||||
{
|
||||
match: { ports: 80, domains: 'example.com' },
|
||||
action: { type: 'forward', targets: [{ host: 'localhost', port: 3000 }] },
|
||||
name: 'HTTP Route for example.com',
|
||||
},
|
||||
{
|
||||
match: { ports: 443, domains: 'secure.example.com' },
|
||||
action: { type: 'forward', targets: [{ host: 'localhost', port: 3001 }], tls: { mode: 'terminate', certificate: 'auto' } },
|
||||
name: 'HTTPS Route for secure.example.com',
|
||||
},
|
||||
{
|
||||
match: { ports: 443, domains: 'api.example.com', path: '/v1/*' },
|
||||
action: { type: 'forward', targets: [{ host: 'localhost', port: 3002 }], tls: { mode: 'terminate', certificate: 'auto' } },
|
||||
name: 'API Route for api.example.com',
|
||||
},
|
||||
{
|
||||
match: { ports: 443, domains: 'ws.example.com', path: '/socket' },
|
||||
action: { type: 'forward', targets: [{ host: 'localhost', port: 3003 }], tls: { mode: 'terminate', certificate: 'auto' }, websocket: { enabled: true } },
|
||||
name: 'WebSocket Route for ws.example.com',
|
||||
},
|
||||
];
|
||||
|
||||
|
||||
// Set priorities
|
||||
routes[0].priority = 10;
|
||||
routes[1].priority = 20;
|
||||
routes[2].priority = 30;
|
||||
routes[3].priority = 40;
|
||||
|
||||
|
||||
// Find routes for different criteria
|
||||
const httpMatches = findMatchingRoutes(routes, { domain: 'example.com', port: 80 });
|
||||
expect(httpMatches.length).toEqual(1);
|
||||
expect(httpMatches[0].name).toInclude('HTTP Route');
|
||||
|
||||
|
||||
const httpsMatches = findMatchingRoutes(routes, { domain: 'secure.example.com', port: 443 });
|
||||
expect(httpsMatches.length).toEqual(1);
|
||||
expect(httpsMatches[0].name).toInclude('HTTPS Route');
|
||||
|
||||
|
||||
const apiMatches = findMatchingRoutes(routes, { domain: 'api.example.com', path: '/v1/users' });
|
||||
expect(apiMatches.length).toEqual(1);
|
||||
expect(apiMatches[0].name).toInclude('API Route');
|
||||
|
||||
|
||||
const wsMatches = findMatchingRoutes(routes, { domain: 'ws.example.com', path: '/socket' });
|
||||
expect(wsMatches.length).toEqual(1);
|
||||
expect(wsMatches[0].name).toInclude('WebSocket Route');
|
||||
|
||||
|
||||
// Test finding multiple routes that match same criteria
|
||||
const route1 = createHttpRoute('example.com', { host: 'localhost', port: 3000 });
|
||||
const route1: IRouteConfig = {
|
||||
match: { ports: 80, domains: 'example.com' },
|
||||
action: { type: 'forward', targets: [{ host: 'localhost', port: 3000 }] },
|
||||
name: 'HTTP Route for example.com',
|
||||
};
|
||||
route1.priority = 10;
|
||||
|
||||
const route2 = createHttpRoute('example.com', { host: 'localhost', port: 3001 });
|
||||
|
||||
const route2: IRouteConfig = {
|
||||
match: { ports: 80, domains: 'example.com' },
|
||||
action: { type: 'forward', targets: [{ host: 'localhost', port: 3001 }] },
|
||||
name: 'HTTP Route for example.com',
|
||||
};
|
||||
route2.priority = 20;
|
||||
route2.match.path = '/api';
|
||||
|
||||
|
||||
const multiMatchRoutes = [route1, route2];
|
||||
|
||||
|
||||
const multiMatches = findMatchingRoutes(multiMatchRoutes, { domain: 'example.com', port: 80 });
|
||||
expect(multiMatches.length).toEqual(2);
|
||||
expect(multiMatches[0].priority).toEqual(20); // Higher priority should be first
|
||||
expect(multiMatches[1].priority).toEqual(10);
|
||||
|
||||
|
||||
// Test disabled routes
|
||||
const disabledRoute = createHttpRoute('example.com', { host: 'localhost', port: 3000 });
|
||||
const disabledRoute: IRouteConfig = {
|
||||
match: { ports: 80, domains: 'example.com' },
|
||||
action: { type: 'forward', targets: [{ host: 'localhost', port: 3000 }] },
|
||||
name: 'HTTP Route for example.com',
|
||||
};
|
||||
disabledRoute.enabled = false;
|
||||
|
||||
|
||||
const enabledRoutes = findMatchingRoutes([disabledRoute], { domain: 'example.com', port: 80 });
|
||||
expect(enabledRoutes.length).toEqual(0);
|
||||
});
|
||||
|
||||
tap.test('Route Finding - findBestMatchingRoute', async () => {
|
||||
// Create multiple routes with different priorities
|
||||
const route1 = createHttpRoute('example.com', { host: 'localhost', port: 3000 });
|
||||
const route1: IRouteConfig = {
|
||||
match: { ports: 80, domains: 'example.com' },
|
||||
action: { type: 'forward', targets: [{ host: 'localhost', port: 3000 }] },
|
||||
name: 'HTTP Route for example.com',
|
||||
};
|
||||
route1.priority = 10;
|
||||
|
||||
const route2 = createHttpRoute('example.com', { host: 'localhost', port: 3001 });
|
||||
|
||||
const route2: IRouteConfig = {
|
||||
match: { ports: 80, domains: 'example.com' },
|
||||
action: { type: 'forward', targets: [{ host: 'localhost', port: 3001 }] },
|
||||
name: 'HTTP Route for example.com',
|
||||
};
|
||||
route2.priority = 20;
|
||||
route2.match.path = '/api';
|
||||
|
||||
const route3 = createHttpRoute('example.com', { host: 'localhost', port: 3002 });
|
||||
|
||||
const route3: IRouteConfig = {
|
||||
match: { ports: 80, domains: 'example.com' },
|
||||
action: { type: 'forward', targets: [{ host: 'localhost', port: 3002 }] },
|
||||
name: 'HTTP Route for example.com',
|
||||
};
|
||||
route3.priority = 30;
|
||||
route3.match.path = '/api/users';
|
||||
|
||||
|
||||
const routes = [route1, route2, route3];
|
||||
|
||||
|
||||
// Find best route for different criteria
|
||||
const bestGeneral = findBestMatchingRoute(routes, { domain: 'example.com', port: 80 });
|
||||
expect(bestGeneral).not.toBeUndefined();
|
||||
expect(bestGeneral?.priority).toEqual(30);
|
||||
|
||||
|
||||
// Test when no routes match
|
||||
const noMatch = findBestMatchingRoute(routes, { domain: 'unknown.com', port: 80 });
|
||||
expect(noMatch).toBeUndefined();
|
||||
@@ -615,389 +695,54 @@ tap.test('Route Finding - findBestMatchingRoute', async () => {
|
||||
|
||||
tap.test('Route Utilities - generateRouteId', async () => {
|
||||
// Test ID generation for different route types
|
||||
const httpRoute = createHttpRoute('example.com', { host: 'localhost', port: 3000 });
|
||||
const httpRoute: IRouteConfig = {
|
||||
match: { ports: 80, domains: 'example.com' },
|
||||
action: { type: 'forward', targets: [{ host: 'localhost', port: 3000 }] },
|
||||
name: 'HTTP Route for example.com',
|
||||
};
|
||||
const httpId = generateRouteId(httpRoute);
|
||||
expect(httpId).toInclude('example-com');
|
||||
expect(httpId).toInclude('80');
|
||||
expect(httpId).toInclude('forward');
|
||||
|
||||
const httpsRoute = createHttpsTerminateRoute('secure.example.com', { host: 'localhost', port: 3001 });
|
||||
|
||||
const httpsRoute: IRouteConfig = {
|
||||
match: { ports: 443, domains: 'secure.example.com' },
|
||||
action: { type: 'forward', targets: [{ host: 'localhost', port: 3001 }], tls: { mode: 'terminate', certificate: 'auto' } },
|
||||
name: 'HTTPS Terminate Route for secure.example.com',
|
||||
};
|
||||
const httpsId = generateRouteId(httpsRoute);
|
||||
expect(httpsId).toInclude('secure-example-com');
|
||||
expect(httpsId).toInclude('443');
|
||||
expect(httpsId).toInclude('forward');
|
||||
|
||||
const multiDomainRoute = createHttpRoute(['example.com', 'example.org'], { host: 'localhost', port: 3000 });
|
||||
|
||||
const multiDomainRoute: IRouteConfig = {
|
||||
match: { ports: 80, domains: ['example.com', 'example.org'] },
|
||||
action: { type: 'forward', targets: [{ host: 'localhost', port: 3000 }] },
|
||||
name: 'HTTP Route for example.com,example.org',
|
||||
};
|
||||
const multiDomainId = generateRouteId(multiDomainRoute);
|
||||
expect(multiDomainId).toInclude('example-com-example-org');
|
||||
});
|
||||
|
||||
tap.test('Route Utilities - cloneRoute', async () => {
|
||||
// Create a route and clone it
|
||||
const originalRoute = createHttpsTerminateRoute('example.com', { host: 'localhost', port: 3000 }, {
|
||||
certificate: 'auto',
|
||||
name: 'Original Route'
|
||||
});
|
||||
|
||||
const originalRoute: IRouteConfig = {
|
||||
match: { ports: 443, domains: 'example.com' },
|
||||
action: { type: 'forward', targets: [{ host: 'localhost', port: 3000 }], tls: { mode: 'terminate', certificate: 'auto' } },
|
||||
name: 'Original Route',
|
||||
};
|
||||
|
||||
const clonedRoute = cloneRoute(originalRoute);
|
||||
|
||||
|
||||
// Check that the values are identical
|
||||
expect(clonedRoute.name).toEqual(originalRoute.name);
|
||||
expect(clonedRoute.match.domains).toEqual(originalRoute.match.domains);
|
||||
expect(clonedRoute.action.type).toEqual(originalRoute.action.type);
|
||||
expect(clonedRoute.action.targets?.[0]?.port).toEqual(originalRoute.action.targets?.[0]?.port);
|
||||
|
||||
|
||||
// Modify the clone and check that the original is unchanged
|
||||
clonedRoute.name = 'Modified Clone';
|
||||
expect(originalRoute.name).toEqual('Original Route');
|
||||
});
|
||||
|
||||
// --------------------------------- Route Helper Tests ---------------------------------
|
||||
|
||||
tap.test('Route Helpers - createHttpRoute', async () => {
|
||||
const route = createHttpRoute('example.com', { host: 'localhost', port: 3000 });
|
||||
|
||||
expect(route.match.domains).toEqual('example.com');
|
||||
expect(route.match.ports).toEqual(80);
|
||||
expect(route.action.type).toEqual('forward');
|
||||
expect(route.action.targets?.[0]?.host).toEqual('localhost');
|
||||
expect(route.action.targets?.[0]?.port).toEqual(3000);
|
||||
|
||||
const validationResult = validateRouteConfig(route);
|
||||
expect(validationResult.valid).toBeTrue();
|
||||
});
|
||||
|
||||
tap.test('Route Helpers - createHttpsTerminateRoute', async () => {
|
||||
const route = createHttpsTerminateRoute('example.com', { host: 'localhost', port: 3000 }, {
|
||||
certificate: 'auto'
|
||||
});
|
||||
|
||||
expect(route.match.domains).toEqual('example.com');
|
||||
expect(route.match.ports).toEqual(443);
|
||||
expect(route.action.type).toEqual('forward');
|
||||
expect(route.action.tls.mode).toEqual('terminate');
|
||||
expect(route.action.tls.certificate).toEqual('auto');
|
||||
|
||||
const validationResult = validateRouteConfig(route);
|
||||
expect(validationResult.valid).toBeTrue();
|
||||
});
|
||||
|
||||
tap.test('Route Helpers - createHttpToHttpsRedirect', async () => {
|
||||
const route = createHttpToHttpsRedirect('example.com');
|
||||
|
||||
expect(route.match.domains).toEqual('example.com');
|
||||
expect(route.match.ports).toEqual(80);
|
||||
expect(route.action.type).toEqual('socket-handler');
|
||||
expect(route.action.socketHandler).toBeDefined();
|
||||
|
||||
const validationResult = validateRouteConfig(route);
|
||||
expect(validationResult.valid).toBeTrue();
|
||||
});
|
||||
|
||||
tap.test('Route Helpers - createHttpsPassthroughRoute', async () => {
|
||||
const route = createHttpsPassthroughRoute('example.com', { host: 'localhost', port: 3000 });
|
||||
|
||||
expect(route.match.domains).toEqual('example.com');
|
||||
expect(route.match.ports).toEqual(443);
|
||||
expect(route.action.type).toEqual('forward');
|
||||
expect(route.action.tls.mode).toEqual('passthrough');
|
||||
|
||||
const validationResult = validateRouteConfig(route);
|
||||
expect(validationResult.valid).toBeTrue();
|
||||
});
|
||||
|
||||
tap.test('Route Helpers - createCompleteHttpsServer', async () => {
|
||||
const routes = createCompleteHttpsServer('example.com', { host: 'localhost', port: 3000 }, {
|
||||
certificate: 'auto'
|
||||
});
|
||||
|
||||
expect(routes.length).toEqual(2);
|
||||
|
||||
// HTTPS route
|
||||
expect(routes[0].match.domains).toEqual('example.com');
|
||||
expect(routes[0].match.ports).toEqual(443);
|
||||
expect(routes[0].action.type).toEqual('forward');
|
||||
expect(routes[0].action.tls.mode).toEqual('terminate');
|
||||
|
||||
// HTTP redirect route
|
||||
expect(routes[1].match.domains).toEqual('example.com');
|
||||
expect(routes[1].match.ports).toEqual(80);
|
||||
expect(routes[1].action.type).toEqual('socket-handler');
|
||||
|
||||
const validation1 = validateRouteConfig(routes[0]);
|
||||
const validation2 = validateRouteConfig(routes[1]);
|
||||
expect(validation1.valid).toBeTrue();
|
||||
expect(validation2.valid).toBeTrue();
|
||||
});
|
||||
|
||||
// createStaticFileRoute has been removed - static file serving should be handled by
|
||||
// external servers (nginx/apache) behind the proxy
|
||||
|
||||
tap.test('Route Helpers - createApiRoute', async () => {
|
||||
const route = createApiRoute('api.example.com', '/v1', { host: 'localhost', port: 3000 }, {
|
||||
useTls: true,
|
||||
certificate: 'auto',
|
||||
addCorsHeaders: true
|
||||
});
|
||||
|
||||
expect(route.match.domains).toEqual('api.example.com');
|
||||
expect(route.match.ports).toEqual(443);
|
||||
expect(route.match.path).toEqual('/v1/*');
|
||||
expect(route.action.type).toEqual('forward');
|
||||
expect(route.action.tls.mode).toEqual('terminate');
|
||||
|
||||
// Check CORS headers if they exist
|
||||
if (route.headers && route.headers.response) {
|
||||
expect(route.headers.response['Access-Control-Allow-Origin']).toEqual('*');
|
||||
}
|
||||
|
||||
const validationResult = validateRouteConfig(route);
|
||||
expect(validationResult.valid).toBeTrue();
|
||||
});
|
||||
|
||||
tap.test('Route Helpers - createWebSocketRoute', async () => {
|
||||
const route = createWebSocketRoute('ws.example.com', '/socket', { host: 'localhost', port: 3000 }, {
|
||||
useTls: true,
|
||||
certificate: 'auto',
|
||||
pingInterval: 15000
|
||||
});
|
||||
|
||||
expect(route.match.domains).toEqual('ws.example.com');
|
||||
expect(route.match.ports).toEqual(443);
|
||||
expect(route.match.path).toEqual('/socket');
|
||||
expect(route.action.type).toEqual('forward');
|
||||
expect(route.action.tls.mode).toEqual('terminate');
|
||||
|
||||
// Check websocket configuration if it exists
|
||||
if (route.action.websocket) {
|
||||
expect(route.action.websocket.enabled).toBeTrue();
|
||||
expect(route.action.websocket.pingInterval).toEqual(15000);
|
||||
}
|
||||
|
||||
const validationResult = validateRouteConfig(route);
|
||||
expect(validationResult.valid).toBeTrue();
|
||||
});
|
||||
|
||||
tap.test('Route Helpers - createLoadBalancerRoute', async () => {
|
||||
const route = createLoadBalancerRoute(
|
||||
'loadbalancer.example.com',
|
||||
['server1.local', 'server2.local', 'server3.local'],
|
||||
8080,
|
||||
{
|
||||
tls: {
|
||||
mode: 'terminate',
|
||||
certificate: 'auto'
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
expect(route.match.domains).toEqual('loadbalancer.example.com');
|
||||
expect(route.match.ports).toEqual(443);
|
||||
expect(route.action.type).toEqual('forward');
|
||||
expect(route.action.targets).toBeDefined();
|
||||
if (route.action.targets && Array.isArray(route.action.targets[0]?.host)) {
|
||||
expect((route.action.targets[0].host as string[]).length).toEqual(3);
|
||||
}
|
||||
expect(route.action.targets?.[0]?.port).toEqual(8080);
|
||||
expect(route.action.tls.mode).toEqual('terminate');
|
||||
|
||||
const validationResult = validateRouteConfig(route);
|
||||
expect(validationResult.valid).toBeTrue();
|
||||
});
|
||||
|
||||
// --------------------------------- Route Pattern Tests ---------------------------------
|
||||
|
||||
tap.test('Route Patterns - createApiGatewayRoute', async () => {
|
||||
// Create API Gateway route
|
||||
const apiGatewayRoute = createApiGatewayRoute(
|
||||
'api.example.com',
|
||||
'/v1',
|
||||
{ host: 'localhost', port: 3000 },
|
||||
{
|
||||
useTls: true,
|
||||
addCorsHeaders: true
|
||||
}
|
||||
);
|
||||
|
||||
// Validate route configuration
|
||||
expect(apiGatewayRoute.match.domains).toEqual('api.example.com');
|
||||
expect(apiGatewayRoute.match.path).toInclude('/v1');
|
||||
expect(apiGatewayRoute.action.type).toEqual('forward');
|
||||
expect(apiGatewayRoute.action.targets?.[0]?.port).toEqual(3000);
|
||||
|
||||
// Check TLS configuration
|
||||
if (apiGatewayRoute.action.tls) {
|
||||
expect(apiGatewayRoute.action.tls.mode).toEqual('terminate');
|
||||
}
|
||||
|
||||
// Check CORS headers
|
||||
if (apiGatewayRoute.headers && apiGatewayRoute.headers.response) {
|
||||
expect(apiGatewayRoute.headers.response['Access-Control-Allow-Origin']).toEqual('*');
|
||||
}
|
||||
|
||||
const result = validateRouteConfig(apiGatewayRoute);
|
||||
expect(result.valid).toBeTrue();
|
||||
});
|
||||
|
||||
// createStaticFileServerRoute has been removed - static file serving should be handled by
|
||||
// external servers (nginx/apache) behind the proxy
|
||||
|
||||
tap.test('Route Patterns - createWebSocketPattern', async () => {
|
||||
// Create WebSocket route pattern
|
||||
const wsRoute = createWebSocketPattern(
|
||||
'ws.example.com',
|
||||
{ host: 'localhost', port: 3000 },
|
||||
{
|
||||
useTls: true,
|
||||
path: '/socket',
|
||||
pingInterval: 10000
|
||||
}
|
||||
);
|
||||
|
||||
// Validate route configuration
|
||||
expect(wsRoute.match.domains).toEqual('ws.example.com');
|
||||
expect(wsRoute.match.path).toEqual('/socket');
|
||||
expect(wsRoute.action.type).toEqual('forward');
|
||||
expect(wsRoute.action.targets?.[0]?.port).toEqual(3000);
|
||||
|
||||
// Check TLS configuration
|
||||
if (wsRoute.action.tls) {
|
||||
expect(wsRoute.action.tls.mode).toEqual('terminate');
|
||||
}
|
||||
|
||||
// Check websocket configuration if it exists
|
||||
if (wsRoute.action.websocket) {
|
||||
expect(wsRoute.action.websocket.enabled).toBeTrue();
|
||||
expect(wsRoute.action.websocket.pingInterval).toEqual(10000);
|
||||
}
|
||||
|
||||
const result = validateRouteConfig(wsRoute);
|
||||
expect(result.valid).toBeTrue();
|
||||
});
|
||||
|
||||
tap.test('Route Patterns - createLoadBalancerRoute pattern', async () => {
|
||||
// Create load balancer route pattern with missing algorithm as it might not be implemented yet
|
||||
try {
|
||||
const lbRoute = createLbPattern(
|
||||
'lb.example.com',
|
||||
[
|
||||
{ host: 'server1.local', port: 8080 },
|
||||
{ host: 'server2.local', port: 8080 },
|
||||
{ host: 'server3.local', port: 8080 }
|
||||
],
|
||||
{
|
||||
useTls: true
|
||||
}
|
||||
);
|
||||
|
||||
// Validate route configuration
|
||||
expect(lbRoute.match.domains).toEqual('lb.example.com');
|
||||
expect(lbRoute.action.type).toEqual('forward');
|
||||
|
||||
// Check target hosts
|
||||
if (lbRoute.action.targets && Array.isArray(lbRoute.action.targets[0]?.host)) {
|
||||
expect((lbRoute.action.targets[0].host as string[]).length).toEqual(3);
|
||||
}
|
||||
|
||||
// Check TLS configuration
|
||||
if (lbRoute.action.tls) {
|
||||
expect(lbRoute.action.tls.mode).toEqual('terminate');
|
||||
}
|
||||
|
||||
const result = validateRouteConfig(lbRoute);
|
||||
expect(result.valid).toBeTrue();
|
||||
} catch (error) {
|
||||
// If the pattern is not implemented yet, skip this test
|
||||
console.log('Load balancer pattern might not be fully implemented yet');
|
||||
}
|
||||
});
|
||||
|
||||
tap.test('Route Security - addRateLimiting', async () => {
|
||||
// Create base route
|
||||
const baseRoute = createHttpRoute('example.com', { host: 'localhost', port: 3000 });
|
||||
|
||||
// Add rate limiting
|
||||
const secureRoute = addRateLimiting(baseRoute, {
|
||||
maxRequests: 100,
|
||||
window: 60, // 1 minute
|
||||
keyBy: 'ip'
|
||||
});
|
||||
|
||||
// Check if rate limiting is applied
|
||||
if (secureRoute.security) {
|
||||
expect(secureRoute.security.rateLimit?.enabled).toBeTrue();
|
||||
expect(secureRoute.security.rateLimit?.maxRequests).toEqual(100);
|
||||
expect(secureRoute.security.rateLimit?.window).toEqual(60);
|
||||
expect(secureRoute.security.rateLimit?.keyBy).toEqual('ip');
|
||||
} else {
|
||||
// Skip this test if security features are not implemented yet
|
||||
console.log('Security features not implemented yet in route configuration');
|
||||
}
|
||||
|
||||
// Just check that the route itself is valid
|
||||
const result = validateRouteConfig(secureRoute);
|
||||
expect(result.valid).toBeTrue();
|
||||
});
|
||||
|
||||
tap.test('Route Security - addBasicAuth', async () => {
|
||||
// Create base route
|
||||
const baseRoute = createHttpRoute('example.com', { host: 'localhost', port: 3000 });
|
||||
|
||||
// Add basic authentication
|
||||
const authRoute = addBasicAuth(baseRoute, {
|
||||
users: [
|
||||
{ username: 'admin', password: 'secret' },
|
||||
{ username: 'user', password: 'password' }
|
||||
],
|
||||
realm: 'Protected Area',
|
||||
excludePaths: ['/public']
|
||||
});
|
||||
|
||||
// Check if basic auth is applied
|
||||
if (authRoute.security) {
|
||||
expect(authRoute.security.basicAuth?.enabled).toBeTrue();
|
||||
expect(authRoute.security.basicAuth?.users.length).toEqual(2);
|
||||
expect(authRoute.security.basicAuth?.realm).toEqual('Protected Area');
|
||||
expect(authRoute.security.basicAuth?.excludePaths).toInclude('/public');
|
||||
} else {
|
||||
// Skip this test if security features are not implemented yet
|
||||
console.log('Security features not implemented yet in route configuration');
|
||||
}
|
||||
|
||||
// Check that the route itself is valid
|
||||
const result = validateRouteConfig(authRoute);
|
||||
expect(result.valid).toBeTrue();
|
||||
});
|
||||
|
||||
tap.test('Route Security - addJwtAuth', async () => {
|
||||
// Create base route
|
||||
const baseRoute = createHttpRoute('example.com', { host: 'localhost', port: 3000 });
|
||||
|
||||
// Add JWT authentication
|
||||
const jwtRoute = addJwtAuth(baseRoute, {
|
||||
secret: 'your-jwt-secret-key',
|
||||
algorithm: 'HS256',
|
||||
issuer: 'auth.example.com',
|
||||
audience: 'api.example.com',
|
||||
expiresIn: 3600
|
||||
});
|
||||
|
||||
// Check if JWT auth is applied
|
||||
if (jwtRoute.security) {
|
||||
expect(jwtRoute.security.jwtAuth?.enabled).toBeTrue();
|
||||
expect(jwtRoute.security.jwtAuth?.secret).toEqual('your-jwt-secret-key');
|
||||
expect(jwtRoute.security.jwtAuth?.algorithm).toEqual('HS256');
|
||||
expect(jwtRoute.security.jwtAuth?.issuer).toEqual('auth.example.com');
|
||||
expect(jwtRoute.security.jwtAuth?.audience).toEqual('api.example.com');
|
||||
expect(jwtRoute.security.jwtAuth?.expiresIn).toEqual(3600);
|
||||
} else {
|
||||
// Skip this test if security features are not implemented yet
|
||||
console.log('Security features not implemented yet in route configuration');
|
||||
}
|
||||
|
||||
// Check that the route itself is valid
|
||||
const result = validateRouteConfig(jwtRoute);
|
||||
expect(result.valid).toBeTrue();
|
||||
});
|
||||
|
||||
export default tap.start();
|
||||
export default tap.start();
|
||||
|
||||
@@ -0,0 +1,192 @@
|
||||
import { expect, tap } from '@git.zone/tstest/tapbundle';
|
||||
|
||||
import type { ISmartProxyOptions } from '../ts/proxies/smart-proxy/models/interfaces.js';
|
||||
import type { IRouteConfig } from '../ts/proxies/smart-proxy/models/route-types.js';
|
||||
import { RoutePreprocessor } from '../ts/proxies/smart-proxy/route-preprocessor.js';
|
||||
import { buildRustProxyOptions } from '../ts/proxies/smart-proxy/utils/rust-config.js';
|
||||
|
||||
tap.test('Rust contract - preprocessor serializes regex headers for Rust', async () => {
|
||||
const route: IRouteConfig = {
|
||||
name: 'contract-route',
|
||||
match: {
|
||||
ports: [443, { from: 8443, to: 8444 }],
|
||||
domains: ['api.example.com', '*.example.com'],
|
||||
transport: 'udp',
|
||||
protocol: 'http3',
|
||||
headers: {
|
||||
'Content-Type': /^application\/json$/i,
|
||||
},
|
||||
},
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{
|
||||
match: {
|
||||
ports: [443],
|
||||
path: '/api/*',
|
||||
method: ['GET'],
|
||||
headers: {
|
||||
'X-Env': /^(prod|stage)$/,
|
||||
},
|
||||
},
|
||||
host: ['backend-a', 'backend-b'],
|
||||
port: 'preserve',
|
||||
sendProxyProtocol: true,
|
||||
backendTransport: 'tcp',
|
||||
}],
|
||||
tls: {
|
||||
mode: 'terminate',
|
||||
certificate: 'auto',
|
||||
},
|
||||
sendProxyProtocol: true,
|
||||
udp: {
|
||||
maxSessionsPerIP: 321,
|
||||
quic: {
|
||||
enableHttp3: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
security: {
|
||||
ipAllowList: [{
|
||||
ip: '10.0.0.0/8',
|
||||
domains: ['api.example.com'],
|
||||
}],
|
||||
},
|
||||
};
|
||||
|
||||
const preprocessor = new RoutePreprocessor();
|
||||
const [rustRoute] = preprocessor.preprocessForRust([route]);
|
||||
|
||||
expect(rustRoute.match.headers?.['Content-Type']).toEqual('/^application\\/json$/i');
|
||||
expect(rustRoute.match.transport).toEqual('udp');
|
||||
expect(rustRoute.match.protocol).toEqual('http3');
|
||||
expect(rustRoute.action.targets?.[0].match?.headers?.['X-Env']).toEqual('/^(prod|stage)$/');
|
||||
expect(rustRoute.action.targets?.[0].port).toEqual('preserve');
|
||||
expect(rustRoute.action.targets?.[0].backendTransport).toEqual('tcp');
|
||||
expect(rustRoute.action.sendProxyProtocol).toBeTrue();
|
||||
expect(rustRoute.action.udp?.maxSessionsPerIp).toEqual(321);
|
||||
});
|
||||
|
||||
tap.test('Rust contract - preprocessor converts dynamic targets to relay-safe payloads', async () => {
|
||||
const route: IRouteConfig = {
|
||||
name: 'dynamic-contract-route',
|
||||
match: {
|
||||
ports: 8080,
|
||||
},
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{
|
||||
host: () => 'dynamic-backend.internal',
|
||||
port: () => 9443,
|
||||
}],
|
||||
},
|
||||
};
|
||||
|
||||
const preprocessor = new RoutePreprocessor();
|
||||
const [rustRoute] = preprocessor.preprocessForRust([route]);
|
||||
|
||||
expect(rustRoute.action.type).toEqual('socket-handler');
|
||||
expect(rustRoute.action.targets?.[0].host).toEqual('localhost');
|
||||
expect(rustRoute.action.targets?.[0].port).toEqual(0);
|
||||
expect(preprocessor.getOriginalRoute('dynamic-contract-route')).toEqual(route);
|
||||
});
|
||||
|
||||
tap.test('Rust contract - top-level config keeps shared SmartProxy settings', async () => {
|
||||
const settings: ISmartProxyOptions = {
|
||||
routes: [{
|
||||
name: 'top-level-contract-route',
|
||||
match: {
|
||||
ports: 443,
|
||||
domains: 'api.example.com',
|
||||
},
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{
|
||||
host: 'backend.internal',
|
||||
port: 8443,
|
||||
}],
|
||||
tls: {
|
||||
mode: 'terminate',
|
||||
certificate: 'auto',
|
||||
},
|
||||
},
|
||||
}],
|
||||
preserveSourceIP: true,
|
||||
proxyIPs: ['10.0.0.1'],
|
||||
acceptProxyProtocol: true,
|
||||
sendProxyProtocol: true,
|
||||
noDelay: true,
|
||||
keepAlive: true,
|
||||
keepAliveInitialDelay: 1500,
|
||||
maxPendingDataSize: 4096,
|
||||
disableInactivityCheck: true,
|
||||
enableKeepAliveProbes: true,
|
||||
enableDetailedLogging: true,
|
||||
enableTlsDebugLogging: true,
|
||||
enableRandomizedTimeouts: true,
|
||||
connectionTimeout: 5000,
|
||||
initialDataTimeout: 7000,
|
||||
socketTimeout: 9000,
|
||||
inactivityCheckInterval: 1100,
|
||||
maxConnectionLifetime: 13000,
|
||||
inactivityTimeout: 15000,
|
||||
gracefulShutdownTimeout: 17000,
|
||||
maxConnectionsPerIP: 20,
|
||||
connectionRateLimitPerMinute: 30,
|
||||
keepAliveTreatment: 'extended',
|
||||
keepAliveInactivityMultiplier: 2,
|
||||
extendedKeepAliveLifetime: 19000,
|
||||
metrics: {
|
||||
enabled: true,
|
||||
sampleIntervalMs: 250,
|
||||
retentionSeconds: 60,
|
||||
},
|
||||
acme: {
|
||||
enabled: true,
|
||||
email: 'ops@example.com',
|
||||
environment: 'staging',
|
||||
useProduction: false,
|
||||
skipConfiguredCerts: true,
|
||||
renewThresholdDays: 14,
|
||||
renewCheckIntervalHours: 12,
|
||||
autoRenew: true,
|
||||
port: 80,
|
||||
},
|
||||
};
|
||||
|
||||
const preprocessor = new RoutePreprocessor();
|
||||
const routes = preprocessor.preprocessForRust(settings.routes);
|
||||
const config = buildRustProxyOptions(settings, routes);
|
||||
|
||||
expect(config.preserveSourceIp).toBeTrue();
|
||||
expect(config.proxyIps).toEqual(['10.0.0.1']);
|
||||
expect(config.acceptProxyProtocol).toBeTrue();
|
||||
expect(config.sendProxyProtocol).toBeTrue();
|
||||
expect(config.noDelay).toBeTrue();
|
||||
expect(config.keepAlive).toBeTrue();
|
||||
expect(config.keepAliveInitialDelay).toEqual(1500);
|
||||
expect(config.maxPendingDataSize).toEqual(4096);
|
||||
expect(config.disableInactivityCheck).toBeTrue();
|
||||
expect(config.enableKeepAliveProbes).toBeTrue();
|
||||
expect(config.enableDetailedLogging).toBeTrue();
|
||||
expect(config.enableTlsDebugLogging).toBeTrue();
|
||||
expect(config.enableRandomizedTimeouts).toBeTrue();
|
||||
expect(config.connectionTimeout).toEqual(5000);
|
||||
expect(config.initialDataTimeout).toEqual(7000);
|
||||
expect(config.socketTimeout).toEqual(9000);
|
||||
expect(config.inactivityCheckInterval).toEqual(1100);
|
||||
expect(config.maxConnectionLifetime).toEqual(13000);
|
||||
expect(config.inactivityTimeout).toEqual(15000);
|
||||
expect(config.gracefulShutdownTimeout).toEqual(17000);
|
||||
expect(config.maxConnectionsPerIp).toEqual(20);
|
||||
expect(config.connectionRateLimitPerMinute).toEqual(30);
|
||||
expect(config.keepAliveTreatment).toEqual('extended');
|
||||
expect(config.keepAliveInactivityMultiplier).toEqual(2);
|
||||
expect(config.extendedKeepAliveLifetime).toEqual(19000);
|
||||
expect(config.metrics?.sampleIntervalMs).toEqual(250);
|
||||
expect(config.acme?.email).toEqual('ops@example.com');
|
||||
expect(config.acme?.environment).toEqual('staging');
|
||||
expect(config.acme?.skipConfiguredCerts).toBeTrue();
|
||||
expect(config.acme?.renewThresholdDays).toEqual(14);
|
||||
});
|
||||
|
||||
export default tap.start();
|
||||
@@ -188,10 +188,12 @@ tap.test('TCP forward - real-time byte tracking', async (tools) => {
|
||||
const byRoute = m.throughput.byRoute();
|
||||
console.log('TCP forward — throughput byRoute:', Array.from(byRoute.entries()));
|
||||
|
||||
// After close, per-IP data should be evicted (memory leak fix)
|
||||
// After close, per-IP buckets are retained briefly for final throughput sampling,
|
||||
// but active connection counts must already be zero.
|
||||
const byIPAfter = m.connections.byIP();
|
||||
console.log('TCP forward — connections byIP after close:', Array.from(byIPAfter.entries()));
|
||||
expect(byIPAfter.size).toEqual(0);
|
||||
expect(byIPAfter.size).toBeGreaterThan(0);
|
||||
expect(Array.from(byIPAfter.values()).every((count) => count === 0)).toEqual(true);
|
||||
|
||||
await proxy.stop();
|
||||
await tools.delayFor(200);
|
||||
|
||||
@@ -0,0 +1,418 @@
|
||||
import { tap, expect } from '@git.zone/tstest/tapbundle';
|
||||
import { SmartProxy } from '../ts/index.js';
|
||||
import * as http from 'http';
|
||||
import WebSocket, { WebSocketServer } from 'ws';
|
||||
import { findFreePorts, assertPortsFree } from './helpers/port-allocator.js';
|
||||
|
||||
/**
|
||||
* Helper: create a WebSocket client that connects through the proxy.
|
||||
* Registers the message handler BEFORE awaiting open to avoid race conditions.
|
||||
*/
|
||||
function connectWs(
|
||||
url: string,
|
||||
headers: Record<string, string> = {},
|
||||
opts: WebSocket.ClientOptions = {},
|
||||
): { ws: WebSocket; messages: string[]; opened: Promise<void> } {
|
||||
const messages: string[] = [];
|
||||
const ws = new WebSocket(url, { headers, ...opts });
|
||||
|
||||
// Register message handler immediately — before open fires
|
||||
ws.on('message', (data) => {
|
||||
messages.push(data.toString());
|
||||
});
|
||||
|
||||
const opened = new Promise<void>((resolve, reject) => {
|
||||
const timeout = setTimeout(() => reject(new Error('WebSocket open timeout')), 5000);
|
||||
ws.on('open', () => { clearTimeout(timeout); resolve(); });
|
||||
ws.on('error', (err) => { clearTimeout(timeout); reject(err); });
|
||||
});
|
||||
|
||||
return { ws, messages, opened };
|
||||
}
|
||||
|
||||
/** Wait until `predicate` returns true, with a hard timeout. */
|
||||
function waitFor(predicate: () => boolean, timeoutMs = 5000): Promise<void> {
|
||||
return new Promise((resolve, reject) => {
|
||||
const deadline = setTimeout(() => reject(new Error('waitFor timeout')), timeoutMs);
|
||||
const check = () => {
|
||||
if (predicate()) { clearTimeout(deadline); resolve(); }
|
||||
else setTimeout(check, 30);
|
||||
};
|
||||
check();
|
||||
});
|
||||
}
|
||||
|
||||
/** Graceful close helper */
|
||||
function closeWs(ws: WebSocket): Promise<void> {
|
||||
return new Promise((resolve) => {
|
||||
if (ws.readyState === WebSocket.CLOSED) return resolve();
|
||||
ws.on('close', () => resolve());
|
||||
ws.close();
|
||||
setTimeout(resolve, 2000); // fallback
|
||||
});
|
||||
}
|
||||
|
||||
// ─── Test 1: Basic WebSocket upgrade and bidirectional messaging ───
|
||||
tap.test('should proxy WebSocket connections with bidirectional messaging', async () => {
|
||||
const [PROXY_PORT, BACKEND_PORT] = await findFreePorts(2);
|
||||
|
||||
// Backend: echoes messages with prefix, sends greeting on connect
|
||||
const backendServer = http.createServer();
|
||||
const wss = new WebSocketServer({ server: backendServer });
|
||||
const backendMessages: string[] = [];
|
||||
|
||||
wss.on('connection', (ws) => {
|
||||
ws.on('message', (data) => {
|
||||
const msg = data.toString();
|
||||
backendMessages.push(msg);
|
||||
ws.send(`echo: ${msg}`);
|
||||
});
|
||||
ws.send('hello from backend');
|
||||
});
|
||||
|
||||
await new Promise<void>((resolve) => {
|
||||
backendServer.listen(BACKEND_PORT, '127.0.0.1', () => resolve());
|
||||
});
|
||||
|
||||
const proxy = new SmartProxy({
|
||||
routes: [{
|
||||
name: 'ws-test-route',
|
||||
match: { ports: PROXY_PORT },
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{ host: '127.0.0.1', port: BACKEND_PORT }],
|
||||
websocket: { enabled: true },
|
||||
},
|
||||
}],
|
||||
});
|
||||
await proxy.start();
|
||||
|
||||
// Connect client — message handler registered before open
|
||||
const { ws, messages, opened } = connectWs(
|
||||
`ws://127.0.0.1:${PROXY_PORT}/`,
|
||||
{ Host: 'test.local' },
|
||||
);
|
||||
await opened;
|
||||
|
||||
// Wait for the backend greeting
|
||||
await waitFor(() => messages.length >= 1);
|
||||
expect(messages[0]).toEqual('hello from backend');
|
||||
|
||||
// Send 3 messages, expect 3 echoes
|
||||
ws.send('ping 1');
|
||||
ws.send('ping 2');
|
||||
ws.send('ping 3');
|
||||
|
||||
await waitFor(() => messages.length >= 4);
|
||||
|
||||
expect(messages).toContain('echo: ping 1');
|
||||
expect(messages).toContain('echo: ping 2');
|
||||
expect(messages).toContain('echo: ping 3');
|
||||
expect(backendMessages).toInclude('ping 1');
|
||||
expect(backendMessages).toInclude('ping 2');
|
||||
expect(backendMessages).toInclude('ping 3');
|
||||
|
||||
await closeWs(ws);
|
||||
await proxy.stop();
|
||||
await new Promise<void>((resolve) => backendServer.close(() => resolve()));
|
||||
await new Promise((r) => setTimeout(r, 500));
|
||||
await assertPortsFree([PROXY_PORT, BACKEND_PORT]);
|
||||
});
|
||||
|
||||
// ─── Test 2: Multiple concurrent WebSocket connections ───
|
||||
tap.test('should handle multiple concurrent WebSocket connections', async () => {
|
||||
const [PROXY_PORT, BACKEND_PORT] = await findFreePorts(2);
|
||||
|
||||
const backendServer = http.createServer();
|
||||
const wss = new WebSocketServer({ server: backendServer });
|
||||
|
||||
let connectionCount = 0;
|
||||
wss.on('connection', (ws) => {
|
||||
const id = ++connectionCount;
|
||||
ws.on('message', (data) => {
|
||||
ws.send(`conn${id}: ${data.toString()}`);
|
||||
});
|
||||
});
|
||||
|
||||
await new Promise<void>((resolve) => {
|
||||
backendServer.listen(BACKEND_PORT, '127.0.0.1', () => resolve());
|
||||
});
|
||||
|
||||
const proxy = new SmartProxy({
|
||||
routes: [{
|
||||
name: 'ws-multi-route',
|
||||
match: { ports: PROXY_PORT },
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{ host: '127.0.0.1', port: BACKEND_PORT }],
|
||||
websocket: { enabled: true },
|
||||
},
|
||||
}],
|
||||
});
|
||||
await proxy.start();
|
||||
|
||||
const NUM_CLIENTS = 5;
|
||||
const clients: { ws: WebSocket; messages: string[] }[] = [];
|
||||
|
||||
for (let i = 0; i < NUM_CLIENTS; i++) {
|
||||
const c = connectWs(
|
||||
`ws://127.0.0.1:${PROXY_PORT}/`,
|
||||
{ Host: 'test.local' },
|
||||
);
|
||||
await c.opened;
|
||||
clients.push(c);
|
||||
}
|
||||
|
||||
// Each client sends a unique message
|
||||
for (let i = 0; i < NUM_CLIENTS; i++) {
|
||||
clients[i].ws.send(`hello from client ${i}`);
|
||||
}
|
||||
|
||||
// Wait for all replies
|
||||
await waitFor(() => clients.every((c) => c.messages.length >= 1));
|
||||
|
||||
for (let i = 0; i < NUM_CLIENTS; i++) {
|
||||
expect(clients[i].messages.length).toBeGreaterThanOrEqual(1);
|
||||
expect(clients[i].messages[0]).toInclude(`hello from client ${i}`);
|
||||
}
|
||||
expect(connectionCount).toEqual(NUM_CLIENTS);
|
||||
|
||||
for (const c of clients) await closeWs(c.ws);
|
||||
await proxy.stop();
|
||||
await new Promise<void>((resolve) => backendServer.close(() => resolve()));
|
||||
await new Promise((r) => setTimeout(r, 500));
|
||||
await assertPortsFree([PROXY_PORT, BACKEND_PORT]);
|
||||
});
|
||||
|
||||
// ─── Test 3: WebSocket with binary data ───
|
||||
tap.test('should proxy binary WebSocket frames', async () => {
|
||||
const [PROXY_PORT, BACKEND_PORT] = await findFreePorts(2);
|
||||
|
||||
const backendServer = http.createServer();
|
||||
const wss = new WebSocketServer({ server: backendServer });
|
||||
|
||||
wss.on('connection', (ws) => {
|
||||
ws.on('message', (data) => {
|
||||
ws.send(data, { binary: true });
|
||||
});
|
||||
});
|
||||
|
||||
await new Promise<void>((resolve) => {
|
||||
backendServer.listen(BACKEND_PORT, '127.0.0.1', () => resolve());
|
||||
});
|
||||
|
||||
const proxy = new SmartProxy({
|
||||
routes: [{
|
||||
name: 'ws-binary-route',
|
||||
match: { ports: PROXY_PORT },
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{ host: '127.0.0.1', port: BACKEND_PORT }],
|
||||
websocket: { enabled: true },
|
||||
},
|
||||
}],
|
||||
});
|
||||
await proxy.start();
|
||||
|
||||
const receivedBuffers: Buffer[] = [];
|
||||
const ws = new WebSocket(`ws://127.0.0.1:${PROXY_PORT}/`, {
|
||||
headers: { Host: 'test.local' },
|
||||
});
|
||||
ws.on('message', (data) => {
|
||||
receivedBuffers.push(Buffer.from(data as ArrayBuffer));
|
||||
});
|
||||
|
||||
await new Promise<void>((resolve, reject) => {
|
||||
const timeout = setTimeout(() => reject(new Error('timeout')), 5000);
|
||||
ws.on('open', () => { clearTimeout(timeout); resolve(); });
|
||||
ws.on('error', (err) => { clearTimeout(timeout); reject(err); });
|
||||
});
|
||||
|
||||
// Send a 256-byte buffer with known content
|
||||
const sentBuffer = Buffer.alloc(256);
|
||||
for (let i = 0; i < 256; i++) sentBuffer[i] = i;
|
||||
ws.send(sentBuffer);
|
||||
|
||||
await waitFor(() => receivedBuffers.length >= 1);
|
||||
|
||||
expect(receivedBuffers[0].length).toEqual(256);
|
||||
expect(Buffer.compare(receivedBuffers[0], sentBuffer)).toEqual(0);
|
||||
|
||||
await closeWs(ws);
|
||||
await proxy.stop();
|
||||
await new Promise<void>((resolve) => backendServer.close(() => resolve()));
|
||||
await new Promise((r) => setTimeout(r, 500));
|
||||
await assertPortsFree([PROXY_PORT, BACKEND_PORT]);
|
||||
});
|
||||
|
||||
// ─── Test 4: WebSocket path and query string preserved ───
|
||||
tap.test('should preserve path and query string through proxy', async () => {
|
||||
const [PROXY_PORT, BACKEND_PORT] = await findFreePorts(2);
|
||||
|
||||
const backendServer = http.createServer();
|
||||
const wss = new WebSocketServer({ server: backendServer });
|
||||
|
||||
let receivedUrl = '';
|
||||
wss.on('connection', (ws, req) => {
|
||||
receivedUrl = req.url || '';
|
||||
ws.send(`url: ${receivedUrl}`);
|
||||
});
|
||||
|
||||
await new Promise<void>((resolve) => {
|
||||
backendServer.listen(BACKEND_PORT, '127.0.0.1', () => resolve());
|
||||
});
|
||||
|
||||
const proxy = new SmartProxy({
|
||||
routes: [{
|
||||
name: 'ws-path-route',
|
||||
match: { ports: PROXY_PORT },
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{ host: '127.0.0.1', port: BACKEND_PORT }],
|
||||
websocket: { enabled: true },
|
||||
},
|
||||
}],
|
||||
});
|
||||
await proxy.start();
|
||||
|
||||
const { ws, messages, opened } = connectWs(
|
||||
`ws://127.0.0.1:${PROXY_PORT}/chat/room1?token=abc123`,
|
||||
{ Host: 'test.local' },
|
||||
);
|
||||
await opened;
|
||||
|
||||
await waitFor(() => messages.length >= 1);
|
||||
|
||||
expect(receivedUrl).toEqual('/chat/room1?token=abc123');
|
||||
expect(messages[0]).toEqual('url: /chat/room1?token=abc123');
|
||||
|
||||
await closeWs(ws);
|
||||
await proxy.stop();
|
||||
await new Promise<void>((resolve) => backendServer.close(() => resolve()));
|
||||
await new Promise((r) => setTimeout(r, 500));
|
||||
await assertPortsFree([PROXY_PORT, BACKEND_PORT]);
|
||||
});
|
||||
|
||||
// ─── Test 5: Clean close propagation ───
|
||||
tap.test('should handle clean WebSocket close from client', async () => {
|
||||
const [PROXY_PORT, BACKEND_PORT] = await findFreePorts(2);
|
||||
|
||||
const backendServer = http.createServer();
|
||||
const wss = new WebSocketServer({ server: backendServer });
|
||||
|
||||
let backendGotClose = false;
|
||||
let backendCloseCode = 0;
|
||||
wss.on('connection', (ws) => {
|
||||
ws.on('close', (code) => {
|
||||
backendGotClose = true;
|
||||
backendCloseCode = code;
|
||||
});
|
||||
ws.on('message', (data) => {
|
||||
ws.send(data);
|
||||
});
|
||||
});
|
||||
|
||||
await new Promise<void>((resolve) => {
|
||||
backendServer.listen(BACKEND_PORT, '127.0.0.1', () => resolve());
|
||||
});
|
||||
|
||||
const proxy = new SmartProxy({
|
||||
routes: [{
|
||||
name: 'ws-close-route',
|
||||
match: { ports: PROXY_PORT },
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{ host: '127.0.0.1', port: BACKEND_PORT }],
|
||||
websocket: { enabled: true },
|
||||
},
|
||||
}],
|
||||
});
|
||||
await proxy.start();
|
||||
|
||||
const { ws, messages, opened } = connectWs(
|
||||
`ws://127.0.0.1:${PROXY_PORT}/`,
|
||||
{ Host: 'test.local' },
|
||||
);
|
||||
await opened;
|
||||
|
||||
// Confirm connection works with a round-trip
|
||||
ws.send('test');
|
||||
await waitFor(() => messages.length >= 1);
|
||||
|
||||
// Close with code 1000
|
||||
let clientCloseCode = 0;
|
||||
const closed = new Promise<void>((resolve) => {
|
||||
ws.on('close', (code) => {
|
||||
clientCloseCode = code;
|
||||
resolve();
|
||||
});
|
||||
setTimeout(resolve, 3000);
|
||||
});
|
||||
ws.close(1000, 'done');
|
||||
await closed;
|
||||
|
||||
// Wait for backend to register
|
||||
await waitFor(() => backendGotClose, 3000);
|
||||
|
||||
expect(backendGotClose).toBeTrue();
|
||||
expect(clientCloseCode).toEqual(1000);
|
||||
|
||||
await proxy.stop();
|
||||
await new Promise<void>((resolve) => backendServer.close(() => resolve()));
|
||||
await new Promise((r) => setTimeout(r, 500));
|
||||
await assertPortsFree([PROXY_PORT, BACKEND_PORT]);
|
||||
});
|
||||
|
||||
// ─── Test 6: Large messages ───
|
||||
tap.test('should handle large WebSocket messages', async () => {
|
||||
const [PROXY_PORT, BACKEND_PORT] = await findFreePorts(2);
|
||||
|
||||
const backendServer = http.createServer();
|
||||
const wss = new WebSocketServer({ server: backendServer, maxPayload: 5 * 1024 * 1024 });
|
||||
|
||||
wss.on('connection', (ws) => {
|
||||
ws.on('message', (data) => {
|
||||
const buf = Buffer.from(data as ArrayBuffer);
|
||||
ws.send(`received ${buf.length} bytes`);
|
||||
});
|
||||
});
|
||||
|
||||
await new Promise<void>((resolve) => {
|
||||
backendServer.listen(BACKEND_PORT, '127.0.0.1', () => resolve());
|
||||
});
|
||||
|
||||
const proxy = new SmartProxy({
|
||||
routes: [{
|
||||
name: 'ws-large-route',
|
||||
match: { ports: PROXY_PORT },
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{ host: '127.0.0.1', port: BACKEND_PORT }],
|
||||
websocket: { enabled: true },
|
||||
},
|
||||
}],
|
||||
});
|
||||
await proxy.start();
|
||||
|
||||
const { ws, messages, opened } = connectWs(
|
||||
`ws://127.0.0.1:${PROXY_PORT}/`,
|
||||
{ Host: 'test.local' },
|
||||
{ maxPayload: 5 * 1024 * 1024 },
|
||||
);
|
||||
await opened;
|
||||
|
||||
// Send a 1MB message
|
||||
const largePayload = Buffer.alloc(1024 * 1024, 0x42);
|
||||
ws.send(largePayload);
|
||||
|
||||
await waitFor(() => messages.length >= 1);
|
||||
expect(messages[0]).toEqual(`received ${1024 * 1024} bytes`);
|
||||
|
||||
await closeWs(ws);
|
||||
await proxy.stop();
|
||||
await new Promise<void>((resolve) => backendServer.close(() => resolve()));
|
||||
await new Promise((r) => setTimeout(r, 500));
|
||||
await assertPortsFree([PROXY_PORT, BACKEND_PORT]);
|
||||
});
|
||||
|
||||
export default tap.start();
|
||||
@@ -3,6 +3,6 @@
|
||||
*/
|
||||
export const commitinfo = {
|
||||
name: '@push.rocks/smartproxy',
|
||||
version: '26.2.2',
|
||||
version: '27.9.0',
|
||||
description: 'A powerful proxy package with unified route-based configuration for high traffic management. Features include SSL/TLS support, flexible routing patterns, WebSocket handling, advanced security options, and automatic ACME certificate management.'
|
||||
}
|
||||
|
||||
+1
-1
@@ -7,7 +7,7 @@ export { SmartProxy } from './proxies/smart-proxy/index.js';
|
||||
export { SharedRouteManager as RouteManager } from './core/routing/route-manager.js';
|
||||
|
||||
// Export smart-proxy models
|
||||
export type { ISmartProxyOptions, IConnectionRecord, IRouteConfig, IRouteMatch, IRouteAction, IRouteTls, IRouteContext } from './proxies/smart-proxy/models/index.js';
|
||||
export type { ISmartProxyOptions, ISmartProxySecurityPolicy, IConnectionRecord, IRouteConfig, IRouteMatch, IRouteAction, IRouteTls, IRouteContext } from './proxies/smart-proxy/models/index.js';
|
||||
export type { TSmartProxyCertProvisionObject, ICertProvisionEventComms, ICertificateIssuedEvent, ICertificateFailedEvent } from './proxies/smart-proxy/models/interfaces.js';
|
||||
export * from './proxies/smart-proxy/utils/index.js';
|
||||
|
||||
|
||||
@@ -19,12 +19,14 @@ export { tsclass };
|
||||
import * as smartcrypto from '@push.rocks/smartcrypto';
|
||||
import * as smartlog from '@push.rocks/smartlog';
|
||||
import * as smartlogDestinationLocal from '@push.rocks/smartlog/destination-local';
|
||||
import * as smartnftables from '@push.rocks/smartnftables';
|
||||
import * as smartrust from '@push.rocks/smartrust';
|
||||
|
||||
export {
|
||||
smartcrypto,
|
||||
smartlog,
|
||||
smartlogDestinationLocal,
|
||||
smartnftables,
|
||||
smartrust,
|
||||
};
|
||||
|
||||
|
||||
@@ -26,6 +26,8 @@ interface IDatagramRelayMessage {
|
||||
* - TS→Rust: { type: "reply", sourceIp, sourcePort, destPort, payloadBase64 }
|
||||
*/
|
||||
export class DatagramHandlerServer {
|
||||
private static readonly MAX_BUFFER_SIZE = 50 * 1024 * 1024; // 50 MB
|
||||
|
||||
private server: plugins.net.Server | null = null;
|
||||
private connection: plugins.net.Socket | null = null;
|
||||
private socketPath: string;
|
||||
@@ -100,6 +102,11 @@ export class DatagramHandlerServer {
|
||||
|
||||
socket.on('data', (chunk: Buffer) => {
|
||||
this.readBuffer = Buffer.concat([this.readBuffer, chunk]);
|
||||
if (this.readBuffer.length > DatagramHandlerServer.MAX_BUFFER_SIZE) {
|
||||
logger.log('error', `DatagramHandlerServer: buffer exceeded ${DatagramHandlerServer.MAX_BUFFER_SIZE} bytes, resetting`);
|
||||
this.readBuffer = Buffer.alloc(0);
|
||||
return;
|
||||
}
|
||||
this.processFrames();
|
||||
});
|
||||
|
||||
|
||||
@@ -2,6 +2,6 @@
|
||||
* SmartProxy models
|
||||
*/
|
||||
// Export everything except IAcmeOptions from interfaces
|
||||
export type { ISmartProxyOptions, ISmartProxyCertStore, IConnectionRecord, TSmartProxyCertProvisionObject, ICertProvisionEventComms, ICertificateIssuedEvent, ICertificateFailedEvent } from './interfaces.js';
|
||||
export type { ISmartProxyOptions, ISmartProxySecurityPolicy, ISmartProxyCertStore, IConnectionRecord, TSmartProxyCertProvisionObject, ICertProvisionEventComms, ICertificateIssuedEvent, ICertificateFailedEvent } from './interfaces.js';
|
||||
export * from './route-types.js';
|
||||
export * from './metrics-types.js';
|
||||
|
||||
@@ -29,6 +29,11 @@ export interface ISmartProxyCertStore {
|
||||
}
|
||||
import type { IRouteConfig } from './route-types.js';
|
||||
|
||||
export interface ISmartProxySecurityPolicy {
|
||||
blockedIps?: string[];
|
||||
blockedCidrs?: string[];
|
||||
}
|
||||
|
||||
/**
|
||||
* Provision object for static or HTTP-01 certificate
|
||||
*/
|
||||
@@ -137,6 +142,7 @@ export interface ISmartProxyOptions {
|
||||
// Rate limiting and security
|
||||
maxConnectionsPerIP?: number; // Maximum simultaneous connections from a single IP
|
||||
connectionRateLimitPerMinute?: number; // Max new connections per minute from a single IP
|
||||
securityPolicy?: ISmartProxySecurityPolicy; // Global ingress block policy, enforced before routing
|
||||
|
||||
// Enhanced keep-alive settings
|
||||
keepAliveTreatment?: 'standard' | 'extended' | 'immortal'; // How to treat keep-alive connections
|
||||
@@ -276,4 +282,4 @@ export interface IConnectionRecord {
|
||||
path?: string;
|
||||
headers?: Record<string, string>;
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
@@ -29,9 +29,31 @@ export interface IThroughputHistoryPoint {
|
||||
out: number;
|
||||
}
|
||||
|
||||
export interface IRequestRateMetrics {
|
||||
perSecond: number;
|
||||
lastMinute: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* Main metrics interface with clean, grouped API
|
||||
*/
|
||||
/**
|
||||
* Protocol distribution for frontend (client→proxy) or backend (proxy→upstream).
|
||||
* Tracks active and total counts for h1/h2/h3/ws/other.
|
||||
*/
|
||||
export interface IProtocolDistribution {
|
||||
h1Active: number;
|
||||
h1Total: number;
|
||||
h2Active: number;
|
||||
h2Total: number;
|
||||
h3Active: number;
|
||||
h3Total: number;
|
||||
wsActive: number;
|
||||
wsTotal: number;
|
||||
otherActive: number;
|
||||
otherTotal: number;
|
||||
}
|
||||
|
||||
export interface IMetrics {
|
||||
// Connection metrics
|
||||
connections: {
|
||||
@@ -40,6 +62,12 @@ export interface IMetrics {
|
||||
byRoute(): Map<string, number>;
|
||||
byIP(): Map<string, number>;
|
||||
topIPs(limit?: number): Array<{ ip: string; count: number }>;
|
||||
/** Per-IP domain request counts: IP -> { domain -> count }. */
|
||||
domainRequestsByIP(): Map<string, Map<string, number>>;
|
||||
/** Top IP-domain pairs sorted by request count descending. */
|
||||
topDomainRequests(limit?: number): Array<{ ip: string; domain: string; count: number }>;
|
||||
frontendProtocols(): IProtocolDistribution;
|
||||
backendProtocols(): IProtocolDistribution;
|
||||
};
|
||||
|
||||
// Throughput metrics (bytes per second)
|
||||
@@ -58,6 +86,7 @@ export interface IMetrics {
|
||||
perSecond(): number;
|
||||
perMinute(): number;
|
||||
total(): number;
|
||||
byDomain(): Map<string, IRequestRateMetrics>;
|
||||
};
|
||||
|
||||
// Cumulative totals
|
||||
@@ -162,4 +191,4 @@ export interface IByteTracker {
|
||||
bytesOut: number;
|
||||
startTime: number;
|
||||
lastUpdate: number;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -141,8 +141,10 @@ export interface IRouteAuthentication {
|
||||
* Security options for routes
|
||||
*/
|
||||
export interface IRouteSecurity {
|
||||
// Access control lists
|
||||
ipAllowList?: string[]; // IP addresses that are allowed to connect
|
||||
// Access control lists.
|
||||
// Entries can be plain IP/CIDR strings (full route access) or
|
||||
// objects { ip, domains } to scope access to specific domains on this route.
|
||||
ipAllowList?: Array<string | { ip: string; domains: string[] }>;
|
||||
ipBlockList?: string[]; // IP addresses that are blocked from connecting
|
||||
|
||||
// Connection limits
|
||||
|
||||
@@ -0,0 +1,167 @@
|
||||
import type { IProtocolCacheEntry, IProtocolDistribution } from './metrics-types.js';
|
||||
import type { IAcmeOptions, ISmartProxyOptions, ISmartProxySecurityPolicy } from './interfaces.js';
|
||||
import type {
|
||||
IRouteAction,
|
||||
IRouteConfig,
|
||||
IRouteMatch,
|
||||
IRouteTarget,
|
||||
ITargetMatch,
|
||||
IRouteUdp,
|
||||
} from './route-types.js';
|
||||
|
||||
export type TRustHeaderMatchers = Record<string, string>;
|
||||
|
||||
export interface IRustRouteMatch extends Omit<IRouteMatch, 'headers'> {
|
||||
headers?: TRustHeaderMatchers;
|
||||
}
|
||||
|
||||
export interface IRustTargetMatch extends Omit<ITargetMatch, 'headers'> {
|
||||
headers?: TRustHeaderMatchers;
|
||||
}
|
||||
|
||||
export interface IRustRouteTarget extends Omit<IRouteTarget, 'host' | 'port' | 'match'> {
|
||||
host: string | string[];
|
||||
port: number | 'preserve';
|
||||
match?: IRustTargetMatch;
|
||||
}
|
||||
|
||||
export interface IRustRouteUdp extends Omit<IRouteUdp, 'maxSessionsPerIP'> {
|
||||
maxSessionsPerIp?: number;
|
||||
}
|
||||
|
||||
export interface IRustDefaultConfig extends Omit<NonNullable<ISmartProxyOptions['defaults']>, 'preserveSourceIP'> {
|
||||
preserveSourceIp?: boolean;
|
||||
}
|
||||
|
||||
export interface IRustRouteAction
|
||||
extends Omit<IRouteAction, 'targets' | 'socketHandler' | 'datagramHandler' | 'forwardingEngine' | 'nftables' | 'udp'> {
|
||||
targets?: IRustRouteTarget[];
|
||||
udp?: IRustRouteUdp;
|
||||
}
|
||||
|
||||
export interface IRustRouteConfig extends Omit<IRouteConfig, 'match' | 'action'> {
|
||||
match: IRustRouteMatch;
|
||||
action: IRustRouteAction;
|
||||
}
|
||||
|
||||
export interface IRustAcmeOptions extends Omit<IAcmeOptions, 'routeForwards'> {}
|
||||
|
||||
export interface IRustProxyOptions {
|
||||
routes: IRustRouteConfig[];
|
||||
preserveSourceIp?: boolean;
|
||||
proxyIps?: string[];
|
||||
acceptProxyProtocol?: boolean;
|
||||
sendProxyProtocol?: boolean;
|
||||
defaults?: IRustDefaultConfig;
|
||||
connectionTimeout?: number;
|
||||
initialDataTimeout?: number;
|
||||
socketTimeout?: number;
|
||||
inactivityCheckInterval?: number;
|
||||
maxConnectionLifetime?: number;
|
||||
inactivityTimeout?: number;
|
||||
gracefulShutdownTimeout?: number;
|
||||
noDelay?: boolean;
|
||||
keepAlive?: boolean;
|
||||
keepAliveInitialDelay?: number;
|
||||
maxPendingDataSize?: number;
|
||||
disableInactivityCheck?: boolean;
|
||||
enableKeepAliveProbes?: boolean;
|
||||
enableDetailedLogging?: boolean;
|
||||
enableTlsDebugLogging?: boolean;
|
||||
enableRandomizedTimeouts?: boolean;
|
||||
maxConnectionsPerIp?: number;
|
||||
connectionRateLimitPerMinute?: number;
|
||||
keepAliveTreatment?: ISmartProxyOptions['keepAliveTreatment'];
|
||||
keepAliveInactivityMultiplier?: number;
|
||||
extendedKeepAliveLifetime?: number;
|
||||
metrics?: ISmartProxyOptions['metrics'];
|
||||
securityPolicy?: ISmartProxySecurityPolicy;
|
||||
acme?: IRustAcmeOptions;
|
||||
}
|
||||
|
||||
export interface IRustStatistics {
|
||||
activeConnections: number;
|
||||
totalConnections: number;
|
||||
routesCount: number;
|
||||
listeningPorts: number[];
|
||||
uptimeSeconds: number;
|
||||
}
|
||||
|
||||
export interface IRustCertificateStatus {
|
||||
domain: string;
|
||||
source: string;
|
||||
expiresAt: number;
|
||||
isValid: boolean;
|
||||
}
|
||||
|
||||
export interface IRustThroughputSample {
|
||||
timestampMs: number;
|
||||
bytesIn: number;
|
||||
bytesOut: number;
|
||||
}
|
||||
|
||||
export interface IRustRouteMetrics {
|
||||
activeConnections: number;
|
||||
totalConnections: number;
|
||||
bytesIn: number;
|
||||
bytesOut: number;
|
||||
throughputInBytesPerSec: number;
|
||||
throughputOutBytesPerSec: number;
|
||||
throughputRecentInBytesPerSec: number;
|
||||
throughputRecentOutBytesPerSec: number;
|
||||
}
|
||||
|
||||
export interface IRustIpMetrics {
|
||||
activeConnections: number;
|
||||
totalConnections: number;
|
||||
bytesIn: number;
|
||||
bytesOut: number;
|
||||
throughputInBytesPerSec: number;
|
||||
throughputOutBytesPerSec: number;
|
||||
domainRequests: Record<string, number>;
|
||||
}
|
||||
|
||||
export interface IRustBackendMetrics {
|
||||
activeConnections: number;
|
||||
totalConnections: number;
|
||||
protocol: string;
|
||||
connectErrors: number;
|
||||
handshakeErrors: number;
|
||||
requestErrors: number;
|
||||
totalConnectTimeUs: number;
|
||||
connectCount: number;
|
||||
poolHits: number;
|
||||
poolMisses: number;
|
||||
h2Failures: number;
|
||||
}
|
||||
|
||||
export interface IRustHttpDomainRequestMetrics {
|
||||
requestsPerSecond: number;
|
||||
requestsLastMinute: number;
|
||||
}
|
||||
|
||||
export interface IRustMetricsSnapshot {
|
||||
activeConnections: number;
|
||||
totalConnections: number;
|
||||
bytesIn: number;
|
||||
bytesOut: number;
|
||||
throughputInBytesPerSec: number;
|
||||
throughputOutBytesPerSec: number;
|
||||
throughputRecentInBytesPerSec: number;
|
||||
throughputRecentOutBytesPerSec: number;
|
||||
routes: Record<string, IRustRouteMetrics>;
|
||||
ips: Record<string, IRustIpMetrics>;
|
||||
backends: Record<string, IRustBackendMetrics>;
|
||||
throughputHistory: IRustThroughputSample[];
|
||||
totalHttpRequests: number;
|
||||
httpRequestsPerSec: number;
|
||||
httpRequestsPerSecRecent: number;
|
||||
httpDomainRequests: Record<string, IRustHttpDomainRequestMetrics>;
|
||||
activeUdpSessions: number;
|
||||
totalUdpSessions: number;
|
||||
totalDatagramsIn: number;
|
||||
totalDatagramsOut: number;
|
||||
detectedProtocols: IProtocolCacheEntry[];
|
||||
frontendProtocols: IProtocolDistribution;
|
||||
backendProtocols: IProtocolDistribution;
|
||||
}
|
||||
@@ -1,5 +1,6 @@
|
||||
import type { IRouteConfig, IRouteAction, IRouteTarget } from './models/route-types.js';
|
||||
import { logger } from '../../core/utils/logger.js';
|
||||
import type { IRustRouteConfig } from './models/rust-types.js';
|
||||
import { serializeRouteForRust } from './utils/rust-config.js';
|
||||
|
||||
/**
|
||||
* Preprocesses routes before sending them to Rust.
|
||||
@@ -24,7 +25,7 @@ export class RoutePreprocessor {
|
||||
* - Non-serializable fields are stripped
|
||||
* - Original routes are preserved in the local map for handler lookup
|
||||
*/
|
||||
public preprocessForRust(routes: IRouteConfig[]): IRouteConfig[] {
|
||||
public preprocessForRust(routes: IRouteConfig[]): IRustRouteConfig[] {
|
||||
this.originalRoutes.clear();
|
||||
return routes.map((route, index) => this.preprocessRoute(route, index));
|
||||
}
|
||||
@@ -43,7 +44,7 @@ export class RoutePreprocessor {
|
||||
return new Map(this.originalRoutes);
|
||||
}
|
||||
|
||||
private preprocessRoute(route: IRouteConfig, index: number): IRouteConfig {
|
||||
private preprocessRoute(route: IRouteConfig, index: number): IRustRouteConfig {
|
||||
const routeKey = route.name || route.id || `route_${index}`;
|
||||
|
||||
// Check if this route needs TS-side handling
|
||||
@@ -57,7 +58,7 @@ export class RoutePreprocessor {
|
||||
// Create a clean copy for Rust
|
||||
const cleanRoute: IRouteConfig = {
|
||||
...route,
|
||||
action: this.cleanAction(route.action, routeKey, needsTsHandling),
|
||||
action: this.cleanAction(route.action, needsTsHandling),
|
||||
};
|
||||
|
||||
// Ensure we have a name for handler lookup
|
||||
@@ -65,7 +66,7 @@ export class RoutePreprocessor {
|
||||
cleanRoute.name = routeKey;
|
||||
}
|
||||
|
||||
return cleanRoute;
|
||||
return serializeRouteForRust(cleanRoute);
|
||||
}
|
||||
|
||||
private routeNeedsTsHandling(route: IRouteConfig): boolean {
|
||||
@@ -91,15 +92,16 @@ export class RoutePreprocessor {
|
||||
return false;
|
||||
}
|
||||
|
||||
private cleanAction(action: IRouteAction, routeKey: string, needsTsHandling: boolean): IRouteAction {
|
||||
const cleanAction: IRouteAction = { ...action };
|
||||
private cleanAction(action: IRouteAction, needsTsHandling: boolean): IRouteAction {
|
||||
let cleanAction: IRouteAction = { ...action };
|
||||
|
||||
if (needsTsHandling) {
|
||||
// Convert to socket-handler type for Rust (Rust will relay back to TS)
|
||||
cleanAction.type = 'socket-handler';
|
||||
// Remove the JS handlers (not serializable)
|
||||
delete (cleanAction as any).socketHandler;
|
||||
delete (cleanAction as any).datagramHandler;
|
||||
const { socketHandler: _socketHandler, datagramHandler: _datagramHandler, ...serializableAction } = cleanAction;
|
||||
cleanAction = {
|
||||
...serializableAction,
|
||||
type: 'socket-handler',
|
||||
};
|
||||
}
|
||||
|
||||
// Clean targets - replace functions with static values
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user