Compare commits

...

116 Commits

Author SHA1 Message Date
jkunz e806f7257f v27.9.0
Default (tags) / security (push) Failing after 0s
Default (tags) / test (push) Failing after 0s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-04-26 15:11:10 +00:00
jkunz af4908b63f feat(smart-proxy): add hot-reloadable global ingress security policy across Rust and TypeScript proxy layers 2026-04-26 15:11:10 +00:00
jkunz 8fa3a51b03 v27.8.2
Default (tags) / security (push) Failing after 0s
Default (tags) / test (push) Failing after 0s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-04-26 11:25:24 +00:00
jkunz 088ef6ab09 fix(rustproxy-metrics): retain inactive per-IP metric buckets briefly to capture final throughput before pruning 2026-04-26 11:25:24 +00:00
jkunz fdb5ec59bc v27.8.1
Default (tags) / security (push) Failing after 0s
Default (tags) / test (push) Failing after 0s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-04-26 09:17:11 +00:00
jkunz 1ea290a085 fix(rustproxy-metrics): preserve high-throughput IPs in metrics snapshots when active-connection rankings are saturated 2026-04-26 09:17:11 +00:00
jkunz cb71f32b90 v27.8.0
Default (tags) / security (push) Failing after 0s
Default (tags) / test (push) Failing after 0s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-04-14 12:43:59 +00:00
jkunz 46155ab12c feat(metrics): add per-domain HTTP request rate metrics 2026-04-14 12:43:59 +00:00
jkunz 490a310b54 v27.7.4
Default (tags) / security (push) Failing after 0s
Default (tags) / test (push) Failing after 0s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-04-14 09:17:55 +00:00
jkunz 6c5180573a fix(rustproxy metrics): use stable route metrics keys across HTTP and passthrough listeners 2026-04-14 09:17:55 +00:00
jkunz 30e5ab308f v27.7.3
Default (tags) / security (push) Failing after 0s
Default (tags) / test (push) Failing after 0s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-04-14 01:14:33 +00:00
jkunz d2a54b3491 fix(repo): no changes detected 2026-04-14 01:14:33 +00:00
jkunz dc922c97df v27.7.2
Default (tags) / security (push) Failing after 0s
Default (tags) / test (push) Failing after 0s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-04-14 00:55:25 +00:00
jkunz 8d1bae7604 fix(docs): clarify metrics documentation for domain normalization and saturating gauges 2026-04-14 00:55:25 +00:00
jkunz 200e86e311 v27.7.1
Default (tags) / security (push) Failing after 0s
Default (tags) / test (push) Failing after 0s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-04-14 00:54:12 +00:00
jkunz a53a2c4ca5 fix(rustproxy-http,rustproxy-metrics): fix domain-scoped request host detection and harden connection metrics cleanup 2026-04-14 00:54:12 +00:00
jkunz 6ee7237357 v27.7.0
Default (tags) / security (push) Failing after 0s
Default (tags) / test (push) Failing after 0s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-04-13 23:21:54 +00:00
jkunz b5b4c608f0 feat(smart-proxy): add typed Rust config serialization and regex header contract coverage 2026-04-13 23:21:54 +00:00
jkunz af132f40fc v27.6.0
Default (tags) / security (push) Failing after 0s
Default (tags) / test (push) Failing after 0s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-04-13 18:33:28 +00:00
jkunz 781634446a feat(metrics): track per-IP domain request metrics across HTTP and TCP passthrough traffic 2026-04-13 18:33:28 +00:00
jkunz e988d935b6 v27.5.0
Default (tags) / security (push) Failing after 0s
Default (tags) / test (push) Failing after 0s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-04-06 12:46:09 +00:00
jkunz 99a026627d feat(security): add domain-scoped IP allow list support across HTTP and passthrough filtering 2026-04-06 12:46:09 +00:00
jkunz 572e31587a v27.4.0
Default (tags) / security (push) Failing after 0s
Default (tags) / test (push) Failing after 1s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-04-04 19:25:06 +00:00
jkunz 8587fb997c feat(rustproxy): add HTTP/3 proxy service wiring for QUIC listeners 2026-04-04 19:25:06 +00:00
jkunz 9ba101c59b v27.3.1
Default (tags) / security (push) Failing after 0s
Default (tags) / test (push) Failing after 0s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-04-04 18:54:05 +00:00
jkunz 1ad3e61c15 fix(metrics): correct frontend and backend protocol connection tracking across h1, h2, h3, and websocket traffic 2026-04-04 18:54:05 +00:00
jkunz 3bfa451341 v27.3.0
Default (tags) / security (push) Failing after 0s
Default (tags) / test (push) Failing after 0s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-04-04 17:59:04 +00:00
jkunz 7b3ab7378b feat(test): add end-to-end WebSocket proxy test coverage 2026-04-04 17:59:04 +00:00
jkunz 527c616cd4 v27.2.0
Default (tags) / security (push) Failing after 0s
Default (tags) / test (push) Failing after 0s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-04-04 16:52:25 +00:00
jkunz b04eb0ab17 feat(metrics): add frontend and backend protocol distribution metrics 2026-04-04 16:52:25 +00:00
jkunz a55ff20391 v27.1.0
Default (tags) / security (push) Failing after 0s
Default (tags) / test (push) Failing after 0s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-03-27 22:34:13 +00:00
jkunz 3c24bf659b feat(rustproxy-passthrough): add selective connection recycling for route, security, and certificate updates 2026-03-27 22:34:13 +00:00
jkunz 5be93c8d38 v27.0.0
Default (tags) / security (push) Failing after 0s
Default (tags) / test (push) Failing after 0s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-03-26 20:45:41 +00:00
jkunz 788ccea81e BREAKING CHANGE(smart-proxy): remove route helper APIs and standardize route configuration on plain route objects 2026-03-26 20:45:41 +00:00
jkunz 47140e5403 v26.3.0
Default (tags) / security (push) Failing after 0s
Default (tags) / test (push) Failing after 0s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-03-26 13:11:57 +00:00
jkunz a6ffa24e36 feat(nftables): move NFTables forwarding management from the Rust engine to @push.rocks/smartnftables 2026-03-26 13:11:57 +00:00
jkunz c0e432fd9b v26.2.4
Default (tags) / security (push) Failing after 0s
Default (tags) / test (push) Failing after 0s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-03-26 07:05:57 +00:00
jkunz a3d8a3a388 fix(rustproxy-http): improve HTTP/3 connection reuse and clean up stale proxy state 2026-03-26 07:05:57 +00:00
jkunz 437d1a3329 v26.2.3
Default (tags) / security (push) Failing after 3s
Default (tags) / test (push) Failing after 3s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-03-25 07:26:47 +00:00
jkunz 746d93663d fix(repo): no changes to commit 2026-03-25 07:26:47 +00:00
jkunz a3f3fee253 v26.2.2
Default (tags) / security (push) Failing after 4s
Default (tags) / test (push) Failing after 5s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-03-25 07:22:17 +00:00
jkunz 53dee1fffc fix(proxy): improve connection cleanup and route validation handling 2026-03-25 07:22:17 +00:00
jkunz 34dc0cb9b6 v26.2.1
Default (tags) / security (push) Failing after 0s
Default (tags) / test (push) Failing after 0s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-03-23 11:11:55 +00:00
jkunz c83c43194b fix(rustproxy-http): include the upstream request URL when caching H3 Alt-Svc discoveries 2026-03-23 11:11:55 +00:00
jkunz d026d7c266 v26.2.0
Default (tags) / security (push) Failing after 0s
Default (tags) / test (push) Failing after 0s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-03-23 09:42:07 +00:00
jkunz 3b01144c51 feat(protocol-cache): add sliding TTL re-probing and eviction for backend protocol detection 2026-03-23 09:42:07 +00:00
jkunz 56f5697e1b v26.1.0
Default (tags) / security (push) Failing after 0s
Default (tags) / test (push) Failing after 0s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-03-22 10:20:00 +00:00
jkunz f04875885f feat(rustproxy-http): add protocol failure suppression, h3 fallback escalation, and protocol cache metrics exposure 2026-03-22 10:20:00 +00:00
jkunz d12812bb8d v26.0.0
Default (tags) / security (push) Failing after 0s
Default (tags) / test (push) Failing after 0s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-03-21 22:23:38 +00:00
jkunz fc04a0210b BREAKING CHANGE(ts-api,rustproxy): remove deprecated TypeScript protocol and utility exports while hardening QUIC, HTTP/3, WebSocket, and rate limiter cleanup paths 2026-03-21 22:23:38 +00:00
jkunz 33fdf42a70 v25.17.10
Default (tags) / security (push) Failing after 0s
Default (tags) / test (push) Failing after 0s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-03-20 08:57:18 +00:00
jkunz fb1c59ac9a fix(rustproxy-http): reuse the shared HTTP proxy service for HTTP/3 request handling 2026-03-20 08:57:18 +00:00
jkunz ea8224c400 v25.17.9
Default (tags) / security (push) Failing after 0s
Default (tags) / test (push) Failing after 0s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-03-20 08:30:09 +00:00
jkunz da1cc58a3d fix(rustproxy-http): correct HTTP/3 host extraction and avoid protocol filtering during UDP route lookup 2026-03-20 08:30:09 +00:00
jkunz 606c620849 v25.17.8
Default (tags) / security (push) Failing after 0s
Default (tags) / test (push) Failing after 0s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-03-20 08:06:32 +00:00
jkunz 4ae09ac6ae fix(rustproxy): use SNI-based certificate resolution for QUIC TLS connections 2026-03-20 08:06:32 +00:00
jkunz 2fce910795 v25.17.7
Default (tags) / security (push) Failing after 0s
Default (tags) / test (push) Failing after 0s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-03-20 07:50:41 +00:00
jkunz ff09cef350 fix(readme): document QUIC and HTTP/3 compatibility caveats 2026-03-20 07:50:41 +00:00
jkunz d0148b2ac3 v25.17.6
Default (tags) / security (push) Failing after 0s
Default (tags) / test (push) Failing after 0s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-03-20 07:48:26 +00:00
jkunz 7217e15649 fix(rustproxy-http): disable HTTP/3 GREASE for client and server connections 2026-03-20 07:48:26 +00:00
jkunz bfcf92a855 v25.17.5
Default (tags) / security (push) Failing after 0s
Default (tags) / test (push) Failing after 0s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-03-20 07:43:32 +00:00
jkunz 8e0804cd20 fix(rustproxy): add HTTP/3 integration test for QUIC response stream FIN handling 2026-03-20 07:43:32 +00:00
jkunz c63f6fcd5f v25.17.4
Default (tags) / security (push) Failing after 0s
Default (tags) / test (push) Failing after 0s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-03-20 03:19:57 +00:00
jkunz f3cd4d193e fix(rustproxy-http): prevent HTTP/3 response body streaming from hanging on backend completion 2026-03-20 03:19:57 +00:00
jkunz 81de611255 v25.17.3
Default (tags) / security (push) Failing after 0s
Default (tags) / test (push) Failing after 0s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-03-20 02:54:44 +00:00
jkunz 91598b3be9 fix(repository): no changes detected 2026-03-20 02:54:44 +00:00
jkunz 4e3c548012 v25.17.2
Default (tags) / security (push) Failing after 0s
Default (tags) / test (push) Failing after 0s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-03-20 02:53:41 +00:00
jkunz 1a2d7529db fix(rustproxy-http): enable TLS connections for HTTP/3 upstream requests when backend re-encryption or TLS is configured 2026-03-20 02:53:41 +00:00
jkunz 31514f54ae v25.17.1
Default (tags) / security (push) Failing after 0s
Default (tags) / test (push) Failing after 0s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-03-20 02:35:22 +00:00
jkunz 247653c9d0 fix(rustproxy-routing): allow QUIC UDP TLS connections without SNI to match domain-restricted routes 2026-03-20 02:35:22 +00:00
jkunz 07d88f6f6a v25.17.0
Default (tags) / security (push) Failing after 0s
Default (tags) / test (push) Failing after 0s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-03-19 23:16:42 +00:00
jkunz 4b64de2c67 feat(rustproxy-passthrough): add PROXY protocol v2 client IP handling for UDP and QUIC listeners 2026-03-19 23:16:42 +00:00
jkunz e8db7bc96d v25.16.3
Default (tags) / security (push) Failing after 0s
Default (tags) / test (push) Failing after 1s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-03-19 22:00:07 +00:00
jkunz 2621dea9fa fix(rustproxy): upgrade fallback UDP listeners to QUIC when TLS certificates become available 2026-03-19 22:00:07 +00:00
jkunz bb5b9b3d12 v25.16.2
Default (tags) / security (push) Failing after 12s
Default (tags) / test (push) Failing after 13s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-03-19 21:24:05 +00:00
jkunz d70c2d77ed fix(rustproxy-http): cache backend Alt-Svc only from original upstream responses during protocol auto-detection 2026-03-19 21:24:05 +00:00
jkunz 4cf13c36f8 v25.16.1
Default (tags) / security (push) Failing after 19s
Default (tags) / test (push) Failing after 18s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-03-19 20:57:48 +00:00
jkunz 37c7233780 fix(http-proxy): avoid repeated HTTP/3 recaching after QUIC fallback and document backend protocol selection 2026-03-19 20:57:48 +00:00
jkunz 15d0a721d5 v25.16.0
Default (tags) / security (push) Failing after 0s
Default (tags) / test (push) Failing after 0s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-03-19 20:27:57 +00:00
jkunz af970c447e feat(quic,http3): add HTTP/3 proxy handling and hot-reload QUIC TLS configuration 2026-03-19 20:27:57 +00:00
jkunz 9e1103e7a7 v25.15.0
Default (tags) / security (push) Failing after 1s
Default (tags) / test (push) Failing after 1s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-03-19 18:55:31 +00:00
jkunz 2b990527ac feat(readme): document UDP, QUIC, and HTTP/3 support in the README 2026-03-19 18:55:31 +00:00
jkunz 9595f0a9fc v25.14.1
Default (tags) / security (push) Failing after 0s
Default (tags) / test (push) Failing after 0s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-03-19 16:21:37 +00:00
jkunz 0fb3988123 fix(deps): update build and runtime dependencies and align route validation test expectations 2026-03-19 16:21:37 +00:00
jkunz 53938df8db v25.14.0
Default (tags) / security (push) Failing after 0s
Default (tags) / test (push) Failing after 2s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-03-19 16:09:51 +00:00
jkunz e890bda8fc feat(udp,http3): add UDP datagram handler relay support and stream HTTP/3 request bodies to backends 2026-03-19 16:09:51 +00:00
jkunz bbe8b729ea v25.13.0
Default (tags) / security (push) Failing after 0s
Default (tags) / test (push) Failing after 0s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-03-19 15:06:27 +00:00
jkunz 4fb91cd868 feat(smart-proxy): add UDP transport support with QUIC/HTTP3 routing and datagram handler relay 2026-03-19 15:06:27 +00:00
jkunz cfa958cf3d v25.12.0
Default (tags) / security (push) Failing after 0s
Default (tags) / test (push) Failing after 1s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-03-19 12:41:26 +00:00
jkunz db2e586da2 feat(proxy-protocol): add PROXY protocol v2 support to the Rust passthrough listener and streamline TypeScript proxy protocol exports 2026-03-19 12:41:26 +00:00
jkunz 91832c368d v25.11.24
Default (tags) / security (push) Failing after 0s
Default (tags) / test (push) Failing after 0s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-03-17 16:47:57 +00:00
jkunz c9d0fccb2d fix(rustproxy-http): improve async static file serving, websocket handshake buffering, and shared metric metadata handling 2026-03-17 16:47:57 +00:00
jkunz 5dccbbc9d1 v25.11.23
Default (tags) / security (push) Failing after 1s
Default (tags) / test (push) Failing after 1s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-03-17 12:22:51 +00:00
jkunz 92d7113c6c fix(rustproxy-http,rustproxy-metrics): reduce per-frame metrics overhead by batching body byte accounting 2026-03-17 12:22:51 +00:00
jkunz 8f6bb30367 v25.11.22
Default (tags) / security (push) Failing after 1s
Default (tags) / test (push) Failing after 1s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-03-17 12:12:24 +00:00
jkunz ef9bac80ff fix(rustproxy-http): reuse healthy HTTP/2 upstream connections after requests with bodies 2026-03-17 12:12:24 +00:00
jkunz 9c78701038 v25.11.21
Default (tags) / security (push) Failing after 1s
Default (tags) / test (push) Failing after 1s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-03-17 11:33:34 +00:00
jkunz 26fd9409a7 fix(rustproxy-http): reuse pooled HTTP/2 connections for requests with and without bodies 2026-03-17 11:33:34 +00:00
jkunz cfff128499 v25.11.20
Default (tags) / security (push) Failing after 1s
Default (tags) / test (push) Failing after 1s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-03-17 01:32:35 +00:00
jkunz 3baff354bd fix(rustproxy-http): avoid downgrading cached backend protocol on H2 stream errors 2026-03-17 01:32:35 +00:00
jkunz c2eacd1b30 v25.11.19
Default (tags) / security (push) Failing after 1s
Default (tags) / test (push) Failing after 2s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-03-16 20:53:39 +00:00
jkunz 1fdbfcf0aa fix(rustproxy-http): avoid reusing pooled HTTP/2 connections for requests with bodies to prevent upload flow-control stalls 2026-03-16 20:53:39 +00:00
jkunz 9b184acc8c v25.11.18
Default (tags) / security (push) Failing after 1s
Default (tags) / test (push) Failing after 1s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-03-16 17:42:14 +00:00
jkunz b475968f4e fix(repo): no changes to commit 2026-03-16 17:42:14 +00:00
jkunz 878eab6e88 v25.11.17
Default (tags) / security (push) Failing after 1s
Default (tags) / test (push) Failing after 1s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-03-16 14:30:43 +00:00
jkunz 77abe0804d fix(rustproxy-http): prevent stale HTTP/2 connection drivers from evicting newer pooled connections 2026-03-16 14:30:43 +00:00
jkunz ae0342d018 v25.11.16
Default (tags) / security (push) Failing after 1s
Default (tags) / test (push) Failing after 1s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-03-16 13:58:22 +00:00
jkunz 365981d9cf fix(repo): no changes to commit 2026-03-16 13:58:22 +00:00
jkunz 2cc0ff0030 v25.11.15
Default (tags) / security (push) Failing after 1s
Default (tags) / test (push) Failing after 1s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-03-16 13:54:56 +00:00
jkunz 72935e7ee0 fix(rustproxy-http): implement vectored write support for backend streams 2026-03-16 13:54:56 +00:00
jkunz 61db285e04 v25.11.14
Default (tags) / security (push) Failing after 1s
Default (tags) / test (push) Failing after 1s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-03-16 13:44:56 +00:00
jkunz d165829022 fix(rustproxy-http): forward vectored write support in ShutdownOnDrop AsyncWrite wrapper 2026-03-16 13:44:56 +00:00
jkunz 5e6cf391ab v25.11.13
Default (tags) / security (push) Failing after 1s
Default (tags) / test (push) Failing after 1s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-03-16 13:17:02 +00:00
jkunz 2b1a21c599 fix(rustproxy-http): remove hot-path debug logging from HTTP/1 connection pool hits 2026-03-16 13:17:02 +00:00
jkunz b8e1c9f3cf v25.11.12
Default (tags) / security (push) Failing after 1s
Default (tags) / test (push) Failing after 1s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-03-16 13:12:24 +00:00
jkunz c65369540c fix(rustproxy-http): remove connection pool hit logging and keep logging limited to actual failures 2026-03-16 13:12:24 +00:00
196 changed files with 19986 additions and 21402 deletions
View File
+392
View File
@@ -1,5 +1,397 @@
# Changelog
## 2026-04-26 - 27.9.0 - feat(smart-proxy)
add hot-reloadable global ingress security policy across Rust and TypeScript proxy layers
- adds global securityPolicy config with blocked IP and CIDR support to SmartProxy and RustProxy options
- introduces management IPC support to update the security policy at runtime via setSecurityPolicy
- enforces the global block list early for TCP, UDP, and QUIC traffic before route selection and backend handling
## 2026-04-26 - 27.8.2 - fix(rustproxy-metrics)
retain inactive per-IP metric buckets briefly to capture final throughput before pruning
- adds a bounded retention window for closed IP buckets so short-lived transfers are still included in per-IP throughput sampling
- prunes expired inactive IP tracking by TTL and hard cap to prevent unbounded metric map growth
- updates Rust and throughput tests to expect zero active connections during the temporary retention period
## 2026-04-26 - 27.8.1 - fix(rustproxy-metrics)
preserve high-throughput IPs in metrics snapshots when active-connection rankings are saturated
- Select snapshot IPs using a blend of active-connection and throughput rankings instead of only active connections
- Adds a regression test to ensure a high-bandwidth IP remains included when many other IPs have more active connections
## 2026-04-14 - 27.8.0 - feat(metrics)
add per-domain HTTP request rate metrics
- Record canonicalized HTTP request rates per domain in the Rust metrics collector and expose per-second and last-minute values in snapshots.
- Add TypeScript metrics interfaces and adapter support for requests.byDomain().
- Cover HTTP domain rate tracking and ensure TLS passthrough SNI traffic does not affect HTTP request rate metrics.
## 2026-04-14 - 27.7.4 - fix(rustproxy metrics)
use stable route metrics keys across HTTP and passthrough listeners
- adds a shared RouteConfig::metrics_key helper that prefers route name and falls back to route id
- updates HTTP, TCP, UDP, and QUIC metrics labeling to use the shared route metrics key consistently
- keeps route cancellation and rate limiter indexing bound to route config ids where required
- adds tests covering metrics key selection behavior
## 2026-04-14 - 27.7.3 - fix(repo)
no changes detected
## 2026-04-14 - 27.7.2 - fix(docs)
clarify metrics documentation for domain normalization and saturating gauges
- Document that per-IP domain keys are normalized to lowercase and have trailing dots stripped before counting.
- Clarify that the saturating close pattern also applies to connection and UDP active gauges.
## 2026-04-14 - 27.7.1 - fix(rustproxy-http,rustproxy-metrics)
fix domain-scoped request host detection and harden connection metrics cleanup
- use a shared request host extractor that falls back to URI authority so domain-scoped IP allow lists work for HTTP/2 and HTTP/3 requests without a Host header
- add request filter and host extraction tests covering domain-scoped ACL behavior
- prevent connection counters from underflowing during close handling and clean up per-IP metrics entries more safely
- normalize tracked domain keys in metrics to reduce duplicate entries caused by case or trailing-dot variations
## 2026-04-13 - 27.7.0 - feat(smart-proxy)
add typed Rust config serialization and regex header contract coverage
- serialize SmartProxy routes and top-level options into explicit Rust-safe types, including header regex literals, UDP field normalization, ACME, defaults, and proxy settings
- support JS-style regex header literals with flags in Rust header matching and add cross-contract tests for route preprocessing and config deserialization
- improve TypeScript safety for Rust bridge and metrics integration by replacing loose any-based payloads with dedicated Rust type definitions
## 2026-04-13 - 27.6.0 - feat(metrics)
track per-IP domain request metrics across HTTP and TCP passthrough traffic
- records domain request counts per frontend IP from HTTP Host headers and TCP SNI
- exposes per-IP domain maps and top IP-domain request pairs through the TypeScript metrics adapter
- bounds per-IP domain tracking and prunes stale entries to limit memory growth
- adds metrics system documentation covering architecture, collected data, and known gaps
## 2026-04-06 - 27.5.0 - feat(security)
add domain-scoped IP allow list support across HTTP and passthrough filtering
- extend route security types to accept IP allow entries scoped to specific domains
- apply domain-aware IP checks using Host headers for HTTP and SNI context for QUIC and passthrough connections
- preserve compatibility for existing plain allow list entries and add validation and tests for scoped matching
## 2026-04-04 - 27.4.0 - feat(rustproxy)
add HTTP/3 proxy service wiring for QUIC listeners
- registers H3ProxyService with the UDP listener manager so QUIC connections can serve HTTP/3
- keeps proxy IP configuration intact while enabling HTTP/3 handling during listener setup
## 2026-04-04 - 27.3.1 - fix(metrics)
correct frontend and backend protocol connection tracking across h1, h2, h3, and websocket traffic
- move frontend protocol accounting from per-request to connection lifetime tracking for HTTP/1, HTTP/2, and HTTP/3
- add backend protocol guards to connection drivers so active protocol metrics reflect live upstream connections
- prevent protocol counter underflow by using atomic saturating decrements in the metrics collector
- read backend protocol distribution directly from cached aggregate counters in the Rust metrics adapter
## 2026-04-04 - 27.3.0 - feat(test)
add end-to-end WebSocket proxy test coverage
- add comprehensive WebSocket e2e tests for upgrade handling, bidirectional messaging, header forwarding, close propagation, and large payloads
- add ws and @types/ws as development dependencies to support the new test suite
## 2026-04-04 - 27.2.0 - feat(metrics)
add frontend and backend protocol distribution metrics
- track active and total frontend protocol counts for h1, h2, h3, websocket, and other traffic
- add backend protocol counters with RAII guards to ensure metrics are decremented on all exit paths
- expose protocol distribution through the TypeScript metrics interfaces and Rust metrics adapter
## 2026-03-27 - 27.1.0 - feat(rustproxy-passthrough)
add selective connection recycling for route, security, and certificate updates
- introduce a shared connection registry to track active TCP and QUIC connections by route, source IP, and domain
- recycle only affected connections when route actions or security rules change instead of broadly invalidating traffic
- gracefully recycle existing connections when TLS certificates change for a domain
- apply route-level IP security checks to QUIC connections and share route cancellation state with UDP listeners
## 2026-03-26 - 27.0.0 - BREAKING CHANGE(smart-proxy)
remove route helper APIs and standardize route configuration on plain route objects
- Removes TypeScript route helper exports and related Rust config helpers in favor of defining routes directly with match and action properties.
- Updates documentation and tests to use plain IRouteConfig objects and SocketHandlers imports instead of helper factory functions.
- Moves socket handlers to a top-level utils export and keeps direct socket-handler route configuration as the supported pattern.
## 2026-03-26 - 26.3.0 - feat(nftables)
move NFTables forwarding management from the Rust engine to @push.rocks/smartnftables
- add @push.rocks/smartnftables as a runtime dependency and export it via the plugin layer
- remove the internal rustproxy-nftables crate along with Rust-side NFTables rule application and status management
- apply and clean up NFTables port-forwarding rules in the TypeScript SmartProxy lifecycle and route update flow
- change getNfTablesStatus to return local smartnftables status instead of querying the Rust bridge
- update README documentation to describe NFTables support as provided through @push.rocks/smartnftables
## 2026-03-26 - 26.2.4 - fix(rustproxy-http)
improve HTTP/3 connection reuse and clean up stale proxy state
- Reuse pooled HTTP/3 SendRequest handles to skip repeated SETTINGS handshakes and reduce request overhead on QUIC pool hits
- Add periodic cleanup for per-route rate limiters and orphaned backend metrics to prevent unbounded memory growth after traffic or backend errors stop
- Enforce HTTP max connection lifetime alongside idle timeouts and apply configured lifetime values from the TCP listener
- Reduce HTTP/3 body copying by using owned Bytes paths for request and response streaming, and replace the custom response body adapter with a stream-based implementation
- Harden auxiliary proxy components by capping datagram handler buffer growth and removing duplicate RustProxy exit listeners
## 2026-03-25 - 26.2.3 - fix(repo)
no changes to commit
## 2026-03-25 - 26.2.2 - fix(proxy)
improve connection cleanup and route validation handling
- add timeouts for HTTP/1 upstream connection drivers to prevent lingering tasks
- ensure QUIC relay sessions cancel and abort background tasks on drop
- avoid registering unnamed routes as duplicates and label unnamed catch-all conflicts clearly
- fix offset mapping route helper to forward only remaining route options without overriding derived values
- update project config filename and toolchain versions for the current build setup
## 2026-03-23 - 26.2.1 - fix(rustproxy-http)
include the upstream request URL when caching H3 Alt-Svc discoveries
- Tracks the request path that triggered Alt-Svc discovery in connection activity state
- Adds request URL context to Alt-Svc debug logging and protocol cache insertion reasons for better traceability
## 2026-03-23 - 26.2.0 - feat(protocol-cache)
add sliding TTL re-probing and eviction for backend protocol detection
- extend protocol cache entries and metrics with last accessed and last probed timestamps
- trigger periodic ALPN re-probes for cached H1/H2 entries while keeping active entries alive with a sliding 1 day TTL
- log protocol transitions with reasons and evict cache entries when all protocol fallback attempts fail
## 2026-03-22 - 26.1.0 - feat(rustproxy-http)
add protocol failure suppression, h3 fallback escalation, and protocol cache metrics exposure
- introduces escalating cooldowns for failed H2/H3 protocol detection to prevent repeated upgrades to unstable backends
- adds within-request escalation to cached HTTP/3 when TCP or TLS backend connections fail in auto-detect mode
- exposes detected protocol cache entries and suppression state through Rust metrics and the TypeScript metrics adapter
## 2026-03-21 - 26.0.0 - BREAKING CHANGE(ts-api,rustproxy)
remove deprecated TypeScript protocol and utility exports while hardening QUIC, HTTP/3, WebSocket, and rate limiter cleanup paths
- Removes large parts of the public TypeScript surface including detection, TLS, router, websocket, proxy/common protocol, and multiple core utility exports and files.
- Adds parent-child cancellation handling for HTTP/3 and QUIC stream forwarding to stop orphaned tasks and close idle or overlong streams.
- Improves cleanup reliability with RAII guards for WebSocket upstream tracking and QUIC connection metrics, plus periodic cleanup for rate limiter and proxy address maps.
- Cleans backend metrics state when active backend connections drop to zero and tracks passthrough backend sockets for shutdown cleanup.
## 2026-03-20 - 25.17.10 - fix(rustproxy-http)
reuse the shared HTTP proxy service for HTTP/3 request handling
- Refactors H3ProxyService to delegate requests to the shared HttpProxyService instead of maintaining separate routing and backend forwarding logic.
- Aligns HTTP/3 with the TCP/HTTP path for route matching, connection pooling, and ALPN-based upstream protocol detection.
- Generalizes request handling and filters to accept boxed/generic HTTP bodies so both HTTP/3 and existing HTTP paths share the same proxy pipeline.
- Updates the HTTP/3 integration route matcher to allow transport matching across shared HTTP and QUIC handling.
## 2026-03-20 - 25.17.9 - fix(rustproxy-http)
correct HTTP/3 host extraction and avoid protocol filtering during UDP route lookup
- Use the URI host or strip the port from the Host header so HTTP/3 requests match routes consistently with TCP/HTTP handling.
- Remove protocol filtering from HTTP/3 route lookup because QUIC transport already constrains routing to UDP and protocol validation happens earlier.
## 2026-03-20 - 25.17.8 - fix(rustproxy)
use SNI-based certificate resolution for QUIC TLS connections
- Replaces static first-certificate selection with the shared CertResolver used by the TCP/TLS path.
- Ensures QUIC connections can present the correct certificate per requested domain.
- Keeps HTTP/3 ALPN configuration while improving multi-domain TLS handling.
## 2026-03-20 - 25.17.7 - fix(readme)
document QUIC and HTTP/3 compatibility caveats
- Add notes explaining that GREASE frames are disabled on both server and client HTTP/3 paths to avoid interoperability issues
- Document that the current HTTP/3 stack depends on pre-1.0 h3 ecosystem components and may still have rough edges
## 2026-03-20 - 25.17.6 - fix(rustproxy-http)
disable HTTP/3 GREASE for client and server connections
- Switch the HTTP/3 server connection setup to use the builder API with send_grease(false)
- Switch the HTTP/3 client handshake to use the builder API with send_grease(false) to improve compatibility
## 2026-03-20 - 25.17.5 - fix(rustproxy)
add HTTP/3 integration test for QUIC response stream FIN handling
- adds an integration test covering HTTP/3 proxying over QUIC with TLS termination
- verifies response bodies fully arrive and the client receives stream termination instead of hanging
- adds test-only dependencies for quinn, h3, h3-quinn, rustls, bytes, and http
## 2026-03-20 - 25.17.4 - fix(rustproxy-http)
prevent HTTP/3 response body streaming from hanging on backend completion
- extract and track Content-Length before consuming the response body
- stop the HTTP/3 body loop when the stream reports end-of-stream or the expected byte count has been sent
- add a per-frame idle timeout to avoid indefinite waits on stalled or close-delimited backend bodies
## 2026-03-20 - 25.17.3 - fix(repository)
no changes detected
## 2026-03-20 - 25.17.2 - fix(rustproxy-http)
enable TLS connections for HTTP/3 upstream requests when backend re-encryption or TLS is configured
- Pass backend TLS client configuration into the HTTP/3 request handler.
- Detect TLS-required upstream targets using route and target TLS settings before connecting.
- Build backend request URIs with the correct http or https scheme to match the upstream connection.
## 2026-03-20 - 25.17.1 - fix(rustproxy-routing)
allow QUIC UDP TLS connections without SNI to match domain-restricted routes
- Exempts UDP transport from the no-SNI rejection logic because QUIC encrypts the TLS ClientHello and SNI is unavailable at accept time
- Adds regression tests to confirm QUIC route matching succeeds without SNI while TCP TLS without SNI remains rejected
## 2026-03-19 - 25.17.0 - feat(rustproxy-passthrough)
add PROXY protocol v2 client IP handling for UDP and QUIC listeners
- propagate trusted proxy IP configuration into UDP and QUIC listener managers
- extract and preserve real client addresses from PROXY protocol v2 headers for HTTP/3 and QUIC stream handling
- apply rate limiting, session limits, routing, and metrics using the resolved client IP while preserving correct proxy return-path routing
## 2026-03-19 - 25.16.3 - fix(rustproxy)
upgrade fallback UDP listeners to QUIC when TLS certificates become available
- Rebuild and apply QUIC TLS configuration during route and certificate updates instead of only when adding new UDP ports.
- Add logic to drain UDP sessions, stop raw fallback listeners, and start QUIC endpoints on existing ports once TLS is available.
- Retry QUIC endpoint creation during upgrade and fall back to rebinding raw UDP if the upgrade cannot complete.
## 2026-03-19 - 25.16.2 - fix(rustproxy-http)
cache backend Alt-Svc only from original upstream responses during protocol auto-detection
- Moves Alt-Svc discovery into streaming response construction so it reads backend headers before response filters inject client-facing Alt-Svc values
- Stores the protocol cache key in connection activity during auto-detect mode and clears it after HTTP/3 connection failure to avoid re-caching failed H3 routes
- Prevents fallback requests from reintroducing stale or self-injected Alt-Svc entries that could cause repeated H3 retry loops
## 2026-03-19 - 25.16.1 - fix(http-proxy)
avoid repeated HTTP/3 recaching after QUIC fallback and document backend protocol selection
- Suppress Alt-Svc HTTP/3 recaching after a failed QUIC backend connection to prevent repeated H3 timeout fallback loops
- Force an ALPN probe on TCP fallback so auto detection correctly reselects HTTP/2 or HTTP/1.1 after H3 connection failure
- Add README documentation for best-effort backendProtocol selection and supported protocol modes
## 2026-03-19 - 25.16.0 - feat(quic,http3)
add HTTP/3 proxy handling and hot-reload QUIC TLS configuration
- initialize and wire H3ProxyService into QUIC listeners so HTTP/3 requests are handled instead of being kept as placeholder connections
- add backend HTTP/3 support with protocol caching that stores Alt-Svc advertised H3 ports for auto-detection
- hot-swap TLS certificates across active QUIC endpoints and require terminating TLS for QUIC route validation
- document QUIC route setup with required TLS and ACME configuration
## 2026-03-19 - 25.15.0 - feat(readme)
document UDP, QUIC, and HTTP/3 support in the README
- Adds README examples for UDP datagram handlers, QUIC/HTTP3 forwarding, and dual-stack TCP/UDP routes
- Expands configuration and API reference sections to cover transport matching, UDP/QUIC options, backend transport selection, and UDP metrics
- Updates architecture and feature descriptions to reflect UDP, QUIC, HTTP/3, and datagram handler capabilities
## 2026-03-19 - 25.14.1 - fix(deps)
update build and runtime dependencies and align route validation test expectations
- split the test preparation step into a dedicated test:before script while keeping test execution separate
- bump development tooling and runtime package versions in package.json
- adjust the route validation test to match the current generic handler error message
## 2026-03-19 - 25.14.0 - feat(udp,http3)
add UDP datagram handler relay support and stream HTTP/3 request bodies to backends
- establish a persistent Unix socket relay for UDP datagram handlers and process handler replies back to clients
- update route validation and smart proxy route reload logic to support datagramHandler routes
- record UDP, QUIC, and HTTP/3 byte metrics more accurately, including request bytes in and UDP session cleanup connection tracking
- add integration tests for UDP forwarding, datagram handlers, and UDP metrics
## 2026-03-19 - 25.13.0 - feat(smart-proxy)
add UDP transport support with QUIC/HTTP3 routing and datagram handler relay
- adds UDP listener and session tracking infrastructure in the Rust proxy, including UDP metrics and hot-reload support for transport-specific ports
- introduces QUIC and HTTP/3 support in routing and HTTP handling, including Alt-Svc advertisement and QUIC TLS configuration
- extends route configuration types in Rust and TypeScript with transport, UDP, QUIC, backend transport, and mixed port range support
- adds a TypeScript datagram handler relay server and bridge command so UDP socket-handler routes can dispatch datagrams to application callbacks
- updates nftables rule generation so protocol=all creates both TCP and UDP rules
## 2026-03-19 - 25.12.0 - feat(proxy-protocol)
add PROXY protocol v2 support to the Rust passthrough listener and streamline TypeScript proxy protocol exports
- detect and parse PROXY protocol v2 headers in the Rust TCP listener, including TCP and UDP address families
- add Rust v2 header generation, incomplete-header handling, and broader parser test coverage
- remove deprecated TypeScript proxy protocol parser exports and tests, leaving shared type definitions only
## 2026-03-17 - 25.11.24 - fix(rustproxy-http)
improve async static file serving, websocket handshake buffering, and shared metric metadata handling
- convert static file serving to async filesystem operations and await directory/file checks
- preserve and forward bytes read past the WebSocket handshake header terminator to avoid dropping buffered upstream data
- reuse Arc<str> values for route and source identifiers across counting bodies and metric reporting
- standardize backend key propagation across H1/H2 forwarding, retry, and fallback paths for consistent logging and metrics
## 2026-03-17 - 25.11.23 - fix(rustproxy-http,rustproxy-metrics)
reduce per-frame metrics overhead by batching body byte accounting
- Buffer HTTP body byte counts and flush them every 64 KB, at end of stream, and on drop to keep totals accurate while preserving throughput sampling.
- Skip zero-value counter updates in metrics collection to avoid unnecessary atomic and DashMap operations for the unused direction.
## 2026-03-17 - 25.11.22 - fix(rustproxy-http)
reuse healthy HTTP/2 upstream connections after requests with bodies
- Registers successful HTTP/2 connections in the pool regardless of whether the proxied request included a body
- Continues to avoid pooling upstream connections that returned 502 Bad Gateway responses
## 2026-03-17 - 25.11.21 - fix(rustproxy-http)
reuse pooled HTTP/2 connections for requests with and without bodies
- remove the bodyless-request restriction from HTTP/2 pool checkout
- always return successful HTTP/2 senders to the connection pool after requests
## 2026-03-17 - 25.11.20 - fix(rustproxy-http)
avoid downgrading cached backend protocol on H2 stream errors
- Treat HTTP/2 stream-level failures as retryable request errors instead of evidence that the backend only supports HTTP/1.1
- Keep protocol cache entries unchanged after successful H2 handshakes so future requests continue using HTTP/2
- Lower log severity for this fallback path from warning to debug while still recording backend H2 failure metrics
## 2026-03-16 - 25.11.19 - fix(rustproxy-http)
avoid reusing pooled HTTP/2 connections for requests with bodies to prevent upload flow-control stalls
- Limit HTTP/2 pool checkout to bodyless requests such as GET, HEAD, and DELETE
- Skip re-registering HTTP/2 connections in the pool after requests that send a body
- Prevent stalled uploads caused by depleted connection-level flow control windows on reused HTTP/2 connections
## 2026-03-16 - 25.11.18 - fix(repo)
no changes to commit
## 2026-03-16 - 25.11.17 - fix(rustproxy-http)
prevent stale HTTP/2 connection drivers from evicting newer pooled connections
- add generation IDs to pooled HTTP/2 senders so pool removal only affects the matching connection
- update HTTP/2 proxy and retry paths to register generation-tagged connections and skip eviction before registration completes
## 2026-03-16 - 25.11.16 - fix(repo)
no changes to commit
## 2026-03-16 - 25.11.15 - fix(rustproxy-http)
implement vectored write support for backend streams
- Add poll_write_vectored forwarding for both plain and TLS backend stream variants
- Expose is_write_vectored so the proxy can correctly report vectored write capability
## 2026-03-16 - 25.11.14 - fix(rustproxy-http)
forward vectored write support in ShutdownOnDrop AsyncWrite wrapper
- Implements poll_write_vectored by delegating to the wrapped writer
- Exposes is_write_vectored so the wrapper preserves underlying AsyncWrite capabilities
## 2026-03-16 - 25.11.13 - fix(rustproxy-http)
remove hot-path debug logging from HTTP/1 connection pool hits
- Stops emitting debug logs when reusing HTTP/1 idle connections in the connection pool.
- Keeps pool hit behavior unchanged while reducing overhead on a frequently executed path.
## 2026-03-16 - 25.11.12 - fix(rustproxy-http)
remove connection pool hit logging and keep logging limited to actual failures
- Removes debug and warning logs for HTTP/2 connection pool hits and age checks.
- Keeps pool behavior unchanged while reducing noisy per-request logging in the Rust HTTP proxy layer.
## 2026-03-16 - 25.11.11 - fix(rustproxy-http)
improve HTTP/2 proxy error logging with warning-level connection failures and debug error details
Generated
+1600 -1913
View File
File diff suppressed because it is too large Load Diff
+1 -1
View File
@@ -1,4 +1,4 @@
Copyright (c) 2019 Lossless GmbH (hello@lossless.com)
Copyright (c) 2019 Task Venture Capital GmbH (hello@task.vc)
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
+19 -15
View File
@@ -1,6 +1,6 @@
{
"name": "@push.rocks/smartproxy",
"version": "25.11.11",
"version": "27.9.0",
"private": false,
"description": "A powerful proxy package with unified route-based configuration for high traffic management. Features include SSL/TLS support, flexible routing patterns, WebSocket handling, advanced security options, and automatic ACME certificate management.",
"main": "dist_ts/index.js",
@@ -9,27 +9,31 @@
"author": "Lossless GmbH",
"license": "MIT",
"scripts": {
"test": "(tsrust) && (tstest test/**/test*.ts --verbose --timeout 60 --logfile)",
"test:before": "(tsrust)",
"test": "(tstest test/**/test*.ts --verbose --timeout 60 --logfile)",
"build": "(tsbuild tsfolders --allowimplicitany) && (tsrust)",
"format": "(gitzone format)",
"buildDocs": "tsdoc"
},
"devDependencies": {
"@git.zone/tsbuild": "^4.1.2",
"@git.zone/tsrun": "^2.0.1",
"@git.zone/tsrust": "^1.3.0",
"@git.zone/tstest": "^3.1.8",
"@push.rocks/smartserve": "^2.0.1",
"@types/node": "^25.2.3",
"typescript": "^5.9.3",
"why-is-node-running": "^3.2.2"
"@git.zone/tsbuild": "^4.4.0",
"@git.zone/tsrun": "^2.0.2",
"@git.zone/tsrust": "^1.3.2",
"@git.zone/tstest": "^3.6.0",
"@push.rocks/smartserve": "^2.0.3",
"@types/node": "^25.5.0",
"@types/ws": "^8.18.1",
"typescript": "^6.0.2",
"why-is-node-running": "^3.2.2",
"ws": "^8.20.0"
},
"dependencies": {
"@push.rocks/smartcrypto": "^2.0.4",
"@push.rocks/smartlog": "^3.1.10",
"@push.rocks/smartrust": "^1.2.1",
"@tsclass/tsclass": "^9.3.0",
"minimatch": "^10.2.0"
"@push.rocks/smartlog": "^3.2.1",
"@push.rocks/smartnftables": "^1.0.1",
"@push.rocks/smartrust": "^1.3.2",
"@tsclass/tsclass": "^9.5.0",
"minimatch": "^10.2.4"
},
"files": [
"ts/**/*",
@@ -40,7 +44,7 @@
"dist_ts_web/**/*",
"assets/**/*",
"cli.js",
"npmextra.json",
".smartconfig.json",
"readme.md",
"changelog.md"
],
+2344 -2759
View File
File diff suppressed because it is too large Load Diff
+38 -16
View File
@@ -462,35 +462,57 @@ For TLS termination modes (`terminate` and `terminate-and-reencrypt`), SmartProx
**HTTP to HTTPS Redirect**:
```typescript
import { createHttpToHttpsRedirect } from '@push.rocks/smartproxy';
import { SocketHandlers } from '@push.rocks/smartproxy';
const redirectRoute = createHttpToHttpsRedirect(['example.com', 'www.example.com']);
const redirectRoute = {
name: 'http-to-https',
match: { ports: 80, domains: ['example.com', 'www.example.com'] },
action: {
type: 'socket-handler' as const,
socketHandler: SocketHandlers.httpRedirect('https://{domain}:443{path}', 301)
}
};
```
**Complete HTTPS Server (with redirect)**:
```typescript
import { createCompleteHttpsServer } from '@push.rocks/smartproxy';
const routes = createCompleteHttpsServer(
'example.com',
{ host: 'localhost', port: 8080 },
{ certificate: 'auto' }
);
const routes = [
{
name: 'https-server',
match: { ports: 443, domains: 'example.com' },
action: {
type: 'forward' as const,
targets: [{ host: 'localhost', port: 8080 }],
tls: { mode: 'terminate' as const, certificate: 'auto' as const }
}
},
{
name: 'http-redirect',
match: { ports: 80, domains: 'example.com' },
action: {
type: 'socket-handler' as const,
socketHandler: SocketHandlers.httpRedirect('https://{domain}:443{path}', 301)
}
}
];
```
**Load Balancer with Health Checks**:
```typescript
import { createLoadBalancerRoute } from '@push.rocks/smartproxy';
const lbRoute = createLoadBalancerRoute(
'api.example.com',
[
const lbRoute = {
name: 'load-balancer',
match: { ports: 443, domains: 'api.example.com' },
action: {
type: 'forward' as const,
targets: [
{ host: 'backend1', port: 8080 },
{ host: 'backend2', port: 8080 },
{ host: 'backend3', port: 8080 }
],
{ tls: { mode: 'terminate', certificate: 'auto' } }
);
tls: { mode: 'terminate' as const, certificate: 'auto' as const },
loadBalancing: { algorithm: 'round-robin' as const }
}
};
```
### Smart SNI Requirement (v22.3+)
+376 -161
View File
@@ -1,6 +1,6 @@
# @push.rocks/smartproxy 🚀
**A high-performance, Rust-powered proxy toolkit for Node.js** — unified route-based configuration for SSL/TLS termination, HTTP/HTTPS reverse proxying, WebSocket support, load balancing, custom protocol handlers, and kernel-level NFTables forwarding.
**A high-performance, Rust-powered proxy toolkit for Node.js** — unified route-based configuration for SSL/TLS termination, HTTP/HTTPS reverse proxying, WebSocket support, UDP/QUIC/HTTP3, load balancing, custom protocol handlers, and kernel-level NFTables forwarding via [`@push.rocks/smartnftables`](https://code.foss.global/push.rocks/smartnftables).
## 📦 Installation
@@ -16,9 +16,9 @@ For reporting bugs, issues, or security vulnerabilities, please visit [community
## 🎯 What is SmartProxy?
SmartProxy is a production-ready proxy solution that takes the complexity out of traffic management. Under the hood, all networking — TCP, TLS, HTTP reverse proxy, connection tracking, security enforcement, and NFTables — is handled by a **Rust engine** for maximum performance, while you configure everything through a clean TypeScript API with full type safety.
SmartProxy is a production-ready proxy solution that takes the complexity out of traffic management. Under the hood, all networking — TCP, UDP, TLS, HTTP reverse proxy, QUIC/HTTP3, connection tracking, security enforcement, and NFTables — is handled by a **Rust engine** for maximum performance, while you configure everything through a clean TypeScript API with full type safety.
Whether you're building microservices, deploying edge infrastructure, or need a battle-tested reverse proxy with automatic Let's Encrypt certificates, SmartProxy has you covered.
Whether you're building microservices, deploying edge infrastructure, proxying UDP-based protocols, or need a battle-tested reverse proxy with automatic Let's Encrypt certificates, SmartProxy has you covered.
### ⚡ Key Features
@@ -29,11 +29,12 @@ Whether you're building microservices, deploying edge infrastructure, or need a
| 🔒 **Automatic SSL/TLS** | Zero-config HTTPS with Let's Encrypt ACME integration |
| 🎯 **Flexible Matching** | Route by port, domain, path, protocol, client IP, TLS version, headers, or custom logic |
| 🚄 **High-Performance** | Choose between user-space or kernel-level (NFTables) forwarding |
| 📡 **UDP & QUIC/HTTP3** | First-class UDP transport, datagram handlers, QUIC tunneling, and HTTP/3 support |
| ⚖️ **Load Balancing** | Round-robin, least-connections, IP-hash with health checks |
| 🛡️ **Enterprise Security** | IP filtering, rate limiting, basic auth, JWT auth, connection limits |
| 🔌 **WebSocket Support** | First-class WebSocket proxying with ping/pong keep-alive |
| 🎮 **Custom Protocols** | Socket handlers for implementing any protocol in TypeScript |
| 📊 **Live Metrics** | Real-time throughput, connection counts, and performance data |
| 🎮 **Custom Protocols** | Socket and datagram handlers for implementing any protocol in TypeScript |
| 📊 **Live Metrics** | Real-time throughput, connection counts, UDP sessions, and performance data |
| 🔧 **Dynamic Management** | Add/remove ports and routes at runtime without restarts |
| 🔄 **PROXY Protocol** | Full PROXY protocol v1/v2 support for preserving client information |
| 💾 **Consumer Cert Storage** | Bring your own persistence — SmartProxy never writes certs to disk |
@@ -43,7 +44,7 @@ Whether you're building microservices, deploying edge infrastructure, or need a
Get up and running in 30 seconds:
```typescript
import { SmartProxy, createCompleteHttpsServer } from '@push.rocks/smartproxy';
import { SmartProxy, SocketHandlers } from '@push.rocks/smartproxy';
// Create a proxy with automatic HTTPS
const proxy = new SmartProxy({
@@ -52,13 +53,25 @@ const proxy = new SmartProxy({
useProduction: true
},
routes: [
// Complete HTTPS setup in one call! ✨
...createCompleteHttpsServer('app.example.com', {
host: 'localhost',
port: 3000
}, {
certificate: 'auto' // Automatic Let's Encrypt cert 🎩
})
// HTTPS route with automatic Let's Encrypt cert
{
name: 'https-app',
match: { ports: 443, domains: 'app.example.com' },
action: {
type: 'forward',
targets: [{ host: 'localhost', port: 3000 }],
tls: { mode: 'terminate', certificate: 'auto' }
}
},
// HTTP → HTTPS redirect
{
name: 'http-redirect',
match: { ports: 80, domains: 'app.example.com' },
action: {
type: 'socket-handler',
socketHandler: SocketHandlers.httpRedirect('https://{domain}:443{path}', 301)
}
}
]
});
@@ -89,7 +102,7 @@ SmartProxy uses a powerful **match/action** pattern that makes routing predictab
```
Every route consists of:
- **Match** — What traffic to capture (ports, domains, paths, protocol, IPs, headers)
- **Match** — What traffic to capture (ports, domains, paths, transport, protocol, IPs, headers)
- **Action** — What to do with it (`forward` or `socket-handler`)
- **Security** (optional) — IP allow/block lists, rate limits, authentication
- **Headers** (optional) — Request/response header manipulation with template variables
@@ -110,31 +123,38 @@ SmartProxy supports three TLS handling modes:
### 🌐 HTTP to HTTPS Redirect
```typescript
import { SmartProxy, createHttpToHttpsRedirect } from '@push.rocks/smartproxy';
import { SmartProxy, SocketHandlers } from '@push.rocks/smartproxy';
const proxy = new SmartProxy({
routes: [
createHttpToHttpsRedirect(['example.com', '*.example.com'])
]
routes: [{
name: 'http-to-https',
match: { ports: 80, domains: ['example.com', '*.example.com'] },
action: {
type: 'socket-handler',
socketHandler: SocketHandlers.httpRedirect('https://{domain}:443{path}', 301)
}
}]
});
```
### ⚖️ Load Balancer with Health Checks
```typescript
import { SmartProxy, createLoadBalancerRoute } from '@push.rocks/smartproxy';
import { SmartProxy } from '@push.rocks/smartproxy';
const proxy = new SmartProxy({
routes: [
createLoadBalancerRoute(
'app.example.com',
[
routes: [{
name: 'load-balancer',
match: { ports: 443, domains: 'app.example.com' },
action: {
type: 'forward',
targets: [
{ host: 'server1.internal', port: 8080 },
{ host: 'server2.internal', port: 8080 },
{ host: 'server3.internal', port: 8080 }
],
{
tls: { mode: 'terminate', certificate: 'auto' },
loadBalancing: {
algorithm: 'round-robin',
healthCheck: {
path: '/health',
@@ -144,78 +164,92 @@ const proxy = new SmartProxy({
healthyThreshold: 2
}
}
)
]
}
}]
});
```
### 🔌 WebSocket Proxy
```typescript
import { SmartProxy, createWebSocketRoute } from '@push.rocks/smartproxy';
import { SmartProxy } from '@push.rocks/smartproxy';
const proxy = new SmartProxy({
routes: [
createWebSocketRoute(
'ws.example.com',
{ host: 'websocket-server', port: 8080 },
{
path: '/socket',
useTls: true,
certificate: 'auto',
routes: [{
name: 'websocket',
match: { ports: 443, domains: 'ws.example.com', path: '/socket' },
priority: 100,
action: {
type: 'forward',
targets: [{ host: 'websocket-server', port: 8080 }],
tls: { mode: 'terminate', certificate: 'auto' },
websocket: {
enabled: true,
pingInterval: 30000,
pingTimeout: 10000
}
)
]
}
}]
});
```
### 🚦 API Gateway with Rate Limiting
```typescript
import { SmartProxy, createApiGatewayRoute, addRateLimiting } from '@push.rocks/smartproxy';
import { SmartProxy } from '@push.rocks/smartproxy';
let apiRoute = createApiGatewayRoute(
'api.example.com',
'/api',
{ host: 'api-backend', port: 8080 },
{
useTls: true,
certificate: 'auto',
addCorsHeaders: true
const proxy = new SmartProxy({
routes: [{
name: 'api-gateway',
match: { ports: 443, domains: 'api.example.com', path: '/api/*' },
priority: 100,
action: {
type: 'forward',
targets: [{ host: 'api-backend', port: 8080 }],
tls: { mode: 'terminate', certificate: 'auto' }
},
headers: {
response: {
'Access-Control-Allow-Origin': '*',
'Access-Control-Allow-Methods': 'GET, POST, PUT, DELETE, OPTIONS',
'Access-Control-Allow-Headers': 'Content-Type, Authorization',
'Access-Control-Max-Age': '86400'
}
);
// Add rate limiting — 100 requests per minute per IP
apiRoute = addRateLimiting(apiRoute, {
},
security: {
rateLimit: {
enabled: true,
maxRequests: 100,
window: 60,
keyBy: 'ip'
}
}
}]
});
const proxy = new SmartProxy({ routes: [apiRoute] });
```
### 🎮 Custom Protocol Handler
### 🎮 Custom Protocol Handler (TCP)
SmartProxy lets you implement any protocol with full socket control. Routes with JavaScript socket handlers are automatically relayed from the Rust engine back to your TypeScript code:
```typescript
import { SmartProxy, createSocketHandlerRoute, SocketHandlers } from '@push.rocks/smartproxy';
import { SmartProxy, SocketHandlers } from '@push.rocks/smartproxy';
const proxy = new SmartProxy({
routes: [
// Use pre-built handlers
const echoRoute = createSocketHandlerRoute(
'echo.example.com',
7777,
SocketHandlers.echo
);
{
name: 'echo-server',
match: { ports: 7777, domains: 'echo.example.com' },
action: { type: 'socket-handler', socketHandler: SocketHandlers.echo }
},
// Or create your own custom protocol
const customRoute = createSocketHandlerRoute(
'custom.example.com',
9999,
async (socket) => {
{
name: 'custom-protocol',
match: { ports: 9999, domains: 'custom.example.com' },
action: {
type: 'socket-handler',
socketHandler: async (socket) => {
console.log(`New connection on custom protocol`);
socket.write('Welcome to my custom protocol!\n');
@@ -229,9 +263,10 @@ const customRoute = createSocketHandlerRoute(
}
});
}
);
const proxy = new SmartProxy({ routes: [echoRoute, customRoute] });
}
}
]
});
```
**Pre-built Socket Handlers:**
@@ -247,25 +282,162 @@ const proxy = new SmartProxy({ routes: [echoRoute, customRoute] });
| `SocketHandlers.httpBlock(status, message)` | HTTP block response |
| `SocketHandlers.block(message)` | Block with optional message |
### ⚡ High-Performance NFTables Forwarding
### 📡 UDP Datagram Handler
For ultra-low latency on Linux, use kernel-level forwarding (requires root):
Handle raw UDP datagrams with custom TypeScript logic — perfect for DNS, game servers, IoT protocols, or any UDP-based service:
```typescript
import { SmartProxy, createNfTablesTerminateRoute } from '@push.rocks/smartproxy';
import { SmartProxy } from '@push.rocks/smartproxy';
import type { IRouteConfig, TDatagramHandler, IDatagramInfo } from '@push.rocks/smartproxy';
// Custom UDP echo handler
const udpHandler: TDatagramHandler = (datagram, info, reply) => {
console.log(`UDP from ${info.sourceIp}:${info.sourcePort} on port ${info.destPort}`);
reply(datagram); // Echo it back
};
const proxy = new SmartProxy({
routes: [
createNfTablesTerminateRoute(
'fast.example.com',
{ host: 'backend', port: 8080 },
{
routes: [{
name: 'udp-echo',
match: {
ports: 5353,
transport: 'udp' // 👈 Listen for UDP datagrams
},
action: {
type: 'socket-handler',
datagramHandler: udpHandler, // 👈 Process each datagram
udp: {
sessionTimeout: 60000, // Session idle timeout (ms)
maxSessionsPerIP: 100,
maxDatagramSize: 65535
}
}
}]
});
await proxy.start();
```
### 📡 QUIC / HTTP3 Forwarding
Forward QUIC traffic to backends with optional protocol translation (e.g., receive QUIC, forward as TCP/HTTP1):
```typescript
import { SmartProxy } from '@push.rocks/smartproxy';
import type { IRouteConfig } from '@push.rocks/smartproxy';
const quicRoute: IRouteConfig = {
name: 'quic-to-backend',
match: {
ports: 443,
certificate: 'auto',
transport: 'udp',
protocol: 'quic' // 👈 Match QUIC protocol
},
action: {
type: 'forward',
targets: [{
host: 'backend-server',
port: 8443,
backendTransport: 'tcp' // 👈 Translate QUIC → TCP for backend
}],
tls: {
mode: 'terminate',
certificate: 'auto' // 👈 QUIC requires TLS 1.3
},
udp: {
quic: {
enableHttp3: true,
maxIdleTimeout: 30000,
maxConcurrentBidiStreams: 100,
altSvcPort: 443, // Advertise in Alt-Svc header
altSvcMaxAge: 86400
}
}
}
};
const proxy = new SmartProxy({
acme: { email: 'ssl@example.com' },
routes: [quicRoute]
});
```
### 🚄 Best-Effort Backend Protocol (H3 > H2 > H1)
SmartProxy automatically uses the **highest protocol your backend supports** for HTTP requests. The backend protocol is independent of the client protocol — a client using HTTP/1.1 can be forwarded over HTTP/3 to the backend, and vice versa.
```typescript
const route: IRouteConfig = {
name: 'auto-protocol',
match: { ports: 443, domains: 'app.example.com' },
action: {
type: 'forward',
targets: [{ host: 'backend', port: 8443 }],
tls: { mode: 'terminate', certificate: 'auto' },
options: {
backendProtocol: 'auto' // 👈 Default — best-effort selection
}
}
};
```
**How protocol discovery works (browser model):**
1. First request → TLS ALPN probe detects H2 or H1
2. Backend response inspected for `Alt-Svc: h3=":port"` header
3. If H3 advertised → cached and used for subsequent requests via QUIC
4. Graceful fallback: H3 failure → H2 → H1 with automatic cache invalidation
| `backendProtocol` | Behavior |
|---|---|
| `'auto'` (default) | Best-effort: H3 > H2 > H1 with Alt-Svc discovery |
| `'http1'` | Always HTTP/1.1 |
| `'http2'` | Always HTTP/2 (hard-fail if unsupported) |
| `'http3'` | Always HTTP/3 via QUIC (hard-fail if unsupported) |
> **Note:** WebSocket upgrades always use HTTP/1.1 to the backend regardless of `backendProtocol`, since there's no performance benefit from H2/H3 Extended CONNECT for tunneled connections, and backend support is rare.
### 🔁 Dual-Stack TCP + UDP Route
Listen on both TCP and UDP with a single route — handle each transport with its own handler:
```typescript
const dualStackRoute: IRouteConfig = {
name: 'dual-stack-dns',
match: {
ports: 53,
transport: 'all' // 👈 Listen on both TCP and UDP
},
action: {
type: 'socket-handler',
socketHandler: handleTcpDns, // 👈 TCP connections
datagramHandler: handleUdpDns, // 👈 UDP datagrams
}
};
```
### ⚡ High-Performance NFTables Forwarding
For ultra-low latency on Linux, use kernel-level forwarding via [`@push.rocks/smartnftables`](https://code.foss.global/push.rocks/smartnftables) (requires root):
```typescript
import { SmartProxy } from '@push.rocks/smartproxy';
const proxy = new SmartProxy({
routes: [{
name: 'nftables-fast',
match: { ports: 443, domains: 'fast.example.com' },
action: {
type: 'forward',
forwardingEngine: 'nftables',
targets: [{ host: 'backend', port: 8080 }],
tls: { mode: 'terminate', certificate: 'auto' },
nftables: {
protocol: 'tcp',
preserveSourceIP: true // Backend sees real client IP
}
)
]
}
}]
});
```
@@ -274,15 +446,18 @@ const proxy = new SmartProxy({
Forward encrypted traffic to backends without terminating TLS — the proxy routes based on the SNI hostname alone:
```typescript
import { SmartProxy, createHttpsPassthroughRoute } from '@push.rocks/smartproxy';
import { SmartProxy } from '@push.rocks/smartproxy';
const proxy = new SmartProxy({
routes: [
createHttpsPassthroughRoute('secure.example.com', {
host: 'backend-that-handles-tls',
port: 8443
})
]
routes: [{
name: 'sni-passthrough',
match: { ports: 443, domains: 'secure.example.com' },
action: {
type: 'forward',
targets: [{ host: 'backend-that-handles-tls', port: 8443 }],
tls: { mode: 'passthrough' }
}
}]
});
```
@@ -389,15 +564,7 @@ Comprehensive per-route security options:
}
```
**Security modifier helpers** let you add security to any existing route:
```typescript
import { addRateLimiting, addBasicAuth, addJwtAuth } from '@push.rocks/smartproxy';
let route = createHttpsTerminateRoute('api.example.com', { host: 'backend', port: 8080 });
route = addRateLimiting(route, { maxRequests: 100, window: 60, keyBy: 'ip' });
route = addBasicAuth(route, { users: [{ username: 'admin', password: 'secret' }] });
```
Security options are configured directly on each route's `security` property — no separate helpers needed.
### 📊 Runtime Management
@@ -419,6 +586,10 @@ console.log(`Bytes in: ${metrics.totals.bytesIn()}`);
console.log(`Requests/sec: ${metrics.requests.perSecond()}`);
console.log(`Throughput in: ${metrics.throughput.instant().in} bytes/sec`);
// UDP metrics
console.log(`UDP sessions: ${metrics.udp.activeSessions()}`);
console.log(`Datagrams in: ${metrics.udp.datagramsIn()}`);
// Get detailed statistics from the Rust engine
const stats = await proxy.getStatistics();
@@ -545,7 +716,7 @@ SmartProxy uses a hybrid **Rust + TypeScript** architecture:
```
┌─────────────────────────────────────────────────────┐
│ Your Application │
│ (TypeScript — routes, config, socket handlers) │
│ (TypeScript — routes, config, handlers)
└──────────────────┬──────────────────────────────────┘
│ IPC (JSON over stdin/stdout)
┌──────────────────▼──────────────────────────────────┐
@@ -555,23 +726,28 @@ SmartProxy uses a hybrid **Rust + TypeScript** architecture:
│ │ Listener│ │ Reverse │ │ Matcher │ │ Cert Mgr │ │
│ │ │ │ Proxy │ │ │ │ │ │
│ └─────────┘ └─────────┘ └─────────┘ └──────────┘ │
│ ┌─────────┐ ┌─────────┐ ┌─────────┐ ┌──────────┐
│ │ Security│ │ Metrics │ │ Connec- │ │ NFTables │
│ │ Enforce │ │ Collect │ tion │ │ Mgr
│ │ │ │ │ Tracker │ │
│ └─────────┘ └─────────┘ └─────────┘ └──────────┘
│ ┌─────────┐ ┌─────────┐ ┌─────────┐
│ │ UDP │ │ Security│ │ Metrics │
│ │ QUIC │ │ Enforce │ │ Collect │
│ │ HTTP/3 │ │
│ └─────────┘ └─────────┘ └─────────┘
└──────────────────┬──────────────────────────────────┘
│ Unix Socket Relay
┌──────────────────▼──────────────────────────────────┐
TypeScript Socket Handler Server
│ (for JS-defined socket handlers & dynamic routes)
│ TypeScript Socket & Datagram Handler Servers
│ (for JS socket handlers, datagram handlers,
│ and dynamic routes) │
├─────────────────────────────────────────────────────┤
│ @push.rocks/smartnftables (kernel-level NFTables) │
│ (DNAT/SNAT, firewall, rate limiting via nft CLI) │
└─────────────────────────────────────────────────────┘
```
- **Rust Engine** handles all networking, TLS, HTTP proxying, connection management, security, and metrics
- **TypeScript** provides the npm API, configuration types, route helpers, validation, and socket handler callbacks
- **Rust Engine** handles all networking: TCP, UDP, TLS, QUIC, HTTP proxying, connection management, security, and metrics
- **TypeScript** provides the npm API, configuration types, validation, and handler callbacks
- **NFTables** managed by [`@push.rocks/smartnftables`](https://code.foss.global/push.rocks/smartnftables) — kernel-level DNAT/SNAT forwarding, firewall rules, and rate limiting via the `nft` CLI
- **IPC** — The TypeScript wrapper uses JSON commands/events over stdin/stdout to communicate with the Rust binary
- **Socket Relay** — A Unix domain socket server for routes requiring TypeScript-side handling (socket handlers, dynamic host/port functions)
- **Socket/Datagram Relay** — Unix domain socket servers for routes requiring TypeScript-side handling (socket handlers, datagram handlers, dynamic host/port functions)
## 🎯 Route Configuration Reference
@@ -579,22 +755,26 @@ SmartProxy uses a hybrid **Rust + TypeScript** architecture:
```typescript
interface IRouteMatch {
ports: number | number[] | Array<{ from: number; to: number }>; // Required — port(s) to listen on
ports: TPortRange; // Required — port(s) to listen on
transport?: 'tcp' | 'udp' | 'all'; // Transport protocol (default: 'tcp')
domains?: string | string[]; // 'example.com', '*.example.com'
path?: string; // '/api/*', '/users/:id'
clientIp?: string[]; // ['10.0.0.0/8', '192.168.*']
tlsVersion?: string[]; // ['TLSv1.2', 'TLSv1.3']
headers?: Record<string, string | RegExp>; // Match by HTTP headers
protocol?: 'http' | 'tcp'; // Match specific protocol ('http' includes h2 + WebSocket upgrades)
protocol?: 'http' | 'tcp' | 'udp' | 'quic' | 'http3'; // Application-layer protocol
}
// Port range supports single numbers, arrays, and ranges
type TPortRange = number | Array<number | { from: number; to: number }>;
```
### Action Types
| Type | Description |
|------|-------------|
| `forward` | Proxy to one or more backend targets (with optional TLS, WebSocket, load balancing) |
| `socket-handler` | Custom socket handling function in TypeScript |
| `forward` | Proxy to one or more backend targets (with optional TLS, WebSocket, load balancing, UDP/QUIC) |
| `socket-handler` | Custom socket/datagram handling function in TypeScript |
### Target Options
@@ -610,6 +790,7 @@ interface IRouteTarget {
sendProxyProtocol?: boolean;
headers?: IRouteHeaders;
advanced?: IRouteAdvanced;
backendTransport?: 'tcp' | 'udp'; // Backend transport (e.g., receive QUIC, forward as TCP)
}
```
@@ -666,47 +847,56 @@ interface IRouteLoadBalancing {
}
```
## 🛠️ Helper Functions Reference
### Backend Protocol Options
All helpers are fully typed and return `IRouteConfig` or `IRouteConfig[]`:
```typescript
// Set on action.options
{
action: {
type: 'forward',
targets: [...],
options: {
backendProtocol: 'auto' | 'http1' | 'http2' | 'http3'
}
}
}
```
| Value | Backend Behavior |
|-------|-----------------|
| `'auto'` | Best-effort: discovers H3 via Alt-Svc, probes H2 via ALPN, falls back to H1 |
| `'http1'` | Always HTTP/1.1 (no ALPN probe) |
| `'http2'` | Always HTTP/2 (hard-fail if handshake fails) |
| `'http3'` | Always HTTP/3 over QUIC (3s connect timeout, hard-fail if unreachable) |
### UDP & QUIC Options
```typescript
interface IRouteUdp {
sessionTimeout?: number; // Idle timeout per UDP session (ms, default: 60000)
maxSessionsPerIP?: number; // Max concurrent sessions per IP (default: 1000)
maxDatagramSize?: number; // Max datagram size in bytes (default: 65535)
quic?: IRouteQuic;
}
interface IRouteQuic {
maxIdleTimeout?: number; // QUIC idle timeout (ms, default: 30000)
maxConcurrentBidiStreams?: number; // Max bidi streams (default: 100)
maxConcurrentUniStreams?: number; // Max uni streams (default: 100)
enableHttp3?: boolean; // Enable HTTP/3 (default: false)
altSvcPort?: number; // Port for Alt-Svc header
altSvcMaxAge?: number; // Alt-Svc max age in seconds (default: 86400)
initialCongestionWindow?: number; // Initial congestion window (bytes)
}
```
## 🛠️ Exports Reference
```typescript
import {
// HTTP/HTTPS
createHttpRoute, // Plain HTTP route
createHttpsTerminateRoute, // HTTPS with TLS termination
createHttpsPassthroughRoute, // SNI passthrough (no termination)
createHttpToHttpsRedirect, // HTTP → HTTPS redirect
createCompleteHttpsServer, // HTTPS + redirect combo (returns IRouteConfig[])
// Load Balancing
createLoadBalancerRoute, // Multi-backend with health checks
createSmartLoadBalancer, // Dynamic domain-based backend selection
// API & WebSocket
createApiRoute, // API route with path matching
createApiGatewayRoute, // API gateway with CORS
createWebSocketRoute, // WebSocket-enabled route
// Custom Protocols
createSocketHandlerRoute, // Custom socket handler
SocketHandlers, // Pre-built handlers (echo, proxy, block, etc.)
// NFTables (Linux, requires root)
createNfTablesRoute, // Kernel-level packet forwarding
createNfTablesTerminateRoute, // NFTables + TLS termination
createCompleteNfTablesHttpsServer, // NFTables HTTPS + redirect combo
// Dynamic Routing
createPortMappingRoute, // Port mapping with context
createOffsetPortMappingRoute, // Simple port offset
createDynamicRoute, // Dynamic host/port via functions
createPortOffset, // Port offset factory
// Security Modifiers
addRateLimiting, // Add rate limiting to any route
addBasicAuth, // Add basic auth to any route
addJwtAuth, // Add JWT auth to any route
// Core
SmartProxy, // Main proxy class
SocketHandlers, // Pre-built socket handlers (echo, proxy, block, httpRedirect, httpServer, etc.)
// Route Utilities
mergeRouteConfigs, // Deep-merge two route configs
@@ -718,6 +908,8 @@ import {
} from '@push.rocks/smartproxy';
```
All routes are configured as plain `IRouteConfig` objects with `match` and `action` properties — see the examples throughout this document.
## 📖 API Documentation
### SmartProxy Class
@@ -748,11 +940,13 @@ class SmartProxy extends EventEmitter {
getCertificateStatus(routeName: string): Promise<any>;
getEligibleDomainsForCertificates(): string[];
// NFTables
getNfTablesStatus(): Promise<Record<string, any>>;
// NFTables (managed by @push.rocks/smartnftables)
getNfTablesStatus(): INftStatus | null;
// Events
on(event: 'error', handler: (err: Error) => void): this;
on(event: 'certificate-issued', handler: (ev: ICertificateIssuedEvent) => void): this;
on(event: 'certificate-failed', handler: (ev: ICertificateFailedEvent) => void): this;
}
```
@@ -775,6 +969,8 @@ interface ISmartProxyOptions {
// Custom certificate provisioning
certProvisionFunction?: (domain: string) => Promise<ICert | 'http01'>;
certProvisionFallbackToAcme?: boolean; // Fall back to ACME on failure (default: true)
certProvisionTimeout?: number; // Timeout per provision call (ms)
certProvisionConcurrency?: number; // Max concurrent provisions
// Consumer-managed certificate persistence (see "Consumer-Managed Certificate Storage")
certStore?: ISmartProxyCertStore;
@@ -782,6 +978,9 @@ interface ISmartProxyOptions {
// Self-signed fallback
disableDefaultCert?: boolean; // Disable '*' self-signed fallback (default: false)
// Rust binary path override
rustBinaryPath?: string; // Custom path to the Rust proxy binary
// Global defaults
defaults?: {
target?: { host: string; port: number };
@@ -794,11 +993,11 @@ interface ISmartProxyOptions {
sendProxyProtocol?: boolean; // Send PROXY protocol to targets
// Timeouts
connectionTimeout?: number; // Backend connection timeout (default: 30s)
initialDataTimeout?: number; // Initial data/SNI timeout (default: 120s)
socketTimeout?: number; // Socket inactivity timeout (default: 1h)
maxConnectionLifetime?: number; // Max connection lifetime (default: 24h)
inactivityTimeout?: number; // Inactivity timeout (default: 4h)
connectionTimeout?: number; // Backend connection timeout (default: 60s)
initialDataTimeout?: number; // Initial data/SNI timeout (default: 60s)
socketTimeout?: number; // Socket inactivity timeout (default: 60s)
maxConnectionLifetime?: number; // Max connection lifetime (default: 1h)
inactivityTimeout?: number; // Inactivity timeout (default: 75s)
gracefulShutdownTimeout?: number; // Shutdown grace period (default: 30s)
// Connection limits
@@ -807,8 +1006,8 @@ interface ISmartProxyOptions {
// Keep-alive
keepAliveTreatment?: 'standard' | 'extended' | 'immortal';
keepAliveInactivityMultiplier?: number; // (default: 6)
extendedKeepAliveLifetime?: number; // (default: 7 days)
keepAliveInactivityMultiplier?: number; // (default: 4)
extendedKeepAliveLifetime?: number; // (default: 1h)
// Metrics
metrics?: {
@@ -868,11 +1067,22 @@ metrics.requests.perSecond(); // Requests per second
metrics.requests.perMinute(); // Requests per minute
metrics.requests.total(); // Total requests
// UDP metrics
metrics.udp.activeSessions(); // Current active UDP sessions
metrics.udp.totalSessions(); // Total UDP sessions since start
metrics.udp.datagramsIn(); // Datagrams received
metrics.udp.datagramsOut(); // Datagrams sent
// Cumulative totals
metrics.totals.bytesIn(); // Total bytes received
metrics.totals.bytesOut(); // Total bytes sent
metrics.totals.connections(); // Total connections
// Backend metrics
metrics.backends.byBackend(); // Map<backend, IBackendMetrics>
metrics.backends.protocols(); // Map<backend, protocol>
metrics.backends.topByErrors(10); // Top N error-prone backends
// Percentiles
metrics.percentiles.connectionDuration(); // { p50, p95, p99 }
metrics.percentiles.bytesTransferred(); // { in: { p50, p95, p99 }, out: { p50, p95, p99 } }
@@ -896,11 +1106,16 @@ metrics.percentiles.bytesTransferred(); // { in: { p50, p95, p99 }, out: { p5
### Rust Binary Not Found
SmartProxy searches for the Rust binary in this order:
1. `SMARTPROXY_RUST_BINARY` environment variable
2. Platform-specific npm package (`@push.rocks/smartproxy-linux-x64`, etc.)
3. `dist_rust/rustproxy` relative to the package root (built by `tsrust`)
4. Local dev build (`./rust/target/release/rustproxy`)
5. System PATH (`rustproxy`)
1. `rustBinaryPath` option in `ISmartProxyOptions`
2. `SMARTPROXY_RUST_BINARY` environment variable
3. Platform-specific npm package (`@push.rocks/smartproxy-linux-x64`, etc.)
4. `dist_rust/rustproxy` relative to the package root (built by `tsrust`)
5. Local dev build (`./rust/target/release/rustproxy`)
6. System PATH (`rustproxy`)
### QUIC / HTTP3 Caveats
- **GREASE frames are disabled.** The underlying h3 crate sends [GREASE frames](https://www.rfc-editor.org/rfc/rfc9114.html#frame-reserved) by default to test protocol extensibility. However, some HTTP/3 clients and servers don't properly ignore unknown frame types, causing 400/500 errors or stream hangs ([h3#206](https://github.com/hyperium/h3/issues/206)). SmartProxy disables GREASE on both the server side (for incoming H3 requests) and the client side (for H3 backend connections) to maximize compatibility.
- **HTTP/3 is pre-release.** The h3 ecosystem (h3 0.0.8, h3-quinn 0.0.10, quinn 0.11) is still pre-1.0. Expect rough edges.
### Performance Tuning
- ✅ Use NFTables forwarding for high-traffic routes (Linux only)
@@ -924,7 +1139,7 @@ SmartProxy searches for the Rust binary in this order:
## License and Legal Information
This repository contains open-source code licensed under the MIT License. A copy of the license can be found in the [LICENSE](./LICENSE) file.
This repository contains open-source code licensed under the MIT License. A copy of the license can be found in the [license](./license) file.
**Please note:** The MIT License does not grant permission to use the trade names, trademarks, service marks, or product names of the project, except as required for reasonable and customary use in describing the origin of the work and reproducing the content of the NOTICE file.
+484
View File
@@ -0,0 +1,484 @@
# SmartProxy Metrics System
## Architecture
Two-tier design separating the data plane from the observation plane:
**Hot path (per-chunk, lock-free):** All recording in the proxy data plane touches only `AtomicU64` counters. No `Mutex` is ever acquired on the forwarding path. `CountingBody` batches flushes every 64KB to reduce DashMap shard contention.
**Cold path (1Hz sampling):** A background tokio task drains pending atomics into `ThroughputTracker` circular buffers (Mutex-guarded), producing per-second throughput history. Same task prunes orphaned entries and cleans up rate limiter state.
**Read path (on-demand):** `snapshot()` reads all atomics and locks ThroughputTrackers to build a serializable `Metrics` struct. TypeScript polls this at 1s intervals via IPC.
```
Data Plane (lock-free) Background (1Hz) Read Path
───────────────────── ────────────────── ─────────
record_bytes() ──> AtomicU64 ──┐
record_http_request() ──> AtomicU64 ──┤
connection_opened/closed() ──> AtomicU64 ──┤ sample_all() snapshot()
backend_*() ──> DashMap<AtomicU64> ──┤────> drain atomics ──────> Metrics struct
protocol_*() ──> AtomicU64 ──┤ feed ThroughputTrackers ──> JSON
datagram_*() ──> AtomicU64 ──┘ prune orphans ──> IPC stdout
──> TS cache
──> IMetrics API
```
### Key Types
| Type | Crate | Purpose |
|---|---|---|
| `MetricsCollector` | `rustproxy-metrics` | Central store. All DashMaps, atomics, and throughput trackers |
| `ThroughputTracker` | `rustproxy-metrics` | Circular buffer of 1Hz samples. Default 3600 capacity (1 hour) |
| `ForwardMetricsCtx` | `rustproxy-passthrough` | Carries `Arc<MetricsCollector>` + route_id + source_ip through TCP forwarding |
| `CountingBody` | `rustproxy-http` | Wraps HTTP bodies, batches byte recording per 64KB, flushes on drop |
| `ProtocolGuard` | `rustproxy-http` | RAII guard for frontend/backend protocol active/total counters |
| `ConnectionGuard` | `rustproxy-passthrough` | RAII guard calling `connection_closed()` on drop |
| `RustMetricsAdapter` | TypeScript | Polls Rust via IPC, implements `IMetrics` interface over cached JSON |
---
## What's Collected
### Global Counters
| Metric | Type | Updated by |
|---|---|---|
| Active connections | AtomicU64 | `connection_opened/closed` |
| Total connections (lifetime) | AtomicU64 | `connection_opened` |
| Total bytes in | AtomicU64 | `record_bytes` |
| Total bytes out | AtomicU64 | `record_bytes` |
| Total HTTP requests | AtomicU64 | `record_http_request` |
| Active UDP sessions | AtomicU64 | `udp_session_opened/closed` |
| Total UDP sessions | AtomicU64 | `udp_session_opened` |
| Total datagrams in | AtomicU64 | `record_datagram_in` |
| Total datagrams out | AtomicU64 | `record_datagram_out` |
### Per-Route Metrics (keyed by route ID string)
| Metric | Storage |
|---|---|
| Active connections | `DashMap<String, AtomicU64>` |
| Total connections | `DashMap<String, AtomicU64>` |
| Bytes in / out | `DashMap<String, AtomicU64>` |
| Pending throughput (in, out) | `DashMap<String, (AtomicU64, AtomicU64)>` |
| Throughput history | `DashMap<String, Mutex<ThroughputTracker>>` |
Entries are pruned via `retain_routes()` when routes are removed.
### Per-IP Metrics (keyed by IP string)
| Metric | Storage |
|---|---|
| Active connections | `DashMap<String, AtomicU64>` |
| Total connections | `DashMap<String, AtomicU64>` |
| Bytes in / out | `DashMap<String, AtomicU64>` |
| Pending throughput (in, out) | `DashMap<String, (AtomicU64, AtomicU64)>` |
| Throughput history | `DashMap<String, Mutex<ThroughputTracker>>` |
| Domain requests | `DashMap<String, DashMap<String, AtomicU64>>` (IP → domain → count) |
All seven maps for an IP are evicted when its active connection count drops to 0. Safety-net pruning in `sample_all()` catches entries orphaned by races. Snapshots cap at 100 IPs, sorted by active connections descending.
**Domain request tracking:** Records which domains each frontend IP has requested. Populated from HTTP Host headers (for HTTP/1.1, HTTP/2, HTTP/3 requests) and from SNI (for TLS passthrough connections). Domain keys are normalized to lowercase with any trailing dot stripped so the same hostname does not fragment across multiple counters. Capped at 256 domains per IP (`MAX_DOMAINS_PER_IP`) to prevent subdomain-spray abuse. Inner DashMap uses 2 shards to minimise base memory per IP (~200 bytes). Common case (IP + domain both known) is two DashMap reads + one atomic increment with zero allocation.
### Per-Backend Metrics (keyed by "host:port")
| Metric | Storage |
|---|---|
| Active connections | `DashMap<String, AtomicU64>` |
| Total connections | `DashMap<String, AtomicU64>` |
| Detected protocol (h1/h2/h3) | `DashMap<String, String>` |
| Connect errors | `DashMap<String, AtomicU64>` |
| Handshake errors | `DashMap<String, AtomicU64>` |
| Request errors | `DashMap<String, AtomicU64>` |
| Total connect time (microseconds) | `DashMap<String, AtomicU64>` |
| Connect count | `DashMap<String, AtomicU64>` |
| Pool hits | `DashMap<String, AtomicU64>` |
| Pool misses | `DashMap<String, AtomicU64>` |
| H2 failures (fallback to H1) | `DashMap<String, AtomicU64>` |
All per-backend maps are evicted when active count reaches 0. Pruned via `retain_backends()`.
### Frontend Protocol Distribution
Tracked via `ProtocolGuard` RAII guards and `FrontendProtocolTracker`. Five protocol categories, each with active + total counters (AtomicU64):
| Protocol | Where detected |
|---|---|
| h1 | `FrontendProtocolTracker` on first HTTP/1.x request |
| h2 | `FrontendProtocolTracker` on first HTTP/2 request |
| h3 | `ProtocolGuard::frontend("h3")` in H3ProxyService |
| ws | `ProtocolGuard::frontend("ws")` on WebSocket upgrade |
| other | Fallback (TCP passthrough without HTTP) |
Uses `fetch_update` for saturating decrements to prevent underflow races. The same saturating-close pattern is also used for connection and UDP active gauges.
### Backend Protocol Distribution
Same five categories (h1/h2/h3/ws/other), tracked via `ProtocolGuard::backend()` at connection establishment. Backend h2 failures (fallback to h1) are separately counted.
### Throughput History
`ThroughputTracker` is a circular buffer storing `ThroughputSample { timestamp_ms, bytes_in, bytes_out }` at 1Hz.
- Global tracker: 1 instance, default 3600 capacity
- Per-route trackers: 1 per active route
- Per-IP trackers: 1 per connected IP (evicted with the IP)
- HTTP request tracker: reuses ThroughputTracker with bytes_in = request count, bytes_out = 0
Query methods:
- `instant()` — last 1 second average
- `recent()` — last 10 seconds average
- `throughput(N)` — last N seconds average
- `history(N)` — last N raw samples in chronological order
Snapshots return 60 samples of global throughput history.
### Protocol Detection Cache
Not part of MetricsCollector. Maintained by `HttpProxyService`'s protocol detection system. Injected into the metrics snapshot at read time by `get_metrics()`.
Each entry records: host, port, domain, detected protocol (h1/h2/h3), H3 port, age, last accessed, last probed, suppression flags, cooldown timers, consecutive failure counts.
---
## Instrumentation Points
### TCP Passthrough (`rustproxy-passthrough`)
**Connection lifecycle**`tcp_listener.rs`:
- Accept: `conn_tracker.connection_opened(&ip)` (rate limiter) + `ConnectionTrackerGuard` RAII
- Route match: `metrics.connection_opened(route_id, source_ip)` + `ConnectionGuard` RAII
- Close: Both guards call their respective `_closed()` methods on drop
**Byte recording**`forwarder.rs` (`forward_bidirectional_with_timeouts`):
- Initial peeked data recorded immediately
- Per-chunk in both directions: `record_bytes(n, 0, ...)` / `record_bytes(0, n, ...)`
- Same pattern in `forward_bidirectional_split_with_timeouts` (tcp_listener.rs) for TLS-terminated paths
### HTTP Proxy (`rustproxy-http`)
**Request counting**`proxy_service.rs`:
- `record_http_request()` called once per request after route matching succeeds
**Body byte counting**`counting_body.rs` wrapping:
- Request bodies: `CountingBody::new(body, ..., Direction::In)` — counts client-to-upstream bytes
- Response bodies: `CountingBody::new(body, ..., Direction::Out)` — counts upstream-to-client bytes
- Batched flush every 64KB (`BYTE_FLUSH_THRESHOLD = 65_536`), remainder flushed on drop
- Also updates `connection_activity` atomic (idle watchdog) and `active_requests` counter (streaming detection)
**Backend metrics**`proxy_service.rs`:
- `backend_connection_opened(key, connect_time)` — after TCP/TLS connect succeeds
- `backend_connection_closed(key)` — on teardown
- `backend_connect_error(key)` — TCP/TLS connect failure or timeout
- `backend_handshake_error(key)` — H1/H2 protocol handshake failure
- `backend_request_error(key)` — send_request failure
- `backend_h2_failure(key)` — H2 attempted, fell back to H1
- `backend_pool_hit(key)` / `backend_pool_miss(key)` — connection pool reuse
- `set_backend_protocol(key, proto)` — records detected protocol
**WebSocket**`proxy_service.rs`:
- Does NOT use CountingBody; records bytes directly per-chunk in both directions of the bidirectional copy loop
### QUIC (`rustproxy-passthrough`)
**Connection level**`quic_handler.rs`:
- `connection_opened` / `connection_closed` via `QuicConnGuard` RAII
- `conn_tracker.connection_opened/closed` for rate limiting
**Stream level**:
- For QUIC-to-TCP stream forwarding: `record_bytes(bytes_in, bytes_out, ...)` called once per stream at completion (post-hoc, not per-chunk)
- For HTTP/3: delegates to `HttpProxyService.handle_request()`, so all HTTP proxy metrics apply
**H3 specifics**`h3_service.rs`:
- `ProtocolGuard::frontend("h3")` tracks the H3 connection
- H3 request bodies: `record_bytes(data.len(), 0, ...)` called directly (not CountingBody) since H3 uses `stream.send_data()`
- H3 response bodies: wrapped in CountingBody like HTTP/1 and HTTP/2
### UDP (`rustproxy-passthrough`)
**Session lifecycle**`udp_listener.rs` / `udp_session.rs`:
- `udp_session_opened()` + `connection_opened(route_id, source_ip)` on new session
- `udp_session_closed()` + `connection_closed(route_id, source_ip)` on idle reap or port drain
**Datagram counting**`udp_listener.rs`:
- Inbound: `record_bytes(len, 0, ...)` + `record_datagram_in()`
- Outbound (backend reply): `record_bytes(0, len, ...)` + `record_datagram_out()`
---
## Sampling Loop
`lib.rs` spawns a tokio task at configurable interval (default 1000ms):
```rust
loop {
tokio::select! {
_ = cancel => break,
_ = interval.tick() => {
metrics.sample_all();
conn_tracker.cleanup_stale_timestamps();
http_proxy.cleanup_all_rate_limiters();
}
}
}
```
`sample_all()` performs in one pass:
1. Drains `global_pending_tp_in/out` into global ThroughputTracker, samples
2. Drains per-route pending counters into per-route trackers, samples each
3. Samples idle route trackers (no new data) to advance their window
4. Drains per-IP pending counters into per-IP trackers, samples each
5. Drains `pending_http_requests` into HTTP request throughput tracker
6. Prunes orphaned per-IP entries (bytes/throughput maps with no matching ip_connections key)
7. Prunes orphaned per-backend entries (error/stats maps with no matching active/total key)
---
## Data Flow: Rust to TypeScript
```
MetricsCollector.snapshot()
├── reads all AtomicU64 counters
├── iterates DashMaps (routes, IPs, backends)
├── locks ThroughputTrackers for instant/recent rates + history
└── produces Metrics struct
RustProxy::get_metrics()
├── calls snapshot()
├── enriches with detectedProtocols from HTTP proxy protocol cache
└── returns Metrics
management.rs "getMetrics" IPC command
├── calls get_metrics()
├── serde_json::to_value (camelCase)
└── writes JSON to stdout
RustProxyBridge (TypeScript)
├── reads JSON from Rust process stdout
└── returns parsed object
RustMetricsAdapter
├── setInterval polls bridge.getMetrics() every 1s
├── stores raw JSON in this.cache
└── IMetrics methods read synchronously from cache
SmartProxy.getMetrics()
└── returns the RustMetricsAdapter instance
```
### IPC JSON Shape (Metrics)
```json
{
"activeConnections": 42,
"totalConnections": 1000,
"bytesIn": 123456789,
"bytesOut": 987654321,
"throughputInBytesPerSec": 50000,
"throughputOutBytesPerSec": 80000,
"throughputRecentInBytesPerSec": 45000,
"throughputRecentOutBytesPerSec": 75000,
"routes": {
"<route-id>": {
"activeConnections": 5,
"totalConnections": 100,
"bytesIn": 0, "bytesOut": 0,
"throughputInBytesPerSec": 0, "throughputOutBytesPerSec": 0,
"throughputRecentInBytesPerSec": 0, "throughputRecentOutBytesPerSec": 0
}
},
"ips": {
"<ip>": {
"activeConnections": 2, "totalConnections": 10,
"bytesIn": 0, "bytesOut": 0,
"throughputInBytesPerSec": 0, "throughputOutBytesPerSec": 0,
"domainRequests": {
"example.com": 4821,
"api.example.com": 312
}
}
},
"backends": {
"<host:port>": {
"activeConnections": 3, "totalConnections": 50,
"protocol": "h2",
"connectErrors": 0, "handshakeErrors": 0, "requestErrors": 0,
"totalConnectTimeUs": 150000, "connectCount": 50,
"poolHits": 40, "poolMisses": 10, "h2Failures": 1
}
},
"throughputHistory": [
{ "timestampMs": 1713000000000, "bytesIn": 50000, "bytesOut": 80000 }
],
"totalHttpRequests": 5000,
"httpRequestsPerSec": 100,
"httpRequestsPerSecRecent": 95,
"activeUdpSessions": 0, "totalUdpSessions": 5,
"totalDatagramsIn": 1000, "totalDatagramsOut": 1000,
"frontendProtocols": {
"h1Active": 10, "h1Total": 500,
"h2Active": 5, "h2Total": 200,
"h3Active": 1, "h3Total": 50,
"wsActive": 2, "wsTotal": 30,
"otherActive": 0, "otherTotal": 0
},
"backendProtocols": { "...same shape..." },
"detectedProtocols": [
{
"host": "backend", "port": 443, "domain": "example.com",
"protocol": "h2", "h3Port": 443,
"ageSecs": 120, "lastAccessedSecs": 5, "lastProbedSecs": 120,
"h2Suppressed": false, "h3Suppressed": false,
"h2CooldownRemainingSecs": null, "h3CooldownRemainingSecs": null,
"h2ConsecutiveFailures": null, "h3ConsecutiveFailures": null
}
]
}
```
### IPC JSON Shape (Statistics)
Lightweight administrative summary, fetched on-demand (not polled):
```json
{
"activeConnections": 42,
"totalConnections": 1000,
"routesCount": 5,
"listeningPorts": [80, 443, 8443],
"uptimeSeconds": 86400
}
```
---
## TypeScript Consumer API
`SmartProxy.getMetrics()` returns an `IMetrics` object. All members are synchronous methods reading from the polled cache.
### connections
| Method | Return | Source |
|---|---|---|
| `active()` | `number` | `cache.activeConnections` |
| `total()` | `number` | `cache.totalConnections` |
| `byRoute()` | `Map<string, number>` | `cache.routes[name].activeConnections` |
| `byIP()` | `Map<string, number>` | `cache.ips[ip].activeConnections` |
| `topIPs(limit?)` | `Array<{ip, count}>` | `cache.ips` sorted by active desc, default 10 |
| `domainRequestsByIP()` | `Map<string, Map<string, number>>` | `cache.ips[ip].domainRequests` |
| `topDomainRequests(limit?)` | `Array<{ip, domain, count}>` | Flattened from all IPs, sorted by count desc, default 20 |
| `frontendProtocols()` | `IProtocolDistribution` | `cache.frontendProtocols.*` |
| `backendProtocols()` | `IProtocolDistribution` | `cache.backendProtocols.*` |
### throughput
| Method | Return | Source |
|---|---|---|
| `instant()` | `{in, out}` | `cache.throughputInBytesPerSec/Out` |
| `recent()` | `{in, out}` | `cache.throughputRecentInBytesPerSec/Out` |
| `average()` | `{in, out}` | Falls back to `instant()` (not wired to windowed average) |
| `custom(seconds)` | `{in, out}` | Falls back to `instant()` (not wired) |
| `history(seconds)` | `IThroughputHistoryPoint[]` | `cache.throughputHistory` sliced to last N entries |
| `byRoute(windowSeconds?)` | `Map<string, {in, out}>` | `cache.routes[name].throughputIn/OutBytesPerSec` |
| `byIP(windowSeconds?)` | `Map<string, {in, out}>` | `cache.ips[ip].throughputIn/OutBytesPerSec` |
### requests
| Method | Return | Source |
|---|---|---|
| `perSecond()` | `number` | `cache.httpRequestsPerSec` |
| `perMinute()` | `number` | `cache.httpRequestsPerSecRecent * 60` |
| `total()` | `number` | `cache.totalHttpRequests` (fallback: totalConnections) |
### totals
| Method | Return | Source |
|---|---|---|
| `bytesIn()` | `number` | `cache.bytesIn` |
| `bytesOut()` | `number` | `cache.bytesOut` |
| `connections()` | `number` | `cache.totalConnections` |
### backends
| Method | Return | Source |
|---|---|---|
| `byBackend()` | `Map<string, IBackendMetrics>` | `cache.backends[key].*` with computed `avgConnectTimeMs` and `poolHitRate` |
| `protocols()` | `Map<string, string>` | `cache.backends[key].protocol` |
| `topByErrors(limit?)` | `IBackendMetrics[]` | Sorted by total errors desc |
| `detectedProtocols()` | `IProtocolCacheEntry[]` | `cache.detectedProtocols` passthrough |
`IBackendMetrics`: `{ protocol, activeConnections, totalConnections, connectErrors, handshakeErrors, requestErrors, avgConnectTimeMs, poolHitRate, h2Failures }`
### udp
| Method | Return | Source |
|---|---|---|
| `activeSessions()` | `number` | `cache.activeUdpSessions` |
| `totalSessions()` | `number` | `cache.totalUdpSessions` |
| `datagramsIn()` | `number` | `cache.totalDatagramsIn` |
| `datagramsOut()` | `number` | `cache.totalDatagramsOut` |
### percentiles (stub)
`connectionDuration()` and `bytesTransferred()` always return zeros. Not implemented.
---
## Configuration
```typescript
interface IMetricsConfig {
enabled: boolean; // default true
sampleIntervalMs: number; // default 1000 (1Hz sampling + TS polling)
retentionSeconds: number; // default 3600 (ThroughputTracker capacity)
enableDetailedTracking: boolean;
enablePercentiles: boolean;
cacheResultsMs: number;
prometheusEnabled: boolean; // not wired
prometheusPath: string; // not wired
prometheusPrefix: string; // not wired
}
```
Rust-side config (`MetricsConfig` in `rustproxy-config`):
```rust
pub struct MetricsConfig {
pub enabled: Option<bool>,
pub sample_interval_ms: Option<u64>, // default 1000
pub retention_seconds: Option<u64>, // default 3600
}
```
---
## Design Decisions
**Lock-free hot path.** `record_bytes()` is the most frequently called method (per-chunk in TCP, per-64KB in HTTP). It only touches `AtomicU64` with `Relaxed` ordering and short-circuits zero-byte directions to skip DashMap lookups entirely.
**CountingBody batching.** HTTP body frames are typically 16KB. Flushing to MetricsCollector every 64KB reduces DashMap shard contention by ~4x compared to per-frame recording.
**RAII guards everywhere.** `ConnectionGuard`, `ConnectionTrackerGuard`, `QuicConnGuard`, `ProtocolGuard`, `FrontendProtocolTracker` all use Drop to guarantee counter cleanup on all exit paths including panics.
**Saturating decrements.** Protocol counters use `fetch_update` loops instead of `fetch_sub` to prevent underflow to `u64::MAX` from concurrent close races.
**Bounded memory.** Per-IP entries evicted on last connection close. Per-backend entries evicted on last connection close. Snapshot caps IPs and backends at 100 each. `sample_all()` prunes orphaned entries every second.
**Two-phase throughput.** Pending bytes accumulate in lock-free atomics. The 1Hz cold path drains them into Mutex-guarded ThroughputTrackers. This means the hot path never contends on a Mutex, while the cold path does minimal work (one drain + one sample per tracker).
---
## Known Gaps
| Gap | Status |
|---|---|
| `throughput.average()` / `throughput.custom(seconds)` | Fall back to `instant()`. Not wired to Rust windowed queries. |
| `percentiles.connectionDuration()` / `percentiles.bytesTransferred()` | Stub returning zeros. No histogram in Rust. |
| Prometheus export | Config fields exist but not wired to any exporter. |
| `LogDeduplicator` | Implemented in `rustproxy-metrics` but not connected to any call site. |
| Rate limit hit counters | Rate-limited requests return 429 but no counter is recorded in MetricsCollector. |
| QUIC stream byte counting | Post-hoc (per-stream totals after close), not per-chunk like TCP. |
| Throughput history in snapshot | Capped at 60 samples. TS `history(seconds)` cannot return more than 60 points regardless of `retentionSeconds`. |
| Per-route total connections / bytes | Available in Rust JSON but `IMetrics.connections.byRoute()` only exposes active connections. |
| Per-IP total connections / bytes | Available in Rust JSON but `IMetrics.connections.byIP()` only exposes active connections. |
| IPC response typing | `RustProxyBridge` declares `result: any` for both metrics commands. No type-safe response. |
+452 -15
View File
@@ -157,12 +157,24 @@ dependencies = [
"shlex",
]
[[package]]
name = "cesu8"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6d43a04d8753f35258c91f8ec639f792891f748a1edbd759cf1dcea3382ad83c"
[[package]]
name = "cfg-if"
version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801"
[[package]]
name = "cfg_aliases"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724"
[[package]]
name = "clap"
version = "4.5.57"
@@ -218,6 +230,16 @@ version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b05b61dc5112cbb17e4b6cd61790d9845d13888356391624cbe7e41efeac1e75"
[[package]]
name = "combine"
version = "4.6.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ba5a308b75df32fe02788e748662718f03fde005016435c444eea572398219fd"
dependencies = [
"bytes",
"memchr",
]
[[package]]
name = "core-foundation"
version = "0.10.1"
@@ -285,6 +307,24 @@ dependencies = [
"windows-sys 0.61.2",
]
[[package]]
name = "fastbloom"
version = "0.14.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4e7f34442dbe69c60fe8eaf58a8cafff81a1f278816d8ab4db255b3bef4ac3c4"
dependencies = [
"getrandom 0.3.4",
"libm",
"rand",
"siphasher",
]
[[package]]
name = "fastrand"
version = "2.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be"
[[package]]
name = "find-msvc-tools"
version = "0.1.9"
@@ -303,6 +343,21 @@ version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c"
[[package]]
name = "futures"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876"
dependencies = [
"futures-channel",
"futures-core",
"futures-executor",
"futures-io",
"futures-sink",
"futures-task",
"futures-util",
]
[[package]]
name = "futures-channel"
version = "0.3.31"
@@ -310,6 +365,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2dff15bf788c671c1934e366d07e30c1814a8ef514e1af724a602e8a2fbe1b10"
dependencies = [
"futures-core",
"futures-sink",
]
[[package]]
@@ -318,6 +374,34 @@ version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e"
[[package]]
name = "futures-executor"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e28d1d997f585e54aebc3f97d39e72338912123a67330d723fdbb564d646c9f"
dependencies = [
"futures-core",
"futures-task",
"futures-util",
]
[[package]]
name = "futures-io"
version = "0.3.32"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cecba35d7ad927e23624b22ad55235f2239cfa44fd10428eecbeba6d6a717718"
[[package]]
name = "futures-macro"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "futures-sink"
version = "0.3.31"
@@ -336,10 +420,16 @@ version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81"
dependencies = [
"futures-channel",
"futures-core",
"futures-io",
"futures-macro",
"futures-sink",
"futures-task",
"memchr",
"pin-project-lite",
"pin-utils",
"slab",
]
[[package]]
@@ -362,9 +452,11 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "899def5c37c4fd7b2664648c28120ecec138e4d395b459e5ca34f9cce2dd77fd"
dependencies = [
"cfg-if",
"js-sys",
"libc",
"r-efi",
"wasip2",
"wasm-bindgen",
]
[[package]]
@@ -392,6 +484,34 @@ dependencies = [
"tracing",
]
[[package]]
name = "h3"
version = "0.0.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "10872b55cfb02a821b69dc7cf8dc6a71d6af25eb9a79662bec4a9d016056b3be"
dependencies = [
"bytes",
"fastrand",
"futures-util",
"http",
"pin-project-lite",
"tokio",
]
[[package]]
name = "h3-quinn"
version = "0.0.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8b2e732c8d91a74731663ac8479ab505042fbf547b9a207213ab7fbcbfc4f8b4"
dependencies = [
"bytes",
"futures",
"h3",
"quinn",
"tokio",
"tokio-util",
]
[[package]]
name = "hashbrown"
version = "0.14.5"
@@ -565,6 +685,28 @@ version = "1.0.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "92ecc6618181def0457392ccd0ee51198e065e016d1d527a7ac1b6dc7c1f09d2"
[[package]]
name = "jni"
version = "0.21.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1a87aa2bb7d2af34197c04845522473242e1aa17c12f4935d5856491a7fb8c97"
dependencies = [
"cesu8",
"cfg-if",
"combine",
"jni-sys",
"log",
"thiserror 1.0.69",
"walkdir",
"windows-sys 0.45.0",
]
[[package]]
name = "jni-sys"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8eaf4bc02d17cbdd7ff4c7438cafcdf7fb9a4613313ad11b4f8fefe7d3fa0130"
[[package]]
name = "jobserver"
version = "0.1.34"
@@ -612,6 +754,12 @@ version = "0.2.180"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bcc35a38544a891a5f7c865aca548a982ccb3b8650a5b06d0fd33a10283c56fc"
[[package]]
name = "libm"
version = "0.2.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b6d2cec3eae94f9f509c767b45932f1ada8350c4bdb85af2fcab4a3c14807981"
[[package]]
name = "libmimalloc-sys"
version = "0.1.44"
@@ -637,6 +785,12 @@ version = "0.4.29"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897"
[[package]]
name = "lru-slab"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "112b39cec0b298b6c1999fee3e31427f74f676e4cb9879ed1a121b43661a4154"
[[package]]
name = "matchers"
version = "0.2.0"
@@ -784,6 +938,15 @@ version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391"
[[package]]
name = "ppv-lite86"
version = "0.2.21"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "85eae3c4ed2f50dcfe72643da4befc30deadb458a9b590d720cde2f2b1e97da9"
dependencies = [
"zerocopy",
]
[[package]]
name = "proc-macro2"
version = "1.0.106"
@@ -793,6 +956,64 @@ dependencies = [
"unicode-ident",
]
[[package]]
name = "quinn"
version = "0.11.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b9e20a958963c291dc322d98411f541009df2ced7b5a4f2bd52337638cfccf20"
dependencies = [
"bytes",
"cfg_aliases",
"futures-io",
"pin-project-lite",
"quinn-proto",
"quinn-udp",
"rustc-hash",
"rustls",
"socket2 0.6.2",
"thiserror 2.0.18",
"tokio",
"tracing",
"web-time",
]
[[package]]
name = "quinn-proto"
version = "0.11.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "434b42fec591c96ef50e21e886936e66d3cc3f737104fdb9b737c40ffb94c098"
dependencies = [
"bytes",
"fastbloom",
"getrandom 0.3.4",
"lru-slab",
"rand",
"ring",
"rustc-hash",
"rustls",
"rustls-pki-types",
"rustls-platform-verifier",
"slab",
"thiserror 2.0.18",
"tinyvec",
"tracing",
"web-time",
]
[[package]]
name = "quinn-udp"
version = "0.5.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "addec6a0dcad8a8d96a771f815f0eaf55f9d1805756410b39f5fa81332574cbd"
dependencies = [
"cfg_aliases",
"libc",
"once_cell",
"socket2 0.6.2",
"tracing",
"windows-sys 0.60.2",
]
[[package]]
name = "quote"
version = "1.0.44"
@@ -808,6 +1029,35 @@ version = "5.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f"
[[package]]
name = "rand"
version = "0.9.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6db2770f06117d490610c7488547d543617b21bfa07796d7a12f6f1bd53850d1"
dependencies = [
"rand_chacha",
"rand_core",
]
[[package]]
name = "rand_chacha"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb"
dependencies = [
"ppv-lite86",
"rand_core",
]
[[package]]
name = "rand_core"
version = "0.9.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "76afc826de14238e6e8c374ddcc1fa19e374fd8dd986b0d2af0d02377261d83c"
dependencies = [
"getrandom 0.3.4",
]
[[package]]
name = "rcgen"
version = "0.13.2"
@@ -873,6 +1123,12 @@ dependencies = [
"windows-sys 0.52.0",
]
[[package]]
name = "rustc-hash"
version = "2.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d"
[[package]]
name = "rustls"
version = "0.23.36"
@@ -916,9 +1172,37 @@ version = "1.14.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "be040f8b0a225e40375822a563fa9524378b9d63112f53e19ffff34df5d33fdd"
dependencies = [
"web-time",
"zeroize",
]
[[package]]
name = "rustls-platform-verifier"
version = "0.6.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1d99feebc72bae7ab76ba994bb5e121b8d83d910ca40b36e0921f53becc41784"
dependencies = [
"core-foundation",
"core-foundation-sys",
"jni",
"log",
"once_cell",
"rustls",
"rustls-native-certs",
"rustls-platform-verifier-android",
"rustls-webpki",
"security-framework",
"security-framework-sys",
"webpki-root-certs",
"windows-sys 0.61.2",
]
[[package]]
name = "rustls-platform-verifier-android"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f87165f0995f63a9fbeea62b64d10b4d9d8e78ec6d7d51fb2125fda7bb36788f"
[[package]]
name = "rustls-webpki"
version = "0.103.9"
@@ -940,16 +1224,20 @@ dependencies = [
"bytes",
"clap",
"dashmap",
"h3",
"h3-quinn",
"http",
"http-body-util",
"hyper",
"hyper-util",
"mimalloc",
"quinn",
"rcgen",
"rustls",
"rustls-pemfile",
"rustproxy-config",
"rustproxy-http",
"rustproxy-metrics",
"rustproxy-nftables",
"rustproxy-passthrough",
"rustproxy-routing",
"rustproxy-security",
@@ -981,10 +1269,14 @@ dependencies = [
"arc-swap",
"bytes",
"dashmap",
"futures",
"h3",
"h3-quinn",
"http-body",
"http-body-util",
"hyper",
"hyper-util",
"quinn",
"regex",
"rustls",
"rustproxy-config",
@@ -1011,33 +1303,23 @@ dependencies = [
"tracing",
]
[[package]]
name = "rustproxy-nftables"
version = "0.1.0"
dependencies = [
"anyhow",
"libc",
"rustproxy-config",
"serde",
"serde_json",
"thiserror 2.0.18",
"tokio",
"tracing",
]
[[package]]
name = "rustproxy-passthrough"
version = "0.1.0"
dependencies = [
"anyhow",
"arc-swap",
"base64",
"dashmap",
"quinn",
"rcgen",
"rustls",
"rustls-pemfile",
"rustproxy-config",
"rustproxy-http",
"rustproxy-metrics",
"rustproxy-routing",
"rustproxy-security",
"serde",
"serde_json",
"socket2 0.5.10",
@@ -1096,6 +1378,15 @@ version = "1.0.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d"
[[package]]
name = "same-file"
version = "1.0.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502"
dependencies = [
"winapi-util",
]
[[package]]
name = "schannel"
version = "0.1.28"
@@ -1214,6 +1505,12 @@ dependencies = [
"time",
]
[[package]]
name = "siphasher"
version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b2aa850e253778c88a04c3d7323b043aeda9d3e30d5971937c1855769763678e"
[[package]]
name = "slab"
version = "0.4.12"
@@ -1349,6 +1646,21 @@ dependencies = [
"time-core",
]
[[package]]
name = "tinyvec"
version = "1.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3e61e67053d25a4e82c844e8424039d9745781b3fc4f32b8d55ed50f5f667ef3"
dependencies = [
"tinyvec_macros",
]
[[package]]
name = "tinyvec_macros"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20"
[[package]]
name = "tokio"
version = "1.49.0"
@@ -1412,6 +1724,7 @@ version = "0.1.44"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "63e71662fa4b2a2c3a26f570f037eb95bb1f85397f3cd8076caed2f026a6d100"
dependencies = [
"log",
"pin-project-lite",
"tracing-attributes",
"tracing-core",
@@ -1497,6 +1810,16 @@ version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ba73ea9cf16a25df0c8caa16c51acb937d5712a8429db78a3ee29d5dcacd3a65"
[[package]]
name = "walkdir"
version = "2.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "29790946404f91d9c5d06f9874efddea1dc06c5efe94541a7d6863108e3a5e4b"
dependencies = [
"same-file",
"winapi-util",
]
[[package]]
name = "want"
version = "0.3.1"
@@ -1566,12 +1889,49 @@ dependencies = [
"unicode-ident",
]
[[package]]
name = "web-time"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5a6580f308b1fad9207618087a65c04e7a10bc77e02c8e84e9b00dd4b12fa0bb"
dependencies = [
"js-sys",
"wasm-bindgen",
]
[[package]]
name = "webpki-root-certs"
version = "1.0.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "804f18a4ac2676ffb4e8b5b5fa9ae38af06df08162314f96a68d2a363e21a8ca"
dependencies = [
"rustls-pki-types",
]
[[package]]
name = "winapi-util"
version = "0.1.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22"
dependencies = [
"windows-sys 0.61.2",
]
[[package]]
name = "windows-link"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5"
[[package]]
name = "windows-sys"
version = "0.45.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "75283be5efb2831d37ea142365f009c02ec203cd29a3ebecbc093d52315b66d0"
dependencies = [
"windows-targets 0.42.2",
]
[[package]]
name = "windows-sys"
version = "0.52.0"
@@ -1599,6 +1959,21 @@ dependencies = [
"windows-link",
]
[[package]]
name = "windows-targets"
version = "0.42.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8e5180c00cd44c9b1c88adb3693291f1cd93605ded80c250a75d472756b4d071"
dependencies = [
"windows_aarch64_gnullvm 0.42.2",
"windows_aarch64_msvc 0.42.2",
"windows_i686_gnu 0.42.2",
"windows_i686_msvc 0.42.2",
"windows_x86_64_gnu 0.42.2",
"windows_x86_64_gnullvm 0.42.2",
"windows_x86_64_msvc 0.42.2",
]
[[package]]
name = "windows-targets"
version = "0.52.6"
@@ -1632,6 +2007,12 @@ dependencies = [
"windows_x86_64_msvc 0.53.1",
]
[[package]]
name = "windows_aarch64_gnullvm"
version = "0.42.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "597a5118570b68bc08d8d59125332c54f1ba9d9adeedeef5b99b02ba2b0698f8"
[[package]]
name = "windows_aarch64_gnullvm"
version = "0.52.6"
@@ -1644,6 +2025,12 @@ version = "0.53.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a9d8416fa8b42f5c947f8482c43e7d89e73a173cead56d044f6a56104a6d1b53"
[[package]]
name = "windows_aarch64_msvc"
version = "0.42.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e08e8864a60f06ef0d0ff4ba04124db8b0fb3be5776a5cd47641e942e58c4d43"
[[package]]
name = "windows_aarch64_msvc"
version = "0.52.6"
@@ -1656,6 +2043,12 @@ version = "0.53.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b9d782e804c2f632e395708e99a94275910eb9100b2114651e04744e9b125006"
[[package]]
name = "windows_i686_gnu"
version = "0.42.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c61d927d8da41da96a81f029489353e68739737d3beca43145c8afec9a31a84f"
[[package]]
name = "windows_i686_gnu"
version = "0.52.6"
@@ -1680,6 +2073,12 @@ version = "0.53.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fa7359d10048f68ab8b09fa71c3daccfb0e9b559aed648a8f95469c27057180c"
[[package]]
name = "windows_i686_msvc"
version = "0.42.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "44d840b6ec649f480a41c8d80f9c65108b92d89345dd94027bfe06ac444d1060"
[[package]]
name = "windows_i686_msvc"
version = "0.52.6"
@@ -1692,6 +2091,12 @@ version = "0.53.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e7ac75179f18232fe9c285163565a57ef8d3c89254a30685b57d83a38d326c2"
[[package]]
name = "windows_x86_64_gnu"
version = "0.42.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8de912b8b8feb55c064867cf047dda097f92d51efad5b491dfb98f6bbb70cb36"
[[package]]
name = "windows_x86_64_gnu"
version = "0.52.6"
@@ -1704,6 +2109,12 @@ version = "0.53.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9c3842cdd74a865a8066ab39c8a7a473c0778a3f29370b5fd6b4b9aa7df4a499"
[[package]]
name = "windows_x86_64_gnullvm"
version = "0.42.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "26d41b46a36d453748aedef1486d5c7a85db22e56aff34643984ea85514e94a3"
[[package]]
name = "windows_x86_64_gnullvm"
version = "0.52.6"
@@ -1716,6 +2127,12 @@ version = "0.53.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0ffa179e2d07eee8ad8f57493436566c7cc30ac536a3379fdf008f47f6bb7ae1"
[[package]]
name = "windows_x86_64_msvc"
version = "0.42.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9aec5da331524158c6d1a4ac0ab1541149c0b9505fde06423b02f5ef0106b9f0"
[[package]]
name = "windows_x86_64_msvc"
version = "0.52.6"
@@ -1743,6 +2160,26 @@ dependencies = [
"time",
]
[[package]]
name = "zerocopy"
version = "0.8.42"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f2578b716f8a7a858b7f02d5bd870c14bf4ddbbcf3a4c05414ba6503640505e3"
dependencies = [
"zerocopy-derive",
]
[[package]]
name = "zerocopy-derive"
version = "0.8.42"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7e6cc098ea4d3bd6246687de65af3f920c430e236bee1e3bf2e441463f08a02f"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "zeroize"
version = "1.8.2"
+7 -2
View File
@@ -7,7 +7,6 @@ members = [
"crates/rustproxy-tls",
"crates/rustproxy-passthrough",
"crates/rustproxy-http",
"crates/rustproxy-nftables",
"crates/rustproxy-metrics",
"crates/rustproxy-security",
]
@@ -91,6 +90,13 @@ libc = "0.2"
# Socket-level options (keepalive, etc.)
socket2 = { version = "0.5", features = ["all"] }
# QUIC transport
quinn = "0.11"
# HTTP/3 protocol
h3 = "0.0.8"
h3-quinn = "0.0.10"
# mimalloc allocator (prevents glibc fragmentation / slow RSS growth)
mimalloc = "0.1"
@@ -100,6 +106,5 @@ rustproxy-routing = { path = "crates/rustproxy-routing" }
rustproxy-tls = { path = "crates/rustproxy-tls" }
rustproxy-passthrough = { path = "crates/rustproxy-passthrough" }
rustproxy-http = { path = "crates/rustproxy-http" }
rustproxy-nftables = { path = "crates/rustproxy-nftables" }
rustproxy-metrics = { path = "crates/rustproxy-metrics" }
rustproxy-security = { path = "crates/rustproxy-security" }
-337
View File
@@ -1,337 +0,0 @@
use crate::route_types::*;
use crate::tls_types::*;
/// Create a simple HTTP forwarding route.
/// Equivalent to SmartProxy's `createHttpRoute()`.
pub fn create_http_route(
domains: impl Into<DomainSpec>,
target_host: impl Into<String>,
target_port: u16,
) -> RouteConfig {
RouteConfig {
id: None,
route_match: RouteMatch {
ports: PortRange::Single(80),
domains: Some(domains.into()),
path: None,
client_ip: None,
tls_version: None,
headers: None,
protocol: None,
},
action: RouteAction {
action_type: RouteActionType::Forward,
targets: Some(vec![RouteTarget {
target_match: None,
host: HostSpec::Single(target_host.into()),
port: PortSpec::Fixed(target_port),
tls: None,
websocket: None,
load_balancing: None,
send_proxy_protocol: None,
headers: None,
advanced: None,
priority: None,
}]),
tls: None,
websocket: None,
load_balancing: None,
advanced: None,
options: None,
forwarding_engine: None,
nftables: None,
send_proxy_protocol: None,
},
headers: None,
security: None,
name: None,
description: None,
priority: None,
tags: None,
enabled: None,
}
}
/// Create an HTTPS termination route.
/// Equivalent to SmartProxy's `createHttpsTerminateRoute()`.
pub fn create_https_terminate_route(
domains: impl Into<DomainSpec>,
target_host: impl Into<String>,
target_port: u16,
) -> RouteConfig {
let mut route = create_http_route(domains, target_host, target_port);
route.route_match.ports = PortRange::Single(443);
route.action.tls = Some(RouteTls {
mode: TlsMode::Terminate,
certificate: Some(CertificateSpec::Auto("auto".to_string())),
acme: None,
versions: None,
ciphers: None,
honor_cipher_order: None,
session_timeout: None,
});
route
}
/// Create a TLS passthrough route.
/// Equivalent to SmartProxy's `createHttpsPassthroughRoute()`.
pub fn create_https_passthrough_route(
domains: impl Into<DomainSpec>,
target_host: impl Into<String>,
target_port: u16,
) -> RouteConfig {
let mut route = create_http_route(domains, target_host, target_port);
route.route_match.ports = PortRange::Single(443);
route.action.tls = Some(RouteTls {
mode: TlsMode::Passthrough,
certificate: None,
acme: None,
versions: None,
ciphers: None,
honor_cipher_order: None,
session_timeout: None,
});
route
}
/// Create an HTTP-to-HTTPS redirect route.
/// Equivalent to SmartProxy's `createHttpToHttpsRedirect()`.
pub fn create_http_to_https_redirect(
domains: impl Into<DomainSpec>,
) -> RouteConfig {
let domains = domains.into();
RouteConfig {
id: None,
route_match: RouteMatch {
ports: PortRange::Single(80),
domains: Some(domains),
path: None,
client_ip: None,
tls_version: None,
headers: None,
protocol: None,
},
action: RouteAction {
action_type: RouteActionType::Forward,
targets: None,
tls: None,
websocket: None,
load_balancing: None,
advanced: Some(RouteAdvanced {
timeout: None,
headers: None,
keep_alive: None,
static_files: None,
test_response: Some(RouteTestResponse {
status: 301,
headers: {
let mut h = std::collections::HashMap::new();
h.insert("Location".to_string(), "https://{domain}{path}".to_string());
h
},
body: String::new(),
}),
url_rewrite: None,
}),
options: None,
forwarding_engine: None,
nftables: None,
send_proxy_protocol: None,
},
headers: None,
security: None,
name: Some("HTTP to HTTPS Redirect".to_string()),
description: None,
priority: None,
tags: None,
enabled: None,
}
}
/// Create a complete HTTPS server with HTTP redirect.
/// Equivalent to SmartProxy's `createCompleteHttpsServer()`.
pub fn create_complete_https_server(
domain: impl Into<String>,
target_host: impl Into<String>,
target_port: u16,
) -> Vec<RouteConfig> {
let domain = domain.into();
let target_host = target_host.into();
vec![
create_http_to_https_redirect(DomainSpec::Single(domain.clone())),
create_https_terminate_route(
DomainSpec::Single(domain),
target_host,
target_port,
),
]
}
/// Create a load balancer route.
/// Equivalent to SmartProxy's `createLoadBalancerRoute()`.
pub fn create_load_balancer_route(
domains: impl Into<DomainSpec>,
targets: Vec<(String, u16)>,
tls: Option<RouteTls>,
) -> RouteConfig {
let route_targets: Vec<RouteTarget> = targets
.into_iter()
.map(|(host, port)| RouteTarget {
target_match: None,
host: HostSpec::Single(host),
port: PortSpec::Fixed(port),
tls: None,
websocket: None,
load_balancing: None,
send_proxy_protocol: None,
headers: None,
advanced: None,
priority: None,
})
.collect();
let port = if tls.is_some() { 443 } else { 80 };
RouteConfig {
id: None,
route_match: RouteMatch {
ports: PortRange::Single(port),
domains: Some(domains.into()),
path: None,
client_ip: None,
tls_version: None,
headers: None,
protocol: None,
},
action: RouteAction {
action_type: RouteActionType::Forward,
targets: Some(route_targets),
tls,
websocket: None,
load_balancing: Some(RouteLoadBalancing {
algorithm: LoadBalancingAlgorithm::RoundRobin,
health_check: None,
}),
advanced: None,
options: None,
forwarding_engine: None,
nftables: None,
send_proxy_protocol: None,
},
headers: None,
security: None,
name: Some("Load Balancer".to_string()),
description: None,
priority: None,
tags: None,
enabled: None,
}
}
// Convenience conversions for DomainSpec
impl From<&str> for DomainSpec {
fn from(s: &str) -> Self {
DomainSpec::Single(s.to_string())
}
}
impl From<String> for DomainSpec {
fn from(s: String) -> Self {
DomainSpec::Single(s)
}
}
impl From<Vec<String>> for DomainSpec {
fn from(v: Vec<String>) -> Self {
DomainSpec::List(v)
}
}
impl From<Vec<&str>> for DomainSpec {
fn from(v: Vec<&str>) -> Self {
DomainSpec::List(v.into_iter().map(|s| s.to_string()).collect())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tls_types::TlsMode;
#[test]
fn test_create_http_route() {
let route = create_http_route("example.com", "localhost", 8080);
assert_eq!(route.route_match.ports.to_ports(), vec![80]);
let domains = route.route_match.domains.as_ref().unwrap().to_vec();
assert_eq!(domains, vec!["example.com"]);
let target = &route.action.targets.as_ref().unwrap()[0];
assert_eq!(target.host.first(), "localhost");
assert_eq!(target.port.resolve(80), 8080);
assert!(route.action.tls.is_none());
}
#[test]
fn test_create_https_terminate_route() {
let route = create_https_terminate_route("api.example.com", "backend", 3000);
assert_eq!(route.route_match.ports.to_ports(), vec![443]);
let tls = route.action.tls.as_ref().unwrap();
assert_eq!(tls.mode, TlsMode::Terminate);
assert!(tls.certificate.as_ref().unwrap().is_auto());
}
#[test]
fn test_create_https_passthrough_route() {
let route = create_https_passthrough_route("secure.example.com", "backend", 443);
assert_eq!(route.route_match.ports.to_ports(), vec![443]);
let tls = route.action.tls.as_ref().unwrap();
assert_eq!(tls.mode, TlsMode::Passthrough);
assert!(tls.certificate.is_none());
}
#[test]
fn test_create_http_to_https_redirect() {
let route = create_http_to_https_redirect("example.com");
assert_eq!(route.route_match.ports.to_ports(), vec![80]);
assert!(route.action.targets.is_none());
let test_response = route.action.advanced.as_ref().unwrap().test_response.as_ref().unwrap();
assert_eq!(test_response.status, 301);
assert!(test_response.headers.contains_key("Location"));
}
#[test]
fn test_create_complete_https_server() {
let routes = create_complete_https_server("example.com", "backend", 8080);
assert_eq!(routes.len(), 2);
// First route is HTTP redirect
assert_eq!(routes[0].route_match.ports.to_ports(), vec![80]);
// Second route is HTTPS terminate
assert_eq!(routes[1].route_match.ports.to_ports(), vec![443]);
}
#[test]
fn test_create_load_balancer_route() {
let targets = vec![
("backend1".to_string(), 8080),
("backend2".to_string(), 8080),
("backend3".to_string(), 8080),
];
let route = create_load_balancer_route("*.example.com", targets, None);
assert_eq!(route.route_match.ports.to_ports(), vec![80]);
assert_eq!(route.action.targets.as_ref().unwrap().len(), 3);
let lb = route.action.load_balancing.as_ref().unwrap();
assert_eq!(lb.algorithm, LoadBalancingAlgorithm::RoundRobin);
}
#[test]
fn test_domain_spec_from_str() {
let spec: DomainSpec = "example.com".into();
assert_eq!(spec.to_vec(), vec!["example.com"]);
}
#[test]
fn test_domain_spec_from_vec() {
let spec: DomainSpec = vec!["a.com", "b.com"].into();
assert_eq!(spec.to_vec(), vec!["a.com", "b.com"]);
}
}
+4 -6
View File
@@ -3,17 +3,15 @@
//! Configuration types for RustProxy, fully compatible with SmartProxy's JSON schema.
//! All types use `#[serde(rename_all = "camelCase")]` to match TypeScript field naming.
pub mod route_types;
pub mod proxy_options;
pub mod tls_types;
pub mod route_types;
pub mod security_types;
pub mod tls_types;
pub mod validation;
pub mod helpers;
// Re-export all primary types
pub use route_types::*;
pub use proxy_options::*;
pub use tls_types::*;
pub use route_types::*;
pub use security_types::*;
pub use tls_types::*;
pub use validation::*;
pub use helpers::*;
+293 -19
View File
@@ -97,6 +97,16 @@ pub struct MetricsConfig {
pub retention_seconds: Option<u64>,
}
/// Global ingress security policy.
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SecurityPolicy {
#[serde(skip_serializing_if = "Option::is_none")]
pub blocked_ips: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub blocked_cidrs: Option<Vec<String>>,
}
/// RustProxy configuration options.
/// Matches TypeScript: `ISmartProxyOptions`
///
@@ -129,7 +139,6 @@ pub struct RustProxyOptions {
pub defaults: Option<DefaultConfig>,
// ─── Timeout Settings ────────────────────────────────────────────
/// Timeout for establishing connection to backend (ms), default: 30000
#[serde(skip_serializing_if = "Option::is_none")]
pub connection_timeout: Option<u64>,
@@ -159,7 +168,6 @@ pub struct RustProxyOptions {
pub graceful_shutdown_timeout: Option<u64>,
// ─── Socket Optimization ─────────────────────────────────────────
/// Disable Nagle's algorithm (default: true)
#[serde(skip_serializing_if = "Option::is_none")]
pub no_delay: Option<bool>,
@@ -177,7 +185,6 @@ pub struct RustProxyOptions {
pub max_pending_data_size: Option<u64>,
// ─── Enhanced Features ───────────────────────────────────────────
/// Disable inactivity checking entirely
#[serde(skip_serializing_if = "Option::is_none")]
pub disable_inactivity_check: Option<bool>,
@@ -199,7 +206,6 @@ pub struct RustProxyOptions {
pub enable_randomized_timeouts: Option<bool>,
// ─── Rate Limiting ───────────────────────────────────────────────
/// Maximum simultaneous connections from a single IP
#[serde(skip_serializing_if = "Option::is_none")]
pub max_connections_per_ip: Option<u64>,
@@ -213,7 +219,6 @@ pub struct RustProxyOptions {
pub max_connections: Option<u64>,
// ─── Keep-Alive Settings ─────────────────────────────────────────
/// How to treat keep-alive connections
#[serde(skip_serializing_if = "Option::is_none")]
pub keep_alive_treatment: Option<KeepAliveTreatment>,
@@ -227,7 +232,6 @@ pub struct RustProxyOptions {
pub extended_keep_alive_lifetime: Option<u64>,
// ─── HttpProxy Integration ───────────────────────────────────────
/// Array of ports to forward to HttpProxy
#[serde(skip_serializing_if = "Option::is_none")]
pub use_http_proxy: Option<Vec<u16>>,
@@ -237,13 +241,15 @@ pub struct RustProxyOptions {
pub http_proxy_port: Option<u16>,
// ─── Metrics ─────────────────────────────────────────────────────
/// Metrics configuration
#[serde(skip_serializing_if = "Option::is_none")]
pub metrics: Option<MetricsConfig>,
// ─── ACME ────────────────────────────────────────────────────────
/// Global ingress security policy, enforced before route selection.
#[serde(skip_serializing_if = "Option::is_none")]
pub security_policy: Option<SecurityPolicy>,
// ─── ACME ────────────────────────────────────────────────────────
/// Global ACME configuration
#[serde(skip_serializing_if = "Option::is_none")]
pub acme: Option<AcmeOptions>,
@@ -283,6 +289,7 @@ impl Default for RustProxyOptions {
use_http_proxy: None,
http_proxy_port: None,
metrics: None,
security_policy: None,
acme: None,
}
}
@@ -318,7 +325,8 @@ impl RustProxyOptions {
/// Get all unique ports that routes listen on.
pub fn all_listening_ports(&self) -> Vec<u16> {
let mut ports: Vec<u16> = self.routes
let mut ports: Vec<u16> = self
.routes
.iter()
.flat_map(|r| r.listening_ports())
.collect();
@@ -331,12 +339,73 @@ impl RustProxyOptions {
#[cfg(test)]
mod tests {
use super::*;
use crate::helpers::*;
use crate::route_types::*;
use crate::tls_types::*;
fn make_route(domain: &str, host: &str, port: u16, listen_port: u16) -> RouteConfig {
RouteConfig {
id: None,
route_match: RouteMatch {
ports: PortRange::Single(listen_port),
domains: Some(DomainSpec::Single(domain.to_string())),
path: None,
client_ip: None,
transport: None,
tls_version: None,
headers: None,
protocol: None,
},
action: RouteAction {
action_type: RouteActionType::Forward,
targets: Some(vec![RouteTarget {
target_match: None,
host: HostSpec::Single(host.to_string()),
port: PortSpec::Fixed(port),
tls: None,
websocket: None,
load_balancing: None,
send_proxy_protocol: None,
headers: None,
advanced: None,
backend_transport: None,
priority: None,
}]),
tls: None,
websocket: None,
load_balancing: None,
advanced: None,
options: None,
send_proxy_protocol: None,
udp: None,
},
headers: None,
security: None,
name: None,
description: None,
priority: None,
tags: None,
enabled: None,
}
}
fn make_passthrough_route(domain: &str, host: &str, port: u16) -> RouteConfig {
let mut route = make_route(domain, host, port, 443);
route.action.tls = Some(RouteTls {
mode: TlsMode::Passthrough,
certificate: None,
acme: None,
versions: None,
ciphers: None,
honor_cipher_order: None,
session_timeout: None,
});
route
}
#[test]
fn test_serde_roundtrip_minimal() {
let options = RustProxyOptions {
routes: vec![create_http_route("example.com", "localhost", 8080)],
routes: vec![make_route("example.com", "localhost", 8080, 80)],
..Default::default()
};
let json = serde_json::to_string(&options).unwrap();
@@ -348,8 +417,8 @@ mod tests {
fn test_serde_roundtrip_full() {
let options = RustProxyOptions {
routes: vec![
create_http_route("a.com", "backend1", 8080),
create_https_passthrough_route("b.com", "backend2", 443),
make_route("a.com", "backend1", 8080, 80),
make_passthrough_route("b.com", "backend2", 443),
],
connection_timeout: Some(5000),
socket_timeout: Some(60000),
@@ -374,6 +443,209 @@ mod tests {
assert_eq!(parsed.connection_timeout, Some(5000));
}
#[test]
fn test_deserialize_ts_contract_route_shapes() {
let value = serde_json::json!({
"routes": [{
"name": "contract-route",
"match": {
"ports": [443, { "from": 8443, "to": 8444 }],
"domains": ["api.example.com", "*.example.com"],
"transport": "udp",
"protocol": "http3",
"headers": {
"content-type": "/^application\\/json$/i"
}
},
"action": {
"type": "forward",
"targets": [{
"match": {
"ports": [443],
"path": "/api/*",
"method": ["GET"],
"headers": {
"x-env": "/^(prod|stage)$/"
}
},
"host": ["backend-a", "backend-b"],
"port": "preserve",
"sendProxyProtocol": true,
"backendTransport": "tcp"
}],
"tls": {
"mode": "terminate",
"certificate": "auto"
},
"sendProxyProtocol": true,
"udp": {
"maxSessionsPerIp": 321,
"quic": {
"enableHttp3": true
}
}
},
"security": {
"ipAllowList": [{
"ip": "10.0.0.0/8",
"domains": ["api.example.com"]
}]
}
}],
"preserveSourceIp": true,
"proxyIps": ["10.0.0.1"],
"acceptProxyProtocol": true,
"sendProxyProtocol": true,
"noDelay": true,
"keepAlive": true,
"keepAliveInitialDelay": 1500,
"maxPendingDataSize": 4096,
"disableInactivityCheck": true,
"enableKeepAliveProbes": true,
"enableDetailedLogging": true,
"enableTlsDebugLogging": true,
"enableRandomizedTimeouts": true,
"connectionTimeout": 5000,
"initialDataTimeout": 7000,
"socketTimeout": 9000,
"inactivityCheckInterval": 1100,
"maxConnectionLifetime": 13000,
"inactivityTimeout": 15000,
"gracefulShutdownTimeout": 17000,
"maxConnectionsPerIp": 20,
"connectionRateLimitPerMinute": 30,
"keepAliveTreatment": "extended",
"keepAliveInactivityMultiplier": 2.0,
"extendedKeepAliveLifetime": 19000,
"metrics": {
"enabled": true,
"sampleIntervalMs": 250,
"retentionSeconds": 60
},
"acme": {
"enabled": true,
"email": "ops@example.com",
"environment": "staging",
"useProduction": false,
"skipConfiguredCerts": true,
"renewThresholdDays": 14,
"renewCheckIntervalHours": 12,
"autoRenew": true,
"port": 80
}
});
let options: RustProxyOptions = serde_json::from_value(value).unwrap();
assert_eq!(options.routes.len(), 1);
assert_eq!(options.preserve_source_ip, Some(true));
assert_eq!(options.proxy_ips, Some(vec!["10.0.0.1".to_string()]));
assert_eq!(options.accept_proxy_protocol, Some(true));
assert_eq!(options.send_proxy_protocol, Some(true));
assert_eq!(options.no_delay, Some(true));
assert_eq!(options.keep_alive, Some(true));
assert_eq!(options.keep_alive_initial_delay, Some(1500));
assert_eq!(options.max_pending_data_size, Some(4096));
assert_eq!(options.disable_inactivity_check, Some(true));
assert_eq!(options.enable_keep_alive_probes, Some(true));
assert_eq!(options.enable_detailed_logging, Some(true));
assert_eq!(options.enable_tls_debug_logging, Some(true));
assert_eq!(options.enable_randomized_timeouts, Some(true));
assert_eq!(options.connection_timeout, Some(5000));
assert_eq!(options.initial_data_timeout, Some(7000));
assert_eq!(options.socket_timeout, Some(9000));
assert_eq!(options.inactivity_check_interval, Some(1100));
assert_eq!(options.max_connection_lifetime, Some(13000));
assert_eq!(options.inactivity_timeout, Some(15000));
assert_eq!(options.graceful_shutdown_timeout, Some(17000));
assert_eq!(options.max_connections_per_ip, Some(20));
assert_eq!(options.connection_rate_limit_per_minute, Some(30));
assert_eq!(
options.keep_alive_treatment,
Some(KeepAliveTreatment::Extended)
);
assert_eq!(options.keep_alive_inactivity_multiplier, Some(2.0));
assert_eq!(options.extended_keep_alive_lifetime, Some(19000));
let route = &options.routes[0];
assert_eq!(route.route_match.transport, Some(TransportProtocol::Udp));
assert_eq!(route.route_match.protocol.as_deref(), Some("http3"));
assert_eq!(
route
.route_match
.headers
.as_ref()
.unwrap()
.get("content-type")
.unwrap(),
"/^application\\/json$/i"
);
let target = &route.action.targets.as_ref().unwrap()[0];
assert!(matches!(target.host, HostSpec::List(_)));
assert!(matches!(target.port, PortSpec::Special(ref p) if p == "preserve"));
assert_eq!(target.backend_transport, Some(TransportProtocol::Tcp));
assert_eq!(target.send_proxy_protocol, Some(true));
assert_eq!(
target
.target_match
.as_ref()
.unwrap()
.headers
.as_ref()
.unwrap()
.get("x-env")
.unwrap(),
"/^(prod|stage)$/"
);
assert_eq!(route.action.send_proxy_protocol, Some(true));
assert_eq!(
route.action.udp.as_ref().unwrap().max_sessions_per_ip,
Some(321)
);
assert_eq!(
route
.action
.udp
.as_ref()
.unwrap()
.quic
.as_ref()
.unwrap()
.enable_http3,
Some(true)
);
let allow_list = route
.security
.as_ref()
.unwrap()
.ip_allow_list
.as_ref()
.unwrap();
assert!(matches!(
&allow_list[0],
crate::security_types::IpAllowEntry::DomainScoped { ip, domains }
if ip == "10.0.0.0/8" && domains == &vec!["api.example.com".to_string()]
));
let metrics = options.metrics.as_ref().unwrap();
assert_eq!(metrics.enabled, Some(true));
assert_eq!(metrics.sample_interval_ms, Some(250));
assert_eq!(metrics.retention_seconds, Some(60));
let acme = options.acme.as_ref().unwrap();
assert_eq!(acme.enabled, Some(true));
assert_eq!(acme.email.as_deref(), Some("ops@example.com"));
assert_eq!(acme.environment, Some(AcmeEnvironment::Staging));
assert_eq!(acme.use_production, Some(false));
assert_eq!(acme.skip_configured_certs, Some(true));
assert_eq!(acme.renew_threshold_days, Some(14));
assert_eq!(acme.renew_check_interval_hours, Some(12));
assert_eq!(acme.auto_renew, Some(true));
assert_eq!(acme.port, Some(80));
}
#[test]
fn test_default_timeouts() {
let options = RustProxyOptions::default();
@@ -402,9 +674,9 @@ mod tests {
fn test_all_listening_ports() {
let options = RustProxyOptions {
routes: vec![
create_http_route("a.com", "backend", 8080), // port 80
create_https_passthrough_route("b.com", "backend", 443), // port 443
create_http_route("c.com", "backend", 9090), // port 80 (duplicate)
make_route("a.com", "backend", 8080, 80), // port 80
make_passthrough_route("b.com", "backend", 443), // port 443
make_route("c.com", "backend", 9090, 80), // port 80 (duplicate)
],
..Default::default()
};
@@ -428,9 +700,11 @@ mod tests {
#[test]
fn test_deserialize_example_json() {
let content = std::fs::read_to_string(
concat!(env!("CARGO_MANIFEST_DIR"), "/../../config/example.json")
).unwrap();
let content = std::fs::read_to_string(concat!(
env!("CARGO_MANIFEST_DIR"),
"/../../config/example.json"
))
.unwrap();
let options: RustProxyOptions = serde_json::from_str(&content).unwrap();
assert_eq!(options.routes.len(), 4);
let ports = options.all_listening_ports();
+194 -62
View File
@@ -1,22 +1,30 @@
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use crate::tls_types::RouteTls;
use crate::security_types::RouteSecurity;
use crate::tls_types::RouteTls;
// ─── Port Range ──────────────────────────────────────────────────────
/// Port range specification format.
/// Matches TypeScript: `type TPortRange = number | number[] | Array<{ from: number; to: number }>`
/// Matches TypeScript: `type TPortRange = number | Array<number | { from: number; to: number }>`
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum PortRange {
/// Single port number
Single(u16),
/// Array of port numbers
List(Vec<u16>),
/// Array of port ranges
Ranges(Vec<PortRangeSpec>),
/// Array of port numbers, ranges, or mixed
List(Vec<PortRangeItem>),
}
/// A single item in a port range array: either a number or a from-to range.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum PortRangeItem {
/// Single port number
Port(u16),
/// A from-to port range
Range(PortRangeSpec),
}
impl PortRange {
@@ -24,10 +32,13 @@ impl PortRange {
pub fn to_ports(&self) -> Vec<u16> {
match self {
PortRange::Single(p) => vec![*p],
PortRange::List(ports) => ports.clone(),
PortRange::Ranges(ranges) => {
ranges.iter().flat_map(|r| r.from..=r.to).collect()
}
PortRange::List(items) => items
.iter()
.flat_map(|item| match item {
PortRangeItem::Port(p) => vec![*p],
PortRangeItem::Range(r) => (r.from..=r.to).collect(),
})
.collect(),
}
}
}
@@ -50,16 +61,6 @@ pub enum RouteActionType {
SocketHandler,
}
// ─── Forwarding Engine ───────────────────────────────────────────────
/// Forwarding engine specification.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum ForwardingEngine {
Node,
Nftables,
}
// ─── Route Match ─────────────────────────────────────────────────────
/// Domain specification: single string or array.
@@ -79,8 +80,34 @@ impl DomainSpec {
}
}
// Convenience conversions for DomainSpec
impl From<&str> for DomainSpec {
fn from(s: &str) -> Self {
DomainSpec::Single(s.to_string())
}
}
impl From<String> for DomainSpec {
fn from(s: String) -> Self {
DomainSpec::Single(s)
}
}
impl From<Vec<String>> for DomainSpec {
fn from(v: Vec<String>) -> Self {
DomainSpec::List(v)
}
}
impl From<Vec<&str>> for DomainSpec {
fn from(v: Vec<&str>) -> Self {
DomainSpec::List(v.into_iter().map(|s| s.to_string()).collect())
}
}
/// Header match value: either exact string or regex pattern.
/// In JSON, all values come as strings. Regex patterns are prefixed with `/` and suffixed with `/`.
/// In JSON, all values come as strings. Regex patterns use JS-style literal syntax,
/// e.g. `/^application\/json$/` or `/^application\/json$/i`.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum HeaderMatchValue {
@@ -95,6 +122,10 @@ pub struct RouteMatch {
/// Listen on these ports (required)
pub ports: PortRange,
/// Transport protocol: tcp (default), udp, or all (both TCP and UDP)
#[serde(skip_serializing_if = "Option::is_none")]
pub transport: Option<TransportProtocol>,
/// Optional domain patterns to match (default: all domains)
#[serde(skip_serializing_if = "Option::is_none")]
pub domains: Option<DomainSpec>,
@@ -115,7 +146,7 @@ pub struct RouteMatch {
#[serde(skip_serializing_if = "Option::is_none")]
pub headers: Option<HashMap<String, String>>,
/// Match specific protocol: "http" (includes h2 + websocket) or "tcp"
/// Match specific protocol: "http", "tcp", "udp", "quic", "http3"
#[serde(skip_serializing_if = "Option::is_none")]
pub protocol: Option<String>,
}
@@ -327,38 +358,6 @@ pub struct RouteAdvanced {
pub url_rewrite: Option<RouteUrlRewrite>,
}
// ─── NFTables Options ────────────────────────────────────────────────
/// NFTables protocol type.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum NfTablesProtocol {
Tcp,
Udp,
All,
}
/// NFTables-specific configuration options.
/// Matches TypeScript: `INfTablesOptions`
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct NfTablesOptions {
#[serde(skip_serializing_if = "Option::is_none")]
pub preserve_source_ip: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub protocol: Option<NfTablesProtocol>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_rate: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub priority: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub table_name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub use_ip_sets: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub use_advanced_nat: Option<bool>,
}
// ─── Backend Protocol ────────────────────────────────────────────────
/// Backend protocol.
@@ -367,9 +366,19 @@ pub struct NfTablesOptions {
pub enum BackendProtocol {
Http1,
Http2,
Http3,
Auto,
}
/// Transport protocol for route matching.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum TransportProtocol {
Tcp,
Udp,
All,
}
/// Action options.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
@@ -470,6 +479,10 @@ pub struct RouteTarget {
#[serde(skip_serializing_if = "Option::is_none")]
pub advanced: Option<RouteAdvanced>,
/// Override transport for backend connection (e.g., receive QUIC but forward as TCP)
#[serde(skip_serializing_if = "Option::is_none")]
pub backend_transport: Option<TransportProtocol>,
/// Priority for matching (higher values checked first, default: 0)
#[serde(skip_serializing_if = "Option::is_none")]
pub priority: Option<i32>,
@@ -513,17 +526,71 @@ pub struct RouteAction {
#[serde(skip_serializing_if = "Option::is_none")]
pub options: Option<ActionOptions>,
/// Forwarding engine specification
#[serde(skip_serializing_if = "Option::is_none")]
pub forwarding_engine: Option<ForwardingEngine>,
/// NFTables-specific options
#[serde(skip_serializing_if = "Option::is_none")]
pub nftables: Option<NfTablesOptions>,
/// PROXY protocol support (default for all targets)
#[serde(skip_serializing_if = "Option::is_none")]
pub send_proxy_protocol: Option<bool>,
/// UDP-specific settings (session tracking, datagram limits, QUIC config)
#[serde(skip_serializing_if = "Option::is_none")]
pub udp: Option<RouteUdp>,
}
// ─── UDP & QUIC Config ──────────────────────────────────────────────
/// UDP-specific settings for route actions.
/// Matches TypeScript: `IRouteUdp`
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct RouteUdp {
/// Idle timeout for a UDP session/flow in ms. Default: 60000
#[serde(skip_serializing_if = "Option::is_none")]
pub session_timeout: Option<u64>,
/// Max concurrent UDP sessions per source IP. Default: 1000
#[serde(skip_serializing_if = "Option::is_none")]
pub max_sessions_per_ip: Option<u32>,
/// Max accepted datagram size in bytes. Default: 65535
#[serde(skip_serializing_if = "Option::is_none")]
pub max_datagram_size: Option<u32>,
/// QUIC-specific configuration
#[serde(skip_serializing_if = "Option::is_none")]
pub quic: Option<RouteQuic>,
}
/// QUIC and HTTP/3 settings.
/// Matches TypeScript: `IRouteQuic`
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct RouteQuic {
/// QUIC connection idle timeout in ms. Default: 30000
#[serde(skip_serializing_if = "Option::is_none")]
pub max_idle_timeout: Option<u64>,
/// Max concurrent bidirectional streams per QUIC connection. Default: 100
#[serde(skip_serializing_if = "Option::is_none")]
pub max_concurrent_bidi_streams: Option<u32>,
/// Max concurrent unidirectional streams per QUIC connection. Default: 100
#[serde(skip_serializing_if = "Option::is_none")]
pub max_concurrent_uni_streams: Option<u32>,
/// Enable HTTP/3 over this QUIC endpoint. Default: false
#[serde(skip_serializing_if = "Option::is_none")]
pub enable_http3: Option<bool>,
/// Port to advertise in Alt-Svc header on TCP HTTP responses
#[serde(skip_serializing_if = "Option::is_none")]
pub alt_svc_port: Option<u16>,
/// Max age for Alt-Svc advertisement in seconds. Default: 86400
#[serde(skip_serializing_if = "Option::is_none")]
pub alt_svc_max_age: Option<u64>,
/// Initial congestion window size in bytes
#[serde(skip_serializing_if = "Option::is_none")]
pub initial_congestion_window: Option<u32>,
}
// ─── Route Config ────────────────────────────────────────────────────
@@ -589,6 +656,11 @@ impl RouteConfig {
self.route_match.ports.to_ports()
}
/// Stable key used for frontend route-scoped metrics.
pub fn metrics_key(&self) -> Option<&str> {
self.name.as_deref().or(self.id.as_deref())
}
/// Get the TLS mode for this route (from action-level or first target).
pub fn tls_mode(&self) -> Option<&crate::tls_types::TlsMode> {
// Check action-level TLS first
@@ -606,3 +678,63 @@ impl RouteConfig {
None
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_route(name: Option<&str>, id: Option<&str>) -> RouteConfig {
RouteConfig {
id: id.map(str::to_string),
route_match: RouteMatch {
ports: PortRange::Single(443),
transport: None,
domains: None,
path: None,
client_ip: None,
tls_version: None,
headers: None,
protocol: None,
},
action: RouteAction {
action_type: RouteActionType::Forward,
targets: None,
tls: None,
websocket: None,
load_balancing: None,
advanced: None,
options: None,
send_proxy_protocol: None,
udp: None,
},
headers: None,
security: None,
name: name.map(str::to_string),
description: None,
priority: None,
tags: None,
enabled: None,
}
}
#[test]
fn metrics_key_prefers_name() {
let route = test_route(Some("named-route"), Some("route-id"));
assert_eq!(route.metrics_key(), Some("named-route"));
}
#[test]
fn metrics_key_falls_back_to_id() {
let route = test_route(None, Some("route-id"));
assert_eq!(route.metrics_key(), Some("route-id"));
}
#[test]
fn metrics_key_is_absent_without_name_or_id() {
let route = test_route(None, None);
assert_eq!(route.metrics_key(), None);
}
}
@@ -103,14 +103,27 @@ pub struct JwtAuthConfig {
pub exclude_paths: Option<Vec<String>>,
}
/// An entry in the IP allow list: either a plain IP/CIDR string
/// or a domain-scoped entry that restricts the IP to specific domains.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum IpAllowEntry {
/// Plain IP/CIDR — allowed for all domains on this route
Plain(String),
/// Domain-scoped — allowed only when the requested domain matches
DomainScoped { ip: String, domains: Vec<String> },
}
/// Security options for routes.
/// Matches TypeScript: `IRouteSecurity`
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct RouteSecurity {
/// IP addresses that are allowed to connect
/// IP addresses that are allowed to connect.
/// Entries can be plain strings (full route access) or objects with
/// `{ ip, domains }` to scope access to specific domains.
#[serde(skip_serializing_if = "Option::is_none")]
pub ip_allow_list: Option<Vec<String>>,
pub ip_allow_list: Option<Vec<IpAllowEntry>>,
/// IP addresses that are blocked from connecting
#[serde(skip_serializing_if = "Option::is_none")]
pub ip_block_list: Option<Vec<String>>,
+60 -9
View File
@@ -1,6 +1,6 @@
use thiserror::Error;
use crate::route_types::{RouteConfig, RouteActionType};
use crate::route_types::{RouteActionType, RouteConfig};
/// Validation errors for route configurations.
#[derive(Debug, Error)]
@@ -30,9 +30,10 @@ pub enum ValidationError {
/// Validate a single route configuration.
pub fn validate_route(route: &RouteConfig) -> Result<(), Vec<ValidationError>> {
let mut errors = Vec::new();
let name = route.name.clone().unwrap_or_else(|| {
route.id.clone().unwrap_or_else(|| "unnamed".to_string())
});
let name = route
.name
.clone()
.unwrap_or_else(|| route.id.clone().unwrap_or_else(|| "unnamed".to_string()));
// Check ports
let ports = route.listening_ports();
@@ -104,7 +105,49 @@ mod tests {
use crate::route_types::*;
fn make_valid_route() -> RouteConfig {
crate::helpers::create_http_route("example.com", "localhost", 8080)
RouteConfig {
id: None,
route_match: RouteMatch {
ports: PortRange::Single(80),
domains: Some(DomainSpec::Single("example.com".to_string())),
path: None,
client_ip: None,
transport: None,
tls_version: None,
headers: None,
protocol: None,
},
action: RouteAction {
action_type: RouteActionType::Forward,
targets: Some(vec![RouteTarget {
target_match: None,
host: HostSpec::Single("localhost".to_string()),
port: PortSpec::Fixed(8080),
tls: None,
websocket: None,
load_balancing: None,
send_proxy_protocol: None,
headers: None,
advanced: None,
backend_transport: None,
priority: None,
}]),
tls: None,
websocket: None,
load_balancing: None,
advanced: None,
options: None,
send_proxy_protocol: None,
udp: None,
},
headers: None,
security: None,
name: None,
description: None,
priority: None,
tags: None,
enabled: None,
}
}
#[test]
@@ -118,7 +161,9 @@ mod tests {
let mut route = make_valid_route();
route.action.targets = None;
let errors = validate_route(&route).unwrap_err();
assert!(errors.iter().any(|e| matches!(e, ValidationError::MissingTargets { .. })));
assert!(errors
.iter()
.any(|e| matches!(e, ValidationError::MissingTargets { .. })));
}
#[test]
@@ -126,7 +171,9 @@ mod tests {
let mut route = make_valid_route();
route.action.targets = Some(vec![]);
let errors = validate_route(&route).unwrap_err();
assert!(errors.iter().any(|e| matches!(e, ValidationError::EmptyTargets { .. })));
assert!(errors
.iter()
.any(|e| matches!(e, ValidationError::EmptyTargets { .. })));
}
#[test]
@@ -134,7 +181,9 @@ mod tests {
let mut route = make_valid_route();
route.route_match.ports = PortRange::Single(0);
let errors = validate_route(&route).unwrap_err();
assert!(errors.iter().any(|e| matches!(e, ValidationError::InvalidPort { port: 0, .. })));
assert!(errors
.iter()
.any(|e| matches!(e, ValidationError::InvalidPort { port: 0, .. })));
}
#[test]
@@ -144,7 +193,9 @@ mod tests {
let mut r2 = make_valid_route();
r2.id = Some("route-1".to_string());
let errors = validate_routes(&[r1, r2]).unwrap_err();
assert!(errors.iter().any(|e| matches!(e, ValidationError::DuplicateId { .. })));
assert!(errors
.iter()
.any(|e| matches!(e, ValidationError::DuplicateId { .. })));
}
#[test]
+4
View File
@@ -27,3 +27,7 @@ arc-swap = { workspace = true }
dashmap = { workspace = true }
tokio-util = { workspace = true }
socket2 = { workspace = true }
quinn = { workspace = true }
h3 = { workspace = true }
h3-quinn = { workspace = true }
futures = { version = "0.3", default-features = false, features = ["std"] }
+168 -30
View File
@@ -1,8 +1,9 @@
//! Backend connection pool for HTTP/1.1 and HTTP/2.
//! Backend connection pool for HTTP/1.1, HTTP/2, and HTTP/3 (QUIC).
//!
//! Reuses idle keep-alive connections to avoid per-request TCP+TLS handshakes.
//! HTTP/2 connections are multiplexed (clone the sender for each request).
//! HTTP/2 and HTTP/3 connections are multiplexed (clone the sender / share the connection).
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
@@ -10,7 +11,6 @@ use bytes::Bytes;
use dashmap::DashMap;
use http_body_util::combinators::BoxBody;
use hyper::client::conn::{http1, http2};
use tracing::{debug, warn};
/// Maximum idle connections per backend key.
const MAX_IDLE_PER_KEY: usize = 16;
@@ -19,9 +19,17 @@ const IDLE_TIMEOUT: Duration = Duration::from_secs(90);
/// Background eviction interval.
const EVICTION_INTERVAL: Duration = Duration::from_secs(30);
/// Maximum age for pooled HTTP/2 connections before proactive eviction.
/// Prevents staleness from backends that close idle connections (e.g. nginx GOAWAY).
/// 120s is well within typical server GOAWAY windows (nginx: ~60s idle, envoy: ~60s).
const MAX_H2_AGE: Duration = Duration::from_secs(120);
/// Maximum age for pooled QUIC/HTTP/3 connections.
const MAX_H3_AGE: Duration = Duration::from_secs(120);
/// Protocol for pool key discrimination.
#[derive(Clone, Debug, Hash, Eq, PartialEq)]
pub enum PoolProtocol {
H1,
H2,
H3,
}
/// Identifies a unique backend endpoint.
#[derive(Clone, Debug, Hash, Eq, PartialEq)]
@@ -29,7 +37,7 @@ pub struct PoolKey {
pub host: String,
pub port: u16,
pub use_tls: bool,
pub h2: bool,
pub protocol: PoolProtocol,
}
/// An idle HTTP/1.1 sender with a timestamp for eviction.
@@ -38,10 +46,24 @@ struct IdleH1 {
idle_since: Instant,
}
/// A pooled HTTP/2 sender (multiplexed, Clone-able).
/// A pooled HTTP/2 sender (multiplexed, Clone-able) with a generation tag.
struct PooledH2 {
sender: http2::SendRequest<BoxBody<Bytes, hyper::Error>>,
created_at: Instant,
/// Unique generation ID. Connection drivers use this to only remove their OWN
/// entry, preventing phantom eviction when multiple connections share the same key.
generation: u64,
}
/// A pooled QUIC/HTTP/3 connection (multiplexed like H2).
/// Stores the h3 `SendRequest` handle so pool hits skip the h3 SETTINGS handshake.
pub struct PooledH3 {
/// Multiplexed h3 request handle — clone to open a new stream.
pub send_request: h3::client::SendRequest<h3_quinn::OpenStreams, Bytes>,
/// Raw QUIC connection — kept for liveness probing (close_reason) only.
pub connection: quinn::Connection,
pub created_at: Instant,
pub generation: u64,
}
/// Backend connection pool.
@@ -50,6 +72,10 @@ pub struct ConnectionPool {
h1_pool: Arc<DashMap<PoolKey, Vec<IdleH1>>>,
/// HTTP/2 multiplexed connections indexed by backend key.
h2_pool: Arc<DashMap<PoolKey, PooledH2>>,
/// HTTP/3 (QUIC) connections indexed by backend key.
h3_pool: Arc<DashMap<PoolKey, PooledH3>>,
/// Monotonic generation counter for H2/H3 pool entries.
h2_generation: AtomicU64,
/// Handle for the background eviction task.
eviction_handle: Option<tokio::task::JoinHandle<()>>,
}
@@ -59,30 +85,40 @@ impl ConnectionPool {
pub fn new() -> Self {
let h1_pool: Arc<DashMap<PoolKey, Vec<IdleH1>>> = Arc::new(DashMap::new());
let h2_pool: Arc<DashMap<PoolKey, PooledH2>> = Arc::new(DashMap::new());
let h3_pool: Arc<DashMap<PoolKey, PooledH3>> = Arc::new(DashMap::new());
let h1_clone = Arc::clone(&h1_pool);
let h2_clone = Arc::clone(&h2_pool);
let h3_clone = Arc::clone(&h3_pool);
let eviction_handle = tokio::spawn(async move {
Self::eviction_loop(h1_clone, h2_clone).await;
Self::eviction_loop(h1_clone, h2_clone, h3_clone).await;
});
Self {
h1_pool,
h2_pool,
h3_pool,
h2_generation: AtomicU64::new(0),
eviction_handle: Some(eviction_handle),
}
}
/// Try to check out an idle HTTP/1.1 sender for the given key.
/// Returns `None` if no usable idle connection exists.
pub fn checkout_h1(&self, key: &PoolKey) -> Option<http1::SendRequest<BoxBody<Bytes, hyper::Error>>> {
pub fn checkout_h1(
&self,
key: &PoolKey,
) -> Option<http1::SendRequest<BoxBody<Bytes, hyper::Error>>> {
let mut entry = self.h1_pool.get_mut(key)?;
let idles = entry.value_mut();
while let Some(idle) = idles.pop() {
// Check if the connection is still alive and ready
if idle.idle_since.elapsed() < IDLE_TIMEOUT && idle.sender.is_ready() && !idle.sender.is_closed() {
debug!("Pool hit (h1): {}:{}", key.host, key.port);
if idle.idle_since.elapsed() < IDLE_TIMEOUT
&& idle.sender.is_ready()
&& !idle.sender.is_closed()
{
// H1 pool hit — no logging on hot path
return Some(idle.sender);
}
// Stale or closed — drop it
@@ -98,7 +134,11 @@ impl ConnectionPool {
/// Return an HTTP/1.1 sender to the pool after the response body has been prepared.
/// The caller should NOT call this if the sender is closed or not ready.
pub fn checkin_h1(&self, key: PoolKey, sender: http1::SendRequest<BoxBody<Bytes, hyper::Error>>) {
pub fn checkin_h1(
&self,
key: PoolKey,
sender: http1::SendRequest<BoxBody<Bytes, hyper::Error>>,
) {
if sender.is_closed() || !sender.is_ready() {
return; // Don't pool broken connections
}
@@ -115,53 +155,136 @@ impl ConnectionPool {
/// Try to get a cloned HTTP/2 sender for the given key.
/// HTTP/2 senders are Clone-able (multiplexed), so we clone rather than remove.
pub fn checkout_h2(&self, key: &PoolKey) -> Option<(http2::SendRequest<BoxBody<Bytes, hyper::Error>>, Duration)> {
pub fn checkout_h2(
&self,
key: &PoolKey,
) -> Option<(http2::SendRequest<BoxBody<Bytes, hyper::Error>>, Duration)> {
let entry = self.h2_pool.get(key)?;
let pooled = entry.value();
let age = pooled.created_at.elapsed();
// Check if the h2 connection is still alive and not too old
if pooled.sender.is_closed() || age >= MAX_H2_AGE {
let reason = if pooled.sender.is_closed() { "closed" } else { "max_age" };
debug!("Pool evict (h2): {}:{} (reason={}, age={:.1}s)", key.host, key.port, reason, age.as_secs_f64());
drop(entry);
self.h2_pool.remove(key);
return None;
}
if pooled.sender.is_ready() {
if age > Duration::from_secs(30) {
warn!("Pool hit (h2): {}:{} — connection age {:.1}s (>30s, may be stale)", key.host, key.port, age.as_secs_f64());
} else {
debug!("Pool hit (h2): {}:{} (age={:.1}s)", key.host, key.port, age.as_secs_f64());
}
return Some((pooled.sender.clone(), age));
}
None
}
/// Remove a dead HTTP/2 sender from the pool.
/// Remove a dead HTTP/2 sender from the pool (unconditional).
/// Called when `send_request` fails to prevent subsequent requests from reusing the stale sender.
pub fn remove_h2(&self, key: &PoolKey) {
self.h2_pool.remove(key);
}
/// Register an HTTP/2 sender in the pool. Since h2 is multiplexed,
/// only one sender per key is stored (it's Clone-able).
pub fn register_h2(&self, key: PoolKey, sender: http2::SendRequest<BoxBody<Bytes, hyper::Error>>) {
if sender.is_closed() {
return;
/// Remove an HTTP/2 sender ONLY if the current entry has the expected generation.
/// This prevents phantom eviction: when multiple connections share the same key,
/// an old connection's driver won't accidentally remove a newer connection's entry.
pub fn remove_h2_if_generation(&self, key: &PoolKey, expected_gen: u64) {
if let Some(entry) = self.h2_pool.get(key) {
if entry.value().generation == expected_gen {
drop(entry); // release DashMap ref before remove
self.h2_pool.remove(key);
}
self.h2_pool.insert(key, PooledH2 {
// else: a newer connection replaced ours — don't touch it
}
}
/// Register an HTTP/2 sender in the pool. Returns the generation ID for this entry.
/// The caller should pass this generation to the connection driver so it can use
/// `remove_h2_if_generation` instead of `remove_h2` to avoid phantom eviction.
pub fn register_h2(
&self,
key: PoolKey,
sender: http2::SendRequest<BoxBody<Bytes, hyper::Error>>,
) -> u64 {
let gen = self.h2_generation.fetch_add(1, Ordering::Relaxed);
if sender.is_closed() {
return gen;
}
self.h2_pool.insert(
key,
PooledH2 {
sender,
created_at: Instant::now(),
});
generation: gen,
},
);
gen
}
// ── HTTP/3 (QUIC) pool methods ──
/// Try to get a pooled QUIC connection for the given key.
/// QUIC connections are multiplexed — the connection is shared, not removed.
pub fn checkout_h3(
&self,
key: &PoolKey,
) -> Option<(
h3::client::SendRequest<h3_quinn::OpenStreams, Bytes>,
quinn::Connection,
Duration,
)> {
let entry = self.h3_pool.get(key)?;
let pooled = entry.value();
let age = pooled.created_at.elapsed();
if age >= MAX_H3_AGE {
drop(entry);
self.h3_pool.remove(key);
return None;
}
// Check if QUIC connection is still alive
if pooled.connection.close_reason().is_some() {
drop(entry);
self.h3_pool.remove(key);
return None;
}
Some((pooled.send_request.clone(), pooled.connection.clone(), age))
}
/// Register a QUIC connection and its h3 SendRequest handle in the pool.
/// Returns the generation ID.
pub fn register_h3(
&self,
key: PoolKey,
connection: quinn::Connection,
send_request: h3::client::SendRequest<h3_quinn::OpenStreams, Bytes>,
) -> u64 {
let gen = self.h2_generation.fetch_add(1, Ordering::Relaxed);
self.h3_pool.insert(
key,
PooledH3 {
send_request,
connection,
created_at: Instant::now(),
generation: gen,
},
);
gen
}
/// Remove a QUIC connection only if generation matches.
pub fn remove_h3_if_generation(&self, key: &PoolKey, expected_gen: u64) {
if let Some(entry) = self.h3_pool.get(key) {
if entry.value().generation == expected_gen {
drop(entry);
self.h3_pool.remove(key);
}
}
}
/// Background eviction loop — runs every EVICTION_INTERVAL to remove stale connections.
async fn eviction_loop(
h1_pool: Arc<DashMap<PoolKey, Vec<IdleH1>>>,
h2_pool: Arc<DashMap<PoolKey, PooledH2>>,
h3_pool: Arc<DashMap<PoolKey, PooledH3>>,
) {
let mut interval = tokio::time::interval(EVICTION_INTERVAL);
loop {
@@ -184,13 +307,28 @@ impl ConnectionPool {
// Evict dead or aged-out H2 connections
let mut dead_h2 = Vec::new();
for entry in h2_pool.iter() {
if entry.value().sender.is_closed() || entry.value().created_at.elapsed() >= MAX_H2_AGE {
if entry.value().sender.is_closed()
|| entry.value().created_at.elapsed() >= MAX_H2_AGE
{
dead_h2.push(entry.key().clone());
}
}
for key in dead_h2 {
h2_pool.remove(&key);
}
// Evict dead or aged-out H3 (QUIC) connections
let mut dead_h3 = Vec::new();
for entry in h3_pool.iter() {
if entry.value().connection.close_reason().is_some()
|| entry.value().created_at.elapsed() >= MAX_H3_AGE
{
dead_h3.push(entry.key().clone());
}
}
for key in dead_h3 {
h3_pool.remove(&key);
}
}
}
}
+47 -17
View File
@@ -1,27 +1,36 @@
//! A body wrapper that counts bytes flowing through and reports them to MetricsCollector.
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::task::{Context, Poll};
use bytes::Bytes;
use http_body::Frame;
use rustproxy_metrics::MetricsCollector;
/// Flush accumulated bytes to the metrics collector every 64 KB.
/// This reduces per-frame DashMap shard-locked reads from ~15 to ~1 per 4 frames
/// (assuming typical 16 KB upload frames). The 1 Hz throughput sampler still sees
/// data within one sampling period even at low transfer rates.
const BYTE_FLUSH_THRESHOLD: u64 = 65_536;
/// Wraps any `http_body::Body` and counts data bytes passing through.
///
/// Each chunk is reported to the `MetricsCollector` immediately so that
/// the throughput tracker (sampled at 1 Hz) reflects real-time data flow.
/// Bytes are accumulated and flushed to the `MetricsCollector` every
/// [`BYTE_FLUSH_THRESHOLD`] bytes (and on Drop) so the throughput tracker
/// (sampled at 1 Hz) reflects real-time data flow without per-frame overhead.
///
/// The inner body is pinned on the heap to support `!Unpin` types like `hyper::body::Incoming`.
pub struct CountingBody<B> {
inner: Pin<Box<B>>,
metrics: Arc<MetricsCollector>,
route_id: Option<String>,
source_ip: Option<String>,
route_id: Option<Arc<str>>,
source_ip: Option<Arc<str>>,
/// Whether we count bytes as "in" (request body) or "out" (response body).
direction: Direction,
/// Accumulated bytes not yet flushed to the metrics collector.
pending_bytes: u64,
/// Optional connection-level activity tracker. When set, poll_frame updates this
/// to keep the idle watchdog alive during active body streaming (uploads/downloads).
connection_activity: Option<Arc<AtomicU64>>,
@@ -47,8 +56,8 @@ impl<B> CountingBody<B> {
pub fn new(
inner: B,
metrics: Arc<MetricsCollector>,
route_id: Option<String>,
source_ip: Option<String>,
route_id: Option<Arc<str>>,
source_ip: Option<Arc<str>>,
direction: Direction,
) -> Self {
Self {
@@ -57,6 +66,7 @@ impl<B> CountingBody<B> {
route_id,
source_ip,
direction,
pending_bytes: 0,
connection_activity: None,
activity_start: None,
active_requests: None,
@@ -66,7 +76,11 @@ impl<B> CountingBody<B> {
/// Set the connection-level activity tracker. When set, each data frame
/// updates this timestamp to prevent the idle watchdog from killing the
/// connection during active body streaming.
pub fn with_connection_activity(mut self, activity: Arc<AtomicU64>, start: std::time::Instant) -> Self {
pub fn with_connection_activity(
mut self,
activity: Arc<AtomicU64>,
start: std::time::Instant,
) -> Self {
self.connection_activity = Some(activity);
self.activity_start = Some(start);
self
@@ -81,14 +95,19 @@ impl<B> CountingBody<B> {
self
}
/// Report a chunk of bytes immediately to the metrics collector.
/// Flush accumulated bytes to the metrics collector.
#[inline]
fn report_chunk(&self, len: u64) {
fn flush_pending(&mut self) {
if self.pending_bytes == 0 {
return;
}
let bytes = self.pending_bytes;
self.pending_bytes = 0;
let route_id = self.route_id.as_deref();
let source_ip = self.source_ip.as_deref();
match self.direction {
Direction::In => self.metrics.record_bytes(len, 0, route_id, source_ip),
Direction::Out => self.metrics.record_bytes(0, len, route_id, source_ip),
Direction::In => self.metrics.record_bytes(bytes, 0, route_id, source_ip),
Direction::Out => self.metrics.record_bytes(0, bytes, route_id, source_ip),
}
}
}
@@ -113,17 +132,26 @@ where
Poll::Ready(Some(Ok(frame))) => {
if let Some(data) = frame.data_ref() {
let len = data.len() as u64;
// Report bytes immediately so the 1 Hz throughput sampler sees them
this.report_chunk(len);
// Keep the connection-level idle watchdog alive during body streaming
if let (Some(activity), Some(start)) = (&this.connection_activity, &this.activity_start) {
this.pending_bytes += len;
if this.pending_bytes >= BYTE_FLUSH_THRESHOLD {
this.flush_pending();
}
// Keep the connection-level idle watchdog alive on every frame
// (this is just one atomic store — cheap enough per-frame)
if let (Some(activity), Some(start)) =
(&this.connection_activity, &this.activity_start)
{
activity.store(start.elapsed().as_millis() as u64, Ordering::Relaxed);
}
}
Poll::Ready(Some(Ok(frame)))
}
Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
Poll::Ready(None) => Poll::Ready(None),
Poll::Ready(None) => {
// End of stream — flush any remaining bytes
this.flush_pending();
Poll::Ready(None)
}
Poll::Pending => Poll::Pending,
}
}
@@ -139,6 +167,8 @@ where
impl<B> Drop for CountingBody<B> {
fn drop(&mut self) {
// Flush any remaining accumulated bytes so totals stay accurate
self.flush_pending();
// Decrement the active-request counter so the HTTP idle watchdog
// knows this response body is no longer streaming.
if let Some(ref counter) = self.active_requests {
@@ -0,0 +1,243 @@
//! HTTP/3 proxy service.
//!
//! Accepts QUIC connections via quinn, runs h3 server to handle HTTP/3 requests,
//! and delegates backend forwarding to the shared `HttpProxyService` — same
//! route matching, connection pool, and protocol auto-detection as TCP/HTTP.
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use bytes::{Buf, Bytes};
use http_body::Frame;
use http_body_util::combinators::BoxBody;
use http_body_util::BodyExt;
use tracing::{debug, warn};
use rustproxy_config::RouteConfig;
use tokio_util::sync::CancellationToken;
use crate::proxy_service::{ConnActivity, HttpProxyService, ProtocolGuard};
/// HTTP/3 proxy service.
///
/// Accepts QUIC connections, parses HTTP/3 requests, and delegates backend
/// forwarding to the shared `HttpProxyService`.
pub struct H3ProxyService {
http_proxy: Arc<HttpProxyService>,
}
impl H3ProxyService {
pub fn new(http_proxy: Arc<HttpProxyService>) -> Self {
Self { http_proxy }
}
/// Handle an accepted QUIC connection as HTTP/3.
///
/// If `real_client_addr` is provided (from PROXY protocol), it overrides
/// `connection.remote_address()` for client IP attribution.
pub async fn handle_connection(
&self,
connection: quinn::Connection,
_fallback_route: &RouteConfig,
port: u16,
real_client_addr: Option<SocketAddr>,
parent_cancel: &CancellationToken,
) -> anyhow::Result<()> {
let remote_addr = real_client_addr.unwrap_or_else(|| connection.remote_address());
debug!("HTTP/3 connection from {} on port {}", remote_addr, port);
// Track frontend H3 connection for the QUIC connection's lifetime.
let _frontend_h3_guard =
ProtocolGuard::frontend(Arc::clone(self.http_proxy.metrics()), "h3");
let mut h3_conn: h3::server::Connection<h3_quinn::Connection, Bytes> =
h3::server::builder()
.send_grease(false)
.build(h3_quinn::Connection::new(connection))
.await
.map_err(|e| anyhow::anyhow!("H3 connection setup failed: {}", e))?;
loop {
let resolver = tokio::select! {
_ = parent_cancel.cancelled() => {
debug!("HTTP/3 connection from {} cancelled by parent", remote_addr);
break;
}
result = h3_conn.accept() => {
match result {
Ok(Some(resolver)) => resolver,
Ok(None) => {
debug!("HTTP/3 connection from {} closed", remote_addr);
break;
}
Err(e) => {
debug!("HTTP/3 accept error from {}: {}", remote_addr, e);
break;
}
}
}
};
let (request, stream) = match resolver.resolve_request().await {
Ok(pair) => pair,
Err(e) => {
debug!("HTTP/3 request resolve error: {}", e);
continue;
}
};
let http_proxy = Arc::clone(&self.http_proxy);
let request_cancel = parent_cancel.child_token();
tokio::spawn(async move {
if let Err(e) = handle_h3_request(
request,
stream,
port,
remote_addr,
&http_proxy,
request_cancel,
)
.await
{
debug!("HTTP/3 request error from {}: {}", remote_addr, e);
}
});
}
Ok(())
}
}
/// Handle a single HTTP/3 request by delegating to HttpProxyService.
///
/// 1. Read the H3 request body via an mpsc channel (streaming, not buffered)
/// 2. Build a `hyper::Request<BoxBody>` that HttpProxyService can handle
/// 3. Call `HttpProxyService::handle_request` — same route matching, connection
/// pool, ALPN protocol detection (H1/H2/H3) as the TCP/HTTP path
/// 4. Stream the response back over the H3 stream
async fn handle_h3_request(
request: hyper::Request<()>,
mut stream: h3::server::RequestStream<h3_quinn::BidiStream<Bytes>, Bytes>,
port: u16,
peer_addr: SocketAddr,
http_proxy: &HttpProxyService,
cancel: CancellationToken,
) -> anyhow::Result<()> {
// Stream request body from H3 client via an mpsc channel.
let (body_tx, body_rx) = tokio::sync::mpsc::channel::<Bytes>(32);
// Spawn the H3 body reader task with cancellation
let body_cancel = cancel.clone();
let body_reader = tokio::spawn(async move {
loop {
let chunk = tokio::select! {
_ = body_cancel.cancelled() => break,
result = stream.recv_data() => {
match result {
Ok(Some(chunk)) => chunk,
_ => break,
}
}
};
let mut chunk = chunk;
let data = chunk.copy_to_bytes(chunk.remaining());
if body_tx.send(data).await.is_err() {
break;
}
}
stream
});
// Build a hyper::Request<BoxBody> from the H3 request + streaming body.
// The URI already has scheme + authority + path set by the h3 crate.
let body = H3RequestBody { receiver: body_rx };
let (parts, _) = request.into_parts();
let boxed_body: BoxBody<Bytes, hyper::Error> = BoxBody::new(body);
let req = hyper::Request::from_parts(parts, boxed_body);
// Delegate to HttpProxyService — same backend path as TCP/HTTP:
// route matching, ALPN protocol detection, connection pool, H1/H2/H3 auto.
let conn_activity = ConnActivity::new_standalone();
let response = http_proxy
.handle_request(req, peer_addr, port, cancel, conn_activity)
.await
.map_err(|e| anyhow::anyhow!("Backend request failed: {}", e))?;
// Await the body reader to get the H3 stream back
let mut stream = body_reader
.await
.map_err(|e| anyhow::anyhow!("Body reader task failed: {}", e))?;
// Send response headers over H3 (skip hop-by-hop headers)
let (resp_parts, resp_body) = response.into_parts();
let mut h3_response = hyper::Response::builder().status(resp_parts.status);
for (name, value) in &resp_parts.headers {
let n = name.as_str();
if n == "transfer-encoding" || n == "connection" || n == "keep-alive" || n == "upgrade" {
continue;
}
h3_response = h3_response.header(name, value);
}
let h3_response = h3_response
.body(())
.map_err(|e| anyhow::anyhow!("Failed to build H3 response: {}", e))?;
stream
.send_response(h3_response)
.await
.map_err(|e| anyhow::anyhow!("Failed to send H3 response: {}", e))?;
// Stream response body back over H3
let mut resp_body = resp_body;
while let Some(frame) = resp_body.frame().await {
match frame {
Ok(frame) => {
if let Ok(data) = frame.into_data() {
stream
.send_data(data)
.await
.map_err(|e| anyhow::anyhow!("Failed to send H3 data: {}", e))?;
}
}
Err(e) => {
warn!("Response body read error: {}", e);
break;
}
}
}
// Finish the H3 stream (send QUIC FIN)
stream
.finish()
.await
.map_err(|e| anyhow::anyhow!("Failed to finish H3 stream: {}", e))?;
Ok(())
}
/// A streaming request body backed by an mpsc channel receiver.
///
/// Implements `http_body::Body` so hyper can poll chunks as they arrive
/// from the H3 client, avoiding buffering the entire request body in memory.
struct H3RequestBody {
receiver: tokio::sync::mpsc::Receiver<Bytes>,
}
impl http_body::Body for H3RequestBody {
type Data = Bytes;
type Error = hyper::Error;
fn poll_frame(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
match self.receiver.poll_recv(cx) {
Poll::Ready(Some(data)) => Poll::Ready(Some(Ok(Frame::data(data)))),
Poll::Ready(None) => Poll::Ready(None),
Poll::Pending => Poll::Pending,
}
}
}
+2
View File
@@ -5,9 +5,11 @@
pub mod connection_pool;
pub mod counting_body;
pub mod h3_service;
pub mod protocol_cache;
pub mod proxy_service;
pub mod request_filter;
mod request_host;
pub mod response_filter;
pub mod shutdown_on_drop;
pub mod template;
+559 -36
View File
@@ -1,31 +1,99 @@
//! Bounded, TTL-based protocol detection cache for HTTP/2 auto-detection.
//! Bounded, sliding-TTL protocol detection cache with periodic re-probing and failure suppression.
//!
//! Caches the ALPN-negotiated protocol (H1 or H2) per backend endpoint and requested
//! Caches the detected protocol (H1, H2, or H3) per backend endpoint and requested
//! domain (host:port + requested_host). This prevents cache oscillation when multiple
//! frontend domains share the same backend but differ in HTTP/2 support.
//! frontend domains share the same backend but differ in protocol support.
//!
//! ## Sliding TTL
//!
//! Each cache hit refreshes the entry's expiry timer (`last_accessed_at`). Entries
//! remain valid for up to 1 day of continuous use. Every 5 minutes, the next request
//! triggers an inline ALPN re-probe to verify the cached protocol is still correct.
//!
//! ## Upgrade signals
//!
//! - ALPN (TLS handshake) → detects H2 vs H1
//! - Alt-Svc (response header) → advertises H3
//!
//! ## Protocol transitions
//!
//! All protocol changes are logged at `info!()` level with the reason:
//! "Protocol transition: H1 → H2 because periodic ALPN re-probe"
//!
//! ## Failure suppression
//!
//! When a protocol fails, `record_failure()` prevents upgrade signals from
//! re-introducing it until an escalating cooldown expires (5s → 10s → ... → 300s).
//! Within-request escalation is allowed via `can_retry()` after a 5s minimum gap.
//!
//! ## Total failure eviction
//!
//! When all protocols (H3, H2, H1) fail for a backend, the cache entry is evicted
//! entirely via `evict()`, forcing a fresh probe on the next request.
//!
//! Cascading: when a lower protocol also fails, higher protocol cooldowns are
//! reduced to 5s remaining (not instant clear), preventing tight retry loops.
use std::sync::Arc;
use std::time::{Duration, Instant};
use dashmap::DashMap;
use tracing::debug;
use tracing::{debug, info};
/// TTL for cached protocol detection results.
/// After this duration, the next request will re-probe the backend.
const PROTOCOL_CACHE_TTL: Duration = Duration::from_secs(300); // 5 minutes
/// Sliding TTL for cached protocol detection results.
/// Entries that haven't been accessed for this duration are evicted.
/// Each `get()` call refreshes the timer (sliding window).
const PROTOCOL_CACHE_TTL: Duration = Duration::from_secs(86400); // 1 day
/// Interval between inline ALPN re-probes for H1/H2 entries.
/// When a cached entry's `last_probed_at` exceeds this, the next request
/// triggers an ALPN re-probe to verify the backend still speaks the same protocol.
const PROTOCOL_REPROBE_INTERVAL: Duration = Duration::from_secs(300); // 5 minutes
/// Maximum number of entries in the protocol cache.
/// Prevents unbounded growth when backends come and go.
const PROTOCOL_CACHE_MAX_ENTRIES: usize = 4096;
/// Background cleanup interval for the protocol cache.
/// Background cleanup interval.
const PROTOCOL_CACHE_CLEANUP_INTERVAL: Duration = Duration::from_secs(60);
/// Minimum cooldown between retry attempts of a failed protocol.
const PROTOCOL_FAILURE_COOLDOWN: Duration = Duration::from_secs(5);
/// Maximum cooldown (escalation ceiling).
const PROTOCOL_FAILURE_MAX_COOLDOWN: Duration = Duration::from_secs(300);
/// Consecutive failure count at which cooldown reaches maximum.
/// 5s × 2^5 = 160s, 5s × 2^6 = 320s → capped at 300s.
const PROTOCOL_FAILURE_ESCALATION_CAP: u32 = 6;
/// Detected backend protocol.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum DetectedProtocol {
H1,
H2,
H3,
}
impl std::fmt::Display for DetectedProtocol {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
DetectedProtocol::H1 => write!(f, "H1"),
DetectedProtocol::H2 => write!(f, "H2"),
DetectedProtocol::H3 => write!(f, "H3"),
}
}
}
/// Result of a protocol cache lookup.
#[derive(Debug, Clone, Copy)]
pub struct CachedProtocol {
pub protocol: DetectedProtocol,
/// For H3: the port advertised by Alt-Svc (may differ from TCP port).
pub h3_port: Option<u16>,
/// True if the entry's `last_probed_at` exceeds `PROTOCOL_REPROBE_INTERVAL`.
/// Caller should perform an inline ALPN re-probe and call `update_probe_result()`.
/// Always `false` for H3 entries (H3 is discovered via Alt-Svc, not ALPN).
pub needs_reprobe: bool,
}
/// Key for the protocol cache: (host, port, requested_host).
@@ -38,22 +106,115 @@ pub struct ProtocolCacheKey {
pub requested_host: Option<String>,
}
/// A cached protocol detection result with a timestamp.
/// A cached protocol detection result with timestamps.
struct CachedEntry {
protocol: DetectedProtocol,
/// When this protocol was first detected (or last changed).
detected_at: Instant,
/// Last time any request used this entry (sliding-window TTL).
last_accessed_at: Instant,
/// Last time an ALPN re-probe was performed for this entry.
last_probed_at: Instant,
/// For H3: the port advertised by Alt-Svc (may differ from TCP port).
h3_port: Option<u16>,
}
/// Bounded, TTL-based protocol detection cache.
/// Failure record for a single protocol level.
#[derive(Debug, Clone)]
struct FailureRecord {
/// When the failure was last recorded.
failed_at: Instant,
/// Current cooldown duration. Escalates on consecutive failures.
cooldown: Duration,
/// Number of consecutive failures (for escalation).
consecutive_failures: u32,
}
/// Per-key failure state. Tracks failures at each upgradeable protocol level.
/// H1 is never tracked (it's the protocol floor — nothing to fall back to).
#[derive(Debug, Clone, Default)]
struct FailureState {
h2: Option<FailureRecord>,
h3: Option<FailureRecord>,
}
impl FailureState {
fn is_empty(&self) -> bool {
self.h2.is_none() && self.h3.is_none()
}
fn all_expired(&self) -> bool {
let h2_expired = self
.h2
.as_ref()
.map(|r| r.failed_at.elapsed() >= r.cooldown)
.unwrap_or(true);
let h3_expired = self
.h3
.as_ref()
.map(|r| r.failed_at.elapsed() >= r.cooldown)
.unwrap_or(true);
h2_expired && h3_expired
}
fn get(&self, protocol: DetectedProtocol) -> Option<&FailureRecord> {
match protocol {
DetectedProtocol::H2 => self.h2.as_ref(),
DetectedProtocol::H3 => self.h3.as_ref(),
DetectedProtocol::H1 => None,
}
}
fn get_mut(&mut self, protocol: DetectedProtocol) -> &mut Option<FailureRecord> {
match protocol {
DetectedProtocol::H2 => &mut self.h2,
DetectedProtocol::H3 => &mut self.h3,
DetectedProtocol::H1 => unreachable!("H1 failures are never recorded"),
}
}
}
/// Snapshot of a single protocol cache entry, suitable for metrics/UI display.
#[derive(Debug, Clone)]
pub struct ProtocolCacheEntry {
pub host: String,
pub port: u16,
pub domain: Option<String>,
pub protocol: String,
pub h3_port: Option<u16>,
pub age_secs: u64,
pub last_accessed_secs: u64,
pub last_probed_secs: u64,
pub h2_suppressed: bool,
pub h3_suppressed: bool,
pub h2_cooldown_remaining_secs: Option<u64>,
pub h3_cooldown_remaining_secs: Option<u64>,
pub h2_consecutive_failures: Option<u32>,
pub h3_consecutive_failures: Option<u32>,
}
/// Exponential backoff: PROTOCOL_FAILURE_COOLDOWN × 2^(n-1), capped at MAX.
fn escalate_cooldown(consecutive: u32) -> Duration {
let base = PROTOCOL_FAILURE_COOLDOWN.as_secs();
let exp = consecutive.saturating_sub(1).min(63) as u64;
let secs = base.saturating_mul(1u64.checked_shl(exp as u32).unwrap_or(u64::MAX));
Duration::from_secs(secs.min(PROTOCOL_FAILURE_MAX_COOLDOWN.as_secs()))
}
/// Bounded, sliding-TTL protocol detection cache with failure suppression.
///
/// Memory safety guarantees:
/// - Hard cap at `PROTOCOL_CACHE_MAX_ENTRIES` — cannot grow unboundedly.
/// - TTL expiry — stale entries naturally age out on lookup.
/// - Sliding TTL expiry — entries age out after 1 day without access.
/// - Background cleanup task — proactively removes expired entries every 60s.
/// - `clear()` — called on route updates to discard stale detections.
/// - `Drop` — aborts the background task to prevent dangling tokio tasks.
pub struct ProtocolCache {
cache: Arc<DashMap<ProtocolCacheKey, CachedEntry>>,
/// Generic protocol failure suppression map. Tracks per-protocol failure
/// records (H2, H3) for each cache key. Used to prevent upgrade signals
/// (ALPN, Alt-Svc) from re-introducing failed protocols.
failures: Arc<DashMap<ProtocolCacheKey, FailureState>>,
cleanup_handle: Option<tokio::task::JoinHandle<()>>,
}
@@ -61,24 +222,41 @@ impl ProtocolCache {
/// Create a new protocol cache and start the background cleanup task.
pub fn new() -> Self {
let cache: Arc<DashMap<ProtocolCacheKey, CachedEntry>> = Arc::new(DashMap::new());
let failures: Arc<DashMap<ProtocolCacheKey, FailureState>> = Arc::new(DashMap::new());
let cache_clone = Arc::clone(&cache);
let failures_clone = Arc::clone(&failures);
let cleanup_handle = tokio::spawn(async move {
Self::cleanup_loop(cache_clone).await;
Self::cleanup_loop(cache_clone, failures_clone).await;
});
Self {
cache,
failures,
cleanup_handle: Some(cleanup_handle),
}
}
/// Look up the cached protocol for a backend endpoint.
///
/// Returns `None` if not cached or expired (caller should probe via ALPN).
pub fn get(&self, key: &ProtocolCacheKey) -> Option<DetectedProtocol> {
let entry = self.cache.get(key)?;
if entry.detected_at.elapsed() < PROTOCOL_CACHE_TTL {
debug!("Protocol cache hit: {:?} for {}:{} (requested: {:?})", entry.protocol, key.host, key.port, key.requested_host);
Some(entry.protocol)
/// On hit, refreshes `last_accessed_at` (sliding TTL) and sets `needs_reprobe`
/// if the entry hasn't been probed in over 5 minutes (H1/H2 only).
pub fn get(&self, key: &ProtocolCacheKey) -> Option<CachedProtocol> {
let mut entry = self.cache.get_mut(key)?;
if entry.last_accessed_at.elapsed() < PROTOCOL_CACHE_TTL {
// Refresh sliding TTL
entry.last_accessed_at = Instant::now();
// H3 is the ceiling — can't ALPN-probe for H3 (discovered via Alt-Svc).
// Only H1/H2 entries trigger periodic re-probing.
let needs_reprobe = entry.protocol != DetectedProtocol::H3
&& entry.last_probed_at.elapsed() >= PROTOCOL_REPROBE_INTERVAL;
Some(CachedProtocol {
protocol: entry.protocol,
h3_port: entry.h3_port,
needs_reprobe,
})
} else {
// Expired — remove and return None to trigger re-probe
drop(entry); // release DashMap ref before remove
@@ -88,45 +266,390 @@ impl ProtocolCache {
}
/// Insert a detected protocol into the cache.
/// If the cache is at capacity, evict the oldest entry first.
pub fn insert(&self, key: ProtocolCacheKey, protocol: DetectedProtocol) {
if self.cache.len() >= PROTOCOL_CACHE_MAX_ENTRIES && !self.cache.contains_key(&key) {
// Evict the oldest entry to stay within bounds
let oldest = self.cache.iter()
.min_by_key(|entry| entry.value().detected_at)
.map(|entry| entry.key().clone());
if let Some(oldest_key) = oldest {
self.cache.remove(&oldest_key);
/// Returns `false` if suppressed due to active failure suppression.
///
/// **Key semantic**: only suppresses if the protocol being inserted matches
/// a suppressed protocol. H1 inserts are NEVER suppressed — downgrades
/// always succeed.
pub fn insert(&self, key: ProtocolCacheKey, protocol: DetectedProtocol, reason: &str) -> bool {
if self.is_suppressed(&key, protocol) {
debug!(
host = %key.host, port = %key.port, domain = ?key.requested_host,
protocol = ?protocol,
"Protocol cache insert suppressed — recent failure"
);
return false;
}
self.insert_internal(key, protocol, None, reason);
true
}
/// Insert an H3 detection result with the Alt-Svc advertised port.
/// Returns `false` if H3 is suppressed.
pub fn insert_h3(&self, key: ProtocolCacheKey, h3_port: u16, reason: &str) -> bool {
if self.is_suppressed(&key, DetectedProtocol::H3) {
debug!(
host = %key.host, port = %key.port, domain = ?key.requested_host,
"H3 upgrade suppressed — recent failure"
);
return false;
}
self.insert_internal(key, DetectedProtocol::H3, Some(h3_port), reason);
true
}
/// Update the cache after an inline ALPN re-probe completes.
///
/// Always updates `last_probed_at`. If the protocol changed, logs the transition
/// and updates the entry. Returns `Some(new_protocol)` if changed, `None` if unchanged.
pub fn update_probe_result(
&self,
key: &ProtocolCacheKey,
probed_protocol: DetectedProtocol,
reason: &str,
) -> Option<DetectedProtocol> {
if let Some(mut entry) = self.cache.get_mut(key) {
let old_protocol = entry.protocol;
entry.last_probed_at = Instant::now();
entry.last_accessed_at = Instant::now();
if old_protocol != probed_protocol {
info!(
host = %key.host, port = %key.port, domain = ?key.requested_host,
old = %old_protocol, new = %probed_protocol, reason = %reason,
"Protocol transition"
);
entry.protocol = probed_protocol;
entry.detected_at = Instant::now();
// Clear h3_port if downgrading from H3
if old_protocol == DetectedProtocol::H3 && probed_protocol != DetectedProtocol::H3 {
entry.h3_port = None;
}
return Some(probed_protocol);
}
debug!(
host = %key.host, port = %key.port, domain = ?key.requested_host,
protocol = %old_protocol, reason = %reason,
"Re-probe confirmed — no protocol change"
);
None
} else {
// Entry was evicted between the get() and the probe completing.
// Insert as a fresh entry.
self.insert_internal(key.clone(), probed_protocol, None, reason);
Some(probed_protocol)
}
}
self.cache.insert(key, CachedEntry {
protocol,
detected_at: Instant::now(),
/// Record a protocol failure. Future `insert()` calls for this protocol
/// will be suppressed until the escalating cooldown expires.
///
/// Cooldown escalation: 5s → 10s → 20s → 40s → 80s → 160s → 300s.
/// Consecutive counter resets if the previous failure is older than 2× its cooldown.
///
/// Cascading: when H2 fails, H3 cooldown is reduced to 5s remaining.
/// H1 failures are ignored (H1 is the protocol floor).
pub fn record_failure(&self, key: ProtocolCacheKey, protocol: DetectedProtocol) {
if protocol == DetectedProtocol::H1 {
return; // H1 is the floor — nothing to suppress
}
let mut entry = self.failures.entry(key.clone()).or_default();
let record = entry.get_mut(protocol);
let (consecutive, new_cooldown) = match record {
Some(existing)
if existing.failed_at.elapsed() < existing.cooldown.saturating_mul(2) =>
{
// Still within the "recent" window — escalate
let c = existing
.consecutive_failures
.saturating_add(1)
.min(PROTOCOL_FAILURE_ESCALATION_CAP);
(c, escalate_cooldown(c))
}
_ => {
// First failure or old failure that expired long ago — reset
(1, PROTOCOL_FAILURE_COOLDOWN)
}
};
*record = Some(FailureRecord {
failed_at: Instant::now(),
cooldown: new_cooldown,
consecutive_failures: consecutive,
});
// Cascading: when H2 fails, reduce H3 cooldown to 5s remaining
if protocol == DetectedProtocol::H2 {
Self::reduce_cooldown_to(entry.h3.as_mut(), PROTOCOL_FAILURE_COOLDOWN);
}
info!(
host = %key.host, port = %key.port, domain = ?key.requested_host,
protocol = ?protocol,
consecutive = consecutive,
cooldown_secs = new_cooldown.as_secs(),
"Protocol failure recorded — suppressing for {:?}", new_cooldown
);
}
/// Check whether a protocol is currently suppressed for the given key.
/// Returns `true` if the protocol failed within its cooldown period.
/// H1 is never suppressed.
pub fn is_suppressed(&self, key: &ProtocolCacheKey, protocol: DetectedProtocol) -> bool {
if protocol == DetectedProtocol::H1 {
return false;
}
self.failures
.get(key)
.and_then(|entry| {
entry
.get(protocol)
.map(|r| r.failed_at.elapsed() < r.cooldown)
})
.unwrap_or(false)
}
/// Check whether a protocol can be retried (for within-request escalation).
/// Returns `true` if there's no failure record OR if ≥5s have passed since
/// the last attempt. More permissive than `is_suppressed`.
pub fn can_retry(&self, key: &ProtocolCacheKey, protocol: DetectedProtocol) -> bool {
if protocol == DetectedProtocol::H1 {
return true;
}
match self.failures.get(key) {
Some(entry) => match entry.get(protocol) {
Some(r) => r.failed_at.elapsed() >= PROTOCOL_FAILURE_COOLDOWN,
None => true, // no failure record
},
None => true,
}
}
/// Record a retry attempt WITHOUT escalating the cooldown.
/// Resets the `failed_at` timestamp to prevent rapid retries (5s gate).
/// Called before an escalation attempt. If the attempt fails,
/// `record_failure` should be called afterward with proper escalation.
pub fn record_retry_attempt(&self, key: &ProtocolCacheKey, protocol: DetectedProtocol) {
if protocol == DetectedProtocol::H1 {
return;
}
if let Some(mut entry) = self.failures.get_mut(key) {
if let Some(ref mut r) = entry.get_mut(protocol) {
r.failed_at = Instant::now();
}
}
}
/// Clear the failure record for a protocol (it recovered).
/// Called when an escalation retry succeeds.
pub fn clear_failure(&self, key: &ProtocolCacheKey, protocol: DetectedProtocol) {
if protocol == DetectedProtocol::H1 {
return;
}
if let Some(mut entry) = self.failures.get_mut(key) {
*entry.get_mut(protocol) = None;
if entry.is_empty() {
drop(entry);
self.failures.remove(key);
}
}
}
/// Evict a cache entry entirely. Called when all protocol probes (H3, H2, H1)
/// have failed for a backend.
pub fn evict(&self, key: &ProtocolCacheKey) {
self.cache.remove(key);
self.failures.remove(key);
info!(
host = %key.host, port = %key.port, domain = ?key.requested_host,
"Cache entry evicted — all protocols failed"
);
}
/// Clear all entries. Called on route updates to discard stale detections.
pub fn clear(&self) {
self.cache.clear();
self.failures.clear();
}
/// Background cleanup loop — removes expired entries every `PROTOCOL_CACHE_CLEANUP_INTERVAL`.
async fn cleanup_loop(cache: Arc<DashMap<ProtocolCacheKey, CachedEntry>>) {
/// Snapshot all non-expired cache entries for metrics/UI display.
pub fn snapshot(&self) -> Vec<ProtocolCacheEntry> {
self.cache
.iter()
.filter(|entry| entry.value().last_accessed_at.elapsed() < PROTOCOL_CACHE_TTL)
.map(|entry| {
let key = entry.key();
let val = entry.value();
let failure_info = self.failures.get(key);
let (h2_sup, h2_cd, h2_cons) =
Self::suppression_info(failure_info.as_deref().and_then(|f| f.h2.as_ref()));
let (h3_sup, h3_cd, h3_cons) =
Self::suppression_info(failure_info.as_deref().and_then(|f| f.h3.as_ref()));
ProtocolCacheEntry {
host: key.host.clone(),
port: key.port,
domain: key.requested_host.clone(),
protocol: match val.protocol {
DetectedProtocol::H1 => "h1".to_string(),
DetectedProtocol::H2 => "h2".to_string(),
DetectedProtocol::H3 => "h3".to_string(),
},
h3_port: val.h3_port,
age_secs: val.detected_at.elapsed().as_secs(),
last_accessed_secs: val.last_accessed_at.elapsed().as_secs(),
last_probed_secs: val.last_probed_at.elapsed().as_secs(),
h2_suppressed: h2_sup,
h3_suppressed: h3_sup,
h2_cooldown_remaining_secs: h2_cd,
h3_cooldown_remaining_secs: h3_cd,
h2_consecutive_failures: h2_cons,
h3_consecutive_failures: h3_cons,
}
})
.collect()
}
// --- Internal helpers ---
/// Insert a protocol detection result with an optional H3 port.
/// Logs protocol transitions when overwriting an existing entry.
/// No suppression check — callers must check before calling.
fn insert_internal(
&self,
key: ProtocolCacheKey,
protocol: DetectedProtocol,
h3_port: Option<u16>,
reason: &str,
) {
// Check for existing entry to log protocol transitions
if let Some(existing) = self.cache.get(&key) {
if existing.protocol != protocol {
info!(
host = %key.host, port = %key.port, domain = ?key.requested_host,
old = %existing.protocol, new = %protocol, reason = %reason,
"Protocol transition"
);
}
drop(existing);
}
// Evict oldest entry if at capacity
if self.cache.len() >= PROTOCOL_CACHE_MAX_ENTRIES && !self.cache.contains_key(&key) {
let oldest = self
.cache
.iter()
.min_by_key(|entry| entry.value().last_accessed_at)
.map(|entry| entry.key().clone());
if let Some(oldest_key) = oldest {
self.cache.remove(&oldest_key);
}
}
let now = Instant::now();
self.cache.insert(
key,
CachedEntry {
protocol,
detected_at: now,
last_accessed_at: now,
last_probed_at: now,
h3_port,
},
);
}
/// Reduce a failure record's remaining cooldown to `target`, if it currently
/// has MORE than `target` remaining. Never increases cooldown.
fn reduce_cooldown_to(record: Option<&mut FailureRecord>, target: Duration) {
if let Some(r) = record {
let elapsed = r.failed_at.elapsed();
if elapsed < r.cooldown {
let remaining = r.cooldown - elapsed;
if remaining > target {
// Shrink cooldown so it expires in `target` from now
r.cooldown = elapsed + target;
}
}
}
}
/// Extract suppression info from a failure record for metrics.
fn suppression_info(record: Option<&FailureRecord>) -> (bool, Option<u64>, Option<u32>) {
match record {
Some(r) => {
let elapsed = r.failed_at.elapsed();
let suppressed = elapsed < r.cooldown;
let remaining = if suppressed {
Some((r.cooldown - elapsed).as_secs())
} else {
None
};
(suppressed, remaining, Some(r.consecutive_failures))
}
None => (false, None, None),
}
}
/// Background cleanup loop.
async fn cleanup_loop(
cache: Arc<DashMap<ProtocolCacheKey, CachedEntry>>,
failures: Arc<DashMap<ProtocolCacheKey, FailureState>>,
) {
let mut interval = tokio::time::interval(PROTOCOL_CACHE_CLEANUP_INTERVAL);
loop {
interval.tick().await;
let expired: Vec<ProtocolCacheKey> = cache.iter()
.filter(|entry| entry.value().detected_at.elapsed() >= PROTOCOL_CACHE_TTL)
// Clean expired cache entries (sliding TTL based on last_accessed_at)
let expired: Vec<ProtocolCacheKey> = cache
.iter()
.filter(|entry| entry.value().last_accessed_at.elapsed() >= PROTOCOL_CACHE_TTL)
.map(|entry| entry.key().clone())
.collect();
if !expired.is_empty() {
debug!("Protocol cache cleanup: removing {} expired entries", expired.len());
debug!(
"Protocol cache cleanup: removing {} expired entries",
expired.len()
);
for key in expired {
cache.remove(&key);
}
}
// Clean fully-expired failure entries
let expired_failures: Vec<ProtocolCacheKey> = failures
.iter()
.filter(|entry| entry.value().all_expired())
.map(|entry| entry.key().clone())
.collect();
if !expired_failures.is_empty() {
debug!(
"Protocol cache cleanup: removing {} expired failure entries",
expired_failures.len()
);
for key in expired_failures {
failures.remove(&key);
}
}
// Safety net: cap failures map at 2× max entries
if failures.len() > PROTOCOL_CACHE_MAX_ENTRIES * 2 {
let oldest: Vec<ProtocolCacheKey> = failures
.iter()
.filter(|e| e.value().all_expired())
.map(|e| e.key().clone())
.take(failures.len() - PROTOCOL_CACHE_MAX_ENTRIES)
.collect();
for key in oldest {
failures.remove(&key);
}
}
}
}
}
File diff suppressed because it is too large Load Diff
+139 -35
View File
@@ -4,14 +4,15 @@ use std::net::SocketAddr;
use std::sync::Arc;
use bytes::Bytes;
use http_body_util::Full;
use http_body_util::BodyExt;
use hyper::body::Incoming;
use hyper::{Request, Response, StatusCode};
use http_body_util::combinators::BoxBody;
use http_body_util::BodyExt;
use http_body_util::Full;
use hyper::{Request, Response, StatusCode};
use rustproxy_config::RouteSecurity;
use rustproxy_security::{IpFilter, BasicAuthValidator, JwtValidator, RateLimiter};
use rustproxy_security::{BasicAuthValidator, IpFilter, JwtValidator, RateLimiter};
use crate::request_host::extract_request_host;
pub struct RequestFilter;
@@ -19,7 +20,7 @@ impl RequestFilter {
/// Apply security filters. Returns Some(response) if the request should be blocked.
pub fn apply(
security: &RouteSecurity,
req: &Request<Incoming>,
req: &Request<impl hyper::body::Body>,
peer_addr: &SocketAddr,
) -> Option<Response<BoxBody<Bytes, hyper::Error>>> {
Self::apply_with_rate_limiter(security, req, peer_addr, None)
@@ -29,20 +30,21 @@ impl RequestFilter {
/// Returns Some(response) if the request should be blocked.
pub fn apply_with_rate_limiter(
security: &RouteSecurity,
req: &Request<Incoming>,
req: &Request<impl hyper::body::Body>,
peer_addr: &SocketAddr,
rate_limiter: Option<&Arc<RateLimiter>>,
) -> Option<Response<BoxBody<Bytes, hyper::Error>>> {
let client_ip = peer_addr.ip();
let request_path = req.uri().path();
// IP filter
// IP filter (domain-aware: use the same host extraction as route matching)
if security.ip_allow_list.is_some() || security.ip_block_list.is_some() {
let allow = security.ip_allow_list.as_deref().unwrap_or(&[]);
let block = security.ip_block_list.as_deref().unwrap_or(&[]);
let filter = IpFilter::new(allow, block);
let normalized = IpFilter::normalize_ip(&client_ip);
if !filter.is_allowed(&normalized) {
let host = extract_request_host(req);
if !filter.is_allowed_for_domain(&normalized, host) {
return Some(error_response(StatusCode::FORBIDDEN, "Access denied"));
}
}
@@ -56,16 +58,15 @@ impl RequestFilter {
!limiter.check(&key)
} else {
// Create a per-check limiter (less ideal but works for non-shared case)
let limiter = RateLimiter::new(
rate_limit_config.max_requests,
rate_limit_config.window,
);
let limiter =
RateLimiter::new(rate_limit_config.max_requests, rate_limit_config.window);
let key = Self::rate_limit_key(rate_limit_config, req, peer_addr);
!limiter.check(&key)
};
if should_block {
let message = rate_limit_config.error_message
let message = rate_limit_config
.error_message
.as_deref()
.unwrap_or("Rate limit exceeded");
return Some(error_response(StatusCode::TOO_MANY_REQUESTS, message));
@@ -81,36 +82,48 @@ impl RequestFilter {
if let Some(ref basic_auth) = security.basic_auth {
if basic_auth.enabled {
// Check basic auth exclude paths
let skip_basic = basic_auth.exclude_paths.as_ref()
let skip_basic = basic_auth
.exclude_paths
.as_ref()
.map(|paths| Self::path_matches_any(request_path, paths))
.unwrap_or(false);
if !skip_basic {
let users: Vec<(String, String)> = basic_auth.users.iter()
let users: Vec<(String, String)> = basic_auth
.users
.iter()
.map(|c| (c.username.clone(), c.password.clone()))
.collect();
let validator = BasicAuthValidator::new(users, basic_auth.realm.clone());
let auth_header = req.headers()
let auth_header = req
.headers()
.get("authorization")
.and_then(|v| v.to_str().ok());
match auth_header {
Some(header) => {
if validator.validate(header).is_none() {
return Some(Response::builder()
return Some(
Response::builder()
.status(StatusCode::UNAUTHORIZED)
.header("WWW-Authenticate", validator.www_authenticate())
.header(
"WWW-Authenticate",
validator.www_authenticate(),
)
.body(boxed_body("Invalid credentials"))
.unwrap());
.unwrap(),
);
}
}
None => {
return Some(Response::builder()
return Some(
Response::builder()
.status(StatusCode::UNAUTHORIZED)
.header("WWW-Authenticate", validator.www_authenticate())
.body(boxed_body("Authentication required"))
.unwrap());
.unwrap(),
);
}
}
}
@@ -121,7 +134,9 @@ impl RequestFilter {
if let Some(ref jwt_auth) = security.jwt_auth {
if jwt_auth.enabled {
// Check JWT auth exclude paths
let skip_jwt = jwt_auth.exclude_paths.as_ref()
let skip_jwt = jwt_auth
.exclude_paths
.as_ref()
.map(|paths| Self::path_matches_any(request_path, paths))
.unwrap_or(false);
@@ -133,18 +148,25 @@ impl RequestFilter {
jwt_auth.audience.as_deref(),
);
let auth_header = req.headers()
let auth_header = req
.headers()
.get("authorization")
.and_then(|v| v.to_str().ok());
match auth_header.and_then(JwtValidator::extract_token) {
Some(token) => {
if validator.validate(token).is_err() {
return Some(error_response(StatusCode::UNAUTHORIZED, "Invalid token"));
return Some(error_response(
StatusCode::UNAUTHORIZED,
"Invalid token",
));
}
}
None => {
return Some(error_response(StatusCode::UNAUTHORIZED, "Bearer token required"));
return Some(error_response(
StatusCode::UNAUTHORIZED,
"Bearer token required",
));
}
}
}
@@ -182,7 +204,7 @@ impl RequestFilter {
/// Determine the rate limit key based on configuration.
fn rate_limit_key(
config: &rustproxy_config::RouteRateLimit,
req: &Request<Incoming>,
req: &Request<impl hyper::body::Body>,
peer_addr: &SocketAddr,
) -> String {
use rustproxy_config::RateLimitKeyBy;
@@ -204,14 +226,19 @@ impl RequestFilter {
}
/// Check IP-based security (for use in passthrough / TCP-level connections).
/// `domain` is the SNI from the TLS handshake (if available) for domain-scoped filtering.
/// Returns true if allowed, false if blocked.
pub fn check_ip_security(security: &RouteSecurity, client_ip: &std::net::IpAddr) -> bool {
pub fn check_ip_security(
security: &RouteSecurity,
client_ip: &std::net::IpAddr,
domain: Option<&str>,
) -> bool {
if security.ip_allow_list.is_some() || security.ip_block_list.is_some() {
let allow = security.ip_allow_list.as_deref().unwrap_or(&[]);
let block = security.ip_block_list.as_deref().unwrap_or(&[]);
let filter = IpFilter::new(allow, block);
let normalized = IpFilter::normalize_ip(client_ip);
filter.is_allowed(&normalized)
filter.is_allowed_for_domain(&normalized, domain)
} else {
true
}
@@ -220,7 +247,7 @@ impl RequestFilter {
/// Handle CORS preflight (OPTIONS) requests.
/// Returns Some(response) if this is a CORS preflight that should be handled.
pub fn handle_cors_preflight(
req: &Request<Incoming>,
req: &Request<impl hyper::body::Body>,
) -> Option<Response<BoxBody<Bytes, hyper::Error>>> {
if req.method() != hyper::Method::OPTIONS {
return None;
@@ -234,19 +261,28 @@ impl RequestFilter {
return None;
}
let origin = req.headers()
let origin = req
.headers()
.get("origin")
.and_then(|v| v.to_str().ok())
.unwrap_or("*");
Some(Response::builder()
Some(
Response::builder()
.status(StatusCode::NO_CONTENT)
.header("Access-Control-Allow-Origin", origin)
.header("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, PATCH, OPTIONS")
.header("Access-Control-Allow-Headers", "Content-Type, Authorization, X-Requested-With")
.header(
"Access-Control-Allow-Methods",
"GET, POST, PUT, DELETE, PATCH, OPTIONS",
)
.header(
"Access-Control-Allow-Headers",
"Content-Type, Authorization, X-Requested-With",
)
.header("Access-Control-Max-Age", "86400")
.body(boxed_body(""))
.unwrap())
.unwrap(),
)
}
}
@@ -261,3 +297,71 @@ fn error_response(status: StatusCode, message: &str) -> Response<BoxBody<Bytes,
fn boxed_body(data: &str) -> BoxBody<Bytes, hyper::Error> {
BoxBody::new(Full::new(Bytes::from(data.to_string())).map_err(|never| match never {}))
}
#[cfg(test)]
mod tests {
use bytes::Bytes;
use http_body_util::Empty;
use hyper::{Request, StatusCode, Version};
use rustproxy_config::{IpAllowEntry, RouteSecurity};
use super::RequestFilter;
fn domain_scoped_security() -> RouteSecurity {
RouteSecurity {
ip_allow_list: Some(vec![IpAllowEntry::DomainScoped {
ip: "10.8.0.2".to_string(),
domains: vec!["*.abc.xyz".to_string()],
}]),
ip_block_list: None,
max_connections: None,
authentication: None,
rate_limit: None,
basic_auth: None,
jwt_auth: None,
}
}
fn peer_addr() -> std::net::SocketAddr {
std::net::SocketAddr::from(([10, 8, 0, 2], 4242))
}
fn request(uri: &str, version: Version, host: Option<&str>) -> Request<Empty<Bytes>> {
let mut builder = Request::builder().uri(uri).version(version);
if let Some(host) = host {
builder = builder.header("host", host);
}
builder.body(Empty::<Bytes>::new()).unwrap()
}
#[test]
fn domain_scoped_acl_allows_uri_authority_without_host_header() {
let security = domain_scoped_security();
let req = request("https://outline.abc.xyz/", Version::HTTP_2, None);
assert!(RequestFilter::apply(&security, &req, &peer_addr()).is_none());
}
#[test]
fn domain_scoped_acl_allows_host_header_with_port() {
let security = domain_scoped_security();
let req = request(
"https://unrelated.invalid/",
Version::HTTP_11,
Some("outline.abc.xyz:443"),
);
assert!(RequestFilter::apply(&security, &req, &peer_addr()).is_none());
}
#[test]
fn domain_scoped_acl_denies_non_matching_uri_authority() {
let security = domain_scoped_security();
let req = request("https://outline.other.xyz/", Version::HTTP_2, None);
let response = RequestFilter::apply(&security, &req, &peer_addr())
.expect("non-matching domain should be denied");
assert_eq!(response.status(), StatusCode::FORBIDDEN);
}
}
@@ -0,0 +1,43 @@
use hyper::Request;
/// Extract the effective request host for routing and scoped ACL checks.
///
/// Prefer the explicit `Host` header when present, otherwise fall back to the
/// URI authority used by HTTP/2 and HTTP/3 requests.
pub(crate) fn extract_request_host<B>(req: &Request<B>) -> Option<&str> {
req.headers()
.get("host")
.and_then(|value| value.to_str().ok())
.map(|host| host.split(':').next().unwrap_or(host))
.or_else(|| req.uri().host())
}
#[cfg(test)]
mod tests {
use bytes::Bytes;
use http_body_util::Empty;
use hyper::Request;
use super::extract_request_host;
#[test]
fn extracts_host_header_before_uri_authority() {
let req = Request::builder()
.uri("https://uri.abc.xyz/test")
.header("host", "header.abc.xyz:443")
.body(Empty::<Bytes>::new())
.unwrap();
assert_eq!(extract_request_host(&req), Some("header.abc.xyz"));
}
#[test]
fn falls_back_to_uri_authority_when_host_header_missing() {
let req = Request::builder()
.uri("https://outline.abc.xyz/test")
.body(Empty::<Bytes>::new())
.unwrap();
assert_eq!(extract_request_host(&req), Some("outline.abc.xyz"));
}
}
@@ -3,14 +3,35 @@
use hyper::header::{HeaderMap, HeaderName, HeaderValue};
use rustproxy_config::RouteConfig;
use crate::template::{RequestContext, expand_template};
use crate::template::{expand_template, RequestContext};
pub struct ResponseFilter;
impl ResponseFilter {
/// Apply response headers from route config and CORS settings.
/// If a `RequestContext` is provided, template variables in header values will be expanded.
pub fn apply_headers(route: &RouteConfig, headers: &mut HeaderMap, req_ctx: Option<&RequestContext>) {
/// Also injects Alt-Svc header for routes with HTTP/3 enabled.
pub fn apply_headers(
route: &RouteConfig,
headers: &mut HeaderMap,
req_ctx: Option<&RequestContext>,
) {
// Inject Alt-Svc for HTTP/3 advertisement if QUIC/HTTP3 is enabled on this route
if let Some(ref udp) = route.action.udp {
if let Some(ref quic) = udp.quic {
if quic.enable_http3.unwrap_or(false) {
let port = quic
.alt_svc_port
.or_else(|| req_ctx.map(|c| c.port))
.unwrap_or(443);
let max_age = quic.alt_svc_max_age.unwrap_or(86400);
let alt_svc = format!("h3=\":{}\"; ma={}", port, max_age);
if let Ok(val) = HeaderValue::from_str(&alt_svc) {
headers.insert("alt-svc", val);
}
}
}
}
// Apply custom response headers from route config
if let Some(ref route_headers) = route.headers {
if let Some(ref response_headers) = route_headers.response {
@@ -47,10 +68,7 @@ impl ResponseFilter {
headers.insert("access-control-allow-origin", val);
}
} else {
headers.insert(
"access-control-allow-origin",
HeaderValue::from_static("*"),
);
headers.insert("access-control-allow-origin", HeaderValue::from_static("*"));
}
// Allow-Methods
@@ -50,17 +50,23 @@ impl<S: AsyncRead + AsyncWrite + Unpin + Send + 'static> AsyncWrite for Shutdown
Pin::new(self.get_mut().inner.as_mut().unwrap()).poll_write(cx, buf)
}
fn poll_flush(
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<io::Result<()>> {
bufs: &[io::IoSlice<'_>],
) -> Poll<io::Result<usize>> {
Pin::new(self.get_mut().inner.as_mut().unwrap()).poll_write_vectored(cx, bufs)
}
fn is_write_vectored(&self) -> bool {
self.inner.as_ref().unwrap().is_write_vectored()
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(self.get_mut().inner.as_mut().unwrap()).poll_flush(cx)
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<io::Result<()>> {
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let this = self.get_mut();
let result = Pin::new(this.inner.as_mut().unwrap()).poll_shutdown(cx);
if result.is_ready() {
@@ -81,7 +87,8 @@ impl<S: AsyncRead + AsyncWrite + Unpin + Send + 'static> Drop for ShutdownOnDrop
let _ = tokio::time::timeout(
std::time::Duration::from_secs(2),
tokio::io::AsyncWriteExt::shutdown(&mut stream),
).await;
)
.await;
// stream is dropped here — all resources freed
});
}
+6 -2
View File
@@ -39,7 +39,8 @@ pub fn expand_headers(
headers: &HashMap<String, String>,
ctx: &RequestContext,
) -> HashMap<String, String> {
headers.iter()
headers
.iter()
.map(|(k, v)| (k.clone(), expand_template(v, ctx)))
.collect()
}
@@ -150,7 +151,10 @@ mod tests {
let ctx = test_context();
let template = "{clientIp}|{domain}|{port}|{path}|{routeName}|{connectionId}";
let result = expand_template(template, &ctx);
assert_eq!(result, "192.168.1.100|example.com|443|/api/v1/users|api-route|42");
assert_eq!(
result,
"192.168.1.100|example.com|443|/api/v1/users|api-route|42"
);
}
#[test]
@@ -7,7 +7,7 @@ use std::sync::Arc;
use std::sync::Mutex;
use dashmap::DashMap;
use rustproxy_config::{RouteTarget, LoadBalancingAlgorithm};
use rustproxy_config::{LoadBalancingAlgorithm, RouteTarget};
/// Upstream selection result.
pub struct UpstreamSelection {
@@ -51,21 +51,19 @@ impl UpstreamSelector {
}
// Determine load balancing algorithm
let algorithm = target.load_balancing.as_ref()
let algorithm = target
.load_balancing
.as_ref()
.map(|lb| &lb.algorithm)
.unwrap_or(&LoadBalancingAlgorithm::RoundRobin);
let idx = match algorithm {
LoadBalancingAlgorithm::RoundRobin => {
self.round_robin_select(&hosts, port)
}
LoadBalancingAlgorithm::RoundRobin => self.round_robin_select(&hosts, port),
LoadBalancingAlgorithm::IpHash => {
let hash = Self::ip_hash(client_addr);
hash % hosts.len()
}
LoadBalancingAlgorithm::LeastConnections => {
self.least_connections_select(&hosts, port)
}
LoadBalancingAlgorithm::LeastConnections => self.least_connections_select(&hosts, port),
};
UpstreamSelection {
@@ -78,9 +76,7 @@ impl UpstreamSelector {
fn round_robin_select(&self, hosts: &[&str], port: u16) -> usize {
let key = format!("{}:{}", hosts[0], port);
let mut counters = self.round_robin.lock().unwrap();
let counter = counters
.entry(key)
.or_insert_with(|| AtomicUsize::new(0));
let counter = counters.entry(key).or_insert_with(|| AtomicUsize::new(0));
let idx = counter.fetch_add(1, Ordering::Relaxed);
idx % hosts.len()
}
@@ -91,7 +87,8 @@ impl UpstreamSelector {
for (i, host) in hosts.iter().enumerate() {
let key = format!("{}:{}", host, port);
let conns = self.active_connections
let conns = self
.active_connections
.get(&key)
.map(|entry| entry.value().load(Ordering::Relaxed))
.unwrap_or(0);
@@ -184,6 +181,7 @@ mod tests {
send_proxy_protocol: None,
headers: None,
advanced: None,
backend_transport: None,
priority: None,
}
}
@@ -227,13 +225,21 @@ mod tests {
selector.connection_started("backend:8080");
selector.connection_started("backend:8080");
assert_eq!(
selector.active_connections.get("backend:8080").unwrap().load(Ordering::Relaxed),
selector
.active_connections
.get("backend:8080")
.unwrap()
.load(Ordering::Relaxed),
2
);
selector.connection_ended("backend:8080");
assert_eq!(
selector.active_connections.get("backend:8080").unwrap().load(Ordering::Relaxed),
selector
.active_connections
.get("backend:8080")
.unwrap()
.load(Ordering::Relaxed),
1
);
File diff suppressed because it is too large Load Diff
+2 -2
View File
@@ -2,10 +2,10 @@
//!
//! Metrics and throughput tracking for RustProxy.
pub mod throughput;
pub mod collector;
pub mod log_dedup;
pub mod throughput;
pub use throughput::*;
pub use collector::*;
pub use log_dedup::*;
pub use throughput::*;
@@ -1,6 +1,6 @@
use dashmap::DashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tracing::info;
@@ -47,7 +47,10 @@ impl LogDeduplicator {
let map_key = format!("{}:{}", category, key);
let now = Instant::now();
let entry = self.events.entry(map_key).or_insert_with(|| AggregatedEvent {
let entry = self
.events
.entry(map_key)
.or_insert_with(|| AggregatedEvent {
category: category.to_string(),
first_message: message.to_string(),
count: AtomicU64::new(0),
+146 -1
View File
@@ -29,6 +29,113 @@ pub struct ThroughputTracker {
created_at: Instant,
}
fn unix_timestamp_seconds() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
}
/// Circular buffer for per-second event counts.
///
/// Unlike `ThroughputTracker`, events are recorded directly into the current
/// second so request counts remain stable even when the collector is sampled
/// more frequently than once per second.
pub(crate) struct RequestRateTracker {
samples: Vec<u64>,
write_index: usize,
count: usize,
capacity: usize,
current_second: Option<u64>,
current_count: u64,
}
impl RequestRateTracker {
pub(crate) fn new(retention_seconds: usize) -> Self {
Self {
samples: Vec::with_capacity(retention_seconds.max(1)),
write_index: 0,
count: 0,
capacity: retention_seconds.max(1),
current_second: None,
current_count: 0,
}
}
fn push_sample(&mut self, count: u64) {
if self.samples.len() < self.capacity {
self.samples.push(count);
} else {
self.samples[self.write_index] = count;
}
self.write_index = (self.write_index + 1) % self.capacity;
self.count = (self.count + 1).min(self.capacity);
}
pub(crate) fn record_event(&mut self) {
self.record_events_at(unix_timestamp_seconds(), 1);
}
pub(crate) fn record_events_at(&mut self, now_sec: u64, count: u64) {
self.advance_to(now_sec);
self.current_count = self.current_count.saturating_add(count);
}
pub(crate) fn advance_to_now(&mut self) {
self.advance_to(unix_timestamp_seconds());
}
pub(crate) fn advance_to(&mut self, now_sec: u64) {
match self.current_second {
Some(current_second) if now_sec > current_second => {
self.push_sample(self.current_count);
for _ in 1..(now_sec - current_second) {
self.push_sample(0);
}
self.current_second = Some(now_sec);
self.current_count = 0;
}
Some(_) => {}
None => {
self.current_second = Some(now_sec);
self.current_count = 0;
}
}
}
fn sum_recent(&self, window_seconds: usize) -> u64 {
let window = window_seconds.min(self.count);
if window == 0 {
return 0;
}
let mut total = 0u64;
for i in 0..window {
let idx = if self.write_index >= i + 1 {
self.write_index - i - 1
} else {
self.capacity - (i + 1 - self.write_index)
};
if idx < self.samples.len() {
total += self.samples[idx];
}
}
total
}
pub(crate) fn last_second(&self) -> u64 {
self.sum_recent(1)
}
pub(crate) fn last_minute(&self) -> u64 {
self.sum_recent(60)
}
pub(crate) fn is_idle(&self) -> bool {
self.current_count == 0 && self.sum_recent(self.capacity) == 0
}
}
impl ThroughputTracker {
/// Create a new tracker with the given capacity (seconds of retention).
pub fn new(retention_seconds: usize) -> Self {
@@ -46,7 +153,8 @@ impl ThroughputTracker {
/// Record bytes (called from data flow callbacks).
pub fn record_bytes(&self, bytes_in: u64, bytes_out: u64) {
self.pending_bytes_in.fetch_add(bytes_in, Ordering::Relaxed);
self.pending_bytes_out.fetch_add(bytes_out, Ordering::Relaxed);
self.pending_bytes_out
.fetch_add(bytes_out, Ordering::Relaxed);
}
/// Take a sample (called at 1Hz).
@@ -229,4 +337,41 @@ mod tests {
let history = tracker.history(10);
assert!(history.is_empty());
}
#[test]
fn test_request_rate_tracker_counts_last_second_and_last_minute() {
let mut tracker = RequestRateTracker::new(60);
tracker.record_events_at(100, 2);
tracker.record_events_at(100, 3);
tracker.advance_to(101);
assert_eq!(tracker.last_second(), 5);
assert_eq!(tracker.last_minute(), 5);
}
#[test]
fn test_request_rate_tracker_adds_zero_samples_for_gaps() {
let mut tracker = RequestRateTracker::new(60);
tracker.record_events_at(100, 4);
tracker.record_events_at(102, 1);
tracker.advance_to(103);
assert_eq!(tracker.last_second(), 1);
assert_eq!(tracker.last_minute(), 5);
}
#[test]
fn test_request_rate_tracker_decays_to_zero_over_window() {
let mut tracker = RequestRateTracker::new(60);
tracker.record_events_at(100, 7);
tracker.advance_to(101);
tracker.advance_to(161);
assert_eq!(tracker.last_second(), 0);
assert_eq!(tracker.last_minute(), 0);
assert!(tracker.is_idle());
}
}
-17
View File
@@ -1,17 +0,0 @@
[package]
name = "rustproxy-nftables"
version.workspace = true
edition.workspace = true
license.workspace = true
authors.workspace = true
description = "NFTables kernel-level forwarding for RustProxy"
[dependencies]
rustproxy-config = { workspace = true }
tokio = { workspace = true }
tracing = { workspace = true }
thiserror = { workspace = true }
anyhow = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
libc = { workspace = true }
-10
View File
@@ -1,10 +0,0 @@
//! # rustproxy-nftables
//!
//! NFTables kernel-level forwarding for RustProxy.
//! Generates and manages nft CLI rules for DNAT/SNAT.
pub mod nft_manager;
pub mod rule_builder;
pub use nft_manager::*;
pub use rule_builder::*;
@@ -1,238 +0,0 @@
use thiserror::Error;
use std::collections::HashMap;
use tracing::{debug, info, warn};
#[derive(Debug, Error)]
pub enum NftError {
#[error("nft command failed: {0}")]
CommandFailed(String),
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
#[error("Not running as root")]
NotRoot,
}
/// Manager for nftables rules.
///
/// Executes `nft` CLI commands to manage kernel-level packet forwarding.
/// Requires root privileges; operations are skipped gracefully if not root.
pub struct NftManager {
table_name: String,
/// Active rules indexed by route ID
active_rules: HashMap<String, Vec<String>>,
/// Whether the table has been initialized
table_initialized: bool,
}
impl NftManager {
pub fn new(table_name: Option<String>) -> Self {
Self {
table_name: table_name.unwrap_or_else(|| "rustproxy".to_string()),
active_rules: HashMap::new(),
table_initialized: false,
}
}
/// Check if we are running as root.
fn is_root() -> bool {
unsafe { libc::geteuid() == 0 }
}
/// Execute a single nft command via the CLI.
async fn exec_nft(command: &str) -> Result<String, NftError> {
// The command starts with "nft ", strip it to get the args
let args = if command.starts_with("nft ") {
&command[4..]
} else {
command
};
let output = tokio::process::Command::new("nft")
.args(args.split_whitespace())
.output()
.await
.map_err(NftError::Io)?;
if output.status.success() {
Ok(String::from_utf8_lossy(&output.stdout).to_string())
} else {
let stderr = String::from_utf8_lossy(&output.stderr);
Err(NftError::CommandFailed(format!(
"Command '{}' failed: {}",
command, stderr
)))
}
}
/// Ensure the nftables table and chains are set up.
async fn ensure_table(&mut self) -> Result<(), NftError> {
if self.table_initialized {
return Ok(());
}
let setup_commands = crate::rule_builder::build_table_setup(&self.table_name);
for cmd in &setup_commands {
Self::exec_nft(cmd).await?;
}
self.table_initialized = true;
info!("NFTables table '{}' initialized", self.table_name);
Ok(())
}
/// Apply rules for a route.
///
/// Executes the nft commands via the CLI. If not running as root,
/// the rules are stored locally but not applied to the kernel.
pub async fn apply_rules(&mut self, route_id: &str, rules: Vec<String>) -> Result<(), NftError> {
if !Self::is_root() {
warn!("Not running as root, nftables rules will not be applied to kernel");
self.active_rules.insert(route_id.to_string(), rules);
return Ok(());
}
self.ensure_table().await?;
for cmd in &rules {
Self::exec_nft(cmd).await?;
debug!("Applied nft rule: {}", cmd);
}
info!("Applied {} nftables rules for route '{}'", rules.len(), route_id);
self.active_rules.insert(route_id.to_string(), rules);
Ok(())
}
/// Remove rules for a route.
///
/// Currently removes the route from tracking. To fully remove specific
/// rules would require handle-based tracking; for now, cleanup() removes
/// the entire table.
pub async fn remove_rules(&mut self, route_id: &str) -> Result<(), NftError> {
if let Some(rules) = self.active_rules.remove(route_id) {
info!("Removed {} tracked nft rules for route '{}'", rules.len(), route_id);
}
Ok(())
}
/// Clean up all managed rules by deleting the entire nftables table.
pub async fn cleanup(&mut self) -> Result<(), NftError> {
if !Self::is_root() {
warn!("Not running as root, skipping nftables cleanup");
self.active_rules.clear();
self.table_initialized = false;
return Ok(());
}
if self.table_initialized {
let cleanup_commands = crate::rule_builder::build_table_cleanup(&self.table_name);
for cmd in &cleanup_commands {
match Self::exec_nft(cmd).await {
Ok(_) => debug!("Cleanup: {}", cmd),
Err(e) => warn!("Cleanup command failed (may be ok): {}", e),
}
}
info!("NFTables table '{}' cleaned up", self.table_name);
}
self.active_rules.clear();
self.table_initialized = false;
Ok(())
}
/// Get the table name.
pub fn table_name(&self) -> &str {
&self.table_name
}
/// Whether the table has been initialized in the kernel.
pub fn is_initialized(&self) -> bool {
self.table_initialized
}
/// Get the number of active route rule sets.
pub fn active_route_count(&self) -> usize {
self.active_rules.len()
}
/// Get the status of all active rules.
pub fn status(&self) -> HashMap<String, serde_json::Value> {
let mut status = HashMap::new();
for (route_id, rules) in &self.active_rules {
status.insert(
route_id.clone(),
serde_json::json!({
"ruleCount": rules.len(),
"rules": rules,
}),
);
}
status
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_new_default_table_name() {
let mgr = NftManager::new(None);
assert_eq!(mgr.table_name(), "rustproxy");
assert!(!mgr.is_initialized());
}
#[test]
fn test_new_custom_table_name() {
let mgr = NftManager::new(Some("custom".to_string()));
assert_eq!(mgr.table_name(), "custom");
}
#[tokio::test]
async fn test_apply_rules_non_root() {
let mut mgr = NftManager::new(None);
// When not root, rules are stored but not applied to kernel
let rules = vec!["nft add rule ip rustproxy prerouting tcp dport 443 dnat to 10.0.0.1:8443".to_string()];
mgr.apply_rules("route-1", rules).await.unwrap();
assert_eq!(mgr.active_route_count(), 1);
let status = mgr.status();
assert!(status.contains_key("route-1"));
assert_eq!(status["route-1"]["ruleCount"], 1);
}
#[tokio::test]
async fn test_remove_rules() {
let mut mgr = NftManager::new(None);
let rules = vec!["nft add rule test".to_string()];
mgr.apply_rules("route-1", rules).await.unwrap();
assert_eq!(mgr.active_route_count(), 1);
mgr.remove_rules("route-1").await.unwrap();
assert_eq!(mgr.active_route_count(), 0);
}
#[tokio::test]
async fn test_cleanup_non_root() {
let mut mgr = NftManager::new(None);
let rules = vec!["nft add rule test".to_string()];
mgr.apply_rules("route-1", rules).await.unwrap();
mgr.apply_rules("route-2", vec!["nft add rule test2".to_string()]).await.unwrap();
mgr.cleanup().await.unwrap();
assert_eq!(mgr.active_route_count(), 0);
assert!(!mgr.is_initialized());
}
#[tokio::test]
async fn test_status_multiple_routes() {
let mut mgr = NftManager::new(None);
mgr.apply_rules("web", vec!["rule1".to_string(), "rule2".to_string()]).await.unwrap();
mgr.apply_rules("api", vec!["rule3".to_string()]).await.unwrap();
let status = mgr.status();
assert_eq!(status.len(), 2);
assert_eq!(status["web"]["ruleCount"], 2);
assert_eq!(status["api"]["ruleCount"], 1);
}
}
@@ -1,123 +0,0 @@
use rustproxy_config::{NfTablesOptions, NfTablesProtocol};
/// Build nftables DNAT rule for port forwarding.
pub fn build_dnat_rule(
table_name: &str,
chain_name: &str,
source_port: u16,
target_host: &str,
target_port: u16,
options: &NfTablesOptions,
) -> Vec<String> {
let protocol = match options.protocol.as_ref().unwrap_or(&NfTablesProtocol::Tcp) {
NfTablesProtocol::Tcp => "tcp",
NfTablesProtocol::Udp => "udp",
NfTablesProtocol::All => "tcp", // TODO: handle "all"
};
let mut rules = Vec::new();
// DNAT rule
rules.push(format!(
"nft add rule ip {} {} {} dport {} dnat to {}:{}",
table_name, chain_name, protocol, source_port, target_host, target_port,
));
// SNAT rule if preserving source IP is not enabled
if !options.preserve_source_ip.unwrap_or(false) {
rules.push(format!(
"nft add rule ip {} postrouting {} dport {} masquerade",
table_name, protocol, target_port,
));
}
// Rate limiting
if let Some(max_rate) = &options.max_rate {
rules.push(format!(
"nft add rule ip {} {} {} dport {} limit rate {} accept",
table_name, chain_name, protocol, source_port, max_rate,
));
}
rules
}
/// Build the initial table and chain setup commands.
pub fn build_table_setup(table_name: &str) -> Vec<String> {
vec![
format!("nft add table ip {}", table_name),
format!("nft add chain ip {} prerouting {{ type nat hook prerouting priority 0 \\; }}", table_name),
format!("nft add chain ip {} postrouting {{ type nat hook postrouting priority 100 \\; }}", table_name),
]
}
/// Build cleanup commands to remove the table.
pub fn build_table_cleanup(table_name: &str) -> Vec<String> {
vec![format!("nft delete table ip {}", table_name)]
}
#[cfg(test)]
mod tests {
use super::*;
fn make_options() -> NfTablesOptions {
NfTablesOptions {
preserve_source_ip: None,
protocol: None,
max_rate: None,
priority: None,
table_name: None,
use_ip_sets: None,
use_advanced_nat: None,
}
}
#[test]
fn test_basic_dnat_rule() {
let options = make_options();
let rules = build_dnat_rule("rustproxy", "prerouting", 443, "10.0.0.1", 8443, &options);
assert!(rules.len() >= 1);
assert!(rules[0].contains("dnat to 10.0.0.1:8443"));
assert!(rules[0].contains("dport 443"));
}
#[test]
fn test_preserve_source_ip() {
let mut options = make_options();
options.preserve_source_ip = Some(true);
let rules = build_dnat_rule("rustproxy", "prerouting", 443, "10.0.0.1", 8443, &options);
// When preserving source IP, no masquerade rule
assert!(rules.iter().all(|r| !r.contains("masquerade")));
}
#[test]
fn test_without_preserve_source_ip() {
let options = make_options();
let rules = build_dnat_rule("rustproxy", "prerouting", 443, "10.0.0.1", 8443, &options);
assert!(rules.iter().any(|r| r.contains("masquerade")));
}
#[test]
fn test_rate_limited_rule() {
let mut options = make_options();
options.max_rate = Some("100/second".to_string());
let rules = build_dnat_rule("rustproxy", "prerouting", 80, "10.0.0.1", 8080, &options);
assert!(rules.iter().any(|r| r.contains("limit rate 100/second")));
}
#[test]
fn test_table_setup_commands() {
let commands = build_table_setup("rustproxy");
assert_eq!(commands.len(), 3);
assert!(commands[0].contains("add table ip rustproxy"));
assert!(commands[1].contains("prerouting"));
assert!(commands[2].contains("postrouting"));
}
#[test]
fn test_table_cleanup() {
let commands = build_table_cleanup("rustproxy");
assert_eq!(commands.len(), 1);
assert!(commands[0].contains("delete table ip rustproxy"));
}
}
@@ -10,6 +10,7 @@ description = "Raw TCP/SNI passthrough engine for RustProxy"
rustproxy-config = { workspace = true }
rustproxy-routing = { workspace = true }
rustproxy-metrics = { workspace = true }
rustproxy-security = { workspace = true }
tokio = { workspace = true }
tracing = { workspace = true }
thiserror = { workspace = true }
@@ -24,3 +25,6 @@ tokio-util = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
socket2 = { workspace = true }
quinn = { workspace = true }
rcgen = { workspace = true }
base64 = { workspace = true }
@@ -0,0 +1,335 @@
//! Shared connection registry for selective connection recycling.
//!
//! Tracks active connections across both TCP and QUIC with metadata
//! (source IP, SNI domain, route ID, cancel token) so that connections
//! can be selectively recycled when certificates, security rules, or
//! route targets change.
use std::collections::HashSet;
use std::net::IpAddr;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use dashmap::DashMap;
use tokio_util::sync::CancellationToken;
use tracing::info;
use rustproxy_config::RouteSecurity;
use rustproxy_http::request_filter::RequestFilter;
use rustproxy_routing::matchers::domain_matches;
/// Metadata about an active connection.
pub struct ConnectionEntry {
/// Per-connection cancel token (child of per-route token).
pub cancel: CancellationToken,
/// Client source IP.
pub source_ip: IpAddr,
/// SNI domain from TLS handshake (None for non-TLS connections).
pub domain: Option<String>,
/// Route ID this connection was matched to (None if route has no ID).
pub route_id: Option<String>,
}
/// Transport-agnostic registry of active connections.
///
/// Used by both `TcpListenerManager` and `UdpListenerManager` to track
/// connections and enable selective recycling on config changes.
pub struct ConnectionRegistry {
connections: DashMap<u64, ConnectionEntry>,
next_id: AtomicU64,
}
impl ConnectionRegistry {
pub fn new() -> Self {
Self {
connections: DashMap::new(),
next_id: AtomicU64::new(1),
}
}
/// Register a connection and return its ID + RAII guard.
///
/// The guard automatically removes the connection from the registry on drop.
pub fn register(self: &Arc<Self>, entry: ConnectionEntry) -> (u64, ConnectionRegistryGuard) {
let id = self.next_id.fetch_add(1, Ordering::Relaxed);
self.connections.insert(id, entry);
let guard = ConnectionRegistryGuard {
registry: Arc::clone(self),
conn_id: id,
};
(id, guard)
}
/// Number of tracked connections (for metrics/debugging).
pub fn len(&self) -> usize {
self.connections.len()
}
/// Recycle connections whose SNI domain matches a renewed certificate domain.
///
/// Uses bidirectional domain matching so that:
/// - Cert `*.example.com` recycles connections for `sub.example.com`
/// - Cert `sub.example.com` recycles connections on routes with `*.example.com`
pub fn recycle_for_cert_change(&self, cert_domain: &str) {
let mut recycled = 0u64;
self.connections.retain(|_, entry| {
let matches = entry
.domain
.as_deref()
.map(|d| domain_matches(cert_domain, d) || domain_matches(d, cert_domain))
.unwrap_or(false);
if matches {
entry.cancel.cancel();
recycled += 1;
false
} else {
true
}
});
if recycled > 0 {
info!(
"Recycled {} connection(s) for cert change on domain '{}'",
recycled, cert_domain
);
}
}
/// Recycle connections on a route whose security config changed.
///
/// Re-evaluates each connection's source IP against the new security rules.
/// Only connections from now-blocked IPs are terminated; allowed IPs are undisturbed.
pub fn recycle_for_security_change(&self, route_id: &str, new_security: &RouteSecurity) {
let mut recycled = 0u64;
self.connections.retain(|_, entry| {
if entry.route_id.as_deref() == Some(route_id) {
if !RequestFilter::check_ip_security(
new_security,
&entry.source_ip,
entry.domain.as_deref(),
) {
info!(
"Terminating connection from {} — IP now blocked on route '{}'",
entry.source_ip, route_id
);
entry.cancel.cancel();
recycled += 1;
return false;
}
}
true
});
if recycled > 0 {
info!(
"Recycled {} connection(s) for security change on route '{}'",
recycled, route_id
);
}
}
/// Recycle all connections on a route (e.g., when targets changed).
pub fn recycle_for_route_change(&self, route_id: &str) {
let mut recycled = 0u64;
self.connections.retain(|_, entry| {
if entry.route_id.as_deref() == Some(route_id) {
entry.cancel.cancel();
recycled += 1;
false
} else {
true
}
});
if recycled > 0 {
info!(
"Recycled {} connection(s) for config change on route '{}'",
recycled, route_id
);
}
}
/// Remove connections on routes that no longer exist.
///
/// This complements per-route CancellationToken cancellation —
/// the token cascade handles graceful shutdown, this cleans up the registry.
pub fn cleanup_removed_routes(&self, active_route_ids: &HashSet<String>) {
self.connections.retain(|_, entry| {
match &entry.route_id {
Some(id) => active_route_ids.contains(id),
None => true, // keep connections without a route ID
}
});
}
}
/// RAII guard that removes a connection from the registry on drop.
pub struct ConnectionRegistryGuard {
registry: Arc<ConnectionRegistry>,
conn_id: u64,
}
impl Drop for ConnectionRegistryGuard {
fn drop(&mut self) {
self.registry.connections.remove(&self.conn_id);
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_registry() -> Arc<ConnectionRegistry> {
Arc::new(ConnectionRegistry::new())
}
#[test]
fn test_register_and_guard_cleanup() {
let reg = make_registry();
let token = CancellationToken::new();
let entry = ConnectionEntry {
cancel: token.clone(),
source_ip: "10.0.0.1".parse().unwrap(),
domain: Some("example.com".to_string()),
route_id: Some("route-1".to_string()),
};
let (id, guard) = reg.register(entry);
assert_eq!(reg.len(), 1);
assert!(reg.connections.contains_key(&id));
drop(guard);
assert_eq!(reg.len(), 0);
assert!(!token.is_cancelled());
}
#[test]
fn test_recycle_for_cert_change_exact() {
let reg = make_registry();
let t1 = CancellationToken::new();
let t2 = CancellationToken::new();
let (_, _g1) = reg.register(ConnectionEntry {
cancel: t1.clone(),
source_ip: "10.0.0.1".parse().unwrap(),
domain: Some("api.example.com".to_string()),
route_id: Some("r1".to_string()),
});
let (_, _g2) = reg.register(ConnectionEntry {
cancel: t2.clone(),
source_ip: "10.0.0.2".parse().unwrap(),
domain: Some("other.com".to_string()),
route_id: Some("r2".to_string()),
});
reg.recycle_for_cert_change("api.example.com");
assert!(t1.is_cancelled());
assert!(!t2.is_cancelled());
// Registry retains unmatched entry (guard still alive keeps it too,
// but the retain removed the matched one before guard could)
}
#[test]
fn test_recycle_for_cert_change_wildcard() {
let reg = make_registry();
let t1 = CancellationToken::new();
let t2 = CancellationToken::new();
let (_, _g1) = reg.register(ConnectionEntry {
cancel: t1.clone(),
source_ip: "10.0.0.1".parse().unwrap(),
domain: Some("sub.example.com".to_string()),
route_id: Some("r1".to_string()),
});
let (_, _g2) = reg.register(ConnectionEntry {
cancel: t2.clone(),
source_ip: "10.0.0.2".parse().unwrap(),
domain: Some("other.com".to_string()),
route_id: Some("r2".to_string()),
});
// Wildcard cert should match subdomain
reg.recycle_for_cert_change("*.example.com");
assert!(t1.is_cancelled());
assert!(!t2.is_cancelled());
}
#[test]
fn test_recycle_for_security_change() {
let reg = make_registry();
let t1 = CancellationToken::new();
let t2 = CancellationToken::new();
let (_, _g1) = reg.register(ConnectionEntry {
cancel: t1.clone(),
source_ip: "10.0.0.1".parse().unwrap(),
domain: None,
route_id: Some("r1".to_string()),
});
let (_, _g2) = reg.register(ConnectionEntry {
cancel: t2.clone(),
source_ip: "10.0.0.2".parse().unwrap(),
domain: None,
route_id: Some("r1".to_string()),
});
// Block 10.0.0.1, allow 10.0.0.2
let security = RouteSecurity {
ip_allow_list: None,
ip_block_list: Some(vec!["10.0.0.1".to_string()]),
max_connections: None,
authentication: None,
rate_limit: None,
basic_auth: None,
jwt_auth: None,
};
reg.recycle_for_security_change("r1", &security);
assert!(t1.is_cancelled());
assert!(!t2.is_cancelled());
}
#[test]
fn test_recycle_for_route_change() {
let reg = make_registry();
let t1 = CancellationToken::new();
let t2 = CancellationToken::new();
let (_, _g1) = reg.register(ConnectionEntry {
cancel: t1.clone(),
source_ip: "10.0.0.1".parse().unwrap(),
domain: None,
route_id: Some("r1".to_string()),
});
let (_, _g2) = reg.register(ConnectionEntry {
cancel: t2.clone(),
source_ip: "10.0.0.2".parse().unwrap(),
domain: None,
route_id: Some("r2".to_string()),
});
reg.recycle_for_route_change("r1");
assert!(t1.is_cancelled());
assert!(!t2.is_cancelled());
}
#[test]
fn test_cleanup_removed_routes() {
let reg = make_registry();
let t1 = CancellationToken::new();
let t2 = CancellationToken::new();
let (_, _g1) = reg.register(ConnectionEntry {
cancel: t1.clone(),
source_ip: "10.0.0.1".parse().unwrap(),
domain: None,
route_id: Some("active".to_string()),
});
let (_, _g2) = reg.register(ConnectionEntry {
cancel: t2.clone(),
source_ip: "10.0.0.2".parse().unwrap(),
domain: None,
route_id: Some("removed".to_string()),
});
let mut active = HashSet::new();
active.insert("active".to_string());
reg.cleanup_removed_routes(&active);
// "removed" route entry was cleaned from registry
// (but guard is still alive so len may differ — the retain already removed it)
assert!(!t1.is_cancelled()); // not cancelled by cleanup, only by token cascade
assert!(!t2.is_cancelled()); // cleanup doesn't cancel, just removes from registry
}
}
@@ -31,7 +31,8 @@ impl ConnectionTracker {
pub fn try_accept(&self, ip: &IpAddr) -> bool {
// Check per-IP connection limit
if let Some(max) = self.max_per_ip {
let count = self.active
let count = self
.active
.get(ip)
.map(|c| c.value().load(Ordering::Relaxed))
.unwrap_or(0);
@@ -48,7 +49,10 @@ impl ConnectionTracker {
let timestamps = entry.value_mut();
// Remove timestamps older than 1 minute
while timestamps.front().is_some_and(|t| now.duration_since(*t) >= one_minute) {
while timestamps
.front()
.is_some_and(|t| now.duration_since(*t) >= one_minute)
{
timestamps.pop_front();
}
@@ -111,7 +115,6 @@ impl ConnectionTracker {
pub fn tracked_ips(&self) -> usize {
self.active.len()
}
}
#[cfg(test)]
@@ -1,8 +1,8 @@
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio_util::sync::CancellationToken;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use tracing::debug;
use rustproxy_metrics::MetricsCollector;
@@ -87,7 +87,12 @@ pub async fn forward_bidirectional_with_timeouts(
if let Some(data) = initial_data {
backend.write_all(data).await?;
if let Some(ref ctx) = metrics {
ctx.collector.record_bytes(data.len() as u64, 0, ctx.route_id.as_deref(), ctx.source_ip.as_deref());
ctx.collector.record_bytes(
data.len() as u64,
0,
ctx.route_id.as_deref(),
ctx.source_ip.as_deref(),
);
}
}
@@ -123,14 +128,17 @@ pub async fn forward_bidirectional_with_timeouts(
total += n as u64;
la1.store(start.elapsed().as_millis() as u64, Ordering::Relaxed);
if let Some(ref ctx) = metrics_c2b {
ctx.collector.record_bytes(n as u64, 0, ctx.route_id.as_deref(), ctx.source_ip.as_deref());
ctx.collector.record_bytes(
n as u64,
0,
ctx.route_id.as_deref(),
ctx.source_ip.as_deref(),
);
}
}
// Graceful shutdown with timeout (sends TCP FIN / TLS close_notify)
let _ = tokio::time::timeout(
std::time::Duration::from_secs(2),
backend_write.shutdown(),
).await;
let _ =
tokio::time::timeout(std::time::Duration::from_secs(2), backend_write.shutdown()).await;
total
});
@@ -154,14 +162,17 @@ pub async fn forward_bidirectional_with_timeouts(
total += n as u64;
la2.store(start.elapsed().as_millis() as u64, Ordering::Relaxed);
if let Some(ref ctx) = metrics_b2c {
ctx.collector.record_bytes(0, n as u64, ctx.route_id.as_deref(), ctx.source_ip.as_deref());
ctx.collector.record_bytes(
0,
n as u64,
ctx.route_id.as_deref(),
ctx.source_ip.as_deref(),
);
}
}
// Graceful shutdown with timeout (sends TCP FIN / TLS close_notify)
let _ = tokio::time::timeout(
std::time::Duration::from_secs(2),
client_write.shutdown(),
).await;
let _ =
tokio::time::timeout(std::time::Duration::from_secs(2), client_write.shutdown()).await;
total
});
+19 -12
View File
@@ -1,22 +1,29 @@
//! # rustproxy-passthrough
//!
//! Raw TCP/SNI passthrough engine for RustProxy.
//! Handles TCP listening, TLS ClientHello SNI extraction, and bidirectional forwarding.
//! Raw TCP/SNI passthrough engine and UDP listener for RustProxy.
//! Handles TCP listening, TLS ClientHello SNI extraction, bidirectional forwarding,
//! and UDP datagram session tracking with forwarding.
pub mod tcp_listener;
pub mod sni_parser;
pub mod connection_registry;
pub mod connection_tracker;
pub mod forwarder;
pub mod proxy_protocol;
pub mod tls_handler;
pub mod connection_tracker;
pub mod socket_relay;
pub mod quic_handler;
pub mod sni_parser;
pub mod socket_opts;
pub mod tcp_listener;
pub mod tls_handler;
pub mod udp_listener;
pub mod udp_session;
pub use tcp_listener::*;
pub use sni_parser::*;
pub use connection_registry::*;
pub use connection_tracker::*;
pub use forwarder::*;
pub use proxy_protocol::*;
pub use tls_handler::*;
pub use connection_tracker::*;
pub use socket_relay::*;
pub use quic_handler::*;
pub use sni_parser::*;
pub use socket_opts::*;
pub use tcp_listener::*;
pub use tls_handler::*;
pub use udp_listener::*;
pub use udp_session::*;
@@ -1,4 +1,4 @@
use std::net::SocketAddr;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use thiserror::Error;
#[derive(Debug, Error)]
@@ -9,9 +9,11 @@ pub enum ProxyProtocolError {
UnsupportedVersion,
#[error("Parse error: {0}")]
Parse(String),
#[error("Incomplete header: need {0} bytes, got {1}")]
Incomplete(usize, usize),
}
/// Parsed PROXY protocol v1 header.
/// Parsed PROXY protocol header (v1 or v2).
#[derive(Debug, Clone)]
pub struct ProxyProtocolHeader {
pub source_addr: SocketAddr,
@@ -24,21 +26,36 @@ pub struct ProxyProtocolHeader {
pub enum ProxyProtocol {
Tcp4,
Tcp6,
Udp4,
Udp6,
Unknown,
}
/// Transport type for PROXY v2 header generation.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ProxyV2Transport {
Stream, // TCP
Datagram, // UDP
}
/// PROXY protocol v2 signature (12 bytes).
const PROXY_V2_SIGNATURE: [u8; 12] = [
0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A,
];
// ===== v1 (text format) =====
/// Parse a PROXY protocol v1 header from data.
///
/// Format: `PROXY TCP4 <src_ip> <dst_ip> <src_port> <dst_port>\r\n`
pub fn parse_v1(data: &[u8]) -> Result<(ProxyProtocolHeader, usize), ProxyProtocolError> {
// Find the end of the header line
let line_end = data
.windows(2)
.position(|w| w == b"\r\n")
.ok_or(ProxyProtocolError::InvalidHeader)?;
let line = std::str::from_utf8(&data[..line_end])
.map_err(|_| ProxyProtocolError::InvalidHeader)?;
let line =
std::str::from_utf8(&data[..line_end]).map_err(|_| ProxyProtocolError::InvalidHeader)?;
if !line.starts_with("PROXY ") {
return Err(ProxyProtocolError::InvalidHeader);
@@ -56,10 +73,10 @@ pub fn parse_v1(data: &[u8]) -> Result<(ProxyProtocolHeader, usize), ProxyProtoc
_ => return Err(ProxyProtocolError::UnsupportedVersion),
};
let src_ip: std::net::IpAddr = parts[2]
let src_ip: IpAddr = parts[2]
.parse()
.map_err(|_| ProxyProtocolError::Parse("Invalid source IP".to_string()))?;
let dst_ip: std::net::IpAddr = parts[3]
let dst_ip: IpAddr = parts[3]
.parse()
.map_err(|_| ProxyProtocolError::Parse("Invalid destination IP".to_string()))?;
let src_port: u16 = parts[4]
@@ -75,7 +92,6 @@ pub fn parse_v1(data: &[u8]) -> Result<(ProxyProtocolHeader, usize), ProxyProtoc
protocol,
};
// Consumed bytes = line + \r\n
Ok((header, line_end + 2))
}
@@ -97,10 +113,226 @@ pub fn is_proxy_protocol_v1(data: &[u8]) -> bool {
data.starts_with(b"PROXY ")
}
// ===== v2 (binary format) =====
/// Check if data starts with a PROXY protocol v2 header.
pub fn is_proxy_protocol_v2(data: &[u8]) -> bool {
data.len() >= 12 && data[..12] == PROXY_V2_SIGNATURE
}
/// Parse a PROXY protocol v2 binary header.
///
/// Binary format:
/// - [0..12] signature (12 bytes)
/// - [12] version (high nibble) + command (low nibble)
/// - [13] address family (high nibble) + transport (low nibble)
/// - [14..16] address block length (big-endian u16)
/// - [16..] address block (variable, depends on family)
pub fn parse_v2(data: &[u8]) -> Result<(ProxyProtocolHeader, usize), ProxyProtocolError> {
if data.len() < 16 {
return Err(ProxyProtocolError::Incomplete(16, data.len()));
}
// Validate signature
if data[..12] != PROXY_V2_SIGNATURE {
return Err(ProxyProtocolError::InvalidHeader);
}
// Version (high nibble of byte 12) must be 0x2
let version = (data[12] >> 4) & 0x0F;
if version != 2 {
return Err(ProxyProtocolError::UnsupportedVersion);
}
// Command (low nibble of byte 12)
let command = data[12] & 0x0F;
// 0x0 = LOCAL, 0x1 = PROXY
if command > 1 {
return Err(ProxyProtocolError::Parse(format!(
"Unknown command: {}",
command
)));
}
// Address family (high nibble) + transport (low nibble) of byte 13
let family = (data[13] >> 4) & 0x0F;
let transport = data[13] & 0x0F;
// Address block length
let addr_len = u16::from_be_bytes([data[14], data[15]]) as usize;
let total_len = 16 + addr_len;
if data.len() < total_len {
return Err(ProxyProtocolError::Incomplete(total_len, data.len()));
}
// LOCAL command: no real addresses, return unspecified
if command == 0 {
return Ok((
ProxyProtocolHeader {
source_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0),
dest_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0),
protocol: ProxyProtocol::Unknown,
},
total_len,
));
}
// PROXY command: parse addresses based on family + transport
let addr_block = &data[16..16 + addr_len];
match (family, transport) {
// AF_INET (0x1) + STREAM (0x1) = TCP4
(0x1, 0x1) => {
if addr_len < 12 {
return Err(ProxyProtocolError::Parse(
"IPv4 address block too short".to_string(),
));
}
let src_ip = Ipv4Addr::new(addr_block[0], addr_block[1], addr_block[2], addr_block[3]);
let dst_ip = Ipv4Addr::new(addr_block[4], addr_block[5], addr_block[6], addr_block[7]);
let src_port = u16::from_be_bytes([addr_block[8], addr_block[9]]);
let dst_port = u16::from_be_bytes([addr_block[10], addr_block[11]]);
Ok((
ProxyProtocolHeader {
source_addr: SocketAddr::new(IpAddr::V4(src_ip), src_port),
dest_addr: SocketAddr::new(IpAddr::V4(dst_ip), dst_port),
protocol: ProxyProtocol::Tcp4,
},
total_len,
))
}
// AF_INET (0x1) + DGRAM (0x2) = UDP4
(0x1, 0x2) => {
if addr_len < 12 {
return Err(ProxyProtocolError::Parse(
"IPv4 address block too short".to_string(),
));
}
let src_ip = Ipv4Addr::new(addr_block[0], addr_block[1], addr_block[2], addr_block[3]);
let dst_ip = Ipv4Addr::new(addr_block[4], addr_block[5], addr_block[6], addr_block[7]);
let src_port = u16::from_be_bytes([addr_block[8], addr_block[9]]);
let dst_port = u16::from_be_bytes([addr_block[10], addr_block[11]]);
Ok((
ProxyProtocolHeader {
source_addr: SocketAddr::new(IpAddr::V4(src_ip), src_port),
dest_addr: SocketAddr::new(IpAddr::V4(dst_ip), dst_port),
protocol: ProxyProtocol::Udp4,
},
total_len,
))
}
// AF_INET6 (0x2) + STREAM (0x1) = TCP6
(0x2, 0x1) => {
if addr_len < 36 {
return Err(ProxyProtocolError::Parse(
"IPv6 address block too short".to_string(),
));
}
let src_ip = Ipv6Addr::from(<[u8; 16]>::try_from(&addr_block[0..16]).unwrap());
let dst_ip = Ipv6Addr::from(<[u8; 16]>::try_from(&addr_block[16..32]).unwrap());
let src_port = u16::from_be_bytes([addr_block[32], addr_block[33]]);
let dst_port = u16::from_be_bytes([addr_block[34], addr_block[35]]);
Ok((
ProxyProtocolHeader {
source_addr: SocketAddr::new(IpAddr::V6(src_ip), src_port),
dest_addr: SocketAddr::new(IpAddr::V6(dst_ip), dst_port),
protocol: ProxyProtocol::Tcp6,
},
total_len,
))
}
// AF_INET6 (0x2) + DGRAM (0x2) = UDP6
(0x2, 0x2) => {
if addr_len < 36 {
return Err(ProxyProtocolError::Parse(
"IPv6 address block too short".to_string(),
));
}
let src_ip = Ipv6Addr::from(<[u8; 16]>::try_from(&addr_block[0..16]).unwrap());
let dst_ip = Ipv6Addr::from(<[u8; 16]>::try_from(&addr_block[16..32]).unwrap());
let src_port = u16::from_be_bytes([addr_block[32], addr_block[33]]);
let dst_port = u16::from_be_bytes([addr_block[34], addr_block[35]]);
Ok((
ProxyProtocolHeader {
source_addr: SocketAddr::new(IpAddr::V6(src_ip), src_port),
dest_addr: SocketAddr::new(IpAddr::V6(dst_ip), dst_port),
protocol: ProxyProtocol::Udp6,
},
total_len,
))
}
// AF_UNSPEC or unknown
(0x0, _) => Ok((
ProxyProtocolHeader {
source_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0),
dest_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0),
protocol: ProxyProtocol::Unknown,
},
total_len,
)),
_ => Err(ProxyProtocolError::Parse(format!(
"Unsupported family/transport: 0x{:X}{:X}",
family, transport
))),
}
}
/// Generate a PROXY protocol v2 binary header.
pub fn generate_v2(source: &SocketAddr, dest: &SocketAddr, transport: ProxyV2Transport) -> Vec<u8> {
let transport_nibble: u8 = match transport {
ProxyV2Transport::Stream => 0x1,
ProxyV2Transport::Datagram => 0x2,
};
match (source.ip(), dest.ip()) {
(IpAddr::V4(src_ip), IpAddr::V4(dst_ip)) => {
let mut buf = Vec::with_capacity(28);
buf.extend_from_slice(&PROXY_V2_SIGNATURE);
buf.push(0x21); // version 2, PROXY command
buf.push(0x10 | transport_nibble); // AF_INET + transport
buf.extend_from_slice(&12u16.to_be_bytes()); // addr block length
buf.extend_from_slice(&src_ip.octets());
buf.extend_from_slice(&dst_ip.octets());
buf.extend_from_slice(&source.port().to_be_bytes());
buf.extend_from_slice(&dest.port().to_be_bytes());
buf
}
(IpAddr::V6(src_ip), IpAddr::V6(dst_ip)) => {
let mut buf = Vec::with_capacity(52);
buf.extend_from_slice(&PROXY_V2_SIGNATURE);
buf.push(0x21); // version 2, PROXY command
buf.push(0x20 | transport_nibble); // AF_INET6 + transport
buf.extend_from_slice(&36u16.to_be_bytes()); // addr block length
buf.extend_from_slice(&src_ip.octets());
buf.extend_from_slice(&dst_ip.octets());
buf.extend_from_slice(&source.port().to_be_bytes());
buf.extend_from_slice(&dest.port().to_be_bytes());
buf
}
// Mixed IPv4/IPv6: map IPv4 to IPv6-mapped address
_ => {
let src_v6 = match source.ip() {
IpAddr::V4(v4) => v4.to_ipv6_mapped(),
IpAddr::V6(v6) => v6,
};
let dst_v6 = match dest.ip() {
IpAddr::V4(v4) => v4.to_ipv6_mapped(),
IpAddr::V6(v6) => v6,
};
let src6 = SocketAddr::new(IpAddr::V6(src_v6), source.port());
let dst6 = SocketAddr::new(IpAddr::V6(dst_v6), dest.port());
generate_v2(&src6, &dst6, transport)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
// ===== v1 tests =====
#[test]
fn test_parse_v1_tcp4() {
let header = b"PROXY TCP4 192.168.1.100 10.0.0.1 12345 443\r\n";
@@ -126,4 +358,133 @@ mod tests {
assert!(is_proxy_protocol_v1(b"PROXY TCP4 ..."));
assert!(!is_proxy_protocol_v1(b"GET / HTTP/1.1"));
}
// ===== v2 tests =====
#[test]
fn test_is_proxy_protocol_v2() {
assert!(is_proxy_protocol_v2(&PROXY_V2_SIGNATURE));
assert!(!is_proxy_protocol_v2(b"PROXY TCP4 ..."));
assert!(!is_proxy_protocol_v2(b"short"));
}
#[test]
fn test_parse_v2_tcp4() {
let source: SocketAddr = "198.51.100.10:54321".parse().unwrap();
let dest: SocketAddr = "203.0.113.25:8443".parse().unwrap();
let header = generate_v2(&source, &dest, ProxyV2Transport::Stream);
assert_eq!(header.len(), 28);
let (parsed, consumed) = parse_v2(&header).unwrap();
assert_eq!(consumed, 28);
assert_eq!(parsed.protocol, ProxyProtocol::Tcp4);
assert_eq!(parsed.source_addr, source);
assert_eq!(parsed.dest_addr, dest);
}
#[test]
fn test_parse_v2_udp4() {
let source: SocketAddr = "10.0.0.1:12345".parse().unwrap();
let dest: SocketAddr = "10.0.0.2:53".parse().unwrap();
let header = generate_v2(&source, &dest, ProxyV2Transport::Datagram);
assert_eq!(header.len(), 28);
assert_eq!(header[13], 0x12); // AF_INET + DGRAM
let (parsed, consumed) = parse_v2(&header).unwrap();
assert_eq!(consumed, 28);
assert_eq!(parsed.protocol, ProxyProtocol::Udp4);
assert_eq!(parsed.source_addr, source);
assert_eq!(parsed.dest_addr, dest);
}
#[test]
fn test_parse_v2_tcp6() {
let source: SocketAddr = "[2001:db8::1]:54321".parse().unwrap();
let dest: SocketAddr = "[2001:db8::2]:443".parse().unwrap();
let header = generate_v2(&source, &dest, ProxyV2Transport::Stream);
assert_eq!(header.len(), 52);
assert_eq!(header[13], 0x21); // AF_INET6 + STREAM
let (parsed, consumed) = parse_v2(&header).unwrap();
assert_eq!(consumed, 52);
assert_eq!(parsed.protocol, ProxyProtocol::Tcp6);
assert_eq!(parsed.source_addr, source);
assert_eq!(parsed.dest_addr, dest);
}
#[test]
fn test_generate_v2_tcp4_byte_layout() {
let source: SocketAddr = "1.2.3.4:1000".parse().unwrap();
let dest: SocketAddr = "5.6.7.8:443".parse().unwrap();
let header = generate_v2(&source, &dest, ProxyV2Transport::Stream);
assert_eq!(&header[0..12], &PROXY_V2_SIGNATURE);
assert_eq!(header[12], 0x21); // v2, PROXY
assert_eq!(header[13], 0x11); // AF_INET, STREAM
assert_eq!(u16::from_be_bytes([header[14], header[15]]), 12); // addr len
assert_eq!(&header[16..20], &[1, 2, 3, 4]); // src ip
assert_eq!(&header[20..24], &[5, 6, 7, 8]); // dst ip
assert_eq!(u16::from_be_bytes([header[24], header[25]]), 1000); // src port
assert_eq!(u16::from_be_bytes([header[26], header[27]]), 443); // dst port
}
#[test]
fn test_generate_v2_udp4_byte_layout() {
let source: SocketAddr = "10.0.0.1:5000".parse().unwrap();
let dest: SocketAddr = "10.0.0.2:53".parse().unwrap();
let header = generate_v2(&source, &dest, ProxyV2Transport::Datagram);
assert_eq!(header[12], 0x21); // v2, PROXY
assert_eq!(header[13], 0x12); // AF_INET, DGRAM (UDP)
}
#[test]
fn test_parse_v2_local_command() {
// Build a LOCAL command header (no addresses)
let mut header = Vec::new();
header.extend_from_slice(&PROXY_V2_SIGNATURE);
header.push(0x20); // v2, LOCAL
header.push(0x00); // AF_UNSPEC
header.extend_from_slice(&0u16.to_be_bytes()); // 0-length address block
let (parsed, consumed) = parse_v2(&header).unwrap();
assert_eq!(consumed, 16);
assert_eq!(parsed.protocol, ProxyProtocol::Unknown);
assert_eq!(parsed.source_addr.port(), 0);
}
#[test]
fn test_parse_v2_incomplete() {
let data = &PROXY_V2_SIGNATURE[..8]; // only 8 bytes
assert!(parse_v2(data).is_err());
}
#[test]
fn test_parse_v2_wrong_version() {
let mut header = Vec::new();
header.extend_from_slice(&PROXY_V2_SIGNATURE);
header.push(0x11); // version 1, not 2
header.push(0x11);
header.extend_from_slice(&12u16.to_be_bytes());
header.extend_from_slice(&[0u8; 12]);
assert!(matches!(
parse_v2(&header),
Err(ProxyProtocolError::UnsupportedVersion)
));
}
#[test]
fn test_v2_roundtrip_with_trailing_data() {
let source: SocketAddr = "192.168.1.1:8080".parse().unwrap();
let dest: SocketAddr = "10.0.0.1:443".parse().unwrap();
let mut data = generate_v2(&source, &dest, ProxyV2Transport::Stream);
data.extend_from_slice(b"GET / HTTP/1.1\r\n"); // trailing app data
let (parsed, consumed) = parse_v2(&data).unwrap();
assert_eq!(consumed, 28);
assert_eq!(parsed.source_addr, source);
assert_eq!(&data[consumed..], b"GET / HTTP/1.1\r\n");
}
}
@@ -0,0 +1,832 @@
//! QUIC connection handling.
//!
//! Manages QUIC endpoints (via quinn), accepts connections, and either:
//! - Forwards streams bidirectionally to TCP backends (QUIC termination)
//! - Dispatches to H3ProxyService for HTTP/3 handling (Phase 5)
//!
//! When `proxy_ips` is configured, a UDP relay layer intercepts PROXY protocol v2
//! headers before they reach quinn, extracting real client IPs for attribution.
use std::net::{IpAddr, SocketAddr};
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::Instant;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::UdpSocket;
use tokio::task::JoinHandle;
use arc_swap::ArcSwap;
use dashmap::DashMap;
use quinn::{Endpoint, ServerConfig as QuinnServerConfig};
use rustls::ServerConfig as RustlsServerConfig;
use tokio_util::sync::CancellationToken;
use tracing::{debug, info, warn};
use rustproxy_config::{RouteConfig, TransportProtocol};
use rustproxy_metrics::MetricsCollector;
use rustproxy_routing::{MatchContext, RouteManager};
use rustproxy_security::IpBlockList;
use rustproxy_http::h3_service::H3ProxyService;
use crate::connection_registry::{ConnectionEntry, ConnectionRegistry};
use crate::connection_tracker::ConnectionTracker;
/// Create a QUIC server endpoint on the given port with the provided TLS config.
///
/// The TLS config must have ALPN protocols set (e.g., `h3` for HTTP/3).
pub fn create_quic_endpoint(
port: u16,
tls_config: Arc<RustlsServerConfig>,
) -> anyhow::Result<Endpoint> {
let quic_crypto = quinn::crypto::rustls::QuicServerConfig::try_from(tls_config)
.map_err(|e| anyhow::anyhow!("Failed to create QUIC crypto config: {}", e))?;
let server_config = QuinnServerConfig::with_crypto(Arc::new(quic_crypto));
let socket = std::net::UdpSocket::bind(SocketAddr::from(([0, 0, 0, 0], port)))?;
let endpoint = Endpoint::new(
quinn::EndpointConfig::default(),
Some(server_config),
socket,
quinn::default_runtime().ok_or_else(|| anyhow::anyhow!("No async runtime for quinn"))?,
)?;
info!("QUIC endpoint listening on port {}", port);
Ok(endpoint)
}
// ===== PROXY protocol relay for QUIC =====
/// Result of creating a QUIC endpoint with a PROXY protocol relay layer.
pub struct QuicProxyRelay {
/// The quinn endpoint (bound to 127.0.0.1:ephemeral).
pub endpoint: Endpoint,
/// The relay recv loop task handle.
pub relay_task: JoinHandle<()>,
/// Maps relay socket local addr → real client SocketAddr (from PROXY v2).
/// Consulted by `quic_accept_loop` to resolve real client IPs.
pub real_client_map: Arc<DashMap<SocketAddr, SocketAddr>>,
}
/// A single relay session for forwarding datagrams between an external source
/// and the internal quinn endpoint.
struct RelaySession {
socket: Arc<UdpSocket>,
last_activity: AtomicU64,
return_task: JoinHandle<()>,
cancel: CancellationToken,
}
impl Drop for RelaySession {
fn drop(&mut self) {
self.cancel.cancel();
self.return_task.abort();
}
}
/// Create a QUIC endpoint with a PROXY protocol v2 relay layer.
///
/// Instead of giving the external socket to quinn, we:
/// 1. Bind a raw UDP socket on 0.0.0.0:port (external)
/// 2. Bind quinn on 127.0.0.1:0 (internal, ephemeral)
/// 3. Run a relay loop that filters PROXY v2 headers and forwards datagrams
///
/// Only used when `proxy_ips` is non-empty.
pub fn create_quic_endpoint_with_proxy_relay(
port: u16,
tls_config: Arc<RustlsServerConfig>,
proxy_ips: Arc<Vec<IpAddr>>,
security_policy: Arc<ArcSwap<IpBlockList>>,
cancel: CancellationToken,
) -> anyhow::Result<QuicProxyRelay> {
// Bind external socket on the real port
let external_socket = std::net::UdpSocket::bind(SocketAddr::from(([0, 0, 0, 0], port)))?;
external_socket.set_nonblocking(true)?;
let external_socket = Arc::new(
UdpSocket::from_std(external_socket)
.map_err(|e| anyhow::anyhow!("Failed to wrap external socket: {}", e))?,
);
// Bind quinn on localhost ephemeral port
let internal_socket = std::net::UdpSocket::bind("127.0.0.1:0")?;
let quinn_internal_addr = internal_socket.local_addr()?;
let quic_crypto = quinn::crypto::rustls::QuicServerConfig::try_from(tls_config)
.map_err(|e| anyhow::anyhow!("Failed to create QUIC crypto config: {}", e))?;
let server_config = QuinnServerConfig::with_crypto(Arc::new(quic_crypto));
let endpoint = Endpoint::new(
quinn::EndpointConfig::default(),
Some(server_config),
internal_socket,
quinn::default_runtime().ok_or_else(|| anyhow::anyhow!("No async runtime for quinn"))?,
)?;
let real_client_map = Arc::new(DashMap::new());
let relay_task = tokio::spawn(quic_proxy_relay_loop(
external_socket,
quinn_internal_addr,
proxy_ips,
security_policy,
Arc::clone(&real_client_map),
cancel,
));
info!(
"QUIC endpoint with PROXY relay on port {} (quinn internal: {})",
port, quinn_internal_addr
);
Ok(QuicProxyRelay {
endpoint,
relay_task,
real_client_map,
})
}
/// Main relay loop: reads datagrams from the external socket, filters PROXY v2
/// headers from trusted proxy IPs, and forwards everything else to quinn via
/// per-session relay sockets.
async fn quic_proxy_relay_loop(
external_socket: Arc<UdpSocket>,
quinn_internal_addr: SocketAddr,
proxy_ips: Arc<Vec<IpAddr>>,
security_policy: Arc<ArcSwap<IpBlockList>>,
real_client_map: Arc<DashMap<SocketAddr, SocketAddr>>,
cancel: CancellationToken,
) {
// Maps external source addr → real client addr (from PROXY v2 headers)
let proxy_addr_map: DashMap<SocketAddr, SocketAddr> = DashMap::new();
// Maps external source addr → relay session
let relay_sessions: DashMap<SocketAddr, Arc<RelaySession>> = DashMap::new();
let epoch = Instant::now();
let mut buf = vec![0u8; 65535];
// Inline cleanup: periodically scan relay_sessions for stale entries
let mut last_cleanup = Instant::now();
let cleanup_interval = std::time::Duration::from_secs(30);
let session_timeout_ms: u64 = 120_000;
loop {
let (len, src_addr) = tokio::select! {
_ = cancel.cancelled() => {
debug!("QUIC proxy relay loop cancelled");
break;
}
result = external_socket.recv_from(&mut buf) => {
match result {
Ok(r) => r,
Err(e) => {
warn!("QUIC proxy relay recv error: {}", e);
continue;
}
}
}
};
let datagram = &buf[..len];
// PROXY v2 handling: only on first datagram from a trusted proxy IP
// (before a relay session exists for this source)
if proxy_ips.contains(&src_addr.ip()) && relay_sessions.get(&src_addr).is_none() {
if crate::proxy_protocol::is_proxy_protocol_v2(datagram) {
match crate::proxy_protocol::parse_v2(datagram) {
Ok((header, _consumed)) => {
debug!(
"QUIC PROXY v2 from {}: real client {}",
src_addr, header.source_addr
);
proxy_addr_map.insert(src_addr, header.source_addr);
continue; // consume the PROXY v2 datagram
}
Err(e) => {
debug!(
"QUIC proxy relay: failed to parse PROXY v2 from {}: {}",
src_addr, e
);
}
}
}
}
// Determine real client address
let real_client = proxy_addr_map
.get(&src_addr)
.map(|r| *r)
.unwrap_or(src_addr);
let block_list = security_policy.load();
if !block_list.is_empty() && block_list.is_blocked(&real_client.ip()) {
debug!(
"QUIC datagram from {} blocked by global security policy",
real_client
);
continue;
}
// Get or create relay session for this external source
let session = match relay_sessions.get(&src_addr) {
Some(s) => {
s.last_activity
.store(epoch.elapsed().as_millis() as u64, Ordering::Relaxed);
Arc::clone(s.value())
}
None => {
// Create new relay socket connected to quinn's internal address
let relay_socket = match UdpSocket::bind("127.0.0.1:0").await {
Ok(s) => s,
Err(e) => {
warn!("QUIC relay: failed to bind relay socket: {}", e);
continue;
}
};
if let Err(e) = relay_socket.connect(quinn_internal_addr).await {
warn!(
"QUIC relay: failed to connect relay socket to {}: {}",
quinn_internal_addr, e
);
continue;
}
let relay_local_addr = match relay_socket.local_addr() {
Ok(a) => a,
Err(e) => {
warn!("QUIC relay: failed to get relay socket local addr: {}", e);
continue;
}
};
let relay_socket = Arc::new(relay_socket);
// Store the real client mapping for the QUIC accept loop
real_client_map.insert(relay_local_addr, real_client);
// Spawn return-path relay: quinn -> external socket -> original source
let session_cancel = cancel.child_token();
let return_task = tokio::spawn(relay_return_path(
Arc::clone(&relay_socket),
Arc::clone(&external_socket),
src_addr,
session_cancel.child_token(),
));
let session = Arc::new(RelaySession {
socket: relay_socket,
last_activity: AtomicU64::new(epoch.elapsed().as_millis() as u64),
return_task,
cancel: session_cancel,
});
relay_sessions.insert(src_addr, Arc::clone(&session));
debug!(
"QUIC relay: new session for {} (relay {}), real client {}",
src_addr, relay_local_addr, real_client
);
session
}
};
// Forward datagram to quinn via the relay socket
if let Err(e) = session.socket.send(datagram).await {
debug!("QUIC relay: forward error to quinn for {}: {}", src_addr, e);
}
// Periodic cleanup of stale relay sessions
if last_cleanup.elapsed() >= cleanup_interval {
last_cleanup = Instant::now();
let now_ms = epoch.elapsed().as_millis() as u64;
let stale_keys: Vec<SocketAddr> = relay_sessions
.iter()
.filter(|entry| {
let age =
now_ms.saturating_sub(entry.value().last_activity.load(Ordering::Relaxed));
age > session_timeout_ms
})
.map(|entry| *entry.key())
.collect();
for key in stale_keys {
if let Some((_, session)) = relay_sessions.remove(&key) {
session.cancel.cancel();
session.return_task.abort();
// Clean up real_client_map entry
if let Ok(addr) = session.socket.local_addr() {
real_client_map.remove(&addr);
}
proxy_addr_map.remove(&key);
debug!("QUIC relay: cleaned up stale session for {}", key);
}
}
// Also clean orphaned proxy_addr_map entries (PROXY header received
// but no relay session was ever created, e.g. client never sent data)
let orphaned: Vec<SocketAddr> = proxy_addr_map
.iter()
.filter(|entry| relay_sessions.get(entry.key()).is_none())
.map(|entry| *entry.key())
.collect();
for key in orphaned {
proxy_addr_map.remove(&key);
debug!(
"QUIC relay: cleaned up orphaned proxy_addr_map entry for {}",
key
);
}
}
}
// Shutdown: cancel all relay sessions
for entry in relay_sessions.iter() {
entry.value().cancel.cancel();
entry.value().return_task.abort();
}
}
/// Return-path relay: receives datagrams from quinn (via the relay socket)
/// and forwards them back to the external client through the external socket.
async fn relay_return_path(
relay_socket: Arc<UdpSocket>,
external_socket: Arc<UdpSocket>,
external_src_addr: SocketAddr,
cancel: CancellationToken,
) {
let mut buf = vec![0u8; 65535];
loop {
let len = tokio::select! {
_ = cancel.cancelled() => break,
result = relay_socket.recv(&mut buf) => {
match result {
Ok(len) => len,
Err(e) => {
debug!("QUIC relay return recv error for {}: {}", external_src_addr, e);
break;
}
}
}
};
if let Err(e) = external_socket
.send_to(&buf[..len], external_src_addr)
.await
{
debug!(
"QUIC relay return send error to {}: {}",
external_src_addr, e
);
break;
}
}
}
// ===== QUIC accept loop =====
/// Run the QUIC accept loop for a single endpoint.
///
/// Accepts incoming QUIC connections and spawns a task per connection.
/// When `real_client_map` is provided, it is consulted to resolve real client
/// IPs from PROXY protocol v2 headers (relay socket addr → real client addr).
pub async fn quic_accept_loop(
endpoint: Endpoint,
port: u16,
route_manager: Arc<ArcSwap<RouteManager>>,
metrics: Arc<MetricsCollector>,
conn_tracker: Arc<ConnectionTracker>,
cancel: CancellationToken,
h3_service: Option<Arc<H3ProxyService>>,
real_client_map: Option<Arc<DashMap<SocketAddr, SocketAddr>>>,
route_cancels: Arc<DashMap<String, CancellationToken>>,
connection_registry: Arc<ConnectionRegistry>,
security_policy: Arc<ArcSwap<IpBlockList>>,
) {
loop {
let incoming = tokio::select! {
_ = cancel.cancelled() => {
debug!("QUIC accept loop on port {} cancelled", port);
break;
}
incoming = endpoint.accept() => {
match incoming {
Some(conn) => conn,
None => {
debug!("QUIC endpoint on port {} closed", port);
break;
}
}
}
};
let remote_addr = incoming.remote_address();
// Resolve real client IP from PROXY protocol map if available
let real_addr = real_client_map
.as_ref()
.and_then(|map| map.get(&remote_addr).map(|r| *r))
.unwrap_or(remote_addr);
let ip = real_addr.ip();
let block_list = security_policy.load();
if !block_list.is_empty() && block_list.is_blocked(&ip) {
debug!(
"QUIC connection from {} blocked by global security policy",
real_addr
);
continue;
}
// Per-IP rate limiting
if !conn_tracker.try_accept(&ip) {
debug!("QUIC connection rejected from {} (rate limit)", real_addr);
// Drop `incoming` to refuse the connection
continue;
}
// Route matching (port + client IP, no domain yet — QUIC Initial is encrypted)
let rm = route_manager.load();
let ip_str = ip.to_string();
let ctx = MatchContext {
port,
domain: None,
path: None,
client_ip: Some(&ip_str),
tls_version: None,
headers: None,
is_tls: true,
protocol: Some("quic"),
transport: Some(TransportProtocol::Udp),
};
let route = match rm.find_route(&ctx) {
Some(m) => m.route.clone(),
None => {
debug!("No QUIC route matched for port {} from {}", port, real_addr);
continue;
}
};
// Check route-level IP security for QUIC (domain from SNI context)
if let Some(ref security) = route.security {
if !rustproxy_http::request_filter::RequestFilter::check_ip_security(
security, &ip, ctx.domain,
) {
debug!(
"QUIC connection from {} blocked by route security",
real_addr
);
continue;
}
}
conn_tracker.connection_opened(&ip);
let route_id = route.metrics_key().map(str::to_string);
metrics.connection_opened(route_id.as_deref(), Some(&ip_str));
// Resolve per-route cancel token (child of global cancel)
let route_cancel = match route_id.as_deref() {
Some(id) => route_cancels
.entry(id.to_string())
.or_insert_with(|| cancel.child_token())
.clone(),
None => cancel.child_token(),
};
// Per-connection child token for selective recycling
let conn_cancel = route_cancel.child_token();
// Register in connection registry
let registry = Arc::clone(&connection_registry);
let reg_entry = ConnectionEntry {
cancel: conn_cancel.clone(),
source_ip: ip,
domain: None, // QUIC Initial is encrypted, domain comes later via H3 :authority
route_id: route_id.clone(),
};
let metrics = Arc::clone(&metrics);
let conn_tracker = Arc::clone(&conn_tracker);
let h3_svc = h3_service.clone();
let real_client_addr = if real_addr != remote_addr {
Some(real_addr)
} else {
None
};
tokio::spawn(async move {
// Register in connection registry (RAII guard removes on drop)
let (_conn_id, _registry_guard) = registry.register(reg_entry);
// RAII guard: ensures metrics/tracker cleanup even on panic
struct QuicConnGuard {
tracker: Arc<ConnectionTracker>,
metrics: Arc<MetricsCollector>,
ip: std::net::IpAddr,
ip_str: String,
route_id: Option<String>,
}
impl Drop for QuicConnGuard {
fn drop(&mut self) {
self.tracker.connection_closed(&self.ip);
self.metrics
.connection_closed(self.route_id.as_deref(), Some(&self.ip_str));
}
}
let _guard = QuicConnGuard {
tracker: conn_tracker,
metrics: Arc::clone(&metrics),
ip,
ip_str,
route_id,
};
match handle_quic_connection(
incoming,
route,
port,
Arc::clone(&metrics),
&conn_cancel,
h3_svc,
real_client_addr,
)
.await
{
Ok(()) => debug!("QUIC connection from {} completed", real_addr),
Err(e) => debug!("QUIC connection from {} error: {}", real_addr, e),
}
});
}
// Graceful shutdown: close endpoint and wait for in-flight connections
endpoint.close(quinn::VarInt::from_u32(0), b"server shutting down");
endpoint.wait_idle().await;
info!("QUIC endpoint on port {} shut down", port);
}
/// Handle a single accepted QUIC connection.
async fn handle_quic_connection(
incoming: quinn::Incoming,
route: RouteConfig,
port: u16,
metrics: Arc<MetricsCollector>,
cancel: &CancellationToken,
h3_service: Option<Arc<H3ProxyService>>,
real_client_addr: Option<SocketAddr>,
) -> anyhow::Result<()> {
let connection = incoming.await?;
let effective_addr = real_client_addr.unwrap_or_else(|| connection.remote_address());
debug!("QUIC connection established from {}", effective_addr);
// Check if this route has HTTP/3 enabled
let enable_http3 = route
.action
.udp
.as_ref()
.and_then(|u| u.quic.as_ref())
.and_then(|q| q.enable_http3)
.unwrap_or(false);
if enable_http3 {
if let Some(ref h3_svc) = h3_service {
debug!(
"HTTP/3 enabled for route {:?}, dispatching to H3ProxyService",
route.name
);
h3_svc
.handle_connection(connection, &route, port, real_client_addr, cancel)
.await
} else {
warn!(
"HTTP/3 enabled for route {:?} but H3ProxyService not initialized",
route.name
);
// Keep connection alive until cancelled
tokio::select! {
_ = cancel.cancelled() => {}
reason = connection.closed() => {
debug!("HTTP/3 connection closed (no service): {}", reason);
}
}
Ok(())
}
} else {
// Non-HTTP3 QUIC: bidirectional stream forwarding to TCP backend
handle_quic_stream_forwarding(connection, route, port, metrics, cancel, real_client_addr)
.await
}
}
/// Forward QUIC streams bidirectionally to a TCP backend.
///
/// For each accepted bidirectional QUIC stream, connects to the backend
/// via TCP and forwards data in both directions. Quinn's RecvStream/SendStream
/// implement AsyncRead/AsyncWrite, enabling reuse of existing forwarder patterns.
async fn handle_quic_stream_forwarding(
connection: quinn::Connection,
route: RouteConfig,
port: u16,
metrics: Arc<MetricsCollector>,
cancel: &CancellationToken,
real_client_addr: Option<SocketAddr>,
) -> anyhow::Result<()> {
let effective_addr = real_client_addr.unwrap_or_else(|| connection.remote_address());
let route_id = route.metrics_key();
let metrics_arc = metrics;
// Resolve backend target
let target = route
.action
.targets
.as_ref()
.and_then(|t| t.first())
.ok_or_else(|| anyhow::anyhow!("No target for QUIC route"))?;
let backend_host = target.host.first();
let backend_port = target.port.resolve(port);
let backend_addr = format!("{}:{}", backend_host, backend_port);
loop {
let (send_stream, recv_stream) = tokio::select! {
_ = cancel.cancelled() => break,
result = connection.accept_bi() => {
match result {
Ok(streams) => streams,
Err(quinn::ConnectionError::ApplicationClosed(_)) => break,
Err(quinn::ConnectionError::LocallyClosed) => break,
Err(e) => {
debug!("QUIC stream accept error from {}: {}", effective_addr, e);
break;
}
}
}
};
let backend_addr = backend_addr.clone();
let ip_str = effective_addr.ip().to_string();
let stream_metrics = Arc::clone(&metrics_arc);
let stream_route_id = route_id.map(|s| s.to_string());
let stream_cancel = cancel.child_token();
// Spawn a task for each QUIC stream → TCP bidirectional forwarding
tokio::spawn(async move {
match forward_quic_stream_to_tcp(send_stream, recv_stream, &backend_addr, stream_cancel)
.await
{
Ok((bytes_in, bytes_out)) => {
stream_metrics.record_bytes(
bytes_in,
bytes_out,
stream_route_id.as_deref(),
Some(&ip_str),
);
debug!(
"QUIC stream forwarded: {}B in, {}B out",
bytes_in, bytes_out
);
}
Err(e) => {
debug!("QUIC stream forwarding error: {}", e);
}
}
});
}
Ok(())
}
/// Forward a single QUIC bidirectional stream to a TCP backend connection.
///
/// Includes inactivity timeout (60s), max lifetime (10min), and cancellation
/// to prevent leaked stream tasks when the parent connection closes.
async fn forward_quic_stream_to_tcp(
mut quic_send: quinn::SendStream,
mut quic_recv: quinn::RecvStream,
backend_addr: &str,
cancel: CancellationToken,
) -> anyhow::Result<(u64, u64)> {
let inactivity_timeout = std::time::Duration::from_secs(60);
let max_lifetime = std::time::Duration::from_secs(600);
// Connect to backend TCP
let tcp_stream = tokio::net::TcpStream::connect(backend_addr).await?;
let (mut tcp_read, mut tcp_write) = tcp_stream.into_split();
let last_activity = Arc::new(AtomicU64::new(0));
let start = std::time::Instant::now();
let conn_cancel = CancellationToken::new();
let la1 = Arc::clone(&last_activity);
let cc1 = conn_cancel.clone();
let c2b = tokio::spawn(async move {
let mut buf = vec![0u8; 65536];
let mut total = 0u64;
loop {
let n = tokio::select! {
result = quic_recv.read(&mut buf) => match result {
Ok(Some(0)) | Ok(None) | Err(_) => break,
Ok(Some(n)) => n,
},
_ = cc1.cancelled() => break,
};
if tcp_write.write_all(&buf[..n]).await.is_err() {
break;
}
total += n as u64;
la1.store(start.elapsed().as_millis() as u64, Ordering::Relaxed);
}
let _ = tokio::time::timeout(std::time::Duration::from_secs(2), tcp_write.shutdown()).await;
total
});
let la2 = Arc::clone(&last_activity);
let cc2 = conn_cancel.clone();
let b2c = tokio::spawn(async move {
let mut buf = vec![0u8; 65536];
let mut total = 0u64;
loop {
let n = tokio::select! {
result = tcp_read.read(&mut buf) => match result {
Ok(0) | Err(_) => break,
Ok(n) => n,
},
_ = cc2.cancelled() => break,
};
// quinn SendStream implements AsyncWrite
if quic_send.write_all(&buf[..n]).await.is_err() {
break;
}
total += n as u64;
la2.store(start.elapsed().as_millis() as u64, Ordering::Relaxed);
}
let _ = quic_send.finish();
total
});
// Watchdog: inactivity, max lifetime, and cancellation
let la_watch = Arc::clone(&last_activity);
let c2b_abort = c2b.abort_handle();
let b2c_abort = b2c.abort_handle();
let watchdog = tokio::spawn(async move {
let check_interval = std::time::Duration::from_secs(5);
let mut last_seen = 0u64;
loop {
tokio::select! {
_ = cancel.cancelled() => break,
_ = tokio::time::sleep(check_interval) => {
if start.elapsed() >= max_lifetime {
debug!("QUIC stream exceeded max lifetime, closing");
break;
}
let current = la_watch.load(Ordering::Relaxed);
if current == last_seen {
let elapsed = start.elapsed().as_millis() as u64 - current;
if elapsed >= inactivity_timeout.as_millis() as u64 {
debug!("QUIC stream inactive for {}ms, closing", elapsed);
break;
}
}
last_seen = current;
}
}
}
conn_cancel.cancel();
tokio::time::sleep(std::time::Duration::from_secs(4)).await;
c2b_abort.abort();
b2c_abort.abort();
});
let bytes_in = c2b.await.unwrap_or(0);
let bytes_out = b2c.await.unwrap_or(0);
watchdog.abort();
Ok((bytes_in, bytes_out))
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_quic_endpoint_requires_tls_config() {
// Install the ring crypto provider for tests
let _ = rustls::crypto::ring::default_provider().install_default();
// Generate a single self-signed cert and use its key pair
let self_signed =
rcgen::generate_simple_self_signed(vec!["localhost".to_string()]).unwrap();
let cert_der = self_signed.cert.der().clone();
let key_der = self_signed.key_pair.serialize_der();
let mut tls_config = RustlsServerConfig::builder()
.with_no_client_auth()
.with_single_cert(
vec![cert_der.into()],
rustls::pki_types::PrivateKeyDer::try_from(key_der).unwrap(),
)
.unwrap();
tls_config.alpn_protocols = vec![b"h3".to_vec()];
// Port 0 = OS assigns a free port
let result = create_quic_endpoint(0, Arc::new(tls_config));
assert!(
result.is_ok(),
"QUIC endpoint creation failed: {:?}",
result.err()
);
}
}
@@ -47,9 +47,8 @@ pub fn extract_sni(data: &[u8]) -> SniResult {
}
// Handshake length (3 bytes) - informational, we parse incrementally
let _handshake_len = ((data[6] as usize) << 16)
| ((data[7] as usize) << 8)
| (data[8] as usize);
let _handshake_len =
((data[6] as usize) << 16) | ((data[7] as usize) << 8) | (data[8] as usize);
let hello = &data[9..];
@@ -170,7 +169,10 @@ pub fn extract_http_path(data: &[u8]) -> Option<String> {
pub fn extract_http_host(data: &[u8]) -> Option<String> {
let text = std::str::from_utf8(data).ok()?;
for line in text.split("\r\n") {
if let Some(value) = line.strip_prefix("Host: ").or_else(|| line.strip_prefix("host: ")) {
if let Some(value) = line
.strip_prefix("Host: ")
.or_else(|| line.strip_prefix("host: "))
{
// Strip port if present
let host = value.split(':').next().unwrap_or(value).trim();
if !host.is_empty() {
@@ -213,7 +215,10 @@ mod tests {
#[test]
fn test_too_short() {
assert!(matches!(extract_sni(&[0x16, 0x03]), SniResult::NeedMoreData));
assert!(matches!(
extract_sni(&[0x16, 0x03]),
SniResult::NeedMoreData
));
}
#[test]
@@ -263,7 +268,8 @@ mod tests {
// Extension: type=0x0000 (SNI), length, data
let sni_extension = {
let mut e = Vec::new();
e.push(0x00); e.push(0x00); // SNI type
e.push(0x00);
e.push(0x00); // SNI type
e.push(((sni_ext_data.len() >> 8) & 0xFF) as u8);
e.push((sni_ext_data.len() & 0xFF) as u8);
e.extend_from_slice(&sni_ext_data);
@@ -283,16 +289,20 @@ mod tests {
let hello_body = {
let mut h = Vec::new();
// Client version TLS 1.2
h.push(0x03); h.push(0x03);
h.push(0x03);
h.push(0x03);
// Random (32 bytes)
h.extend_from_slice(&[0u8; 32]);
// Session ID length = 0
h.push(0x00);
// Cipher suites: length=2, one suite
h.push(0x00); h.push(0x02);
h.push(0x00); h.push(0x2F); // TLS_RSA_WITH_AES_128_CBC_SHA
h.push(0x00);
h.push(0x02);
h.push(0x00);
h.push(0x2F); // TLS_RSA_WITH_AES_128_CBC_SHA
// Compression methods: length=1, null
h.push(0x01); h.push(0x00);
h.push(0x01);
h.push(0x00);
// Extensions
h.extend_from_slice(&extensions);
h
@@ -313,7 +323,8 @@ mod tests {
// TLS record: type=0x16, version TLS 1.0, length
let mut record = Vec::new();
record.push(0x16); // Handshake
record.push(0x03); record.push(0x01); // TLS 1.0
record.push(0x03);
record.push(0x01); // TLS 1.0
record.push(((handshake.len() >> 8) & 0xFF) as u8);
record.push((handshake.len() & 0xFF) as u8);
record.extend_from_slice(&handshake);
@@ -1,126 +1,4 @@
//! Socket handler relay for connecting client connections to a TypeScript handler
//! via a Unix domain socket.
//! Socket handler relay module.
//!
//! Protocol: Send a JSON metadata line terminated by `\n`, then bidirectional relay.
use tokio::net::UnixStream;
use tokio::io::{AsyncWriteExt, AsyncReadExt};
use tokio::net::TcpStream;
use serde::Serialize;
use tracing::debug;
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
struct RelayMetadata {
connection_id: u64,
remote_ip: String,
remote_port: u16,
local_port: u16,
sni: Option<String>,
route_name: String,
initial_data_base64: Option<String>,
}
/// Relay a client connection to a TypeScript handler via Unix domain socket.
///
/// Protocol: Send a JSON metadata line terminated by `\n`, then bidirectional relay.
pub async fn relay_to_handler(
client: TcpStream,
relay_socket_path: &str,
connection_id: u64,
remote_ip: String,
remote_port: u16,
local_port: u16,
sni: Option<String>,
route_name: String,
initial_data: Option<&[u8]>,
) -> std::io::Result<()> {
debug!(
"Relaying connection {} to handler socket {}",
connection_id, relay_socket_path
);
// Connect to TypeScript handler Unix socket
let mut handler = UnixStream::connect(relay_socket_path).await?;
// Build and send metadata header
let initial_data_base64 = initial_data.map(base64_encode);
let metadata = RelayMetadata {
connection_id,
remote_ip,
remote_port,
local_port,
sni,
route_name,
initial_data_base64,
};
let metadata_json = serde_json::to_string(&metadata)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
handler.write_all(metadata_json.as_bytes()).await?;
handler.write_all(b"\n").await?;
// Bidirectional relay between client and handler
let (mut client_read, mut client_write) = client.into_split();
let (mut handler_read, mut handler_write) = handler.into_split();
let c2h = tokio::spawn(async move {
let mut buf = vec![0u8; 65536];
loop {
let n = match client_read.read(&mut buf).await {
Ok(0) | Err(_) => break,
Ok(n) => n,
};
if handler_write.write_all(&buf[..n]).await.is_err() {
break;
}
}
let _ = handler_write.shutdown().await;
});
let h2c = tokio::spawn(async move {
let mut buf = vec![0u8; 65536];
loop {
let n = match handler_read.read(&mut buf).await {
Ok(0) | Err(_) => break,
Ok(n) => n,
};
if client_write.write_all(&buf[..n]).await.is_err() {
break;
}
}
let _ = client_write.shutdown().await;
});
let _ = tokio::join!(c2h, h2c);
debug!("Relay connection {} completed", connection_id);
Ok(())
}
/// Simple base64 encoding without external dependency.
fn base64_encode(data: &[u8]) -> String {
const CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
let mut result = String::new();
for chunk in data.chunks(3) {
let b0 = chunk[0] as u32;
let b1 = if chunk.len() > 1 { chunk[1] as u32 } else { 0 };
let b2 = if chunk.len() > 2 { chunk[2] as u32 } else { 0 };
let n = (b0 << 16) | (b1 << 8) | b2;
result.push(CHARS[((n >> 18) & 0x3F) as usize] as char);
result.push(CHARS[((n >> 12) & 0x3F) as usize] as char);
if chunk.len() > 1 {
result.push(CHARS[((n >> 6) & 0x3F) as usize] as char);
} else {
result.push('=');
}
if chunk.len() > 2 {
result.push(CHARS[(n & 0x3F) as usize] as char);
} else {
result.push('=');
}
}
result
}
//! Note: The actual relay logic lives in `tcp_listener::relay_to_socket_handler()`
//! which has proper timeouts, cancellation, and metrics integration.
File diff suppressed because it is too large Load Diff
@@ -7,7 +7,7 @@ use rustls::server::ResolvesServerCert;
use rustls::sign::CertifiedKey;
use rustls::ServerConfig;
use tokio::net::TcpStream;
use tokio_rustls::{TlsAcceptor, TlsConnector, server::TlsStream as ServerTlsStream};
use tokio_rustls::{server::TlsStream as ServerTlsStream, TlsAcceptor, TlsConnector};
use tracing::{debug, info};
use crate::tcp_listener::TlsCertConfig;
@@ -29,7 +29,9 @@ pub struct CertResolver {
impl CertResolver {
/// Build a resolver from PEM-encoded cert/key configs.
/// Parses all PEM data upfront so connections only do a cheap HashMap lookup.
pub fn new(configs: &HashMap<String, TlsCertConfig>) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
pub fn new(
configs: &HashMap<String, TlsCertConfig>,
) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
ensure_crypto_provider();
let provider = rustls::crypto::ring::default_provider();
let mut certs = HashMap::new();
@@ -38,8 +40,10 @@ impl CertResolver {
for (domain, cfg) in configs {
let cert_chain = load_certs(&cfg.cert_pem)?;
let key = load_private_key(&cfg.key_pem)?;
let ck = Arc::new(CertifiedKey::from_der(cert_chain, key, &provider)
.map_err(|e| format!("CertifiedKey for {}: {}", domain, e))?);
let ck = Arc::new(
CertifiedKey::from_der(cert_chain, key, &provider)
.map_err(|e| format!("CertifiedKey for {}: {}", domain, e))?,
);
if domain == "*" {
fallback = Some(Arc::clone(&ck));
}
@@ -78,7 +82,9 @@ impl ResolvesServerCert for CertResolver {
/// Build a shared TLS acceptor with SNI resolution, session cache, and session tickets.
/// The returned acceptor can be reused across all connections (cheap Arc clone).
pub fn build_shared_tls_acceptor(resolver: CertResolver) -> Result<TlsAcceptor, Box<dyn std::error::Error + Send + Sync>> {
pub fn build_shared_tls_acceptor(
resolver: CertResolver,
) -> Result<TlsAcceptor, Box<dyn std::error::Error + Send + Sync>> {
ensure_crypto_provider();
let mut config = ServerConfig::builder()
.with_no_client_auth()
@@ -90,22 +96,30 @@ pub fn build_shared_tls_acceptor(resolver: CertResolver) -> Result<TlsAcceptor,
// Shared session cache — enables session ID resumption across connections
config.session_storage = rustls::server::ServerSessionMemoryCache::new(4096);
// Session ticket resumption (12-hour lifetime, Chacha20Poly1305 encrypted)
config.ticketer = rustls::crypto::ring::Ticketer::new()
.map_err(|e| format!("Ticketer: {}", e))?;
config.ticketer =
rustls::crypto::ring::Ticketer::new().map_err(|e| format!("Ticketer: {}", e))?;
info!("Built shared TLS config with session cache (4096), ticket support, and ALPN h2+http/1.1");
info!(
"Built shared TLS config with session cache (4096), ticket support, and ALPN h2+http/1.1"
);
Ok(TlsAcceptor::from(Arc::new(config)))
}
/// Build a TLS acceptor from PEM-encoded cert and key data.
/// Advertises both h2 and http/1.1 via ALPN (for client-facing connections).
pub fn build_tls_acceptor(cert_pem: &str, key_pem: &str) -> Result<TlsAcceptor, Box<dyn std::error::Error + Send + Sync>> {
pub fn build_tls_acceptor(
cert_pem: &str,
key_pem: &str,
) -> Result<TlsAcceptor, Box<dyn std::error::Error + Send + Sync>> {
build_tls_acceptor_with_config(cert_pem, key_pem, None)
}
/// Build a TLS acceptor for backend servers that only speak HTTP/1.1.
/// Does NOT advertise h2 in ALPN, preventing false h2 auto-detection.
pub fn build_tls_acceptor_h1_only(cert_pem: &str, key_pem: &str) -> Result<TlsAcceptor, Box<dyn std::error::Error + Send + Sync>> {
pub fn build_tls_acceptor_h1_only(
cert_pem: &str,
key_pem: &str,
) -> Result<TlsAcceptor, Box<dyn std::error::Error + Send + Sync>> {
ensure_crypto_provider();
let certs = load_certs(cert_pem)?;
let key = load_private_key(key_pem)?;
@@ -130,9 +144,7 @@ pub fn build_tls_acceptor_with_config(
// Apply TLS version restrictions
let versions = resolve_tls_versions(route_tls.versions.as_deref());
let builder = ServerConfig::builder_with_protocol_versions(&versions);
builder
.with_no_client_auth()
.with_single_cert(certs, key)?
builder.with_no_client_auth().with_single_cert(certs, key)?
} else {
ServerConfig::builder()
.with_no_client_auth()
@@ -156,7 +168,9 @@ pub fn build_tls_acceptor_with_config(
}
/// Resolve TLS version strings to rustls SupportedProtocolVersion.
fn resolve_tls_versions(versions: Option<&[String]>) -> Vec<&'static rustls::SupportedProtocolVersion> {
fn resolve_tls_versions(
versions: Option<&[String]>,
) -> Vec<&'static rustls::SupportedProtocolVersion> {
let versions = match versions {
Some(v) if !v.is_empty() => v,
_ => return vec![&rustls::version::TLS12, &rustls::version::TLS13],
@@ -207,7 +221,8 @@ pub async fn accept_tls(
static SHARED_CLIENT_CONFIG: OnceLock<Arc<rustls::ClientConfig>> = OnceLock::new();
pub fn shared_backend_tls_config() -> Arc<rustls::ClientConfig> {
SHARED_CLIENT_CONFIG.get_or_init(|| {
SHARED_CLIENT_CONFIG
.get_or_init(|| {
ensure_crypto_provider();
let config = rustls::ClientConfig::builder()
.dangerous()
@@ -215,7 +230,8 @@ pub fn shared_backend_tls_config() -> Arc<rustls::ClientConfig> {
.with_no_client_auth();
info!("Built shared backend TLS client config with session resumption");
Arc::new(config)
}).clone()
})
.clone()
}
/// Get or create a shared backend TLS `ClientConfig` with ALPN `h2` + `http/1.1`.
@@ -225,16 +241,20 @@ pub fn shared_backend_tls_config() -> Arc<rustls::ClientConfig> {
static SHARED_CLIENT_CONFIG_ALPN: OnceLock<Arc<rustls::ClientConfig>> = OnceLock::new();
pub fn shared_backend_tls_config_alpn() -> Arc<rustls::ClientConfig> {
SHARED_CLIENT_CONFIG_ALPN.get_or_init(|| {
SHARED_CLIENT_CONFIG_ALPN
.get_or_init(|| {
ensure_crypto_provider();
let mut config = rustls::ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(Arc::new(InsecureVerifier))
.with_no_client_auth();
config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
info!("Built shared backend TLS client config with ALPN h2+http/1.1 for auto-detection");
info!(
"Built shared backend TLS client config with ALPN h2+http/1.1 for auto-detection"
);
Arc::new(config)
}).clone()
})
.clone()
}
/// Connect to a backend with TLS (for terminate-and-reencrypt mode).
@@ -249,7 +269,8 @@ pub async fn connect_tls(
let stream = TcpStream::connect(format!("{}:{}", host, port)).await?;
stream.set_nodelay(true)?;
// Apply keepalive with 60s default (tls_handler doesn't have ConnectionConfig access)
if let Err(e) = crate::socket_opts::apply_keepalive(&stream, std::time::Duration::from_secs(60)) {
if let Err(e) = crate::socket_opts::apply_keepalive(&stream, std::time::Duration::from_secs(60))
{
debug!("Failed to set keepalive on backend TLS socket: {}", e);
}
@@ -260,10 +281,12 @@ pub async fn connect_tls(
}
/// Load certificates from PEM string.
fn load_certs(pem: &str) -> Result<Vec<CertificateDer<'static>>, Box<dyn std::error::Error + Send + Sync>> {
fn load_certs(
pem: &str,
) -> Result<Vec<CertificateDer<'static>>, Box<dyn std::error::Error + Send + Sync>> {
let mut reader = BufReader::new(pem.as_bytes());
let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut reader)
.collect::<Result<Vec<_>, _>>()?;
let certs: Vec<CertificateDer<'static>> =
rustls_pemfile::certs(&mut reader).collect::<Result<Vec<_>, _>>()?;
if certs.is_empty() {
return Err("No certificates found in PEM data".into());
}
@@ -271,11 +294,13 @@ fn load_certs(pem: &str) -> Result<Vec<CertificateDer<'static>>, Box<dyn std::er
}
/// Load private key from PEM string.
fn load_private_key(pem: &str) -> Result<PrivateKeyDer<'static>, Box<dyn std::error::Error + Send + Sync>> {
fn load_private_key(
pem: &str,
) -> Result<PrivateKeyDer<'static>, Box<dyn std::error::Error + Send + Sync>> {
let mut reader = BufReader::new(pem.as_bytes());
// Try PKCS8 first, then RSA, then EC
let key = rustls_pemfile::private_key(&mut reader)?
.ok_or("No private key found in PEM data")?;
let key =
rustls_pemfile::private_key(&mut reader)?.ok_or("No private key found in PEM data")?;
Ok(key)
}
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,356 @@
//! UDP session (flow) tracking.
//!
//! A UDP "session" is a flow identified by (client_addr, listening_port).
//! Each session maintains a backend socket bound to an ephemeral port and
//! connected to the backend target, plus a background task that relays
//! return datagrams from the backend back to the client.
use std::net::{IpAddr, SocketAddr};
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::Instant;
use dashmap::DashMap;
use tokio::net::UdpSocket;
use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken;
use tracing::debug;
use rustproxy_metrics::MetricsCollector;
use crate::connection_tracker::ConnectionTracker;
/// A single UDP session (flow).
pub struct UdpSession {
/// Socket bound to ephemeral port, connected to backend
pub backend_socket: Arc<UdpSocket>,
/// Milliseconds since the session table's epoch
pub last_activity: AtomicU64,
/// When the session was created
pub created_at: Instant,
/// Route ID for metrics
pub route_id: Option<String>,
/// Source IP for metrics/tracking
pub source_ip: IpAddr,
/// Client address (for return path)
pub client_addr: SocketAddr,
/// Handle for the return-path relay task
pub return_task: JoinHandle<()>,
/// Per-session cancellation
pub cancel: CancellationToken,
}
impl Drop for UdpSession {
fn drop(&mut self) {
self.cancel.cancel();
self.return_task.abort();
}
}
/// Configuration for UDP session behavior.
#[derive(Debug, Clone)]
pub struct UdpSessionConfig {
/// Idle timeout in milliseconds. Default: 60000.
pub session_timeout_ms: u64,
/// Max concurrent sessions per source IP. Default: 1000.
pub max_sessions_per_ip: u32,
/// Max accepted datagram size in bytes. Default: 65535.
pub max_datagram_size: u32,
}
impl Default for UdpSessionConfig {
fn default() -> Self {
Self {
session_timeout_ms: 60_000,
max_sessions_per_ip: 1_000,
max_datagram_size: 65_535,
}
}
}
impl UdpSessionConfig {
/// Build from route's UDP config, falling back to defaults.
pub fn from_route_udp(udp: Option<&rustproxy_config::RouteUdp>) -> Self {
match udp {
Some(u) => Self {
session_timeout_ms: u.session_timeout.unwrap_or(60_000),
max_sessions_per_ip: u.max_sessions_per_ip.unwrap_or(1_000),
max_datagram_size: u.max_datagram_size.unwrap_or(65_535),
},
None => Self::default(),
}
}
}
/// Session key: (client address, listening port).
pub type SessionKey = (SocketAddr, u16);
/// Tracks all active UDP sessions across all ports.
pub struct UdpSessionTable {
/// Active sessions keyed by (client_addr, listen_port)
sessions: DashMap<SessionKey, Arc<UdpSession>>,
/// Per-IP session counts for enforcing limits
ip_session_counts: DashMap<IpAddr, u32>,
/// Time reference for last_activity
epoch: Instant,
}
impl UdpSessionTable {
pub fn new() -> Self {
Self {
sessions: DashMap::new(),
ip_session_counts: DashMap::new(),
epoch: Instant::now(),
}
}
/// Get elapsed milliseconds since epoch (for last_activity tracking).
pub fn elapsed_ms(&self) -> u64 {
self.epoch.elapsed().as_millis() as u64
}
/// Look up an existing session.
pub fn get(&self, key: &SessionKey) -> Option<Arc<UdpSession>> {
self.sessions
.get(key)
.map(|entry| Arc::clone(entry.value()))
}
/// Check if we can create a new session for this IP (under the per-IP limit).
pub fn can_create_session(&self, ip: &IpAddr, max_per_ip: u32) -> bool {
let count = self
.ip_session_counts
.get(ip)
.map(|c| *c.value())
.unwrap_or(0);
count < max_per_ip
}
/// Insert a new session. Returns false if per-IP limit exceeded.
pub fn insert(&self, key: SessionKey, session: Arc<UdpSession>, max_per_ip: u32) -> bool {
let ip = session.source_ip;
// Atomically check and increment per-IP count
let mut count_entry = self.ip_session_counts.entry(ip).or_insert(0);
if *count_entry.value() >= max_per_ip {
return false;
}
*count_entry.value_mut() += 1;
drop(count_entry);
self.sessions.insert(key, session);
true
}
/// Remove a session and decrement per-IP count.
pub fn remove(&self, key: &SessionKey) -> Option<Arc<UdpSession>> {
if let Some((_, session)) = self.sessions.remove(key) {
let ip = session.source_ip;
if let Some(mut count) = self.ip_session_counts.get_mut(&ip) {
*count.value_mut() = count.value().saturating_sub(1);
if *count.value() == 0 {
drop(count);
self.ip_session_counts.remove(&ip);
}
}
Some(session)
} else {
None
}
}
/// Clean up idle sessions past the given timeout.
/// Returns the number of sessions removed.
pub fn cleanup_idle(
&self,
timeout_ms: u64,
metrics: &MetricsCollector,
conn_tracker: &ConnectionTracker,
) -> usize {
let now_ms = self.elapsed_ms();
let mut removed = 0;
// Collect keys to remove (avoid holding DashMap refs during removal)
let stale_keys: Vec<SessionKey> = self
.sessions
.iter()
.filter(|entry| {
let last = entry.value().last_activity.load(Ordering::Relaxed);
now_ms.saturating_sub(last) >= timeout_ms
})
.map(|entry| *entry.key())
.collect();
for key in stale_keys {
if let Some(session) = self.remove(&key) {
debug!(
"UDP session expired: {} -> port {} (idle {}ms)",
session.client_addr,
key.1,
now_ms.saturating_sub(session.last_activity.load(Ordering::Relaxed))
);
conn_tracker.connection_closed(&session.source_ip);
metrics.connection_closed(
session.route_id.as_deref(),
Some(&session.source_ip.to_string()),
);
metrics.udp_session_closed();
removed += 1;
}
}
removed
}
/// Drain all sessions on a given listening port, releasing socket references.
/// Used when upgrading a raw UDP listener to QUIC — the raw UDP socket's
/// Arc refcount must drop to zero so the port can be rebound.
pub fn drain_port(
&self,
port: u16,
metrics: &MetricsCollector,
conn_tracker: &ConnectionTracker,
) -> usize {
let keys: Vec<SessionKey> = self
.sessions
.iter()
.filter(|entry| entry.key().1 == port)
.map(|entry| *entry.key())
.collect();
let mut removed = 0;
for key in keys {
if let Some(session) = self.remove(&key) {
session.cancel.cancel();
conn_tracker.connection_closed(&session.source_ip);
metrics.connection_closed(
session.route_id.as_deref(),
Some(&session.source_ip.to_string()),
);
metrics.udp_session_closed();
removed += 1;
}
}
removed
}
/// Total number of active sessions.
pub fn session_count(&self) -> usize {
self.sessions.len()
}
/// Number of tracked IPs with active sessions.
pub fn tracked_ips(&self) -> usize {
self.ip_session_counts.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::{Ipv4Addr, SocketAddrV4};
fn make_addr(port: u16) -> SocketAddr {
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(10, 0, 0, 1), port))
}
fn make_session(client_addr: SocketAddr, cancel: CancellationToken) -> Arc<UdpSession> {
// Create a dummy backend socket for testing
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
let backend_socket =
rt.block_on(async { Arc::new(UdpSocket::bind("127.0.0.1:0").await.unwrap()) });
let child_cancel = cancel.child_token();
let return_task = rt.spawn(async move {
child_cancel.cancelled().await;
});
Arc::new(UdpSession {
backend_socket,
last_activity: AtomicU64::new(0),
created_at: Instant::now(),
route_id: None,
source_ip: client_addr.ip(),
client_addr,
return_task,
cancel,
})
}
#[test]
fn test_session_table_insert_and_get() {
let table = UdpSessionTable::new();
let cancel = CancellationToken::new();
let addr = make_addr(12345);
let key: SessionKey = (addr, 53);
let session = make_session(addr, cancel);
assert!(table.insert(key, session, 1000));
assert!(table.get(&key).is_some());
assert_eq!(table.session_count(), 1);
}
#[test]
fn test_session_table_per_ip_limit() {
let table = UdpSessionTable::new();
let ip = Ipv4Addr::new(10, 0, 0, 1);
// Insert 2 sessions from same IP, limit is 2
for port in [12345u16, 12346] {
let addr = SocketAddr::V4(SocketAddrV4::new(ip, port));
let cancel = CancellationToken::new();
let session = make_session(addr, cancel);
assert!(table.insert((addr, 53), session, 2));
}
// Third should be rejected
let addr3 = SocketAddr::V4(SocketAddrV4::new(ip, 12347));
let cancel3 = CancellationToken::new();
let session3 = make_session(addr3, cancel3);
assert!(!table.insert((addr3, 53), session3, 2));
assert_eq!(table.session_count(), 2);
}
#[test]
fn test_session_table_remove() {
let table = UdpSessionTable::new();
let cancel = CancellationToken::new();
let addr = make_addr(12345);
let key: SessionKey = (addr, 53);
let session = make_session(addr, cancel);
table.insert(key, session, 1000);
assert_eq!(table.session_count(), 1);
assert_eq!(table.tracked_ips(), 1);
table.remove(&key);
assert_eq!(table.session_count(), 0);
assert_eq!(table.tracked_ips(), 0);
}
#[test]
fn test_session_config_defaults() {
let config = UdpSessionConfig::default();
assert_eq!(config.session_timeout_ms, 60_000);
assert_eq!(config.max_sessions_per_ip, 1_000);
assert_eq!(config.max_datagram_size, 65_535);
}
#[test]
fn test_session_config_from_route() {
let route_udp = rustproxy_config::RouteUdp {
session_timeout: Some(10_000),
max_sessions_per_ip: Some(500),
max_datagram_size: Some(1400),
quic: None,
};
let config = UdpSessionConfig::from_route_udp(Some(&route_udp));
assert_eq!(config.session_timeout_ms, 10_000);
assert_eq!(config.max_sessions_per_ip, 500);
assert_eq!(config.max_datagram_size, 1400);
}
}
+1 -1
View File
@@ -3,7 +3,7 @@
//! Route matching engine for RustProxy.
//! Provides domain/path/IP/header matchers and a port-indexed RouteManager.
pub mod route_manager;
pub mod matchers;
pub mod route_manager;
pub use route_manager::*;
@@ -1,5 +1,42 @@
use std::collections::HashMap;
use regex::Regex;
use std::collections::HashMap;
fn compile_regex_pattern(pattern: &str) -> Option<Regex> {
if !pattern.starts_with('/') {
return None;
}
let last_slash = pattern.rfind('/')?;
if last_slash == 0 {
return None;
}
let regex_body = &pattern[1..last_slash];
let flags = &pattern[last_slash + 1..];
let mut inline_flags = String::new();
for flag in flags.chars() {
match flag {
'i' | 'm' | 's' | 'u' => {
if !inline_flags.contains(flag) {
inline_flags.push(flag);
}
}
'g' => {
// Global has no effect for single header matching.
}
_ => return None,
}
}
let compiled = if inline_flags.is_empty() {
regex_body.to_string()
} else {
format!("(?{}){}", inline_flags, regex_body)
};
Regex::new(&compiled).ok()
}
/// Match HTTP headers against a set of patterns.
///
@@ -24,16 +61,15 @@ pub fn headers_match(
None => return false, // Required header not present
};
// Check if pattern is a regex (surrounded by /)
if pattern.starts_with('/') && pattern.ends_with('/') && pattern.len() > 2 {
let regex_str = &pattern[1..pattern.len() - 1];
match Regex::new(regex_str) {
Ok(re) => {
// Check if pattern is a regex literal (/pattern/ or /pattern/flags)
if pattern.starts_with('/') && pattern.len() > 2 {
match compile_regex_pattern(pattern) {
Some(re) => {
if !re.is_match(header_value) {
return false;
}
}
Err(_) => {
None => {
// Invalid regex, fall back to exact match
if header_value != pattern {
return false;
@@ -85,6 +121,24 @@ mod tests {
assert!(headers_match(&patterns, &headers));
}
#[test]
fn test_regex_header_match_with_flags() {
let patterns: HashMap<String, String> = {
let mut m = HashMap::new();
m.insert(
"Content-Type".to_string(),
"/^application\\/json$/i".to_string(),
);
m
};
let headers: HashMap<String, String> = {
let mut m = HashMap::new();
m.insert("content-type".to_string(), "Application/JSON".to_string());
m
};
assert!(headers_match(&patterns, &headers));
}
#[test]
fn test_missing_header() {
let patterns: HashMap<String, String> = {
@@ -1,6 +1,6 @@
use ipnet::IpNet;
use std::net::IpAddr;
use std::str::FromStr;
use ipnet::IpNet;
/// Match an IP address against a pattern.
///
@@ -85,7 +85,10 @@ fn wildcard_to_cidr(pattern: &str) -> Option<String> {
}
}
Some(format!("{}.{}.{}.{}/{}", octets[0], octets[1], octets[2], octets[3], prefix_len))
Some(format!(
"{}.{}.{}.{}/{}",
octets[0], octets[1], octets[2], octets[3], prefix_len
))
}
#[cfg(test)]
@@ -1,9 +1,9 @@
pub mod domain;
pub mod path;
pub mod ip;
pub mod header;
pub mod ip;
pub mod path;
pub use domain::*;
pub use path::*;
pub use ip::*;
pub use header::*;
pub use ip::*;
pub use path::*;
@@ -1,7 +1,7 @@
use std::collections::HashMap;
use rustproxy_config::{RouteConfig, RouteTarget, TlsMode};
use crate::matchers;
use rustproxy_config::{RouteConfig, RouteTarget, TlsMode, TransportProtocol};
/// Context for route matching (subset of connection info).
pub struct MatchContext<'a> {
@@ -12,8 +12,10 @@ pub struct MatchContext<'a> {
pub tls_version: Option<&'a str>,
pub headers: Option<&'a HashMap<String, String>>,
pub is_tls: bool,
/// Detected protocol: "http" or "tcp". None when unknown (e.g. pre-TLS-termination).
/// Detected protocol: "http", "tcp", "udp", "quic". None when unknown.
pub protocol: Option<&'a str>,
/// Transport protocol of the listener: None = TCP (backward compat), Some(Udp), Some(All).
pub transport: Option<TransportProtocol>,
}
/// Result of a route match.
@@ -40,19 +42,14 @@ impl RouteManager {
};
// Filter enabled routes and sort by priority
let mut enabled_routes: Vec<RouteConfig> = routes
.into_iter()
.filter(|r| r.is_enabled())
.collect();
let mut enabled_routes: Vec<RouteConfig> =
routes.into_iter().filter(|r| r.is_enabled()).collect();
enabled_routes.sort_by(|a, b| b.effective_priority().cmp(&a.effective_priority()));
// Build port index
for (idx, route) in enabled_routes.iter().enumerate() {
for port in route.listening_ports() {
manager.port_index
.entry(port)
.or_default()
.push(idx);
manager.port_index.entry(port).or_default().push(idx);
}
}
@@ -64,7 +61,9 @@ impl RouteManager {
/// Used to skip expensive header HashMap construction when no route needs it.
pub fn any_route_has_headers(&self, port: u16) -> bool {
if let Some(indices) = self.port_index.get(&port) {
indices.iter().any(|&idx| self.routes[idx].route_match.headers.is_some())
indices
.iter()
.any(|&idx| self.routes[idx].route_match.headers.is_some())
} else {
false
}
@@ -92,6 +91,22 @@ impl RouteManager {
fn matches_route(&self, route: &RouteConfig, ctx: &MatchContext<'_>) -> bool {
let rm = &route.route_match;
// Transport filtering: ensure route transport matches context transport
let route_transport = rm.transport.as_ref();
let ctx_transport = ctx.transport.as_ref();
match (route_transport, ctx_transport) {
// Route requires UDP only — reject non-UDP contexts
(Some(TransportProtocol::Udp), None)
| (Some(TransportProtocol::Udp), Some(TransportProtocol::Tcp)) => return false,
// Route requires TCP only — reject UDP contexts
(Some(TransportProtocol::Tcp), Some(TransportProtocol::Udp)) => return false,
// Route has no transport (default = TCP) — reject UDP contexts
(None, Some(TransportProtocol::Udp)) => return false,
// All other combinations match: All matches everything, same transport matches,
// None + None/Tcp matches (backward compat)
_ => {}
}
// Domain matching
if let Some(ref domains) = rm.domains {
if let Some(domain) = ctx.domain {
@@ -104,6 +119,11 @@ impl RouteManager {
// This prevents session-ticket resumption from misrouting when clients
// omit SNI (RFC 8446 recommends but doesn't mandate SNI on resumption).
// Wildcard-only routes (domains: ["*"]) still match since they accept all.
//
// Exception: QUIC (UDP transport) encrypts the TLS ClientHello, so SNI
// is unavailable at accept time. Domain verification happens per-request
// in H3ProxyService via the :authority header.
if ctx.transport != Some(TransportProtocol::Udp) {
let patterns = domains.to_vec();
let is_wildcard_only = patterns.iter().all(|d| *d == "*");
if !is_wildcard_only {
@@ -111,6 +131,7 @@ impl RouteManager {
}
}
}
}
// Path matching
if let Some(ref pattern) = rm.path {
@@ -172,7 +193,11 @@ impl RouteManager {
}
/// Find the best matching target within a route.
fn find_target<'a>(&self, route: &'a RouteConfig, ctx: &MatchContext<'_>) -> Option<&'a RouteTarget> {
fn find_target<'a>(
&self,
route: &'a RouteConfig,
ctx: &MatchContext<'_>,
) -> Option<&'a RouteTarget> {
let targets = route.action.targets.as_ref()?;
if targets.len() == 1 && targets[0].target_match.is_none() {
@@ -199,17 +224,11 @@ impl RouteManager {
}
// Fall back to first target without match criteria
best.or_else(|| {
targets.iter().find(|t| t.target_match.is_none())
})
best.or_else(|| targets.iter().find(|t| t.target_match.is_none()))
}
/// Check if a target match criteria matches the context.
fn matches_target(
&self,
tm: &rustproxy_config::TargetMatch,
ctx: &MatchContext<'_>,
) -> bool {
fn matches_target(&self, tm: &rustproxy_config::TargetMatch, ctx: &MatchContext<'_>) -> bool {
// Port matching
if let Some(ref ports) = tm.ports {
if !ports.contains(&ctx.port) {
@@ -257,6 +276,11 @@ impl RouteManager {
.unwrap_or_default()
}
/// Get all enabled routes.
pub fn routes(&self) -> &[RouteConfig] {
&self.routes
}
/// Get the total number of enabled routes.
pub fn route_count(&self) -> usize {
self.routes.len()
@@ -269,9 +293,7 @@ impl RouteManager {
// If multiple passthrough routes on same port, SNI is needed
let passthrough_routes: Vec<_> = routes
.iter()
.filter(|r| {
r.tls_mode() == Some(&TlsMode::Passthrough)
})
.filter(|r| r.tls_mode() == Some(&TlsMode::Passthrough))
.collect();
if passthrough_routes.len() > 1 {
@@ -303,6 +325,7 @@ mod tests {
id: None,
route_match: RouteMatch {
ports: PortRange::Single(port),
transport: None,
domains: domain.map(|d| DomainSpec::Single(d.to_string())),
path: None,
client_ip: None,
@@ -322,6 +345,7 @@ mod tests {
send_proxy_protocol: None,
headers: None,
advanced: None,
backend_transport: None,
priority: None,
}]),
tls: None,
@@ -329,9 +353,8 @@ mod tests {
load_balancing: None,
advanced: None,
options: None,
forwarding_engine: None,
nftables: None,
send_proxy_protocol: None,
udp: None,
},
headers: None,
security: None,
@@ -360,6 +383,7 @@ mod tests {
headers: None,
is_tls: false,
protocol: None,
transport: None,
};
let result = manager.find_route(&ctx);
@@ -383,11 +407,16 @@ mod tests {
headers: None,
is_tls: false,
protocol: None,
transport: None,
};
let result = manager.find_route(&ctx).unwrap();
// Should match the higher-priority specific route
assert!(result.route.route_match.domains.as_ref()
assert!(result
.route
.route_match
.domains
.as_ref()
.map(|d| d.to_vec())
.unwrap()
.contains(&"api.example.com"));
@@ -407,6 +436,7 @@ mod tests {
headers: None,
is_tls: false,
protocol: None,
transport: None,
};
assert!(manager.find_route(&ctx).is_none());
@@ -493,6 +523,7 @@ mod tests {
headers: None,
is_tls: false,
protocol: None,
transport: None,
};
assert!(manager.find_route(&ctx).is_some());
@@ -513,6 +544,7 @@ mod tests {
headers: None,
is_tls: true,
protocol: None,
transport: None,
};
assert!(manager.find_route(&ctx).is_none());
@@ -533,6 +565,7 @@ mod tests {
headers: None,
is_tls: true,
protocol: None,
transport: None,
};
assert!(manager.find_route(&ctx).is_none());
@@ -553,6 +586,7 @@ mod tests {
headers: None,
is_tls: true,
protocol: None,
transport: None,
};
assert!(manager.find_route(&ctx).is_some());
@@ -577,12 +611,19 @@ mod tests {
headers: None,
is_tls: true,
protocol: None,
transport: None,
};
let result = manager.find_route(&ctx);
assert!(result.is_some());
let matched_domains = result.unwrap().route.route_match.domains.as_ref()
.map(|d| d.to_vec()).unwrap();
let matched_domains = result
.unwrap()
.route
.route_match
.domains
.as_ref()
.map(|d| d.to_vec())
.unwrap();
assert!(matched_domains.contains(&"*"));
}
@@ -602,6 +643,7 @@ mod tests {
headers: None,
is_tls: false,
protocol: None,
transport: None,
};
assert!(manager.find_route(&ctx).is_some());
@@ -621,6 +663,7 @@ mod tests {
headers: None,
is_tls: false,
protocol: None,
transport: None,
};
assert!(manager.find_route(&ctx).is_some());
@@ -645,6 +688,7 @@ mod tests {
send_proxy_protocol: None,
headers: None,
advanced: None,
backend_transport: None,
priority: Some(10),
},
RouteTarget {
@@ -657,6 +701,7 @@ mod tests {
send_proxy_protocol: None,
headers: None,
advanced: None,
backend_transport: None,
priority: None,
},
]);
@@ -672,6 +717,7 @@ mod tests {
headers: None,
is_tls: false,
protocol: None,
transport: None,
};
let result = manager.find_route(&ctx).unwrap();
assert_eq!(result.target.unwrap().host.first(), "api-backend");
@@ -686,12 +732,17 @@ mod tests {
headers: None,
is_tls: false,
protocol: None,
transport: None,
};
let result = manager.find_route(&ctx).unwrap();
assert_eq!(result.target.unwrap().host.first(), "default-backend");
}
fn make_route_with_protocol(port: u16, domain: Option<&str>, protocol: Option<&str>) -> RouteConfig {
fn make_route_with_protocol(
port: u16,
domain: Option<&str>,
protocol: Option<&str>,
) -> RouteConfig {
let mut route = make_route(port, domain, 0);
route.route_match.protocol = protocol.map(|s| s.to_string());
route
@@ -711,6 +762,7 @@ mod tests {
headers: None,
is_tls: false,
protocol: Some("http"),
transport: None,
};
assert!(manager.find_route(&ctx).is_some());
}
@@ -729,6 +781,7 @@ mod tests {
headers: None,
is_tls: false,
protocol: Some("tcp"),
transport: None,
};
assert!(manager.find_route(&ctx).is_none());
}
@@ -748,6 +801,7 @@ mod tests {
headers: None,
is_tls: false,
protocol: Some("http"),
transport: None,
};
assert!(manager.find_route(&ctx_http).is_some());
@@ -760,6 +814,7 @@ mod tests {
headers: None,
is_tls: false,
protocol: Some("tcp"),
transport: None,
};
assert!(manager.find_route(&ctx_tcp).is_some());
}
@@ -780,7 +835,234 @@ mod tests {
headers: None,
is_tls: true,
protocol: None,
transport: None,
};
assert!(manager.find_route(&ctx).is_some());
}
// ===== Transport filtering tests =====
fn make_route_with_transport(port: u16, transport: Option<TransportProtocol>) -> RouteConfig {
let mut route = make_route(port, None, 0);
route.route_match.transport = transport;
route
}
#[test]
fn test_transport_udp_route_matches_udp_context() {
let routes = vec![make_route_with_transport(53, Some(TransportProtocol::Udp))];
let manager = RouteManager::new(routes);
let ctx = MatchContext {
port: 53,
domain: None,
path: None,
client_ip: None,
tls_version: None,
headers: None,
is_tls: false,
protocol: Some("udp"),
transport: Some(TransportProtocol::Udp),
};
assert!(manager.find_route(&ctx).is_some());
}
#[test]
fn test_transport_udp_route_rejects_tcp_context() {
let routes = vec![make_route_with_transport(53, Some(TransportProtocol::Udp))];
let manager = RouteManager::new(routes);
// TCP context (transport: None = TCP)
let ctx = MatchContext {
port: 53,
domain: None,
path: None,
client_ip: None,
tls_version: None,
headers: None,
is_tls: false,
protocol: None,
transport: None,
};
assert!(manager.find_route(&ctx).is_none());
}
#[test]
fn test_transport_tcp_route_rejects_udp_context() {
let routes = vec![make_route_with_transport(80, Some(TransportProtocol::Tcp))];
let manager = RouteManager::new(routes);
let ctx = MatchContext {
port: 80,
domain: None,
path: None,
client_ip: None,
tls_version: None,
headers: None,
is_tls: false,
protocol: Some("udp"),
transport: Some(TransportProtocol::Udp),
};
assert!(manager.find_route(&ctx).is_none());
}
#[test]
fn test_transport_all_matches_both() {
let routes = vec![make_route_with_transport(443, Some(TransportProtocol::All))];
let manager = RouteManager::new(routes);
// TCP context
let tcp_ctx = MatchContext {
port: 443,
domain: None,
path: None,
client_ip: None,
tls_version: None,
headers: None,
is_tls: false,
protocol: None,
transport: None,
};
assert!(manager.find_route(&tcp_ctx).is_some());
// UDP context
let udp_ctx = MatchContext {
port: 443,
domain: None,
path: None,
client_ip: None,
tls_version: None,
headers: None,
is_tls: false,
protocol: Some("udp"),
transport: Some(TransportProtocol::Udp),
};
assert!(manager.find_route(&udp_ctx).is_some());
}
#[test]
fn test_transport_none_default_matches_tcp_only() {
// Route with no transport field = TCP only (backward compat)
let routes = vec![make_route_with_transport(80, None)];
let manager = RouteManager::new(routes);
// TCP context should match
let tcp_ctx = MatchContext {
port: 80,
domain: None,
path: None,
client_ip: None,
tls_version: None,
headers: None,
is_tls: false,
protocol: None,
transport: None,
};
assert!(manager.find_route(&tcp_ctx).is_some());
// UDP context should NOT match
let udp_ctx = MatchContext {
port: 80,
domain: None,
path: None,
client_ip: None,
tls_version: None,
headers: None,
is_tls: false,
protocol: Some("udp"),
transport: Some(TransportProtocol::Udp),
};
assert!(manager.find_route(&udp_ctx).is_none());
}
#[test]
fn test_transport_mixed_routes_same_port() {
// TCP and UDP routes on the same port — each matches only its transport
let mut tcp_route = make_route_with_transport(443, Some(TransportProtocol::Tcp));
tcp_route.name = Some("tcp-route".to_string());
let mut udp_route = make_route_with_transport(443, Some(TransportProtocol::Udp));
udp_route.name = Some("udp-route".to_string());
let manager = RouteManager::new(vec![tcp_route, udp_route]);
let tcp_ctx = MatchContext {
port: 443,
domain: None,
path: None,
client_ip: None,
tls_version: None,
headers: None,
is_tls: false,
protocol: None,
transport: None,
};
let result = manager.find_route(&tcp_ctx).unwrap();
assert_eq!(result.route.name.as_deref(), Some("tcp-route"));
let udp_ctx = MatchContext {
port: 443,
domain: None,
path: None,
client_ip: None,
tls_version: None,
headers: None,
is_tls: false,
protocol: Some("udp"),
transport: Some(TransportProtocol::Udp),
};
let result = manager.find_route(&udp_ctx).unwrap();
assert_eq!(result.route.name.as_deref(), Some("udp-route"));
}
#[test]
fn test_quic_tls_no_sni_matches_domain_restricted_route() {
// QUIC accept-level matching: is_tls=true, domain=None, transport=Udp.
// Should match because QUIC encrypts the ClientHello — SNI is unavailable
// at accept time but verified per-request in H3ProxyService.
let mut route = make_route(443, Some("example.com"), 0);
route.route_match.transport = Some(TransportProtocol::Udp);
let routes = vec![route];
let manager = RouteManager::new(routes);
let ctx = MatchContext {
port: 443,
domain: None,
path: None,
client_ip: None,
tls_version: None,
headers: None,
is_tls: true,
protocol: Some("quic"),
transport: Some(TransportProtocol::Udp),
};
assert!(
manager.find_route(&ctx).is_some(),
"QUIC (UDP) with is_tls=true and domain=None should match domain-restricted routes"
);
}
#[test]
fn test_tcp_tls_no_sni_still_rejects_domain_restricted_route() {
// TCP TLS without SNI must still be rejected (no QUIC exemption).
let routes = vec![make_route(443, Some("example.com"), 0)];
let manager = RouteManager::new(routes);
let ctx = MatchContext {
port: 443,
domain: None,
path: None,
client_ip: None,
tls_version: None,
headers: None,
is_tls: true,
protocol: None,
transport: None, // TCP (default)
};
assert!(
manager.find_route(&ctx).is_none(),
"TCP TLS without SNI should NOT match domain-restricted routes"
);
}
}
@@ -1,5 +1,5 @@
use base64::Engine;
use base64::engine::general_purpose::STANDARD as BASE64;
use base64::Engine;
/// Basic auth validator.
pub struct BasicAuthValidator {
+217 -41
View File
@@ -2,14 +2,26 @@ use ipnet::IpNet;
use std::net::IpAddr;
use std::str::FromStr;
use rustproxy_config::IpAllowEntry;
/// IP filter supporting CIDR ranges, wildcards, and exact matches.
/// Supports domain-scoped allow entries that restrict an IP to specific domains.
pub struct IpFilter {
/// Plain allow entries — IP allowed for any domain on the route
allow_list: Vec<IpPattern>,
/// Domain-scoped allow entries — IP allowed only for matching domains
domain_scoped: Vec<DomainScopedEntry>,
block_list: Vec<IpPattern>,
}
/// A domain-scoped allow entry: IP + list of allowed domain patterns.
struct DomainScopedEntry {
pattern: IpPattern,
domains: Vec<String>,
}
/// Represents an IP pattern for matching.
#[derive(Debug)]
#[derive(Debug, Clone)]
enum IpPattern {
/// Exact IP match
Exact(IpAddr),
@@ -19,6 +31,37 @@ enum IpPattern {
Wildcard,
}
/// Compiled block list for early ingress filtering.
#[derive(Debug, Clone)]
pub struct IpBlockList {
block_list: Vec<IpPattern>,
}
impl IpBlockList {
pub fn new(block_list: &[String]) -> Self {
Self {
block_list: block_list.iter().map(|s| IpPattern::parse(s)).collect(),
}
}
pub fn empty() -> Self {
Self {
block_list: Vec::new(),
}
}
pub fn is_empty(&self) -> bool {
self.block_list.is_empty()
}
pub fn is_blocked(&self, ip: &IpAddr) -> bool {
let normalized = IpFilter::normalize_ip(ip);
self.block_list
.iter()
.any(|pattern| pattern.matches(&normalized))
}
}
impl IpPattern {
fn parse(s: &str) -> Self {
let s = s.trim();
@@ -31,10 +74,6 @@ impl IpPattern {
if let Ok(addr) = IpAddr::from_str(s) {
return IpPattern::Exact(addr);
}
// Try as CIDR by appending default prefix
if let Ok(addr) = IpAddr::from_str(s) {
return IpPattern::Exact(addr);
}
// Fallback: treat as exact, will never match an invalid string
IpPattern::Exact(IpAddr::from_str("0.0.0.0").unwrap())
}
@@ -48,19 +87,55 @@ impl IpPattern {
}
}
/// Simple domain pattern matching (exact, `*`, or `*.suffix`).
fn domain_matches_pattern(pattern: &str, domain: &str) -> bool {
let p = pattern.trim();
let d = domain.trim();
if p == "*" {
return true;
}
if p.eq_ignore_ascii_case(d) {
return true;
}
if p.starts_with("*.") {
let suffix = &p[1..]; // e.g., ".abc.xyz"
d.len() > suffix.len() && d[d.len() - suffix.len()..].eq_ignore_ascii_case(suffix)
} else {
false
}
}
impl IpFilter {
/// Create a new IP filter from allow and block lists.
pub fn new(allow_list: &[String], block_list: &[String]) -> Self {
/// Create a new IP filter from allow entries and a block list.
pub fn new(allow_entries: &[IpAllowEntry], block_list: &[String]) -> Self {
let mut allow_list = Vec::new();
let mut domain_scoped = Vec::new();
for entry in allow_entries {
match entry {
IpAllowEntry::Plain(ip) => {
allow_list.push(IpPattern::parse(ip));
}
IpAllowEntry::DomainScoped { ip, domains } => {
domain_scoped.push(DomainScopedEntry {
pattern: IpPattern::parse(ip),
domains: domains.clone(),
});
}
}
}
Self {
allow_list: allow_list.iter().map(|s| IpPattern::parse(s)).collect(),
allow_list,
domain_scoped,
block_list: block_list.iter().map(|s| IpPattern::parse(s)).collect(),
}
}
/// Check if an IP is allowed.
/// If allow_list is non-empty, IP must match at least one entry.
/// If block_list is non-empty, IP must NOT match any entry.
pub fn is_allowed(&self, ip: &IpAddr) -> bool {
/// Check if an IP is allowed, considering domain-scoped entries.
/// If `domain` is Some, domain-scoped entries are evaluated against it.
/// If `domain` is None, only plain allow entries are considered.
pub fn is_allowed_for_domain(&self, ip: &IpAddr, domain: Option<&str>) -> bool {
// Check block list first
if !self.block_list.is_empty() {
for pattern in &self.block_list {
@@ -70,14 +145,40 @@ impl IpFilter {
}
}
// If allow list is non-empty, must match at least one
if !self.allow_list.is_empty() {
return self.allow_list.iter().any(|p| p.matches(ip));
// If there are any allow entries (plain or domain-scoped), IP must match
let has_any_allow = !self.allow_list.is_empty() || !self.domain_scoped.is_empty();
if has_any_allow {
// Check plain allow list — grants access to entire route
if self.allow_list.iter().any(|p| p.matches(ip)) {
return true;
}
// Check domain-scoped entries — grants access only if domain matches
if let Some(req_domain) = domain {
for entry in &self.domain_scoped {
if entry.pattern.matches(ip) {
if entry
.domains
.iter()
.any(|d| domain_matches_pattern(d, req_domain))
{
return true;
}
}
}
}
return false;
}
true
}
/// Check if an IP is allowed (backwards-compat wrapper, no domain context).
pub fn is_allowed(&self, ip: &IpAddr) -> bool {
self.is_allowed_for_domain(ip, None)
}
/// Normalize IPv4-mapped IPv6 addresses (::ffff:x.x.x.x -> x.x.x.x)
pub fn normalize_ip(ip: &IpAddr) -> IpAddr {
match ip {
@@ -97,19 +198,28 @@ impl IpFilter {
mod tests {
use super::*;
fn plain(s: &str) -> IpAllowEntry {
IpAllowEntry::Plain(s.to_string())
}
fn scoped(ip: &str, domains: &[&str]) -> IpAllowEntry {
IpAllowEntry::DomainScoped {
ip: ip.to_string(),
domains: domains.iter().map(|s| s.to_string()).collect(),
}
}
#[test]
fn test_empty_lists_allow_all() {
let filter = IpFilter::new(&[], &[]);
let ip: IpAddr = "192.168.1.1".parse().unwrap();
assert!(filter.is_allowed(&ip));
assert!(filter.is_allowed_for_domain(&ip, Some("example.com")));
}
#[test]
fn test_allow_list_exact() {
let filter = IpFilter::new(
&["10.0.0.1".to_string()],
&[],
);
fn test_plain_allow_list_exact() {
let filter = IpFilter::new(&[plain("10.0.0.1")], &[]);
let allowed: IpAddr = "10.0.0.1".parse().unwrap();
let denied: IpAddr = "10.0.0.2".parse().unwrap();
assert!(filter.is_allowed(&allowed));
@@ -117,11 +227,8 @@ mod tests {
}
#[test]
fn test_allow_list_cidr() {
let filter = IpFilter::new(
&["10.0.0.0/8".to_string()],
&[],
);
fn test_plain_allow_list_cidr() {
let filter = IpFilter::new(&[plain("10.0.0.0/8")], &[]);
let allowed: IpAddr = "10.255.255.255".parse().unwrap();
let denied: IpAddr = "192.168.1.1".parse().unwrap();
assert!(filter.is_allowed(&allowed));
@@ -130,10 +237,7 @@ mod tests {
#[test]
fn test_block_list() {
let filter = IpFilter::new(
&[],
&["192.168.1.100".to_string()],
);
let filter = IpFilter::new(&[], &["192.168.1.100".to_string()]);
let blocked: IpAddr = "192.168.1.100".parse().unwrap();
let allowed: IpAddr = "192.168.1.101".parse().unwrap();
assert!(!filter.is_allowed(&blocked));
@@ -142,10 +246,7 @@ mod tests {
#[test]
fn test_block_trumps_allow() {
let filter = IpFilter::new(
&["10.0.0.0/8".to_string()],
&["10.0.0.5".to_string()],
);
let filter = IpFilter::new(&[plain("10.0.0.0/8")], &["10.0.0.5".to_string()]);
let blocked: IpAddr = "10.0.0.5".parse().unwrap();
let allowed: IpAddr = "10.0.0.6".parse().unwrap();
assert!(!filter.is_allowed(&blocked));
@@ -154,20 +255,14 @@ mod tests {
#[test]
fn test_wildcard_allow() {
let filter = IpFilter::new(
&["*".to_string()],
&[],
);
let filter = IpFilter::new(&[plain("*")], &[]);
let ip: IpAddr = "1.2.3.4".parse().unwrap();
assert!(filter.is_allowed(&ip));
}
#[test]
fn test_wildcard_block() {
let filter = IpFilter::new(
&[],
&["*".to_string()],
);
let filter = IpFilter::new(&[], &["*".to_string()]);
let ip: IpAddr = "1.2.3.4".parse().unwrap();
assert!(!filter.is_allowed(&ip));
}
@@ -186,4 +281,85 @@ mod tests {
let normalized = IpFilter::normalize_ip(&ip);
assert_eq!(normalized, ip);
}
// Domain-scoped tests
#[test]
fn test_domain_scoped_allows_matching_domain() {
let filter = IpFilter::new(&[scoped("10.8.0.2", &["outline.abc.xyz"])], &[]);
let ip: IpAddr = "10.8.0.2".parse().unwrap();
assert!(filter.is_allowed_for_domain(&ip, Some("outline.abc.xyz")));
}
#[test]
fn test_domain_scoped_denies_non_matching_domain() {
let filter = IpFilter::new(&[scoped("10.8.0.2", &["outline.abc.xyz"])], &[]);
let ip: IpAddr = "10.8.0.2".parse().unwrap();
assert!(!filter.is_allowed_for_domain(&ip, Some("app.abc.xyz")));
}
#[test]
fn test_domain_scoped_denies_without_domain() {
let filter = IpFilter::new(&[scoped("10.8.0.2", &["outline.abc.xyz"])], &[]);
let ip: IpAddr = "10.8.0.2".parse().unwrap();
// Without domain context, domain-scoped entries cannot match
assert!(!filter.is_allowed_for_domain(&ip, None));
}
#[test]
fn test_domain_scoped_wildcard_domain() {
let filter = IpFilter::new(&[scoped("10.8.0.2", &["*.abc.xyz"])], &[]);
let ip: IpAddr = "10.8.0.2".parse().unwrap();
assert!(filter.is_allowed_for_domain(&ip, Some("outline.abc.xyz")));
assert!(filter.is_allowed_for_domain(&ip, Some("app.abc.xyz")));
assert!(!filter.is_allowed_for_domain(&ip, Some("other.com")));
}
#[test]
fn test_plain_and_domain_scoped_coexist() {
let filter = IpFilter::new(
&[
plain("1.2.3.4"), // full route access
scoped("10.8.0.2", &["outline.abc.xyz"]), // scoped access
],
&[],
);
let admin: IpAddr = "1.2.3.4".parse().unwrap();
let vpn: IpAddr = "10.8.0.2".parse().unwrap();
let other: IpAddr = "9.9.9.9".parse().unwrap();
// Admin IP has full access
assert!(filter.is_allowed_for_domain(&admin, Some("anything.abc.xyz")));
assert!(filter.is_allowed_for_domain(&admin, Some("outline.abc.xyz")));
// VPN IP only has scoped access
assert!(filter.is_allowed_for_domain(&vpn, Some("outline.abc.xyz")));
assert!(!filter.is_allowed_for_domain(&vpn, Some("app.abc.xyz")));
// Unknown IP denied
assert!(!filter.is_allowed_for_domain(&other, Some("outline.abc.xyz")));
}
#[test]
fn test_block_trumps_domain_scoped() {
let filter = IpFilter::new(
&[scoped("10.8.0.2", &["outline.abc.xyz"])],
&["10.8.0.2".to_string()],
);
let ip: IpAddr = "10.8.0.2".parse().unwrap();
assert!(!filter.is_allowed_for_domain(&ip, Some("outline.abc.xyz")));
}
#[test]
fn test_domain_matches_pattern_fn() {
assert!(domain_matches_pattern("example.com", "example.com"));
assert!(domain_matches_pattern("*.abc.xyz", "outline.abc.xyz"));
assert!(domain_matches_pattern("*.abc.xyz", "app.abc.xyz"));
assert!(!domain_matches_pattern("*.abc.xyz", "abc.xyz")); // suffix only, not exact parent
assert!(domain_matches_pattern("*", "anything.com"));
assert!(!domain_matches_pattern("outline.abc.xyz", "app.abc.xyz"));
// Case insensitive
assert!(domain_matches_pattern("*.ABC.XYZ", "outline.abc.xyz"));
}
}
@@ -1,4 +1,4 @@
use jsonwebtoken::{decode, DecodingKey, Validation, Algorithm};
use jsonwebtoken::{decode, Algorithm, DecodingKey, Validation};
use serde::{Deserialize, Serialize};
/// JWT claims (minimal structure).
@@ -160,10 +160,7 @@ mod tests {
#[test]
fn test_extract_token_bearer() {
assert_eq!(
JwtValidator::extract_token("Bearer abc123"),
Some("abc123")
);
assert_eq!(JwtValidator::extract_token("Bearer abc123"), Some("abc123"));
}
#[test]
+4 -4
View File
@@ -2,12 +2,12 @@
//!
//! IP filtering, rate limiting, and authentication for RustProxy.
pub mod ip_filter;
pub mod rate_limiter;
pub mod basic_auth;
pub mod ip_filter;
pub mod jwt_auth;
pub mod rate_limiter;
pub use ip_filter::*;
pub use rate_limiter::*;
pub use basic_auth::*;
pub use ip_filter::*;
pub use jwt_auth::*;
pub use rate_limiter::*;
+14 -13
View File
@@ -4,8 +4,7 @@
//! Account credentials are ephemeral — the consumer owns all persistence.
use instant_acme::{
Account, NewAccount, NewOrder, Identifier, ChallengeType, OrderStatus,
AccountCredentials,
Account, AccountCredentials, ChallengeType, Identifier, NewAccount, NewOrder, OrderStatus,
};
use rcgen::{CertificateParams, KeyPair};
use thiserror::Error;
@@ -89,7 +88,11 @@ impl AcmeClient {
F: FnOnce(PendingChallenge) -> Fut,
Fut: std::future::Future<Output = Result<(), AcmeError>>,
{
info!("Starting ACME provisioning for {} via {}", domain, self.directory_url());
info!(
"Starting ACME provisioning for {} via {}",
domain,
self.directory_url()
);
// 1. Get or create ACME account
let account = self.get_or_create_account().await?;
@@ -170,14 +173,14 @@ impl AcmeClient {
debug!("Order ready, finalizing...");
// 6. Generate CSR and finalize
let key_pair = KeyPair::generate().map_err(|e| {
AcmeError::FinalizationFailed(format!("Key generation failed: {}", e))
})?;
let key_pair = KeyPair::generate()
.map_err(|e| AcmeError::FinalizationFailed(format!("Key generation failed: {}", e)))?;
let mut params = CertificateParams::new(vec![domain.to_string()]).map_err(|e| {
AcmeError::FinalizationFailed(format!("CSR params failed: {}", e))
})?;
params.distinguished_name.push(rcgen::DnType::CommonName, domain);
let mut params = CertificateParams::new(vec![domain.to_string()])
.map_err(|e| AcmeError::FinalizationFailed(format!("CSR params failed: {}", e)))?;
params
.distinguished_name
.push(rcgen::DnType::CommonName, domain);
let csr = params.serialize_request(&key_pair).map_err(|e| {
AcmeError::FinalizationFailed(format!("CSR serialization failed: {}", e))
@@ -219,9 +222,7 @@ impl AcmeClient {
.certificate()
.await
.map_err(|e| AcmeError::FinalizationFailed(e.to_string()))?
.ok_or_else(|| {
AcmeError::FinalizationFailed("No certificate returned".to_string())
})?;
.ok_or_else(|| AcmeError::FinalizationFailed("No certificate returned".to_string()))?;
let private_key_pem = key_pair.serialize_pem();
+11 -13
View File
@@ -2,8 +2,8 @@ use std::time::{SystemTime, UNIX_EPOCH};
use thiserror::Error;
use tracing::info;
use crate::cert_store::{CertStore, CertBundle, CertMetadata, CertSource};
use crate::acme::AcmeClient;
use crate::cert_store::{CertBundle, CertMetadata, CertSource, CertStore};
#[derive(Debug, Error)]
pub enum CertManagerError {
@@ -45,17 +45,13 @@ impl CertManager {
/// Create an ACME client using this manager's configuration.
/// Returns None if no ACME email is configured.
pub fn acme_client(&self) -> Option<AcmeClient> {
self.acme_email.as_ref().map(|email| {
AcmeClient::new(email.clone(), self.use_production)
})
self.acme_email
.as_ref()
.map(|email| AcmeClient::new(email.clone(), self.use_production))
}
/// Load a static certificate into the store (infallible — pure cache insert).
pub fn load_static(
&mut self,
domain: String,
bundle: CertBundle,
) {
pub fn load_static(&mut self, domain: String, bundle: CertBundle) {
self.store.store(domain, bundle);
}
@@ -108,20 +104,22 @@ impl CertManager {
F: FnOnce(String, String) -> Fut,
Fut: std::future::Future<Output = ()>,
{
let acme_client = self.acme_client()
.ok_or(CertManagerError::NoEmail)?;
let acme_client = self.acme_client().ok_or(CertManagerError::NoEmail)?;
info!("Renewing certificate for {}", domain);
let domain_owned = domain.to_string();
let result = acme_client.provision(&domain_owned, |pending| {
let result = acme_client
.provision(&domain_owned, |pending| {
let token = pending.token.clone();
let key_auth = pending.key_authorization.clone();
async move {
challenge_setup(token, key_auth).await;
Ok(())
}
}).await.map_err(|e| CertManagerError::AcmeFailure {
})
.await
.map_err(|e| CertManagerError::AcmeFailure {
domain: domain.to_string(),
message: e.to_string(),
})?;
+15 -6
View File
@@ -1,5 +1,5 @@
use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
/// Certificate metadata stored alongside certs.
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -90,8 +90,10 @@ mod tests {
fn make_test_bundle(domain: &str) -> CertBundle {
CertBundle {
key_pem: "-----BEGIN PRIVATE KEY-----\ntest-key\n-----END PRIVATE KEY-----\n".to_string(),
cert_pem: "-----BEGIN CERTIFICATE-----\ntest-cert\n-----END CERTIFICATE-----\n".to_string(),
key_pem: "-----BEGIN PRIVATE KEY-----\ntest-key\n-----END PRIVATE KEY-----\n"
.to_string(),
cert_pem: "-----BEGIN CERTIFICATE-----\ntest-cert\n-----END CERTIFICATE-----\n"
.to_string(),
ca_pem: None,
metadata: CertMetadata {
domain: domain.to_string(),
@@ -122,7 +124,8 @@ mod tests {
let mut store = CertStore::new();
let mut bundle = make_test_bundle("secure.com");
bundle.ca_pem = Some("-----BEGIN CERTIFICATE-----\nca-cert\n-----END CERTIFICATE-----\n".to_string());
bundle.ca_pem =
Some("-----BEGIN CERTIFICATE-----\nca-cert\n-----END CERTIFICATE-----\n".to_string());
store.store("secure.com".to_string(), bundle);
let loaded = store.get("secure.com").unwrap();
@@ -147,7 +150,10 @@ mod tests {
fn test_remove_cert() {
let mut store = CertStore::new();
store.store("remove-me.com".to_string(), make_test_bundle("remove-me.com"));
store.store(
"remove-me.com".to_string(),
make_test_bundle("remove-me.com"),
);
assert!(store.has("remove-me.com"));
let removed = store.remove("remove-me.com");
@@ -165,7 +171,10 @@ mod tests {
fn test_wildcard_domain() {
let mut store = CertStore::new();
store.store("*.example.com".to_string(), make_test_bundle("*.example.com"));
store.store(
"*.example.com".to_string(),
make_test_bundle("*.example.com"),
);
assert!(store.has("*.example.com"));
let loaded = store.get("*.example.com").unwrap();
+3 -3
View File
@@ -3,11 +3,11 @@
//! TLS certificate management for RustProxy.
//! Handles ACME (Let's Encrypt), static certificates, and dynamic SNI resolution.
pub mod cert_store;
pub mod cert_manager;
pub mod acme;
pub mod cert_manager;
pub mod cert_store;
pub mod sni_resolver;
pub use cert_store::*;
pub use cert_manager::*;
pub use cert_store::*;
pub use sni_resolver::*;
+7 -1
View File
@@ -20,7 +20,6 @@ rustproxy-routing = { workspace = true }
rustproxy-tls = { workspace = true }
rustproxy-passthrough = { workspace = true }
rustproxy-http = { workspace = true }
rustproxy-nftables = { workspace = true }
rustproxy-metrics = { workspace = true }
rustproxy-security = { workspace = true }
tokio = { workspace = true }
@@ -32,6 +31,7 @@ arc-swap = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
rustls = { workspace = true }
rustls-pemfile = { workspace = true }
tokio-rustls = { workspace = true }
tokio-util = { workspace = true }
dashmap = { workspace = true }
@@ -43,3 +43,9 @@ mimalloc = { workspace = true }
[dev-dependencies]
rcgen = { workspace = true }
quinn = { workspace = true }
h3 = { workspace = true }
h3-quinn = { workspace = true }
bytes = { workspace = true }
rustls = { workspace = true }
http = "1"
+12 -8
View File
@@ -13,7 +13,7 @@ use hyper::{Request, Response, StatusCode};
use hyper_util::rt::TokioIo;
use tokio::net::TcpListener;
use tokio_util::sync::CancellationToken;
use tracing::{debug, info, error};
use tracing::{debug, error, info};
/// ACME HTTP-01 challenge server.
pub struct ChallengeServer {
@@ -47,7 +47,10 @@ impl ChallengeServer {
}
/// Start the challenge server on the given port.
pub async fn start(&mut self, port: u16) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
pub async fn start(
&mut self,
port: u16,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let addr = format!("0.0.0.0:{}", port);
let listener = TcpListener::bind(&addr).await?;
info!("ACME challenge server listening on port {}", port);
@@ -101,10 +104,7 @@ impl ChallengeServer {
pub async fn stop(&mut self) {
self.cancel.cancel();
if let Some(handle) = self.handle.take() {
let _ = tokio::time::timeout(
std::time::Duration::from_secs(5),
handle,
).await;
let _ = tokio::time::timeout(std::time::Duration::from_secs(5), handle).await;
}
self.challenges.clear();
self.cancel = CancellationToken::new();
@@ -154,10 +154,14 @@ mod tests {
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
// Fetch the challenge
let client = tokio::net::TcpStream::connect("127.0.0.1:19900").await.unwrap();
let client = tokio::net::TcpStream::connect("127.0.0.1:19900")
.await
.unwrap();
let io = TokioIo::new(client);
let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await.unwrap();
tokio::spawn(async move { let _ = conn.await; });
tokio::spawn(async move {
let _ = conn.await;
});
let req = Request::get("/.well-known/acme-challenge/test-token")
.body(Full::new(Bytes::new()))
File diff suppressed because it is too large Load Diff
+4 -9
View File
@@ -1,12 +1,12 @@
#[global_allocator]
static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc;
use anyhow::Result;
use clap::Parser;
use tracing_subscriber::EnvFilter;
use anyhow::Result;
use rustproxy::RustProxy;
use rustproxy::management;
use rustproxy::RustProxy;
use rustproxy_config::RustProxyOptions;
/// RustProxy - High-performance multi-protocol proxy
@@ -43,8 +43,7 @@ async fn main() -> Result<()> {
tracing_subscriber::fmt()
.with_writer(std::io::stderr)
.with_env_filter(
EnvFilter::try_from_default_env()
.unwrap_or_else(|_| EnvFilter::new(&cli.log_level))
EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new(&cli.log_level)),
)
.init();
@@ -60,11 +59,7 @@ async fn main() -> Result<()> {
let options = RustProxyOptions::from_file(&cli.config)
.map_err(|e| anyhow::anyhow!("Failed to load config '{}': {}", cli.config, e))?;
tracing::info!(
"Loaded {} routes from {}",
options.routes.len(),
cli.config
);
tracing::info!("Loaded {} routes from {}", options.routes.len(), cli.config);
// Validate-only mode
if cli.validate {
+159 -67
View File
@@ -1,7 +1,7 @@
use anyhow::Result;
use serde::{Deserialize, Serialize};
use tokio::io::{AsyncBufReadExt, BufReader};
use tracing::{info, error};
use tracing::{error, info};
use crate::RustProxy;
use rustproxy_config::RustProxyOptions;
@@ -141,14 +141,19 @@ async fn handle_request(
"start" => handle_start(&id, &request.params, proxy).await,
"stop" => handle_stop(&id, proxy).await,
"updateRoutes" => handle_update_routes(&id, &request.params, proxy).await,
"setSecurityPolicy" => handle_set_security_policy(&id, &request.params, proxy),
"getMetrics" => handle_get_metrics(&id, proxy),
"getStatistics" => handle_get_statistics(&id, proxy),
"provisionCertificate" => handle_provision_certificate(&id, &request.params, proxy).await,
"renewCertificate" => handle_renew_certificate(&id, &request.params, proxy).await,
"getCertificateStatus" => handle_get_certificate_status(&id, &request.params, proxy).await,
"getListeningPorts" => handle_get_listening_ports(&id, proxy),
"getNftablesStatus" => handle_get_nftables_status(&id, proxy).await,
"setSocketHandlerRelay" => handle_set_socket_handler_relay(&id, &request.params, proxy).await,
"setSocketHandlerRelay" => {
handle_set_socket_handler_relay(&id, &request.params, proxy).await
}
"setDatagramHandlerRelay" => {
handle_set_datagram_handler_relay(&id, &request.params, proxy).await
}
"addListeningPort" => handle_add_listening_port(&id, &request.params, proxy).await,
"removeListeningPort" => handle_remove_listening_port(&id, &request.params, proxy).await,
"loadCertificate" => handle_load_certificate(&id, &request.params, proxy).await,
@@ -167,7 +172,12 @@ async fn handle_start(
let config = match params.get("config") {
Some(config) => config,
None => return ManagementResponse::err(id.to_string(), "Missing 'config' parameter".to_string()),
None => {
return ManagementResponse::err(
id.to_string(),
"Missing 'config' parameter".to_string(),
)
}
};
let options: RustProxyOptions = match serde_json::from_value(config.clone()) {
@@ -176,8 +186,7 @@ async fn handle_start(
};
match RustProxy::new(options) {
Ok(mut p) => {
match p.start().await {
Ok(mut p) => match p.start().await {
Ok(()) => {
send_event("started", serde_json::json!({}));
*proxy = Some(p);
@@ -187,27 +196,21 @@ async fn handle_start(
send_event("error", serde_json::json!({"message": format!("{}", e)}));
ManagementResponse::err(id.to_string(), format!("Failed to start: {}", e))
}
}
}
},
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to create proxy: {}", e)),
}
}
async fn handle_stop(
id: &str,
proxy: &mut Option<RustProxy>,
) -> ManagementResponse {
async fn handle_stop(id: &str, proxy: &mut Option<RustProxy>) -> ManagementResponse {
match proxy.as_mut() {
Some(p) => {
match p.stop().await {
Some(p) => match p.stop().await {
Ok(()) => {
*proxy = None;
send_event("stopped", serde_json::json!({}));
ManagementResponse::ok(id.to_string(), serde_json::json!({}))
}
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to stop: {}", e)),
}
}
},
None => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
}
}
@@ -224,7 +227,12 @@ async fn handle_update_routes(
let routes = match params.get("routes") {
Some(routes) => routes,
None => return ManagementResponse::err(id.to_string(), "Missing 'routes' parameter".to_string()),
None => {
return ManagementResponse::err(
id.to_string(),
"Missing 'routes' parameter".to_string(),
)
}
};
let routes: Vec<rustproxy_config::RouteConfig> = match serde_json::from_value(routes.clone()) {
@@ -234,36 +242,72 @@ async fn handle_update_routes(
match p.update_routes(routes).await {
Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to update routes: {}", e)),
Err(e) => {
ManagementResponse::err(id.to_string(), format!("Failed to update routes: {}", e))
}
}
}
fn handle_get_metrics(
fn handle_set_security_policy(
id: &str,
proxy: &Option<RustProxy>,
params: &serde_json::Value,
proxy: &mut Option<RustProxy>,
) -> ManagementResponse {
let p = match proxy.as_mut() {
Some(p) => p,
None => return ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
};
let policy = match params.get("policy") {
Some(policy) => policy,
None => {
return ManagementResponse::err(
id.to_string(),
"Missing 'policy' parameter".to_string(),
)
}
};
let policy: rustproxy_config::SecurityPolicy = match serde_json::from_value(policy.clone()) {
Ok(policy) => policy,
Err(e) => {
return ManagementResponse::err(
id.to_string(),
format!("Invalid security policy: {}", e),
)
}
};
p.set_security_policy(policy);
ManagementResponse::ok(id.to_string(), serde_json::json!({}))
}
fn handle_get_metrics(id: &str, proxy: &Option<RustProxy>) -> ManagementResponse {
match proxy.as_ref() {
Some(p) => {
let metrics = p.get_metrics();
match serde_json::to_value(&metrics) {
Ok(v) => ManagementResponse::ok(id.to_string(), v),
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to serialize metrics: {}", e)),
Err(e) => ManagementResponse::err(
id.to_string(),
format!("Failed to serialize metrics: {}", e),
),
}
}
None => ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
}
}
fn handle_get_statistics(
id: &str,
proxy: &Option<RustProxy>,
) -> ManagementResponse {
fn handle_get_statistics(id: &str, proxy: &Option<RustProxy>) -> ManagementResponse {
match proxy.as_ref() {
Some(p) => {
let stats = p.get_statistics();
match serde_json::to_value(&stats) {
Ok(v) => ManagementResponse::ok(id.to_string(), v),
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to serialize statistics: {}", e)),
Err(e) => ManagementResponse::err(
id.to_string(),
format!("Failed to serialize statistics: {}", e),
),
}
}
None => ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
@@ -282,12 +326,20 @@ async fn handle_provision_certificate(
let route_name = match params.get("routeName").and_then(|v| v.as_str()) {
Some(name) => name.to_string(),
None => return ManagementResponse::err(id.to_string(), "Missing 'routeName' parameter".to_string()),
None => {
return ManagementResponse::err(
id.to_string(),
"Missing 'routeName' parameter".to_string(),
)
}
};
match p.provision_certificate(&route_name).await {
Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to provision certificate: {}", e)),
Err(e) => ManagementResponse::err(
id.to_string(),
format!("Failed to provision certificate: {}", e),
),
}
}
@@ -303,12 +355,20 @@ async fn handle_renew_certificate(
let route_name = match params.get("routeName").and_then(|v| v.as_str()) {
Some(name) => name.to_string(),
None => return ManagementResponse::err(id.to_string(), "Missing 'routeName' parameter".to_string()),
None => {
return ManagementResponse::err(
id.to_string(),
"Missing 'routeName' parameter".to_string(),
)
}
};
match p.renew_certificate(&route_name).await {
Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to renew certificate: {}", e)),
Err(e) => ManagementResponse::err(
id.to_string(),
format!("Failed to renew certificate: {}", e),
),
}
}
@@ -324,24 +384,29 @@ async fn handle_get_certificate_status(
let route_name = match params.get("routeName").and_then(|v| v.as_str()) {
Some(name) => name,
None => return ManagementResponse::err(id.to_string(), "Missing 'routeName' parameter".to_string()),
None => {
return ManagementResponse::err(
id.to_string(),
"Missing 'routeName' parameter".to_string(),
)
}
};
match p.get_certificate_status(route_name).await {
Some(status) => ManagementResponse::ok(id.to_string(), serde_json::json!({
Some(status) => ManagementResponse::ok(
id.to_string(),
serde_json::json!({
"domain": status.domain,
"source": status.source,
"expiresAt": status.expires_at,
"isValid": status.is_valid,
})),
}),
),
None => ManagementResponse::ok(id.to_string(), serde_json::Value::Null),
}
}
fn handle_get_listening_ports(
id: &str,
proxy: &Option<RustProxy>,
) -> ManagementResponse {
fn handle_get_listening_ports(id: &str, proxy: &Option<RustProxy>) -> ManagementResponse {
match proxy.as_ref() {
Some(p) => {
let ports = p.get_listening_ports();
@@ -351,26 +416,6 @@ fn handle_get_listening_ports(
}
}
async fn handle_get_nftables_status(
id: &str,
proxy: &Option<RustProxy>,
) -> ManagementResponse {
match proxy.as_ref() {
Some(p) => {
match p.get_nftables_status().await {
Ok(status) => {
match serde_json::to_value(&status) {
Ok(v) => ManagementResponse::ok(id.to_string(), v),
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to serialize: {}", e)),
}
}
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to get status: {}", e)),
}
}
None => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
}
}
async fn handle_set_socket_handler_relay(
id: &str,
params: &serde_json::Value,
@@ -381,7 +426,8 @@ async fn handle_set_socket_handler_relay(
None => return ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
};
let socket_path = params.get("socketPath")
let socket_path = params
.get("socketPath")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
@@ -391,6 +437,27 @@ async fn handle_set_socket_handler_relay(
ManagementResponse::ok(id.to_string(), serde_json::json!({}))
}
async fn handle_set_datagram_handler_relay(
id: &str,
params: &serde_json::Value,
proxy: &mut Option<RustProxy>,
) -> ManagementResponse {
let p = match proxy.as_mut() {
Some(p) => p,
None => return ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
};
let socket_path = params
.get("socketPath")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
info!("setDatagramHandlerRelay: socket_path={:?}", socket_path);
p.set_datagram_handler_relay_path(socket_path).await;
ManagementResponse::ok(id.to_string(), serde_json::json!({}))
}
async fn handle_add_listening_port(
id: &str,
params: &serde_json::Value,
@@ -403,12 +470,17 @@ async fn handle_add_listening_port(
let port = match params.get("port").and_then(|v| v.as_u64()) {
Some(port) => port as u16,
None => return ManagementResponse::err(id.to_string(), "Missing 'port' parameter".to_string()),
None => {
return ManagementResponse::err(id.to_string(), "Missing 'port' parameter".to_string())
}
};
match p.add_listening_port(port).await {
Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to add port {}: {}", port, e)),
Err(e) => ManagementResponse::err(
id.to_string(),
format!("Failed to add port {}: {}", port, e),
),
}
}
@@ -424,12 +496,17 @@ async fn handle_remove_listening_port(
let port = match params.get("port").and_then(|v| v.as_u64()) {
Some(port) => port as u16,
None => return ManagementResponse::err(id.to_string(), "Missing 'port' parameter".to_string()),
None => {
return ManagementResponse::err(id.to_string(), "Missing 'port' parameter".to_string())
}
};
match p.remove_listening_port(port).await {
Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to remove port {}: {}", port, e)),
Err(e) => ManagementResponse::err(
id.to_string(),
format!("Failed to remove port {}: {}", port, e),
),
}
}
@@ -445,26 +522,41 @@ async fn handle_load_certificate(
let domain = match params.get("domain").and_then(|v| v.as_str()) {
Some(d) => d.to_string(),
None => return ManagementResponse::err(id.to_string(), "Missing 'domain' parameter".to_string()),
None => {
return ManagementResponse::err(
id.to_string(),
"Missing 'domain' parameter".to_string(),
)
}
};
let cert = match params.get("cert").and_then(|v| v.as_str()) {
Some(c) => c.to_string(),
None => return ManagementResponse::err(id.to_string(), "Missing 'cert' parameter".to_string()),
None => {
return ManagementResponse::err(id.to_string(), "Missing 'cert' parameter".to_string())
}
};
let key = match params.get("key").and_then(|v| v.as_str()) {
Some(k) => k.to_string(),
None => return ManagementResponse::err(id.to_string(), "Missing 'key' parameter".to_string()),
None => {
return ManagementResponse::err(id.to_string(), "Missing 'key' parameter".to_string())
}
};
let ca = params.get("ca").and_then(|v| v.as_str()).map(|s| s.to_string());
let ca = params
.get("ca")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
info!("loadCertificate: domain={}", domain);
// Load cert into cert manager and hot-swap TLS config
match p.load_certificate(&domain, cert, key, ca).await {
Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to load certificate for {}: {}", domain, e)),
Err(e) => ManagementResponse::err(
id.to_string(),
format!("Failed to load certificate for {}: {}", domain, e),
),
}
}
+11 -10
View File
@@ -136,7 +136,8 @@ pub async fn start_http_echo_backend(port: u16, backend_name: &str) -> JoinHandl
let path = parts.get(1).copied().unwrap_or("/");
// Extract Host header
let host = req_str.lines()
let host = req_str
.lines()
.find(|l| l.to_lowercase().starts_with("host:"))
.map(|l| l[5..].trim())
.unwrap_or("unknown");
@@ -269,6 +270,7 @@ pub fn make_test_route(
id: None,
route_match: rustproxy_config::RouteMatch {
ports: rustproxy_config::PortRange::Single(port),
transport: None,
domains: domain.map(|d| rustproxy_config::DomainSpec::Single(d.to_string())),
path: None,
client_ip: None,
@@ -288,6 +290,7 @@ pub fn make_test_route(
send_proxy_protocol: None,
headers: None,
advanced: None,
backend_transport: None,
priority: None,
}]),
tls: None,
@@ -295,9 +298,8 @@ pub fn make_test_route(
load_balancing: None,
advanced: None,
options: None,
forwarding_engine: None,
nftables: None,
send_proxy_protocol: None,
udp: None,
},
headers: None,
security: None,
@@ -335,7 +337,8 @@ pub async fn start_ws_echo_backend(port: u16) -> JoinHandle<()> {
let req_str = String::from_utf8_lossy(&buf[..n]);
// Extract Sec-WebSocket-Key for proper handshake
let ws_key = req_str.lines()
let ws_key = req_str
.lines()
.find(|l| l.to_lowercase().starts_with("sec-websocket-key:"))
.map(|l| l.split(':').nth(1).unwrap_or("").trim().to_string())
.unwrap_or_default();
@@ -377,7 +380,9 @@ pub fn generate_self_signed_cert(domain: &str) -> (String, String) {
use rcgen::{CertificateParams, KeyPair};
let mut params = CertificateParams::new(vec![domain.to_string()]).unwrap();
params.distinguished_name.push(rcgen::DnType::CommonName, domain);
params
.distinguished_name
.push(rcgen::DnType::CommonName, domain);
let key_pair = KeyPair::generate().unwrap();
let cert = params.self_signed(&key_pair).unwrap();
@@ -457,11 +462,7 @@ pub fn make_tls_terminate_route(
/// Start a TLS WebSocket echo backend: accepts TLS, performs WS handshake, then echoes data.
/// Combines TLS acceptance (like `start_tls_http_backend`) with WebSocket echo (like `start_ws_echo_backend`).
pub async fn start_tls_ws_echo_backend(
port: u16,
cert_pem: &str,
key_pem: &str,
) -> JoinHandle<()> {
pub async fn start_tls_ws_echo_backend(port: u16, cert_pem: &str, key_pem: &str) -> JoinHandle<()> {
use std::sync::Arc;
let acceptor = rustproxy_passthrough::build_tls_acceptor(cert_pem, key_pem)
@@ -0,0 +1,213 @@
mod common;
use bytes::Buf;
use common::*;
use rustproxy::RustProxy;
use rustproxy_config::{RouteQuic, RouteUdp, RustProxyOptions, TransportProtocol};
use std::sync::Arc;
/// Build a route that listens on UDP with HTTP/3 enabled and TLS terminate.
fn make_h3_route(
port: u16,
target_host: &str,
target_port: u16,
cert_pem: &str,
key_pem: &str,
) -> rustproxy_config::RouteConfig {
let mut route = make_tls_terminate_route(
port,
"localhost",
target_host,
target_port,
cert_pem,
key_pem,
);
route.route_match.transport = Some(TransportProtocol::All);
// Keep domain="localhost" from make_tls_terminate_route — needed for TLS cert extraction
route.action.udp = Some(RouteUdp {
session_timeout: None,
max_sessions_per_ip: None,
max_datagram_size: None,
quic: Some(RouteQuic {
max_idle_timeout: Some(30000),
max_concurrent_bidi_streams: None,
max_concurrent_uni_streams: None,
enable_http3: Some(true),
alt_svc_port: None,
alt_svc_max_age: None,
initial_congestion_window: None,
}),
});
route
}
/// Build a quinn client endpoint with insecure TLS for testing.
fn make_h3_client_endpoint() -> quinn::Endpoint {
let mut tls_config = rustls::ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(Arc::new(InsecureVerifier))
.with_no_client_auth();
tls_config.alpn_protocols = vec![b"h3".to_vec()];
let quic_client_config = quinn::crypto::rustls::QuicClientConfig::try_from(tls_config)
.expect("Failed to build QUIC client config");
let client_config = quinn::ClientConfig::new(Arc::new(quic_client_config));
let mut endpoint = quinn::Endpoint::client("0.0.0.0:0".parse().unwrap())
.expect("Failed to create QUIC client endpoint");
endpoint.set_default_client_config(client_config);
endpoint
}
/// Test that HTTP/3 response streams properly finish (FIN is received by client).
///
/// This is the critical regression test for the FIN bug: the proxy must send
/// a QUIC stream FIN after the response body so the client's `recv_data()`
/// returns `None` instead of hanging forever.
#[tokio::test]
async fn test_h3_response_stream_finishes() {
let backend_port = next_port();
let proxy_port = next_port();
let body_text = "Hello from HTTP/3 backend! This body has a known length for testing.";
// 1. Start plain HTTP backend with known body + content-length
let _backend = start_http_server(backend_port, 200, body_text).await;
// 2. Generate self-signed cert and configure H3 route
let (cert_pem, key_pem) = generate_self_signed_cert("localhost");
let route = make_h3_route(proxy_port, "127.0.0.1", backend_port, &cert_pem, &key_pem);
let options = RustProxyOptions {
routes: vec![route],
..Default::default()
};
// 3. Start proxy and wait for UDP bind
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
tokio::time::sleep(std::time::Duration::from_millis(500)).await;
// 4. Connect QUIC/H3 client
let endpoint = make_h3_client_endpoint();
let addr: std::net::SocketAddr = format!("127.0.0.1:{}", proxy_port).parse().unwrap();
let connection = endpoint
.connect(addr, "localhost")
.expect("Failed to initiate QUIC connection")
.await
.expect("QUIC handshake failed");
let (mut driver, mut send_request) = h3::client::new(h3_quinn::Connection::new(connection))
.await
.expect("H3 connection setup failed");
// Drive the H3 connection in background
tokio::spawn(async move {
let _ = driver.wait_idle().await;
});
// 5. Send GET request
let req = http::Request::builder()
.method("GET")
.uri("https://localhost/")
.header("host", "localhost")
.body(())
.unwrap();
let mut stream = send_request
.send_request(req)
.await
.expect("Failed to send H3 request");
stream
.finish()
.await
.expect("Failed to finish sending H3 request body");
// 6. Read response headers
let resp = stream
.recv_response()
.await
.expect("Failed to receive H3 response");
assert_eq!(
resp.status(),
http::StatusCode::OK,
"Expected 200 OK, got {}",
resp.status()
);
// 7. Read body and verify stream ends (FIN received)
// This is the critical assertion: recv_data() must return None (stream ended)
// within the timeout, NOT hang forever waiting for a FIN that never arrives.
let result = with_timeout(
async {
let mut total = 0usize;
while let Some(chunk) = stream.recv_data().await.expect("H3 data receive error") {
total += chunk.remaining();
}
// recv_data() returned None => stream ended (FIN received)
total
},
10,
)
.await;
let bytes_received = result.expect(
"TIMEOUT: H3 stream never ended (FIN not received by client). \
The proxy sent all response data but failed to send the QUIC stream FIN.",
);
assert_eq!(
bytes_received,
body_text.len(),
"Expected {} bytes, got {}",
body_text.len(),
bytes_received
);
// 8. Cleanup
endpoint.close(quinn::VarInt::from_u32(0), b"test done");
proxy.stop().await.unwrap();
}
/// Insecure TLS verifier that accepts any certificate (for tests only).
#[derive(Debug)]
struct InsecureVerifier;
impl rustls::client::danger::ServerCertVerifier for InsecureVerifier {
fn verify_server_cert(
&self,
_end_entity: &rustls::pki_types::CertificateDer<'_>,
_intermediates: &[rustls::pki_types::CertificateDer<'_>],
_server_name: &rustls::pki_types::ServerName<'_>,
_ocsp_response: &[u8],
_now: rustls::pki_types::UnixTime,
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
Ok(rustls::client::danger::ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &rustls::pki_types::CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &rustls::pki_types::CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
vec![
rustls::SignatureScheme::RSA_PKCS1_SHA256,
rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
rustls::SignatureScheme::ED25519,
rustls::SignatureScheme::RSA_PSS_SHA256,
]
}
}
@@ -43,17 +43,32 @@ async fn test_http_forward_basic() {
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async {
let result = with_timeout(
async {
let response = send_http_request(proxy_port, "anyhost.com", "GET", "/hello").await;
let body = extract_body(&response);
body.to_string()
}, 10)
},
10,
)
.await
.unwrap();
assert!(result.contains(r#""method":"GET"#), "Expected GET method, got: {}", result);
assert!(result.contains(r#""path":"/hello"#), "Expected /hello path, got: {}", result);
assert!(result.contains(r#""backend":"main"#), "Expected main backend, got: {}", result);
assert!(
result.contains(r#""method":"GET"#),
"Expected GET method, got: {}",
result
);
assert!(
result.contains(r#""path":"/hello"#),
"Expected /hello path, got: {}",
result
);
assert!(
result.contains(r#""backend":"main"#),
"Expected main backend, got: {}",
result
);
proxy.stop().await.unwrap();
}
@@ -69,8 +84,18 @@ async fn test_http_forward_host_routing() {
let options = RustProxyOptions {
routes: vec![
make_test_route(proxy_port, Some("alpha.example.com"), "127.0.0.1", backend1_port),
make_test_route(proxy_port, Some("beta.example.com"), "127.0.0.1", backend2_port),
make_test_route(
proxy_port,
Some("alpha.example.com"),
"127.0.0.1",
backend1_port,
),
make_test_route(
proxy_port,
Some("beta.example.com"),
"127.0.0.1",
backend2_port,
),
],
..Default::default()
};
@@ -80,24 +105,38 @@ async fn test_http_forward_host_routing() {
assert!(wait_for_port(proxy_port, 2000).await);
// Test alpha domain
let alpha_result = with_timeout(async {
let alpha_result = with_timeout(
async {
let response = send_http_request(proxy_port, "alpha.example.com", "GET", "/").await;
extract_body(&response).to_string()
}, 10)
},
10,
)
.await
.unwrap();
assert!(alpha_result.contains(r#""backend":"alpha"#), "Expected alpha backend, got: {}", alpha_result);
assert!(
alpha_result.contains(r#""backend":"alpha"#),
"Expected alpha backend, got: {}",
alpha_result
);
// Test beta domain
let beta_result = with_timeout(async {
let beta_result = with_timeout(
async {
let response = send_http_request(proxy_port, "beta.example.com", "GET", "/").await;
extract_body(&response).to_string()
}, 10)
},
10,
)
.await
.unwrap();
assert!(beta_result.contains(r#""backend":"beta"#), "Expected beta backend, got: {}", beta_result);
assert!(
beta_result.contains(r#""backend":"beta"#),
"Expected beta backend, got: {}",
beta_result
);
proxy.stop().await.unwrap();
}
@@ -127,24 +166,38 @@ async fn test_http_forward_path_routing() {
assert!(wait_for_port(proxy_port, 2000).await);
// Test API path
let api_result = with_timeout(async {
let api_result = with_timeout(
async {
let response = send_http_request(proxy_port, "any.com", "GET", "/api/users").await;
extract_body(&response).to_string()
}, 10)
},
10,
)
.await
.unwrap();
assert!(api_result.contains(r#""backend":"api"#), "Expected api backend, got: {}", api_result);
assert!(
api_result.contains(r#""backend":"api"#),
"Expected api backend, got: {}",
api_result
);
// Test web path (no /api prefix)
let web_result = with_timeout(async {
let web_result = with_timeout(
async {
let response = send_http_request(proxy_port, "any.com", "GET", "/index.html").await;
extract_body(&response).to_string()
}, 10)
},
10,
)
.await
.unwrap();
assert!(web_result.contains(r#""backend":"web"#), "Expected web backend, got: {}", web_result);
assert!(
web_result.contains(r#""backend":"web"#),
"Expected web backend, got: {}",
web_result
);
proxy.stop().await.unwrap();
}
@@ -184,9 +237,18 @@ async fn test_http_forward_cors_preflight() {
.unwrap();
// Should get 204 No Content with CORS headers
assert!(result.contains("204"), "Expected 204 status, got: {}", result);
assert!(result.to_lowercase().contains("access-control-allow-origin"),
"Expected CORS header, got: {}", result);
assert!(
result.contains("204"),
"Expected 204 status, got: {}",
result
);
assert!(
result
.to_lowercase()
.contains("access-control-allow-origin"),
"Expected CORS header, got: {}",
result
);
proxy.stop().await.unwrap();
}
@@ -208,15 +270,22 @@ async fn test_http_forward_backend_error() {
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async {
let result = with_timeout(
async {
let response = send_http_request(proxy_port, "example.com", "GET", "/fail").await;
response
}, 10)
},
10,
)
.await
.unwrap();
// Proxy should relay the 500 from backend
assert!(result.contains("500"), "Expected 500 status, got: {}", result);
assert!(
result.contains("500"),
"Expected 500 status, got: {}",
result
);
proxy.stop().await.unwrap();
}
@@ -227,7 +296,12 @@ async fn test_http_forward_no_route_matched() {
// Create a route only for a specific domain
let options = RustProxyOptions {
routes: vec![make_test_route(proxy_port, Some("known.example.com"), "127.0.0.1", 9999)],
routes: vec![make_test_route(
proxy_port,
Some("known.example.com"),
"127.0.0.1",
9999,
)],
..Default::default()
};
@@ -235,15 +309,22 @@ async fn test_http_forward_no_route_matched() {
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async {
let result = with_timeout(
async {
let response = send_http_request(proxy_port, "unknown.example.com", "GET", "/").await;
response
}, 10)
},
10,
)
.await
.unwrap();
// Should get 502 Bad Gateway (no route matched)
assert!(result.contains("502"), "Expected 502 status, got: {}", result);
assert!(
result.contains("502"),
"Expected 502 status, got: {}",
result
);
proxy.stop().await.unwrap();
}
@@ -262,15 +343,22 @@ async fn test_http_forward_backend_unavailable() {
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async {
let result = with_timeout(
async {
let response = send_http_request(proxy_port, "example.com", "GET", "/").await;
response
}, 10)
},
10,
)
.await
.unwrap();
// Should get 502 Bad Gateway (backend unavailable)
assert!(result.contains("502"), "Expected 502 status, got: {}", result);
assert!(
result.contains("502"),
"Expected 502 status, got: {}",
result
);
proxy.stop().await.unwrap();
}
@@ -286,7 +374,12 @@ async fn test_https_terminate_http_forward() {
let options = RustProxyOptions {
routes: vec![make_tls_terminate_route(
proxy_port, domain, "127.0.0.1", backend_port, &cert_pem, &key_pem,
proxy_port,
domain,
"127.0.0.1",
backend_port,
&cert_pem,
&key_pem,
)],
..Default::default()
};
@@ -295,7 +388,8 @@ async fn test_https_terminate_http_forward() {
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async {
let result = with_timeout(
async {
let _ = rustls::crypto::ring::default_provider().install_default();
let tls_config = rustls::ClientConfig::builder()
.dangerous()
@@ -319,14 +413,28 @@ async fn test_https_terminate_http_forward() {
let mut response = Vec::new();
tls_stream.read_to_end(&mut response).await.unwrap();
String::from_utf8_lossy(&response).to_string()
}, 10)
},
10,
)
.await
.unwrap();
let body = extract_body(&result);
assert!(body.contains(r#""method":"GET"#), "Expected GET, got: {}", body);
assert!(body.contains(r#""path":"/api/data"#), "Expected /api/data, got: {}", body);
assert!(body.contains(r#""backend":"tls-backend"#), "Expected tls-backend, got: {}", body);
assert!(
body.contains(r#""method":"GET"#),
"Expected GET, got: {}",
body
);
assert!(
body.contains(r#""path":"/api/data"#),
"Expected /api/data, got: {}",
body
);
assert!(
body.contains(r#""backend":"tls-backend"#),
"Expected tls-backend, got: {}",
body
);
proxy.stop().await.unwrap();
}
@@ -347,7 +455,8 @@ async fn test_websocket_through_proxy() {
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async {
let result = with_timeout(
async {
let mut stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
@@ -369,7 +478,9 @@ async fn test_websocket_through_proxy() {
let mut temp = [0u8; 1];
loop {
let n = stream.read(&mut temp).await.unwrap();
if n == 0 { break; }
if n == 0 {
break;
}
response_buf.push(temp[0]);
if response_buf.len() >= 4 {
let len = response_buf.len();
@@ -380,7 +491,11 @@ async fn test_websocket_through_proxy() {
}
let response_str = String::from_utf8_lossy(&response_buf).to_string();
assert!(response_str.contains("101"), "Expected 101 Switching Protocols, got: {}", response_str);
assert!(
response_str.contains("101"),
"Expected 101 Switching Protocols, got: {}",
response_str
);
assert!(
response_str.to_lowercase().contains("upgrade: websocket"),
"Expected Upgrade header, got: {}",
@@ -399,7 +514,9 @@ async fn test_websocket_through_proxy() {
assert_eq!(echoed, test_data, "Expected echo of sent data");
"ok".to_string()
}, 10)
},
10,
)
.await
.unwrap();
@@ -431,12 +548,22 @@ async fn test_terminate_and_reencrypt_http_routing() {
// Create terminate-and-reencrypt routes
let mut route1 = make_tls_terminate_route(
proxy_port, "alpha.example.com", "127.0.0.1", backend1_port, &cert1, &key1,
proxy_port,
"alpha.example.com",
"127.0.0.1",
backend1_port,
&cert1,
&key1,
);
route1.action.tls.as_mut().unwrap().mode = rustproxy_config::TlsMode::TerminateAndReencrypt;
let mut route2 = make_tls_terminate_route(
proxy_port, "beta.example.com", "127.0.0.1", backend2_port, &cert2, &key2,
proxy_port,
"beta.example.com",
"127.0.0.1",
backend2_port,
&cert2,
&key2,
);
route2.action.tls.as_mut().unwrap().mode = rustproxy_config::TlsMode::TerminateAndReencrypt;
@@ -450,7 +577,8 @@ async fn test_terminate_and_reencrypt_http_routing() {
assert!(wait_for_port(proxy_port, 2000).await);
// Test alpha domain - HTTP request through TLS terminate-and-reencrypt
let alpha_result = with_timeout(async {
let alpha_result = with_timeout(
async {
let _ = rustls::crypto::ring::default_provider().install_default();
let tls_config = rustls::ClientConfig::builder()
.dangerous()
@@ -461,16 +589,20 @@ async fn test_terminate_and_reencrypt_http_routing() {
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
let server_name = rustls::pki_types::ServerName::try_from("alpha.example.com".to_string()).unwrap();
let server_name =
rustls::pki_types::ServerName::try_from("alpha.example.com".to_string()).unwrap();
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
let request = "GET /api/data HTTP/1.1\r\nHost: alpha.example.com\r\nConnection: close\r\n\r\n";
let request =
"GET /api/data HTTP/1.1\r\nHost: alpha.example.com\r\nConnection: close\r\n\r\n";
tls_stream.write_all(request.as_bytes()).await.unwrap();
let mut response = Vec::new();
tls_stream.read_to_end(&mut response).await.unwrap();
String::from_utf8_lossy(&response).to_string()
}, 10)
},
10,
)
.await
.unwrap();
@@ -498,7 +630,8 @@ async fn test_terminate_and_reencrypt_http_routing() {
);
// Test beta domain - different host goes to different backend
let beta_result = with_timeout(async {
let beta_result = with_timeout(
async {
let _ = rustls::crypto::ring::default_provider().install_default();
let tls_config = rustls::ClientConfig::builder()
.dangerous()
@@ -509,16 +642,20 @@ async fn test_terminate_and_reencrypt_http_routing() {
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
let server_name = rustls::pki_types::ServerName::try_from("beta.example.com".to_string()).unwrap();
let server_name =
rustls::pki_types::ServerName::try_from("beta.example.com".to_string()).unwrap();
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
let request = "GET /other HTTP/1.1\r\nHost: beta.example.com\r\nConnection: close\r\n\r\n";
let request =
"GET /other HTTP/1.1\r\nHost: beta.example.com\r\nConnection: close\r\n\r\n";
tls_stream.write_all(request.as_bytes()).await.unwrap();
let mut response = Vec::new();
tls_stream.read_to_end(&mut response).await.unwrap();
String::from_utf8_lossy(&response).to_string()
}, 10)
},
10,
)
.await
.unwrap();
@@ -589,14 +726,12 @@ async fn test_terminate_and_reencrypt_websocket() {
.dangerous()
.with_custom_certificate_verifier(std::sync::Arc::new(InsecureVerifier))
.with_no_client_auth();
let connector =
tokio_rustls::TlsConnector::from(std::sync::Arc::new(tls_config));
let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(tls_config));
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
let server_name =
rustls::pki_types::ServerName::try_from(domain.to_string()).unwrap();
let server_name = rustls::pki_types::ServerName::try_from(domain.to_string()).unwrap();
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
// Send WebSocket upgrade request through TLS
@@ -685,10 +820,13 @@ async fn test_protocol_field_in_route_config() {
assert!(wait_for_port(proxy_port, 2000).await);
// HTTP request should match the route and get proxied
let result = with_timeout(async {
let result = with_timeout(
async {
let response = send_http_request(proxy_port, "example.com", "GET", "/test").await;
extract_body(&response).to_string()
}, 10)
},
10,
)
.await
.unwrap();
@@ -20,13 +20,19 @@ async fn test_start_and_stop() {
assert!(!wait_for_port(port, 200).await);
proxy.start().await.unwrap();
assert!(wait_for_port(port, 2000).await, "Port should be listening after start");
assert!(
wait_for_port(port, 2000).await,
"Port should be listening after start"
);
proxy.stop().await.unwrap();
// Give the OS a moment to release the port
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
assert!(!wait_for_port(port, 200).await, "Port should not be listening after stop");
assert!(
!wait_for_port(port, 200).await,
"Port should not be listening after stop"
);
}
#[tokio::test]
@@ -54,7 +60,12 @@ async fn test_update_routes_hot_reload() {
let port = next_port();
let options = RustProxyOptions {
routes: vec![make_test_route(port, Some("old.example.com"), "127.0.0.1", 8080)],
routes: vec![make_test_route(
port,
Some("old.example.com"),
"127.0.0.1",
8080,
)],
..Default::default()
};
@@ -62,9 +73,12 @@ async fn test_update_routes_hot_reload() {
proxy.start().await.unwrap();
// Update routes atomically
let new_routes = vec![
make_test_route(port, Some("new.example.com"), "127.0.0.1", 9090),
];
let new_routes = vec![make_test_route(
port,
Some("new.example.com"),
"127.0.0.1",
9090,
)];
let result = proxy.update_routes(new_routes).await;
assert!(result.is_ok());
@@ -87,15 +101,24 @@ async fn test_add_remove_listening_port() {
// Add a new port
proxy.add_listening_port(port2).await.unwrap();
assert!(wait_for_port(port2, 2000).await, "New port should be listening");
assert!(
wait_for_port(port2, 2000).await,
"New port should be listening"
);
// Remove the port
proxy.remove_listening_port(port2).await.unwrap();
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
assert!(!wait_for_port(port2, 200).await, "Removed port should not be listening");
assert!(
!wait_for_port(port2, 200).await,
"Removed port should not be listening"
);
// Original port should still be listening
assert!(wait_for_port(port1, 200).await, "Original port should still be listening");
assert!(
wait_for_port(port1, 200).await,
"Original port should still be listening"
);
proxy.stop().await.unwrap();
}
@@ -168,7 +191,11 @@ async fn test_metrics_track_connections() {
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
let stats = proxy.get_statistics();
assert!(stats.total_connections > 0, "Expected total_connections > 0, got {}", stats.total_connections);
assert!(
stats.total_connections > 0,
"Expected total_connections > 0, got {}",
stats.total_connections
);
proxy.stop().await.unwrap();
}
@@ -205,8 +232,11 @@ async fn test_metrics_track_bytes() {
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
let stats = proxy.get_statistics();
assert!(stats.total_connections > 0,
"Expected some connections tracked, got {}", stats.total_connections);
assert!(
stats.total_connections > 0,
"Expected some connections tracked, got {}",
stats.total_connections
);
proxy.stop().await.unwrap();
}
@@ -228,23 +258,38 @@ async fn test_hot_reload_port_changes() {
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
assert!(wait_for_port(port1, 2000).await);
assert!(!wait_for_port(port2, 200).await, "port2 should not be listening yet");
assert!(
!wait_for_port(port2, 200).await,
"port2 should not be listening yet"
);
// Update routes to use port2 instead
let new_routes = vec![
make_test_route(port2, None, "127.0.0.1", backend_port),
];
let new_routes = vec![make_test_route(port2, None, "127.0.0.1", backend_port)];
proxy.update_routes(new_routes).await.unwrap();
// Port2 should now be listening, port1 should be closed
assert!(wait_for_port(port2, 2000).await, "port2 should be listening after reload");
assert!(
wait_for_port(port2, 2000).await,
"port2 should be listening after reload"
);
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
assert!(!wait_for_port(port1, 200).await, "port1 should be closed after reload");
assert!(
!wait_for_port(port1, 200).await,
"port1 should be closed after reload"
);
// Verify port2 works
let ports = proxy.get_listening_ports();
assert!(ports.contains(&port2), "Expected port2 in listening ports: {:?}", ports);
assert!(!ports.contains(&port1), "port1 should not be in listening ports: {:?}", ports);
assert!(
ports.contains(&port2),
"Expected port2 in listening ports: {:?}",
ports
);
assert!(
!ports.contains(&port1),
"port1 should not be in listening ports: {:?}",
ports
);
proxy.stop().await.unwrap();
}
@@ -24,10 +24,14 @@ async fn test_tcp_forward_echo() {
proxy.start().await.unwrap();
// Wait for proxy to be ready
assert!(wait_for_port(proxy_port, 2000).await, "Proxy port not ready");
assert!(
wait_for_port(proxy_port, 2000).await,
"Proxy port not ready"
);
// Connect and send data
let result = with_timeout(async {
let result = with_timeout(
async {
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
@@ -36,7 +40,9 @@ async fn test_tcp_forward_echo() {
let mut buf = vec![0u8; 1024];
let n = stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string()
}, 5)
},
5,
)
.await
.unwrap();
@@ -61,7 +67,8 @@ async fn test_tcp_forward_large_payload() {
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async {
let result = with_timeout(
async {
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
@@ -75,7 +82,9 @@ async fn test_tcp_forward_large_payload() {
let mut received = Vec::new();
stream.read_to_end(&mut received).await.unwrap();
received.len()
}, 10)
},
10,
)
.await
.unwrap();
@@ -100,7 +109,8 @@ async fn test_tcp_forward_multiple_connections() {
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async {
let result = with_timeout(
async {
let mut handles = Vec::new();
for i in 0..10 {
let port = proxy_port;
@@ -122,7 +132,9 @@ async fn test_tcp_forward_multiple_connections() {
results.push(handle.await.unwrap());
}
results
}, 10)
},
10,
)
.await
.unwrap();
@@ -149,14 +161,20 @@ async fn test_tcp_forward_backend_unreachable() {
assert!(wait_for_port(proxy_port, 2000).await);
// Connection should complete (proxy accepts it) but data should not flow
let result = with_timeout(async {
let result = with_timeout(
async {
let stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port)).await;
stream.is_ok()
}, 5)
},
5,
)
.await
.unwrap();
assert!(result, "Should be able to connect to proxy even if backend is down");
assert!(
result,
"Should be able to connect to proxy even if backend is down"
);
proxy.stop().await.unwrap();
}
@@ -178,7 +196,8 @@ async fn test_tcp_forward_bidirectional() {
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async {
let result = with_timeout(
async {
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
@@ -187,7 +206,9 @@ async fn test_tcp_forward_bidirectional() {
let mut buf = vec![0u8; 1024];
let n = stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string()
}, 5)
},
5,
)
.await
.unwrap();
@@ -65,8 +65,18 @@ async fn test_tls_passthrough_sni_routing() {
let options = RustProxyOptions {
routes: vec![
make_tls_passthrough_route(proxy_port, Some("one.example.com"), "127.0.0.1", backend1_port),
make_tls_passthrough_route(proxy_port, Some("two.example.com"), "127.0.0.1", backend2_port),
make_tls_passthrough_route(
proxy_port,
Some("one.example.com"),
"127.0.0.1",
backend1_port,
),
make_tls_passthrough_route(
proxy_port,
Some("two.example.com"),
"127.0.0.1",
backend2_port,
),
],
..Default::default()
};
@@ -76,7 +86,8 @@ async fn test_tls_passthrough_sni_routing() {
assert!(wait_for_port(proxy_port, 2000).await);
// Send a fake ClientHello with SNI "one.example.com"
let result = with_timeout(async {
let result = with_timeout(
async {
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
@@ -86,15 +97,22 @@ async fn test_tls_passthrough_sni_routing() {
let mut buf = vec![0u8; 4096];
let n = stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string()
}, 5)
},
5,
)
.await
.unwrap();
// Backend1 should have received the ClientHello and prefixed its response
assert!(result.starts_with("BACKEND1:"), "Expected BACKEND1 prefix, got: {}", result);
assert!(
result.starts_with("BACKEND1:"),
"Expected BACKEND1 prefix, got: {}",
result
);
// Now test routing to backend2
let result2 = with_timeout(async {
let result2 = with_timeout(
async {
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
@@ -104,11 +122,17 @@ async fn test_tls_passthrough_sni_routing() {
let mut buf = vec![0u8; 4096];
let n = stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string()
}, 5)
},
5,
)
.await
.unwrap();
assert!(result2.starts_with("BACKEND2:"), "Expected BACKEND2 prefix, got: {}", result2);
assert!(
result2.starts_with("BACKEND2:"),
"Expected BACKEND2 prefix, got: {}",
result2
);
proxy.stop().await.unwrap();
}
@@ -121,9 +145,12 @@ async fn test_tls_passthrough_unknown_sni() {
let _backend = start_echo_server(backend_port).await;
let options = RustProxyOptions {
routes: vec![
make_tls_passthrough_route(proxy_port, Some("known.example.com"), "127.0.0.1", backend_port),
],
routes: vec![make_tls_passthrough_route(
proxy_port,
Some("known.example.com"),
"127.0.0.1",
backend_port,
)],
..Default::default()
};
@@ -132,7 +159,8 @@ async fn test_tls_passthrough_unknown_sni() {
assert!(wait_for_port(proxy_port, 2000).await);
// Send ClientHello with unknown SNI - should get no response (connection dropped)
let result = with_timeout(async {
let result = with_timeout(
async {
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
@@ -146,7 +174,9 @@ async fn test_tls_passthrough_unknown_sni() {
Ok(_) => false, // Got data = route shouldn't have matched
Err(_) => true, // Error = connection dropped
}
}, 5)
},
5,
)
.await
.unwrap();
@@ -163,9 +193,12 @@ async fn test_tls_passthrough_wildcard_domain() {
let _backend = start_prefix_echo_server(backend_port, "WILDCARD:").await;
let options = RustProxyOptions {
routes: vec![
make_tls_passthrough_route(proxy_port, Some("*.example.com"), "127.0.0.1", backend_port),
],
routes: vec![make_tls_passthrough_route(
proxy_port,
Some("*.example.com"),
"127.0.0.1",
backend_port,
)],
..Default::default()
};
@@ -174,7 +207,8 @@ async fn test_tls_passthrough_wildcard_domain() {
assert!(wait_for_port(proxy_port, 2000).await);
// Should match any subdomain of example.com
let result = with_timeout(async {
let result = with_timeout(
async {
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
@@ -184,11 +218,17 @@ async fn test_tls_passthrough_wildcard_domain() {
let mut buf = vec![0u8; 4096];
let n = stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string()
}, 5)
},
5,
)
.await
.unwrap();
assert!(result.starts_with("WILDCARD:"), "Expected WILDCARD prefix, got: {}", result);
assert!(
result.starts_with("WILDCARD:"),
"Expected WILDCARD prefix, got: {}",
result
);
proxy.stop().await.unwrap();
}
@@ -222,7 +262,8 @@ async fn test_tls_passthrough_multiple_domains() {
("beta.example.com", "B2:"),
("gamma.example.com", "B3:"),
] {
let result = with_timeout(async {
let result = with_timeout(
async {
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await
.unwrap();
@@ -232,14 +273,18 @@ async fn test_tls_passthrough_multiple_domains() {
let mut buf = vec![0u8; 4096];
let n = stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string()
}, 5)
},
5,
)
.await
.unwrap();
assert!(
result.starts_with(expected_prefix),
"Domain {} should route to {}, got: {}",
domain, expected_prefix, result
domain,
expected_prefix,
result
);
}
@@ -74,7 +74,12 @@ async fn test_tls_terminate_basic() {
let options = RustProxyOptions {
routes: vec![make_tls_terminate_route(
proxy_port, domain, "127.0.0.1", backend_port, &cert_pem, &key_pem,
proxy_port,
domain,
"127.0.0.1",
backend_port,
&cert_pem,
&key_pem,
)],
..Default::default()
};
@@ -84,7 +89,8 @@ async fn test_tls_terminate_basic() {
assert!(wait_for_port(proxy_port, 2000).await);
// Connect with TLS client
let result = with_timeout(async {
let result = with_timeout(
async {
let tls_config = make_insecure_tls_client_config();
let connector = tokio_rustls::TlsConnector::from(tls_config);
@@ -100,7 +106,9 @@ async fn test_tls_terminate_basic() {
let mut buf = vec![0u8; 1024];
let n = tls_stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string()
}, 10)
},
10,
)
.await
.unwrap();
@@ -125,7 +133,12 @@ async fn test_tls_terminate_and_reencrypt() {
// Create terminate-and-reencrypt route
let mut route = make_tls_terminate_route(
proxy_port, domain, "127.0.0.1", backend_port, &proxy_cert, &proxy_key,
proxy_port,
domain,
"127.0.0.1",
backend_port,
&proxy_cert,
&proxy_key,
);
route.action.tls.as_mut().unwrap().mode = rustproxy_config::TlsMode::TerminateAndReencrypt;
@@ -138,7 +151,8 @@ async fn test_tls_terminate_and_reencrypt() {
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async {
let result = with_timeout(
async {
let tls_config = make_insecure_tls_client_config();
let connector = tokio_rustls::TlsConnector::from(tls_config);
@@ -154,7 +168,9 @@ async fn test_tls_terminate_and_reencrypt() {
let mut buf = vec![0u8; 1024];
let n = tls_stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string()
}, 10)
},
10,
)
.await
.unwrap();
@@ -177,8 +193,22 @@ async fn test_tls_terminate_sni_cert_selection() {
let options = RustProxyOptions {
routes: vec![
make_tls_terminate_route(proxy_port, "alpha.example.com", "127.0.0.1", backend1_port, &cert1, &key1),
make_tls_terminate_route(proxy_port, "beta.example.com", "127.0.0.1", backend2_port, &cert2, &key2),
make_tls_terminate_route(
proxy_port,
"alpha.example.com",
"127.0.0.1",
backend1_port,
&cert1,
&key1,
),
make_tls_terminate_route(
proxy_port,
"beta.example.com",
"127.0.0.1",
backend2_port,
&cert2,
&key2,
),
],
..Default::default()
};
@@ -188,7 +218,8 @@ async fn test_tls_terminate_sni_cert_selection() {
assert!(wait_for_port(proxy_port, 2000).await);
// Test alpha domain
let result = with_timeout(async {
let result = with_timeout(
async {
let tls_config = make_insecure_tls_client_config();
let connector = tokio_rustls::TlsConnector::from(tls_config);
@@ -196,7 +227,8 @@ async fn test_tls_terminate_sni_cert_selection() {
.await
.unwrap();
let server_name = rustls::pki_types::ServerName::try_from("alpha.example.com".to_string()).unwrap();
let server_name =
rustls::pki_types::ServerName::try_from("alpha.example.com".to_string()).unwrap();
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
tls_stream.write_all(b"test").await.unwrap();
@@ -204,11 +236,17 @@ async fn test_tls_terminate_sni_cert_selection() {
let mut buf = vec![0u8; 1024];
let n = tls_stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string()
}, 10)
},
10,
)
.await
.unwrap();
assert!(result.starts_with("ALPHA:"), "Expected ALPHA prefix, got: {}", result);
assert!(
result.starts_with("ALPHA:"),
"Expected ALPHA prefix, got: {}",
result
);
proxy.stop().await.unwrap();
}
@@ -224,7 +262,12 @@ async fn test_tls_terminate_large_payload() {
let options = RustProxyOptions {
routes: vec![make_tls_terminate_route(
proxy_port, domain, "127.0.0.1", backend_port, &cert_pem, &key_pem,
proxy_port,
domain,
"127.0.0.1",
backend_port,
&cert_pem,
&key_pem,
)],
..Default::default()
};
@@ -233,7 +276,8 @@ async fn test_tls_terminate_large_payload() {
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async {
let result = with_timeout(
async {
let tls_config = make_insecure_tls_client_config();
let connector = tokio_rustls::TlsConnector::from(tls_config);
@@ -252,7 +296,9 @@ async fn test_tls_terminate_large_payload() {
let mut received = Vec::new();
tls_stream.read_to_end(&mut received).await.unwrap();
received.len()
}, 15)
},
15,
)
.await
.unwrap();
@@ -272,7 +318,12 @@ async fn test_tls_terminate_concurrent() {
let options = RustProxyOptions {
routes: vec![make_tls_terminate_route(
proxy_port, domain, "127.0.0.1", backend_port, &cert_pem, &key_pem,
proxy_port,
domain,
"127.0.0.1",
backend_port,
&cert_pem,
&key_pem,
)],
..Default::default()
};
@@ -281,7 +332,8 @@ async fn test_tls_terminate_concurrent() {
proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async {
let result = with_timeout(
async {
let mut handles = Vec::new();
for i in 0..10 {
let port = proxy_port;
@@ -311,7 +363,9 @@ async fn test_tls_terminate_concurrent() {
results.push(handle.await.unwrap());
}
results
}, 15)
},
15,
)
.await
.unwrap();
-200
View File
@@ -1,200 +0,0 @@
import { tap, expect } from '@git.zone/tstest/tapbundle';
import {
delay,
retryWithBackoff,
withTimeout,
parallelLimit,
debounceAsync,
AsyncMutex,
CircuitBreaker
} from '../../../ts/core/utils/async-utils.js';
tap.test('delay should pause execution for specified milliseconds', async () => {
const startTime = Date.now();
await delay(100);
const elapsed = Date.now() - startTime;
// Allow some tolerance for timing
expect(elapsed).toBeGreaterThan(90);
expect(elapsed).toBeLessThan(150);
});
tap.test('retryWithBackoff should retry failed operations', async () => {
let attempts = 0;
const operation = async () => {
attempts++;
if (attempts < 3) {
throw new Error('Test error');
}
return 'success';
};
const result = await retryWithBackoff(operation, {
maxAttempts: 3,
initialDelay: 10
});
expect(result).toEqual('success');
expect(attempts).toEqual(3);
});
tap.test('retryWithBackoff should throw after max attempts', async () => {
let attempts = 0;
const operation = async () => {
attempts++;
throw new Error('Always fails');
};
let error: Error | null = null;
try {
await retryWithBackoff(operation, {
maxAttempts: 2,
initialDelay: 10
});
} catch (e: any) {
error = e;
}
expect(error).not.toBeNull();
expect(error?.message).toEqual('Always fails');
expect(attempts).toEqual(2);
});
tap.test('withTimeout should complete operations within timeout', async () => {
const operation = async () => {
await delay(50);
return 'completed';
};
const result = await withTimeout(operation, 100);
expect(result).toEqual('completed');
});
tap.test('withTimeout should throw on timeout', async () => {
const operation = async () => {
await delay(200);
return 'never happens';
};
let error: Error | null = null;
try {
await withTimeout(operation, 50);
} catch (e: any) {
error = e;
}
expect(error).not.toBeNull();
expect(error?.message).toContain('timed out');
});
tap.test('parallelLimit should respect concurrency limit', async () => {
let concurrent = 0;
let maxConcurrent = 0;
const items = [1, 2, 3, 4, 5, 6];
const operation = async (item: number) => {
concurrent++;
maxConcurrent = Math.max(maxConcurrent, concurrent);
await delay(50);
concurrent--;
return item * 2;
};
const results = await parallelLimit(items, operation, 2);
expect(results).toEqual([2, 4, 6, 8, 10, 12]);
expect(maxConcurrent).toBeLessThan(3);
expect(maxConcurrent).toBeGreaterThan(0);
});
tap.test('debounceAsync should debounce function calls', async () => {
let callCount = 0;
const fn = async (value: string) => {
callCount++;
return value;
};
const debounced = debounceAsync(fn, 50);
// Make multiple calls quickly
debounced('a');
debounced('b');
debounced('c');
const result = await debounced('d');
// Wait a bit to ensure no more calls
await delay(100);
expect(result).toEqual('d');
expect(callCount).toEqual(1); // Only the last call should execute
});
tap.test('AsyncMutex should ensure exclusive access', async () => {
const mutex = new AsyncMutex();
const results: number[] = [];
const operation = async (value: number) => {
await mutex.runExclusive(async () => {
results.push(value);
await delay(10);
results.push(value * 10);
});
};
// Run operations concurrently
await Promise.all([
operation(1),
operation(2),
operation(3)
]);
// Results should show sequential execution
expect(results).toEqual([1, 10, 2, 20, 3, 30]);
});
tap.test('CircuitBreaker should open after failures', async () => {
const breaker = new CircuitBreaker({
failureThreshold: 2,
resetTimeout: 100
});
let attempt = 0;
const failingOperation = async () => {
attempt++;
throw new Error('Test failure');
};
// First two failures
for (let i = 0; i < 2; i++) {
try {
await breaker.execute(failingOperation);
} catch (e) {
// Expected
}
}
expect(breaker.isOpen()).toBeTrue();
// Next attempt should fail immediately
let error: Error | null = null;
try {
await breaker.execute(failingOperation);
} catch (e: any) {
error = e;
}
expect(error?.message).toEqual('Circuit breaker is open');
expect(attempt).toEqual(2); // Operation not called when circuit is open
// Wait for reset timeout
await delay(150);
// Circuit should be half-open now, allowing one attempt
const successOperation = async () => 'success';
const result = await breaker.execute(successOperation);
expect(result).toEqual('success');
expect(breaker.getState()).toEqual('closed');
});
tap.start();
-206
View File
@@ -1,206 +0,0 @@
import { tap, expect } from '@git.zone/tstest/tapbundle';
import { BinaryHeap } from '../../../ts/core/utils/binary-heap.js';
interface TestItem {
id: string;
priority: number;
value: string;
}
tap.test('should create empty heap', async () => {
const heap = new BinaryHeap<number>((a, b) => a - b);
expect(heap.size).toEqual(0);
expect(heap.isEmpty()).toBeTrue();
expect(heap.peek()).toBeUndefined();
});
tap.test('should insert and extract in correct order', async () => {
const heap = new BinaryHeap<number>((a, b) => a - b);
heap.insert(5);
heap.insert(3);
heap.insert(7);
heap.insert(1);
heap.insert(9);
heap.insert(4);
expect(heap.size).toEqual(6);
// Extract in ascending order
expect(heap.extract()).toEqual(1);
expect(heap.extract()).toEqual(3);
expect(heap.extract()).toEqual(4);
expect(heap.extract()).toEqual(5);
expect(heap.extract()).toEqual(7);
expect(heap.extract()).toEqual(9);
expect(heap.extract()).toBeUndefined();
});
tap.test('should work with custom objects and comparator', async () => {
const heap = new BinaryHeap<TestItem>(
(a, b) => a.priority - b.priority,
(item) => item.id
);
heap.insert({ id: 'a', priority: 5, value: 'five' });
heap.insert({ id: 'b', priority: 2, value: 'two' });
heap.insert({ id: 'c', priority: 8, value: 'eight' });
heap.insert({ id: 'd', priority: 1, value: 'one' });
const first = heap.extract();
expect(first?.priority).toEqual(1);
expect(first?.value).toEqual('one');
const second = heap.extract();
expect(second?.priority).toEqual(2);
expect(second?.value).toEqual('two');
});
tap.test('should support reverse order (max heap)', async () => {
const heap = new BinaryHeap<number>((a, b) => b - a);
heap.insert(5);
heap.insert(3);
heap.insert(7);
heap.insert(1);
heap.insert(9);
// Extract in descending order
expect(heap.extract()).toEqual(9);
expect(heap.extract()).toEqual(7);
expect(heap.extract()).toEqual(5);
});
tap.test('should extract by predicate', async () => {
const heap = new BinaryHeap<TestItem>((a, b) => a.priority - b.priority);
heap.insert({ id: 'a', priority: 5, value: 'five' });
heap.insert({ id: 'b', priority: 2, value: 'two' });
heap.insert({ id: 'c', priority: 8, value: 'eight' });
const extracted = heap.extractIf(item => item.id === 'b');
expect(extracted?.id).toEqual('b');
expect(heap.size).toEqual(2);
// Should not find it again
const notFound = heap.extractIf(item => item.id === 'b');
expect(notFound).toBeUndefined();
});
tap.test('should extract by key', async () => {
const heap = new BinaryHeap<TestItem>(
(a, b) => a.priority - b.priority,
(item) => item.id
);
heap.insert({ id: 'a', priority: 5, value: 'five' });
heap.insert({ id: 'b', priority: 2, value: 'two' });
heap.insert({ id: 'c', priority: 8, value: 'eight' });
expect(heap.hasKey('b')).toBeTrue();
const extracted = heap.extractByKey('b');
expect(extracted?.id).toEqual('b');
expect(heap.size).toEqual(2);
expect(heap.hasKey('b')).toBeFalse();
// Should not find it again
const notFound = heap.extractByKey('b');
expect(notFound).toBeUndefined();
});
tap.test('should throw when using key operations without extractKey', async () => {
const heap = new BinaryHeap<TestItem>((a, b) => a.priority - b.priority);
heap.insert({ id: 'a', priority: 5, value: 'five' });
let error: Error | null = null;
try {
heap.extractByKey('a');
} catch (e: any) {
error = e;
}
expect(error).not.toBeNull();
expect(error?.message).toContain('extractKey function must be provided');
});
tap.test('should handle duplicates correctly', async () => {
const heap = new BinaryHeap<number>((a, b) => a - b);
heap.insert(5);
heap.insert(5);
heap.insert(5);
heap.insert(3);
heap.insert(7);
expect(heap.size).toEqual(5);
expect(heap.extract()).toEqual(3);
expect(heap.extract()).toEqual(5);
expect(heap.extract()).toEqual(5);
expect(heap.extract()).toEqual(5);
expect(heap.extract()).toEqual(7);
});
tap.test('should convert to array without modifying heap', async () => {
const heap = new BinaryHeap<number>((a, b) => a - b);
heap.insert(5);
heap.insert(3);
heap.insert(7);
const array = heap.toArray();
expect(array).toContain(3);
expect(array).toContain(5);
expect(array).toContain(7);
expect(array.length).toEqual(3);
// Heap should still be intact
expect(heap.size).toEqual(3);
expect(heap.extract()).toEqual(3);
});
tap.test('should clear the heap', async () => {
const heap = new BinaryHeap<TestItem>(
(a, b) => a.priority - b.priority,
(item) => item.id
);
heap.insert({ id: 'a', priority: 5, value: 'five' });
heap.insert({ id: 'b', priority: 2, value: 'two' });
expect(heap.size).toEqual(2);
expect(heap.hasKey('a')).toBeTrue();
heap.clear();
expect(heap.size).toEqual(0);
expect(heap.isEmpty()).toBeTrue();
expect(heap.hasKey('a')).toBeFalse();
});
tap.test('should handle complex extraction patterns', async () => {
const heap = new BinaryHeap<number>((a, b) => a - b);
// Insert numbers 1-10 in random order
[8, 3, 5, 9, 1, 7, 4, 10, 2, 6].forEach(n => heap.insert(n));
// Extract some in order
expect(heap.extract()).toEqual(1);
expect(heap.extract()).toEqual(2);
// Insert more
heap.insert(0);
heap.insert(1.5);
// Continue extracting
expect(heap.extract()).toEqual(0);
expect(heap.extract()).toEqual(1.5);
expect(heap.extract()).toEqual(3);
// Verify remaining size (10 - 2 extracted + 2 inserted - 3 extracted = 7)
expect(heap.size).toEqual(7);
});
tap.start();
-185
View File
@@ -1,185 +0,0 @@
import { tap, expect } from '@git.zone/tstest/tapbundle';
import * as path from 'path';
import { AsyncFileSystem } from '../../../ts/core/utils/fs-utils.js';
// Use a temporary directory for tests
const testDir = path.join(process.cwd(), '.nogit', 'test-fs-utils');
const testFile = path.join(testDir, 'test.txt');
const testJsonFile = path.join(testDir, 'test.json');
tap.test('should create and check directory existence', async () => {
// Ensure directory
await AsyncFileSystem.ensureDir(testDir);
// Check it exists
const exists = await AsyncFileSystem.exists(testDir);
expect(exists).toBeTrue();
// Check it's a directory
const isDir = await AsyncFileSystem.isDirectory(testDir);
expect(isDir).toBeTrue();
});
tap.test('should write and read text files', async () => {
const testContent = 'Hello, async filesystem!';
// Write file
await AsyncFileSystem.writeFile(testFile, testContent);
// Check file exists
const exists = await AsyncFileSystem.exists(testFile);
expect(exists).toBeTrue();
// Read file
const content = await AsyncFileSystem.readFile(testFile);
expect(content).toEqual(testContent);
// Check it's a file
const isFile = await AsyncFileSystem.isFile(testFile);
expect(isFile).toBeTrue();
});
tap.test('should write and read JSON files', async () => {
const testData = {
name: 'Test',
value: 42,
nested: {
array: [1, 2, 3]
}
};
// Write JSON
await AsyncFileSystem.writeJSON(testJsonFile, testData);
// Read JSON
const readData = await AsyncFileSystem.readJSON(testJsonFile);
expect(readData).toEqual(testData);
});
tap.test('should copy files', async () => {
const copyFile = path.join(testDir, 'copy.txt');
// Copy file
await AsyncFileSystem.copyFile(testFile, copyFile);
// Check copy exists
const exists = await AsyncFileSystem.exists(copyFile);
expect(exists).toBeTrue();
// Check content matches
const content = await AsyncFileSystem.readFile(copyFile);
const originalContent = await AsyncFileSystem.readFile(testFile);
expect(content).toEqual(originalContent);
});
tap.test('should move files', async () => {
const moveFile = path.join(testDir, 'moved.txt');
const copyFile = path.join(testDir, 'copy.txt');
// Move file
await AsyncFileSystem.moveFile(copyFile, moveFile);
// Check moved file exists
const movedExists = await AsyncFileSystem.exists(moveFile);
expect(movedExists).toBeTrue();
// Check original doesn't exist
const originalExists = await AsyncFileSystem.exists(copyFile);
expect(originalExists).toBeFalse();
});
tap.test('should list files in directory', async () => {
const files = await AsyncFileSystem.listFiles(testDir);
expect(files).toContain('test.txt');
expect(files).toContain('test.json');
expect(files).toContain('moved.txt');
});
tap.test('should list files with full paths', async () => {
const files = await AsyncFileSystem.listFilesFullPath(testDir);
const fileNames = files.map(f => path.basename(f));
expect(fileNames).toContain('test.txt');
expect(fileNames).toContain('test.json');
// All paths should be absolute
files.forEach(file => {
expect(path.isAbsolute(file)).toBeTrue();
});
});
tap.test('should get file stats', async () => {
const stats = await AsyncFileSystem.getStats(testFile);
expect(stats).not.toBeNull();
expect(stats?.isFile()).toBeTrue();
expect(stats?.size).toBeGreaterThan(0);
});
tap.test('should handle non-existent files gracefully', async () => {
const nonExistent = path.join(testDir, 'does-not-exist.txt');
// Check existence
const exists = await AsyncFileSystem.exists(nonExistent);
expect(exists).toBeFalse();
// Get stats should return null
const stats = await AsyncFileSystem.getStats(nonExistent);
expect(stats).toBeNull();
// Remove should not throw
await AsyncFileSystem.remove(nonExistent);
});
tap.test('should remove files', async () => {
// Remove a file
await AsyncFileSystem.remove(testFile);
// Check it's gone
const exists = await AsyncFileSystem.exists(testFile);
expect(exists).toBeFalse();
});
tap.test('should ensure file exists', async () => {
const ensureFile = path.join(testDir, 'ensure.txt');
// Ensure file
await AsyncFileSystem.ensureFile(ensureFile);
// Check it exists
const exists = await AsyncFileSystem.exists(ensureFile);
expect(exists).toBeTrue();
// Check it's empty
const content = await AsyncFileSystem.readFile(ensureFile);
expect(content).toEqual('');
});
tap.test('should recursively list files', async () => {
// Create subdirectory with file
const subDir = path.join(testDir, 'subdir');
const subFile = path.join(subDir, 'nested.txt');
await AsyncFileSystem.ensureDir(subDir);
await AsyncFileSystem.writeFile(subFile, 'nested content');
// List recursively
const files = await AsyncFileSystem.listFilesRecursive(testDir);
// Should include files from subdirectory
const fileNames = files.map(f => path.relative(testDir, f));
expect(fileNames).toContain('test.json');
expect(fileNames).toContain(path.join('subdir', 'nested.txt'));
});
tap.test('should clean up test directory', async () => {
// Remove entire test directory
await AsyncFileSystem.removeDir(testDir);
// Check it's gone
const exists = await AsyncFileSystem.exists(testDir);
expect(exists).toBeFalse();
});
tap.start();
-156
View File
@@ -1,156 +0,0 @@
import { expect, tap } from '@git.zone/tstest/tapbundle';
import { IpUtils } from '../../../ts/core/utils/ip-utils.js';
tap.test('ip-utils - normalizeIP', async () => {
// IPv4 normalization
const ipv4Variants = IpUtils.normalizeIP('127.0.0.1');
expect(ipv4Variants).toEqual(['127.0.0.1', '::ffff:127.0.0.1']);
// IPv6-mapped IPv4 normalization
const ipv6MappedVariants = IpUtils.normalizeIP('::ffff:127.0.0.1');
expect(ipv6MappedVariants).toEqual(['::ffff:127.0.0.1', '127.0.0.1']);
// IPv6 normalization
const ipv6Variants = IpUtils.normalizeIP('::1');
expect(ipv6Variants).toEqual(['::1']);
// Invalid/empty input handling
expect(IpUtils.normalizeIP('')).toEqual([]);
expect(IpUtils.normalizeIP(null as any)).toEqual([]);
expect(IpUtils.normalizeIP(undefined as any)).toEqual([]);
});
tap.test('ip-utils - isGlobIPMatch', async () => {
// Direct matches
expect(IpUtils.isGlobIPMatch('127.0.0.1', ['127.0.0.1'])).toEqual(true);
expect(IpUtils.isGlobIPMatch('::1', ['::1'])).toEqual(true);
// Wildcard matches
expect(IpUtils.isGlobIPMatch('127.0.0.1', ['127.0.0.*'])).toEqual(true);
expect(IpUtils.isGlobIPMatch('127.0.0.1', ['127.0.*.*'])).toEqual(true);
expect(IpUtils.isGlobIPMatch('127.0.0.1', ['127.*.*.*'])).toEqual(true);
// IPv4-mapped IPv6 handling
expect(IpUtils.isGlobIPMatch('::ffff:127.0.0.1', ['127.0.0.1'])).toEqual(true);
expect(IpUtils.isGlobIPMatch('127.0.0.1', ['::ffff:127.0.0.1'])).toEqual(true);
// Match multiple patterns
expect(IpUtils.isGlobIPMatch('127.0.0.1', ['10.0.0.1', '127.0.0.1', '192.168.1.1'])).toEqual(true);
// Non-matching patterns
expect(IpUtils.isGlobIPMatch('127.0.0.1', ['10.0.0.1'])).toEqual(false);
expect(IpUtils.isGlobIPMatch('127.0.0.1', ['128.0.0.1'])).toEqual(false);
expect(IpUtils.isGlobIPMatch('127.0.0.1', ['127.0.0.2'])).toEqual(false);
// Edge cases
expect(IpUtils.isGlobIPMatch('', ['127.0.0.1'])).toEqual(false);
expect(IpUtils.isGlobIPMatch('127.0.0.1', [])).toEqual(false);
expect(IpUtils.isGlobIPMatch('127.0.0.1', null as any)).toEqual(false);
expect(IpUtils.isGlobIPMatch(null as any, ['127.0.0.1'])).toEqual(false);
});
tap.test('ip-utils - isIPAuthorized', async () => {
// Basic tests to check the core functionality works
// No restrictions - all IPs allowed
expect(IpUtils.isIPAuthorized('127.0.0.1')).toEqual(true);
// Basic blocked IP test
const blockedIP = '8.8.8.8';
const blockedIPs = [blockedIP];
expect(IpUtils.isIPAuthorized(blockedIP, [], blockedIPs)).toEqual(false);
// Basic allowed IP test
const allowedIP = '10.0.0.1';
const allowedIPs = [allowedIP];
expect(IpUtils.isIPAuthorized(allowedIP, allowedIPs)).toEqual(true);
expect(IpUtils.isIPAuthorized('192.168.1.1', allowedIPs)).toEqual(false);
});
tap.test('ip-utils - isPrivateIP', async () => {
// Private IPv4 ranges
expect(IpUtils.isPrivateIP('10.0.0.1')).toEqual(true);
expect(IpUtils.isPrivateIP('172.16.0.1')).toEqual(true);
expect(IpUtils.isPrivateIP('172.31.255.255')).toEqual(true);
expect(IpUtils.isPrivateIP('192.168.0.1')).toEqual(true);
expect(IpUtils.isPrivateIP('127.0.0.1')).toEqual(true);
// Public IPv4 addresses
expect(IpUtils.isPrivateIP('8.8.8.8')).toEqual(false);
expect(IpUtils.isPrivateIP('203.0.113.1')).toEqual(false);
// IPv4-mapped IPv6 handling
expect(IpUtils.isPrivateIP('::ffff:10.0.0.1')).toEqual(true);
expect(IpUtils.isPrivateIP('::ffff:8.8.8.8')).toEqual(false);
// Private IPv6 addresses
expect(IpUtils.isPrivateIP('::1')).toEqual(true);
expect(IpUtils.isPrivateIP('fd00::')).toEqual(true);
expect(IpUtils.isPrivateIP('fe80::1')).toEqual(true);
// Public IPv6 addresses
expect(IpUtils.isPrivateIP('2001:db8::1')).toEqual(false);
// Edge cases
expect(IpUtils.isPrivateIP('')).toEqual(false);
expect(IpUtils.isPrivateIP(null as any)).toEqual(false);
expect(IpUtils.isPrivateIP(undefined as any)).toEqual(false);
});
tap.test('ip-utils - isPublicIP', async () => {
// Public IPv4 addresses
expect(IpUtils.isPublicIP('8.8.8.8')).toEqual(true);
expect(IpUtils.isPublicIP('203.0.113.1')).toEqual(true);
// Private IPv4 ranges
expect(IpUtils.isPublicIP('10.0.0.1')).toEqual(false);
expect(IpUtils.isPublicIP('172.16.0.1')).toEqual(false);
expect(IpUtils.isPublicIP('192.168.0.1')).toEqual(false);
expect(IpUtils.isPublicIP('127.0.0.1')).toEqual(false);
// Public IPv6 addresses
expect(IpUtils.isPublicIP('2001:db8::1')).toEqual(true);
// Private IPv6 addresses
expect(IpUtils.isPublicIP('::1')).toEqual(false);
expect(IpUtils.isPublicIP('fd00::')).toEqual(false);
expect(IpUtils.isPublicIP('fe80::1')).toEqual(false);
// Edge cases - the implementation treats these as non-private, which is technically correct but might not be what users expect
const emptyResult = IpUtils.isPublicIP('');
expect(emptyResult).toEqual(true);
const nullResult = IpUtils.isPublicIP(null as any);
expect(nullResult).toEqual(true);
const undefinedResult = IpUtils.isPublicIP(undefined as any);
expect(undefinedResult).toEqual(true);
});
tap.test('ip-utils - cidrToGlobPatterns', async () => {
// Class C network
const classC = IpUtils.cidrToGlobPatterns('192.168.1.0/24');
expect(classC).toEqual(['192.168.1.*']);
// Class B network
const classB = IpUtils.cidrToGlobPatterns('172.16.0.0/16');
expect(classB).toEqual(['172.16.*.*']);
// Class A network
const classA = IpUtils.cidrToGlobPatterns('10.0.0.0/8');
expect(classA).toEqual(['10.*.*.*']);
// Small subnet (/28 = 16 addresses)
const smallSubnet = IpUtils.cidrToGlobPatterns('192.168.1.0/28');
expect(smallSubnet.length).toEqual(16);
expect(smallSubnet).toContain('192.168.1.0');
expect(smallSubnet).toContain('192.168.1.15');
// Invalid inputs
expect(IpUtils.cidrToGlobPatterns('')).toEqual([]);
expect(IpUtils.cidrToGlobPatterns('192.168.1.0')).toEqual([]);
expect(IpUtils.cidrToGlobPatterns('192.168.1.0/')).toEqual([]);
expect(IpUtils.cidrToGlobPatterns('192.168.1.0/33')).toEqual([]);
expect(IpUtils.cidrToGlobPatterns('invalid/24')).toEqual([]);
});
export default tap.start();
-252
View File
@@ -1,252 +0,0 @@
import { tap, expect } from '@git.zone/tstest/tapbundle';
import { LifecycleComponent } from '../../../ts/core/utils/lifecycle-component.js';
import { EventEmitter } from 'events';
// Test implementation of LifecycleComponent
class TestComponent extends LifecycleComponent {
public timerCallCount = 0;
public intervalCallCount = 0;
public cleanupCalled = false;
public testEmitter = new EventEmitter();
public listenerCallCount = 0;
constructor() {
super();
this.setupTimers();
this.setupListeners();
}
private setupTimers() {
// Set up a timeout
this.setTimeout(() => {
this.timerCallCount++;
}, 100);
// Set up an interval
this.setInterval(() => {
this.intervalCallCount++;
}, 50);
}
private setupListeners() {
this.addEventListener(this.testEmitter, 'test-event', () => {
this.listenerCallCount++;
});
}
protected async onCleanup(): Promise<void> {
this.cleanupCalled = true;
}
// Expose protected methods for testing
public testSetTimeout(handler: Function, timeout: number): NodeJS.Timeout {
return this.setTimeout(handler, timeout);
}
public testSetInterval(handler: Function, interval: number): NodeJS.Timeout {
return this.setInterval(handler, interval);
}
public testClearTimeout(timer: NodeJS.Timeout): void {
return this.clearTimeout(timer);
}
public testClearInterval(timer: NodeJS.Timeout): void {
return this.clearInterval(timer);
}
public testAddEventListener(target: any, event: string, handler: Function, options?: { once?: boolean }): void {
return this.addEventListener(target, event, handler, options);
}
public testIsShuttingDown(): boolean {
return this.isShuttingDownState();
}
}
tap.test('should manage timers properly', async () => {
const component = new TestComponent();
// Wait for timers to fire
await new Promise(resolve => setTimeout(resolve, 200));
expect(component.timerCallCount).toEqual(1);
expect(component.intervalCallCount).toBeGreaterThan(2);
await component.cleanup();
});
tap.test('should manage event listeners properly', async () => {
const component = new TestComponent();
// Emit events
component.testEmitter.emit('test-event');
component.testEmitter.emit('test-event');
expect(component.listenerCallCount).toEqual(2);
// Cleanup and verify listeners are removed
await component.cleanup();
component.testEmitter.emit('test-event');
expect(component.listenerCallCount).toEqual(2); // Should not increase
});
tap.test('should prevent timer execution after cleanup', async () => {
const component = new TestComponent();
let laterCallCount = 0;
component.testSetTimeout(() => {
laterCallCount++;
}, 100);
// Cleanup immediately
await component.cleanup();
// Wait for timer that would have fired
await new Promise(resolve => setTimeout(resolve, 150));
expect(laterCallCount).toEqual(0);
});
tap.test('should handle child components', async () => {
class ParentComponent extends LifecycleComponent {
public child: TestComponent;
constructor() {
super();
this.child = new TestComponent();
this.registerChildComponent(this.child);
}
}
const parent = new ParentComponent();
// Wait for child timers
await new Promise(resolve => setTimeout(resolve, 100));
expect(parent.child.timerCallCount).toEqual(1);
// Cleanup parent should cleanup child
await parent.cleanup();
expect(parent.child.cleanupCalled).toBeTrue();
expect(parent.child.testIsShuttingDown()).toBeTrue();
});
tap.test('should handle multiple cleanup calls gracefully', async () => {
const component = new TestComponent();
// Call cleanup multiple times
const promises = [
component.cleanup(),
component.cleanup(),
component.cleanup()
];
await Promise.all(promises);
// Should only clean up once
expect(component.cleanupCalled).toBeTrue();
});
tap.test('should clear specific timers', async () => {
const component = new TestComponent();
let callCount = 0;
const timer = component.testSetTimeout(() => {
callCount++;
}, 100);
// Clear the timer
component.testClearTimeout(timer);
// Wait and verify it didn't fire
await new Promise(resolve => setTimeout(resolve, 150));
expect(callCount).toEqual(0);
await component.cleanup();
});
tap.test('should clear specific intervals', async () => {
const component = new TestComponent();
let callCount = 0;
const interval = component.testSetInterval(() => {
callCount++;
}, 50);
// Let it run a bit
await new Promise(resolve => setTimeout(resolve, 120));
const countBeforeClear = callCount;
expect(countBeforeClear).toBeGreaterThan(1);
// Clear the interval
component.testClearInterval(interval);
// Wait and verify it stopped
await new Promise(resolve => setTimeout(resolve, 100));
expect(callCount).toEqual(countBeforeClear);
await component.cleanup();
});
tap.test('should handle once event listeners', async () => {
const component = new TestComponent();
const emitter = new EventEmitter();
let callCount = 0;
const handler = () => {
callCount++;
};
component.testAddEventListener(emitter, 'once-event', handler, { once: true });
// Check listener count before emit
const beforeCount = emitter.listenerCount('once-event');
expect(beforeCount).toEqual(1);
// Emit once - the listener should fire and auto-remove
emitter.emit('once-event');
expect(callCount).toEqual(1);
// Check listener was auto-removed
const afterCount = emitter.listenerCount('once-event');
expect(afterCount).toEqual(0);
// Emit again - should not increase count
emitter.emit('once-event');
expect(callCount).toEqual(1);
await component.cleanup();
});
tap.test('should not create timers when shutting down', async () => {
const component = new TestComponent();
// Start cleanup
const cleanupPromise = component.cleanup();
// Try to create timers during shutdown
let timerFired = false;
let intervalFired = false;
component.testSetTimeout(() => {
timerFired = true;
}, 10);
component.testSetInterval(() => {
intervalFired = true;
}, 10);
await cleanupPromise;
await new Promise(resolve => setTimeout(resolve, 50));
expect(timerFired).toBeFalse();
expect(intervalFired).toBeFalse();
});
export default tap.start();
@@ -1,158 +0,0 @@
import { expect, tap } from '@git.zone/tstest/tapbundle';
import { SharedSecurityManager } from '../../../ts/core/utils/shared-security-manager.js';
import type { IRouteConfig, IRouteContext } from '../../../ts/proxies/smart-proxy/models/route-types.js';
// Test security manager
tap.test('Shared Security Manager', async () => {
let securityManager: SharedSecurityManager;
// Set up a new security manager for each test
securityManager = new SharedSecurityManager({
maxConnectionsPerIP: 5,
connectionRateLimitPerMinute: 10
});
tap.test('should validate IPs correctly', async () => {
// Should allow IPs under connection limit
expect(securityManager.validateIP('192.168.1.1').allowed).toBeTrue();
// Track multiple connections
for (let i = 0; i < 4; i++) {
securityManager.trackConnectionByIP('192.168.1.1', `conn_${i}`);
}
// Should still allow IPs under connection limit
expect(securityManager.validateIP('192.168.1.1').allowed).toBeTrue();
// Add one more to reach the limit
securityManager.trackConnectionByIP('192.168.1.1', 'conn_4');
// Should now block IPs over connection limit
expect(securityManager.validateIP('192.168.1.1').allowed).toBeFalse();
// Remove a connection
securityManager.removeConnectionByIP('192.168.1.1', 'conn_0');
// Should allow again after connection is removed
expect(securityManager.validateIP('192.168.1.1').allowed).toBeTrue();
});
tap.test('should authorize IPs based on allow/block lists', async () => {
// Test with allow list only
expect(securityManager.isIPAuthorized('192.168.1.1', ['192.168.1.*'])).toBeTrue();
expect(securityManager.isIPAuthorized('192.168.2.1', ['192.168.1.*'])).toBeFalse();
// Test with block list
expect(securityManager.isIPAuthorized('192.168.1.5', ['*'], ['192.168.1.5'])).toBeFalse();
expect(securityManager.isIPAuthorized('192.168.1.1', ['*'], ['192.168.1.5'])).toBeTrue();
// Test with both allow and block lists
expect(securityManager.isIPAuthorized('192.168.1.1', ['192.168.1.*'], ['192.168.1.5'])).toBeTrue();
expect(securityManager.isIPAuthorized('192.168.1.5', ['192.168.1.*'], ['192.168.1.5'])).toBeFalse();
});
tap.test('should validate route access', async () => {
const route: IRouteConfig = {
match: {
ports: [8080]
},
action: {
type: 'forward',
targets: [{ host: 'target.com', port: 443 }]
},
security: {
ipAllowList: ['10.0.0.*', '192.168.1.*'],
ipBlockList: ['192.168.1.100'],
maxConnections: 3
}
};
const allowedContext: IRouteContext = {
clientIp: '192.168.1.1',
port: 8080,
serverIp: '127.0.0.1',
isTls: false,
timestamp: Date.now(),
connectionId: 'test_conn_1'
};
const blockedByIPContext: IRouteContext = {
...allowedContext,
clientIp: '192.168.1.100'
};
const blockedByRangeContext: IRouteContext = {
...allowedContext,
clientIp: '172.16.0.1'
};
const blockedByMaxConnectionsContext: IRouteContext = {
...allowedContext,
connectionId: 'test_conn_4'
};
expect(securityManager.isAllowed(route, allowedContext)).toBeTrue();
expect(securityManager.isAllowed(route, blockedByIPContext)).toBeFalse();
expect(securityManager.isAllowed(route, blockedByRangeContext)).toBeFalse();
// Test max connections for route - assuming implementation has been updated
if ((securityManager as any).trackConnectionByRoute) {
(securityManager as any).trackConnectionByRoute(route, 'conn_1');
(securityManager as any).trackConnectionByRoute(route, 'conn_2');
(securityManager as any).trackConnectionByRoute(route, 'conn_3');
// Should now block due to max connections
expect(securityManager.isAllowed(route, blockedByMaxConnectionsContext)).toBeFalse();
}
});
tap.test('should clean up expired entries', async () => {
const route: IRouteConfig = {
match: {
ports: [8080]
},
action: {
type: 'forward',
targets: [{ host: 'target.com', port: 443 }]
},
security: {
rateLimit: {
enabled: true,
maxRequests: 5,
window: 60 // 60 seconds
}
}
};
const context: IRouteContext = {
clientIp: '192.168.1.1',
port: 8080,
serverIp: '127.0.0.1',
isTls: false,
timestamp: Date.now(),
connectionId: 'test_conn_1'
};
// Test rate limiting if method exists
if ((securityManager as any).checkRateLimit) {
// Add 5 attempts (max allowed)
for (let i = 0; i < 5; i++) {
expect((securityManager as any).checkRateLimit(route, context)).toBeTrue();
}
// Should now be blocked
expect((securityManager as any).checkRateLimit(route, context)).toBeFalse();
// Force cleanup (normally runs periodically)
if ((securityManager as any).cleanup) {
(securityManager as any).cleanup();
}
// Should still be blocked since entries are not expired yet
expect((securityManager as any).checkRateLimit(route, context)).toBeFalse();
}
});
});
// Export test runner
export default tap.start();
-302
View File
@@ -1,302 +0,0 @@
import { expect, tap } from '@git.zone/tstest/tapbundle';
import { ValidationUtils } from '../../../ts/core/utils/validation-utils.js';
import type { IDomainOptions, IAcmeOptions } from '../../../ts/core/models/common-types.js';
tap.test('validation-utils - isValidPort', async () => {
// Valid port values
expect(ValidationUtils.isValidPort(1)).toEqual(true);
expect(ValidationUtils.isValidPort(80)).toEqual(true);
expect(ValidationUtils.isValidPort(443)).toEqual(true);
expect(ValidationUtils.isValidPort(8080)).toEqual(true);
expect(ValidationUtils.isValidPort(65535)).toEqual(true);
// Invalid port values
expect(ValidationUtils.isValidPort(0)).toEqual(false);
expect(ValidationUtils.isValidPort(-1)).toEqual(false);
expect(ValidationUtils.isValidPort(65536)).toEqual(false);
expect(ValidationUtils.isValidPort(80.5)).toEqual(false);
expect(ValidationUtils.isValidPort(NaN)).toEqual(false);
expect(ValidationUtils.isValidPort(null as any)).toEqual(false);
expect(ValidationUtils.isValidPort(undefined as any)).toEqual(false);
});
tap.test('validation-utils - isValidDomainName', async () => {
// Valid domain names
expect(ValidationUtils.isValidDomainName('example.com')).toEqual(true);
expect(ValidationUtils.isValidDomainName('sub.example.com')).toEqual(true);
expect(ValidationUtils.isValidDomainName('*.example.com')).toEqual(true);
expect(ValidationUtils.isValidDomainName('a-hyphenated-domain.example.com')).toEqual(true);
expect(ValidationUtils.isValidDomainName('example123.com')).toEqual(true);
// Invalid domain names
expect(ValidationUtils.isValidDomainName('')).toEqual(false);
expect(ValidationUtils.isValidDomainName(null as any)).toEqual(false);
expect(ValidationUtils.isValidDomainName(undefined as any)).toEqual(false);
expect(ValidationUtils.isValidDomainName('-invalid.com')).toEqual(false);
expect(ValidationUtils.isValidDomainName('invalid-.com')).toEqual(false);
expect(ValidationUtils.isValidDomainName('inv@lid.com')).toEqual(false);
expect(ValidationUtils.isValidDomainName('example')).toEqual(false);
expect(ValidationUtils.isValidDomainName('example.')).toEqual(false);
});
tap.test('validation-utils - isValidEmail', async () => {
// Valid email addresses
expect(ValidationUtils.isValidEmail('user@example.com')).toEqual(true);
expect(ValidationUtils.isValidEmail('admin@sub.example.com')).toEqual(true);
expect(ValidationUtils.isValidEmail('first.last@example.com')).toEqual(true);
expect(ValidationUtils.isValidEmail('user+tag@example.com')).toEqual(true);
// Invalid email addresses
expect(ValidationUtils.isValidEmail('')).toEqual(false);
expect(ValidationUtils.isValidEmail(null as any)).toEqual(false);
expect(ValidationUtils.isValidEmail(undefined as any)).toEqual(false);
expect(ValidationUtils.isValidEmail('user')).toEqual(false);
expect(ValidationUtils.isValidEmail('user@')).toEqual(false);
expect(ValidationUtils.isValidEmail('@example.com')).toEqual(false);
expect(ValidationUtils.isValidEmail('user example.com')).toEqual(false);
});
tap.test('validation-utils - isValidCertificate', async () => {
// Valid certificate format
const validCert = `-----BEGIN CERTIFICATE-----
MIIDazCCAlOgAwIBAgIUJlq+zz9CO2E91rlD4vhx0CX1Z/kwDQYJKoZIhvcNAQEL
BQAwRTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM
GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDAeFw0yMzAxMDEwMDAwMDBaFw0yNDAx
MDEwMDAwMDBaMEUxCzAJBgNVBAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEw
HwYDVQQKDBhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGQwggEiMA0GCSqGSIb3DQEB
AQUAA4IBDwAwggEKAoIBAQC0aQeHIV9vQpZ4UVwW/xhx9zl01UbppLXdoqe3NP9x
KfXTCB1YbtJ4GgKIlQqHGLGsLI5ZOE7KxmJeGEwK7ueP4f3WkUlM5C5yTbZ5hSUo
R+OFnszFRJJiBXJlw57YAW9+zqKQHYxwve64O64dlgw6pekDYJhXtrUUZ78Lz0GX
veJvCrci1M4Xk6/7/p1Ii9PNmbPKqHafdmkFLf6TXiWPuRDhPuHW7cXyE8xD5ahr
NsDuwJyRUk+GS4/oJg0TqLSiD0IPxDH50V5MSfUIB82i+lc1t+OAGwLhjUDuQmJi
Pv1+9Zvv+HA5PXBCsGXnSADrOOUO6t9q5R9PXbSvAgMBAAGjUzBRMB0GA1UdDgQW
BBQEtdtBhH/z1XyIf+y+5O9ErDGCVjAfBgNVHSMEGDAWgBQEtdtBhH/z1XyIf+y+
5O9ErDGCVjAPBgNVHRMBAf8EBTADAQH/MA0GCSqGSIb3DQEBCwUAA4IBAQBmJyQ0
r0pBJkYJJVDJ6i3WMoEEFTD8MEUkWxASHRnuMzm7XlZ8WS1HvbEWF0+WfJPCYHnk
tGbvUFGaZ4qUxZ4Ip2mvKXoeYTJCZRxxhHeSVWnZZu0KS3X7xVAFwQYQNhdLOqP8
XOHyLhHV/1/kcFd3GvKKjXxE79jUUZ/RXHZ/IY50KvxGzWc/5ZOFYrPEW1/rNlRo
7ixXo1hNnBQsG1YoFAxTBGegdTFJeTYHYjZZ5XlRvY2aBq6QveRbJGJLcPm1UQMd
HQYxacbWSVAQf3ltYwSH+y3a97C5OsJJiQXpRRJlQKL3txklzcpg3E5swhr63bM2
jUoNXr5G5Q5h3GD5
-----END CERTIFICATE-----`;
expect(ValidationUtils.isValidCertificate(validCert)).toEqual(true);
// Invalid certificate format
expect(ValidationUtils.isValidCertificate('')).toEqual(false);
expect(ValidationUtils.isValidCertificate(null as any)).toEqual(false);
expect(ValidationUtils.isValidCertificate(undefined as any)).toEqual(false);
expect(ValidationUtils.isValidCertificate('invalid certificate')).toEqual(false);
expect(ValidationUtils.isValidCertificate('-----BEGIN CERTIFICATE-----')).toEqual(false);
});
tap.test('validation-utils - isValidPrivateKey', async () => {
// Valid private key format
const validKey = `-----BEGIN PRIVATE KEY-----
MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQC0aQeHIV9vQpZ4
UVwW/xhx9zl01UbppLXdoqe3NP9xKfXTCB1YbtJ4GgKIlQqHGLGsLI5ZOE7KxmJe
GEwK7ueP4f3WkUlM5C5yTbZ5hSUoR+OFnszFRJJiBXJlw57YAW9+zqKQHYxwve64
O64dlgw6pekDYJhXtrUUZ78Lz0GXveJvCrci1M4Xk6/7/p1Ii9PNmbPKqHafdmkF
Lf6TXiWPuRDhPuHW7cXyE8xD5ahrNsDuwJyRUk+GS4/oJg0TqLSiD0IPxDH50V5M
SfUIB82i+lc1t+OAGwLhjUDuQmJiPv1+9Zvv+HA5PXBCsGXnSADrOOUO6t9q5R9P
XbSvAgMBAAECggEADw8Xx9iEv3FvS8hYIRn2ZWM8ObRgbHkFN92NJ/5RvUwgyV03
gG8GwVN+7IsVLnIQRyIYEGGJ0ZLZFIq7//Jy0jYUgEGLmXxknuZQn1cQEqqYVyBr
G9JrfKkXaDEoP/bZBMvZ0KEO2C9Vq6mY8M0h0GxDT2y6UQnQYjH3+H6Rvhbhh+Ld
n8lCJqWoW1t9GOUZ4xLsZ5jEDibcMJJzLBWYRxgHWyECK31/VtEQDKFiUcymrJ3I
/zoDEDGbp1gdJHvlCxfSLJ2za7ErtRKRXYFRhZ9QkNSXl1pVFMqRQkedXIcA1/Cs
VpUxiIE2JA3hSrv2csjmXoGJKDLVCvZ3CFxKL3u/AQKBgQDf6MxHXN3IDuJNrJP7
0gyRbO5d6vcvP/8qiYjtEt2xB2MNt5jDz9Bxl6aKEdNW2+UE0rvXXT6KAMZv9LiF
hxr5qiJmmSB8OeGfr0W4FCixGN4BkRNwfT1gUqZgQOrfMOLHNXOksc1CJwHJfROV
h6AH+gjtF2BCXnVEHcqtRklk4QKBgQDOOYnLJn1CwgFAyRUYK8LQYKnrLp2cGn7N
YH0SLf+VnCu7RCeNr3dm9FoHBCynjkx+qv9kGvCaJuZqEJ7+7IimNUZfDjwXTOJ+
pzs8kEPN5EQOcbkmYCTQyOA0YeBuEXcv5xIZRZUYQvKg1xXOe/JhAQ4siVIMhgQL
2XR3QwzRDwKBgB7rjZs2VYnuVExGr74lUUAGoZ71WCgt9Du9aYGJfNUriDtTEWAd
VT5sKgVqpRwkY/zXujdxGr+K8DZu4vSdHBLcDLQsEBvRZIILTzjwXBRPGMnVe95v
Q90+vytbmHshlkbMaVRNQxCjdbf7LbQbLecgRt+5BKxHVwL4u3BZNIqhAoGAas4f
PoPOdFfKAMKZL7FLGMhEXLyFsg1JcGRfmByxTNgOJKXpYv5Hl7JLYOvfaiUOUYKI
5Dnh5yLdFOaOjnB3iP0KEiSVEwZK0/Vna5JkzFTqImK9QD3SQCtQLXHJLD52EPFR
9gRa8N5k68+mIzGDEzPBoC1AajbXFGPxNOwaQQ0CgYEAq0dPYK0TTv3Yez27LzVy
RbHkwpE+df4+KhpHbCzUKzfQYo4WTahlR6IzhpOyVQKIptkjuTDyQzkmt0tXEGw3
/M3yHa1FcY9IzPrHXHJoOeU1r9ay0GOQUi4FxKkYYWxUCtjOi5xlUxI0ABD8vGGR
QbKMrQXRgLd/84nDnY2cYzA=
-----END PRIVATE KEY-----`;
expect(ValidationUtils.isValidPrivateKey(validKey)).toEqual(true);
// Invalid private key format
expect(ValidationUtils.isValidPrivateKey('')).toEqual(false);
expect(ValidationUtils.isValidPrivateKey(null as any)).toEqual(false);
expect(ValidationUtils.isValidPrivateKey(undefined as any)).toEqual(false);
expect(ValidationUtils.isValidPrivateKey('invalid key')).toEqual(false);
expect(ValidationUtils.isValidPrivateKey('-----BEGIN PRIVATE KEY-----')).toEqual(false);
});
tap.test('validation-utils - validateDomainOptions', async () => {
// Valid domain options
const validDomainOptions: IDomainOptions = {
domainName: 'example.com',
sslRedirect: true,
acmeMaintenance: true
};
expect(ValidationUtils.validateDomainOptions(validDomainOptions).isValid).toEqual(true);
// Valid domain options with forward
const validDomainOptionsWithForward: IDomainOptions = {
domainName: 'example.com',
sslRedirect: true,
acmeMaintenance: true,
forward: {
ip: '127.0.0.1',
port: 8080
}
};
expect(ValidationUtils.validateDomainOptions(validDomainOptionsWithForward).isValid).toEqual(true);
// Invalid domain options - no domain name
const invalidDomainOptions1: IDomainOptions = {
domainName: '',
sslRedirect: true,
acmeMaintenance: true
};
expect(ValidationUtils.validateDomainOptions(invalidDomainOptions1).isValid).toEqual(false);
// Invalid domain options - invalid domain name
const invalidDomainOptions2: IDomainOptions = {
domainName: 'inv@lid.com',
sslRedirect: true,
acmeMaintenance: true
};
expect(ValidationUtils.validateDomainOptions(invalidDomainOptions2).isValid).toEqual(false);
// Invalid domain options - forward missing ip
const invalidDomainOptions3: IDomainOptions = {
domainName: 'example.com',
sslRedirect: true,
acmeMaintenance: true,
forward: {
ip: '',
port: 8080
}
};
expect(ValidationUtils.validateDomainOptions(invalidDomainOptions3).isValid).toEqual(false);
// Invalid domain options - forward missing port
const invalidDomainOptions4: IDomainOptions = {
domainName: 'example.com',
sslRedirect: true,
acmeMaintenance: true,
forward: {
ip: '127.0.0.1',
port: null as any
}
};
expect(ValidationUtils.validateDomainOptions(invalidDomainOptions4).isValid).toEqual(false);
// Invalid domain options - invalid forward port
const invalidDomainOptions5: IDomainOptions = {
domainName: 'example.com',
sslRedirect: true,
acmeMaintenance: true,
forward: {
ip: '127.0.0.1',
port: 99999
}
};
expect(ValidationUtils.validateDomainOptions(invalidDomainOptions5).isValid).toEqual(false);
});
tap.test('validation-utils - validateAcmeOptions', async () => {
// Valid ACME options
const validAcmeOptions: IAcmeOptions = {
enabled: true,
accountEmail: 'admin@example.com',
port: 80,
httpsRedirectPort: 443,
useProduction: false,
renewThresholdDays: 30,
renewCheckIntervalHours: 24,
certificateStore: './certs'
};
expect(ValidationUtils.validateAcmeOptions(validAcmeOptions).isValid).toEqual(true);
// ACME disabled - should be valid regardless of other options
const disabledAcmeOptions: IAcmeOptions = {
enabled: false
};
// Don't need to verify other fields when ACME is disabled
const disabledResult = ValidationUtils.validateAcmeOptions(disabledAcmeOptions);
expect(disabledResult.isValid).toEqual(true);
// Invalid ACME options - missing email
const invalidAcmeOptions1: IAcmeOptions = {
enabled: true,
accountEmail: '',
port: 80
};
expect(ValidationUtils.validateAcmeOptions(invalidAcmeOptions1).isValid).toEqual(false);
// Invalid ACME options - invalid email
const invalidAcmeOptions2: IAcmeOptions = {
enabled: true,
accountEmail: 'invalid-email',
port: 80
};
expect(ValidationUtils.validateAcmeOptions(invalidAcmeOptions2).isValid).toEqual(false);
// Invalid ACME options - invalid port
const invalidAcmeOptions3: IAcmeOptions = {
enabled: true,
accountEmail: 'admin@example.com',
port: 99999
};
expect(ValidationUtils.validateAcmeOptions(invalidAcmeOptions3).isValid).toEqual(false);
// Invalid ACME options - invalid HTTPS redirect port
const invalidAcmeOptions4: IAcmeOptions = {
enabled: true,
accountEmail: 'admin@example.com',
port: 80,
httpsRedirectPort: -1
};
expect(ValidationUtils.validateAcmeOptions(invalidAcmeOptions4).isValid).toEqual(false);
// Invalid ACME options - invalid renew threshold days
const invalidAcmeOptions5: IAcmeOptions = {
enabled: true,
accountEmail: 'admin@example.com',
port: 80,
renewThresholdDays: 0
};
// The implementation allows renewThresholdDays of 0, even though the docstring suggests otherwise
const validationResult5 = ValidationUtils.validateAcmeOptions(invalidAcmeOptions5);
expect(validationResult5.isValid).toEqual(true);
// Invalid ACME options - invalid renew check interval hours
const invalidAcmeOptions6: IAcmeOptions = {
enabled: true,
accountEmail: 'admin@example.com',
port: 80,
renewCheckIntervalHours: 0
};
// The implementation should validate this, but let's check the actual result
const checkIntervalResult = ValidationUtils.validateAcmeOptions(invalidAcmeOptions6);
// Adjust test to match actual implementation behavior
expect(checkIntervalResult.isValid !== false ? true : false).toEqual(true);
});
export default tap.start();
+33 -36
View File
@@ -1,11 +1,5 @@
import { tap, expect } from '@git.zone/tstest/tapbundle';
import {
createHttpsTerminateRoute,
createCompleteHttpsServer,
createHttpRoute,
} from '../ts/proxies/smart-proxy/utils/route-helpers.js';
import {
mergeRouteConfigs,
cloneRoute,
@@ -19,8 +13,11 @@ import {
import type { IRouteConfig } from '../ts/proxies/smart-proxy/models/route-types.js';
tap.test('route creation - createHttpsTerminateRoute produces correct structure', async () => {
const route = createHttpsTerminateRoute('secure.example.com', { host: '127.0.0.1', port: 8443 });
tap.test('route creation - HTTPS terminate route has correct structure', async () => {
const route: IRouteConfig = {
match: { ports: 443, domains: 'secure.example.com' },
action: { type: 'forward', targets: [{ host: '127.0.0.1', port: 8443 }], tls: { mode: 'terminate', certificate: 'auto' } }
};
expect(route).toHaveProperty('match');
expect(route).toHaveProperty('action');
expect(route.action.type).toEqual('forward');
@@ -29,20 +26,10 @@ tap.test('route creation - createHttpsTerminateRoute produces correct structure'
expect(route.match.domains).toEqual('secure.example.com');
});
tap.test('route creation - createCompleteHttpsServer returns redirect and main route', async () => {
const routes = createCompleteHttpsServer('app.example.com', { host: '127.0.0.1', port: 3000 });
expect(routes).toBeArray();
expect(routes.length).toBeGreaterThanOrEqual(2);
// Should have an HTTP→HTTPS redirect and an HTTPS route
const hasRedirect = routes.some((r) => r.action.type === 'forward' && r.action.redirect !== undefined);
const hasHttps = routes.some((r) => r.action.tls?.mode === 'terminate');
expect(hasRedirect || hasHttps).toBeTrue();
});
tap.test('route validation - validateRoutes on a set of routes', async () => {
const routes: IRouteConfig[] = [
createHttpRoute('a.com', { host: '127.0.0.1', port: 3000 }),
createHttpRoute('b.com', { host: '127.0.0.1', port: 4000 }),
{ match: { ports: 80, domains: 'a.com' }, action: { type: 'forward', targets: [{ host: '127.0.0.1', port: 3000 }] } },
{ match: { ports: 80, domains: 'b.com' }, action: { type: 'forward', targets: [{ host: '127.0.0.1', port: 4000 }] } },
];
const result = validateRoutes(routes);
expect(result.valid).toBeTrue();
@@ -51,7 +38,7 @@ tap.test('route validation - validateRoutes on a set of routes', async () => {
tap.test('route validation - validateRoutes catches invalid route in set', async () => {
const routes: any[] = [
createHttpRoute('valid.com', { host: '127.0.0.1', port: 3000 }),
{ match: { ports: 80, domains: 'valid.com' }, action: { type: 'forward', targets: [{ host: '127.0.0.1', port: 3000 }] } },
{ match: { ports: 80 } }, // missing action
];
const result = validateRoutes(routes);
@@ -60,23 +47,30 @@ tap.test('route validation - validateRoutes catches invalid route in set', async
});
tap.test('path matching - routeMatchesPath with exact path', async () => {
const route = createHttpRoute('example.com', { host: '127.0.0.1', port: 3000 });
route.match.path = '/api';
const route: IRouteConfig = {
match: { ports: 80, domains: 'example.com', path: '/api' },
action: { type: 'forward', targets: [{ host: '127.0.0.1', port: 3000 }] }
};
expect(routeMatchesPath(route, '/api')).toBeTrue();
expect(routeMatchesPath(route, '/other')).toBeFalse();
});
tap.test('path matching - route without path matches everything', async () => {
const route = createHttpRoute('example.com', { host: '127.0.0.1', port: 3000 });
// No path set, should match any path
const route: IRouteConfig = {
match: { ports: 80, domains: 'example.com' },
action: { type: 'forward', targets: [{ host: '127.0.0.1', port: 3000 }] }
};
expect(routeMatchesPath(route, '/anything')).toBeTrue();
expect(routeMatchesPath(route, '/')).toBeTrue();
});
tap.test('route merging - mergeRouteConfigs combines routes', async () => {
const base = createHttpRoute('example.com', { host: '127.0.0.1', port: 3000 });
base.priority = 10;
base.name = 'base-route';
const base: IRouteConfig = {
match: { ports: 80, domains: 'example.com' },
action: { type: 'forward', targets: [{ host: '127.0.0.1', port: 3000 }] },
priority: 10,
name: 'base-route'
};
const merged = mergeRouteConfigs(base, {
priority: 50,
@@ -85,14 +79,16 @@ tap.test('route merging - mergeRouteConfigs combines routes', async () => {
expect(merged.priority).toEqual(50);
expect(merged.name).toEqual('merged-route');
// Original route fields should be preserved
expect(merged.match.domains).toEqual('example.com');
expect(merged.action.targets![0].host).toEqual('127.0.0.1');
});
tap.test('route merging - mergeRouteConfigs does not mutate original', async () => {
const base = createHttpRoute('example.com', { host: '127.0.0.1', port: 3000 });
base.name = 'original';
const base: IRouteConfig = {
match: { ports: 80, domains: 'example.com' },
action: { type: 'forward', targets: [{ host: '127.0.0.1', port: 3000 }] },
name: 'original'
};
const merged = mergeRouteConfigs(base, { name: 'changed' });
expect(base.name).toEqual('original');
@@ -100,20 +96,21 @@ tap.test('route merging - mergeRouteConfigs does not mutate original', async ()
});
tap.test('route cloning - cloneRoute produces independent copy', async () => {
const original = createHttpRoute('example.com', { host: '127.0.0.1', port: 3000 });
original.priority = 42;
original.name = 'original-route';
const original: IRouteConfig = {
match: { ports: 80, domains: 'example.com' },
action: { type: 'forward', targets: [{ host: '127.0.0.1', port: 3000 }] },
priority: 42,
name: 'original-route'
};
const cloned = cloneRoute(original);
// Should be equal in value
expect(cloned.match.domains).toEqual('example.com');
expect(cloned.priority).toEqual(42);
expect(cloned.name).toEqual('original-route');
expect(cloned.action.targets![0].host).toEqual('127.0.0.1');
expect(cloned.action.targets![0].port).toEqual(3000);
// Should be independent - modifying clone shouldn't affect original
cloned.name = 'cloned-route';
cloned.priority = 99;
expect(original.name).toEqual('original-route');
+125
View File
@@ -0,0 +1,125 @@
import { expect, tap } from '@git.zone/tstest/tapbundle';
import * as dgram from 'dgram';
import { SmartProxy } from '../ts/index.js';
import type { TDatagramHandler, IDatagramInfo } from '../ts/index.js';
import { findFreePorts, assertPortsFree } from './helpers/port-allocator.js';
let smartProxy: SmartProxy;
let PROXY_PORT: number;
// Helper: send a single UDP datagram and wait for a response
function sendDatagram(port: number, msg: string, timeoutMs = 5000): Promise<string> {
return new Promise((resolve, reject) => {
const client = dgram.createSocket('udp4');
const timeout = setTimeout(() => {
client.close();
reject(new Error(`UDP response timeout after ${timeoutMs}ms`));
}, timeoutMs);
client.send(Buffer.from(msg), port, '127.0.0.1');
client.on('message', (data) => {
clearTimeout(timeout);
client.close();
resolve(data.toString());
});
client.on('error', (err) => {
clearTimeout(timeout);
client.close();
reject(err);
});
});
}
tap.test('setup: start SmartProxy with datagramHandler', async () => {
[PROXY_PORT] = await findFreePorts(1);
const handler: TDatagramHandler = (datagram, info, reply) => {
reply(Buffer.from(`Handled: ${datagram.toString()}`));
};
smartProxy = new SmartProxy({
routes: [
{
name: 'dgram-handler-test',
match: {
ports: PROXY_PORT,
transport: 'udp' as const,
},
action: {
type: 'socket-handler',
datagramHandler: handler,
},
},
],
defaults: {
security: {
ipAllowList: ['127.0.0.1', '::1', '::ffff:127.0.0.1'],
},
},
});
await smartProxy.start();
});
tap.test('datagram handler: receives and replies to datagram', async () => {
const response = await sendDatagram(PROXY_PORT, 'Hello Handler');
expect(response).toEqual('Handled: Hello Handler');
});
tap.test('datagram handler: async handler works', async () => {
// Stop and restart with async handler
await smartProxy.stop();
[PROXY_PORT] = await findFreePorts(1);
const asyncHandler: TDatagramHandler = async (datagram, info, reply) => {
// Simulate async work
await new Promise<void>((resolve) => setTimeout(resolve, 10));
reply(Buffer.from(`Async: ${datagram.toString()}`));
};
smartProxy = new SmartProxy({
routes: [
{
name: 'dgram-async-handler',
match: {
ports: PROXY_PORT,
transport: 'udp' as const,
},
action: {
type: 'socket-handler',
datagramHandler: asyncHandler,
},
},
],
defaults: {
security: {
ipAllowList: ['127.0.0.1', '::1', '::ffff:127.0.0.1'],
},
},
});
await smartProxy.start();
const response = await sendDatagram(PROXY_PORT, 'Test Async');
expect(response).toEqual('Async: Test Async');
});
tap.test('datagram handler: multiple rapid datagrams', async () => {
const promises: Promise<string>[] = [];
for (let i = 0; i < 5; i++) {
promises.push(sendDatagram(PROXY_PORT, `msg-${i}`));
}
const responses = await Promise.all(promises);
for (let i = 0; i < 5; i++) {
expect(responses).toContain(`Async: msg-${i}`);
}
});
tap.test('cleanup: stop SmartProxy', async () => {
await smartProxy.stop();
await assertPortsFree([PROXY_PORT]);
});
export default tap.start();
+38 -34
View File
@@ -1,11 +1,5 @@
import { tap, expect } from '@git.zone/tstest/tapbundle';
import {
createHttpRoute,
createHttpsTerminateRoute,
createLoadBalancerRoute,
} from '../ts/proxies/smart-proxy/utils/route-helpers.js';
import {
findMatchingRoutes,
findBestMatchingRoute,
@@ -22,24 +16,11 @@ import {
import type { IRouteConfig } from '../ts/proxies/smart-proxy/models/route-types.js';
tap.test('route creation - createHttpRoute produces correct structure', async () => {
const route = createHttpRoute('example.com', { host: '127.0.0.1', port: 3000 });
expect(route).toHaveProperty('match');
expect(route).toHaveProperty('action');
expect(route.match.domains).toEqual('example.com');
expect(route.action.type).toEqual('forward');
expect(route.action.targets).toBeArray();
expect(route.action.targets![0].host).toEqual('127.0.0.1');
expect(route.action.targets![0].port).toEqual(3000);
});
tap.test('route creation - createHttpRoute with array of domains', async () => {
const route = createHttpRoute(['a.com', 'b.com'], { host: 'localhost', port: 8080 });
expect(route.match.domains).toEqual(['a.com', 'b.com']);
});
tap.test('route validation - validateRouteConfig accepts valid route', async () => {
const route = createHttpRoute('valid.example.com', { host: '10.0.0.1', port: 8080 });
const route: IRouteConfig = {
match: { ports: 80, domains: 'valid.example.com' },
action: { type: 'forward', targets: [{ host: '10.0.0.1', port: 8080 }] }
};
const result = validateRouteConfig(route);
expect(result.valid).toBeTrue();
expect(result.errors).toHaveLength(0);
@@ -67,30 +48,44 @@ tap.test('route validation - isValidPort checks correctly', async () => {
});
tap.test('domain matching - exact domain', async () => {
const route = createHttpRoute('example.com', { host: '127.0.0.1', port: 3000 });
const route: IRouteConfig = {
match: { ports: 80, domains: 'example.com' },
action: { type: 'forward', targets: [{ host: '127.0.0.1', port: 3000 }] }
};
expect(routeMatchesDomain(route, 'example.com')).toBeTrue();
expect(routeMatchesDomain(route, 'other.com')).toBeFalse();
});
tap.test('domain matching - wildcard domain', async () => {
const route = createHttpRoute('*.example.com', { host: '127.0.0.1', port: 3000 });
const route: IRouteConfig = {
match: { ports: 80, domains: '*.example.com' },
action: { type: 'forward', targets: [{ host: '127.0.0.1', port: 3000 }] }
};
expect(routeMatchesDomain(route, 'sub.example.com')).toBeTrue();
expect(routeMatchesDomain(route, 'example.com')).toBeFalse();
});
tap.test('port matching - single port', async () => {
const route = createHttpRoute('example.com', { host: '127.0.0.1', port: 3000 });
// createHttpRoute defaults to port 80
const route: IRouteConfig = {
match: { ports: 80, domains: 'example.com' },
action: { type: 'forward', targets: [{ host: '127.0.0.1', port: 3000 }] }
};
expect(routeMatchesPort(route, 80)).toBeTrue();
expect(routeMatchesPort(route, 443)).toBeFalse();
});
tap.test('route finding - findBestMatchingRoute selects by priority', async () => {
const lowPriority = createHttpRoute('example.com', { host: '127.0.0.1', port: 3000 });
lowPriority.priority = 10;
const lowPriority: IRouteConfig = {
match: { ports: 80, domains: 'example.com' },
action: { type: 'forward', targets: [{ host: '127.0.0.1', port: 3000 }] },
priority: 10
};
const highPriority = createHttpRoute('example.com', { host: '127.0.0.1', port: 4000 });
highPriority.priority = 100;
const highPriority: IRouteConfig = {
match: { ports: 80, domains: 'example.com' },
action: { type: 'forward', targets: [{ host: '127.0.0.1', port: 4000 }] },
priority: 100
};
const routes: IRouteConfig[] = [lowPriority, highPriority];
const best = findBestMatchingRoute(routes, { domain: 'example.com', port: 80 });
@@ -100,9 +95,18 @@ tap.test('route finding - findBestMatchingRoute selects by priority', async () =
});
tap.test('route finding - findMatchingRoutes returns all matches', async () => {
const route1 = createHttpRoute('example.com', { host: '127.0.0.1', port: 3000 });
const route2 = createHttpRoute('example.com', { host: '127.0.0.1', port: 4000 });
const route3 = createHttpRoute('other.com', { host: '127.0.0.1', port: 5000 });
const route1: IRouteConfig = {
match: { ports: 80, domains: 'example.com' },
action: { type: 'forward', targets: [{ host: '127.0.0.1', port: 3000 }] }
};
const route2: IRouteConfig = {
match: { ports: 80, domains: 'example.com' },
action: { type: 'forward', targets: [{ host: '127.0.0.1', port: 4000 }] }
};
const route3: IRouteConfig = {
match: { ports: 80, domains: 'other.com' },
action: { type: 'forward', targets: [{ host: '127.0.0.1', port: 5000 }] }
};
const matches = findMatchingRoutes([route1, route2, route3], { domain: 'example.com', port: 80 });
expect(matches).toHaveLength(2);
-146
View File
@@ -1,146 +0,0 @@
import { expect, tap } from '@git.zone/tstest/tapbundle';
import * as smartproxy from '../ts/index.js';
tap.test('Protocol Detection - TLS Detection', async () => {
// Test TLS handshake detection
const tlsHandshake = Buffer.from([
0x16, // Handshake record type
0x03, 0x01, // TLS 1.0
0x00, 0x05, // Length: 5 bytes
0x01, // ClientHello
0x00, 0x00, 0x01, 0x00 // Handshake length and data
]);
const detector = new smartproxy.detection.TlsDetector();
expect(detector.canHandle(tlsHandshake)).toEqual(true);
const result = detector.detect(tlsHandshake);
expect(result).toBeDefined();
expect(result?.protocol).toEqual('tls');
expect(result?.connectionInfo.tlsVersion).toEqual('TLSv1.0');
});
tap.test('Protocol Detection - HTTP Detection', async () => {
// Test HTTP request detection
const httpRequest = Buffer.from(
'GET /test HTTP/1.1\r\n' +
'Host: example.com\r\n' +
'User-Agent: TestClient/1.0\r\n' +
'\r\n'
);
const detector = new smartproxy.detection.HttpDetector();
expect(detector.canHandle(httpRequest)).toEqual(true);
const result = detector.detect(httpRequest);
expect(result).toBeDefined();
expect(result?.protocol).toEqual('http');
expect(result?.connectionInfo.method).toEqual('GET');
expect(result?.connectionInfo.path).toEqual('/test');
expect(result?.connectionInfo.domain).toEqual('example.com');
});
tap.test('Protocol Detection - Main Detector TLS', async () => {
const tlsHandshake = Buffer.from([
0x16, // Handshake record type
0x03, 0x03, // TLS 1.2
0x00, 0x05, // Length: 5 bytes
0x01, // ClientHello
0x00, 0x00, 0x01, 0x00 // Handshake length and data
]);
const result = await smartproxy.detection.ProtocolDetector.detect(tlsHandshake);
expect(result.protocol).toEqual('tls');
expect(result.connectionInfo.tlsVersion).toEqual('TLSv1.2');
});
tap.test('Protocol Detection - Main Detector HTTP', async () => {
const httpRequest = Buffer.from(
'POST /api/test HTTP/1.1\r\n' +
'Host: api.example.com\r\n' +
'Content-Type: application/json\r\n' +
'Content-Length: 2\r\n' +
'\r\n' +
'{}'
);
const result = await smartproxy.detection.ProtocolDetector.detect(httpRequest);
expect(result.protocol).toEqual('http');
expect(result.connectionInfo.method).toEqual('POST');
expect(result.connectionInfo.path).toEqual('/api/test');
expect(result.connectionInfo.domain).toEqual('api.example.com');
});
tap.test('Protocol Detection - Unknown Protocol', async () => {
const unknownData = Buffer.from('UNKNOWN PROTOCOL DATA\r\n');
const result = await smartproxy.detection.ProtocolDetector.detect(unknownData);
expect(result.protocol).toEqual('unknown');
expect(result.isComplete).toEqual(true);
});
tap.test('Protocol Detection - Fragmented HTTP', async () => {
// Create connection context
const context = smartproxy.detection.ProtocolDetector.createConnectionContext({
sourceIp: '127.0.0.1',
sourcePort: 12345,
destIp: '127.0.0.1',
destPort: 80,
socketId: 'test-connection-1'
});
// First fragment
const fragment1 = Buffer.from('GET /test HT');
let result = await smartproxy.detection.ProtocolDetector.detectWithContext(
fragment1,
context
);
expect(result.protocol).toEqual('http');
expect(result.isComplete).toEqual(false);
// Second fragment
const fragment2 = Buffer.from('TP/1.1\r\nHost: example.com\r\n\r\n');
result = await smartproxy.detection.ProtocolDetector.detectWithContext(
fragment2,
context
);
expect(result.protocol).toEqual('http');
expect(result.isComplete).toEqual(true);
expect(result.connectionInfo.method).toEqual('GET');
expect(result.connectionInfo.path).toEqual('/test');
expect(result.connectionInfo.domain).toEqual('example.com');
// Clean up fragments
smartproxy.detection.ProtocolDetector.cleanupConnection(context);
});
tap.test('Protocol Detection - HTTP Methods', async () => {
const methods = ['GET', 'POST', 'PUT', 'DELETE', 'PATCH', 'HEAD', 'OPTIONS'];
for (const method of methods) {
const request = Buffer.from(
`${method} /test HTTP/1.1\r\n` +
'Host: example.com\r\n' +
'\r\n'
);
const detector = new smartproxy.detection.HttpDetector();
const result = detector.detect(request);
expect(result?.connectionInfo.method).toEqual(method);
}
});
tap.test('Protocol Detection - Invalid Data', async () => {
// Binary data that's not a valid protocol
const binaryData = Buffer.from([0xFF, 0xFE, 0xFD, 0xFC, 0xFB]);
const result = await smartproxy.detection.ProtocolDetector.detect(binaryData);
expect(result.protocol).toEqual('unknown');
});
tap.test('cleanup detection', async () => {
// Clean up the protocol detector instance
smartproxy.detection.ProtocolDetector.destroy();
});
export default tap.start();
+191
View File
@@ -0,0 +1,191 @@
import { expect, tap } from '@git.zone/tstest/tapbundle';
import { SmartProxy } from '../ts/index.js';
import * as http from 'http';
import * as net from 'net';
import * as tls from 'tls';
import * as fs from 'fs';
import * as path from 'path';
import { fileURLToPath } from 'url';
import { assertPortsFree, findFreePorts } from './helpers/port-allocator.js';
const __filename = fileURLToPath(import.meta.url);
const __dirname = path.dirname(__filename);
const CERT_PEM = fs.readFileSync(path.join(__dirname, '..', 'assets', 'certs', 'cert.pem'), 'utf8');
const KEY_PEM = fs.readFileSync(path.join(__dirname, '..', 'assets', 'certs', 'key.pem'), 'utf8');
let httpBackendPort: number;
let tlsBackendPort: number;
let httpProxyPort: number;
let tlsProxyPort: number;
let httpBackend: http.Server;
let tlsBackend: tls.Server;
let proxy: SmartProxy;
async function pollMetrics(proxyToPoll: SmartProxy): Promise<void> {
await (proxyToPoll as any).metricsAdapter.poll();
}
async function waitForCondition(
callback: () => Promise<boolean>,
timeoutMs: number = 5000,
stepMs: number = 100,
): Promise<void> {
const deadline = Date.now() + timeoutMs;
while (Date.now() < deadline) {
if (await callback()) {
return;
}
await new Promise((resolve) => setTimeout(resolve, stepMs));
}
throw new Error(`Condition not met within ${timeoutMs}ms`);
}
function hasIpDomainRequest(domain: string): boolean {
const byIp = proxy.getMetrics().connections.domainRequestsByIP();
for (const domainMap of byIp.values()) {
if (domainMap.has(domain)) {
return true;
}
}
return false;
}
tap.test('setup - backend servers for HTTP domain rate metrics', async () => {
[httpBackendPort, tlsBackendPort, httpProxyPort, tlsProxyPort] = await findFreePorts(4);
httpBackend = http.createServer((req, res) => {
let body = '';
req.on('data', (chunk) => {
body += chunk;
});
req.on('end', () => {
res.writeHead(200, { 'Content-Type': 'text/plain' });
res.end(`ok:${body}`);
});
});
await new Promise<void>((resolve) => {
httpBackend.listen(httpBackendPort, () => resolve());
});
tlsBackend = tls.createServer({ cert: CERT_PEM, key: KEY_PEM }, (socket) => {
socket.on('data', (data) => {
socket.write(data);
});
socket.on('error', () => {});
});
await new Promise<void>((resolve) => {
tlsBackend.listen(tlsBackendPort, () => resolve());
});
});
tap.test('setup - start proxy with HTTP and TLS passthrough routes', async () => {
proxy = new SmartProxy({
routes: [
{
id: 'http-domain-rates',
name: 'http-domain-rates',
match: { ports: httpProxyPort, domains: 'example.com' },
action: {
type: 'forward',
targets: [{ host: 'localhost', port: httpBackendPort }],
},
},
{
id: 'tls-passthrough-domain-rates',
name: 'tls-passthrough-domain-rates',
match: { ports: tlsProxyPort, domains: 'passthrough.example.com' },
action: {
type: 'forward',
tls: { mode: 'passthrough' },
targets: [{ host: 'localhost', port: tlsBackendPort }],
},
},
],
metrics: { enabled: true, sampleIntervalMs: 100, retentionSeconds: 60 },
});
await proxy.start();
await new Promise((resolve) => setTimeout(resolve, 300));
});
tap.test('HTTP requests populate per-domain HTTP request rates', async () => {
for (let i = 0; i < 3; i++) {
await new Promise<void>((resolve, reject) => {
const body = `payload-${i}`;
const req = http.request(
{
hostname: 'localhost',
port: httpProxyPort,
path: '/echo',
method: 'POST',
headers: {
Host: 'Example.COM',
'Content-Type': 'text/plain',
'Content-Length': String(body.length),
},
},
(res) => {
res.resume();
res.on('end', () => resolve());
},
);
req.on('error', reject);
req.end(body);
});
}
await waitForCondition(async () => {
await pollMetrics(proxy);
const domainMetrics = proxy.getMetrics().requests.byDomain().get('example.com');
return (domainMetrics?.lastMinute ?? 0) >= 3 && (domainMetrics?.perSecond ?? 0) > 0;
});
const exampleMetrics = proxy.getMetrics().requests.byDomain().get('example.com');
expect(exampleMetrics).toBeTruthy();
expect(exampleMetrics?.lastMinute).toEqual(3);
expect(exampleMetrics?.perSecond).toBeGreaterThan(0);
});
tap.test('TLS passthrough SNI does not inflate HTTP domain request rates', async () => {
const tlsClient = tls.connect({
host: 'localhost',
port: tlsProxyPort,
servername: 'passthrough.example.com',
rejectUnauthorized: false,
});
await new Promise<void>((resolve, reject) => {
tlsClient.once('secureConnect', () => resolve());
tlsClient.once('error', reject);
});
const echoPromise = new Promise<void>((resolve, reject) => {
tlsClient.once('data', () => resolve());
tlsClient.once('error', reject);
});
tlsClient.write(Buffer.from('hello over tls passthrough'));
await echoPromise;
await waitForCondition(async () => {
await pollMetrics(proxy);
return hasIpDomainRequest('passthrough.example.com');
});
const requestRates = proxy.getMetrics().requests.byDomain();
expect(requestRates.has('passthrough.example.com')).toBeFalse();
expect(requestRates.get('example.com')?.lastMinute).toEqual(3);
expect(hasIpDomainRequest('passthrough.example.com')).toBeTrue();
tlsClient.destroy();
});
tap.test('cleanup - stop proxy and close backend servers', async () => {
await proxy.stop();
await new Promise<void>((resolve) => httpBackend.close(() => resolve()));
await new Promise<void>((resolve) => tlsBackend.close(() => resolve()));
await assertPortsFree([httpBackendPort, tlsBackendPort, httpProxyPort, tlsProxyPort]);
});
export default tap.start()
+57 -104
View File
@@ -2,146 +2,101 @@ import * as path from 'path';
import { tap, expect } from '@git.zone/tstest/tapbundle';
import { SmartProxy } from '../ts/proxies/smart-proxy/index.js';
import {
createHttpRoute,
createHttpsTerminateRoute,
createHttpsPassthroughRoute,
createHttpToHttpsRedirect,
createCompleteHttpsServer,
createLoadBalancerRoute,
createApiRoute,
createWebSocketRoute
} from '../ts/proxies/smart-proxy/utils/route-helpers.js';
import { SocketHandlers } from '../ts/proxies/smart-proxy/utils/socket-handlers.js';
import type { IRouteConfig } from '../ts/proxies/smart-proxy/models/route-types.js';
// Test to demonstrate various route configurations using the new helpers
tap.test('Route-based configuration examples', async (tools) => {
// Example 1: HTTP-only configuration
const httpOnlyRoute = createHttpRoute(
'http.example.com',
{
host: 'localhost',
port: 3000
},
{
const httpOnlyRoute: IRouteConfig = {
match: { ports: 80, domains: 'http.example.com' },
action: { type: 'forward', targets: [{ host: 'localhost', port: 3000 }] },
name: 'Basic HTTP Route'
}
);
};
console.log('HTTP-only route created successfully:', httpOnlyRoute.name);
expect(httpOnlyRoute.action.type).toEqual('forward');
expect(httpOnlyRoute.match.domains).toEqual('http.example.com');
// Example 2: HTTPS Passthrough (SNI) configuration
const httpsPassthroughRoute = createHttpsPassthroughRoute(
'pass.example.com',
{
host: ['10.0.0.1', '10.0.0.2'], // Round-robin target IPs
port: 443
},
{
const httpsPassthroughRoute: IRouteConfig = {
match: { ports: 443, domains: 'pass.example.com' },
action: { type: 'forward', targets: [{ host: '10.0.0.1', port: 443 }, { host: '10.0.0.2', port: 443 }], tls: { mode: 'passthrough' } },
name: 'HTTPS Passthrough Route'
}
);
};
expect(httpsPassthroughRoute).toBeTruthy();
expect(httpsPassthroughRoute.action.tls?.mode).toEqual('passthrough');
expect(Array.isArray(httpsPassthroughRoute.action.targets)).toBeTrue();
// Example 3: HTTPS Termination to HTTP Backend
const terminateToHttpRoute = createHttpsTerminateRoute(
'secure.example.com',
{
host: 'localhost',
port: 8080
},
{
certificate: 'auto',
const terminateToHttpRoute: IRouteConfig = {
match: { ports: 443, domains: 'secure.example.com' },
action: { type: 'forward', targets: [{ host: 'localhost', port: 8080 }], tls: { mode: 'terminate', certificate: 'auto' } },
name: 'HTTPS Termination to HTTP Backend'
}
);
};
// Create the HTTP to HTTPS redirect for this domain
const httpToHttpsRedirect = createHttpToHttpsRedirect(
'secure.example.com',
443,
{
const httpToHttpsRedirect: IRouteConfig = {
match: { ports: 80, domains: 'secure.example.com' },
action: { type: 'socket-handler', socketHandler: SocketHandlers.httpRedirect('https://{domain}:443{path}', 301) },
name: 'HTTP to HTTPS Redirect for secure.example.com'
}
);
};
expect(terminateToHttpRoute).toBeTruthy();
expect(terminateToHttpRoute.action.tls?.mode).toEqual('terminate');
expect(httpToHttpsRedirect.action.type).toEqual('socket-handler');
// Example 4: Load Balancer with HTTPS
const loadBalancerRoute = createLoadBalancerRoute(
'proxy.example.com',
['internal-api-1.local', 'internal-api-2.local'],
8443,
{
tls: {
mode: 'terminate-and-reencrypt',
certificate: 'auto'
const loadBalancerRoute: IRouteConfig = {
match: { ports: 443, domains: 'proxy.example.com' },
action: {
type: 'forward',
targets: [
{ host: 'internal-api-1.local', port: 8443 },
{ host: 'internal-api-2.local', port: 8443 }
],
tls: { mode: 'terminate-and-reencrypt', certificate: 'auto' }
},
name: 'Load Balanced HTTPS Route'
}
);
};
expect(loadBalancerRoute).toBeTruthy();
expect(loadBalancerRoute.action.tls?.mode).toEqual('terminate-and-reencrypt');
expect(Array.isArray(loadBalancerRoute.action.targets)).toBeTrue();
// Example 5: API Route
const apiRoute = createApiRoute(
'api.example.com',
'/api',
{ host: 'localhost', port: 8081 },
{
name: 'API Route',
useTls: true,
addCorsHeaders: true
}
);
const apiRoute: IRouteConfig = {
match: { ports: 443, domains: 'api.example.com', path: '/api' },
action: {
type: 'forward',
targets: [{ host: 'localhost', port: 8081 }],
tls: { mode: 'terminate', certificate: 'auto' }
},
name: 'API Route'
};
expect(apiRoute.action.type).toEqual('forward');
expect(apiRoute.match.path).toBeTruthy();
// Example 6: Complete HTTPS Server with HTTP Redirect
const httpsServerRoutes = createCompleteHttpsServer(
'complete.example.com',
{
host: 'localhost',
port: 8080
},
{
certificate: 'auto',
const httpsRoute: IRouteConfig = {
match: { ports: 443, domains: 'complete.example.com' },
action: { type: 'forward', targets: [{ host: 'localhost', port: 8080 }], tls: { mode: 'terminate', certificate: 'auto' } },
name: 'Complete HTTPS Server'
}
);
};
expect(Array.isArray(httpsServerRoutes)).toBeTrue();
expect(httpsServerRoutes.length).toEqual(2); // HTTPS route and HTTP redirect
expect(httpsServerRoutes[0].action.tls?.mode).toEqual('terminate');
expect(httpsServerRoutes[1].action.type).toEqual('socket-handler');
const httpsRedirectRoute: IRouteConfig = {
match: { ports: 80, domains: 'complete.example.com' },
action: { type: 'socket-handler', socketHandler: SocketHandlers.httpRedirect('https://{domain}:443{path}', 301) },
name: 'Complete HTTPS Server - Redirect'
};
// Example 7: Static File Server - removed (use nginx/apache behind proxy)
// Example 8: WebSocket Route
const webSocketRoute = createWebSocketRoute(
'ws.example.com',
'/ws',
{ host: 'localhost', port: 8082 },
{
useTls: true,
const webSocketRoute: IRouteConfig = {
match: { ports: 443, domains: 'ws.example.com', path: '/ws' },
action: {
type: 'forward',
targets: [{ host: 'localhost', port: 8082 }],
tls: { mode: 'terminate', certificate: 'auto' },
websocket: { enabled: true }
},
name: 'WebSocket Route'
}
);
};
expect(webSocketRoute.action.type).toEqual('forward');
expect(webSocketRoute.action.websocket?.enabled).toBeTrue();
// Create a SmartProxy instance with all routes
const allRoutes: IRouteConfig[] = [
httpOnlyRoute,
httpsPassthroughRoute,
@@ -149,19 +104,17 @@ tap.test('Route-based configuration examples', async (tools) => {
httpToHttpsRedirect,
loadBalancerRoute,
apiRoute,
...httpsServerRoutes,
httpsRoute,
httpsRedirectRoute,
webSocketRoute
];
// We're not actually starting the SmartProxy in this test,
// just verifying that the configuration is valid
const smartProxy = new SmartProxy({
routes: allRoutes
});
// Just verify that all routes are configured correctly
console.log(`Created ${allRoutes.length} example routes`);
expect(allRoutes.length).toEqual(9); // One less without static file route
expect(allRoutes.length).toEqual(9);
});
export default tap.start();
+17 -47
View File
@@ -1,27 +1,8 @@
import { tap, expect } from '@git.zone/tstest/tapbundle';
import * as plugins from '../ts/plugins.js';
// Import route-based helpers
import {
createHttpRoute,
createHttpsTerminateRoute,
createHttpsPassthroughRoute,
createHttpToHttpsRedirect,
createCompleteHttpsServer
} from '../ts/proxies/smart-proxy/utils/route-helpers.js';
import type { IRouteConfig } from '../ts/proxies/smart-proxy/models/route-types.js';
// Create helper functions for backward compatibility
const helpers = {
httpOnly: (domains: string | string[], target: any) => createHttpRoute(domains, target),
tlsTerminateToHttp: (domains: string | string[], target: any) =>
createHttpsTerminateRoute(domains, target),
tlsTerminateToHttps: (domains: string | string[], target: any) =>
createHttpsTerminateRoute(domains, target, { reencrypt: true }),
httpsPassthrough: (domains: string | string[], target: any) =>
createHttpsPassthroughRoute(domains, target)
};
// Route-based utility functions for testing
function findRouteForDomain(routes: any[], domain: string): any {
return routes.find(route => {
const domains = Array.isArray(route.match.domains)
@@ -31,55 +12,44 @@ function findRouteForDomain(routes: any[], domain: string): any {
});
}
// Replace the old test with route-based tests
tap.test('Route Helpers - Create HTTP routes', async () => {
const route = helpers.httpOnly('example.com', { host: 'localhost', port: 3000 });
const route: IRouteConfig = {
match: { ports: 80, domains: 'example.com' },
action: { type: 'forward', targets: [{ host: 'localhost', port: 3000 }] }
};
expect(route.action.type).toEqual('forward');
expect(route.match.domains).toEqual('example.com');
expect(route.action.targets?.[0]).toEqual({ host: 'localhost', port: 3000 });
});
tap.test('Route Helpers - Create HTTPS terminate to HTTP routes', async () => {
const route = helpers.tlsTerminateToHttp('secure.example.com', { host: 'localhost', port: 3000 });
const route: IRouteConfig = {
match: { ports: 443, domains: 'secure.example.com' },
action: { type: 'forward', targets: [{ host: 'localhost', port: 3000 }], tls: { mode: 'terminate', certificate: 'auto' } }
};
expect(route.action.type).toEqual('forward');
expect(route.match.domains).toEqual('secure.example.com');
expect(route.action.tls?.mode).toEqual('terminate');
});
tap.test('Route Helpers - Create HTTPS passthrough routes', async () => {
const route = helpers.httpsPassthrough('passthrough.example.com', { host: 'backend', port: 443 });
const route: IRouteConfig = {
match: { ports: 443, domains: 'passthrough.example.com' },
action: { type: 'forward', targets: [{ host: 'backend', port: 443 }], tls: { mode: 'passthrough' } }
};
expect(route.action.type).toEqual('forward');
expect(route.match.domains).toEqual('passthrough.example.com');
expect(route.action.tls?.mode).toEqual('passthrough');
});
tap.test('Route Helpers - Create HTTPS to HTTPS routes', async () => {
const route = helpers.tlsTerminateToHttps('reencrypt.example.com', { host: 'backend', port: 443 });
const route: IRouteConfig = {
match: { ports: 443, domains: 'reencrypt.example.com' },
action: { type: 'forward', targets: [{ host: 'backend', port: 443 }], tls: { mode: 'terminate-and-reencrypt', certificate: 'auto' } }
};
expect(route.action.type).toEqual('forward');
expect(route.match.domains).toEqual('reencrypt.example.com');
expect(route.action.tls?.mode).toEqual('terminate-and-reencrypt');
});
tap.test('Route Helpers - Create complete HTTPS server with redirect', async () => {
const routes = createCompleteHttpsServer(
'full.example.com',
{ host: 'localhost', port: 3000 },
{ certificate: 'auto' }
);
expect(routes.length).toEqual(2);
// Check HTTP to HTTPS redirect - find route by port
const redirectRoute = routes.find(r => r.match.ports === 80);
expect(redirectRoute.action.type).toEqual('socket-handler');
expect(redirectRoute.action.socketHandler).toBeDefined();
expect(redirectRoute.match.ports).toEqual(80);
// Check HTTPS route
const httpsRoute = routes.find(r => r.action.type === 'forward');
expect(httpsRoute.match.ports).toEqual(443);
expect(httpsRoute.action.tls?.mode).toEqual('terminate');
});
// Export test runner
export default tap.start();
-128
View File
@@ -1,128 +0,0 @@
import { expect, tap } from '@git.zone/tstest/tapbundle';
import * as smartproxy from '../ts/index.js';
import { RouteValidator } from '../ts/proxies/smart-proxy/utils/route-validator.js';
import { IpUtils } from '../ts/core/utils/ip-utils.js';
tap.test('IP Validation - Shorthand patterns', async () => {
// Test shorthand patterns are now accepted
const testPatterns = [
{ pattern: '192.168.*', shouldPass: true },
{ pattern: '192.168.*.*', shouldPass: true },
{ pattern: '10.*', shouldPass: true },
{ pattern: '10.*.*.*', shouldPass: true },
{ pattern: '172.16.*', shouldPass: true },
{ pattern: '10.0.0.0/8', shouldPass: true },
{ pattern: '192.168.0.0/16', shouldPass: true },
{ pattern: '192.168.1.100', shouldPass: true },
{ pattern: '*', shouldPass: true },
{ pattern: '192.168.1.1-192.168.1.100', shouldPass: true },
];
for (const { pattern, shouldPass } of testPatterns) {
const route = {
name: 'test',
match: { ports: 80 },
action: { type: 'forward' as const, targets: [{ host: 'localhost', port: 8080 }] },
security: { ipAllowList: [pattern] }
};
const result = RouteValidator.validateRoute(route);
if (shouldPass) {
expect(result.valid).toEqual(true);
console.log(`✅ Pattern '${pattern}' correctly accepted`);
} else {
expect(result.valid).toEqual(false);
console.log(`✅ Pattern '${pattern}' correctly rejected`);
}
}
});
tap.test('IP Matching - Runtime shorthand pattern matching', async () => {
// Test runtime matching with shorthand patterns
const testCases = [
{ ip: '192.168.1.100', patterns: ['192.168.*'], expected: true },
{ ip: '192.168.1.100', patterns: ['192.168.1.*'], expected: true },
{ ip: '192.168.1.100', patterns: ['192.168.2.*'], expected: false },
{ ip: '10.0.0.1', patterns: ['10.*'], expected: true },
{ ip: '10.1.2.3', patterns: ['10.*'], expected: true },
{ ip: '172.16.0.1', patterns: ['10.*'], expected: false },
{ ip: '192.168.1.1', patterns: ['192.168.*.*'], expected: true },
];
for (const { ip, patterns, expected } of testCases) {
const result = IpUtils.isGlobIPMatch(ip, patterns);
expect(result).toEqual(expected);
console.log(`✅ IP ${ip} with pattern ${patterns[0]} = ${result} (expected ${expected})`);
}
});
tap.test('IP Matching - CIDR notation', async () => {
// Test CIDR notation matching
const cidrTests = [
{ ip: '10.0.0.1', cidr: '10.0.0.0/8', expected: true },
{ ip: '10.255.255.255', cidr: '10.0.0.0/8', expected: true },
{ ip: '11.0.0.1', cidr: '10.0.0.0/8', expected: false },
{ ip: '192.168.1.1', cidr: '192.168.0.0/16', expected: true },
{ ip: '192.168.255.255', cidr: '192.168.0.0/16', expected: true },
{ ip: '192.169.0.1', cidr: '192.168.0.0/16', expected: false },
{ ip: '192.168.1.100', cidr: '192.168.1.0/24', expected: true },
{ ip: '192.168.2.100', cidr: '192.168.1.0/24', expected: false },
];
for (const { ip, cidr, expected } of cidrTests) {
const result = IpUtils.isGlobIPMatch(ip, [cidr]);
expect(result).toEqual(expected);
console.log(`✅ IP ${ip} in CIDR ${cidr} = ${result} (expected ${expected})`);
}
});
tap.test('IP Matching - Range notation', async () => {
// Test range notation matching
const rangeTests = [
{ ip: '192.168.1.1', range: '192.168.1.1-192.168.1.100', expected: true },
{ ip: '192.168.1.50', range: '192.168.1.1-192.168.1.100', expected: true },
{ ip: '192.168.1.100', range: '192.168.1.1-192.168.1.100', expected: true },
{ ip: '192.168.1.101', range: '192.168.1.1-192.168.1.100', expected: false },
{ ip: '192.168.2.50', range: '192.168.1.1-192.168.1.100', expected: false },
];
for (const { ip, range, expected } of rangeTests) {
const result = IpUtils.isGlobIPMatch(ip, [range]);
expect(result).toEqual(expected);
console.log(`✅ IP ${ip} in range ${range} = ${result} (expected ${expected})`);
}
});
tap.test('IP Matching - Mixed patterns', async () => {
// Test with mixed pattern types
const allowList = [
'10.0.0.0/8', // CIDR
'192.168.*', // Shorthand glob
'172.16.1.*', // Specific subnet glob
'8.8.8.8', // Single IP
'1.1.1.1-1.1.1.10' // Range
];
const tests = [
{ ip: '10.1.2.3', expected: true }, // Matches CIDR
{ ip: '192.168.100.1', expected: true }, // Matches shorthand glob
{ ip: '172.16.1.5', expected: true }, // Matches specific glob
{ ip: '8.8.8.8', expected: true }, // Matches single IP
{ ip: '1.1.1.5', expected: true }, // Matches range
{ ip: '9.9.9.9', expected: false }, // Doesn't match any
];
for (const { ip, expected } of tests) {
const result = IpUtils.isGlobIPMatch(ip, allowList);
expect(result).toEqual(expected);
console.log(`✅ IP ${ip} in mixed patterns = ${result} (expected ${expected})`);
}
});
export default tap.start();
-112
View File
@@ -1,112 +0,0 @@
import { expect, tap } from '@git.zone/tstest/tapbundle';
import { LogDeduplicator } from '../ts/core/utils/log-deduplicator.js';
let deduplicator: LogDeduplicator;
tap.test('Setup log deduplicator', async () => {
deduplicator = new LogDeduplicator(1000); // 1 second flush interval for testing
});
tap.test('Connection rejection deduplication', async (tools) => {
// Simulate multiple connection rejections
for (let i = 0; i < 10; i++) {
deduplicator.log(
'connection-rejected',
'warn',
'Connection rejected',
{ reason: 'global-limit', component: 'test' },
'global-limit'
);
}
for (let i = 0; i < 5; i++) {
deduplicator.log(
'connection-rejected',
'warn',
'Connection rejected',
{ reason: 'route-limit', component: 'test' },
'route-limit'
);
}
// Force flush
deduplicator.flush('connection-rejected');
// The logs should have been aggregated
// (Can't easily test the actual log output, but we can verify the mechanism works)
expect(deduplicator).toBeInstanceOf(LogDeduplicator);
});
tap.test('IP rejection deduplication', async (tools) => {
// Simulate rejections from multiple IPs
const ips = ['192.168.1.100', '192.168.1.101', '192.168.1.100', '10.0.0.1'];
const reasons = ['per-ip-limit', 'rate-limit', 'per-ip-limit', 'global-limit'];
for (let i = 0; i < ips.length; i++) {
deduplicator.log(
'ip-rejected',
'warn',
`Connection rejected from ${ips[i]}`,
{ remoteIP: ips[i], reason: reasons[i] },
ips[i]
);
}
// Add more rejections from the same IP
for (let i = 0; i < 20; i++) {
deduplicator.log(
'ip-rejected',
'warn',
'Connection rejected from 192.168.1.100',
{ remoteIP: '192.168.1.100', reason: 'rate-limit' },
'192.168.1.100'
);
}
// Force flush
deduplicator.flush('ip-rejected');
// Verify the deduplicator exists and works
expect(deduplicator).toBeInstanceOf(LogDeduplicator);
});
tap.test('Connection cleanup deduplication', async (tools) => {
// Simulate various cleanup events
const reasons = ['normal', 'timeout', 'error', 'normal', 'zombie'];
for (const reason of reasons) {
for (let i = 0; i < 5; i++) {
deduplicator.log(
'connection-cleanup',
'info',
`Connection cleanup: ${reason}`,
{ connectionId: `conn-${i}`, reason },
reason
);
}
}
// Wait for automatic flush
await tools.delayFor(1500);
// Verify deduplicator is working
expect(deduplicator).toBeInstanceOf(LogDeduplicator);
});
tap.test('Automatic periodic flush', async (tools) => {
// Add some events
deduplicator.log('test-event', 'info', 'Test message', {}, 'test');
// Wait for automatic flush (should happen within 2x flush interval = 2 seconds)
await tools.delayFor(2500);
// Events should have been flushed automatically
expect(deduplicator).toBeInstanceOf(LogDeduplicator);
});
tap.test('Cleanup deduplicator', async () => {
deduplicator.cleanup();
expect(deduplicator).toBeInstanceOf(LogDeduplicator);
});
export default tap.start();
+3
View File
@@ -83,6 +83,9 @@ tap.test('should verify new metrics API structure', async () => {
expect(metrics.throughput).toHaveProperty('history');
expect(metrics.throughput).toHaveProperty('byRoute');
expect(metrics.throughput).toHaveProperty('byIP');
// Check request methods
expect(metrics.requests).toHaveProperty('byDomain');
});
tap.test('should track active connections', async (tools) => {
+33 -33
View File
@@ -1,15 +1,14 @@
import { SmartProxy } from '../ts/proxies/smart-proxy/index.js';
import { createNfTablesRoute, createNfTablesTerminateRoute } from '../ts/proxies/smart-proxy/utils/route-helpers.js';
import { expect, tap } from '@git.zone/tstest/tapbundle';
import * as child_process from 'child_process';
import { promisify } from 'util';
import type { IRouteConfig } from '../ts/proxies/smart-proxy/models/route-types.js';
const exec = promisify(child_process.exec);
// Check if we have root privileges to run NFTables tests
async function checkRootPrivileges(): Promise<boolean> {
try {
// Check if we're running as root
const { stdout } = await exec('id -u');
return stdout.trim() === '0';
} catch (err) {
@@ -17,7 +16,6 @@ async function checkRootPrivileges(): Promise<boolean> {
}
}
// Check if tests should run
const isRoot = await checkRootPrivileges();
if (!isRoot) {
@@ -29,38 +27,45 @@ if (!isRoot) {
console.log('');
}
// Define the test with proper skip condition
const testFn = isRoot ? tap.test : tap.skip.test;
testFn('NFTables integration tests', async () => {
console.log('Running NFTables tests with root privileges');
// Create test routes
const routes = [
createNfTablesRoute('tcp-forward', {
host: 'localhost',
port: 8080
}, {
ports: 9080,
protocol: 'tcp'
}),
const routes: IRouteConfig[] = [
{
match: { ports: 9080 },
action: {
type: 'forward',
forwardingEngine: 'nftables',
targets: [{ host: 'localhost', port: 8080 }],
nftables: { protocol: 'tcp' }
},
name: 'tcp-forward'
},
createNfTablesRoute('udp-forward', {
host: 'localhost',
port: 5353
}, {
ports: 5354,
protocol: 'udp'
}),
{
match: { ports: 5354 },
action: {
type: 'forward',
forwardingEngine: 'nftables',
targets: [{ host: 'localhost', port: 5353 }],
nftables: { protocol: 'udp' }
},
name: 'udp-forward'
},
createNfTablesRoute('port-range', {
host: 'localhost',
port: 8080
}, {
ports: [{ from: 9000, to: 9100 }],
protocol: 'tcp'
})
{
match: { ports: [{ from: 9000, to: 9100 }] },
action: {
type: 'forward',
forwardingEngine: 'nftables',
targets: [{ host: 'localhost', port: 8080 }],
nftables: { protocol: 'tcp' }
},
name: 'port-range'
}
];
const smartProxy = new SmartProxy({
@@ -68,15 +73,12 @@ testFn('NFTables integration tests', async () => {
routes
});
// Start the proxy
await smartProxy.start();
console.log('SmartProxy started with NFTables routes');
// Get NFTables status
const status = await smartProxy.getNfTablesStatus();
console.log('NFTables status:', JSON.stringify(status, null, 2));
// Verify all routes are provisioned
expect(Object.keys(status).length).toEqual(routes.length);
for (const routeStatus of Object.values(status)) {
@@ -84,11 +86,9 @@ testFn('NFTables integration tests', async () => {
expect(routeStatus.ruleCount.total).toBeGreaterThan(0);
}
// Stop the proxy
await smartProxy.stop();
console.log('SmartProxy stopped');
// Verify all rules are cleaned up
const finalStatus = await smartProxy.getNfTablesStatus();
expect(Object.keys(finalStatus).length).toEqual(0);
});
+54 -67
View File
@@ -1,5 +1,4 @@
import { SmartProxy } from '../ts/proxies/smart-proxy/index.js';
import { createNfTablesRoute, createNfTablesTerminateRoute } from '../ts/proxies/smart-proxy/utils/route-helpers.js';
import { expect, tap } from '@git.zone/tstest/tapbundle';
import * as net from 'net';
import * as http from 'http';
@@ -10,13 +9,13 @@ import { fileURLToPath } from 'url';
import * as child_process from 'child_process';
import { promisify } from 'util';
import type { IRouteConfig } from '../ts/proxies/smart-proxy/models/route-types.js';
const exec = promisify(child_process.exec);
// Get __dirname equivalent for ES modules
const __filename = fileURLToPath(import.meta.url);
const __dirname = path.dirname(__filename);
// Check if we have root privileges
async function checkRootPrivileges(): Promise<boolean> {
try {
const { stdout } = await exec('id -u');
@@ -26,7 +25,6 @@ async function checkRootPrivileges(): Promise<boolean> {
}
}
// Check if tests should run
const runTests = await checkRootPrivileges();
if (!runTests) {
@@ -36,10 +34,8 @@ if (!runTests) {
console.log('Skipping NFTables integration tests');
console.log('========================================');
console.log('');
// Skip tests when not running as root - tests are marked with tap.skip.test
}
// Test server and client utilities
let testTcpServer: net.Server;
let testHttpServer: http.Server;
let testHttpsServer: https.Server;
@@ -53,10 +49,8 @@ const PROXY_HTTP_PORT = 5001;
const PROXY_HTTPS_PORT = 5002;
const TEST_DATA = 'Hello through NFTables!';
// Helper to create test certificates
async function createTestCertificates() {
try {
// Import the certificate helper
const certsModule = await import('./helpers/certificates.js');
const certificates = certsModule.loadTestCertificates();
return {
@@ -65,7 +59,6 @@ async function createTestCertificates() {
};
} catch (err) {
console.error('Failed to load test certificates:', err);
// Use dummy certificates for testing
return {
cert: fs.readFileSync(path.join(__dirname, '..', 'assets', 'certs', 'cert.pem'), 'utf8'),
key: fs.readFileSync(path.join(__dirname, '..', 'assets', 'certs', 'key.pem'), 'utf8')
@@ -76,7 +69,6 @@ async function createTestCertificates() {
tap.skip.test('setup NFTables integration test environment', async () => {
console.log('Running NFTables integration tests with root privileges');
// Create a basic TCP test server
testTcpServer = net.createServer((socket) => {
socket.on('data', (data) => {
socket.write(`Server says: ${data.toString()}`);
@@ -90,7 +82,6 @@ tap.skip.test('setup NFTables integration test environment', async () => {
});
});
// Create an HTTP test server
testHttpServer = http.createServer((req, res) => {
res.writeHead(200, { 'Content-Type': 'text/plain' });
res.end(`HTTP Server says: ${TEST_DATA}`);
@@ -103,7 +94,6 @@ tap.skip.test('setup NFTables integration test environment', async () => {
});
});
// Create an HTTPS test server
const certs = await createTestCertificates();
testHttpsServer = https.createServer({ key: certs.key, cert: certs.cert }, (req, res) => {
res.writeHead(200, { 'Content-Type': 'text/plain' });
@@ -117,69 +107,73 @@ tap.skip.test('setup NFTables integration test environment', async () => {
});
});
// Create SmartProxy with various NFTables routes
smartProxy = new SmartProxy({
enableDetailedLogging: true,
routes: [
// TCP forwarding route
createNfTablesRoute('tcp-nftables', {
host: 'localhost',
port: TEST_TCP_PORT
}, {
ports: PROXY_TCP_PORT,
protocol: 'tcp'
}),
{
match: { ports: PROXY_TCP_PORT },
action: {
type: 'forward',
forwardingEngine: 'nftables',
targets: [{ host: 'localhost', port: TEST_TCP_PORT }],
nftables: { protocol: 'tcp' }
},
name: 'tcp-nftables'
},
// HTTP forwarding route
createNfTablesRoute('http-nftables', {
host: 'localhost',
port: TEST_HTTP_PORT
}, {
ports: PROXY_HTTP_PORT,
protocol: 'tcp'
}),
{
match: { ports: PROXY_HTTP_PORT },
action: {
type: 'forward',
forwardingEngine: 'nftables',
targets: [{ host: 'localhost', port: TEST_HTTP_PORT }],
nftables: { protocol: 'tcp' }
},
name: 'http-nftables'
},
// HTTPS termination route
createNfTablesTerminateRoute('https-nftables.example.com', {
host: 'localhost',
port: TEST_HTTPS_PORT
}, {
ports: PROXY_HTTPS_PORT,
protocol: 'tcp',
certificate: certs
}),
{
match: { ports: PROXY_HTTPS_PORT, domains: 'https-nftables.example.com' },
action: {
type: 'forward',
forwardingEngine: 'nftables',
targets: [{ host: 'localhost', port: TEST_HTTPS_PORT }],
tls: { mode: 'terminate', certificate: 'auto' },
nftables: { protocol: 'tcp' }
},
name: 'https-nftables'
},
// Route with IP allow list
createNfTablesRoute('secure-tcp', {
host: 'localhost',
port: TEST_TCP_PORT
}, {
ports: 5003,
protocol: 'tcp',
ipAllowList: ['127.0.0.1', '::1']
}),
{
match: { ports: 5003 },
action: {
type: 'forward',
forwardingEngine: 'nftables',
targets: [{ host: 'localhost', port: TEST_TCP_PORT }],
nftables: { protocol: 'tcp', ipAllowList: ['127.0.0.1', '::1'] }
},
name: 'secure-tcp'
},
// Route with QoS settings
createNfTablesRoute('qos-tcp', {
host: 'localhost',
port: TEST_TCP_PORT
}, {
ports: 5004,
protocol: 'tcp',
maxRate: '10mbps',
priority: 1
})
{
match: { ports: 5004 },
action: {
type: 'forward',
forwardingEngine: 'nftables',
targets: [{ host: 'localhost', port: TEST_TCP_PORT }],
nftables: { protocol: 'tcp', maxRate: '10mbps', priority: 1 }
},
name: 'qos-tcp'
}
]
});
console.log('SmartProxy created, now starting...');
// Start the proxy
try {
await smartProxy.start();
console.log('SmartProxy started successfully');
// Verify proxy is listening on expected ports
const listeningPorts = smartProxy.getListeningPorts();
console.log(`SmartProxy is listening on ports: ${listeningPorts.join(', ')}`);
} catch (err) {
@@ -191,7 +185,6 @@ tap.skip.test('setup NFTables integration test environment', async () => {
tap.skip.test('should forward TCP connections through NFTables', async () => {
console.log(`Attempting to connect to proxy TCP port ${PROXY_TCP_PORT}...`);
// First verify our test server is running
try {
const testClient = new net.Socket();
await new Promise<void>((resolve, reject) => {
@@ -206,7 +199,6 @@ tap.skip.test('should forward TCP connections through NFTables', async () => {
console.error(`Test server on port ${TEST_TCP_PORT} is not accessible: ${err}`);
}
// Connect to the proxy port
const client = new net.Socket();
const response = await new Promise<string>((resolve, reject) => {
@@ -259,14 +251,13 @@ tap.skip.test('should forward HTTP connections through NFTables', async () => {
});
tap.skip.test('should handle HTTPS termination with NFTables', async () => {
// Skip this test if running without proper certificates
const response = await new Promise<string>((resolve, reject) => {
const options = {
hostname: 'localhost',
port: PROXY_HTTPS_PORT,
path: '/',
method: 'GET',
rejectUnauthorized: false // For self-signed cert
rejectUnauthorized: false
};
https.get(options, (res) => {
@@ -284,7 +275,6 @@ tap.skip.test('should handle HTTPS termination with NFTables', async () => {
});
tap.skip.test('should respect IP allow lists in NFTables', async () => {
// This test should pass since we're connecting from localhost
const client = new net.Socket();
const connected = await new Promise<boolean>((resolve) => {
@@ -311,11 +301,9 @@ tap.skip.test('should respect IP allow lists in NFTables', async () => {
tap.skip.test('should get NFTables status', async () => {
const status = await smartProxy.getNfTablesStatus();
// Check that we have status for our routes
const statusKeys = Object.keys(status);
expect(statusKeys.length).toBeGreaterThan(0);
// Check status structure for one of the routes
const firstStatus = status[statusKeys[0]];
expect(firstStatus).toHaveProperty('active');
expect(firstStatus).toHaveProperty('ruleCount');
@@ -324,7 +312,6 @@ tap.skip.test('should get NFTables status', async () => {
});
tap.skip.test('cleanup NFTables integration test environment', async () => {
// Stop the proxy and test servers
await smartProxy.stop();
await new Promise<void>((resolve) => {
+46 -82
View File
@@ -1,30 +1,20 @@
import { expect, tap } from '@git.zone/tstest/tapbundle';
import * as net from 'net';
import { SmartProxy } from '../ts/proxies/smart-proxy/index.js';
import {
createPortMappingRoute,
createOffsetPortMappingRoute,
createDynamicRoute,
createSmartLoadBalancer,
createPortOffset
} from '../ts/proxies/smart-proxy/utils/route-helpers.js';
import type { IRouteConfig, IRouteContext } from '../ts/proxies/smart-proxy/models/route-types.js';
import { findFreePorts, assertPortsFree } from './helpers/port-allocator.js';
// Test server and client utilities
let testServers: Array<{ server: net.Server; port: number }> = [];
let smartProxy: SmartProxy;
let TEST_PORTS: number[]; // 3 test server ports
let PROXY_PORTS: number[]; // 6 proxy ports
let TEST_PORTS: number[];
let PROXY_PORTS: number[];
const TEST_DATA = 'Hello through dynamic port mapper!';
// Cleanup function to close all servers and proxies
function cleanup() {
console.log('Starting cleanup...');
const promises = [];
// Close test servers
for (const { server, port } of testServers) {
promises.push(new Promise<void>(resolve => {
console.log(`Closing test server on port ${port}`);
@@ -35,7 +25,6 @@ function cleanup() {
}));
}
// Stop SmartProxy
if (smartProxy) {
console.log('Stopping SmartProxy...');
promises.push(smartProxy.stop().then(() => {
@@ -46,12 +35,10 @@ function cleanup() {
return Promise.all(promises);
}
// Helper: Creates a test TCP server that listens on a given port
function createTestServer(port: number): Promise<net.Server> {
return new Promise((resolve) => {
const server = net.createServer((socket) => {
socket.on('data', (data) => {
// Echo the received data back with a server identifier
socket.write(`Server ${port} says: ${data.toString()}`);
});
socket.on('error', (error) => {
@@ -67,7 +54,6 @@ function createTestServer(port: number): Promise<net.Server> {
});
}
// Helper: Creates a test client connection with timeout
function createTestClient(port: number, data: string): Promise<string> {
return new Promise((resolve, reject) => {
const client = new net.Socket();
@@ -100,123 +86,108 @@ function createTestClient(port: number, data: string): Promise<string> {
});
}
// Set up test environment
tap.test('setup port mapping test environment', async () => {
const allPorts = await findFreePorts(9);
TEST_PORTS = allPorts.slice(0, 3);
PROXY_PORTS = allPorts.slice(3, 9);
// Create multiple test servers on different ports
await Promise.all([
createTestServer(TEST_PORTS[0]),
createTestServer(TEST_PORTS[1]),
createTestServer(TEST_PORTS[2]),
]);
// Compute dynamic offset between proxy and test ports
const portOffset = TEST_PORTS[1] - PROXY_PORTS[1];
// Create a SmartProxy with dynamic port mapping routes
smartProxy = new SmartProxy({
routes: [
// Simple function that returns the same port (identity mapping)
createPortMappingRoute({
sourcePortRange: PROXY_PORTS[0],
targetHost: 'localhost',
portMapper: (context) => TEST_PORTS[0],
{
match: { ports: PROXY_PORTS[0] },
action: {
type: 'forward',
targets: [{
host: 'localhost',
port: (context: IRouteContext) => TEST_PORTS[0]
}]
},
name: 'Identity Port Mapping'
}),
},
// Offset port mapping using dynamic offset
createOffsetPortMappingRoute({
ports: PROXY_PORTS[1],
targetHost: 'localhost',
offset: portOffset,
{
match: { ports: PROXY_PORTS[1] },
action: {
type: 'forward',
targets: [{
host: 'localhost',
port: (context: IRouteContext) => context.port + portOffset
}]
},
name: `Offset Port Mapping (${portOffset})`
}),
},
// Dynamic route with conditional port mapping
createDynamicRoute({
ports: [PROXY_PORTS[2], PROXY_PORTS[3]],
targetHost: (context) => {
// Dynamic host selection based on port
{
match: { ports: [PROXY_PORTS[2], PROXY_PORTS[3]] },
action: {
type: 'forward',
targets: [{
host: (context: IRouteContext) => {
return context.port === PROXY_PORTS[2] ? 'localhost' : '127.0.0.1';
},
portMapper: (context) => {
// Port mapping logic based on incoming port
port: (context: IRouteContext) => {
if (context.port === PROXY_PORTS[2]) {
return TEST_PORTS[0];
} else {
return TEST_PORTS[2];
}
}
}]
},
name: 'Dynamic Host and Port Mapping'
}),
// Smart load balancer for domain-based routing
createSmartLoadBalancer({
ports: PROXY_PORTS[4],
domainTargets: {
'test1.example.com': 'localhost',
'test2.example.com': '127.0.0.1'
},
portMapper: (context) => {
// Use different backend ports based on domain
{
match: { ports: PROXY_PORTS[4] },
action: {
type: 'forward',
targets: [{
host: (context: IRouteContext) => {
if (context.domain === 'test1.example.com') return 'localhost';
if (context.domain === 'test2.example.com') return '127.0.0.1';
return 'localhost';
},
port: (context: IRouteContext) => {
if (context.domain === 'test1.example.com') {
return TEST_PORTS[0];
} else {
return TEST_PORTS[1];
}
}
}]
},
defaultTarget: 'localhost',
name: 'Smart Domain Load Balancer'
})
}
]
});
// Start the SmartProxy
await smartProxy.start();
});
// Test 1: Simple identity port mapping
tap.test('should map port using identity function', async () => {
const response = await createTestClient(PROXY_PORTS[0], TEST_DATA);
expect(response).toEqual(`Server ${TEST_PORTS[0]} says: ${TEST_DATA}`);
});
// Test 2: Offset port mapping
tap.test('should map port using offset function', async () => {
const response = await createTestClient(PROXY_PORTS[1], TEST_DATA);
expect(response).toEqual(`Server ${TEST_PORTS[1]} says: ${TEST_DATA}`);
});
// Test 3: Dynamic port and host mapping (conditional logic)
tap.test('should map port using dynamic logic', async () => {
const response = await createTestClient(PROXY_PORTS[2], TEST_DATA);
expect(response).toEqual(`Server ${TEST_PORTS[0]} says: ${TEST_DATA}`);
});
// Test 4: Test reuse of createPortOffset helper
tap.test('should use createPortOffset helper for port mapping', async () => {
// Test the createPortOffset helper with dynamic offset
const portOffset = TEST_PORTS[1] - PROXY_PORTS[1];
const offsetFn = createPortOffset(portOffset);
const context = {
port: PROXY_PORTS[1],
clientIp: '127.0.0.1',
serverIp: '127.0.0.1',
isTls: false,
timestamp: Date.now(),
connectionId: 'test-connection'
} as IRouteContext;
const mappedPort = offsetFn(context);
expect(mappedPort).toEqual(TEST_PORTS[1]);
});
// Test 5: Test error handling for invalid port mapping functions
tap.test('should handle errors in port mapping functions', async () => {
// Create a route with a function that throws an error
const errorRoute: IRouteConfig = {
match: {
ports: PROXY_PORTS[5]
@@ -233,23 +204,17 @@ tap.test('should handle errors in port mapping functions', async () => {
name: 'Error Route'
};
// Add the route to SmartProxy
await smartProxy.updateRoutes([...smartProxy.settings.routes, errorRoute]);
// The connection should fail or timeout
try {
await createTestClient(PROXY_PORTS[5], TEST_DATA);
// Connection should not succeed
expect(false).toBeTrue();
} catch (error) {
// Connection failed as expected
expect(true).toBeTrue();
}
});
// Cleanup
tap.test('cleanup port mapping test environment', async () => {
// Add timeout to prevent hanging if SmartProxy shutdown has issues
const cleanupPromise = cleanup();
const timeoutPromise = new Promise((_, reject) =>
setTimeout(() => reject(new Error('Cleanup timeout after 5 seconds')), 5000)
@@ -259,7 +224,6 @@ tap.test('cleanup port mapping test environment', async () => {
await Promise.race([cleanupPromise, timeoutPromise]);
} catch (error) {
console.error('Cleanup error:', error);
// Force cleanup even if there's an error
testServers = [];
smartProxy = null as any;
}
-133
View File
@@ -1,133 +0,0 @@
import { expect, tap } from '@git.zone/tstest/tapbundle';
import * as smartproxy from '../ts/index.js';
import { ProxyProtocolParser } from '../ts/core/utils/proxy-protocol.js';
tap.test('PROXY protocol v1 parser - valid headers', async () => {
// Test TCP4 format
const tcp4Header = Buffer.from('PROXY TCP4 192.168.1.1 10.0.0.1 56324 443\r\n', 'ascii');
const tcp4Result = ProxyProtocolParser.parse(tcp4Header);
expect(tcp4Result.proxyInfo).property('protocol').toEqual('TCP4');
expect(tcp4Result.proxyInfo).property('sourceIP').toEqual('192.168.1.1');
expect(tcp4Result.proxyInfo).property('sourcePort').toEqual(56324);
expect(tcp4Result.proxyInfo).property('destinationIP').toEqual('10.0.0.1');
expect(tcp4Result.proxyInfo).property('destinationPort').toEqual(443);
expect(tcp4Result.remainingData.length).toEqual(0);
// Test TCP6 format
const tcp6Header = Buffer.from('PROXY TCP6 2001:db8::1 2001:db8::2 56324 443\r\n', 'ascii');
const tcp6Result = ProxyProtocolParser.parse(tcp6Header);
expect(tcp6Result.proxyInfo).property('protocol').toEqual('TCP6');
expect(tcp6Result.proxyInfo).property('sourceIP').toEqual('2001:db8::1');
expect(tcp6Result.proxyInfo).property('sourcePort').toEqual(56324);
expect(tcp6Result.proxyInfo).property('destinationIP').toEqual('2001:db8::2');
expect(tcp6Result.proxyInfo).property('destinationPort').toEqual(443);
// Test UNKNOWN protocol
const unknownHeader = Buffer.from('PROXY UNKNOWN\r\n', 'ascii');
const unknownResult = ProxyProtocolParser.parse(unknownHeader);
expect(unknownResult.proxyInfo).property('protocol').toEqual('UNKNOWN');
expect(unknownResult.proxyInfo).property('sourceIP').toEqual('');
expect(unknownResult.proxyInfo).property('sourcePort').toEqual(0);
});
tap.test('PROXY protocol v1 parser - with remaining data', async () => {
const headerWithData = Buffer.concat([
Buffer.from('PROXY TCP4 192.168.1.1 10.0.0.1 56324 443\r\n', 'ascii'),
Buffer.from('GET / HTTP/1.1\r\n', 'ascii')
]);
const result = ProxyProtocolParser.parse(headerWithData);
expect(result.proxyInfo).property('protocol').toEqual('TCP4');
expect(result.proxyInfo).property('sourceIP').toEqual('192.168.1.1');
expect(result.remainingData.toString()).toEqual('GET / HTTP/1.1\r\n');
});
tap.test('PROXY protocol v1 parser - invalid headers', async () => {
// Not a PROXY protocol header
const notProxy = Buffer.from('GET / HTTP/1.1\r\n', 'ascii');
const notProxyResult = ProxyProtocolParser.parse(notProxy);
expect(notProxyResult.proxyInfo).toBeNull();
expect(notProxyResult.remainingData).toEqual(notProxy);
// Invalid protocol
expect(() => {
ProxyProtocolParser.parse(Buffer.from('PROXY INVALID 1.1.1.1 2.2.2.2 80 443\r\n', 'ascii'));
}).toThrow();
// Wrong number of fields
expect(() => {
ProxyProtocolParser.parse(Buffer.from('PROXY TCP4 192.168.1.1 10.0.0.1 56324\r\n', 'ascii'));
}).toThrow();
// Invalid port
expect(() => {
ProxyProtocolParser.parse(Buffer.from('PROXY TCP4 192.168.1.1 10.0.0.1 99999 443\r\n', 'ascii'));
}).toThrow();
// Invalid IP for protocol
expect(() => {
ProxyProtocolParser.parse(Buffer.from('PROXY TCP4 2001:db8::1 10.0.0.1 56324 443\r\n', 'ascii'));
}).toThrow();
});
tap.test('PROXY protocol v1 parser - incomplete headers', async () => {
// Header without terminator
const incomplete = Buffer.from('PROXY TCP4 192.168.1.1 10.0.0.1 56324 443', 'ascii');
const result = ProxyProtocolParser.parse(incomplete);
expect(result.proxyInfo).toBeNull();
expect(result.remainingData).toEqual(incomplete);
// Header exceeding max length - create a buffer that actually starts with PROXY
const longHeader = Buffer.from('PROXY TCP4 ' + '1'.repeat(100), 'ascii');
expect(() => {
ProxyProtocolParser.parse(longHeader);
}).toThrow();
});
tap.test('PROXY protocol v1 generator', async () => {
// Generate TCP4 header
const tcp4Info = {
protocol: 'TCP4' as const,
sourceIP: '192.168.1.1',
sourcePort: 56324,
destinationIP: '10.0.0.1',
destinationPort: 443
};
const tcp4Header = ProxyProtocolParser.generate(tcp4Info);
expect(tcp4Header.toString('ascii')).toEqual('PROXY TCP4 192.168.1.1 10.0.0.1 56324 443\r\n');
// Generate TCP6 header
const tcp6Info = {
protocol: 'TCP6' as const,
sourceIP: '2001:db8::1',
sourcePort: 56324,
destinationIP: '2001:db8::2',
destinationPort: 443
};
const tcp6Header = ProxyProtocolParser.generate(tcp6Info);
expect(tcp6Header.toString('ascii')).toEqual('PROXY TCP6 2001:db8::1 2001:db8::2 56324 443\r\n');
// Generate UNKNOWN header
const unknownInfo = {
protocol: 'UNKNOWN' as const,
sourceIP: '',
sourcePort: 0,
destinationIP: '',
destinationPort: 0
};
const unknownHeader = ProxyProtocolParser.generate(unknownInfo);
expect(unknownHeader.toString('ascii')).toEqual('PROXY UNKNOWN\r\n');
});
// Skipping integration tests for now - focus on unit tests
// Integration tests would require more complex setup and teardown
export default tap.start();
+225 -182
View File
@@ -6,7 +6,7 @@ import { expect, tap } from '@git.zone/tstest/tapbundle';
// Import from core modules
import { SmartProxy } from '../ts/proxies/smart-proxy/index.js';
// Import route utilities and helpers
// Import route utilities
import {
findMatchingRoutes,
findBestMatchingRoute,
@@ -28,16 +28,7 @@ import {
assertValidRoute
} from '../ts/proxies/smart-proxy/utils/route-validator.js';
import {
createHttpRoute,
createHttpsTerminateRoute,
createHttpsPassthroughRoute,
createHttpToHttpsRedirect,
createCompleteHttpsServer,
createLoadBalancerRoute,
createApiRoute,
createWebSocketRoute
} from '../ts/proxies/smart-proxy/utils/route-helpers.js';
import { SocketHandlers } from '../ts/proxies/smart-proxy/utils/socket-handlers.js';
// Import test helpers
import { loadTestCertificates } from './helpers/certificates.js';
@@ -47,12 +38,12 @@ import type { IRouteConfig } from '../ts/proxies/smart-proxy/models/route-types.
// --------------------------------- Route Creation Tests ---------------------------------
tap.test('Routes: Should create basic HTTP route', async () => {
// Create a simple HTTP route
const httpRoute = createHttpRoute('example.com', { host: 'localhost', port: 3000 }, {
const httpRoute: IRouteConfig = {
match: { ports: 80, domains: 'example.com' },
action: { type: 'forward', targets: [{ host: 'localhost', port: 3000 }] },
name: 'Basic HTTP Route'
});
};
// Validate the route configuration
expect(httpRoute.match.ports).toEqual(80);
expect(httpRoute.match.domains).toEqual('example.com');
expect(httpRoute.action.type).toEqual('forward');
@@ -62,14 +53,17 @@ tap.test('Routes: Should create basic HTTP route', async () => {
});
tap.test('Routes: Should create HTTPS route with TLS termination', async () => {
// Create an HTTPS route with TLS termination
const httpsRoute = createHttpsTerminateRoute('secure.example.com', { host: 'localhost', port: 8080 }, {
certificate: 'auto',
const httpsRoute: IRouteConfig = {
match: { ports: 443, domains: 'secure.example.com' },
action: {
type: 'forward',
targets: [{ host: 'localhost', port: 8080 }],
tls: { mode: 'terminate', certificate: 'auto' }
},
name: 'HTTPS Route'
});
};
// Validate the route configuration
expect(httpsRoute.match.ports).toEqual(443); // Default HTTPS port
expect(httpsRoute.match.ports).toEqual(443);
expect(httpsRoute.match.domains).toEqual('secure.example.com');
expect(httpsRoute.action.type).toEqual('forward');
expect(httpsRoute.action.tls?.mode).toEqual('terminate');
@@ -80,10 +74,15 @@ tap.test('Routes: Should create HTTPS route with TLS termination', async () => {
});
tap.test('Routes: Should create HTTP to HTTPS redirect', async () => {
// Create an HTTP to HTTPS redirect
const redirectRoute = createHttpToHttpsRedirect('example.com', 443);
const redirectRoute: IRouteConfig = {
match: { ports: 80, domains: 'example.com' },
action: {
type: 'socket-handler',
socketHandler: SocketHandlers.httpRedirect('https://{domain}:443{path}', 301)
},
name: 'HTTP to HTTPS Redirect for example.com'
};
// Validate the route configuration
expect(redirectRoute.match.ports).toEqual(80);
expect(redirectRoute.match.domains).toEqual('example.com');
expect(redirectRoute.action.type).toEqual('socket-handler');
@@ -91,22 +90,34 @@ tap.test('Routes: Should create HTTP to HTTPS redirect', async () => {
});
tap.test('Routes: Should create complete HTTPS server with redirects', async () => {
// Create a complete HTTPS server setup
const routes = createCompleteHttpsServer('example.com', { host: 'localhost', port: 8080 }, {
certificate: 'auto'
});
const routes: IRouteConfig[] = [
{
match: { ports: 443, domains: 'example.com' },
action: {
type: 'forward',
targets: [{ host: 'localhost', port: 8080 }],
tls: { mode: 'terminate', certificate: 'auto' }
},
name: 'HTTPS Terminate Route for example.com'
},
{
match: { ports: 80, domains: 'example.com' },
action: {
type: 'socket-handler',
socketHandler: SocketHandlers.httpRedirect('https://{domain}:443{path}', 301)
},
name: 'HTTP to HTTPS Redirect for example.com'
}
];
// Validate that we got two routes (HTTPS route and HTTP redirect)
expect(routes.length).toEqual(2);
// Validate HTTPS route
const httpsRoute = routes[0];
expect(httpsRoute.match.ports).toEqual(443);
expect(httpsRoute.match.domains).toEqual('example.com');
expect(httpsRoute.action.type).toEqual('forward');
expect(httpsRoute.action.tls?.mode).toEqual('terminate');
// Validate HTTP redirect route
const redirectRoute = routes[1];
expect(redirectRoute.match.ports).toEqual(80);
expect(redirectRoute.action.type).toEqual('socket-handler');
@@ -114,21 +125,17 @@ tap.test('Routes: Should create complete HTTPS server with redirects', async ()
});
tap.test('Routes: Should create load balancer route', async () => {
// Create a load balancer route
const lbRoute = createLoadBalancerRoute(
'app.example.com',
['10.0.0.1', '10.0.0.2', '10.0.0.3'],
8080,
{
tls: {
mode: 'terminate',
certificate: 'auto'
const lbRoute: IRouteConfig = {
match: { ports: 443, domains: 'app.example.com' },
action: {
type: 'forward',
targets: [{ host: ['10.0.0.1', '10.0.0.2', '10.0.0.3'], port: 8080 }],
tls: { mode: 'terminate', certificate: 'auto' },
loadBalancing: { algorithm: 'round-robin' }
},
name: 'Load Balanced Route'
}
);
};
// Validate the route configuration
expect(lbRoute.match.domains).toEqual('app.example.com');
expect(lbRoute.action.type).toEqual('forward');
expect(Array.isArray(lbRoute.action.targets?.[0]?.host)).toBeTrue();
@@ -139,15 +146,25 @@ tap.test('Routes: Should create load balancer route', async () => {
});
tap.test('Routes: Should create API route with CORS', async () => {
// Create an API route with CORS headers
const apiRoute = createApiRoute('api.example.com', '/v1', { host: 'localhost', port: 3000 }, {
useTls: true,
certificate: 'auto',
addCorsHeaders: true,
const apiRoute: IRouteConfig = {
match: { ports: 443, domains: 'api.example.com', path: '/v1/*' },
action: {
type: 'forward',
targets: [{ host: 'localhost', port: 3000 }],
tls: { mode: 'terminate', certificate: 'auto' }
},
headers: {
response: {
'Access-Control-Allow-Origin': '*',
'Access-Control-Allow-Methods': 'GET, POST, PUT, DELETE, OPTIONS',
'Access-Control-Allow-Headers': 'Content-Type, Authorization',
'Access-Control-Max-Age': '86400'
}
},
priority: 100,
name: 'API Route'
});
};
// Validate the route configuration
expect(apiRoute.match.domains).toEqual('api.example.com');
expect(apiRoute.match.path).toEqual('/v1/*');
expect(apiRoute.action.type).toEqual('forward');
@@ -155,7 +172,6 @@ tap.test('Routes: Should create API route with CORS', async () => {
expect(apiRoute.action.targets?.[0]?.host).toEqual('localhost');
expect(apiRoute.action.targets?.[0]?.port).toEqual(3000);
// Check CORS headers
expect(apiRoute.headers).toBeDefined();
if (apiRoute.headers?.response) {
expect(apiRoute.headers.response['Access-Control-Allow-Origin']).toEqual('*');
@@ -164,15 +180,18 @@ tap.test('Routes: Should create API route with CORS', async () => {
});
tap.test('Routes: Should create WebSocket route', async () => {
// Create a WebSocket route
const wsRoute = createWebSocketRoute('ws.example.com', '/socket', { host: 'localhost', port: 5000 }, {
useTls: true,
certificate: 'auto',
pingInterval: 15000,
const wsRoute: IRouteConfig = {
match: { ports: 443, domains: 'ws.example.com', path: '/socket' },
action: {
type: 'forward',
targets: [{ host: 'localhost', port: 5000 }],
tls: { mode: 'terminate', certificate: 'auto' },
websocket: { enabled: true, pingInterval: 15000 }
},
priority: 100,
name: 'WebSocket Route'
});
};
// Validate the route configuration
expect(wsRoute.match.domains).toEqual('ws.example.com');
expect(wsRoute.match.path).toEqual('/socket');
expect(wsRoute.action.type).toEqual('forward');
@@ -180,7 +199,6 @@ tap.test('Routes: Should create WebSocket route', async () => {
expect(wsRoute.action.targets?.[0]?.host).toEqual('localhost');
expect(wsRoute.action.targets?.[0]?.port).toEqual(5000);
// Check WebSocket configuration
expect(wsRoute.action.websocket).toBeDefined();
if (wsRoute.action.websocket) {
expect(wsRoute.action.websocket.enabled).toBeTrue();
@@ -191,22 +209,27 @@ tap.test('Routes: Should create WebSocket route', async () => {
// Static file serving has been removed - should be handled by external servers
tap.test('SmartProxy: Should create instance with route-based config', async () => {
// Create TLS certificates for testing
const certs = loadTestCertificates();
// Create a SmartProxy instance with route-based configuration
const proxy = new SmartProxy({
routes: [
createHttpRoute('example.com', { host: 'localhost', port: 3000 }, {
{
match: { ports: 80, domains: 'example.com' },
action: { type: 'forward', targets: [{ host: 'localhost', port: 3000 }] },
name: 'HTTP Route'
}),
createHttpsTerminateRoute('secure.example.com', { host: 'localhost', port: 8443 }, {
certificate: {
key: certs.privateKey,
cert: certs.publicKey
},
{
match: { ports: 443, domains: 'secure.example.com' },
action: {
type: 'forward',
targets: [{ host: 'localhost', port: 8443 }],
tls: {
mode: 'terminate',
certificate: { key: certs.privateKey, cert: certs.publicKey }
}
},
name: 'HTTPS Route'
})
}
],
defaults: {
target: {
@@ -218,13 +241,11 @@ tap.test('SmartProxy: Should create instance with route-based config', async ()
maxConnections: 100
}
},
// Additional settings
initialDataTimeout: 10000,
inactivityTimeout: 300000,
enableDetailedLogging: true
});
// Simply verify the instance was created successfully
expect(typeof proxy).toEqual('object');
expect(typeof proxy.start).toEqual('function');
expect(typeof proxy.stop).toEqual('function');
@@ -233,7 +254,6 @@ tap.test('SmartProxy: Should create instance with route-based config', async ()
// --------------------------------- Edge Case Tests ---------------------------------
tap.test('Edge Case - Empty Routes Array', async () => {
// Attempting to find routes in an empty array
const emptyRoutes: IRouteConfig[] = [];
const matches = findMatchingRoutes(emptyRoutes, { domain: 'example.com', port: 80 });
@@ -245,82 +265,98 @@ tap.test('Edge Case - Empty Routes Array', async () => {
});
tap.test('Edge Case - Multiple Matching Routes with Same Priority', async () => {
// Create multiple routes with identical priority but different targets
const route1 = createHttpRoute('example.com', { host: 'server1', port: 3000 });
const route2 = createHttpRoute('example.com', { host: 'server2', port: 3000 });
const route3 = createHttpRoute('example.com', { host: 'server3', port: 3000 });
const route1: IRouteConfig = {
match: { ports: 80, domains: 'example.com' },
action: { type: 'forward', targets: [{ host: 'server1', port: 3000 }] },
name: 'HTTP Route for example.com'
};
const route2: IRouteConfig = {
match: { ports: 80, domains: 'example.com' },
action: { type: 'forward', targets: [{ host: 'server2', port: 3000 }] },
name: 'HTTP Route for example.com'
};
const route3: IRouteConfig = {
match: { ports: 80, domains: 'example.com' },
action: { type: 'forward', targets: [{ host: 'server3', port: 3000 }] },
name: 'HTTP Route for example.com'
};
// Set all to the same priority
route1.priority = 100;
route2.priority = 100;
route3.priority = 100;
const routes = [route1, route2, route3];
// Find matching routes
const matches = findMatchingRoutes(routes, { domain: 'example.com', port: 80 });
// Should find all three routes
expect(matches.length).toEqual(3);
// First match could be any of the routes since they have the same priority
// But the implementation should be consistent (likely keep the original order)
const bestMatch = findBestMatchingRoute(routes, { domain: 'example.com', port: 80 });
expect(bestMatch).not.toBeUndefined();
});
tap.test('Edge Case - Wildcard Domains and Path Matching', async () => {
// Create routes with wildcard domains and path patterns
const wildcardApiRoute = createApiRoute('*.example.com', '/api', { host: 'api-server', port: 3000 }, {
useTls: true,
certificate: 'auto'
});
const wildcardApiRoute: IRouteConfig = {
match: { ports: 443, domains: '*.example.com', path: '/api/*' },
action: {
type: 'forward',
targets: [{ host: 'api-server', port: 3000 }],
tls: { mode: 'terminate', certificate: 'auto' }
},
priority: 100,
name: 'API Route for *.example.com'
};
const exactApiRoute = createApiRoute('api.example.com', '/api', { host: 'specific-api-server', port: 3001 }, {
useTls: true,
certificate: 'auto',
priority: 200 // Higher priority
});
const exactApiRoute: IRouteConfig = {
match: { ports: 443, domains: 'api.example.com', path: '/api/*' },
action: {
type: 'forward',
targets: [{ host: 'specific-api-server', port: 3001 }],
tls: { mode: 'terminate', certificate: 'auto' }
},
priority: 200,
name: 'API Route for api.example.com'
};
const routes = [wildcardApiRoute, exactApiRoute];
// Test with a specific subdomain that matches both routes
const matches = findMatchingRoutes(routes, { domain: 'api.example.com', path: '/api/users', port: 443 });
// Should match both routes
expect(matches.length).toEqual(2);
// The exact domain match should have higher priority
const bestMatch = findBestMatchingRoute(routes, { domain: 'api.example.com', path: '/api/users', port: 443 });
expect(bestMatch).not.toBeUndefined();
if (bestMatch) {
expect(bestMatch.action.targets[0].port).toEqual(3001); // Should match the exact domain route
expect(bestMatch.action.targets[0].port).toEqual(3001);
}
// Test with a different subdomain - should only match the wildcard route
const otherMatches = findMatchingRoutes(routes, { domain: 'other.example.com', path: '/api/products', port: 443 });
expect(otherMatches.length).toEqual(1);
expect(otherMatches[0].action.targets[0].port).toEqual(3000); // Should match the wildcard domain route
expect(otherMatches[0].action.targets[0].port).toEqual(3000);
});
tap.test('Edge Case - Disabled Routes', async () => {
// Create enabled and disabled routes
const enabledRoute = createHttpRoute('example.com', { host: 'server1', port: 3000 });
const disabledRoute = createHttpRoute('example.com', { host: 'server2', port: 3001 });
const enabledRoute: IRouteConfig = {
match: { ports: 80, domains: 'example.com' },
action: { type: 'forward', targets: [{ host: 'server1', port: 3000 }] },
name: 'HTTP Route for example.com'
};
const disabledRoute: IRouteConfig = {
match: { ports: 80, domains: 'example.com' },
action: { type: 'forward', targets: [{ host: 'server2', port: 3001 }] },
name: 'HTTP Route for example.com'
};
disabledRoute.enabled = false;
const routes = [enabledRoute, disabledRoute];
// Find matching routes
const matches = findMatchingRoutes(routes, { domain: 'example.com', port: 80 });
// Should only find the enabled route
expect(matches.length).toEqual(1);
expect(matches[0].action.targets[0].port).toEqual(3000);
});
tap.test('Edge Case - Complex Path and Headers Matching', async () => {
// Create route with complex path and headers matching
const complexRoute: IRouteConfig = {
match: {
domains: 'api.example.com',
@@ -345,7 +381,6 @@ tap.test('Edge Case - Complex Path and Headers Matching', async () => {
name: 'Complex API Route'
};
// Test with matching criteria
const matchingPath = routeMatchesPath(complexRoute, '/api/v2/users');
expect(matchingPath).toBeTrue();
@@ -356,7 +391,6 @@ tap.test('Edge Case - Complex Path and Headers Matching', async () => {
});
expect(matchingHeaders).toBeTrue();
// Test with non-matching criteria
const nonMatchingPath = routeMatchesPath(complexRoute, '/api/v1/users');
expect(nonMatchingPath).toBeFalse();
@@ -368,7 +402,6 @@ tap.test('Edge Case - Complex Path and Headers Matching', async () => {
});
tap.test('Edge Case - Port Range Matching', async () => {
// Create route with port range matching
const portRangeRoute: IRouteConfig = {
match: {
domains: 'example.com',
@@ -384,16 +417,13 @@ tap.test('Edge Case - Port Range Matching', async () => {
name: 'Port Range Route'
};
// Test with ports in the range
expect(routeMatchesPort(portRangeRoute, 8000)).toBeTrue(); // Lower bound
expect(routeMatchesPort(portRangeRoute, 8500)).toBeTrue(); // Middle
expect(routeMatchesPort(portRangeRoute, 9000)).toBeTrue(); // Upper bound
expect(routeMatchesPort(portRangeRoute, 8000)).toBeTrue();
expect(routeMatchesPort(portRangeRoute, 8500)).toBeTrue();
expect(routeMatchesPort(portRangeRoute, 9000)).toBeTrue();
// Test with ports outside the range
expect(routeMatchesPort(portRangeRoute, 7999)).toBeFalse(); // Just below
expect(routeMatchesPort(portRangeRoute, 9001)).toBeFalse(); // Just above
expect(routeMatchesPort(portRangeRoute, 7999)).toBeFalse();
expect(routeMatchesPort(portRangeRoute, 9001)).toBeFalse();
// Test with multiple port ranges
const multiRangeRoute: IRouteConfig = {
match: {
domains: 'example.com',
@@ -420,55 +450,56 @@ tap.test('Edge Case - Port Range Matching', async () => {
// --------------------------------- Wildcard Domain Tests ---------------------------------
tap.test('Wildcard Domain Handling', async () => {
// Create routes with different wildcard patterns
const simpleDomainRoute = createHttpRoute('example.com', { host: 'server1', port: 3000 });
const wildcardSubdomainRoute = createHttpRoute('*.example.com', { host: 'server2', port: 3001 });
const specificSubdomainRoute = createHttpRoute('api.example.com', { host: 'server3', port: 3002 });
const simpleDomainRoute: IRouteConfig = {
match: { ports: 80, domains: 'example.com' },
action: { type: 'forward', targets: [{ host: 'server1', port: 3000 }] },
name: 'HTTP Route for example.com'
};
const wildcardSubdomainRoute: IRouteConfig = {
match: { ports: 80, domains: '*.example.com' },
action: { type: 'forward', targets: [{ host: 'server2', port: 3001 }] },
name: 'HTTP Route for *.example.com'
};
const specificSubdomainRoute: IRouteConfig = {
match: { ports: 80, domains: 'api.example.com' },
action: { type: 'forward', targets: [{ host: 'server3', port: 3002 }] },
name: 'HTTP Route for api.example.com'
};
// Set explicit priorities to ensure deterministic matching
specificSubdomainRoute.priority = 200; // Highest priority for specific domain
wildcardSubdomainRoute.priority = 100; // Medium priority for wildcard
simpleDomainRoute.priority = 50; // Lowest priority for generic domain
specificSubdomainRoute.priority = 200;
wildcardSubdomainRoute.priority = 100;
simpleDomainRoute.priority = 50;
const routes = [simpleDomainRoute, wildcardSubdomainRoute, specificSubdomainRoute];
// Test exact domain match
expect(routeMatchesDomain(simpleDomainRoute, 'example.com')).toBeTrue();
expect(routeMatchesDomain(simpleDomainRoute, 'sub.example.com')).toBeFalse();
// Test wildcard subdomain match
expect(routeMatchesDomain(wildcardSubdomainRoute, 'any.example.com')).toBeTrue();
expect(routeMatchesDomain(wildcardSubdomainRoute, 'nested.sub.example.com')).toBeTrue();
expect(routeMatchesDomain(wildcardSubdomainRoute, 'example.com')).toBeFalse();
// Test specific subdomain match
expect(routeMatchesDomain(specificSubdomainRoute, 'api.example.com')).toBeTrue();
expect(routeMatchesDomain(specificSubdomainRoute, 'other.example.com')).toBeFalse();
expect(routeMatchesDomain(specificSubdomainRoute, 'sub.api.example.com')).toBeFalse();
// Test finding best match when multiple domains match
const specificSubdomainRequest = { domain: 'api.example.com', port: 80 };
const bestSpecificMatch = findBestMatchingRoute(routes, specificSubdomainRequest);
expect(bestSpecificMatch).not.toBeUndefined();
if (bestSpecificMatch) {
// Find which route was matched
const matchedPort = bestSpecificMatch.action.targets[0].port;
console.log(`Matched route with port: ${matchedPort}`);
// Verify it's the specific subdomain route (with highest priority)
expect(bestSpecificMatch.priority).toEqual(200);
}
// Test with a subdomain that matches wildcard but not specific
const otherSubdomainRequest = { domain: 'other.example.com', port: 80 };
const bestWildcardMatch = findBestMatchingRoute(routes, otherSubdomainRequest);
expect(bestWildcardMatch).not.toBeUndefined();
if (bestWildcardMatch) {
// Find which route was matched
const matchedPort = bestWildcardMatch.action.targets[0].port;
console.log(`Matched route with port: ${matchedPort}`);
// Verify it's the wildcard subdomain route (with medium priority)
expect(bestWildcardMatch.priority).toEqual(100);
}
});
@@ -476,39 +507,68 @@ tap.test('Wildcard Domain Handling', async () => {
// --------------------------------- Integration Tests ---------------------------------
tap.test('Route Integration - Combining Multiple Route Types', async () => {
// Create a comprehensive set of routes for a full application
const routes: IRouteConfig[] = [
// Main website with HTTPS and HTTP redirect
...createCompleteHttpsServer('example.com', { host: 'web-server', port: 8080 }, {
certificate: 'auto'
}),
// API endpoints
createApiRoute('api.example.com', '/v1', { host: 'api-server', port: 3000 }, {
useTls: true,
certificate: 'auto',
addCorsHeaders: true
}),
// WebSocket for real-time updates
createWebSocketRoute('ws.example.com', '/live', { host: 'websocket-server', port: 5000 }, {
useTls: true,
certificate: 'auto'
}),
// Legacy system with passthrough
createHttpsPassthroughRoute('legacy.example.com', { host: 'legacy-server', port: 443 })
{
match: { ports: 443, domains: 'example.com' },
action: {
type: 'forward',
targets: [{ host: 'web-server', port: 8080 }],
tls: { mode: 'terminate', certificate: 'auto' }
},
name: 'HTTPS Terminate Route for example.com'
},
{
match: { ports: 80, domains: 'example.com' },
action: {
type: 'socket-handler',
socketHandler: SocketHandlers.httpRedirect('https://{domain}:443{path}', 301)
},
name: 'HTTP to HTTPS Redirect for example.com'
},
{
match: { ports: 443, domains: 'api.example.com', path: '/v1/*' },
action: {
type: 'forward',
targets: [{ host: 'api-server', port: 3000 }],
tls: { mode: 'terminate', certificate: 'auto' }
},
headers: {
response: {
'Access-Control-Allow-Origin': '*',
'Access-Control-Allow-Methods': 'GET, POST, PUT, DELETE, OPTIONS',
'Access-Control-Allow-Headers': 'Content-Type, Authorization',
'Access-Control-Max-Age': '86400'
}
},
priority: 100,
name: 'API Route for api.example.com'
},
{
match: { ports: 443, domains: 'ws.example.com', path: '/live' },
action: {
type: 'forward',
targets: [{ host: 'websocket-server', port: 5000 }],
tls: { mode: 'terminate', certificate: 'auto' },
websocket: { enabled: true }
},
priority: 100,
name: 'WebSocket Route for ws.example.com'
},
{
match: { ports: 443, domains: 'legacy.example.com' },
action: {
type: 'forward',
targets: [{ host: 'legacy-server', port: 443 }],
tls: { mode: 'passthrough' }
},
name: 'HTTPS Passthrough Route for legacy.example.com'
}
];
// Validate all routes
const validationResult = validateRoutes(routes);
expect(validationResult.valid).toBeTrue();
expect(validationResult.errors.length).toEqual(0);
// Test route matching for different endpoints
// Web server (HTTPS)
const webServerMatch = findBestMatchingRoute(routes, { domain: 'example.com', port: 443 });
expect(webServerMatch).not.toBeUndefined();
if (webServerMatch) {
@@ -516,14 +576,12 @@ tap.test('Route Integration - Combining Multiple Route Types', async () => {
expect(webServerMatch.action.targets[0].host).toEqual('web-server');
}
// Web server (HTTP redirect via socket handler)
const webRedirectMatch = findBestMatchingRoute(routes, { domain: 'example.com', port: 80 });
expect(webRedirectMatch).not.toBeUndefined();
if (webRedirectMatch) {
expect(webRedirectMatch.action.type).toEqual('socket-handler');
}
// API server
const apiMatch = findBestMatchingRoute(routes, {
domain: 'api.example.com',
port: 443,
@@ -535,7 +593,6 @@ tap.test('Route Integration - Combining Multiple Route Types', async () => {
expect(apiMatch.action.targets[0].host).toEqual('api-server');
}
// WebSocket server
const wsMatch = findBestMatchingRoute(routes, {
domain: 'ws.example.com',
port: 443,
@@ -548,9 +605,6 @@ tap.test('Route Integration - Combining Multiple Route Types', async () => {
expect(wsMatch.action.websocket?.enabled).toBeTrue();
}
// Static assets route was removed - static file serving should be handled externally
// Legacy system
const legacyMatch = findBestMatchingRoute(routes, {
domain: 'legacy.example.com',
port: 443
@@ -565,7 +619,6 @@ tap.test('Route Integration - Combining Multiple Route Types', async () => {
// --------------------------------- Protocol Match Field Tests ---------------------------------
tap.test('Routes: Should accept protocol field on route match', async () => {
// Create a route with protocol: 'http'
const httpOnlyRoute: IRouteConfig = {
match: {
ports: 443,
@@ -583,16 +636,13 @@ tap.test('Routes: Should accept protocol field on route match', async () => {
name: 'HTTP-only Route',
};
// Validate the route - protocol field should not cause errors
const validation = validateRouteConfig(httpOnlyRoute);
expect(validation.valid).toBeTrue();
// Verify the protocol field is preserved
expect(httpOnlyRoute.match.protocol).toEqual('http');
});
tap.test('Routes: Should accept protocol tcp on route match', async () => {
// Create a route with protocol: 'tcp'
const tcpOnlyRoute: IRouteConfig = {
match: {
ports: 443,
@@ -616,28 +666,26 @@ tap.test('Routes: Should accept protocol tcp on route match', async () => {
});
tap.test('Routes: Protocol field should work with terminate-and-reencrypt', async () => {
// Create a terminate-and-reencrypt route that only accepts HTTP
const reencryptRoute = createHttpsTerminateRoute(
'secure.example.com',
{ host: 'backend', port: 443 },
{ reencrypt: true, certificate: 'auto', name: 'Reencrypt HTTP Route' }
);
const reencryptRoute: IRouteConfig = {
match: { ports: 443, domains: 'secure.example.com' },
action: {
type: 'forward',
targets: [{ host: 'backend', port: 443 }],
tls: { mode: 'terminate-and-reencrypt', certificate: 'auto' }
},
name: 'Reencrypt HTTP Route'
};
// Set protocol restriction to http
reencryptRoute.match.protocol = 'http';
// Validate the route
const validation = validateRouteConfig(reencryptRoute);
expect(validation.valid).toBeTrue();
// Verify TLS mode
expect(reencryptRoute.action.tls?.mode).toEqual('terminate-and-reencrypt');
// Verify protocol field is preserved
expect(reencryptRoute.match.protocol).toEqual('http');
});
tap.test('Routes: Protocol field should not affect domain/port matching', async () => {
// Routes with and without protocol field should both match the same domain/port
const routeWithProtocol: IRouteConfig = {
match: {
ports: 443,
@@ -669,11 +717,9 @@ tap.test('Routes: Protocol field should not affect domain/port matching', async
const routes = [routeWithProtocol, routeWithoutProtocol];
// Both routes should match the domain/port (protocol is a hint for Rust-side matching)
const matches = findMatchingRoutes(routes, { domain: 'example.com', port: 443 });
expect(matches.length).toEqual(2);
// The one with higher priority should be first
const best = findBestMatchingRoute(routes, { domain: 'example.com', port: 443 });
expect(best).not.toBeUndefined();
expect(best!.name).toEqual('With Protocol');
@@ -696,11 +742,9 @@ tap.test('Routes: Protocol field preserved through route cloning', async () => {
const cloned = cloneRoute(original);
// Verify protocol is preserved in clone
expect(cloned.match.protocol).toEqual('http');
expect(cloned.action.tls?.mode).toEqual('terminate-and-reencrypt');
// Modify clone should not affect original
cloned.match.protocol = 'tcp';
expect(original.match.protocol).toEqual('http');
});
@@ -720,7 +764,6 @@ tap.test('Routes: Protocol field preserved through route merging', async () => {
name: 'Merge Base',
};
// Merge with override that changes name but not protocol
const merged = mergeRouteConfigs(base, { name: 'Merged Route' });
expect(merged.match.protocol).toEqual('http');
expect(merged.name).toEqual('Merged Route');
+151 -406
View File
@@ -1,21 +1,7 @@
import { tap, expect } from '@git.zone/tstest/tapbundle';
import * as plugins from '../ts/plugins.js';
// Import from individual modules to avoid naming conflicts
import {
// Route helpers
createHttpRoute,
createHttpsTerminateRoute,
createApiRoute,
createWebSocketRoute,
createHttpToHttpsRedirect,
createHttpsPassthroughRoute,
createCompleteHttpsServer,
createLoadBalancerRoute
} from '../ts/proxies/smart-proxy/utils/route-helpers.js';
import {
// Route validators
validateRouteConfig,
validateRoutes,
isValidDomain,
@@ -27,7 +13,6 @@ import {
} from '../ts/proxies/smart-proxy/utils/route-validator.js';
import {
// Route utilities
mergeRouteConfigs,
findMatchingRoutes,
findBestMatchingRoute,
@@ -39,16 +24,6 @@ import {
cloneRoute
} from '../ts/proxies/smart-proxy/utils/route-utils.js';
import {
// Route patterns
createApiGatewayRoute,
createWebSocketRoute as createWebSocketPattern,
createLoadBalancerRoute as createLbPattern,
addRateLimiting,
addBasicAuth,
addJwtAuth
} from '../ts/proxies/smart-proxy/utils/route-helpers.js';
import type {
IRouteConfig,
IRouteMatch,
@@ -174,12 +149,16 @@ tap.test('Route Validation - validateRouteAction', async () => {
const invalidSocketResult = validateRouteAction(invalidSocketAction);
expect(invalidSocketResult.valid).toBeFalse();
expect(invalidSocketResult.errors.length).toBeGreaterThan(0);
expect(invalidSocketResult.errors[0]).toInclude('Socket handler function is required');
expect(invalidSocketResult.errors[0]).toInclude('handler function is required');
});
tap.test('Route Validation - validateRouteConfig', async () => {
// Valid route config
const validRoute = createHttpRoute('example.com', { host: 'localhost', port: 3000 });
const validRoute: IRouteConfig = {
match: { ports: 80, domains: 'example.com' },
action: { type: 'forward', targets: [{ host: 'localhost', port: 3000 }] },
name: 'HTTP Route for example.com',
};
const validResult = validateRouteConfig(validRoute);
expect(validResult.valid).toBeTrue();
expect(validResult.errors.length).toEqual(0);
@@ -203,7 +182,11 @@ tap.test('Route Validation - validateRouteConfig', async () => {
tap.test('Route Validation - validateRoutes', async () => {
// Create valid and invalid routes
const routes = [
createHttpRoute('example.com', { host: 'localhost', port: 3000 }),
{
match: { ports: 80, domains: 'example.com' },
action: { type: 'forward', targets: [{ host: 'localhost', port: 3000 }] },
name: 'HTTP Route for example.com',
} as IRouteConfig,
{
match: {
domains: 'invalid..domain',
@@ -217,7 +200,11 @@ tap.test('Route Validation - validateRoutes', async () => {
}
}
} as IRouteConfig,
createHttpsTerminateRoute('secure.example.com', { host: 'localhost', port: 3001 })
{
match: { ports: 443, domains: 'secure.example.com' },
action: { type: 'forward', targets: [{ host: 'localhost', port: 3001 }], tls: { mode: 'terminate', certificate: 'auto' } },
name: 'HTTPS Terminate Route for secure.example.com',
} as IRouteConfig
];
const result = validateRoutes(routes);
@@ -230,13 +217,13 @@ tap.test('Route Validation - validateRoutes', async () => {
tap.test('Route Validation - hasRequiredPropertiesForAction', async () => {
// Forward action
const forwardRoute = createHttpRoute('example.com', { host: 'localhost', port: 3000 });
const forwardRoute: IRouteConfig = {
match: { ports: 80, domains: 'example.com' },
action: { type: 'forward', targets: [{ host: 'localhost', port: 3000 }] },
name: 'HTTP Route for example.com',
};
expect(hasRequiredPropertiesForAction(forwardRoute, 'forward')).toBeTrue();
// Socket handler action (redirect functionality)
const redirectRoute = createHttpToHttpsRedirect('example.com');
expect(hasRequiredPropertiesForAction(redirectRoute, 'socket-handler')).toBeTrue();
// Socket handler action
const socketRoute: IRouteConfig = {
match: {
@@ -269,7 +256,11 @@ tap.test('Route Validation - hasRequiredPropertiesForAction', async () => {
tap.test('Route Validation - assertValidRoute', async () => {
// Valid route
const validRoute = createHttpRoute('example.com', { host: 'localhost', port: 3000 });
const validRoute: IRouteConfig = {
match: { ports: 80, domains: 'example.com' },
action: { type: 'forward', targets: [{ host: 'localhost', port: 3000 }] },
name: 'HTTP Route for example.com',
};
expect(() => assertValidRoute(validRoute)).not.toThrow();
// Invalid route
@@ -290,7 +281,11 @@ tap.test('Route Validation - assertValidRoute', async () => {
tap.test('Route Utilities - mergeRouteConfigs', async () => {
// Base route
const baseRoute = createHttpRoute('example.com', { host: 'localhost', port: 3000 });
const baseRoute: IRouteConfig = {
match: { ports: 80, domains: 'example.com' },
action: { type: 'forward', targets: [{ host: 'localhost', port: 3000 }] },
name: 'HTTP Route for example.com',
};
// Override with different name and port
const overrideRoute: Partial<IRouteConfig> = {
@@ -345,13 +340,25 @@ tap.test('Route Utilities - mergeRouteConfigs', async () => {
tap.test('Route Matching - routeMatchesDomain', async () => {
// Create route with wildcard domain
const wildcardRoute = createHttpRoute('*.example.com', { host: 'localhost', port: 3000 });
const wildcardRoute: IRouteConfig = {
match: { ports: 80, domains: '*.example.com' },
action: { type: 'forward', targets: [{ host: 'localhost', port: 3000 }] },
name: 'HTTP Route for *.example.com',
};
// Create route with exact domain
const exactRoute = createHttpRoute('example.com', { host: 'localhost', port: 3000 });
const exactRoute: IRouteConfig = {
match: { ports: 80, domains: 'example.com' },
action: { type: 'forward', targets: [{ host: 'localhost', port: 3000 }] },
name: 'HTTP Route for example.com',
};
// Create route with multiple domains
const multiDomainRoute = createHttpRoute(['example.com', 'example.org'], { host: 'localhost', port: 3000 });
const multiDomainRoute: IRouteConfig = {
match: { ports: 80, domains: ['example.com', 'example.org'] },
action: { type: 'forward', targets: [{ host: 'localhost', port: 3000 }] },
name: 'HTTP Route for example.com,example.org',
};
// Test wildcard domain matching
expect(routeMatchesDomain(wildcardRoute, 'sub.example.com')).toBeTrue();
@@ -374,7 +381,11 @@ tap.test('Route Matching - routeMatchesDomain', async () => {
tap.test('Route Matching - routeMatchesPort', async () => {
// Create routes with different port configurations
const singlePortRoute = createHttpRoute('example.com', { host: 'localhost', port: 3000 });
const singlePortRoute: IRouteConfig = {
match: { ports: 80, domains: 'example.com' },
action: { type: 'forward', targets: [{ host: 'localhost', port: 3000 }] },
name: 'HTTP Route for example.com',
};
const multiPortRoute: IRouteConfig = {
match: {
@@ -526,8 +537,37 @@ tap.test('Route Matching - routeMatchesHeaders', async () => {
'X-Custom-Header': 'value'
})).toBeFalse();
const regexHeaderRoute: IRouteConfig = {
match: {
domains: 'example.com',
ports: 80,
headers: {
'Content-Type': /^application\/(json|problem\+json)$/i,
}
},
action: {
type: 'forward',
targets: [{
host: 'localhost',
port: 3000
}]
}
};
expect(routeMatchesHeaders(regexHeaderRoute, {
'Content-Type': 'Application/Problem+Json',
})).toBeTrue();
expect(routeMatchesHeaders(regexHeaderRoute, {
'Content-Type': 'text/html',
})).toBeFalse();
// Route without header matching should match any headers
const noHeaderRoute = createHttpRoute('example.com', { host: 'localhost', port: 3000 });
const noHeaderRoute: IRouteConfig = {
match: { ports: 80, domains: 'example.com' },
action: { type: 'forward', targets: [{ host: 'localhost', port: 3000 }] },
name: 'HTTP Route for example.com',
};
expect(routeMatchesHeaders(noHeaderRoute, {
'Content-Type': 'application/json'
})).toBeTrue();
@@ -536,10 +576,26 @@ tap.test('Route Matching - routeMatchesHeaders', async () => {
tap.test('Route Finding - findMatchingRoutes', async () => {
// Create multiple routes
const routes: IRouteConfig[] = [
createHttpRoute('example.com', { host: 'localhost', port: 3000 }),
createHttpsTerminateRoute('secure.example.com', { host: 'localhost', port: 3001 }),
createApiRoute('api.example.com', '/v1', { host: 'localhost', port: 3002 }),
createWebSocketRoute('ws.example.com', '/socket', { host: 'localhost', port: 3003 })
{
match: { ports: 80, domains: 'example.com' },
action: { type: 'forward', targets: [{ host: 'localhost', port: 3000 }] },
name: 'HTTP Route for example.com',
},
{
match: { ports: 443, domains: 'secure.example.com' },
action: { type: 'forward', targets: [{ host: 'localhost', port: 3001 }], tls: { mode: 'terminate', certificate: 'auto' } },
name: 'HTTPS Route for secure.example.com',
},
{
match: { ports: 443, domains: 'api.example.com', path: '/v1/*' },
action: { type: 'forward', targets: [{ host: 'localhost', port: 3002 }], tls: { mode: 'terminate', certificate: 'auto' } },
name: 'API Route for api.example.com',
},
{
match: { ports: 443, domains: 'ws.example.com', path: '/socket' },
action: { type: 'forward', targets: [{ host: 'localhost', port: 3003 }], tls: { mode: 'terminate', certificate: 'auto' }, websocket: { enabled: true } },
name: 'WebSocket Route for ws.example.com',
},
];
// Set priorities
@@ -566,10 +622,18 @@ tap.test('Route Finding - findMatchingRoutes', async () => {
expect(wsMatches[0].name).toInclude('WebSocket Route');
// Test finding multiple routes that match same criteria
const route1 = createHttpRoute('example.com', { host: 'localhost', port: 3000 });
const route1: IRouteConfig = {
match: { ports: 80, domains: 'example.com' },
action: { type: 'forward', targets: [{ host: 'localhost', port: 3000 }] },
name: 'HTTP Route for example.com',
};
route1.priority = 10;
const route2 = createHttpRoute('example.com', { host: 'localhost', port: 3001 });
const route2: IRouteConfig = {
match: { ports: 80, domains: 'example.com' },
action: { type: 'forward', targets: [{ host: 'localhost', port: 3001 }] },
name: 'HTTP Route for example.com',
};
route2.priority = 20;
route2.match.path = '/api';
@@ -581,7 +645,11 @@ tap.test('Route Finding - findMatchingRoutes', async () => {
expect(multiMatches[1].priority).toEqual(10);
// Test disabled routes
const disabledRoute = createHttpRoute('example.com', { host: 'localhost', port: 3000 });
const disabledRoute: IRouteConfig = {
match: { ports: 80, domains: 'example.com' },
action: { type: 'forward', targets: [{ host: 'localhost', port: 3000 }] },
name: 'HTTP Route for example.com',
};
disabledRoute.enabled = false;
const enabledRoutes = findMatchingRoutes([disabledRoute], { domain: 'example.com', port: 80 });
@@ -590,14 +658,26 @@ tap.test('Route Finding - findMatchingRoutes', async () => {
tap.test('Route Finding - findBestMatchingRoute', async () => {
// Create multiple routes with different priorities
const route1 = createHttpRoute('example.com', { host: 'localhost', port: 3000 });
const route1: IRouteConfig = {
match: { ports: 80, domains: 'example.com' },
action: { type: 'forward', targets: [{ host: 'localhost', port: 3000 }] },
name: 'HTTP Route for example.com',
};
route1.priority = 10;
const route2 = createHttpRoute('example.com', { host: 'localhost', port: 3001 });
const route2: IRouteConfig = {
match: { ports: 80, domains: 'example.com' },
action: { type: 'forward', targets: [{ host: 'localhost', port: 3001 }] },
name: 'HTTP Route for example.com',
};
route2.priority = 20;
route2.match.path = '/api';
const route3 = createHttpRoute('example.com', { host: 'localhost', port: 3002 });
const route3: IRouteConfig = {
match: { ports: 80, domains: 'example.com' },
action: { type: 'forward', targets: [{ host: 'localhost', port: 3002 }] },
name: 'HTTP Route for example.com',
};
route3.priority = 30;
route3.match.path = '/api/users';
@@ -615,29 +695,42 @@ tap.test('Route Finding - findBestMatchingRoute', async () => {
tap.test('Route Utilities - generateRouteId', async () => {
// Test ID generation for different route types
const httpRoute = createHttpRoute('example.com', { host: 'localhost', port: 3000 });
const httpRoute: IRouteConfig = {
match: { ports: 80, domains: 'example.com' },
action: { type: 'forward', targets: [{ host: 'localhost', port: 3000 }] },
name: 'HTTP Route for example.com',
};
const httpId = generateRouteId(httpRoute);
expect(httpId).toInclude('example-com');
expect(httpId).toInclude('80');
expect(httpId).toInclude('forward');
const httpsRoute = createHttpsTerminateRoute('secure.example.com', { host: 'localhost', port: 3001 });
const httpsRoute: IRouteConfig = {
match: { ports: 443, domains: 'secure.example.com' },
action: { type: 'forward', targets: [{ host: 'localhost', port: 3001 }], tls: { mode: 'terminate', certificate: 'auto' } },
name: 'HTTPS Terminate Route for secure.example.com',
};
const httpsId = generateRouteId(httpsRoute);
expect(httpsId).toInclude('secure-example-com');
expect(httpsId).toInclude('443');
expect(httpsId).toInclude('forward');
const multiDomainRoute = createHttpRoute(['example.com', 'example.org'], { host: 'localhost', port: 3000 });
const multiDomainRoute: IRouteConfig = {
match: { ports: 80, domains: ['example.com', 'example.org'] },
action: { type: 'forward', targets: [{ host: 'localhost', port: 3000 }] },
name: 'HTTP Route for example.com,example.org',
};
const multiDomainId = generateRouteId(multiDomainRoute);
expect(multiDomainId).toInclude('example-com-example-org');
});
tap.test('Route Utilities - cloneRoute', async () => {
// Create a route and clone it
const originalRoute = createHttpsTerminateRoute('example.com', { host: 'localhost', port: 3000 }, {
certificate: 'auto',
name: 'Original Route'
});
const originalRoute: IRouteConfig = {
match: { ports: 443, domains: 'example.com' },
action: { type: 'forward', targets: [{ host: 'localhost', port: 3000 }], tls: { mode: 'terminate', certificate: 'auto' } },
name: 'Original Route',
};
const clonedRoute = cloneRoute(originalRoute);
@@ -652,352 +745,4 @@ tap.test('Route Utilities - cloneRoute', async () => {
expect(originalRoute.name).toEqual('Original Route');
});
// --------------------------------- Route Helper Tests ---------------------------------
tap.test('Route Helpers - createHttpRoute', async () => {
const route = createHttpRoute('example.com', { host: 'localhost', port: 3000 });
expect(route.match.domains).toEqual('example.com');
expect(route.match.ports).toEqual(80);
expect(route.action.type).toEqual('forward');
expect(route.action.targets?.[0]?.host).toEqual('localhost');
expect(route.action.targets?.[0]?.port).toEqual(3000);
const validationResult = validateRouteConfig(route);
expect(validationResult.valid).toBeTrue();
});
tap.test('Route Helpers - createHttpsTerminateRoute', async () => {
const route = createHttpsTerminateRoute('example.com', { host: 'localhost', port: 3000 }, {
certificate: 'auto'
});
expect(route.match.domains).toEqual('example.com');
expect(route.match.ports).toEqual(443);
expect(route.action.type).toEqual('forward');
expect(route.action.tls.mode).toEqual('terminate');
expect(route.action.tls.certificate).toEqual('auto');
const validationResult = validateRouteConfig(route);
expect(validationResult.valid).toBeTrue();
});
tap.test('Route Helpers - createHttpToHttpsRedirect', async () => {
const route = createHttpToHttpsRedirect('example.com');
expect(route.match.domains).toEqual('example.com');
expect(route.match.ports).toEqual(80);
expect(route.action.type).toEqual('socket-handler');
expect(route.action.socketHandler).toBeDefined();
const validationResult = validateRouteConfig(route);
expect(validationResult.valid).toBeTrue();
});
tap.test('Route Helpers - createHttpsPassthroughRoute', async () => {
const route = createHttpsPassthroughRoute('example.com', { host: 'localhost', port: 3000 });
expect(route.match.domains).toEqual('example.com');
expect(route.match.ports).toEqual(443);
expect(route.action.type).toEqual('forward');
expect(route.action.tls.mode).toEqual('passthrough');
const validationResult = validateRouteConfig(route);
expect(validationResult.valid).toBeTrue();
});
tap.test('Route Helpers - createCompleteHttpsServer', async () => {
const routes = createCompleteHttpsServer('example.com', { host: 'localhost', port: 3000 }, {
certificate: 'auto'
});
expect(routes.length).toEqual(2);
// HTTPS route
expect(routes[0].match.domains).toEqual('example.com');
expect(routes[0].match.ports).toEqual(443);
expect(routes[0].action.type).toEqual('forward');
expect(routes[0].action.tls.mode).toEqual('terminate');
// HTTP redirect route
expect(routes[1].match.domains).toEqual('example.com');
expect(routes[1].match.ports).toEqual(80);
expect(routes[1].action.type).toEqual('socket-handler');
const validation1 = validateRouteConfig(routes[0]);
const validation2 = validateRouteConfig(routes[1]);
expect(validation1.valid).toBeTrue();
expect(validation2.valid).toBeTrue();
});
// createStaticFileRoute has been removed - static file serving should be handled by
// external servers (nginx/apache) behind the proxy
tap.test('Route Helpers - createApiRoute', async () => {
const route = createApiRoute('api.example.com', '/v1', { host: 'localhost', port: 3000 }, {
useTls: true,
certificate: 'auto',
addCorsHeaders: true
});
expect(route.match.domains).toEqual('api.example.com');
expect(route.match.ports).toEqual(443);
expect(route.match.path).toEqual('/v1/*');
expect(route.action.type).toEqual('forward');
expect(route.action.tls.mode).toEqual('terminate');
// Check CORS headers if they exist
if (route.headers && route.headers.response) {
expect(route.headers.response['Access-Control-Allow-Origin']).toEqual('*');
}
const validationResult = validateRouteConfig(route);
expect(validationResult.valid).toBeTrue();
});
tap.test('Route Helpers - createWebSocketRoute', async () => {
const route = createWebSocketRoute('ws.example.com', '/socket', { host: 'localhost', port: 3000 }, {
useTls: true,
certificate: 'auto',
pingInterval: 15000
});
expect(route.match.domains).toEqual('ws.example.com');
expect(route.match.ports).toEqual(443);
expect(route.match.path).toEqual('/socket');
expect(route.action.type).toEqual('forward');
expect(route.action.tls.mode).toEqual('terminate');
// Check websocket configuration if it exists
if (route.action.websocket) {
expect(route.action.websocket.enabled).toBeTrue();
expect(route.action.websocket.pingInterval).toEqual(15000);
}
const validationResult = validateRouteConfig(route);
expect(validationResult.valid).toBeTrue();
});
tap.test('Route Helpers - createLoadBalancerRoute', async () => {
const route = createLoadBalancerRoute(
'loadbalancer.example.com',
['server1.local', 'server2.local', 'server3.local'],
8080,
{
tls: {
mode: 'terminate',
certificate: 'auto'
}
}
);
expect(route.match.domains).toEqual('loadbalancer.example.com');
expect(route.match.ports).toEqual(443);
expect(route.action.type).toEqual('forward');
expect(route.action.targets).toBeDefined();
if (route.action.targets && Array.isArray(route.action.targets[0]?.host)) {
expect((route.action.targets[0].host as string[]).length).toEqual(3);
}
expect(route.action.targets?.[0]?.port).toEqual(8080);
expect(route.action.tls.mode).toEqual('terminate');
const validationResult = validateRouteConfig(route);
expect(validationResult.valid).toBeTrue();
});
// --------------------------------- Route Pattern Tests ---------------------------------
tap.test('Route Patterns - createApiGatewayRoute', async () => {
// Create API Gateway route
const apiGatewayRoute = createApiGatewayRoute(
'api.example.com',
'/v1',
{ host: 'localhost', port: 3000 },
{
useTls: true,
addCorsHeaders: true
}
);
// Validate route configuration
expect(apiGatewayRoute.match.domains).toEqual('api.example.com');
expect(apiGatewayRoute.match.path).toInclude('/v1');
expect(apiGatewayRoute.action.type).toEqual('forward');
expect(apiGatewayRoute.action.targets?.[0]?.port).toEqual(3000);
// Check TLS configuration
if (apiGatewayRoute.action.tls) {
expect(apiGatewayRoute.action.tls.mode).toEqual('terminate');
}
// Check CORS headers
if (apiGatewayRoute.headers && apiGatewayRoute.headers.response) {
expect(apiGatewayRoute.headers.response['Access-Control-Allow-Origin']).toEqual('*');
}
const result = validateRouteConfig(apiGatewayRoute);
expect(result.valid).toBeTrue();
});
// createStaticFileServerRoute has been removed - static file serving should be handled by
// external servers (nginx/apache) behind the proxy
tap.test('Route Patterns - createWebSocketPattern', async () => {
// Create WebSocket route pattern
const wsRoute = createWebSocketPattern(
'ws.example.com',
{ host: 'localhost', port: 3000 },
{
useTls: true,
path: '/socket',
pingInterval: 10000
}
);
// Validate route configuration
expect(wsRoute.match.domains).toEqual('ws.example.com');
expect(wsRoute.match.path).toEqual('/socket');
expect(wsRoute.action.type).toEqual('forward');
expect(wsRoute.action.targets?.[0]?.port).toEqual(3000);
// Check TLS configuration
if (wsRoute.action.tls) {
expect(wsRoute.action.tls.mode).toEqual('terminate');
}
// Check websocket configuration if it exists
if (wsRoute.action.websocket) {
expect(wsRoute.action.websocket.enabled).toBeTrue();
expect(wsRoute.action.websocket.pingInterval).toEqual(10000);
}
const result = validateRouteConfig(wsRoute);
expect(result.valid).toBeTrue();
});
tap.test('Route Patterns - createLoadBalancerRoute pattern', async () => {
// Create load balancer route pattern with missing algorithm as it might not be implemented yet
try {
const lbRoute = createLbPattern(
'lb.example.com',
[
{ host: 'server1.local', port: 8080 },
{ host: 'server2.local', port: 8080 },
{ host: 'server3.local', port: 8080 }
],
{
useTls: true
}
);
// Validate route configuration
expect(lbRoute.match.domains).toEqual('lb.example.com');
expect(lbRoute.action.type).toEqual('forward');
// Check target hosts
if (lbRoute.action.targets && Array.isArray(lbRoute.action.targets[0]?.host)) {
expect((lbRoute.action.targets[0].host as string[]).length).toEqual(3);
}
// Check TLS configuration
if (lbRoute.action.tls) {
expect(lbRoute.action.tls.mode).toEqual('terminate');
}
const result = validateRouteConfig(lbRoute);
expect(result.valid).toBeTrue();
} catch (error) {
// If the pattern is not implemented yet, skip this test
console.log('Load balancer pattern might not be fully implemented yet');
}
});
tap.test('Route Security - addRateLimiting', async () => {
// Create base route
const baseRoute = createHttpRoute('example.com', { host: 'localhost', port: 3000 });
// Add rate limiting
const secureRoute = addRateLimiting(baseRoute, {
maxRequests: 100,
window: 60, // 1 minute
keyBy: 'ip'
});
// Check if rate limiting is applied
if (secureRoute.security) {
expect(secureRoute.security.rateLimit?.enabled).toBeTrue();
expect(secureRoute.security.rateLimit?.maxRequests).toEqual(100);
expect(secureRoute.security.rateLimit?.window).toEqual(60);
expect(secureRoute.security.rateLimit?.keyBy).toEqual('ip');
} else {
// Skip this test if security features are not implemented yet
console.log('Security features not implemented yet in route configuration');
}
// Just check that the route itself is valid
const result = validateRouteConfig(secureRoute);
expect(result.valid).toBeTrue();
});
tap.test('Route Security - addBasicAuth', async () => {
// Create base route
const baseRoute = createHttpRoute('example.com', { host: 'localhost', port: 3000 });
// Add basic authentication
const authRoute = addBasicAuth(baseRoute, {
users: [
{ username: 'admin', password: 'secret' },
{ username: 'user', password: 'password' }
],
realm: 'Protected Area',
excludePaths: ['/public']
});
// Check if basic auth is applied
if (authRoute.security) {
expect(authRoute.security.basicAuth?.enabled).toBeTrue();
expect(authRoute.security.basicAuth?.users.length).toEqual(2);
expect(authRoute.security.basicAuth?.realm).toEqual('Protected Area');
expect(authRoute.security.basicAuth?.excludePaths).toInclude('/public');
} else {
// Skip this test if security features are not implemented yet
console.log('Security features not implemented yet in route configuration');
}
// Check that the route itself is valid
const result = validateRouteConfig(authRoute);
expect(result.valid).toBeTrue();
});
tap.test('Route Security - addJwtAuth', async () => {
// Create base route
const baseRoute = createHttpRoute('example.com', { host: 'localhost', port: 3000 });
// Add JWT authentication
const jwtRoute = addJwtAuth(baseRoute, {
secret: 'your-jwt-secret-key',
algorithm: 'HS256',
issuer: 'auth.example.com',
audience: 'api.example.com',
expiresIn: 3600
});
// Check if JWT auth is applied
if (jwtRoute.security) {
expect(jwtRoute.security.jwtAuth?.enabled).toBeTrue();
expect(jwtRoute.security.jwtAuth?.secret).toEqual('your-jwt-secret-key');
expect(jwtRoute.security.jwtAuth?.algorithm).toEqual('HS256');
expect(jwtRoute.security.jwtAuth?.issuer).toEqual('auth.example.com');
expect(jwtRoute.security.jwtAuth?.audience).toEqual('api.example.com');
expect(jwtRoute.security.jwtAuth?.expiresIn).toEqual(3600);
} else {
// Skip this test if security features are not implemented yet
console.log('Security features not implemented yet in route configuration');
}
// Check that the route itself is valid
const result = validateRouteConfig(jwtRoute);
expect(result.valid).toBeTrue();
});
export default tap.start();
-403
View File
@@ -1,403 +0,0 @@
import { expect, tap } from '@git.zone/tstest/tapbundle';
import * as http from 'http';
import { HttpRouter, type RouterResult } from '../ts/routing/router/http-router.js';
import type { IRouteConfig } from '../ts/proxies/smart-proxy/models/route-types.js';
// Test proxies and configurations
let router: HttpRouter;
// Sample hostname for testing
const TEST_DOMAIN = 'example.com';
const TEST_SUBDOMAIN = 'api.example.com';
const TEST_WILDCARD = '*.example.com';
// Helper: Creates a mock HTTP request for testing
function createMockRequest(host: string, url: string = '/'): http.IncomingMessage {
const req = {
headers: { host },
url,
socket: {
remoteAddress: '127.0.0.1'
}
} as any;
return req;
}
// Helper: Creates a test route configuration
function createRouteConfig(
hostname: string,
destinationIp: string = '10.0.0.1',
destinationPort: number = 8080
): IRouteConfig {
return {
name: `route-${hostname}`,
match: {
domains: [hostname],
ports: 443
},
action: {
type: 'forward',
targets: [{
host: destinationIp,
port: destinationPort
}]
}
};
}
// SETUP: Create an HttpRouter instance
tap.test('setup http router test environment', async () => {
router = new HttpRouter();
// Initialize with empty config
router.setRoutes([]);
});
// Test basic routing by hostname
tap.test('should route requests by hostname', async () => {
const config = createRouteConfig(TEST_DOMAIN);
router.setRoutes([config]);
const req = createMockRequest(TEST_DOMAIN);
const result = router.routeReq(req);
expect(result).toBeTruthy();
expect(result).toEqual(config);
});
// Test handling of hostname with port number
tap.test('should handle hostname with port number', async () => {
const config = createRouteConfig(TEST_DOMAIN);
router.setRoutes([config]);
const req = createMockRequest(`${TEST_DOMAIN}:443`);
const result = router.routeReq(req);
expect(result).toBeTruthy();
expect(result).toEqual(config);
});
// Test case-insensitive hostname matching
tap.test('should perform case-insensitive hostname matching', async () => {
const config = createRouteConfig(TEST_DOMAIN.toLowerCase());
router.setRoutes([config]);
const req = createMockRequest(TEST_DOMAIN.toUpperCase());
const result = router.routeReq(req);
expect(result).toBeTruthy();
expect(result).toEqual(config);
});
// Test handling of unmatched hostnames
tap.test('should return undefined for unmatched hostnames', async () => {
const config = createRouteConfig(TEST_DOMAIN);
router.setRoutes([config]);
const req = createMockRequest('unknown.domain.com');
const result = router.routeReq(req);
expect(result).toBeUndefined();
});
// Test adding path patterns
tap.test('should match requests using path patterns', async () => {
const config = createRouteConfig(TEST_DOMAIN);
config.match.path = '/api/users';
router.setRoutes([config]);
// Test that path matches
const req1 = createMockRequest(TEST_DOMAIN, '/api/users');
const result1 = router.routeReqWithDetails(req1);
expect(result1).toBeTruthy();
expect(result1.route).toEqual(config);
expect(result1.pathMatch).toEqual('/api/users');
// Test that non-matching path doesn't match
const req2 = createMockRequest(TEST_DOMAIN, '/web/users');
const result2 = router.routeReqWithDetails(req2);
expect(result2).toBeUndefined();
});
// Test handling wildcard patterns
tap.test('should support wildcard path patterns', async () => {
const config = createRouteConfig(TEST_DOMAIN);
config.match.path = '/api/*';
router.setRoutes([config]);
// Test with path that matches the wildcard pattern
const req = createMockRequest(TEST_DOMAIN, '/api/users/123');
const result = router.routeReqWithDetails(req);
expect(result).toBeTruthy();
expect(result.route).toEqual(config);
expect(result.pathMatch).toEqual('/api');
// Print the actual value to diagnose issues
console.log('Path remainder value:', result.pathRemainder);
expect(result.pathRemainder).toBeTruthy();
expect(result.pathRemainder).toEqual('/users/123');
});
// Test extracting path parameters
tap.test('should extract path parameters from URL', async () => {
const config = createRouteConfig(TEST_DOMAIN);
config.match.path = '/users/:id/profile';
router.setRoutes([config]);
const req = createMockRequest(TEST_DOMAIN, '/users/123/profile');
const result = router.routeReqWithDetails(req);
expect(result).toBeTruthy();
expect(result.route).toEqual(config);
expect(result.pathParams).toBeTruthy();
expect(result.pathParams.id).toEqual('123');
});
// Test multiple configs for same hostname with different paths
tap.test('should support multiple configs for same hostname with different paths', async () => {
const apiConfig = createRouteConfig(TEST_DOMAIN, '10.0.0.1', 8001);
apiConfig.match.path = '/api/*';
apiConfig.name = 'api-route';
const webConfig = createRouteConfig(TEST_DOMAIN, '10.0.0.2', 8002);
webConfig.match.path = '/web/*';
webConfig.name = 'web-route';
// Add both configs
router.setRoutes([apiConfig, webConfig]);
// Test API path routes to API config
const apiReq = createMockRequest(TEST_DOMAIN, '/api/users');
const apiResult = router.routeReq(apiReq);
expect(apiResult).toEqual(apiConfig);
// Test web path routes to web config
const webReq = createMockRequest(TEST_DOMAIN, '/web/dashboard');
const webResult = router.routeReq(webReq);
expect(webResult).toEqual(webConfig);
// Test unknown path returns undefined
const unknownReq = createMockRequest(TEST_DOMAIN, '/unknown');
const unknownResult = router.routeReq(unknownReq);
expect(unknownResult).toBeUndefined();
});
// Test wildcard subdomains
tap.test('should match wildcard subdomains', async () => {
const wildcardConfig = createRouteConfig(TEST_WILDCARD);
router.setRoutes([wildcardConfig]);
// Test that subdomain.example.com matches *.example.com
const req = createMockRequest('subdomain.example.com');
const result = router.routeReq(req);
expect(result).toBeTruthy();
expect(result).toEqual(wildcardConfig);
});
// Test TLD wildcards (example.*)
tap.test('should match TLD wildcards', async () => {
const tldWildcardConfig = createRouteConfig('example.*');
router.setRoutes([tldWildcardConfig]);
// Test that example.com matches example.*
const req1 = createMockRequest('example.com');
const result1 = router.routeReq(req1);
expect(result1).toBeTruthy();
expect(result1).toEqual(tldWildcardConfig);
// Test that example.org matches example.*
const req2 = createMockRequest('example.org');
const result2 = router.routeReq(req2);
expect(result2).toBeTruthy();
expect(result2).toEqual(tldWildcardConfig);
// Test that subdomain.example.com doesn't match example.*
const req3 = createMockRequest('subdomain.example.com');
const result3 = router.routeReq(req3);
expect(result3).toBeUndefined();
});
// Test complex pattern matching (*.lossless*)
tap.test('should match complex wildcard patterns', async () => {
const complexWildcardConfig = createRouteConfig('*.lossless*');
router.setRoutes([complexWildcardConfig]);
// Test that sub.lossless.com matches *.lossless*
const req1 = createMockRequest('sub.lossless.com');
const result1 = router.routeReq(req1);
expect(result1).toBeTruthy();
expect(result1).toEqual(complexWildcardConfig);
// Test that api.lossless.org matches *.lossless*
const req2 = createMockRequest('api.lossless.org');
const result2 = router.routeReq(req2);
expect(result2).toBeTruthy();
expect(result2).toEqual(complexWildcardConfig);
// Test that losslessapi.com matches *.lossless*
const req3 = createMockRequest('losslessapi.com');
const result3 = router.routeReq(req3);
expect(result3).toBeUndefined(); // Should not match as it doesn't have a subdomain
});
// Test default configuration fallback
tap.test('should fall back to default configuration', async () => {
const defaultConfig = createRouteConfig('*');
const specificConfig = createRouteConfig(TEST_DOMAIN);
router.setRoutes([specificConfig, defaultConfig]);
// Test specific domain routes to specific config
const specificReq = createMockRequest(TEST_DOMAIN);
const specificResult = router.routeReq(specificReq);
expect(specificResult).toEqual(specificConfig);
// Test unknown domain falls back to default config
const unknownReq = createMockRequest('unknown.com');
const unknownResult = router.routeReq(unknownReq);
expect(unknownResult).toEqual(defaultConfig);
});
// Test priority between exact and wildcard matches
tap.test('should prioritize exact hostname over wildcard', async () => {
const wildcardConfig = createRouteConfig(TEST_WILDCARD);
const exactConfig = createRouteConfig(TEST_SUBDOMAIN);
router.setRoutes([exactConfig, wildcardConfig]);
// Test that exact match takes priority
const req = createMockRequest(TEST_SUBDOMAIN);
const result = router.routeReq(req);
expect(result).toEqual(exactConfig);
});
// Test adding and removing configurations
tap.test('should manage configurations correctly', async () => {
router.setRoutes([]);
// Add a config
const config = createRouteConfig(TEST_DOMAIN);
router.setRoutes([config]);
// Verify routing works
const req = createMockRequest(TEST_DOMAIN);
let result = router.routeReq(req);
expect(result).toEqual(config);
// Remove the config and verify it no longer routes
router.setRoutes([]);
result = router.routeReq(req);
expect(result).toBeUndefined();
});
// Test path pattern specificity
tap.test('should prioritize more specific path patterns', async () => {
const genericConfig = createRouteConfig(TEST_DOMAIN, '10.0.0.1', 8001);
genericConfig.match.path = '/api/*';
genericConfig.name = 'generic-api';
const specificConfig = createRouteConfig(TEST_DOMAIN, '10.0.0.2', 8002);
specificConfig.match.path = '/api/users';
specificConfig.name = 'specific-api';
specificConfig.priority = 10; // Higher priority
router.setRoutes([genericConfig, specificConfig]);
// The more specific '/api/users' should match before the '/api/*' wildcard
const req = createMockRequest(TEST_DOMAIN, '/api/users');
const result = router.routeReq(req);
expect(result).toEqual(specificConfig);
});
// Test multiple hostnames
tap.test('should handle multiple configured hostnames', async () => {
const routes = [
createRouteConfig(TEST_DOMAIN),
createRouteConfig(TEST_SUBDOMAIN)
];
router.setRoutes(routes);
// Test first domain routes correctly
const req1 = createMockRequest(TEST_DOMAIN);
const result1 = router.routeReq(req1);
expect(result1).toEqual(routes[0]);
// Test second domain routes correctly
const req2 = createMockRequest(TEST_SUBDOMAIN);
const result2 = router.routeReq(req2);
expect(result2).toEqual(routes[1]);
});
// Test handling missing host header
tap.test('should handle missing host header', async () => {
const defaultConfig = createRouteConfig('*');
router.setRoutes([defaultConfig]);
const req = createMockRequest('');
req.headers.host = undefined;
const result = router.routeReq(req);
expect(result).toEqual(defaultConfig);
});
// Test complex path parameters
tap.test('should handle complex path parameters', async () => {
const config = createRouteConfig(TEST_DOMAIN);
config.match.path = '/api/:version/users/:userId/posts/:postId';
router.setRoutes([config]);
const req = createMockRequest(TEST_DOMAIN, '/api/v1/users/123/posts/456');
const result = router.routeReqWithDetails(req);
expect(result).toBeTruthy();
expect(result.route).toEqual(config);
expect(result.pathParams).toBeTruthy();
expect(result.pathParams.version).toEqual('v1');
expect(result.pathParams.userId).toEqual('123');
expect(result.pathParams.postId).toEqual('456');
});
// Performance test
tap.test('should handle many configurations efficiently', async () => {
const configs = [];
// Create many configs with different hostnames
for (let i = 0; i < 100; i++) {
configs.push(createRouteConfig(`host-${i}.example.com`));
}
router.setRoutes(configs);
// Test middle of the list to avoid best/worst case
const req = createMockRequest('host-50.example.com');
const result = router.routeReq(req);
expect(result).toEqual(configs[50]);
});
// Test cleanup
tap.test('cleanup proxy router test environment', async () => {
// Clear all configurations
router.setRoutes([]);
// Verify empty state by testing that no routes match
const req = createMockRequest(TEST_DOMAIN);
const result = router.routeReq(req);
expect(result).toBeUndefined();
});
export default tap.start();

Some files were not shown because too many files have changed in this diff Show More