From 1df3b7af4ab1c696a1ee5db9bfb517744e157bd1 Mon Sep 17 00:00:00 2001 From: Juergen Kunz Date: Mon, 9 Feb 2026 10:55:46 +0000 Subject: [PATCH] feat(rustproxy): introduce a Rust-powered proxy engine and workspace with core crates for proxy functionality, ACME/TLS support, passthrough and HTTP proxies, metrics, nftables integration, routing/security, management IPC, tests, and README updates --- .gitignore | 3 +- changelog.md | 15 + readme.md | 719 ++++--- rust/Cargo.lock | 1760 +++++++++++++++++ rust/Cargo.toml | 98 + rust/config/example.json | 145 ++ rust/crates/rustproxy-config/Cargo.toml | 13 + rust/crates/rustproxy-config/src/helpers.rs | 334 ++++ rust/crates/rustproxy-config/src/lib.rs | 19 + .../rustproxy-config/src/proxy_options.rs | 439 ++++ .../rustproxy-config/src/route_types.rs | 603 ++++++ .../rustproxy-config/src/security_types.rs | 132 ++ rust/crates/rustproxy-config/src/tls_types.rs | 93 + .../crates/rustproxy-config/src/validation.rs | 158 ++ rust/crates/rustproxy-http/Cargo.toml | 24 + rust/crates/rustproxy-http/src/lib.rs | 14 + .../rustproxy-http/src/proxy_service.rs | 827 ++++++++ .../rustproxy-http/src/request_filter.rs | 263 +++ .../rustproxy-http/src/response_filter.rs | 92 + rust/crates/rustproxy-http/src/template.rs | 162 ++ .../rustproxy-http/src/upstream_selector.rs | 222 +++ rust/crates/rustproxy-metrics/Cargo.toml | 15 + .../crates/rustproxy-metrics/src/collector.rs | 251 +++ rust/crates/rustproxy-metrics/src/lib.rs | 11 + .../crates/rustproxy-metrics/src/log_dedup.rs | 219 ++ .../rustproxy-metrics/src/throughput.rs | 173 ++ rust/crates/rustproxy-nftables/Cargo.toml | 17 + rust/crates/rustproxy-nftables/src/lib.rs | 10 + .../rustproxy-nftables/src/nft_manager.rs | 238 +++ .../rustproxy-nftables/src/rule_builder.rs | 123 ++ rust/crates/rustproxy-passthrough/Cargo.toml | 25 + .../src/connection_record.rs | 155 ++ .../src/connection_tracker.rs | 402 ++++ .../rustproxy-passthrough/src/forwarder.rs | 325 +++ rust/crates/rustproxy-passthrough/src/lib.rs | 22 + .../src/proxy_protocol.rs | 129 ++ .../rustproxy-passthrough/src/sni_parser.rs | 287 +++ .../rustproxy-passthrough/src/socket_relay.rs | 126 ++ .../rustproxy-passthrough/src/tcp_listener.rs | 874 ++++++++ .../rustproxy-passthrough/src/tls_handler.rs | 190 ++ rust/crates/rustproxy-routing/Cargo.toml | 16 + rust/crates/rustproxy-routing/src/lib.rs | 9 + .../rustproxy-routing/src/matchers/domain.rs | 86 + .../rustproxy-routing/src/matchers/header.rs | 98 + .../rustproxy-routing/src/matchers/ip.rs | 126 ++ .../rustproxy-routing/src/matchers/mod.rs | 9 + .../rustproxy-routing/src/matchers/path.rs | 65 + .../rustproxy-routing/src/route_manager.rs | 545 +++++ rust/crates/rustproxy-security/Cargo.toml | 17 + .../rustproxy-security/src/basic_auth.rs | 111 ++ .../rustproxy-security/src/ip_filter.rs | 189 ++ .../crates/rustproxy-security/src/jwt_auth.rs | 174 ++ rust/crates/rustproxy-security/src/lib.rs | 13 + .../rustproxy-security/src/rate_limiter.rs | 97 + rust/crates/rustproxy-tls/Cargo.toml | 22 + rust/crates/rustproxy-tls/src/acme.rs | 360 ++++ rust/crates/rustproxy-tls/src/cert_manager.rs | 183 ++ rust/crates/rustproxy-tls/src/cert_store.rs | 314 +++ rust/crates/rustproxy-tls/src/lib.rs | 13 + rust/crates/rustproxy-tls/src/sni_resolver.rs | 139 ++ rust/crates/rustproxy/Cargo.toml | 44 + rust/crates/rustproxy/src/challenge_server.rs | 177 ++ rust/crates/rustproxy/src/lib.rs | 931 +++++++++ rust/crates/rustproxy/src/main.rs | 90 + rust/crates/rustproxy/src/management.rs | 470 +++++ rust/crates/rustproxy/tests/common/mod.rs | 402 ++++ .../rustproxy/tests/integration_http_proxy.rs | 453 +++++ .../tests/integration_proxy_lifecycle.rs | 250 +++ .../tests/integration_tcp_passthrough.rs | 197 ++ .../tests/integration_tls_passthrough.rs | 247 +++ .../tests/integration_tls_terminate.rs | 324 +++ test/test.acme-route-creation.ts | 218 -- test/test.acme-state-manager.node.ts | 188 -- test/test.acme-timing-simple.ts | 122 -- test/test.acme-timing.ts | 204 -- test/test.certificate-acme-update.ts | 77 - test/test.certificate-provision.ts | 423 ---- test/test.certificate-provisioning.ts | 241 --- test/test.cleanup-queue-bug.node.ts | 146 -- test/test.connect-disconnect-cleanup.node.ts | 240 --- ...t.connection-cleanup-comprehensive.node.ts | 277 --- test/test.connection-limits.node.ts | 304 --- test/test.fix-verification.ts | 83 - test/test.http-fix-unit.ts | 183 -- test/test.http-fix-verification.ts | 256 --- test/test.http-forwarding-fix.ts | 189 -- test/test.http-port8080-simple.ts | 246 --- test/test.http-proxy-security-limits.node.ts | 114 -- test/test.httpproxy.function-targets.ts | 405 ---- test/test.httpproxy.ts | 596 ------ test/test.keepalive-support.node.ts | 250 --- test/test.memory-leak-check.node.ts | 151 -- test/test.memory-leak-simple.ts | 59 - test/test.memory-leak-unit.ts | 131 -- test/test.metrics-collector.ts | 280 --- test/test.nftables-manager.ts | 188 -- test/test.nftables-status.ts | 166 -- test/test.port80-management.node.ts | 281 --- test/test.proxy-chain-cleanup.node.ts | 182 -- test/test.proxy-chain-simple.node.ts | 193 -- test/test.proxy-chaining-accumulation.node.ts | 364 ---- test/test.rapid-retry-cleanup.node.ts | 199 -- test/test.route-callback-simple.ts | 117 -- test/test.route-update-callback.node.ts | 343 ---- test/test.smartacme-integration.ts | 54 - test/test.stuck-connection-cleanup.node.ts | 144 -- test/test.websocket-keepalive.node.ts | 157 -- test/test.wrapped-socket.ts | 57 - test/test.zombie-connection-cleanup.node.ts | 304 --- ts/00_commitinfo_data.ts | 2 +- ts/index.ts | 16 +- ts/proxies/http-proxy/connection-pool.ts | 228 --- ts/proxies/http-proxy/context-creator.ts | 145 -- ts/proxies/http-proxy/default-certificates.ts | 150 -- ts/proxies/http-proxy/function-cache.ts | 279 --- ts/proxies/http-proxy/handlers/index.ts | 5 - ts/proxies/http-proxy/http-proxy.ts | 669 ------- ts/proxies/http-proxy/http-request-handler.ts | 331 ---- .../http-proxy/http2-request-handler.ts | 255 --- ts/proxies/http-proxy/index.ts | 18 - ts/proxies/http-proxy/models/http-types.ts | 148 -- ts/proxies/http-proxy/models/index.ts | 5 - ts/proxies/http-proxy/models/types.ts | 125 -- ts/proxies/http-proxy/request-handler.ts | 878 -------- ts/proxies/http-proxy/security-manager.ts | 413 ---- ts/proxies/http-proxy/websocket-handler.ts | 581 ------ ts/proxies/index.ts | 10 +- ts/proxies/smart-proxy/acme-state-manager.ts | 112 -- ts/proxies/smart-proxy/cert-store.ts | 92 - ts/proxies/smart-proxy/certificate-manager.ts | 895 --------- ts/proxies/smart-proxy/connection-manager.ts | 809 -------- ts/proxies/smart-proxy/http-proxy-bridge.ts | 213 -- ts/proxies/smart-proxy/index.ts | 19 +- ts/proxies/smart-proxy/metrics-collector.ts | 453 ----- ts/proxies/smart-proxy/models/interfaces.ts | 10 +- ts/proxies/smart-proxy/nftables-manager.ts | 271 --- ts/proxies/smart-proxy/port-manager.ts | 358 ---- .../smart-proxy/route-connection-handler.ts | 1712 ---------------- ts/proxies/smart-proxy/route-orchestrator.ts | 297 --- ts/proxies/smart-proxy/route-preprocessor.ts | 122 ++ ts/proxies/smart-proxy/rust-binary-locator.ts | 112 ++ .../smart-proxy/rust-metrics-adapter.ts | 136 ++ ts/proxies/smart-proxy/rust-proxy-bridge.ts | 278 +++ ts/proxies/smart-proxy/security-manager.ts | 269 --- ts/proxies/smart-proxy/smart-proxy.ts | 1075 +++------- .../smart-proxy/socket-handler-server.ts | 178 ++ ts/proxies/smart-proxy/throughput-tracker.ts | 138 -- ts/proxies/smart-proxy/timeout-manager.ts | 196 -- ts/proxies/smart-proxy/tls-manager.ts | 171 -- ts/routing/index.ts | 4 +- ts/routing/models/http-types.ts | 151 +- 151 files changed, 16927 insertions(+), 19432 deletions(-) create mode 100644 rust/Cargo.lock create mode 100644 rust/Cargo.toml create mode 100644 rust/config/example.json create mode 100644 rust/crates/rustproxy-config/Cargo.toml create mode 100644 rust/crates/rustproxy-config/src/helpers.rs create mode 100644 rust/crates/rustproxy-config/src/lib.rs create mode 100644 rust/crates/rustproxy-config/src/proxy_options.rs create mode 100644 rust/crates/rustproxy-config/src/route_types.rs create mode 100644 rust/crates/rustproxy-config/src/security_types.rs create mode 100644 rust/crates/rustproxy-config/src/tls_types.rs create mode 100644 rust/crates/rustproxy-config/src/validation.rs create mode 100644 rust/crates/rustproxy-http/Cargo.toml create mode 100644 rust/crates/rustproxy-http/src/lib.rs create mode 100644 rust/crates/rustproxy-http/src/proxy_service.rs create mode 100644 rust/crates/rustproxy-http/src/request_filter.rs create mode 100644 rust/crates/rustproxy-http/src/response_filter.rs create mode 100644 rust/crates/rustproxy-http/src/template.rs create mode 100644 rust/crates/rustproxy-http/src/upstream_selector.rs create mode 100644 rust/crates/rustproxy-metrics/Cargo.toml create mode 100644 rust/crates/rustproxy-metrics/src/collector.rs create mode 100644 rust/crates/rustproxy-metrics/src/lib.rs create mode 100644 rust/crates/rustproxy-metrics/src/log_dedup.rs create mode 100644 rust/crates/rustproxy-metrics/src/throughput.rs create mode 100644 rust/crates/rustproxy-nftables/Cargo.toml create mode 100644 rust/crates/rustproxy-nftables/src/lib.rs create mode 100644 rust/crates/rustproxy-nftables/src/nft_manager.rs create mode 100644 rust/crates/rustproxy-nftables/src/rule_builder.rs create mode 100644 rust/crates/rustproxy-passthrough/Cargo.toml create mode 100644 rust/crates/rustproxy-passthrough/src/connection_record.rs create mode 100644 rust/crates/rustproxy-passthrough/src/connection_tracker.rs create mode 100644 rust/crates/rustproxy-passthrough/src/forwarder.rs create mode 100644 rust/crates/rustproxy-passthrough/src/lib.rs create mode 100644 rust/crates/rustproxy-passthrough/src/proxy_protocol.rs create mode 100644 rust/crates/rustproxy-passthrough/src/sni_parser.rs create mode 100644 rust/crates/rustproxy-passthrough/src/socket_relay.rs create mode 100644 rust/crates/rustproxy-passthrough/src/tcp_listener.rs create mode 100644 rust/crates/rustproxy-passthrough/src/tls_handler.rs create mode 100644 rust/crates/rustproxy-routing/Cargo.toml create mode 100644 rust/crates/rustproxy-routing/src/lib.rs create mode 100644 rust/crates/rustproxy-routing/src/matchers/domain.rs create mode 100644 rust/crates/rustproxy-routing/src/matchers/header.rs create mode 100644 rust/crates/rustproxy-routing/src/matchers/ip.rs create mode 100644 rust/crates/rustproxy-routing/src/matchers/mod.rs create mode 100644 rust/crates/rustproxy-routing/src/matchers/path.rs create mode 100644 rust/crates/rustproxy-routing/src/route_manager.rs create mode 100644 rust/crates/rustproxy-security/Cargo.toml create mode 100644 rust/crates/rustproxy-security/src/basic_auth.rs create mode 100644 rust/crates/rustproxy-security/src/ip_filter.rs create mode 100644 rust/crates/rustproxy-security/src/jwt_auth.rs create mode 100644 rust/crates/rustproxy-security/src/lib.rs create mode 100644 rust/crates/rustproxy-security/src/rate_limiter.rs create mode 100644 rust/crates/rustproxy-tls/Cargo.toml create mode 100644 rust/crates/rustproxy-tls/src/acme.rs create mode 100644 rust/crates/rustproxy-tls/src/cert_manager.rs create mode 100644 rust/crates/rustproxy-tls/src/cert_store.rs create mode 100644 rust/crates/rustproxy-tls/src/lib.rs create mode 100644 rust/crates/rustproxy-tls/src/sni_resolver.rs create mode 100644 rust/crates/rustproxy/Cargo.toml create mode 100644 rust/crates/rustproxy/src/challenge_server.rs create mode 100644 rust/crates/rustproxy/src/lib.rs create mode 100644 rust/crates/rustproxy/src/main.rs create mode 100644 rust/crates/rustproxy/src/management.rs create mode 100644 rust/crates/rustproxy/tests/common/mod.rs create mode 100644 rust/crates/rustproxy/tests/integration_http_proxy.rs create mode 100644 rust/crates/rustproxy/tests/integration_proxy_lifecycle.rs create mode 100644 rust/crates/rustproxy/tests/integration_tcp_passthrough.rs create mode 100644 rust/crates/rustproxy/tests/integration_tls_passthrough.rs create mode 100644 rust/crates/rustproxy/tests/integration_tls_terminate.rs delete mode 100644 test/test.acme-route-creation.ts delete mode 100644 test/test.acme-state-manager.node.ts delete mode 100644 test/test.acme-timing-simple.ts delete mode 100644 test/test.acme-timing.ts delete mode 100644 test/test.certificate-acme-update.ts delete mode 100644 test/test.certificate-provision.ts delete mode 100644 test/test.certificate-provisioning.ts delete mode 100644 test/test.cleanup-queue-bug.node.ts delete mode 100644 test/test.connect-disconnect-cleanup.node.ts delete mode 100644 test/test.connection-cleanup-comprehensive.node.ts delete mode 100644 test/test.connection-limits.node.ts delete mode 100644 test/test.fix-verification.ts delete mode 100644 test/test.http-fix-unit.ts delete mode 100644 test/test.http-fix-verification.ts delete mode 100644 test/test.http-forwarding-fix.ts delete mode 100644 test/test.http-port8080-simple.ts delete mode 100644 test/test.http-proxy-security-limits.node.ts delete mode 100644 test/test.httpproxy.function-targets.ts delete mode 100644 test/test.httpproxy.ts delete mode 100644 test/test.keepalive-support.node.ts delete mode 100644 test/test.memory-leak-check.node.ts delete mode 100644 test/test.memory-leak-simple.ts delete mode 100644 test/test.memory-leak-unit.ts delete mode 100644 test/test.metrics-collector.ts delete mode 100644 test/test.nftables-manager.ts delete mode 100644 test/test.nftables-status.ts delete mode 100644 test/test.port80-management.node.ts delete mode 100644 test/test.proxy-chain-cleanup.node.ts delete mode 100644 test/test.proxy-chain-simple.node.ts delete mode 100644 test/test.proxy-chaining-accumulation.node.ts delete mode 100644 test/test.rapid-retry-cleanup.node.ts delete mode 100644 test/test.route-callback-simple.ts delete mode 100644 test/test.route-update-callback.node.ts delete mode 100644 test/test.smartacme-integration.ts delete mode 100644 test/test.stuck-connection-cleanup.node.ts delete mode 100644 test/test.websocket-keepalive.node.ts delete mode 100644 test/test.zombie-connection-cleanup.node.ts delete mode 100644 ts/proxies/http-proxy/connection-pool.ts delete mode 100644 ts/proxies/http-proxy/context-creator.ts delete mode 100644 ts/proxies/http-proxy/default-certificates.ts delete mode 100644 ts/proxies/http-proxy/function-cache.ts delete mode 100644 ts/proxies/http-proxy/handlers/index.ts delete mode 100644 ts/proxies/http-proxy/http-proxy.ts delete mode 100644 ts/proxies/http-proxy/http-request-handler.ts delete mode 100644 ts/proxies/http-proxy/http2-request-handler.ts delete mode 100644 ts/proxies/http-proxy/index.ts delete mode 100644 ts/proxies/http-proxy/models/http-types.ts delete mode 100644 ts/proxies/http-proxy/models/index.ts delete mode 100644 ts/proxies/http-proxy/models/types.ts delete mode 100644 ts/proxies/http-proxy/request-handler.ts delete mode 100644 ts/proxies/http-proxy/security-manager.ts delete mode 100644 ts/proxies/http-proxy/websocket-handler.ts delete mode 100644 ts/proxies/smart-proxy/acme-state-manager.ts delete mode 100644 ts/proxies/smart-proxy/cert-store.ts delete mode 100644 ts/proxies/smart-proxy/certificate-manager.ts delete mode 100644 ts/proxies/smart-proxy/connection-manager.ts delete mode 100644 ts/proxies/smart-proxy/http-proxy-bridge.ts delete mode 100644 ts/proxies/smart-proxy/metrics-collector.ts delete mode 100644 ts/proxies/smart-proxy/nftables-manager.ts delete mode 100644 ts/proxies/smart-proxy/port-manager.ts delete mode 100644 ts/proxies/smart-proxy/route-connection-handler.ts delete mode 100644 ts/proxies/smart-proxy/route-orchestrator.ts create mode 100644 ts/proxies/smart-proxy/route-preprocessor.ts create mode 100644 ts/proxies/smart-proxy/rust-binary-locator.ts create mode 100644 ts/proxies/smart-proxy/rust-metrics-adapter.ts create mode 100644 ts/proxies/smart-proxy/rust-proxy-bridge.ts delete mode 100644 ts/proxies/smart-proxy/security-manager.ts create mode 100644 ts/proxies/smart-proxy/socket-handler-server.ts delete mode 100644 ts/proxies/smart-proxy/throughput-tracker.ts delete mode 100644 ts/proxies/smart-proxy/timeout-manager.ts delete mode 100644 ts/proxies/smart-proxy/tls-manager.ts diff --git a/.gitignore b/.gitignore index ec34704..cf3f6ea 100644 --- a/.gitignore +++ b/.gitignore @@ -17,4 +17,5 @@ dist/ dist_*/ #------# custom -.claude/* \ No newline at end of file +.claude/* +rust/target \ No newline at end of file diff --git a/changelog.md b/changelog.md index 95cd5cf..98ad20d 100644 --- a/changelog.md +++ b/changelog.md @@ -1,5 +1,20 @@ # Changelog +## 2026-02-09 - 22.5.0 - feat(rustproxy) +introduce a Rust-powered proxy engine and workspace with core crates for proxy functionality, ACME/TLS support, passthrough and HTTP proxies, metrics, nftables integration, routing/security, management IPC, tests, and README updates + +- Add Rust workspace and multiple crates: rustproxy, rustproxy-config, rustproxy-routing, rustproxy-tls, rustproxy-passthrough, rustproxy-http, rustproxy-nftables, rustproxy-metrics, rustproxy-security +- Implement ACME integration (instant-acme) and an HTTP-01 challenge server with certificate lifecycle management +- Add TLS management: cert store, cert manager, SNI resolver, TLS acceptor/connector and certificate hot-swap support +- Implement TCP/TLS passthrough engine with ClientHello SNI parsing, PROXY v1 support, connection tracking and bidirectional forwarder +- Add Hyper-based HTTP proxy components: request/response filtering, CORS, auth, header templating and upstream selection with load balancing +- Introduce metrics (throughput tracker, metrics collector) and log deduplication utilities +- Implement nftables manager and rule builder (safe no-op behavior when not running as root) +- Add route types, validation, helpers, route manager and matchers (domain/path/header/ip) +- Provide management IPC (JSON over stdin/stdout) for TypeScript wrapper control (start/stop/add/remove ports, load certificates, etc.) +- Include extensive unit and integration tests, test helpers, and an example Rust config.json +- Update README to document the Rust-powered engine, new features and rustBinaryPath lookup + ## 2026-01-31 - 22.4.2 - fix(tests) shorten long-lived connection test timeouts and update certificate metadata timestamps diff --git a/readme.md b/readme.md index 9a39331..859bb9a 100644 --- a/readme.md +++ b/readme.md @@ -1,6 +1,6 @@ # @push.rocks/smartproxy ๐Ÿš€ -**The Swiss Army Knife of Node.js Proxies** - A unified, high-performance proxy toolkit that handles everything from simple HTTP forwarding to complex enterprise routing scenarios. +**A high-performance, Rust-powered proxy toolkit for Node.js** โ€” unified route-based configuration for SSL/TLS termination, HTTP/HTTPS reverse proxying, WebSocket support, load balancing, custom protocol handlers, and kernel-level NFTables forwarding. ## ๐Ÿ“ฆ Installation @@ -16,22 +16,26 @@ For reporting bugs, issues, or security vulnerabilities, please visit [community ## ๐ŸŽฏ What is SmartProxy? -SmartProxy is a modern, production-ready proxy solution that brings order to the chaos of traffic management. Whether you're building microservices, deploying edge infrastructure, or need a battle-tested reverse proxy, SmartProxy has you covered. +SmartProxy is a production-ready proxy solution that takes the complexity out of traffic management. Under the hood, all networking โ€” TCP, TLS, HTTP reverse proxy, connection tracking, security enforcement, and NFTables โ€” is handled by a **Rust engine** for maximum performance, while you configure everything through a clean TypeScript API with full type safety. + +Whether you're building microservices, deploying edge infrastructure, or need a battle-tested reverse proxy with automatic Let's Encrypt certificates, SmartProxy has you covered. ### โšก Key Features | Feature | Description | |---------|-------------| +| ๐Ÿฆ€ **Rust-Powered Engine** | All networking handled by a high-performance Rust binary via IPC | | ๐Ÿ”€ **Unified Route-Based Config** | Clean match/action patterns for intuitive traffic routing | | ๐Ÿ”’ **Automatic SSL/TLS** | Zero-config HTTPS with Let's Encrypt ACME integration | -| ๐ŸŽฏ **Flexible Matching** | Route by port, domain, path, client IP, TLS version, or custom logic | +| ๐ŸŽฏ **Flexible Matching** | Route by port, domain, path, client IP, TLS version, headers, or custom logic | | ๐Ÿš„ **High-Performance** | Choose between user-space or kernel-level (NFTables) forwarding | -| โš–๏ธ **Load Balancing** | Distribute traffic with health checks and multiple algorithms | -| ๐Ÿ›ก๏ธ **Enterprise Security** | IP filtering, rate limiting, authentication, connection limits | +| โš–๏ธ **Load Balancing** | Round-robin, least-connections, IP-hash with health checks | +| ๐Ÿ›ก๏ธ **Enterprise Security** | IP filtering, rate limiting, basic auth, JWT auth, connection limits | | ๐Ÿ”Œ **WebSocket Support** | First-class WebSocket proxying with ping/pong keep-alive | -| ๐ŸŽฎ **Custom Protocols** | Socket handlers for implementing any protocol | +| ๐ŸŽฎ **Custom Protocols** | Socket handlers for implementing any protocol in TypeScript | | ๐Ÿ“Š **Live Metrics** | Real-time throughput, connection counts, and performance data | | ๐Ÿ”ง **Dynamic Management** | Add/remove ports and routes at runtime without restarts | +| ๐Ÿ”„ **PROXY Protocol** | Full PROXY protocol v1/v2 support for preserving client information | ## ๐Ÿš€ Quick Start @@ -43,16 +47,16 @@ import { SmartProxy, createCompleteHttpsServer } from '@push.rocks/smartproxy'; // Create a proxy with automatic HTTPS const proxy = new SmartProxy({ acme: { - email: 'ssl@yourdomain.com', // Your email for Let's Encrypt - useProduction: true // Use production servers + email: 'ssl@yourdomain.com', + useProduction: true }, routes: [ - // Complete HTTPS setup in one line! โœจ + // Complete HTTPS setup in one call! โœจ ...createCompleteHttpsServer('app.example.com', { host: 'localhost', port: 3000 }, { - certificate: 'auto' // Magic! ๐ŸŽฉ + certificate: 'auto' // Automatic Let's Encrypt cert ๐ŸŽฉ }) ] }); @@ -84,10 +88,11 @@ SmartProxy uses a powerful **match/action** pattern that makes routing predictab ``` Every route consists of: -- **Match** - What traffic to capture (ports, domains, paths, IPs) -- **Action** - What to do with it (forward, redirect, block, socket-handler) -- **Security** (optional) - Access controls, rate limits, authentication -- **Name/Priority** (optional) - For identification and ordering +- **Match** โ€” What traffic to capture (ports, domains, paths, IPs, headers) +- **Action** โ€” What to do with it (`forward` or `socket-handler`) +- **Security** (optional) โ€” IP allow/block lists, rate limits, authentication +- **Headers** (optional) โ€” Request/response header manipulation with template variables +- **Name/Priority** (optional) โ€” For identification and ordering ### ๐Ÿ”„ TLS Modes @@ -95,8 +100,8 @@ SmartProxy supports three TLS handling modes: | Mode | Description | Use Case | |------|-------------|----------| -| `passthrough` | Forward encrypted traffic as-is | Backend handles TLS | -| `terminate` | Decrypt at proxy, forward plain | Standard reverse proxy | +| `passthrough` | Forward encrypted traffic as-is (SNI-based routing) | Backend handles TLS | +| `terminate` | Decrypt at proxy, forward plain HTTP to backend | Standard reverse proxy | | `terminate-and-reencrypt` | Decrypt, then re-encrypt to backend | Zero-trust environments | ## ๐Ÿ’ก Common Use Cases @@ -116,53 +121,61 @@ const proxy = new SmartProxy({ ### โš–๏ธ Load Balancer with Health Checks ```typescript -import { createLoadBalancerRoute } from '@push.rocks/smartproxy'; +import { SmartProxy, createLoadBalancerRoute } from '@push.rocks/smartproxy'; -const route = createLoadBalancerRoute( - 'app.example.com', - [ - { host: 'server1.internal', port: 8080 }, - { host: 'server2.internal', port: 8080 }, - { host: 'server3.internal', port: 8080 } - ], - { - tls: { mode: 'terminate', certificate: 'auto' }, - loadBalancing: { - algorithm: 'round-robin', - healthCheck: { - path: '/health', - interval: 30000, - timeout: 5000 +const proxy = new SmartProxy({ + routes: [ + createLoadBalancerRoute( + 'app.example.com', + [ + { host: 'server1.internal', port: 8080 }, + { host: 'server2.internal', port: 8080 }, + { host: 'server3.internal', port: 8080 } + ], + { + tls: { mode: 'terminate', certificate: 'auto' }, + loadBalancing: { + algorithm: 'round-robin', + healthCheck: { + path: '/health', + interval: 30000, + timeout: 5000 + } + } } - } - } -); + ) + ] +}); ``` ### ๐Ÿ”Œ WebSocket Proxy ```typescript -import { createWebSocketRoute } from '@push.rocks/smartproxy'; +import { SmartProxy, createWebSocketRoute } from '@push.rocks/smartproxy'; -const route = createWebSocketRoute( - 'ws.example.com', - { host: 'websocket-server', port: 8080 }, - { - path: '/socket', - useTls: true, - certificate: 'auto', - pingInterval: 30000, // Keep connections alive - pingTimeout: 10000 - } -); +const proxy = new SmartProxy({ + routes: [ + createWebSocketRoute( + 'ws.example.com', + { host: 'websocket-server', port: 8080 }, + { + path: '/socket', + useTls: true, + certificate: 'auto', + pingInterval: 30000, + pingTimeout: 10000 + } + ) + ] +}); ``` ### ๐Ÿšฆ API Gateway with Rate Limiting ```typescript -import { createApiGatewayRoute, addRateLimiting } from '@push.rocks/smartproxy'; +import { SmartProxy, createApiGatewayRoute, addRateLimiting } from '@push.rocks/smartproxy'; -let route = createApiGatewayRoute( +let apiRoute = createApiGatewayRoute( 'api.example.com', '/api', { host: 'api-backend', port: 8080 }, @@ -173,20 +186,22 @@ let route = createApiGatewayRoute( } ); -// Add rate limiting - 100 requests per minute per IP -route = addRateLimiting(route, { +// Add rate limiting โ€” 100 requests per minute per IP +apiRoute = addRateLimiting(apiRoute, { maxRequests: 100, window: 60, keyBy: 'ip' }); + +const proxy = new SmartProxy({ routes: [apiRoute] }); ``` ### ๐ŸŽฎ Custom Protocol Handler -SmartProxy lets you implement any protocol with full socket control: +SmartProxy lets you implement any protocol with full socket control. Routes with JavaScript socket handlers are automatically relayed from the Rust engine back to your TypeScript code: ```typescript -import { createSocketHandlerRoute, SocketHandlers } from '@push.rocks/smartproxy'; +import { SmartProxy, createSocketHandlerRoute, SocketHandlers } from '@push.rocks/smartproxy'; // Use pre-built handlers const echoRoute = createSocketHandlerRoute( @@ -214,18 +229,21 @@ const customRoute = createSocketHandlerRoute( }); } ); + +const proxy = new SmartProxy({ routes: [echoRoute, customRoute] }); ``` **Pre-built Socket Handlers:** | Handler | Description | |---------|-------------| -| `SocketHandlers.echo` | Echo server - returns everything sent | +| `SocketHandlers.echo` | Echo server โ€” returns everything sent | | `SocketHandlers.proxy(host, port)` | TCP proxy to another server | | `SocketHandlers.lineProtocol(handler)` | Line-based text protocol | | `SocketHandlers.httpResponse(code, body)` | Simple HTTP response | -| `SocketHandlers.httpRedirect(url, code)` | HTTP redirect with templates | +| `SocketHandlers.httpRedirect(url, code)` | HTTP redirect with template variables (`{domain}`, `{path}`, `{port}`, `{clientIp}`) | | `SocketHandlers.httpServer(handler)` | Full HTTP request/response handling | +| `SocketHandlers.httpBlock(status, message)` | HTTP block response | | `SocketHandlers.block(message)` | Block with optional message | ### โšก High-Performance NFTables Forwarding @@ -233,48 +251,73 @@ const customRoute = createSocketHandlerRoute( For ultra-low latency on Linux, use kernel-level forwarding (requires root): ```typescript -import { createNfTablesTerminateRoute } from '@push.rocks/smartproxy'; +import { SmartProxy, createNfTablesTerminateRoute } from '@push.rocks/smartproxy'; -const route = createNfTablesTerminateRoute( - 'fast.example.com', - { host: 'backend', port: 8080 }, - { - ports: 443, - certificate: 'auto', - preserveSourceIP: true, // Backend sees real client IP - maxRate: '1gbps' // QoS rate limiting - } -); +const proxy = new SmartProxy({ + routes: [ + createNfTablesTerminateRoute( + 'fast.example.com', + { host: 'backend', port: 8080 }, + { + ports: 443, + certificate: 'auto', + preserveSourceIP: true, // Backend sees real client IP + maxRate: '1gbps' // QoS rate limiting + } + ) + ] +}); +``` + +### ๐Ÿ”’ SNI Passthrough (TLS Passthrough) + +Forward encrypted traffic to backends without terminating TLS โ€” the proxy routes based on the SNI hostname alone: + +```typescript +import { SmartProxy, createHttpsPassthroughRoute } from '@push.rocks/smartproxy'; + +const proxy = new SmartProxy({ + routes: [ + createHttpsPassthroughRoute('secure.example.com', { + host: 'backend-that-handles-tls', + port: 8443 + }) + ] +}); ``` ## ๐Ÿ”ง Advanced Features ### ๐ŸŽฏ Dynamic Routing -Route traffic based on runtime conditions: +Route traffic based on runtime conditions using function-based host/port resolution: ```typescript -{ - name: 'business-hours-only', - match: { - ports: 443, - domains: 'internal.example.com' - }, - action: { - type: 'forward', - targets: [{ - host: (context) => { - // Dynamic host selection based on path - return context.path?.startsWith('/premium') - ? 'premium-backend' - : 'standard-backend'; - }, - port: 8080 - }] - } -} +const proxy = new SmartProxy({ + routes: [{ + name: 'dynamic-backend', + match: { + ports: 443, + domains: 'app.example.com' + }, + action: { + type: 'forward', + targets: [{ + host: (context) => { + return context.path?.startsWith('/premium') + ? 'premium-backend' + : 'standard-backend'; + }, + port: 8080 + }], + tls: { mode: 'terminate', certificate: 'auto' } + } + }] +}); ``` +> **Note:** Routes with dynamic functions (host/port callbacks) are automatically relayed through the TypeScript socket handler server, since JavaScript functions can't be serialized to Rust. + ### ๐Ÿ”’ Security Controls Comprehensive per-route security options: @@ -285,7 +328,8 @@ Comprehensive per-route security options: match: { ports: 443, domains: 'api.example.com' }, action: { type: 'forward', - targets: [{ host: 'api-backend', port: 8080 }] + targets: [{ host: 'api-backend', port: 8080 }], + tls: { mode: 'terminate', certificate: 'auto' } }, security: { // IP-based access control @@ -294,17 +338,31 @@ Comprehensive per-route security options: // Connection limits maxConnections: 1000, - maxConnectionsPerIp: 10, // Rate limiting rateLimit: { + enabled: true, maxRequests: 100, - windowMs: 60000 - } + window: 60 + }, + + // Authentication + basicAuth: { users: [{ username: 'admin', password: 'secret' }] }, + jwtAuth: { secret: 'your-jwt-secret', algorithm: 'HS256' } } } ``` +**Security modifier helpers** let you add security to any existing route: + +```typescript +import { addRateLimiting, addBasicAuth, addJwtAuth } from '@push.rocks/smartproxy'; + +let route = createHttpsTerminateRoute('api.example.com', { host: 'backend', port: 8080 }); +route = addRateLimiting(route, { maxRequests: 100, window: 60, keyBy: 'ip' }); +route = addBasicAuth(route, { users: [{ username: 'admin', password: 'secret' }] }); +``` + ### ๐Ÿ“Š Runtime Management Control your proxy without restarts: @@ -313,21 +371,26 @@ Control your proxy without restarts: // Dynamic port management await proxy.addListeningPort(8443); await proxy.removeListeningPort(8080); +const ports = await proxy.getListeningPorts(); -// Update routes on the fly +// Update routes on the fly (atomic, mutex-locked) await proxy.updateRoutes([...newRoutes]); -// Monitor status -const status = proxy.getStatus(); -console.log(`Active connections: ${status.activeConnections}`); - -// Get detailed metrics +// Get real-time metrics const metrics = proxy.getMetrics(); -console.log(`Throughput: ${metrics.throughput.bytesPerSecond} bytes/sec`); +console.log(`Active connections: ${metrics.connections.active()}`); +console.log(`Requests/sec: ${metrics.throughput.requestsPerSecond()}`); + +// Get detailed statistics from the Rust engine +const stats = await proxy.getStatistics(); // Certificate management -const certInfo = proxy.getCertificateInfo('example.com'); -console.log(`Certificate expires: ${certInfo.expiresAt}`); +await proxy.provisionCertificate('my-route-name'); +await proxy.renewCertificate('my-route-name'); +const certStatus = await proxy.getCertificateStatus('my-route-name'); + +// NFTables status +const nftStatus = await proxy.getNfTablesStatus(); ``` ### ๐Ÿ”„ Header Manipulation @@ -338,51 +401,107 @@ Transform requests and responses with template variables: { action: { type: 'forward', - targets: [{ host: 'backend', port: 8080 }], - headers: { - request: { - 'X-Real-IP': '{clientIp}', - 'X-Request-ID': '{uuid}', - 'X-Forwarded-Proto': 'https' - }, - response: { - 'X-Powered-By': 'SmartProxy', - 'Strict-Transport-Security': 'max-age=31536000', - 'X-Frame-Options': 'DENY' - } + targets: [{ host: 'backend', port: 8080 }] + }, + headers: { + request: { + 'X-Real-IP': '{clientIp}', + 'X-Request-ID': '{uuid}', + 'X-Forwarded-Proto': 'https' + }, + response: { + 'Strict-Transport-Security': 'max-age=31536000', + 'X-Frame-Options': 'DENY' } } } ``` +### ๐Ÿ”€ PROXY Protocol Support + +Preserve original client information through proxy chains: + +```typescript +const proxy = new SmartProxy({ + // Accept PROXY protocol from trusted load balancers + acceptProxyProtocol: true, + proxyIPs: ['10.0.0.1', '10.0.0.2'], + + // Forward PROXY protocol to backends + sendProxyProtocol: true, + + routes: [...] +}); +``` + +### ๐Ÿ—๏ธ Custom Certificate Provisioning + +Supply your own certificates or integrate with external certificate providers: + +```typescript +const proxy = new SmartProxy({ + certProvisionFunction: async (domain: string) => { + // Return 'http01' to let the built-in ACME handle it + if (domain.endsWith('.example.com')) return 'http01'; + + // Or return a static certificate object + return { + publicKey: myPemCert, + privateKey: myPemKey, + }; + }, + certProvisionFallbackToAcme: true, // Fall back to ACME if callback fails + routes: [...] +}); +``` + ## ๐Ÿ›๏ธ Architecture -SmartProxy is built with a modular, extensible architecture: +SmartProxy uses a hybrid **Rust + TypeScript** architecture: ``` -SmartProxy -โ”œโ”€โ”€ ๐Ÿ“‹ RouteManager # Route matching and prioritization -โ”œโ”€โ”€ ๐Ÿ”Œ PortManager # Dynamic port lifecycle management -โ”œโ”€โ”€ ๐Ÿ”’ SmartCertManager # ACME/Let's Encrypt automation -โ”œโ”€โ”€ ๐Ÿšฆ ConnectionManager # Connection pooling and tracking -โ”œโ”€โ”€ ๐Ÿ“Š MetricsCollector # Real-time performance monitoring -โ”œโ”€โ”€ ๐Ÿ›ก๏ธ SecurityManager # Access control and rate limiting -โ”œโ”€โ”€ ๐Ÿ”ง ProtocolDetector # Smart HTTP/TLS/WebSocket detection -โ”œโ”€โ”€ โšก NFTablesManager # Kernel-level forwarding (Linux) -โ””โ”€โ”€ ๐ŸŒ HttpProxyBridge # HTTP/HTTPS request handling +โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” +โ”‚ Your Application โ”‚ +โ”‚ (TypeScript โ€” routes, config, socket handlers) โ”‚ +โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ + โ”‚ IPC (JSON over stdin/stdout) +โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ–ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” +โ”‚ Rust Proxy Engine โ”‚ +โ”‚ โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” โ”‚ +โ”‚ โ”‚ TCP/TLS โ”‚ โ”‚ HTTP โ”‚ โ”‚ Route โ”‚ โ”‚ ACME โ”‚ โ”‚ +โ”‚ โ”‚ Listenerโ”‚ โ”‚ Reverse โ”‚ โ”‚ Matcher โ”‚ โ”‚ Cert Mgr โ”‚ โ”‚ +โ”‚ โ”‚ โ”‚ โ”‚ Proxy โ”‚ โ”‚ โ”‚ โ”‚ โ”‚ โ”‚ +โ”‚ โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ โ”‚ +โ”‚ โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” โ”‚ +โ”‚ โ”‚ Securityโ”‚ โ”‚ Metrics โ”‚ โ”‚ Connec- โ”‚ โ”‚ NFTables โ”‚ โ”‚ +โ”‚ โ”‚ Enforce โ”‚ โ”‚ Collect โ”‚ โ”‚ tion โ”‚ โ”‚ Mgr โ”‚ โ”‚ +โ”‚ โ”‚ โ”‚ โ”‚ โ”‚ โ”‚ Tracker โ”‚ โ”‚ โ”‚ โ”‚ +โ”‚ โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ โ”‚ +โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ + โ”‚ Unix Socket Relay +โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ–ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” +โ”‚ TypeScript Socket Handler Server โ”‚ +โ”‚ (for JS-defined socket handlers & dynamic routes) โ”‚ +โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ ``` +- **Rust Engine** handles all networking, TLS, HTTP proxying, connection management, security, and metrics +- **TypeScript** provides the npm API, configuration types, route helpers, validation, and socket handler callbacks +- **IPC** โ€” JSON commands/events over stdin/stdout for seamless cross-language communication +- **Socket Relay** โ€” a Unix domain socket server for routes requiring TypeScript-side handling (socket handlers, dynamic host/port functions) + ## ๐ŸŽฏ Route Configuration Reference ### Match Criteria ```typescript interface IRouteMatch { - ports: number | number[] | string; // 80, [80, 443], '8000-8999' - domains?: string | string[]; // 'example.com', '*.example.com' - path?: string; // '/api/*', '/users/:id' - clientIp?: string | string[]; // '10.0.0.0/8', ['192.168.*'] - tlsVersion?: string | string[]; // ['TLSv1.2', 'TLSv1.3'] + ports: number | number[] | Array<{ from: number; to: number }>; // Port(s) to listen on + domains?: string | string[]; // 'example.com', '*.example.com' + path?: string; // '/api/*', '/users/:id' + clientIp?: string[]; // ['10.0.0.0/8', '192.168.*'] + tlsVersion?: string[]; // ['TLSv1.2', 'TLSv1.3'] + headers?: Record; // Match by HTTP headers } ``` @@ -390,69 +509,251 @@ interface IRouteMatch { | Type | Description | |------|-------------| -| `forward` | Proxy to one or more backend targets | -| `redirect` | HTTP redirect with status code | -| `block` | Block the connection | -| `socket-handler` | Custom socket handling function | +| `forward` | Proxy to one or more backend targets (with optional TLS, WebSocket, load balancing) | +| `socket-handler` | Custom socket handling function in TypeScript | + +### Target Options + +```typescript +interface IRouteTarget { + host: string | string[] | ((context: IRouteContext) => string); + port: number | 'preserve' | ((context: IRouteContext) => number); + tls?: { ... }; // Per-target TLS override + priority?: number; // Target priority + match?: ITargetMatch; // Sub-match within a route (by port, path, headers, method) +} +``` ### TLS Options ```typescript interface IRouteTls { mode: 'passthrough' | 'terminate' | 'terminate-and-reencrypt'; - certificate: 'auto' | { key: string; cert: string }; - // For terminate-and-reencrypt: - reencrypt?: { - host: string; - port: number; - ca?: string; // Custom CA for backend + certificate: 'auto' | { + key: string; + cert: string; + ca?: string; + keyFile?: string; + certFile?: string; + }; + acme?: { + email: string; + useProduction?: boolean; + challengePort?: number; + renewBeforeDays?: number; + }; + versions?: string[]; + ciphers?: string[]; + honorCipherOrder?: boolean; + sessionTimeout?: number; +} +``` + +### WebSocket Options + +```typescript +interface IRouteWebSocket { + enabled: boolean; + pingInterval?: number; // ms between pings + pingTimeout?: number; // ms to wait for pong + maxPayloadSize?: number; // Maximum frame payload + subprotocols?: string[]; // Allowed subprotocols + allowedOrigins?: string[]; // CORS origins +} +``` + +### Load Balancing Options + +```typescript +interface IRouteLoadBalancing { + algorithm: 'round-robin' | 'least-connections' | 'ip-hash'; + healthCheck?: { + path: string; + interval: number; // ms + timeout: number; // ms + unhealthyThreshold?: number; + healthyThreshold?: number; }; } ``` ## ๐Ÿ› ๏ธ Helper Functions Reference -All helpers are fully typed and documented: +All helpers are fully typed and return `IRouteConfig` or `IRouteConfig[]`: ```typescript import { // HTTP/HTTPS - createHttpRoute, - createHttpsTerminateRoute, - createHttpsPassthroughRoute, - createHttpToHttpsRedirect, - createCompleteHttpsServer, + createHttpRoute, // Plain HTTP route + createHttpsTerminateRoute, // HTTPS with TLS termination + createHttpsPassthroughRoute, // SNI passthrough (no termination) + createHttpToHttpsRedirect, // HTTP โ†’ HTTPS redirect + createCompleteHttpsServer, // HTTPS + redirect combo (returns IRouteConfig[]) // Load Balancing - createLoadBalancerRoute, - createSmartLoadBalancer, + createLoadBalancerRoute, // Multi-backend with health checks + createSmartLoadBalancer, // Dynamic domain-based backend selection // API & WebSocket - createApiRoute, - createApiGatewayRoute, - createWebSocketRoute, + createApiRoute, // API route with path matching + createApiGatewayRoute, // API gateway with CORS + createWebSocketRoute, // WebSocket-enabled route // Custom Protocols - createSocketHandlerRoute, - SocketHandlers, + createSocketHandlerRoute, // Custom socket handler + SocketHandlers, // Pre-built handlers (echo, proxy, block, etc.) - // NFTables (Linux) - createNfTablesRoute, - createNfTablesTerminateRoute, - createCompleteNfTablesHttpsServer, + // NFTables (Linux, requires root) + createNfTablesRoute, // Kernel-level packet forwarding + createNfTablesTerminateRoute, // NFTables + TLS termination + createCompleteNfTablesHttpsServer, // NFTables HTTPS + redirect combo // Dynamic Routing - createPortMappingRoute, - createOffsetPortMappingRoute, - createDynamicRoute, + createPortMappingRoute, // Port mapping with context + createOffsetPortMappingRoute, // Simple port offset + createDynamicRoute, // Dynamic host/port via functions // Security Modifiers - addRateLimiting, - addBasicAuth, - addJwtAuth + addRateLimiting, // Add rate limiting to any route + addBasicAuth, // Add basic auth to any route + addJwtAuth, // Add JWT auth to any route + + // Route Utilities + mergeRouteConfigs, // Deep-merge two route configs + findMatchingRoutes, // Find routes matching criteria + findBestMatchingRoute, // Find best matching route + cloneRoute, // Deep-clone a route + generateRouteId, // Generate deterministic route ID + RouteValidator, // Validate route configurations } from '@push.rocks/smartproxy'; ``` +## ๐Ÿ“– API Documentation + +### SmartProxy Class + +```typescript +class SmartProxy extends EventEmitter { + constructor(options: ISmartProxyOptions); + + // Lifecycle + start(): Promise; + stop(): Promise; + + // Route Management (atomic, mutex-locked) + updateRoutes(routes: IRouteConfig[]): Promise; + + // Port Management + addListeningPort(port: number): Promise; + removeListeningPort(port: number): Promise; + getListeningPorts(): Promise; + + // Monitoring & Metrics + getMetrics(): IMetrics; // Sync โ€” returns cached metrics adapter + getStatistics(): Promise; // Async โ€” queries Rust engine + + // Certificate Management + provisionCertificate(routeName: string): Promise; + renewCertificate(routeName: string): Promise; + getCertificateStatus(routeName: string): Promise; + getEligibleDomainsForCertificates(): string[]; + + // NFTables + getNfTablesStatus(): Promise>; + + // Events + on(event: 'error', handler: (err: Error) => void): this; +} +``` + +### Configuration Options + +```typescript +interface ISmartProxyOptions { + routes: IRouteConfig[]; // Required: array of route configs + + // ACME/Let's Encrypt + acme?: { + email: string; // Contact email for Let's Encrypt + useProduction?: boolean; // Use production servers (default: false) + port?: number; // HTTP-01 challenge port (default: 80) + renewThresholdDays?: number; // Days before expiry to renew (default: 30) + autoRenew?: boolean; // Enable auto-renewal (default: true) + certificateStore?: string; // Directory to store certs (default: './certs') + renewCheckIntervalHours?: number; // Renewal check interval (default: 24) + }; + + // Custom certificate provisioning + certProvisionFunction?: (domain: string) => Promise; + certProvisionFallbackToAcme?: boolean; // Fall back to ACME on failure (default: true) + + // Global defaults + defaults?: { + target?: { host: string; port: number }; + security?: { ipAllowList?: string[]; ipBlockList?: string[]; maxConnections?: number }; + }; + + // PROXY protocol + proxyIPs?: string[]; // Trusted proxy IPs + acceptProxyProtocol?: boolean; // Accept PROXY protocol headers + sendProxyProtocol?: boolean; // Send PROXY protocol to targets + + // Timeouts + connectionTimeout?: number; // Backend connection timeout (default: 30s) + initialDataTimeout?: number; // Initial data/SNI timeout (default: 120s) + socketTimeout?: number; // Socket inactivity timeout (default: 1h) + maxConnectionLifetime?: number; // Max connection lifetime (default: 24h) + inactivityTimeout?: number; // Inactivity timeout (default: 4h) + gracefulShutdownTimeout?: number; // Shutdown grace period (default: 30s) + + // Connection limits + maxConnectionsPerIP?: number; // Per-IP connection limit (default: 100) + connectionRateLimitPerMinute?: number; // Per-IP rate limit (default: 300/min) + + // Keep-alive + keepAliveTreatment?: 'standard' | 'extended' | 'immortal'; + keepAliveInactivityMultiplier?: number; // (default: 6) + extendedKeepAliveLifetime?: number; // (default: 7 days) + + // Metrics + metrics?: { + enabled?: boolean; + sampleIntervalMs?: number; + retentionSeconds?: number; + }; + + // Behavior + enableDetailedLogging?: boolean; // Verbose connection logging + enableTlsDebugLogging?: boolean; // TLS handshake debug logging + + // Rust binary + rustBinaryPath?: string; // Custom path to the Rust binary +} +``` + +### NfTablesProxy Class + +A standalone class for managing nftables NAT rules directly (Linux only, requires root): + +```typescript +import { NfTablesProxy } from '@push.rocks/smartproxy'; + +const nftProxy = new NfTablesProxy({ + fromPorts: [80, 443], + toHost: 'backend-server', + toPorts: [8080, 8443], + protocol: 'tcp', + preserveSourceIP: true, + enableIPv6: true, + maxRate: '1gbps', + useIPSets: true +}); + +await nftProxy.start(); // Apply nftables rules +const status = nftProxy.getStatus(); +await nftProxy.stop(); // Remove rules +``` + ## ๐Ÿ› Troubleshooting ### Certificate Issues @@ -460,93 +761,41 @@ import { - โœ… Port 80 must be accessible for ACME HTTP-01 challenges - โœ… Check DNS propagation with `dig` or `nslookup` - โœ… Verify the email in ACME configuration is valid +- โœ… Use `getCertificateStatus('route-name')` to check cert state ### Connection Problems - โœ… Check route priorities (higher number = matched first) - โœ… Verify security rules aren't blocking legitimate traffic - โœ… Test with `curl -v` for detailed connection output -- โœ… Enable debug logging for verbose output +- โœ… Enable debug logging with `enableDetailedLogging: true` + +### Rust Binary Not Found +SmartProxy searches for the Rust binary in this order: +1. `SMARTPROXY_RUST_BINARY` environment variable +2. Platform-specific npm package (`@push.rocks/smartproxy-linux-x64`, etc.) +3. Local dev build (`./rust/target/release/rustproxy`) +4. System PATH (`rustproxy`) + +Set `rustBinaryPath` in options to override. ### Performance Tuning - โœ… Use NFTables forwarding for high-traffic routes (Linux only) - โœ… Enable connection keep-alive where appropriate -- โœ… Monitor metrics to identify bottlenecks -- โœ… Adjust `maxConnections` based on your server resources - -### Debug Mode - -```typescript -const proxy = new SmartProxy({ - enableDetailedLogging: true, // Verbose connection logging - routes: [...] -}); -``` +- โœ… Use `getMetrics()` and `getStatistics()` to identify bottlenecks +- โœ… Adjust `maxConnectionsPerIP` and `connectionRateLimitPerMinute` based on your workload +- โœ… Use `passthrough` TLS mode when backend can handle TLS directly ## ๐Ÿ† Best Practices -1. **๐Ÿ“ Use Helper Functions** - They provide sensible defaults and prevent common mistakes -2. **๐ŸŽฏ Set Route Priorities** - More specific routes should have higher priority values -3. **๐Ÿ”’ Enable Security** - Always use IP filtering and rate limiting for public services -4. **๐Ÿ“Š Monitor Metrics** - Use the built-in metrics to identify issues early -5. **๐Ÿ”„ Certificate Monitoring** - Set up alerts for certificate expiration -6. **๐Ÿ›‘ Graceful Shutdown** - Always call `proxy.stop()` for clean connection termination -7. **๐Ÿ”ง Test Routes** - Validate your route configurations before deploying to production - -## ๐Ÿ“– API Documentation - -### SmartProxy Class - -```typescript -class SmartProxy { - constructor(options: ISmartProxyOptions); - - // Lifecycle - start(): Promise; - stop(): Promise; - - // Route Management - updateRoutes(routes: IRouteConfig[]): Promise; - - // Port Management - addListeningPort(port: number): Promise; - removeListeningPort(port: number): Promise; - getListeningPorts(): number[]; - - // Monitoring - getStatus(): IProxyStatus; - getMetrics(): IMetrics; - - // Certificate Management - getCertificateInfo(domain: string): ICertStatus | null; -} -``` - -### Configuration Options - -```typescript -interface ISmartProxyOptions { - routes: IRouteConfig[]; // Required: array of route configs - - // ACME/Let's Encrypt - acme?: { - email: string; // Contact email - useProduction?: boolean; // Use production servers (default: false) - port?: number; // Challenge port (default: 80) - renewThresholdDays?: number; // Days before expiry to renew (default: 30) - }; - - // Defaults - defaults?: { - target?: { host: string; port: number }; - security?: IRouteSecurity; - tls?: IRouteTls; - }; - - // Behavior - enableDetailedLogging?: boolean; - gracefulShutdownTimeout?: number; // ms to wait for connections to close -} -``` +1. **๐Ÿ“ Use Helper Functions** โ€” They provide sensible defaults and prevent common mistakes +2. **๐ŸŽฏ Set Route Priorities** โ€” More specific routes should have higher priority values +3. **๐Ÿ”’ Enable Security** โ€” Always use IP filtering and rate limiting for public-facing services +4. **๐Ÿ“Š Monitor Metrics** โ€” Use the built-in metrics to catch issues early +5. **๐Ÿ”„ Certificate Monitoring** โ€” Set up alerts before certificates expire +6. **๐Ÿ›‘ Graceful Shutdown** โ€” Always call `proxy.stop()` for clean connection termination +7. **โœ… Validate Routes** โ€” Use `RouteValidator.validateRoutes()` to catch config errors before deployment +8. **๐Ÿ”€ Atomic Updates** โ€” Use `updateRoutes()` for hot-reloading routes (mutex-locked, no downtime) +9. **๐ŸŽฎ Use Socket Handlers** โ€” For protocols beyond HTTP, implement custom socket handlers instead of fighting the proxy model ## License and Legal Information diff --git a/rust/Cargo.lock b/rust/Cargo.lock new file mode 100644 index 0000000..6479928 --- /dev/null +++ b/rust/Cargo.lock @@ -0,0 +1,1760 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 4 + +[[package]] +name = "aho-corasick" +version = "1.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ddd31a130427c27518df266943a5308ed92d4b226cc639f5a8f1002816174301" +dependencies = [ + "memchr", +] + +[[package]] +name = "anstream" +version = "0.6.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43d5b281e737544384e969a5ccad3f1cdd24b48086a0fc1b2a5262a26b8f4f4a" +dependencies = [ + "anstyle", + "anstyle-parse", + "anstyle-query", + "anstyle-wincon", + "colorchoice", + "is_terminal_polyfill", + "utf8parse", +] + +[[package]] +name = "anstyle" +version = "1.0.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5192cca8006f1fd4f7237516f40fa183bb07f8fbdfedaa0036de5ea9b0b45e78" + +[[package]] +name = "anstyle-parse" +version = "0.2.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4e7644824f0aa2c7b9384579234ef10eb7efb6a0deb83f9630a49594dd9c15c2" +dependencies = [ + "utf8parse", +] + +[[package]] +name = "anstyle-query" +version = "1.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "40c48f72fd53cd289104fc64099abca73db4166ad86ea0b4341abe65af83dadc" +dependencies = [ + "windows-sys 0.61.2", +] + +[[package]] +name = "anstyle-wincon" +version = "3.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "291e6a250ff86cd4a820112fb8898808a366d8f9f58ce16d1f538353ad55747d" +dependencies = [ + "anstyle", + "once_cell_polyfill", + "windows-sys 0.61.2", +] + +[[package]] +name = "anyhow" +version = "1.0.101" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f0e0fee31ef5ed1ba1316088939cea399010ed7731dba877ed44aeb407a75ea" + +[[package]] +name = "arc-swap" +version = "1.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ded5f9a03ac8f24d1b8a25101ee812cd32cdc8c50a4c50237de2c4915850e73" +dependencies = [ + "rustversion", +] + +[[package]] +name = "async-trait" +version = "0.1.89" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9035ad2d096bed7955a320ee7e2230574d28fd3c3a0f186cbea1ff3c7eed5dbb" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "atomic-waker" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" + +[[package]] +name = "autocfg" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" + +[[package]] +name = "aws-lc-rs" +version = "1.15.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b7b6141e96a8c160799cc2d5adecd5cbbe5054cb8c7c4af53da0f83bb7ad256" +dependencies = [ + "aws-lc-sys", + "zeroize", +] + +[[package]] +name = "aws-lc-sys" +version = "0.37.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c34dda4df7017c8db52132f0f8a2e0f8161649d15723ed63fc00c82d0f2081a" +dependencies = [ + "cc", + "cmake", + "dunce", + "fs_extra", +] + +[[package]] +name = "base64" +version = "0.22.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" + +[[package]] +name = "bitflags" +version = "2.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "812e12b5285cc515a9c72a5c1d3b6d46a19dac5acfef5265968c166106e31dd3" + +[[package]] +name = "bumpalo" +version = "3.19.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5dd9dc738b7a8311c7ade152424974d8115f2cdad61e8dab8dac9f2362298510" + +[[package]] +name = "bytes" +version = "1.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e748733b7cbc798e1434b6ac524f0c1ff2ab456fe201501e6497c8417a4fc33" + +[[package]] +name = "cc" +version = "1.2.55" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47b26a0954ae34af09b50f0de26458fa95369a0d478d8236d3f93082b219bd29" +dependencies = [ + "find-msvc-tools", + "jobserver", + "libc", + "shlex", +] + +[[package]] +name = "cfg-if" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801" + +[[package]] +name = "clap" +version = "4.5.57" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6899ea499e3fb9305a65d5ebf6e3d2248c5fab291f300ad0a704fbe142eae31a" +dependencies = [ + "clap_builder", + "clap_derive", +] + +[[package]] +name = "clap_builder" +version = "4.5.57" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b12c8b680195a62a8364d16b8447b01b6c2c8f9aaf68bee653be34d4245e238" +dependencies = [ + "anstream", + "anstyle", + "clap_lex", + "strsim", +] + +[[package]] +name = "clap_derive" +version = "4.5.55" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a92793da1a46a5f2a02a6f4c46c6496b28c43638adea8306fcb0caa1634f24e5" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "clap_lex" +version = "0.7.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3e64b0cc0439b12df2fa678eae89a1c56a529fd067a9115f7827f1fffd22b32" + +[[package]] +name = "cmake" +version = "0.1.57" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75443c44cd6b379beb8c5b45d85d0773baf31cce901fe7bb252f4eff3008ef7d" +dependencies = [ + "cc", +] + +[[package]] +name = "colorchoice" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b05b61dc5112cbb17e4b6cd61790d9845d13888356391624cbe7e41efeac1e75" + +[[package]] +name = "core-foundation" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b2a6cd9ae233e7f62ba4e9353e81a88df7fc8a5987b8d445b4d90c879bd156f6" +dependencies = [ + "core-foundation-sys", + "libc", +] + +[[package]] +name = "core-foundation-sys" +version = "0.8.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" + +[[package]] +name = "crossbeam-utils" +version = "0.8.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" + +[[package]] +name = "dashmap" +version = "6.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5041cc499144891f3790297212f32a74fb938e5136a14943f338ef9e0ae276cf" +dependencies = [ + "cfg-if", + "crossbeam-utils", + "hashbrown 0.14.5", + "lock_api", + "once_cell", + "parking_lot_core", +] + +[[package]] +name = "deranged" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ececcb659e7ba858fb4f10388c250a7252eb0a27373f1a72b8748afdd248e587" +dependencies = [ + "powerfmt", +] + +[[package]] +name = "dunce" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92773504d58c093f6de2459af4af33faa518c13451eb8f2b5698ed3d36e7c813" + +[[package]] +name = "equivalent" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" + +[[package]] +name = "errno" +version = "0.3.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" +dependencies = [ + "libc", + "windows-sys 0.61.2", +] + +[[package]] +name = "fastrand" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" + +[[package]] +name = "find-msvc-tools" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5baebc0774151f905a1a2cc41989300b1e6fbb29aff0ceffa1064fdd3088d582" + +[[package]] +name = "fnv" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" + +[[package]] +name = "fs_extra" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c" + +[[package]] +name = "futures-channel" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2dff15bf788c671c1934e366d07e30c1814a8ef514e1af724a602e8a2fbe1b10" +dependencies = [ + "futures-core", +] + +[[package]] +name = "futures-core" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e" + +[[package]] +name = "futures-sink" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e575fab7d1e0dcb8d0c7bcf9a63ee213816ab51902e6d244a95819acacf1d4f7" + +[[package]] +name = "futures-task" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988" + +[[package]] +name = "futures-util" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" +dependencies = [ + "futures-core", + "futures-task", + "pin-project-lite", + "pin-utils", +] + +[[package]] +name = "getrandom" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff2abc00be7fca6ebc474524697ae276ad847ad0a6b3faa4bcb027e9a4614ad0" +dependencies = [ + "cfg-if", + "js-sys", + "libc", + "wasi", + "wasm-bindgen", +] + +[[package]] +name = "getrandom" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "899def5c37c4fd7b2664648c28120ecec138e4d395b459e5ca34f9cce2dd77fd" +dependencies = [ + "cfg-if", + "libc", + "r-efi", + "wasip2", +] + +[[package]] +name = "glob-match" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9985c9503b412198aa4197559e9a318524ebc4519c229bfa05a535828c950b9d" + +[[package]] +name = "h2" +version = "0.4.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2f44da3a8150a6703ed5d34e164b875fd14c2cdab9af1252a9a1020bde2bdc54" +dependencies = [ + "atomic-waker", + "bytes", + "fnv", + "futures-core", + "futures-sink", + "http", + "indexmap", + "slab", + "tokio", + "tokio-util", + "tracing", +] + +[[package]] +name = "hashbrown" +version = "0.14.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" + +[[package]] +name = "hashbrown" +version = "0.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100" + +[[package]] +name = "heck" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" + +[[package]] +name = "http" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3ba2a386d7f85a81f119ad7498ebe444d2e22c2af0b86b069416ace48b3311a" +dependencies = [ + "bytes", + "itoa", +] + +[[package]] +name = "http-body" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" +dependencies = [ + "bytes", + "http", +] + +[[package]] +name = "http-body-util" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b021d93e26becf5dc7e1b75b1bed1fd93124b374ceb73f43d4d4eafec896a64a" +dependencies = [ + "bytes", + "futures-core", + "http", + "http-body", + "pin-project-lite", +] + +[[package]] +name = "httparse" +version = "1.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87" + +[[package]] +name = "httpdate" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" + +[[package]] +name = "hyper" +version = "1.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2ab2d4f250c3d7b1c9fcdff1cece94ea4e2dfbec68614f7b87cb205f24ca9d11" +dependencies = [ + "atomic-waker", + "bytes", + "futures-channel", + "futures-core", + "h2", + "http", + "http-body", + "httparse", + "httpdate", + "itoa", + "pin-project-lite", + "pin-utils", + "smallvec", + "tokio", + "want", +] + +[[package]] +name = "hyper-rustls" +version = "0.27.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3c93eb611681b207e1fe55d5a71ecf91572ec8a6705cdb6857f7d8d5242cf58" +dependencies = [ + "http", + "hyper", + "hyper-util", + "rustls", + "rustls-native-certs", + "rustls-pki-types", + "tokio", + "tokio-rustls", + "tower-service", +] + +[[package]] +name = "hyper-util" +version = "0.1.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96547c2556ec9d12fb1578c4eaf448b04993e7fb79cbaad930a656880a6bdfa0" +dependencies = [ + "bytes", + "futures-channel", + "futures-util", + "http", + "http-body", + "hyper", + "libc", + "pin-project-lite", + "socket2", + "tokio", + "tower-service", + "tracing", +] + +[[package]] +name = "indexmap" +version = "2.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7714e70437a7dc3ac8eb7e6f8df75fd8eb422675fc7678aff7364301092b1017" +dependencies = [ + "equivalent", + "hashbrown 0.16.1", +] + +[[package]] +name = "instant-acme" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37221e690dcc5d0ea7c1f70decda6ae3495e72e8af06bca15e982193ffdf4fc4" +dependencies = [ + "async-trait", + "base64", + "bytes", + "http", + "http-body", + "http-body-util", + "hyper", + "hyper-rustls", + "hyper-util", + "ring", + "rustls-pki-types", + "serde", + "serde_json", + "thiserror 1.0.69", +] + +[[package]] +name = "ipnet" +version = "2.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "469fb0b9cefa57e3ef31275ee7cacb78f2fdca44e4765491884a2b119d4eb130" + +[[package]] +name = "is_terminal_polyfill" +version = "1.70.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a6cb138bb79a146c1bd460005623e142ef0181e3d0219cb493e02f7d08a35695" + +[[package]] +name = "itoa" +version = "1.0.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92ecc6618181def0457392ccd0ee51198e065e016d1d527a7ac1b6dc7c1f09d2" + +[[package]] +name = "jobserver" +version = "0.1.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9afb3de4395d6b3e67a780b6de64b51c978ecf11cb9a462c66be7d4ca9039d33" +dependencies = [ + "getrandom 0.3.4", + "libc", +] + +[[package]] +name = "js-sys" +version = "0.3.85" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8c942ebf8e95485ca0d52d97da7c5a2c387d0e7f0ba4c35e93bfcaee045955b3" +dependencies = [ + "once_cell", + "wasm-bindgen", +] + +[[package]] +name = "jsonwebtoken" +version = "9.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a87cc7a48537badeae96744432de36f4be2b4a34a05a5ef32e9dd8a1c169dde" +dependencies = [ + "base64", + "js-sys", + "pem", + "ring", + "serde", + "serde_json", + "simple_asn1", +] + +[[package]] +name = "lazy_static" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" + +[[package]] +name = "libc" +version = "0.2.180" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bcc35a38544a891a5f7c865aca548a982ccb3b8650a5b06d0fd33a10283c56fc" + +[[package]] +name = "linux-raw-sys" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df1d3c3b53da64cf5760482273a98e575c651a67eec7f77df96b5b642de8f039" + +[[package]] +name = "lock_api" +version = "0.4.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "224399e74b87b5f3557511d98dff8b14089b3dadafcab6bb93eab67d3aace965" +dependencies = [ + "scopeguard", +] + +[[package]] +name = "log" +version = "0.4.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" + +[[package]] +name = "matchers" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d1525a2a28c7f4fa0fc98bb91ae755d1e2d1505079e05539e35bc876b5d65ae9" +dependencies = [ + "regex-automata", +] + +[[package]] +name = "memchr" +version = "2.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79" + +[[package]] +name = "mio" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a69bcab0ad47271a0234d9422b131806bf3968021e5dc9328caf2d4cd58557fc" +dependencies = [ + "libc", + "wasi", + "windows-sys 0.61.2", +] + +[[package]] +name = "nu-ansi-term" +version = "0.50.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5" +dependencies = [ + "windows-sys 0.61.2", +] + +[[package]] +name = "num-bigint" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9" +dependencies = [ + "num-integer", + "num-traits", +] + +[[package]] +name = "num-conv" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cf97ec579c3c42f953ef76dbf8d55ac91fb219dde70e49aa4a6b7d74e9919050" + +[[package]] +name = "num-integer" +version = "0.1.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-traits" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" +dependencies = [ + "autocfg", +] + +[[package]] +name = "once_cell" +version = "1.21.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" + +[[package]] +name = "once_cell_polyfill" +version = "1.70.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "384b8ab6d37215f3c5301a95a4accb5d64aa607f1fcb26a11b5303878451b4fe" + +[[package]] +name = "openssl-probe" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c87def4c32ab89d880effc9e097653c8da5d6ef28e6b539d313baaacfbafcbe" + +[[package]] +name = "parking_lot" +version = "0.12.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93857453250e3077bd71ff98b6a65ea6621a19bb0f559a85248955ac12c45a1a" +dependencies = [ + "lock_api", + "parking_lot_core", +] + +[[package]] +name = "parking_lot_core" +version = "0.9.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2621685985a2ebf1c516881c026032ac7deafcda1a2c9b7850dc81e3dfcb64c1" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall", + "smallvec", + "windows-link", +] + +[[package]] +name = "pem" +version = "3.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d30c53c26bc5b31a98cd02d20f25a7c8567146caf63ed593a9d87b2775291be" +dependencies = [ + "base64", + "serde_core", +] + +[[package]] +name = "pin-project-lite" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b3cff922bd51709b605d9ead9aa71031d81447142d828eb4a6eba76fe619f9b" + +[[package]] +name = "pin-utils" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" + +[[package]] +name = "powerfmt" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" + +[[package]] +name = "proc-macro2" +version = "1.0.106" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fd00f0bb2e90d81d1044c2b32617f68fcb9fa3bb7640c23e9c748e53fb30934" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "quote" +version = "1.0.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "21b2ebcf727b7760c461f091f9f0f539b77b8e87f2fd88131e7f1b433b3cece4" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "r-efi" +version = "5.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" + +[[package]] +name = "rcgen" +version = "0.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75e669e5202259b5314d1ea5397316ad400819437857b90861765f24c4cf80a2" +dependencies = [ + "pem", + "ring", + "rustls-pki-types", + "time", + "yasna", +] + +[[package]] +name = "redox_syscall" +version = "0.5.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed2bf2547551a7053d6fdfafda3f938979645c44812fbfcda098faae3f1a362d" +dependencies = [ + "bitflags", +] + +[[package]] +name = "regex" +version = "1.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e10754a14b9137dd7b1e3e5b0493cc9171fdd105e0ab477f51b72e7f3ac0e276" +dependencies = [ + "aho-corasick", + "memchr", + "regex-automata", + "regex-syntax", +] + +[[package]] +name = "regex-automata" +version = "0.4.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e1dd4122fc1595e8162618945476892eefca7b88c52820e74af6262213cae8f" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax", +] + +[[package]] +name = "regex-syntax" +version = "0.8.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a96887878f22d7bad8a3b6dc5b7440e0ada9a245242924394987b21cf2210a4c" + +[[package]] +name = "ring" +version = "0.17.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4689e6c2294d81e88dc6261c768b63bc4fcdb852be6d1352498b114f61383b7" +dependencies = [ + "cc", + "cfg-if", + "getrandom 0.2.17", + "libc", + "untrusted", + "windows-sys 0.52.0", +] + +[[package]] +name = "rustix" +version = "1.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "146c9e247ccc180c1f61615433868c99f3de3ae256a30a43b49f67c2d9171f34" +dependencies = [ + "bitflags", + "errno", + "libc", + "linux-raw-sys", + "windows-sys 0.61.2", +] + +[[package]] +name = "rustls" +version = "0.23.36" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c665f33d38cea657d9614f766881e4d510e0eda4239891eea56b4cadcf01801b" +dependencies = [ + "aws-lc-rs", + "log", + "once_cell", + "ring", + "rustls-pki-types", + "rustls-webpki", + "subtle", + "zeroize", +] + +[[package]] +name = "rustls-native-certs" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "612460d5f7bea540c490b2b6395d8e34a953e52b491accd6c86c8164c5932a63" +dependencies = [ + "openssl-probe", + "rustls-pki-types", + "schannel", + "security-framework", +] + +[[package]] +name = "rustls-pemfile" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dce314e5fee3f39953d46bb63bb8a46d40c2f8fb7cc5a3b6cab2bde9721d6e50" +dependencies = [ + "rustls-pki-types", +] + +[[package]] +name = "rustls-pki-types" +version = "1.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be040f8b0a225e40375822a563fa9524378b9d63112f53e19ffff34df5d33fdd" +dependencies = [ + "zeroize", +] + +[[package]] +name = "rustls-webpki" +version = "0.103.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d7df23109aa6c1567d1c575b9952556388da57401e4ace1d15f79eedad0d8f53" +dependencies = [ + "aws-lc-rs", + "ring", + "rustls-pki-types", + "untrusted", +] + +[[package]] +name = "rustproxy" +version = "0.1.0" +dependencies = [ + "anyhow", + "arc-swap", + "bytes", + "clap", + "dashmap", + "http-body-util", + "hyper", + "hyper-util", + "rcgen", + "rustls", + "rustproxy-config", + "rustproxy-http", + "rustproxy-metrics", + "rustproxy-nftables", + "rustproxy-passthrough", + "rustproxy-routing", + "rustproxy-security", + "rustproxy-tls", + "serde", + "serde_json", + "tokio", + "tokio-rustls", + "tokio-util", + "tracing", + "tracing-subscriber", +] + +[[package]] +name = "rustproxy-config" +version = "0.1.0" +dependencies = [ + "ipnet", + "serde", + "serde_json", + "thiserror 2.0.18", +] + +[[package]] +name = "rustproxy-http" +version = "0.1.0" +dependencies = [ + "anyhow", + "arc-swap", + "bytes", + "dashmap", + "http-body-util", + "hyper", + "hyper-util", + "regex", + "rustproxy-config", + "rustproxy-metrics", + "rustproxy-routing", + "rustproxy-security", + "thiserror 2.0.18", + "tokio", + "tracing", +] + +[[package]] +name = "rustproxy-metrics" +version = "0.1.0" +dependencies = [ + "dashmap", + "serde", + "serde_json", + "tokio", + "tokio-util", + "tracing", +] + +[[package]] +name = "rustproxy-nftables" +version = "0.1.0" +dependencies = [ + "anyhow", + "libc", + "rustproxy-config", + "serde", + "serde_json", + "thiserror 2.0.18", + "tokio", + "tracing", +] + +[[package]] +name = "rustproxy-passthrough" +version = "0.1.0" +dependencies = [ + "anyhow", + "arc-swap", + "dashmap", + "rustls", + "rustls-pemfile", + "rustproxy-config", + "rustproxy-http", + "rustproxy-metrics", + "rustproxy-routing", + "serde", + "serde_json", + "thiserror 2.0.18", + "tokio", + "tokio-rustls", + "tokio-util", + "tracing", +] + +[[package]] +name = "rustproxy-routing" +version = "0.1.0" +dependencies = [ + "arc-swap", + "glob-match", + "ipnet", + "regex", + "rustproxy-config", + "thiserror 2.0.18", + "tracing", +] + +[[package]] +name = "rustproxy-security" +version = "0.1.0" +dependencies = [ + "base64", + "dashmap", + "ipnet", + "jsonwebtoken", + "rustproxy-config", + "serde", + "thiserror 2.0.18", + "tracing", +] + +[[package]] +name = "rustproxy-tls" +version = "0.1.0" +dependencies = [ + "anyhow", + "instant-acme", + "rcgen", + "rustls", + "rustproxy-config", + "serde", + "serde_json", + "tempfile", + "thiserror 2.0.18", + "tokio", + "tracing", +] + +[[package]] +name = "rustversion" +version = "1.0.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d" + +[[package]] +name = "schannel" +version = "0.1.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "891d81b926048e76efe18581bf793546b4c0eaf8448d72be8de2bbee5fd166e1" +dependencies = [ + "windows-sys 0.61.2", +] + +[[package]] +name = "scopeguard" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" + +[[package]] +name = "security-framework" +version = "3.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b3297343eaf830f66ede390ea39da1d462b6b0c1b000f420d0a83f898bbbe6ef" +dependencies = [ + "bitflags", + "core-foundation", + "core-foundation-sys", + "libc", + "security-framework-sys", +] + +[[package]] +name = "security-framework-sys" +version = "2.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc1f0cbffaac4852523ce30d8bd3c5cdc873501d96ff467ca09b6767bb8cd5c0" +dependencies = [ + "core-foundation-sys", + "libc", +] + +[[package]] +name = "serde" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e" +dependencies = [ + "serde_core", + "serde_derive", +] + +[[package]] +name = "serde_core" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "serde_json" +version = "1.0.149" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "83fc039473c5595ace860d8c4fafa220ff474b3fc6bfdb4293327f1a37e94d86" +dependencies = [ + "itoa", + "memchr", + "serde", + "serde_core", + "zmij", +] + +[[package]] +name = "sharded-slab" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f40ca3c46823713e0d4209592e8d6e826aa57e928f09752619fc696c499637f6" +dependencies = [ + "lazy_static", +] + +[[package]] +name = "shlex" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" + +[[package]] +name = "signal-hook-registry" +version = "1.4.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4db69cba1110affc0e9f7bcd48bbf87b3f4fc7c61fc9155afd4c469eb3d6c1b" +dependencies = [ + "errno", + "libc", +] + +[[package]] +name = "simple_asn1" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "297f631f50729c8c99b84667867963997ec0b50f32b2a7dbcab828ef0541e8bb" +dependencies = [ + "num-bigint", + "num-traits", + "thiserror 2.0.18", + "time", +] + +[[package]] +name = "slab" +version = "0.4.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c790de23124f9ab44544d7ac05d60440adc586479ce501c1d6d7da3cd8c9cf5" + +[[package]] +name = "smallvec" +version = "1.15.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" + +[[package]] +name = "socket2" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "86f4aa3ad99f2088c990dfa82d367e19cb29268ed67c574d10d0a4bfe71f07e0" +dependencies = [ + "libc", + "windows-sys 0.60.2", +] + +[[package]] +name = "strsim" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" + +[[package]] +name = "subtle" +version = "2.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" + +[[package]] +name = "syn" +version = "2.0.114" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d4d107df263a3013ef9b1879b0df87d706ff80f65a86ea879bd9c31f9b307c2a" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "tempfile" +version = "3.24.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "655da9c7eb6305c55742045d5a8d2037996d61d8de95806335c7c86ce0f82e9c" +dependencies = [ + "fastrand", + "getrandom 0.3.4", + "once_cell", + "rustix", + "windows-sys 0.61.2", +] + +[[package]] +name = "thiserror" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52" +dependencies = [ + "thiserror-impl 1.0.69", +] + +[[package]] +name = "thiserror" +version = "2.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4288b5bcbc7920c07a1149a35cf9590a2aa808e0bc1eafaade0b80947865fbc4" +dependencies = [ + "thiserror-impl 2.0.18", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "thiserror-impl" +version = "2.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebc4ee7f67670e9b64d05fa4253e753e016c6c95ff35b89b7941d6b856dec1d5" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "thread_local" +version = "1.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f60246a4944f24f6e018aa17cdeffb7818b76356965d03b07d6a9886e8962185" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "time" +version = "0.3.47" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "743bd48c283afc0388f9b8827b976905fb217ad9e647fae3a379a9283c4def2c" +dependencies = [ + "deranged", + "itoa", + "num-conv", + "powerfmt", + "serde_core", + "time-core", + "time-macros", +] + +[[package]] +name = "time-core" +version = "0.1.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7694e1cfe791f8d31026952abf09c69ca6f6fa4e1a1229e18988f06a04a12dca" + +[[package]] +name = "time-macros" +version = "0.2.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2e70e4c5a0e0a8a4823ad65dfe1a6930e4f4d756dcd9dd7939022b5e8c501215" +dependencies = [ + "num-conv", + "time-core", +] + +[[package]] +name = "tokio" +version = "1.49.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72a2903cd7736441aac9df9d7688bd0ce48edccaadf181c3b90be801e81d3d86" +dependencies = [ + "bytes", + "libc", + "mio", + "parking_lot", + "pin-project-lite", + "signal-hook-registry", + "socket2", + "tokio-macros", + "windows-sys 0.61.2", +] + +[[package]] +name = "tokio-macros" +version = "2.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af407857209536a95c8e56f8231ef2c2e2aff839b22e07a1ffcbc617e9db9fa5" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "tokio-rustls" +version = "0.26.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1729aa945f29d91ba541258c8df89027d5792d85a8841fb65e8bf0f4ede4ef61" +dependencies = [ + "rustls", + "tokio", +] + +[[package]] +name = "tokio-util" +version = "0.7.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ae9cec805b01e8fc3fd2fe289f89149a9b66dd16786abd8b19cfa7b48cb0098" +dependencies = [ + "bytes", + "futures-core", + "futures-sink", + "pin-project-lite", + "tokio", +] + +[[package]] +name = "tower-service" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3" + +[[package]] +name = "tracing" +version = "0.1.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "63e71662fa4b2a2c3a26f570f037eb95bb1f85397f3cd8076caed2f026a6d100" +dependencies = [ + "pin-project-lite", + "tracing-attributes", + "tracing-core", +] + +[[package]] +name = "tracing-attributes" +version = "0.1.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7490cfa5ec963746568740651ac6781f701c9c5ea257c58e057f3ba8cf69e8da" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "tracing-core" +version = "0.1.36" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db97caf9d906fbde555dd62fa95ddba9eecfd14cb388e4f491a66d74cd5fb79a" +dependencies = [ + "once_cell", + "valuable", +] + +[[package]] +name = "tracing-log" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3" +dependencies = [ + "log", + "once_cell", + "tracing-core", +] + +[[package]] +name = "tracing-subscriber" +version = "0.3.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2f30143827ddab0d256fd843b7a66d164e9f271cfa0dde49142c5ca0ca291f1e" +dependencies = [ + "matchers", + "nu-ansi-term", + "once_cell", + "regex-automata", + "sharded-slab", + "smallvec", + "thread_local", + "tracing", + "tracing-core", + "tracing-log", +] + +[[package]] +name = "try-lock" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" + +[[package]] +name = "unicode-ident" +version = "1.0.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9312f7c4f6ff9069b165498234ce8be658059c6728633667c526e27dc2cf1df5" + +[[package]] +name = "untrusted" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" + +[[package]] +name = "utf8parse" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" + +[[package]] +name = "valuable" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba73ea9cf16a25df0c8caa16c51acb937d5712a8429db78a3ee29d5dcacd3a65" + +[[package]] +name = "want" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa7760aed19e106de2c7c0b581b509f2f25d3dacaf737cb82ac61bc6d760b0e" +dependencies = [ + "try-lock", +] + +[[package]] +name = "wasi" +version = "0.11.1+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" + +[[package]] +name = "wasip2" +version = "1.0.2+wasi-0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9517f9239f02c069db75e65f174b3da828fe5f5b945c4dd26bd25d89c03ebcf5" +dependencies = [ + "wit-bindgen", +] + +[[package]] +name = "wasm-bindgen" +version = "0.2.108" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "64024a30ec1e37399cf85a7ffefebdb72205ca1c972291c51512360d90bd8566" +dependencies = [ + "cfg-if", + "once_cell", + "rustversion", + "wasm-bindgen-macro", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-macro" +version = "0.2.108" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "008b239d9c740232e71bd39e8ef6429d27097518b6b30bdf9086833bd5b6d608" +dependencies = [ + "quote", + "wasm-bindgen-macro-support", +] + +[[package]] +name = "wasm-bindgen-macro-support" +version = "0.2.108" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5256bae2d58f54820e6490f9839c49780dff84c65aeab9e772f15d5f0e913a55" +dependencies = [ + "bumpalo", + "proc-macro2", + "quote", + "syn", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-shared" +version = "0.2.108" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f01b580c9ac74c8d8f0c0e4afb04eeef2acf145458e52c03845ee9cd23e3d12" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "windows-link" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" + +[[package]] +name = "windows-sys" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" +dependencies = [ + "windows-targets 0.52.6", +] + +[[package]] +name = "windows-sys" +version = "0.60.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2f500e4d28234f72040990ec9d39e3a6b950f9f22d3dba18416c35882612bcb" +dependencies = [ + "windows-targets 0.53.5", +] + +[[package]] +name = "windows-sys" +version = "0.61.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae137229bcbd6cdf0f7b80a31df61766145077ddf49416a728b02cb3921ff3fc" +dependencies = [ + "windows-link", +] + +[[package]] +name = "windows-targets" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" +dependencies = [ + "windows_aarch64_gnullvm 0.52.6", + "windows_aarch64_msvc 0.52.6", + "windows_i686_gnu 0.52.6", + "windows_i686_gnullvm 0.52.6", + "windows_i686_msvc 0.52.6", + "windows_x86_64_gnu 0.52.6", + "windows_x86_64_gnullvm 0.52.6", + "windows_x86_64_msvc 0.52.6", +] + +[[package]] +name = "windows-targets" +version = "0.53.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4945f9f551b88e0d65f3db0bc25c33b8acea4d9e41163edf90dcd0b19f9069f3" +dependencies = [ + "windows-link", + "windows_aarch64_gnullvm 0.53.1", + "windows_aarch64_msvc 0.53.1", + "windows_i686_gnu 0.53.1", + "windows_i686_gnullvm 0.53.1", + "windows_i686_msvc 0.53.1", + "windows_x86_64_gnu 0.53.1", + "windows_x86_64_gnullvm 0.53.1", + "windows_x86_64_msvc 0.53.1", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9d8416fa8b42f5c947f8482c43e7d89e73a173cead56d044f6a56104a6d1b53" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9d782e804c2f632e395708e99a94275910eb9100b2114651e04744e9b125006" + +[[package]] +name = "windows_i686_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" + +[[package]] +name = "windows_i686_gnu" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "960e6da069d81e09becb0ca57a65220ddff016ff2d6af6a223cf372a506593a3" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fa7359d10048f68ab8b09fa71c3daccfb0e9b559aed648a8f95469c27057180c" + +[[package]] +name = "windows_i686_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" + +[[package]] +name = "windows_i686_msvc" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e7ac75179f18232fe9c285163565a57ef8d3c89254a30685b57d83a38d326c2" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c3842cdd74a865a8066ab39c8a7a473c0778a3f29370b5fd6b4b9aa7df4a499" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ffa179e2d07eee8ad8f57493436566c7cc30ac536a3379fdf008f47f6bb7ae1" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d6bbff5f0aada427a1e5a6da5f1f98158182f26556f345ac9e04d36d0ebed650" + +[[package]] +name = "wit-bindgen" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d7249219f66ced02969388cf2bb044a09756a083d0fab1e566056b04d9fbcaa5" + +[[package]] +name = "yasna" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e17bb3549cc1321ae1296b9cdc2698e2b6cb1992adfa19a8c72e5b7a738f44cd" +dependencies = [ + "time", +] + +[[package]] +name = "zeroize" +version = "1.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b97154e67e32c85465826e8bcc1c59429aaaf107c1e4a9e53c8d8ccd5eff88d0" + +[[package]] +name = "zmij" +version = "1.0.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3ff05f8caa9038894637571ae6b9e29466c1f4f829d26c9b28f869a29cbe3445" diff --git a/rust/Cargo.toml b/rust/Cargo.toml new file mode 100644 index 0000000..fd2f7f8 --- /dev/null +++ b/rust/Cargo.toml @@ -0,0 +1,98 @@ +[workspace] +resolver = "2" +members = [ + "crates/rustproxy", + "crates/rustproxy-config", + "crates/rustproxy-routing", + "crates/rustproxy-tls", + "crates/rustproxy-passthrough", + "crates/rustproxy-http", + "crates/rustproxy-nftables", + "crates/rustproxy-metrics", + "crates/rustproxy-security", +] + +[workspace.package] +version = "0.1.0" +edition = "2021" +license = "MIT" +authors = ["Lossless GmbH "] + +[workspace.dependencies] +# Async runtime +tokio = { version = "1", features = ["full"] } + +# Serialization +serde = { version = "1", features = ["derive"] } +serde_json = "1" + +# HTTP proxy engine (hyper-based) +hyper = { version = "1", features = ["http1", "http2", "server", "client"] } +hyper-util = { version = "0.1", features = ["tokio", "http1", "http2", "client-legacy", "server-auto"] } +http-body-util = "0.1" +bytes = "1" + +# ACME / Let's Encrypt +instant-acme = { version = "0.7", features = ["hyper-rustls"] } + +# TLS for passthrough SNI +rustls = { version = "0.23", features = ["ring"] } +tokio-rustls = "0.26" +rustls-pemfile = "2" + +# Self-signed cert generation for tests +rcgen = "0.13" + +# Temp directories for tests +tempfile = "3" + +# Lock-free atomics +arc-swap = "1" + +# Concurrent maps +dashmap = "6" + +# Domain wildcard matching +glob-match = "0.2" + +# IP/CIDR parsing +ipnet = "2" + +# JWT authentication +jsonwebtoken = "9" + +# Structured logging +tracing = "0.1" +tracing-subscriber = { version = "0.3", features = ["env-filter"] } + +# Error handling +thiserror = "2" +anyhow = "1" + +# CLI +clap = { version = "4", features = ["derive"] } + +# Regex for URL rewriting +regex = "1" + +# Base64 for basic auth +base64 = "0.22" + +# Cancellation / utility +tokio-util = "0.7" + +# Async traits +async-trait = "0.1" + +# libc for uid checks +libc = "0.2" + +# Internal crates +rustproxy-config = { path = "crates/rustproxy-config" } +rustproxy-routing = { path = "crates/rustproxy-routing" } +rustproxy-tls = { path = "crates/rustproxy-tls" } +rustproxy-passthrough = { path = "crates/rustproxy-passthrough" } +rustproxy-http = { path = "crates/rustproxy-http" } +rustproxy-nftables = { path = "crates/rustproxy-nftables" } +rustproxy-metrics = { path = "crates/rustproxy-metrics" } +rustproxy-security = { path = "crates/rustproxy-security" } diff --git a/rust/config/example.json b/rust/config/example.json new file mode 100644 index 0000000..b75a9fe --- /dev/null +++ b/rust/config/example.json @@ -0,0 +1,145 @@ +{ + "routes": [ + { + "id": "https-passthrough", + "name": "HTTPS Passthrough to Backend", + "match": { + "ports": 443, + "domains": "backend.example.com" + }, + "action": { + "type": "forward", + "targets": [ + { + "host": "10.0.0.1", + "port": 443 + } + ], + "tls": { + "mode": "passthrough" + } + }, + "priority": 10, + "enabled": true + }, + { + "id": "https-terminate", + "name": "HTTPS Terminate for API", + "match": { + "ports": 443, + "domains": "api.example.com" + }, + "action": { + "type": "forward", + "targets": [ + { + "host": "localhost", + "port": 8080 + } + ], + "tls": { + "mode": "terminate", + "certificate": "auto" + } + }, + "priority": 20, + "enabled": true + }, + { + "id": "http-redirect", + "name": "HTTP to HTTPS Redirect", + "match": { + "ports": 80, + "domains": ["api.example.com", "www.example.com"] + }, + "action": { + "type": "forward", + "targets": [ + { + "host": "localhost", + "port": 8080 + } + ] + }, + "priority": 0 + }, + { + "id": "load-balanced", + "name": "Load Balanced Backend", + "match": { + "ports": 443, + "domains": "*.example.com" + }, + "action": { + "type": "forward", + "targets": [ + { + "host": "backend1.internal", + "port": 8080 + }, + { + "host": "backend2.internal", + "port": 8080 + }, + { + "host": "backend3.internal", + "port": 8080 + } + ], + "tls": { + "mode": "terminate", + "certificate": "auto" + }, + "loadBalancing": { + "algorithm": "round-robin", + "healthCheck": { + "path": "/health", + "interval": 30, + "timeout": 5, + "unhealthyThreshold": 3, + "healthyThreshold": 2 + } + } + }, + "security": { + "ipAllowList": ["10.0.0.0/8", "192.168.0.0/16"], + "maxConnections": 1000, + "rateLimit": { + "enabled": true, + "maxRequests": 100, + "window": 60, + "keyBy": "ip" + } + }, + "headers": { + "request": { + "X-Forwarded-For": "{clientIp}", + "X-Real-IP": "{clientIp}" + }, + "response": { + "X-Powered-By": "RustProxy" + }, + "cors": { + "enabled": true, + "allowOrigin": "*", + "allowMethods": "GET,POST,PUT,DELETE,OPTIONS", + "allowHeaders": "Content-Type,Authorization", + "allowCredentials": false, + "maxAge": 86400 + } + }, + "priority": 5 + } + ], + "acme": { + "email": "admin@example.com", + "useProduction": false, + "port": 80 + }, + "connectionTimeout": 30000, + "socketTimeout": 3600000, + "maxConnectionsPerIp": 100, + "connectionRateLimitPerMinute": 300, + "keepAliveTreatment": "extended", + "enableDetailedLogging": false +} diff --git a/rust/crates/rustproxy-config/Cargo.toml b/rust/crates/rustproxy-config/Cargo.toml new file mode 100644 index 0000000..9bb304c --- /dev/null +++ b/rust/crates/rustproxy-config/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "rustproxy-config" +version.workspace = true +edition.workspace = true +license.workspace = true +authors.workspace = true +description = "Configuration types for RustProxy, compatible with SmartProxy JSON schema" + +[dependencies] +serde = { workspace = true } +serde_json = { workspace = true } +thiserror = { workspace = true } +ipnet = { workspace = true } diff --git a/rust/crates/rustproxy-config/src/helpers.rs b/rust/crates/rustproxy-config/src/helpers.rs new file mode 100644 index 0000000..d10c4b2 --- /dev/null +++ b/rust/crates/rustproxy-config/src/helpers.rs @@ -0,0 +1,334 @@ +use crate::route_types::*; +use crate::tls_types::*; + +/// Create a simple HTTP forwarding route. +/// Equivalent to SmartProxy's `createHttpRoute()`. +pub fn create_http_route( + domains: impl Into, + target_host: impl Into, + target_port: u16, +) -> RouteConfig { + RouteConfig { + id: None, + route_match: RouteMatch { + ports: PortRange::Single(80), + domains: Some(domains.into()), + path: None, + client_ip: None, + tls_version: None, + headers: None, + }, + action: RouteAction { + action_type: RouteActionType::Forward, + targets: Some(vec![RouteTarget { + target_match: None, + host: HostSpec::Single(target_host.into()), + port: PortSpec::Fixed(target_port), + tls: None, + websocket: None, + load_balancing: None, + send_proxy_protocol: None, + headers: None, + advanced: None, + priority: None, + }]), + tls: None, + websocket: None, + load_balancing: None, + advanced: None, + options: None, + forwarding_engine: None, + nftables: None, + send_proxy_protocol: None, + }, + headers: None, + security: None, + name: None, + description: None, + priority: None, + tags: None, + enabled: None, + } +} + +/// Create an HTTPS termination route. +/// Equivalent to SmartProxy's `createHttpsTerminateRoute()`. +pub fn create_https_terminate_route( + domains: impl Into, + target_host: impl Into, + target_port: u16, +) -> RouteConfig { + let mut route = create_http_route(domains, target_host, target_port); + route.route_match.ports = PortRange::Single(443); + route.action.tls = Some(RouteTls { + mode: TlsMode::Terminate, + certificate: Some(CertificateSpec::Auto("auto".to_string())), + acme: None, + versions: None, + ciphers: None, + honor_cipher_order: None, + session_timeout: None, + }); + route +} + +/// Create a TLS passthrough route. +/// Equivalent to SmartProxy's `createHttpsPassthroughRoute()`. +pub fn create_https_passthrough_route( + domains: impl Into, + target_host: impl Into, + target_port: u16, +) -> RouteConfig { + let mut route = create_http_route(domains, target_host, target_port); + route.route_match.ports = PortRange::Single(443); + route.action.tls = Some(RouteTls { + mode: TlsMode::Passthrough, + certificate: None, + acme: None, + versions: None, + ciphers: None, + honor_cipher_order: None, + session_timeout: None, + }); + route +} + +/// Create an HTTP-to-HTTPS redirect route. +/// Equivalent to SmartProxy's `createHttpToHttpsRedirect()`. +pub fn create_http_to_https_redirect( + domains: impl Into, +) -> RouteConfig { + let domains = domains.into(); + RouteConfig { + id: None, + route_match: RouteMatch { + ports: PortRange::Single(80), + domains: Some(domains), + path: None, + client_ip: None, + tls_version: None, + headers: None, + }, + action: RouteAction { + action_type: RouteActionType::Forward, + targets: None, + tls: None, + websocket: None, + load_balancing: None, + advanced: Some(RouteAdvanced { + timeout: None, + headers: None, + keep_alive: None, + static_files: None, + test_response: Some(RouteTestResponse { + status: 301, + headers: { + let mut h = std::collections::HashMap::new(); + h.insert("Location".to_string(), "https://{domain}{path}".to_string()); + h + }, + body: String::new(), + }), + url_rewrite: None, + }), + options: None, + forwarding_engine: None, + nftables: None, + send_proxy_protocol: None, + }, + headers: None, + security: None, + name: Some("HTTP to HTTPS Redirect".to_string()), + description: None, + priority: None, + tags: None, + enabled: None, + } +} + +/// Create a complete HTTPS server with HTTP redirect. +/// Equivalent to SmartProxy's `createCompleteHttpsServer()`. +pub fn create_complete_https_server( + domain: impl Into, + target_host: impl Into, + target_port: u16, +) -> Vec { + let domain = domain.into(); + let target_host = target_host.into(); + + vec![ + create_http_to_https_redirect(DomainSpec::Single(domain.clone())), + create_https_terminate_route( + DomainSpec::Single(domain), + target_host, + target_port, + ), + ] +} + +/// Create a load balancer route. +/// Equivalent to SmartProxy's `createLoadBalancerRoute()`. +pub fn create_load_balancer_route( + domains: impl Into, + targets: Vec<(String, u16)>, + tls: Option, +) -> RouteConfig { + let route_targets: Vec = targets + .into_iter() + .map(|(host, port)| RouteTarget { + target_match: None, + host: HostSpec::Single(host), + port: PortSpec::Fixed(port), + tls: None, + websocket: None, + load_balancing: None, + send_proxy_protocol: None, + headers: None, + advanced: None, + priority: None, + }) + .collect(); + + let port = if tls.is_some() { 443 } else { 80 }; + + RouteConfig { + id: None, + route_match: RouteMatch { + ports: PortRange::Single(port), + domains: Some(domains.into()), + path: None, + client_ip: None, + tls_version: None, + headers: None, + }, + action: RouteAction { + action_type: RouteActionType::Forward, + targets: Some(route_targets), + tls, + websocket: None, + load_balancing: Some(RouteLoadBalancing { + algorithm: LoadBalancingAlgorithm::RoundRobin, + health_check: None, + }), + advanced: None, + options: None, + forwarding_engine: None, + nftables: None, + send_proxy_protocol: None, + }, + headers: None, + security: None, + name: Some("Load Balancer".to_string()), + description: None, + priority: None, + tags: None, + enabled: None, + } +} + +// Convenience conversions for DomainSpec +impl From<&str> for DomainSpec { + fn from(s: &str) -> Self { + DomainSpec::Single(s.to_string()) + } +} + +impl From for DomainSpec { + fn from(s: String) -> Self { + DomainSpec::Single(s) + } +} + +impl From> for DomainSpec { + fn from(v: Vec) -> Self { + DomainSpec::List(v) + } +} + +impl From> for DomainSpec { + fn from(v: Vec<&str>) -> Self { + DomainSpec::List(v.into_iter().map(|s| s.to_string()).collect()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::tls_types::TlsMode; + + #[test] + fn test_create_http_route() { + let route = create_http_route("example.com", "localhost", 8080); + assert_eq!(route.route_match.ports.to_ports(), vec![80]); + let domains = route.route_match.domains.as_ref().unwrap().to_vec(); + assert_eq!(domains, vec!["example.com"]); + let target = &route.action.targets.as_ref().unwrap()[0]; + assert_eq!(target.host.first(), "localhost"); + assert_eq!(target.port.resolve(80), 8080); + assert!(route.action.tls.is_none()); + } + + #[test] + fn test_create_https_terminate_route() { + let route = create_https_terminate_route("api.example.com", "backend", 3000); + assert_eq!(route.route_match.ports.to_ports(), vec![443]); + let tls = route.action.tls.as_ref().unwrap(); + assert_eq!(tls.mode, TlsMode::Terminate); + assert!(tls.certificate.as_ref().unwrap().is_auto()); + } + + #[test] + fn test_create_https_passthrough_route() { + let route = create_https_passthrough_route("secure.example.com", "backend", 443); + assert_eq!(route.route_match.ports.to_ports(), vec![443]); + let tls = route.action.tls.as_ref().unwrap(); + assert_eq!(tls.mode, TlsMode::Passthrough); + assert!(tls.certificate.is_none()); + } + + #[test] + fn test_create_http_to_https_redirect() { + let route = create_http_to_https_redirect("example.com"); + assert_eq!(route.route_match.ports.to_ports(), vec![80]); + assert!(route.action.targets.is_none()); + let test_response = route.action.advanced.as_ref().unwrap().test_response.as_ref().unwrap(); + assert_eq!(test_response.status, 301); + assert!(test_response.headers.contains_key("Location")); + } + + #[test] + fn test_create_complete_https_server() { + let routes = create_complete_https_server("example.com", "backend", 8080); + assert_eq!(routes.len(), 2); + // First route is HTTP redirect + assert_eq!(routes[0].route_match.ports.to_ports(), vec![80]); + // Second route is HTTPS terminate + assert_eq!(routes[1].route_match.ports.to_ports(), vec![443]); + } + + #[test] + fn test_create_load_balancer_route() { + let targets = vec![ + ("backend1".to_string(), 8080), + ("backend2".to_string(), 8080), + ("backend3".to_string(), 8080), + ]; + let route = create_load_balancer_route("*.example.com", targets, None); + assert_eq!(route.route_match.ports.to_ports(), vec![80]); + assert_eq!(route.action.targets.as_ref().unwrap().len(), 3); + let lb = route.action.load_balancing.as_ref().unwrap(); + assert_eq!(lb.algorithm, LoadBalancingAlgorithm::RoundRobin); + } + + #[test] + fn test_domain_spec_from_str() { + let spec: DomainSpec = "example.com".into(); + assert_eq!(spec.to_vec(), vec!["example.com"]); + } + + #[test] + fn test_domain_spec_from_vec() { + let spec: DomainSpec = vec!["a.com", "b.com"].into(); + assert_eq!(spec.to_vec(), vec!["a.com", "b.com"]); + } +} diff --git a/rust/crates/rustproxy-config/src/lib.rs b/rust/crates/rustproxy-config/src/lib.rs new file mode 100644 index 0000000..e62471c --- /dev/null +++ b/rust/crates/rustproxy-config/src/lib.rs @@ -0,0 +1,19 @@ +//! # rustproxy-config +//! +//! Configuration types for RustProxy, fully compatible with SmartProxy's JSON schema. +//! All types use `#[serde(rename_all = "camelCase")]` to match TypeScript field naming. + +pub mod route_types; +pub mod proxy_options; +pub mod tls_types; +pub mod security_types; +pub mod validation; +pub mod helpers; + +// Re-export all primary types +pub use route_types::*; +pub use proxy_options::*; +pub use tls_types::*; +pub use security_types::*; +pub use validation::*; +pub use helpers::*; diff --git a/rust/crates/rustproxy-config/src/proxy_options.rs b/rust/crates/rustproxy-config/src/proxy_options.rs new file mode 100644 index 0000000..0847a7f --- /dev/null +++ b/rust/crates/rustproxy-config/src/proxy_options.rs @@ -0,0 +1,439 @@ +use serde::{Deserialize, Serialize}; + +use crate::route_types::RouteConfig; + +/// Global ACME configuration options. +/// Matches TypeScript: `IAcmeOptions` +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct AcmeOptions { + #[serde(skip_serializing_if = "Option::is_none")] + pub enabled: Option, + /// Required when any route uses certificate: 'auto' + #[serde(skip_serializing_if = "Option::is_none")] + pub email: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub environment: Option, + /// Alias for email + #[serde(skip_serializing_if = "Option::is_none")] + pub account_email: Option, + /// Port for HTTP-01 challenges (default: 80) + #[serde(skip_serializing_if = "Option::is_none")] + pub port: Option, + /// Use Let's Encrypt production (default: false) + #[serde(skip_serializing_if = "Option::is_none")] + pub use_production: Option, + /// Days before expiry to renew (default: 30) + #[serde(skip_serializing_if = "Option::is_none")] + pub renew_threshold_days: Option, + /// Enable automatic renewal (default: true) + #[serde(skip_serializing_if = "Option::is_none")] + pub auto_renew: Option, + /// Directory to store certificates (default: './certs') + #[serde(skip_serializing_if = "Option::is_none")] + pub certificate_store: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub skip_configured_certs: Option, + /// How often to check for renewals (default: 24) + #[serde(skip_serializing_if = "Option::is_none")] + pub renew_check_interval_hours: Option, +} + +/// ACME environment. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum AcmeEnvironment { + Production, + Staging, +} + +/// Default target configuration. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct DefaultTarget { + pub host: String, + pub port: u16, +} + +/// Default security configuration. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct DefaultSecurity { + #[serde(skip_serializing_if = "Option::is_none")] + pub ip_allow_list: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub ip_block_list: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub max_connections: Option, +} + +/// Default configuration. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct DefaultConfig { + #[serde(skip_serializing_if = "Option::is_none")] + pub target: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub security: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub preserve_source_ip: Option, +} + +/// Keep-alive treatment. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum KeepAliveTreatment { + Standard, + Extended, + Immortal, +} + +/// Metrics configuration. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct MetricsConfig { + #[serde(skip_serializing_if = "Option::is_none")] + pub enabled: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub sample_interval_ms: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub retention_seconds: Option, +} + +/// RustProxy configuration options. +/// Matches TypeScript: `ISmartProxyOptions` +/// +/// This is the top-level configuration that can be loaded from a JSON file +/// or constructed programmatically. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct RustProxyOptions { + /// The unified configuration array (required) + pub routes: Vec, + + /// Preserve client IP when forwarding + #[serde(skip_serializing_if = "Option::is_none")] + pub preserve_source_ip: Option, + + /// List of trusted proxy IPs that can send PROXY protocol + #[serde(skip_serializing_if = "Option::is_none")] + pub proxy_ips: Option>, + + /// Global option to accept PROXY protocol + #[serde(skip_serializing_if = "Option::is_none")] + pub accept_proxy_protocol: Option, + + /// Global option to send PROXY protocol to all targets + #[serde(skip_serializing_if = "Option::is_none")] + pub send_proxy_protocol: Option, + + /// Global/default settings + #[serde(skip_serializing_if = "Option::is_none")] + pub defaults: Option, + + // โ”€โ”€โ”€ Timeout Settings โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + + /// Timeout for establishing connection to backend (ms), default: 30000 + #[serde(skip_serializing_if = "Option::is_none")] + pub connection_timeout: Option, + + /// Timeout for initial data/SNI (ms), default: 60000 + #[serde(skip_serializing_if = "Option::is_none")] + pub initial_data_timeout: Option, + + /// Socket inactivity timeout (ms), default: 3600000 + #[serde(skip_serializing_if = "Option::is_none")] + pub socket_timeout: Option, + + /// How often to check for inactive connections (ms), default: 60000 + #[serde(skip_serializing_if = "Option::is_none")] + pub inactivity_check_interval: Option, + + /// Default max connection lifetime (ms), default: 86400000 + #[serde(skip_serializing_if = "Option::is_none")] + pub max_connection_lifetime: Option, + + /// Inactivity timeout (ms), default: 14400000 + #[serde(skip_serializing_if = "Option::is_none")] + pub inactivity_timeout: Option, + + /// Maximum time to wait for connections to close during shutdown (ms) + #[serde(skip_serializing_if = "Option::is_none")] + pub graceful_shutdown_timeout: Option, + + // โ”€โ”€โ”€ Socket Optimization โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + + /// Disable Nagle's algorithm (default: true) + #[serde(skip_serializing_if = "Option::is_none")] + pub no_delay: Option, + + /// Enable TCP keepalive (default: true) + #[serde(skip_serializing_if = "Option::is_none")] + pub keep_alive: Option, + + /// Initial delay before sending keepalive probes (ms) + #[serde(skip_serializing_if = "Option::is_none")] + pub keep_alive_initial_delay: Option, + + /// Maximum bytes to buffer during connection setup + #[serde(skip_serializing_if = "Option::is_none")] + pub max_pending_data_size: Option, + + // โ”€โ”€โ”€ Enhanced Features โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + + /// Disable inactivity checking entirely + #[serde(skip_serializing_if = "Option::is_none")] + pub disable_inactivity_check: Option, + + /// Enable TCP keep-alive probes + #[serde(skip_serializing_if = "Option::is_none")] + pub enable_keep_alive_probes: Option, + + /// Enable detailed connection logging + #[serde(skip_serializing_if = "Option::is_none")] + pub enable_detailed_logging: Option, + + /// Enable TLS handshake debug logging + #[serde(skip_serializing_if = "Option::is_none")] + pub enable_tls_debug_logging: Option, + + /// Randomize timeouts to prevent thundering herd + #[serde(skip_serializing_if = "Option::is_none")] + pub enable_randomized_timeouts: Option, + + // โ”€โ”€โ”€ Rate Limiting โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + + /// Maximum simultaneous connections from a single IP + #[serde(skip_serializing_if = "Option::is_none")] + pub max_connections_per_ip: Option, + + /// Max new connections per minute from a single IP + #[serde(skip_serializing_if = "Option::is_none")] + pub connection_rate_limit_per_minute: Option, + + // โ”€โ”€โ”€ Keep-Alive Settings โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + + /// How to treat keep-alive connections + #[serde(skip_serializing_if = "Option::is_none")] + pub keep_alive_treatment: Option, + + /// Multiplier for inactivity timeout for keep-alive connections + #[serde(skip_serializing_if = "Option::is_none")] + pub keep_alive_inactivity_multiplier: Option, + + /// Extended lifetime for keep-alive connections (ms) + #[serde(skip_serializing_if = "Option::is_none")] + pub extended_keep_alive_lifetime: Option, + + // โ”€โ”€โ”€ HttpProxy Integration โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + + /// Array of ports to forward to HttpProxy + #[serde(skip_serializing_if = "Option::is_none")] + pub use_http_proxy: Option>, + + /// Port where HttpProxy is listening (default: 8443) + #[serde(skip_serializing_if = "Option::is_none")] + pub http_proxy_port: Option, + + // โ”€โ”€โ”€ Metrics โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + + /// Metrics configuration + #[serde(skip_serializing_if = "Option::is_none")] + pub metrics: Option, + + // โ”€โ”€โ”€ ACME โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + + /// Global ACME configuration + #[serde(skip_serializing_if = "Option::is_none")] + pub acme: Option, +} + +impl Default for RustProxyOptions { + fn default() -> Self { + Self { + routes: Vec::new(), + preserve_source_ip: None, + proxy_ips: None, + accept_proxy_protocol: None, + send_proxy_protocol: None, + defaults: None, + connection_timeout: None, + initial_data_timeout: None, + socket_timeout: None, + inactivity_check_interval: None, + max_connection_lifetime: None, + inactivity_timeout: None, + graceful_shutdown_timeout: None, + no_delay: None, + keep_alive: None, + keep_alive_initial_delay: None, + max_pending_data_size: None, + disable_inactivity_check: None, + enable_keep_alive_probes: None, + enable_detailed_logging: None, + enable_tls_debug_logging: None, + enable_randomized_timeouts: None, + max_connections_per_ip: None, + connection_rate_limit_per_minute: None, + keep_alive_treatment: None, + keep_alive_inactivity_multiplier: None, + extended_keep_alive_lifetime: None, + use_http_proxy: None, + http_proxy_port: None, + metrics: None, + acme: None, + } + } +} + +impl RustProxyOptions { + /// Load configuration from a JSON file. + pub fn from_file(path: &str) -> Result> { + let content = std::fs::read_to_string(path)?; + let options: Self = serde_json::from_str(&content)?; + Ok(options) + } + + /// Get the effective connection timeout in milliseconds. + pub fn effective_connection_timeout(&self) -> u64 { + self.connection_timeout.unwrap_or(30_000) + } + + /// Get the effective initial data timeout in milliseconds. + pub fn effective_initial_data_timeout(&self) -> u64 { + self.initial_data_timeout.unwrap_or(60_000) + } + + /// Get the effective socket timeout in milliseconds. + pub fn effective_socket_timeout(&self) -> u64 { + self.socket_timeout.unwrap_or(3_600_000) + } + + /// Get the effective max connection lifetime in milliseconds. + pub fn effective_max_connection_lifetime(&self) -> u64 { + self.max_connection_lifetime.unwrap_or(86_400_000) + } + + /// Get all unique ports that routes listen on. + pub fn all_listening_ports(&self) -> Vec { + let mut ports: Vec = self.routes + .iter() + .flat_map(|r| r.listening_ports()) + .collect(); + ports.sort(); + ports.dedup(); + ports + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::helpers::*; + + #[test] + fn test_serde_roundtrip_minimal() { + let options = RustProxyOptions { + routes: vec![create_http_route("example.com", "localhost", 8080)], + ..Default::default() + }; + let json = serde_json::to_string(&options).unwrap(); + let parsed: RustProxyOptions = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed.routes.len(), 1); + } + + #[test] + fn test_serde_roundtrip_full() { + let options = RustProxyOptions { + routes: vec![ + create_http_route("a.com", "backend1", 8080), + create_https_passthrough_route("b.com", "backend2", 443), + ], + connection_timeout: Some(5000), + socket_timeout: Some(60000), + max_connections_per_ip: Some(100), + acme: Some(AcmeOptions { + enabled: Some(true), + email: Some("admin@example.com".to_string()), + environment: Some(AcmeEnvironment::Staging), + account_email: None, + port: None, + use_production: None, + renew_threshold_days: None, + auto_renew: None, + certificate_store: None, + skip_configured_certs: None, + renew_check_interval_hours: None, + }), + ..Default::default() + }; + let json = serde_json::to_string_pretty(&options).unwrap(); + let parsed: RustProxyOptions = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed.routes.len(), 2); + assert_eq!(parsed.connection_timeout, Some(5000)); + } + + #[test] + fn test_default_timeouts() { + let options = RustProxyOptions::default(); + assert_eq!(options.effective_connection_timeout(), 30_000); + assert_eq!(options.effective_initial_data_timeout(), 60_000); + assert_eq!(options.effective_socket_timeout(), 3_600_000); + assert_eq!(options.effective_max_connection_lifetime(), 86_400_000); + } + + #[test] + fn test_custom_timeouts() { + let options = RustProxyOptions { + connection_timeout: Some(5000), + initial_data_timeout: Some(10000), + socket_timeout: Some(30000), + max_connection_lifetime: Some(60000), + ..Default::default() + }; + assert_eq!(options.effective_connection_timeout(), 5000); + assert_eq!(options.effective_initial_data_timeout(), 10000); + assert_eq!(options.effective_socket_timeout(), 30000); + assert_eq!(options.effective_max_connection_lifetime(), 60000); + } + + #[test] + fn test_all_listening_ports() { + let options = RustProxyOptions { + routes: vec![ + create_http_route("a.com", "backend", 8080), // port 80 + create_https_passthrough_route("b.com", "backend", 443), // port 443 + create_http_route("c.com", "backend", 9090), // port 80 (duplicate) + ], + ..Default::default() + }; + let ports = options.all_listening_ports(); + assert_eq!(ports, vec![80, 443]); + } + + #[test] + fn test_camel_case_field_names() { + let options = RustProxyOptions { + connection_timeout: Some(5000), + max_connections_per_ip: Some(100), + keep_alive_treatment: Some(KeepAliveTreatment::Extended), + ..Default::default() + }; + let json = serde_json::to_string(&options).unwrap(); + assert!(json.contains("connectionTimeout")); + assert!(json.contains("maxConnectionsPerIp")); + assert!(json.contains("keepAliveTreatment")); + } + + #[test] + fn test_deserialize_example_json() { + let content = std::fs::read_to_string( + concat!(env!("CARGO_MANIFEST_DIR"), "/../../config/example.json") + ).unwrap(); + let options: RustProxyOptions = serde_json::from_str(&content).unwrap(); + assert_eq!(options.routes.len(), 4); + let ports = options.all_listening_ports(); + assert!(ports.contains(&80)); + assert!(ports.contains(&443)); + } +} diff --git a/rust/crates/rustproxy-config/src/route_types.rs b/rust/crates/rustproxy-config/src/route_types.rs new file mode 100644 index 0000000..066e289 --- /dev/null +++ b/rust/crates/rustproxy-config/src/route_types.rs @@ -0,0 +1,603 @@ +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +use crate::tls_types::RouteTls; +use crate::security_types::RouteSecurity; + +// โ”€โ”€โ”€ Port Range โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +/// Port range specification format. +/// Matches TypeScript: `type TPortRange = number | number[] | Array<{ from: number; to: number }>` +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(untagged)] +pub enum PortRange { + /// Single port number + Single(u16), + /// Array of port numbers + List(Vec), + /// Array of port ranges + Ranges(Vec), +} + +impl PortRange { + /// Expand the port range into a flat list of ports. + pub fn to_ports(&self) -> Vec { + match self { + PortRange::Single(p) => vec![*p], + PortRange::List(ports) => ports.clone(), + PortRange::Ranges(ranges) => { + ranges.iter().flat_map(|r| r.from..=r.to).collect() + } + } + } +} + +/// A from-to port range. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PortRangeSpec { + pub from: u16, + pub to: u16, +} + +// โ”€โ”€โ”€ Route Action Type โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +/// Supported action types for route configurations. +/// Matches TypeScript: `type TRouteActionType = 'forward' | 'socket-handler'` +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "kebab-case")] +pub enum RouteActionType { + Forward, + SocketHandler, +} + +// โ”€โ”€โ”€ Forwarding Engine โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +/// Forwarding engine specification. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum ForwardingEngine { + Node, + Nftables, +} + +// โ”€โ”€โ”€ Route Match โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +/// Domain specification: single string or array. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(untagged)] +pub enum DomainSpec { + Single(String), + List(Vec), +} + +impl DomainSpec { + pub fn to_vec(&self) -> Vec<&str> { + match self { + DomainSpec::Single(s) => vec![s.as_str()], + DomainSpec::List(v) => v.iter().map(|s| s.as_str()).collect(), + } + } +} + +/// Header match value: either exact string or regex pattern. +/// In JSON, all values come as strings. Regex patterns are prefixed with `/` and suffixed with `/`. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(untagged)] +pub enum HeaderMatchValue { + Exact(String), +} + +/// Route match criteria for incoming requests. +/// Matches TypeScript: `IRouteMatch` +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct RouteMatch { + /// Listen on these ports (required) + pub ports: PortRange, + + /// Optional domain patterns to match (default: all domains) + #[serde(skip_serializing_if = "Option::is_none")] + pub domains: Option, + + /// Match specific paths + #[serde(skip_serializing_if = "Option::is_none")] + pub path: Option, + + /// Match specific client IPs + #[serde(skip_serializing_if = "Option::is_none")] + pub client_ip: Option>, + + /// Match specific TLS versions + #[serde(skip_serializing_if = "Option::is_none")] + pub tls_version: Option>, + + /// Match specific HTTP headers + #[serde(skip_serializing_if = "Option::is_none")] + pub headers: Option>, +} + +// โ”€โ”€โ”€ Target Match โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +/// Target-specific match criteria for sub-routing within a route. +/// Matches TypeScript: `ITargetMatch` +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct TargetMatch { + /// Match specific ports from the route + #[serde(skip_serializing_if = "Option::is_none")] + pub ports: Option>, + /// Match specific paths (supports wildcards like /api/*) + #[serde(skip_serializing_if = "Option::is_none")] + pub path: Option, + /// Match specific HTTP headers + #[serde(skip_serializing_if = "Option::is_none")] + pub headers: Option>, + /// Match specific HTTP methods + #[serde(skip_serializing_if = "Option::is_none")] + pub method: Option>, +} + +// โ”€โ”€โ”€ WebSocket Config โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +/// WebSocket configuration. +/// Matches TypeScript: `IRouteWebSocket` +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct RouteWebSocket { + pub enabled: bool, + #[serde(skip_serializing_if = "Option::is_none")] + pub ping_interval: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub ping_timeout: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub max_payload_size: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub custom_headers: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub subprotocols: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub rewrite_path: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub allowed_origins: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub authenticate_request: Option, +} + +// โ”€โ”€โ”€ Load Balancing โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +/// Load balancing algorithm. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "kebab-case")] +pub enum LoadBalancingAlgorithm { + RoundRobin, + LeastConnections, + IpHash, +} + +/// Health check configuration. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct HealthCheck { + pub path: String, + pub interval: u64, + pub timeout: u64, + pub unhealthy_threshold: u32, + pub healthy_threshold: u32, +} + +/// Load balancing configuration. +/// Matches TypeScript: `IRouteLoadBalancing` +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct RouteLoadBalancing { + pub algorithm: LoadBalancingAlgorithm, + #[serde(skip_serializing_if = "Option::is_none")] + pub health_check: Option, +} + +// โ”€โ”€โ”€ CORS โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +/// Allowed origin specification. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(untagged)] +pub enum AllowOrigin { + Single(String), + List(Vec), +} + +/// CORS configuration for a route. +/// Matches TypeScript: `IRouteCors` +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct RouteCors { + pub enabled: bool, + #[serde(skip_serializing_if = "Option::is_none")] + pub allow_origin: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub allow_methods: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub allow_headers: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub allow_credentials: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub expose_headers: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub max_age: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub preflight: Option, +} + +// โ”€โ”€โ”€ Headers โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +/// Headers configuration. +/// Matches TypeScript: `IRouteHeaders` +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct RouteHeaders { + /// Headers to add/modify for requests to backend + #[serde(skip_serializing_if = "Option::is_none")] + pub request: Option>, + /// Headers to add/modify for responses to client + #[serde(skip_serializing_if = "Option::is_none")] + pub response: Option>, + /// CORS configuration + #[serde(skip_serializing_if = "Option::is_none")] + pub cors: Option, +} + +// โ”€โ”€โ”€ Static Files โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +/// Static file server configuration. +/// Matches TypeScript: `IRouteStaticFiles` +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct RouteStaticFiles { + pub root: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub index: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub headers: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub directory: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub index_files: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub cache_control: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub expires: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub follow_symlinks: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub disable_directory_listing: Option, +} + +// โ”€โ”€โ”€ Test Response โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +/// Test route response configuration. +/// Matches TypeScript: `IRouteTestResponse` +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct RouteTestResponse { + pub status: u16, + pub headers: HashMap, + pub body: String, +} + +// โ”€โ”€โ”€ URL Rewriting โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +/// URL rewriting configuration. +/// Matches TypeScript: `IRouteUrlRewrite` +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct RouteUrlRewrite { + /// RegExp pattern to match in URL + pub pattern: String, + /// Replacement pattern + pub target: String, + /// RegExp flags + #[serde(skip_serializing_if = "Option::is_none")] + pub flags: Option, + /// Only apply to path, not query string + #[serde(skip_serializing_if = "Option::is_none")] + pub only_rewrite_path: Option, +} + +// โ”€โ”€โ”€ Advanced Options โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +/// Advanced options for route actions. +/// Matches TypeScript: `IRouteAdvanced` +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct RouteAdvanced { + #[serde(skip_serializing_if = "Option::is_none")] + pub timeout: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub headers: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub keep_alive: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub static_files: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub test_response: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub url_rewrite: Option, +} + +// โ”€โ”€โ”€ NFTables Options โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +/// NFTables protocol type. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum NfTablesProtocol { + Tcp, + Udp, + All, +} + +/// NFTables-specific configuration options. +/// Matches TypeScript: `INfTablesOptions` +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct NfTablesOptions { + #[serde(skip_serializing_if = "Option::is_none")] + pub preserve_source_ip: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub protocol: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub max_rate: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub priority: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub table_name: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub use_ip_sets: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub use_advanced_nat: Option, +} + +// โ”€โ”€โ”€ Backend Protocol โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +/// Backend protocol. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum BackendProtocol { + Http1, + Http2, +} + +/// Action options. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ActionOptions { + #[serde(skip_serializing_if = "Option::is_none")] + pub backend_protocol: Option, + /// Catch-all for additional options + #[serde(flatten)] + pub extra: HashMap, +} + +// โ”€โ”€โ”€ Route Target โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +/// Host specification: single string or array of strings. +/// Note: Dynamic host functions are only available via programmatic API, not JSON. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(untagged)] +pub enum HostSpec { + Single(String), + List(Vec), +} + +impl HostSpec { + pub fn to_vec(&self) -> Vec<&str> { + match self { + HostSpec::Single(s) => vec![s.as_str()], + HostSpec::List(v) => v.iter().map(|s| s.as_str()).collect(), + } + } + + pub fn first(&self) -> &str { + match self { + HostSpec::Single(s) => s.as_str(), + HostSpec::List(v) => v.first().map(|s| s.as_str()).unwrap_or(""), + } + } +} + +/// Port specification: number or "preserve". +/// Note: Dynamic port functions are only available via programmatic API, not JSON. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(untagged)] +pub enum PortSpec { + /// Fixed port number + Fixed(u16), + /// Special string value like "preserve" + Special(String), +} + +impl PortSpec { + /// Resolve the port, using incoming_port when "preserve" is specified. + pub fn resolve(&self, incoming_port: u16) -> u16 { + match self { + PortSpec::Fixed(p) => *p, + PortSpec::Special(s) if s == "preserve" => incoming_port, + PortSpec::Special(_) => incoming_port, // fallback + } + } +} + +/// Target configuration for forwarding with sub-matching and overrides. +/// Matches TypeScript: `IRouteTarget` +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct RouteTarget { + /// Optional sub-matching criteria within the route + #[serde(rename = "match")] + #[serde(skip_serializing_if = "Option::is_none")] + pub target_match: Option, + + /// Target host(s) + pub host: HostSpec, + + /// Target port + pub port: PortSpec, + + /// Override route-level TLS settings + #[serde(skip_serializing_if = "Option::is_none")] + pub tls: Option, + + /// Override route-level WebSocket settings + #[serde(skip_serializing_if = "Option::is_none")] + pub websocket: Option, + + /// Override route-level load balancing + #[serde(skip_serializing_if = "Option::is_none")] + pub load_balancing: Option, + + /// Override route-level proxy protocol setting + #[serde(skip_serializing_if = "Option::is_none")] + pub send_proxy_protocol: Option, + + /// Override route-level headers + #[serde(skip_serializing_if = "Option::is_none")] + pub headers: Option, + + /// Override route-level advanced settings + #[serde(skip_serializing_if = "Option::is_none")] + pub advanced: Option, + + /// Priority for matching (higher values checked first, default: 0) + #[serde(skip_serializing_if = "Option::is_none")] + pub priority: Option, +} + +// โ”€โ”€โ”€ Route Action โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +/// Action configuration for route handling. +/// Matches TypeScript: `IRouteAction` +/// +/// Note: `socketHandler` is not serializable in JSON. Use the programmatic API +/// for socket handler routes. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct RouteAction { + /// Basic routing type + #[serde(rename = "type")] + pub action_type: RouteActionType, + + /// Targets for forwarding (array supports multiple targets with sub-matching) + #[serde(skip_serializing_if = "Option::is_none")] + pub targets: Option>, + + /// TLS handling (default for all targets) + #[serde(skip_serializing_if = "Option::is_none")] + pub tls: Option, + + /// WebSocket support (default for all targets) + #[serde(skip_serializing_if = "Option::is_none")] + pub websocket: Option, + + /// Load balancing options (default for all targets) + #[serde(skip_serializing_if = "Option::is_none")] + pub load_balancing: Option, + + /// Advanced options (default for all targets) + #[serde(skip_serializing_if = "Option::is_none")] + pub advanced: Option, + + /// Additional options + #[serde(skip_serializing_if = "Option::is_none")] + pub options: Option, + + /// Forwarding engine specification + #[serde(skip_serializing_if = "Option::is_none")] + pub forwarding_engine: Option, + + /// NFTables-specific options + #[serde(skip_serializing_if = "Option::is_none")] + pub nftables: Option, + + /// PROXY protocol support (default for all targets) + #[serde(skip_serializing_if = "Option::is_none")] + pub send_proxy_protocol: Option, +} + +// โ”€โ”€โ”€ Route Config โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + +/// The core unified configuration interface. +/// Matches TypeScript: `IRouteConfig` +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct RouteConfig { + /// Unique identifier + #[serde(skip_serializing_if = "Option::is_none")] + pub id: Option, + + /// What to match + #[serde(rename = "match")] + pub route_match: RouteMatch, + + /// What to do with matched traffic + pub action: RouteAction, + + /// Custom headers + #[serde(skip_serializing_if = "Option::is_none")] + pub headers: Option, + + /// Security features + #[serde(skip_serializing_if = "Option::is_none")] + pub security: Option, + + /// Human-readable name for this route + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, + + /// Description of the route's purpose + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + + /// Controls matching order (higher = matched first) + #[serde(skip_serializing_if = "Option::is_none")] + pub priority: Option, + + /// Arbitrary tags for categorization + #[serde(skip_serializing_if = "Option::is_none")] + pub tags: Option>, + + /// Whether the route is active (default: true) + #[serde(skip_serializing_if = "Option::is_none")] + pub enabled: Option, +} + +impl RouteConfig { + /// Check if this route is enabled (defaults to true). + pub fn is_enabled(&self) -> bool { + self.enabled.unwrap_or(true) + } + + /// Get the effective priority (defaults to 0). + pub fn effective_priority(&self) -> i32 { + self.priority.unwrap_or(0) + } + + /// Get all ports this route listens on. + pub fn listening_ports(&self) -> Vec { + self.route_match.ports.to_ports() + } + + /// Get the TLS mode for this route (from action-level or first target). + pub fn tls_mode(&self) -> Option<&crate::tls_types::TlsMode> { + // Check action-level TLS first + if let Some(tls) = &self.action.tls { + return Some(&tls.mode); + } + // Check first target's TLS + if let Some(targets) = &self.action.targets { + if let Some(first) = targets.first() { + if let Some(tls) = &first.tls { + return Some(&tls.mode); + } + } + } + None + } +} diff --git a/rust/crates/rustproxy-config/src/security_types.rs b/rust/crates/rustproxy-config/src/security_types.rs new file mode 100644 index 0000000..4a6ad10 --- /dev/null +++ b/rust/crates/rustproxy-config/src/security_types.rs @@ -0,0 +1,132 @@ +use serde::{Deserialize, Serialize}; + +/// Rate limiting configuration. +/// Matches TypeScript: `IRouteRateLimit` +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct RouteRateLimit { + pub enabled: bool, + pub max_requests: u64, + /// Time window in seconds + pub window: u64, + #[serde(skip_serializing_if = "Option::is_none")] + pub key_by: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub header_name: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub error_message: Option, +} + +/// Rate limit key selection. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum RateLimitKeyBy { + Ip, + Path, + Header, +} + +/// Authentication type. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum AuthenticationType { + Basic, + Digest, + Oauth, + Jwt, +} + +/// Authentication credentials. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct AuthCredentials { + pub username: String, + pub password: String, +} + +/// Authentication options. +/// Matches TypeScript: `IRouteAuthentication` +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct RouteAuthentication { + #[serde(rename = "type")] + pub auth_type: AuthenticationType, + #[serde(skip_serializing_if = "Option::is_none")] + pub credentials: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub realm: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub jwt_secret: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub jwt_issuer: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub oauth_provider: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub oauth_client_id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub oauth_client_secret: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub oauth_redirect_uri: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub options: Option, +} + +/// Basic auth configuration. +/// Matches TypeScript: `IRouteSecurity.basicAuth` +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct BasicAuthConfig { + pub enabled: bool, + pub users: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub realm: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub exclude_paths: Option>, +} + +/// JWT auth configuration. +/// Matches TypeScript: `IRouteSecurity.jwtAuth` +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct JwtAuthConfig { + pub enabled: bool, + pub secret: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub algorithm: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub issuer: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub audience: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub expires_in: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub exclude_paths: Option>, +} + +/// Security options for routes. +/// Matches TypeScript: `IRouteSecurity` +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct RouteSecurity { + /// IP addresses that are allowed to connect + #[serde(skip_serializing_if = "Option::is_none")] + pub ip_allow_list: Option>, + /// IP addresses that are blocked from connecting + #[serde(skip_serializing_if = "Option::is_none")] + pub ip_block_list: Option>, + /// Maximum concurrent connections + #[serde(skip_serializing_if = "Option::is_none")] + pub max_connections: Option, + /// Authentication configuration + #[serde(skip_serializing_if = "Option::is_none")] + pub authentication: Option, + /// Rate limiting + #[serde(skip_serializing_if = "Option::is_none")] + pub rate_limit: Option, + /// Basic auth + #[serde(skip_serializing_if = "Option::is_none")] + pub basic_auth: Option, + /// JWT auth + #[serde(skip_serializing_if = "Option::is_none")] + pub jwt_auth: Option, +} diff --git a/rust/crates/rustproxy-config/src/tls_types.rs b/rust/crates/rustproxy-config/src/tls_types.rs new file mode 100644 index 0000000..eca4cfb --- /dev/null +++ b/rust/crates/rustproxy-config/src/tls_types.rs @@ -0,0 +1,93 @@ +use serde::{Deserialize, Serialize}; + +/// TLS handling modes for route configurations. +/// Matches TypeScript: `type TTlsMode = 'passthrough' | 'terminate' | 'terminate-and-reencrypt'` +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "kebab-case")] +pub enum TlsMode { + Passthrough, + Terminate, + TerminateAndReencrypt, +} + +/// Static certificate configuration (PEM-encoded). +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct CertificateConfig { + /// PEM-encoded private key + pub key: String, + /// PEM-encoded certificate + pub cert: String, + /// PEM-encoded CA chain + #[serde(skip_serializing_if = "Option::is_none")] + pub ca: Option, + /// Path to key file (overrides key) + #[serde(skip_serializing_if = "Option::is_none")] + pub key_file: Option, + /// Path to cert file (overrides cert) + #[serde(skip_serializing_if = "Option::is_none")] + pub cert_file: Option, +} + +/// Certificate specification: either automatic (ACME) or static. +/// Matches TypeScript: `certificate?: 'auto' | { key, cert, ca?, keyFile?, certFile? }` +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(untagged)] +pub enum CertificateSpec { + /// Use ACME (Let's Encrypt) for automatic provisioning + Auto(String), // "auto" + /// Static certificate configuration + Static(CertificateConfig), +} + +impl CertificateSpec { + /// Check if this is an auto (ACME) certificate + pub fn is_auto(&self) -> bool { + matches!(self, CertificateSpec::Auto(s) if s == "auto") + } +} + +/// ACME configuration for automatic certificate provisioning. +/// Matches TypeScript: `IRouteAcme` +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct RouteAcme { + /// Contact email for ACME account + pub email: String, + /// Use production ACME servers (default: false) + #[serde(skip_serializing_if = "Option::is_none")] + pub use_production: Option, + /// Port for HTTP-01 challenges (default: 80) + #[serde(skip_serializing_if = "Option::is_none")] + pub challenge_port: Option, + /// Days before expiry to renew (default: 30) + #[serde(skip_serializing_if = "Option::is_none")] + pub renew_before_days: Option, +} + +/// TLS configuration for route actions. +/// Matches TypeScript: `IRouteTls` +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct RouteTls { + /// TLS mode (passthrough, terminate, terminate-and-reencrypt) + pub mode: TlsMode, + /// Certificate configuration (auto or static) + #[serde(skip_serializing_if = "Option::is_none")] + pub certificate: Option, + /// ACME options when certificate is 'auto' + #[serde(skip_serializing_if = "Option::is_none")] + pub acme: Option, + /// Allowed TLS versions + #[serde(skip_serializing_if = "Option::is_none")] + pub versions: Option>, + /// OpenSSL cipher string + #[serde(skip_serializing_if = "Option::is_none")] + pub ciphers: Option, + /// Use server's cipher preferences + #[serde(skip_serializing_if = "Option::is_none")] + pub honor_cipher_order: Option, + /// TLS session timeout in seconds + #[serde(skip_serializing_if = "Option::is_none")] + pub session_timeout: Option, +} diff --git a/rust/crates/rustproxy-config/src/validation.rs b/rust/crates/rustproxy-config/src/validation.rs new file mode 100644 index 0000000..4998f57 --- /dev/null +++ b/rust/crates/rustproxy-config/src/validation.rs @@ -0,0 +1,158 @@ +use thiserror::Error; + +use crate::route_types::{RouteConfig, RouteActionType}; + +/// Validation errors for route configurations. +#[derive(Debug, Error)] +pub enum ValidationError { + #[error("Route '{name}' has no targets but action type is 'forward'")] + MissingTargets { name: String }, + + #[error("Route '{name}' has empty targets list")] + EmptyTargets { name: String }, + + #[error("Route '{name}' has no ports specified")] + NoPorts { name: String }, + + #[error("Route '{name}' port {port} is invalid (must be 1-65535)")] + InvalidPort { name: String, port: u16 }, + + #[error("Route '{name}': socket-handler action type is not supported in JSON config")] + SocketHandlerInJson { name: String }, + + #[error("Route '{name}': duplicate route ID '{id}'")] + DuplicateId { name: String, id: String }, + + #[error("Route '{name}': {message}")] + Custom { name: String, message: String }, +} + +/// Validate a single route configuration. +pub fn validate_route(route: &RouteConfig) -> Result<(), Vec> { + let mut errors = Vec::new(); + let name = route.name.clone().unwrap_or_else(|| { + route.id.clone().unwrap_or_else(|| "unnamed".to_string()) + }); + + // Check ports + let ports = route.listening_ports(); + if ports.is_empty() { + errors.push(ValidationError::NoPorts { name: name.clone() }); + } + for &port in &ports { + if port == 0 { + errors.push(ValidationError::InvalidPort { + name: name.clone(), + port, + }); + } + } + + // Check forward action has targets + if route.action.action_type == RouteActionType::Forward { + match &route.action.targets { + None => { + errors.push(ValidationError::MissingTargets { name: name.clone() }); + } + Some(targets) if targets.is_empty() => { + errors.push(ValidationError::EmptyTargets { name: name.clone() }); + } + _ => {} + } + } + + if errors.is_empty() { + Ok(()) + } else { + Err(errors) + } +} + +/// Validate an entire list of routes. +pub fn validate_routes(routes: &[RouteConfig]) -> Result<(), Vec> { + let mut all_errors = Vec::new(); + let mut seen_ids = std::collections::HashSet::new(); + + for route in routes { + // Check for duplicate IDs + if let Some(id) = &route.id { + if !seen_ids.insert(id.clone()) { + let name = route.name.clone().unwrap_or_else(|| id.clone()); + all_errors.push(ValidationError::DuplicateId { + name, + id: id.clone(), + }); + } + } + + // Validate individual route + if let Err(errors) = validate_route(route) { + all_errors.extend(errors); + } + } + + if all_errors.is_empty() { + Ok(()) + } else { + Err(all_errors) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::route_types::*; + + fn make_valid_route() -> RouteConfig { + crate::helpers::create_http_route("example.com", "localhost", 8080) + } + + #[test] + fn test_valid_route_passes() { + let route = make_valid_route(); + assert!(validate_route(&route).is_ok()); + } + + #[test] + fn test_missing_targets() { + let mut route = make_valid_route(); + route.action.targets = None; + let errors = validate_route(&route).unwrap_err(); + assert!(errors.iter().any(|e| matches!(e, ValidationError::MissingTargets { .. }))); + } + + #[test] + fn test_empty_targets() { + let mut route = make_valid_route(); + route.action.targets = Some(vec![]); + let errors = validate_route(&route).unwrap_err(); + assert!(errors.iter().any(|e| matches!(e, ValidationError::EmptyTargets { .. }))); + } + + #[test] + fn test_invalid_port_zero() { + let mut route = make_valid_route(); + route.route_match.ports = PortRange::Single(0); + let errors = validate_route(&route).unwrap_err(); + assert!(errors.iter().any(|e| matches!(e, ValidationError::InvalidPort { port: 0, .. }))); + } + + #[test] + fn test_duplicate_ids() { + let mut r1 = make_valid_route(); + r1.id = Some("route-1".to_string()); + let mut r2 = make_valid_route(); + r2.id = Some("route-1".to_string()); + let errors = validate_routes(&[r1, r2]).unwrap_err(); + assert!(errors.iter().any(|e| matches!(e, ValidationError::DuplicateId { .. }))); + } + + #[test] + fn test_multiple_errors_collected() { + let mut r1 = make_valid_route(); + r1.action.targets = None; // MissingTargets + r1.route_match.ports = PortRange::Single(0); // InvalidPort + let errors = validate_route(&r1).unwrap_err(); + assert!(errors.len() >= 2); + } +} diff --git a/rust/crates/rustproxy-http/Cargo.toml b/rust/crates/rustproxy-http/Cargo.toml new file mode 100644 index 0000000..f3c7ba1 --- /dev/null +++ b/rust/crates/rustproxy-http/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "rustproxy-http" +version.workspace = true +edition.workspace = true +license.workspace = true +authors.workspace = true +description = "Hyper-based HTTP proxy service for RustProxy" + +[dependencies] +rustproxy-config = { workspace = true } +rustproxy-routing = { workspace = true } +rustproxy-security = { workspace = true } +rustproxy-metrics = { workspace = true } +hyper = { workspace = true } +hyper-util = { workspace = true } +regex = { workspace = true } +http-body-util = { workspace = true } +bytes = { workspace = true } +tokio = { workspace = true } +tracing = { workspace = true } +thiserror = { workspace = true } +anyhow = { workspace = true } +arc-swap = { workspace = true } +dashmap = { workspace = true } diff --git a/rust/crates/rustproxy-http/src/lib.rs b/rust/crates/rustproxy-http/src/lib.rs new file mode 100644 index 0000000..1ad4cee --- /dev/null +++ b/rust/crates/rustproxy-http/src/lib.rs @@ -0,0 +1,14 @@ +//! # rustproxy-http +//! +//! Hyper-based HTTP proxy service for RustProxy. +//! Handles HTTP request parsing, route-based forwarding, and response filtering. + +pub mod proxy_service; +pub mod request_filter; +pub mod response_filter; +pub mod template; +pub mod upstream_selector; + +pub use proxy_service::*; +pub use template::*; +pub use upstream_selector::*; diff --git a/rust/crates/rustproxy-http/src/proxy_service.rs b/rust/crates/rustproxy-http/src/proxy_service.rs new file mode 100644 index 0000000..935ca29 --- /dev/null +++ b/rust/crates/rustproxy-http/src/proxy_service.rs @@ -0,0 +1,827 @@ +//! Hyper-based HTTP proxy service. +//! +//! Accepts decrypted TCP streams (from TLS termination or plain TCP), +//! parses HTTP requests, matches routes, and forwards to upstream backends. +//! Supports HTTP/1.1 keep-alive, HTTP/2 (auto-detect), and WebSocket upgrade. + +use std::collections::HashMap; +use std::sync::Arc; + +use bytes::Bytes; +use http_body_util::{BodyExt, Full, combinators::BoxBody}; +use hyper::body::Incoming; +use hyper::{Request, Response, StatusCode}; +use hyper_util::rt::TokioIo; +use regex::Regex; +use tokio::net::TcpStream; +use tracing::{debug, error, info, warn}; + +use rustproxy_routing::RouteManager; +use rustproxy_metrics::MetricsCollector; + +use crate::request_filter::RequestFilter; +use crate::response_filter::ResponseFilter; +use crate::upstream_selector::UpstreamSelector; + +/// HTTP proxy service that processes HTTP traffic. +pub struct HttpProxyService { + route_manager: Arc, + metrics: Arc, + upstream_selector: UpstreamSelector, +} + +impl HttpProxyService { + pub fn new(route_manager: Arc, metrics: Arc) -> Self { + Self { + route_manager, + metrics, + upstream_selector: UpstreamSelector::new(), + } + } + + /// Handle an incoming HTTP connection on a plain TCP stream. + pub async fn handle_connection( + self: Arc, + stream: TcpStream, + peer_addr: std::net::SocketAddr, + port: u16, + ) { + self.handle_io(stream, peer_addr, port).await; + } + + /// Handle an incoming HTTP connection on any IO type (plain TCP or TLS-terminated). + /// + /// Uses HTTP/1.1 with upgrade support. For clients that negotiate HTTP/2, + /// use `handle_io_auto` instead. + pub async fn handle_io( + self: Arc, + stream: I, + peer_addr: std::net::SocketAddr, + port: u16, + ) + where + I: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static, + { + let io = TokioIo::new(stream); + + let service = hyper::service::service_fn(move |req: Request| { + let svc = Arc::clone(&self); + let peer = peer_addr; + async move { + svc.handle_request(req, peer, port).await + } + }); + + // Use http1::Builder with upgrades for WebSocket support + let conn = hyper::server::conn::http1::Builder::new() + .keep_alive(true) + .serve_connection(io, service) + .with_upgrades(); + + if let Err(e) = conn.await { + debug!("HTTP connection error from {}: {}", peer_addr, e); + } + } + + /// Handle a single HTTP request. + async fn handle_request( + &self, + req: Request, + peer_addr: std::net::SocketAddr, + port: u16, + ) -> Result>, hyper::Error> { + let host = req.headers() + .get("host") + .and_then(|h| h.to_str().ok()) + .map(|h| { + // Strip port from host header + h.split(':').next().unwrap_or(h).to_string() + }); + + let path = req.uri().path().to_string(); + let method = req.method().clone(); + + // Extract headers for matching + let headers: HashMap = req.headers() + .iter() + .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string())) + .collect(); + + debug!("HTTP {} {} (host: {:?}) from {}", method, path, host, peer_addr); + + // Check for CORS preflight + if method == hyper::Method::OPTIONS { + if let Some(response) = RequestFilter::handle_cors_preflight(&req) { + return Ok(response); + } + } + + // Match route + let ctx = rustproxy_routing::MatchContext { + port, + domain: host.as_deref(), + path: Some(&path), + client_ip: Some(&peer_addr.ip().to_string()), + tls_version: None, + headers: Some(&headers), + is_tls: false, + }; + + let route_match = match self.route_manager.find_route(&ctx) { + Some(rm) => rm, + None => { + debug!("No route matched for HTTP request to {:?}{}", host, path); + return Ok(error_response(StatusCode::BAD_GATEWAY, "No route matched")); + } + }; + + let route_id = route_match.route.id.as_deref(); + self.metrics.connection_opened(route_id); + + // Apply request filters (IP check, rate limiting, auth) + if let Some(ref security) = route_match.route.security { + if let Some(response) = RequestFilter::apply(security, &req, &peer_addr) { + self.metrics.connection_closed(route_id); + return Ok(response); + } + } + + // Check for test response (returns immediately, no upstream needed) + if let Some(ref advanced) = route_match.route.action.advanced { + if let Some(ref test_response) = advanced.test_response { + self.metrics.connection_closed(route_id); + return Ok(Self::build_test_response(test_response)); + } + } + + // Check for static file serving + if let Some(ref advanced) = route_match.route.action.advanced { + if let Some(ref static_files) = advanced.static_files { + self.metrics.connection_closed(route_id); + return Ok(Self::serve_static_file(&path, static_files)); + } + } + + // Select upstream + let target = match route_match.target { + Some(t) => t, + None => { + self.metrics.connection_closed(route_id); + return Ok(error_response(StatusCode::BAD_GATEWAY, "No target available")); + } + }; + + let upstream = self.upstream_selector.select(target, &peer_addr, port); + let upstream_key = format!("{}:{}", upstream.host, upstream.port); + self.upstream_selector.connection_started(&upstream_key); + + // Check for WebSocket upgrade + let is_websocket = req.headers() + .get("upgrade") + .and_then(|v| v.to_str().ok()) + .map(|v| v.eq_ignore_ascii_case("websocket")) + .unwrap_or(false); + + if is_websocket { + let result = self.handle_websocket_upgrade( + req, peer_addr, &upstream, route_match.route, route_id, &upstream_key, + ).await; + // Note: for WebSocket, connection_ended is called inside + // the spawned tunnel task when the connection closes. + return result; + } + + // Determine backend protocol + let use_h2 = route_match.route.action.options.as_ref() + .and_then(|o| o.backend_protocol.as_ref()) + .map(|p| *p == rustproxy_config::BackendProtocol::Http2) + .unwrap_or(false); + + // Build the upstream path (path + query), applying URL rewriting if configured + let upstream_path = { + let raw_path = match req.uri().query() { + Some(q) => format!("{}?{}", path, q), + None => path.clone(), + }; + Self::apply_url_rewrite(&raw_path, &route_match.route) + }; + + // Build upstream request - stream body instead of buffering + let (parts, body) = req.into_parts(); + + // Apply request headers from route config + let mut upstream_headers = parts.headers.clone(); + if let Some(ref route_headers) = route_match.route.headers { + if let Some(ref request_headers) = route_headers.request { + for (key, value) in request_headers { + if let Ok(name) = hyper::header::HeaderName::from_bytes(key.as_bytes()) { + if let Ok(val) = hyper::header::HeaderValue::from_str(value) { + upstream_headers.insert(name, val); + } + } + } + } + } + + // Connect to upstream + let upstream_stream = match TcpStream::connect(format!("{}:{}", upstream.host, upstream.port)).await { + Ok(s) => s, + Err(e) => { + error!("Failed to connect to upstream {}:{}: {}", upstream.host, upstream.port, e); + self.upstream_selector.connection_ended(&upstream_key); + self.metrics.connection_closed(route_id); + return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend unavailable")); + } + }; + upstream_stream.set_nodelay(true).ok(); + + let io = TokioIo::new(upstream_stream); + + let result = if use_h2 { + // HTTP/2 backend + self.forward_h2(io, parts, body, upstream_headers, &upstream_path, &upstream, route_match.route, route_id).await + } else { + // HTTP/1.1 backend (default) + self.forward_h1(io, parts, body, upstream_headers, &upstream_path, &upstream, route_match.route, route_id).await + }; + self.upstream_selector.connection_ended(&upstream_key); + result + } + + /// Forward request to backend via HTTP/1.1 with body streaming. + async fn forward_h1( + &self, + io: TokioIo, + parts: hyper::http::request::Parts, + body: Incoming, + upstream_headers: hyper::HeaderMap, + upstream_path: &str, + upstream: &crate::upstream_selector::UpstreamSelection, + route: &rustproxy_config::RouteConfig, + route_id: Option<&str>, + ) -> Result>, hyper::Error> { + let (mut sender, conn) = match hyper::client::conn::http1::handshake(io).await { + Ok(h) => h, + Err(e) => { + error!("Upstream handshake failed: {}", e); + self.metrics.connection_closed(route_id); + return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend handshake failed")); + } + }; + + tokio::spawn(async move { + if let Err(e) = conn.await { + debug!("Upstream connection error: {}", e); + } + }); + + let mut upstream_req = Request::builder() + .method(parts.method) + .uri(upstream_path) + .version(parts.version); + + if let Some(headers) = upstream_req.headers_mut() { + *headers = upstream_headers; + if let Ok(host_val) = hyper::header::HeaderValue::from_str( + &format!("{}:{}", upstream.host, upstream.port) + ) { + headers.insert(hyper::header::HOST, host_val); + } + } + + // Stream the request body through to upstream + let upstream_req = upstream_req.body(body).unwrap(); + + let upstream_response = match sender.send_request(upstream_req).await { + Ok(resp) => resp, + Err(e) => { + error!("Upstream request failed: {}", e); + self.metrics.connection_closed(route_id); + return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend request failed")); + } + }; + + self.build_streaming_response(upstream_response, route, route_id).await + } + + /// Forward request to backend via HTTP/2 with body streaming. + async fn forward_h2( + &self, + io: TokioIo, + parts: hyper::http::request::Parts, + body: Incoming, + upstream_headers: hyper::HeaderMap, + upstream_path: &str, + upstream: &crate::upstream_selector::UpstreamSelection, + route: &rustproxy_config::RouteConfig, + route_id: Option<&str>, + ) -> Result>, hyper::Error> { + let exec = hyper_util::rt::TokioExecutor::new(); + let (mut sender, conn) = match hyper::client::conn::http2::handshake(exec, io).await { + Ok(h) => h, + Err(e) => { + error!("HTTP/2 upstream handshake failed: {}", e); + self.metrics.connection_closed(route_id); + return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend H2 handshake failed")); + } + }; + + tokio::spawn(async move { + if let Err(e) = conn.await { + debug!("HTTP/2 upstream connection error: {}", e); + } + }); + + let mut upstream_req = Request::builder() + .method(parts.method) + .uri(upstream_path); + + if let Some(headers) = upstream_req.headers_mut() { + *headers = upstream_headers; + if let Ok(host_val) = hyper::header::HeaderValue::from_str( + &format!("{}:{}", upstream.host, upstream.port) + ) { + headers.insert(hyper::header::HOST, host_val); + } + } + + // Stream the request body through to upstream + let upstream_req = upstream_req.body(body).unwrap(); + + let upstream_response = match sender.send_request(upstream_req).await { + Ok(resp) => resp, + Err(e) => { + error!("HTTP/2 upstream request failed: {}", e); + self.metrics.connection_closed(route_id); + return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend H2 request failed")); + } + }; + + self.build_streaming_response(upstream_response, route, route_id).await + } + + /// Build the client-facing response from an upstream response, streaming the body. + async fn build_streaming_response( + &self, + upstream_response: Response, + route: &rustproxy_config::RouteConfig, + route_id: Option<&str>, + ) -> Result>, hyper::Error> { + let (resp_parts, resp_body) = upstream_response.into_parts(); + + let mut response = Response::builder() + .status(resp_parts.status); + + if let Some(headers) = response.headers_mut() { + *headers = resp_parts.headers; + ResponseFilter::apply_headers(route, headers, None); + } + + self.metrics.connection_closed(route_id); + + // Stream the response body directly from upstream to client + let body: BoxBody = BoxBody::new(resp_body); + + Ok(response.body(body).unwrap()) + } + + /// Handle a WebSocket upgrade request. + async fn handle_websocket_upgrade( + &self, + req: Request, + peer_addr: std::net::SocketAddr, + upstream: &crate::upstream_selector::UpstreamSelection, + route: &rustproxy_config::RouteConfig, + route_id: Option<&str>, + upstream_key: &str, + ) -> Result>, hyper::Error> { + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + + // Get WebSocket config from route + let ws_config = route.action.websocket.as_ref(); + + // Check allowed origins if configured + if let Some(ws) = ws_config { + if let Some(ref allowed_origins) = ws.allowed_origins { + let origin = req.headers() + .get("origin") + .and_then(|v| v.to_str().ok()) + .unwrap_or(""); + if !allowed_origins.is_empty() && !allowed_origins.iter().any(|o| o == "*" || o == origin) { + self.upstream_selector.connection_ended(upstream_key); + self.metrics.connection_closed(route_id); + return Ok(error_response(StatusCode::FORBIDDEN, "Origin not allowed")); + } + } + } + + info!("WebSocket upgrade from {} -> {}:{}", peer_addr, upstream.host, upstream.port); + + let mut upstream_stream = match TcpStream::connect( + format!("{}:{}", upstream.host, upstream.port) + ).await { + Ok(s) => s, + Err(e) => { + error!("WebSocket: failed to connect upstream {}:{}: {}", upstream.host, upstream.port, e); + self.upstream_selector.connection_ended(upstream_key); + self.metrics.connection_closed(route_id); + return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend unavailable")); + } + }; + upstream_stream.set_nodelay(true).ok(); + + let path = req.uri().path().to_string(); + let upstream_path = { + let raw = match req.uri().query() { + Some(q) => format!("{}?{}", path, q), + None => path, + }; + // Apply rewrite_path if configured + if let Some(ws) = ws_config { + if let Some(ref rewrite_path) = ws.rewrite_path { + rewrite_path.clone() + } else { + raw + } + } else { + raw + } + }; + + let (parts, _body) = req.into_parts(); + + let mut raw_request = format!( + "{} {} HTTP/1.1\r\n", + parts.method, upstream_path + ); + + let upstream_host = format!("{}:{}", upstream.host, upstream.port); + for (name, value) in parts.headers.iter() { + if name == hyper::header::HOST { + raw_request.push_str(&format!("host: {}\r\n", upstream_host)); + } else { + raw_request.push_str(&format!("{}: {}\r\n", name, value.to_str().unwrap_or(""))); + } + } + + if let Some(ref route_headers) = route.headers { + if let Some(ref request_headers) = route_headers.request { + for (key, value) in request_headers { + raw_request.push_str(&format!("{}: {}\r\n", key, value)); + } + } + } + + // Apply WebSocket custom headers + if let Some(ws) = ws_config { + if let Some(ref custom_headers) = ws.custom_headers { + for (key, value) in custom_headers { + raw_request.push_str(&format!("{}: {}\r\n", key, value)); + } + } + } + + raw_request.push_str("\r\n"); + + if let Err(e) = upstream_stream.write_all(raw_request.as_bytes()).await { + error!("WebSocket: failed to send upgrade request to upstream: {}", e); + self.upstream_selector.connection_ended(upstream_key); + self.metrics.connection_closed(route_id); + return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend write failed")); + } + + let mut response_buf = Vec::with_capacity(4096); + let mut temp = [0u8; 1]; + loop { + match upstream_stream.read(&mut temp).await { + Ok(0) => { + error!("WebSocket: upstream closed before completing handshake"); + self.upstream_selector.connection_ended(upstream_key); + self.metrics.connection_closed(route_id); + return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend closed")); + } + Ok(_) => { + response_buf.push(temp[0]); + if response_buf.len() >= 4 { + let len = response_buf.len(); + if response_buf[len-4..] == *b"\r\n\r\n" { + break; + } + } + if response_buf.len() > 8192 { + error!("WebSocket: upstream response headers too large"); + self.upstream_selector.connection_ended(upstream_key); + self.metrics.connection_closed(route_id); + return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend response too large")); + } + } + Err(e) => { + error!("WebSocket: failed to read upstream response: {}", e); + self.upstream_selector.connection_ended(upstream_key); + self.metrics.connection_closed(route_id); + return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend read failed")); + } + } + } + + let response_str = String::from_utf8_lossy(&response_buf); + + let status_line = response_str.lines().next().unwrap_or(""); + let status_code = status_line + .split_whitespace() + .nth(1) + .and_then(|s| s.parse::().ok()) + .unwrap_or(0); + + if status_code != 101 { + debug!("WebSocket: upstream rejected upgrade with status {}", status_code); + self.upstream_selector.connection_ended(upstream_key); + self.metrics.connection_closed(route_id); + return Ok(error_response( + StatusCode::from_u16(status_code).unwrap_or(StatusCode::BAD_GATEWAY), + "WebSocket upgrade rejected by backend", + )); + } + + let mut client_resp = Response::builder() + .status(StatusCode::SWITCHING_PROTOCOLS); + + if let Some(resp_headers) = client_resp.headers_mut() { + for line in response_str.lines().skip(1) { + let line = line.trim(); + if line.is_empty() { + break; + } + if let Some((name, value)) = line.split_once(':') { + let name = name.trim(); + let value = value.trim(); + if let Ok(header_name) = hyper::header::HeaderName::from_bytes(name.as_bytes()) { + if let Ok(header_value) = hyper::header::HeaderValue::from_str(value) { + resp_headers.insert(header_name, header_value); + } + } + } + } + } + + let on_client_upgrade = hyper::upgrade::on( + Request::from_parts(parts, http_body_util::Empty::::new()) + ); + + let metrics = Arc::clone(&self.metrics); + let route_id_owned = route_id.map(|s| s.to_string()); + let upstream_selector = self.upstream_selector.clone(); + let upstream_key_owned = upstream_key.to_string(); + + tokio::spawn(async move { + let client_upgraded = match on_client_upgrade.await { + Ok(upgraded) => upgraded, + Err(e) => { + debug!("WebSocket: client upgrade failed: {}", e); + upstream_selector.connection_ended(&upstream_key_owned); + if let Some(ref rid) = route_id_owned { + metrics.connection_closed(Some(rid.as_str())); + } + return; + } + }; + + let client_io = TokioIo::new(client_upgraded); + + let (mut cr, mut cw) = tokio::io::split(client_io); + let (mut ur, mut uw) = tokio::io::split(upstream_stream); + + let c2u = tokio::spawn(async move { + let mut buf = vec![0u8; 65536]; + let mut total = 0u64; + loop { + let n = match cr.read(&mut buf).await { + Ok(0) | Err(_) => break, + Ok(n) => n, + }; + if uw.write_all(&buf[..n]).await.is_err() { + break; + } + total += n as u64; + } + let _ = uw.shutdown().await; + total + }); + + let u2c = tokio::spawn(async move { + let mut buf = vec![0u8; 65536]; + let mut total = 0u64; + loop { + let n = match ur.read(&mut buf).await { + Ok(0) | Err(_) => break, + Ok(n) => n, + }; + if cw.write_all(&buf[..n]).await.is_err() { + break; + } + total += n as u64; + } + let _ = cw.shutdown().await; + total + }); + + let bytes_in = c2u.await.unwrap_or(0); + let bytes_out = u2c.await.unwrap_or(0); + + debug!("WebSocket tunnel closed: {} bytes in, {} bytes out", bytes_in, bytes_out); + + upstream_selector.connection_ended(&upstream_key_owned); + if let Some(ref rid) = route_id_owned { + metrics.record_bytes(bytes_in, bytes_out, Some(rid.as_str())); + metrics.connection_closed(Some(rid.as_str())); + } + }); + + let body: BoxBody = BoxBody::new( + http_body_util::Empty::::new().map_err(|never| match never {}) + ); + Ok(client_resp.body(body).unwrap()) + } + + /// Build a test response from config (no upstream connection needed). + fn build_test_response(config: &rustproxy_config::RouteTestResponse) -> Response> { + let mut response = Response::builder() + .status(StatusCode::from_u16(config.status).unwrap_or(StatusCode::OK)); + + if let Some(headers) = response.headers_mut() { + for (key, value) in &config.headers { + if let Ok(name) = hyper::header::HeaderName::from_bytes(key.as_bytes()) { + if let Ok(val) = hyper::header::HeaderValue::from_str(value) { + headers.insert(name, val); + } + } + } + } + + let body = Full::new(Bytes::from(config.body.clone())) + .map_err(|never| match never {}); + response.body(BoxBody::new(body)).unwrap() + } + + /// Apply URL rewriting rules from route config. + fn apply_url_rewrite(path: &str, route: &rustproxy_config::RouteConfig) -> String { + let rewrite = match route.action.advanced.as_ref() + .and_then(|a| a.url_rewrite.as_ref()) + { + Some(r) => r, + None => return path.to_string(), + }; + + // Determine what to rewrite + let (subject, suffix) = if rewrite.only_rewrite_path.unwrap_or(false) { + // Only rewrite the path portion (before ?) + match path.split_once('?') { + Some((p, q)) => (p.to_string(), format!("?{}", q)), + None => (path.to_string(), String::new()), + } + } else { + (path.to_string(), String::new()) + }; + + match Regex::new(&rewrite.pattern) { + Ok(re) => { + let result = re.replace_all(&subject, rewrite.target.as_str()); + format!("{}{}", result, suffix) + } + Err(e) => { + warn!("Invalid URL rewrite pattern '{}': {}", rewrite.pattern, e); + path.to_string() + } + } + } + + /// Serve a static file from the configured directory. + fn serve_static_file( + path: &str, + config: &rustproxy_config::RouteStaticFiles, + ) -> Response> { + use std::path::Path; + + let root = Path::new(&config.root); + + // Sanitize path to prevent directory traversal + let clean_path = path.trim_start_matches('/'); + let clean_path = clean_path.replace("..", ""); + + let mut file_path = root.join(&clean_path); + + // If path points to a directory, try index files + if file_path.is_dir() || clean_path.is_empty() { + let index_files = config.index_files.as_deref() + .or(config.index.as_deref()) + .unwrap_or(&[]); + let default_index = vec!["index.html".to_string()]; + let index_files = if index_files.is_empty() { &default_index } else { index_files }; + + let mut found = false; + for index in index_files { + let candidate = if clean_path.is_empty() { + root.join(index) + } else { + file_path.join(index) + }; + if candidate.is_file() { + file_path = candidate; + found = true; + break; + } + } + if !found { + return error_response(StatusCode::NOT_FOUND, "Not found"); + } + } + + // Ensure the resolved path is within the root (prevent traversal) + let canonical_root = match root.canonicalize() { + Ok(p) => p, + Err(_) => return error_response(StatusCode::NOT_FOUND, "Not found"), + }; + let canonical_file = match file_path.canonicalize() { + Ok(p) => p, + Err(_) => return error_response(StatusCode::NOT_FOUND, "Not found"), + }; + if !canonical_file.starts_with(&canonical_root) { + return error_response(StatusCode::FORBIDDEN, "Forbidden"); + } + + // Check if symlinks are allowed + if config.follow_symlinks == Some(false) && canonical_file != file_path { + return error_response(StatusCode::FORBIDDEN, "Forbidden"); + } + + // Read the file + match std::fs::read(&file_path) { + Ok(content) => { + let content_type = guess_content_type(&file_path); + let mut response = Response::builder() + .status(StatusCode::OK) + .header("Content-Type", content_type); + + // Apply cache-control if configured + if let Some(ref cache_control) = config.cache_control { + response = response.header("Cache-Control", cache_control.as_str()); + } + + // Apply custom headers + if let Some(ref headers) = config.headers { + for (key, value) in headers { + response = response.header(key.as_str(), value.as_str()); + } + } + + let body = Full::new(Bytes::from(content)) + .map_err(|never| match never {}); + response.body(BoxBody::new(body)).unwrap() + } + Err(_) => error_response(StatusCode::NOT_FOUND, "Not found"), + } + } +} + +/// Guess MIME content type from file extension. +fn guess_content_type(path: &std::path::Path) -> &'static str { + match path.extension().and_then(|e| e.to_str()) { + Some("html") | Some("htm") => "text/html; charset=utf-8", + Some("css") => "text/css; charset=utf-8", + Some("js") | Some("mjs") => "application/javascript; charset=utf-8", + Some("json") => "application/json; charset=utf-8", + Some("xml") => "application/xml; charset=utf-8", + Some("txt") => "text/plain; charset=utf-8", + Some("png") => "image/png", + Some("jpg") | Some("jpeg") => "image/jpeg", + Some("gif") => "image/gif", + Some("svg") => "image/svg+xml", + Some("ico") => "image/x-icon", + Some("woff") => "font/woff", + Some("woff2") => "font/woff2", + Some("ttf") => "font/ttf", + Some("pdf") => "application/pdf", + Some("wasm") => "application/wasm", + _ => "application/octet-stream", + } +} + +impl Default for HttpProxyService { + fn default() -> Self { + Self { + route_manager: Arc::new(RouteManager::new(vec![])), + metrics: Arc::new(MetricsCollector::new()), + upstream_selector: UpstreamSelector::new(), + } + } +} + +fn error_response(status: StatusCode, message: &str) -> Response> { + let body = Full::new(Bytes::from(message.to_string())) + .map_err(|never| match never {}); + Response::builder() + .status(status) + .header("Content-Type", "text/plain") + .body(BoxBody::new(body)) + .unwrap() +} diff --git a/rust/crates/rustproxy-http/src/request_filter.rs b/rust/crates/rustproxy-http/src/request_filter.rs new file mode 100644 index 0000000..7bfa777 --- /dev/null +++ b/rust/crates/rustproxy-http/src/request_filter.rs @@ -0,0 +1,263 @@ +//! Request filtering: security checks, auth, CORS preflight. + +use std::net::SocketAddr; +use std::sync::Arc; + +use bytes::Bytes; +use http_body_util::Full; +use http_body_util::BodyExt; +use hyper::body::Incoming; +use hyper::{Request, Response, StatusCode}; +use http_body_util::combinators::BoxBody; + +use rustproxy_config::RouteSecurity; +use rustproxy_security::{IpFilter, BasicAuthValidator, JwtValidator, RateLimiter}; + +pub struct RequestFilter; + +impl RequestFilter { + /// Apply security filters. Returns Some(response) if the request should be blocked. + pub fn apply( + security: &RouteSecurity, + req: &Request, + peer_addr: &SocketAddr, + ) -> Option>> { + Self::apply_with_rate_limiter(security, req, peer_addr, None) + } + + /// Apply security filters with an optional shared rate limiter. + /// Returns Some(response) if the request should be blocked. + pub fn apply_with_rate_limiter( + security: &RouteSecurity, + req: &Request, + peer_addr: &SocketAddr, + rate_limiter: Option<&Arc>, + ) -> Option>> { + let client_ip = peer_addr.ip(); + let request_path = req.uri().path(); + + // IP filter + if security.ip_allow_list.is_some() || security.ip_block_list.is_some() { + let allow = security.ip_allow_list.as_deref().unwrap_or(&[]); + let block = security.ip_block_list.as_deref().unwrap_or(&[]); + let filter = IpFilter::new(allow, block); + let normalized = IpFilter::normalize_ip(&client_ip); + if !filter.is_allowed(&normalized) { + return Some(error_response(StatusCode::FORBIDDEN, "Access denied")); + } + } + + // Rate limiting + if let Some(ref rate_limit_config) = security.rate_limit { + if rate_limit_config.enabled { + // Use shared rate limiter if provided, otherwise create ephemeral one + let should_block = if let Some(limiter) = rate_limiter { + let key = Self::rate_limit_key(rate_limit_config, req, peer_addr); + !limiter.check(&key) + } else { + // Create a per-check limiter (less ideal but works for non-shared case) + let limiter = RateLimiter::new( + rate_limit_config.max_requests, + rate_limit_config.window, + ); + let key = Self::rate_limit_key(rate_limit_config, req, peer_addr); + !limiter.check(&key) + }; + + if should_block { + let message = rate_limit_config.error_message + .as_deref() + .unwrap_or("Rate limit exceeded"); + return Some(error_response(StatusCode::TOO_MANY_REQUESTS, message)); + } + } + } + + // Check exclude paths before auth + let should_skip_auth = Self::path_matches_exclude_list(request_path, security); + + if !should_skip_auth { + // Basic auth + if let Some(ref basic_auth) = security.basic_auth { + if basic_auth.enabled { + // Check basic auth exclude paths + let skip_basic = basic_auth.exclude_paths.as_ref() + .map(|paths| Self::path_matches_any(request_path, paths)) + .unwrap_or(false); + + if !skip_basic { + let users: Vec<(String, String)> = basic_auth.users.iter() + .map(|c| (c.username.clone(), c.password.clone())) + .collect(); + let validator = BasicAuthValidator::new(users, basic_auth.realm.clone()); + + let auth_header = req.headers() + .get("authorization") + .and_then(|v| v.to_str().ok()); + + match auth_header { + Some(header) => { + if validator.validate(header).is_none() { + return Some(Response::builder() + .status(StatusCode::UNAUTHORIZED) + .header("WWW-Authenticate", validator.www_authenticate()) + .body(boxed_body("Invalid credentials")) + .unwrap()); + } + } + None => { + return Some(Response::builder() + .status(StatusCode::UNAUTHORIZED) + .header("WWW-Authenticate", validator.www_authenticate()) + .body(boxed_body("Authentication required")) + .unwrap()); + } + } + } + } + } + + // JWT auth + if let Some(ref jwt_auth) = security.jwt_auth { + if jwt_auth.enabled { + // Check JWT auth exclude paths + let skip_jwt = jwt_auth.exclude_paths.as_ref() + .map(|paths| Self::path_matches_any(request_path, paths)) + .unwrap_or(false); + + if !skip_jwt { + let validator = JwtValidator::new( + &jwt_auth.secret, + jwt_auth.algorithm.as_deref(), + jwt_auth.issuer.as_deref(), + jwt_auth.audience.as_deref(), + ); + + let auth_header = req.headers() + .get("authorization") + .and_then(|v| v.to_str().ok()); + + match auth_header.and_then(JwtValidator::extract_token) { + Some(token) => { + if validator.validate(token).is_err() { + return Some(error_response(StatusCode::UNAUTHORIZED, "Invalid token")); + } + } + None => { + return Some(error_response(StatusCode::UNAUTHORIZED, "Bearer token required")); + } + } + } + } + } + } + + None + } + + /// Check if a request path matches any pattern in the exclude list. + fn path_matches_exclude_list(_path: &str, _security: &RouteSecurity) -> bool { + // No global exclude paths on RouteSecurity currently, + // but we check per-auth exclude paths above. + // This can be extended if a global exclude_paths is added. + false + } + + /// Check if a path matches any pattern in the list. + /// Supports simple glob patterns: `/health*` matches `/health`, `/healthz`, `/health/check` + fn path_matches_any(path: &str, patterns: &[String]) -> bool { + for pattern in patterns { + if pattern.ends_with('*') { + let prefix = &pattern[..pattern.len() - 1]; + if path.starts_with(prefix) { + return true; + } + } else if path == pattern { + return true; + } + } + false + } + + /// Determine the rate limit key based on configuration. + fn rate_limit_key( + config: &rustproxy_config::RouteRateLimit, + req: &Request, + peer_addr: &SocketAddr, + ) -> String { + use rustproxy_config::RateLimitKeyBy; + match config.key_by.as_ref().unwrap_or(&RateLimitKeyBy::Ip) { + RateLimitKeyBy::Ip => peer_addr.ip().to_string(), + RateLimitKeyBy::Path => req.uri().path().to_string(), + RateLimitKeyBy::Header => { + if let Some(ref header_name) = config.header_name { + req.headers() + .get(header_name.as_str()) + .and_then(|v| v.to_str().ok()) + .unwrap_or("unknown") + .to_string() + } else { + peer_addr.ip().to_string() + } + } + } + } + + /// Check IP-based security (for use in passthrough / TCP-level connections). + /// Returns true if allowed, false if blocked. + pub fn check_ip_security(security: &RouteSecurity, client_ip: &std::net::IpAddr) -> bool { + if security.ip_allow_list.is_some() || security.ip_block_list.is_some() { + let allow = security.ip_allow_list.as_deref().unwrap_or(&[]); + let block = security.ip_block_list.as_deref().unwrap_or(&[]); + let filter = IpFilter::new(allow, block); + let normalized = IpFilter::normalize_ip(client_ip); + filter.is_allowed(&normalized) + } else { + true + } + } + + /// Handle CORS preflight (OPTIONS) requests. + /// Returns Some(response) if this is a CORS preflight that should be handled. + pub fn handle_cors_preflight( + req: &Request, + ) -> Option>> { + if req.method() != hyper::Method::OPTIONS { + return None; + } + + // Check for CORS preflight indicators + let has_origin = req.headers().contains_key("origin"); + let has_request_method = req.headers().contains_key("access-control-request-method"); + + if !has_origin || !has_request_method { + return None; + } + + let origin = req.headers() + .get("origin") + .and_then(|v| v.to_str().ok()) + .unwrap_or("*"); + + Some(Response::builder() + .status(StatusCode::NO_CONTENT) + .header("Access-Control-Allow-Origin", origin) + .header("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, PATCH, OPTIONS") + .header("Access-Control-Allow-Headers", "Content-Type, Authorization, X-Requested-With") + .header("Access-Control-Max-Age", "86400") + .body(boxed_body("")) + .unwrap()) + } +} + +fn error_response(status: StatusCode, message: &str) -> Response> { + Response::builder() + .status(status) + .header("Content-Type", "text/plain") + .body(boxed_body(message)) + .unwrap() +} + +fn boxed_body(data: &str) -> BoxBody { + BoxBody::new(Full::new(Bytes::from(data.to_string())).map_err(|never| match never {})) +} diff --git a/rust/crates/rustproxy-http/src/response_filter.rs b/rust/crates/rustproxy-http/src/response_filter.rs new file mode 100644 index 0000000..1441628 --- /dev/null +++ b/rust/crates/rustproxy-http/src/response_filter.rs @@ -0,0 +1,92 @@ +//! Response filtering: CORS headers, custom headers, security headers. + +use hyper::header::{HeaderMap, HeaderName, HeaderValue}; +use rustproxy_config::RouteConfig; + +use crate::template::{RequestContext, expand_template}; + +pub struct ResponseFilter; + +impl ResponseFilter { + /// Apply response headers from route config and CORS settings. + /// If a `RequestContext` is provided, template variables in header values will be expanded. + pub fn apply_headers(route: &RouteConfig, headers: &mut HeaderMap, req_ctx: Option<&RequestContext>) { + // Apply custom response headers from route config + if let Some(ref route_headers) = route.headers { + if let Some(ref response_headers) = route_headers.response { + for (key, value) in response_headers { + if let Ok(name) = HeaderName::from_bytes(key.as_bytes()) { + let expanded = match req_ctx { + Some(ctx) => expand_template(value, ctx), + None => value.clone(), + }; + if let Ok(val) = HeaderValue::from_str(&expanded) { + headers.insert(name, val); + } + } + } + } + + // Apply CORS headers if configured + if let Some(ref cors) = route_headers.cors { + if cors.enabled { + Self::apply_cors_headers(cors, headers); + } + } + } + } + + fn apply_cors_headers(cors: &rustproxy_config::RouteCors, headers: &mut HeaderMap) { + // Allow-Origin + if let Some(ref origin) = cors.allow_origin { + let origin_str = match origin { + rustproxy_config::AllowOrigin::Single(s) => s.clone(), + rustproxy_config::AllowOrigin::List(list) => list.join(", "), + }; + if let Ok(val) = HeaderValue::from_str(&origin_str) { + headers.insert("access-control-allow-origin", val); + } + } else { + headers.insert( + "access-control-allow-origin", + HeaderValue::from_static("*"), + ); + } + + // Allow-Methods + if let Some(ref methods) = cors.allow_methods { + if let Ok(val) = HeaderValue::from_str(methods) { + headers.insert("access-control-allow-methods", val); + } + } + + // Allow-Headers + if let Some(ref allow_headers) = cors.allow_headers { + if let Ok(val) = HeaderValue::from_str(allow_headers) { + headers.insert("access-control-allow-headers", val); + } + } + + // Allow-Credentials + if cors.allow_credentials == Some(true) { + headers.insert( + "access-control-allow-credentials", + HeaderValue::from_static("true"), + ); + } + + // Expose-Headers + if let Some(ref expose) = cors.expose_headers { + if let Ok(val) = HeaderValue::from_str(expose) { + headers.insert("access-control-expose-headers", val); + } + } + + // Max-Age + if let Some(max_age) = cors.max_age { + if let Ok(val) = HeaderValue::from_str(&max_age.to_string()) { + headers.insert("access-control-max-age", val); + } + } + } +} diff --git a/rust/crates/rustproxy-http/src/template.rs b/rust/crates/rustproxy-http/src/template.rs new file mode 100644 index 0000000..a6333bc --- /dev/null +++ b/rust/crates/rustproxy-http/src/template.rs @@ -0,0 +1,162 @@ +//! Header template variable expansion. +//! +//! Supports expanding template variables like `{clientIp}`, `{domain}`, etc. +//! in header values before they are applied to requests or responses. + +use std::collections::HashMap; +use std::time::{SystemTime, UNIX_EPOCH}; + +/// Context for template variable expansion. +pub struct RequestContext { + pub client_ip: String, + pub domain: String, + pub port: u16, + pub path: String, + pub route_name: String, + pub connection_id: u64, +} + +/// Expand template variables in a header value. +/// Supported variables: {clientIp}, {domain}, {port}, {path}, {routeName}, {connectionId}, {timestamp} +pub fn expand_template(template: &str, ctx: &RequestContext) -> String { + let timestamp = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_secs(); + + template + .replace("{clientIp}", &ctx.client_ip) + .replace("{domain}", &ctx.domain) + .replace("{port}", &ctx.port.to_string()) + .replace("{path}", &ctx.path) + .replace("{routeName}", &ctx.route_name) + .replace("{connectionId}", &ctx.connection_id.to_string()) + .replace("{timestamp}", ×tamp.to_string()) +} + +/// Expand templates in a map of header key-value pairs. +pub fn expand_headers( + headers: &HashMap, + ctx: &RequestContext, +) -> HashMap { + headers.iter() + .map(|(k, v)| (k.clone(), expand_template(v, ctx))) + .collect() +} + +#[cfg(test)] +mod tests { + use super::*; + + fn test_context() -> RequestContext { + RequestContext { + client_ip: "192.168.1.100".to_string(), + domain: "example.com".to_string(), + port: 443, + path: "/api/v1/users".to_string(), + route_name: "api-route".to_string(), + connection_id: 42, + } + } + + #[test] + fn test_expand_client_ip() { + let ctx = test_context(); + assert_eq!(expand_template("{clientIp}", &ctx), "192.168.1.100"); + } + + #[test] + fn test_expand_domain() { + let ctx = test_context(); + assert_eq!(expand_template("{domain}", &ctx), "example.com"); + } + + #[test] + fn test_expand_port() { + let ctx = test_context(); + assert_eq!(expand_template("{port}", &ctx), "443"); + } + + #[test] + fn test_expand_path() { + let ctx = test_context(); + assert_eq!(expand_template("{path}", &ctx), "/api/v1/users"); + } + + #[test] + fn test_expand_route_name() { + let ctx = test_context(); + assert_eq!(expand_template("{routeName}", &ctx), "api-route"); + } + + #[test] + fn test_expand_connection_id() { + let ctx = test_context(); + assert_eq!(expand_template("{connectionId}", &ctx), "42"); + } + + #[test] + fn test_expand_timestamp() { + let ctx = test_context(); + let result = expand_template("{timestamp}", &ctx); + // Timestamp should be a valid number + let ts: u64 = result.parse().expect("timestamp should be a number"); + // Should be a reasonable Unix timestamp (after 2020) + assert!(ts > 1_577_836_800); + } + + #[test] + fn test_expand_mixed_template() { + let ctx = test_context(); + let result = expand_template("client={clientIp}, host={domain}:{port}", &ctx); + assert_eq!(result, "client=192.168.1.100, host=example.com:443"); + } + + #[test] + fn test_expand_no_variables() { + let ctx = test_context(); + assert_eq!(expand_template("plain-value", &ctx), "plain-value"); + } + + #[test] + fn test_expand_empty_string() { + let ctx = test_context(); + assert_eq!(expand_template("", &ctx), ""); + } + + #[test] + fn test_expand_multiple_same_variable() { + let ctx = test_context(); + let result = expand_template("{clientIp}-{clientIp}", &ctx); + assert_eq!(result, "192.168.1.100-192.168.1.100"); + } + + #[test] + fn test_expand_headers_map() { + let ctx = test_context(); + let mut headers = HashMap::new(); + headers.insert("X-Forwarded-For".to_string(), "{clientIp}".to_string()); + headers.insert("X-Route".to_string(), "{routeName}".to_string()); + headers.insert("X-Static".to_string(), "no-template".to_string()); + + let result = expand_headers(&headers, &ctx); + assert_eq!(result.get("X-Forwarded-For").unwrap(), "192.168.1.100"); + assert_eq!(result.get("X-Route").unwrap(), "api-route"); + assert_eq!(result.get("X-Static").unwrap(), "no-template"); + } + + #[test] + fn test_expand_all_variables_in_one() { + let ctx = test_context(); + let template = "{clientIp}|{domain}|{port}|{path}|{routeName}|{connectionId}"; + let result = expand_template(template, &ctx); + assert_eq!(result, "192.168.1.100|example.com|443|/api/v1/users|api-route|42"); + } + + #[test] + fn test_expand_unknown_variable_left_as_is() { + let ctx = test_context(); + let result = expand_template("{unknownVar}", &ctx); + assert_eq!(result, "{unknownVar}"); + } +} diff --git a/rust/crates/rustproxy-http/src/upstream_selector.rs b/rust/crates/rustproxy-http/src/upstream_selector.rs new file mode 100644 index 0000000..e611159 --- /dev/null +++ b/rust/crates/rustproxy-http/src/upstream_selector.rs @@ -0,0 +1,222 @@ +//! Route-aware upstream selection with load balancing. + +use std::collections::HashMap; +use std::net::SocketAddr; +use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering}; +use std::sync::Arc; +use std::sync::Mutex; + +use dashmap::DashMap; +use rustproxy_config::{RouteTarget, LoadBalancingAlgorithm}; + +/// Upstream selection result. +pub struct UpstreamSelection { + pub host: String, + pub port: u16, + pub use_tls: bool, +} + +/// Selects upstream backends with load balancing support. +pub struct UpstreamSelector { + /// Round-robin counters per route (keyed by first target host:port) + round_robin: Mutex>, + /// Active connection counts per host (keyed by "host:port") + active_connections: Arc>, +} + +impl UpstreamSelector { + pub fn new() -> Self { + Self { + round_robin: Mutex::new(HashMap::new()), + active_connections: Arc::new(DashMap::new()), + } + } + + /// Select an upstream target based on the route target config and load balancing. + pub fn select( + &self, + target: &RouteTarget, + client_addr: &SocketAddr, + incoming_port: u16, + ) -> UpstreamSelection { + let hosts = target.host.to_vec(); + let port = target.port.resolve(incoming_port); + + if hosts.len() <= 1 { + return UpstreamSelection { + host: hosts.first().map(|s| s.to_string()).unwrap_or_default(), + port, + use_tls: target.tls.is_some(), + }; + } + + // Determine load balancing algorithm + let algorithm = target.load_balancing.as_ref() + .map(|lb| &lb.algorithm) + .unwrap_or(&LoadBalancingAlgorithm::RoundRobin); + + let idx = match algorithm { + LoadBalancingAlgorithm::RoundRobin => { + self.round_robin_select(&hosts, port) + } + LoadBalancingAlgorithm::IpHash => { + let hash = Self::ip_hash(client_addr); + hash % hosts.len() + } + LoadBalancingAlgorithm::LeastConnections => { + self.least_connections_select(&hosts, port) + } + }; + + UpstreamSelection { + host: hosts[idx].to_string(), + port, + use_tls: target.tls.is_some(), + } + } + + fn round_robin_select(&self, hosts: &[&str], port: u16) -> usize { + let key = format!("{}:{}", hosts[0], port); + let mut counters = self.round_robin.lock().unwrap(); + let counter = counters + .entry(key) + .or_insert_with(|| AtomicUsize::new(0)); + let idx = counter.fetch_add(1, Ordering::Relaxed); + idx % hosts.len() + } + + fn least_connections_select(&self, hosts: &[&str], port: u16) -> usize { + let mut min_conns = u64::MAX; + let mut min_idx = 0; + + for (i, host) in hosts.iter().enumerate() { + let key = format!("{}:{}", host, port); + let conns = self.active_connections + .get(&key) + .map(|entry| entry.value().load(Ordering::Relaxed)) + .unwrap_or(0); + if conns < min_conns { + min_conns = conns; + min_idx = i; + } + } + + min_idx + } + + /// Record that a connection to the given host has started. + pub fn connection_started(&self, host: &str) { + self.active_connections + .entry(host.to_string()) + .or_insert_with(|| AtomicU64::new(0)) + .fetch_add(1, Ordering::Relaxed); + } + + /// Record that a connection to the given host has ended. + pub fn connection_ended(&self, host: &str) { + if let Some(counter) = self.active_connections.get(host) { + let prev = counter.value().fetch_sub(1, Ordering::Relaxed); + // Guard against underflow (shouldn't happen, but be safe) + if prev == 0 { + counter.value().store(0, Ordering::Relaxed); + } + } + } + + fn ip_hash(addr: &SocketAddr) -> usize { + let ip_str = addr.ip().to_string(); + let mut hash: usize = 5381; + for byte in ip_str.bytes() { + hash = hash.wrapping_mul(33).wrapping_add(byte as usize); + } + hash + } +} + +impl Default for UpstreamSelector { + fn default() -> Self { + Self::new() + } +} + +impl Clone for UpstreamSelector { + fn clone(&self) -> Self { + Self { + round_robin: Mutex::new(HashMap::new()), + active_connections: Arc::clone(&self.active_connections), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use rustproxy_config::*; + + fn make_target(hosts: Vec<&str>, port: u16) -> RouteTarget { + RouteTarget { + target_match: None, + host: if hosts.len() == 1 { + HostSpec::Single(hosts[0].to_string()) + } else { + HostSpec::List(hosts.iter().map(|s| s.to_string()).collect()) + }, + port: PortSpec::Fixed(port), + tls: None, + websocket: None, + load_balancing: None, + send_proxy_protocol: None, + headers: None, + advanced: None, + priority: None, + } + } + + #[test] + fn test_single_host() { + let selector = UpstreamSelector::new(); + let target = make_target(vec!["backend"], 8080); + let addr: SocketAddr = "127.0.0.1:1234".parse().unwrap(); + let result = selector.select(&target, &addr, 80); + assert_eq!(result.host, "backend"); + assert_eq!(result.port, 8080); + } + + #[test] + fn test_round_robin() { + let selector = UpstreamSelector::new(); + let mut target = make_target(vec!["a", "b", "c"], 8080); + target.load_balancing = Some(RouteLoadBalancing { + algorithm: LoadBalancingAlgorithm::RoundRobin, + health_check: None, + }); + let addr: SocketAddr = "127.0.0.1:1234".parse().unwrap(); + + let r1 = selector.select(&target, &addr, 80); + let r2 = selector.select(&target, &addr, 80); + let r3 = selector.select(&target, &addr, 80); + let r4 = selector.select(&target, &addr, 80); + + // Should cycle through a, b, c, a + assert_eq!(r1.host, "a"); + assert_eq!(r2.host, "b"); + assert_eq!(r3.host, "c"); + assert_eq!(r4.host, "a"); + } + + #[test] + fn test_ip_hash_consistent() { + let selector = UpstreamSelector::new(); + let mut target = make_target(vec!["a", "b", "c"], 8080); + target.load_balancing = Some(RouteLoadBalancing { + algorithm: LoadBalancingAlgorithm::IpHash, + health_check: None, + }); + let addr: SocketAddr = "10.0.0.5:1234".parse().unwrap(); + + let r1 = selector.select(&target, &addr, 80); + let r2 = selector.select(&target, &addr, 80); + // Same IP should always get same backend + assert_eq!(r1.host, r2.host); + } +} diff --git a/rust/crates/rustproxy-metrics/Cargo.toml b/rust/crates/rustproxy-metrics/Cargo.toml new file mode 100644 index 0000000..bcf827c --- /dev/null +++ b/rust/crates/rustproxy-metrics/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "rustproxy-metrics" +version.workspace = true +edition.workspace = true +license.workspace = true +authors.workspace = true +description = "Metrics and throughput tracking for RustProxy" + +[dependencies] +dashmap = { workspace = true } +tracing = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } +tokio = { workspace = true } +tokio-util = { workspace = true } diff --git a/rust/crates/rustproxy-metrics/src/collector.rs b/rust/crates/rustproxy-metrics/src/collector.rs new file mode 100644 index 0000000..8932ed6 --- /dev/null +++ b/rust/crates/rustproxy-metrics/src/collector.rs @@ -0,0 +1,251 @@ +use dashmap::DashMap; +use serde::{Deserialize, Serialize}; +use std::sync::atomic::{AtomicU64, Ordering}; + +/// Aggregated metrics snapshot. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct Metrics { + pub active_connections: u64, + pub total_connections: u64, + pub bytes_in: u64, + pub bytes_out: u64, + pub throughput_in_bytes_per_sec: u64, + pub throughput_out_bytes_per_sec: u64, + pub routes: std::collections::HashMap, +} + +/// Per-route metrics. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct RouteMetrics { + pub active_connections: u64, + pub total_connections: u64, + pub bytes_in: u64, + pub bytes_out: u64, + pub throughput_in_bytes_per_sec: u64, + pub throughput_out_bytes_per_sec: u64, +} + +/// Statistics snapshot. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct Statistics { + pub active_connections: u64, + pub total_connections: u64, + pub routes_count: u64, + pub listening_ports: Vec, + pub uptime_seconds: u64, +} + +/// Metrics collector tracking connections and throughput. +pub struct MetricsCollector { + active_connections: AtomicU64, + total_connections: AtomicU64, + total_bytes_in: AtomicU64, + total_bytes_out: AtomicU64, + /// Per-route active connection counts + route_connections: DashMap, + /// Per-route total connection counts + route_total_connections: DashMap, + /// Per-route byte counters + route_bytes_in: DashMap, + route_bytes_out: DashMap, +} + +impl MetricsCollector { + pub fn new() -> Self { + Self { + active_connections: AtomicU64::new(0), + total_connections: AtomicU64::new(0), + total_bytes_in: AtomicU64::new(0), + total_bytes_out: AtomicU64::new(0), + route_connections: DashMap::new(), + route_total_connections: DashMap::new(), + route_bytes_in: DashMap::new(), + route_bytes_out: DashMap::new(), + } + } + + /// Record a new connection. + pub fn connection_opened(&self, route_id: Option<&str>) { + self.active_connections.fetch_add(1, Ordering::Relaxed); + self.total_connections.fetch_add(1, Ordering::Relaxed); + + if let Some(route_id) = route_id { + self.route_connections + .entry(route_id.to_string()) + .or_insert_with(|| AtomicU64::new(0)) + .fetch_add(1, Ordering::Relaxed); + self.route_total_connections + .entry(route_id.to_string()) + .or_insert_with(|| AtomicU64::new(0)) + .fetch_add(1, Ordering::Relaxed); + } + } + + /// Record a connection closing. + pub fn connection_closed(&self, route_id: Option<&str>) { + self.active_connections.fetch_sub(1, Ordering::Relaxed); + + if let Some(route_id) = route_id { + if let Some(counter) = self.route_connections.get(route_id) { + let val = counter.load(Ordering::Relaxed); + if val > 0 { + counter.fetch_sub(1, Ordering::Relaxed); + } + } + } + } + + /// Record bytes transferred. + pub fn record_bytes(&self, bytes_in: u64, bytes_out: u64, route_id: Option<&str>) { + self.total_bytes_in.fetch_add(bytes_in, Ordering::Relaxed); + self.total_bytes_out.fetch_add(bytes_out, Ordering::Relaxed); + + if let Some(route_id) = route_id { + self.route_bytes_in + .entry(route_id.to_string()) + .or_insert_with(|| AtomicU64::new(0)) + .fetch_add(bytes_in, Ordering::Relaxed); + self.route_bytes_out + .entry(route_id.to_string()) + .or_insert_with(|| AtomicU64::new(0)) + .fetch_add(bytes_out, Ordering::Relaxed); + } + } + + /// Get current active connection count. + pub fn active_connections(&self) -> u64 { + self.active_connections.load(Ordering::Relaxed) + } + + /// Get total connection count. + pub fn total_connections(&self) -> u64 { + self.total_connections.load(Ordering::Relaxed) + } + + /// Get total bytes received. + pub fn total_bytes_in(&self) -> u64 { + self.total_bytes_in.load(Ordering::Relaxed) + } + + /// Get total bytes sent. + pub fn total_bytes_out(&self) -> u64 { + self.total_bytes_out.load(Ordering::Relaxed) + } + + /// Get a full metrics snapshot including per-route data. + pub fn snapshot(&self) -> Metrics { + let mut routes = std::collections::HashMap::new(); + + // Collect per-route metrics + for entry in self.route_total_connections.iter() { + let route_id = entry.key().clone(); + let total = entry.value().load(Ordering::Relaxed); + let active = self.route_connections + .get(&route_id) + .map(|c| c.load(Ordering::Relaxed)) + .unwrap_or(0); + let bytes_in = self.route_bytes_in + .get(&route_id) + .map(|c| c.load(Ordering::Relaxed)) + .unwrap_or(0); + let bytes_out = self.route_bytes_out + .get(&route_id) + .map(|c| c.load(Ordering::Relaxed)) + .unwrap_or(0); + + routes.insert(route_id, RouteMetrics { + active_connections: active, + total_connections: total, + bytes_in, + bytes_out, + throughput_in_bytes_per_sec: 0, + throughput_out_bytes_per_sec: 0, + }); + } + + Metrics { + active_connections: self.active_connections(), + total_connections: self.total_connections(), + bytes_in: self.total_bytes_in(), + bytes_out: self.total_bytes_out(), + throughput_in_bytes_per_sec: 0, + throughput_out_bytes_per_sec: 0, + routes, + } + } +} + +impl Default for MetricsCollector { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_initial_state_zeros() { + let collector = MetricsCollector::new(); + assert_eq!(collector.active_connections(), 0); + assert_eq!(collector.total_connections(), 0); + } + + #[test] + fn test_connection_opened_increments() { + let collector = MetricsCollector::new(); + collector.connection_opened(None); + assert_eq!(collector.active_connections(), 1); + assert_eq!(collector.total_connections(), 1); + collector.connection_opened(None); + assert_eq!(collector.active_connections(), 2); + assert_eq!(collector.total_connections(), 2); + } + + #[test] + fn test_connection_closed_decrements() { + let collector = MetricsCollector::new(); + collector.connection_opened(None); + collector.connection_opened(None); + assert_eq!(collector.active_connections(), 2); + collector.connection_closed(None); + assert_eq!(collector.active_connections(), 1); + // total_connections should stay at 2 + assert_eq!(collector.total_connections(), 2); + } + + #[test] + fn test_route_specific_tracking() { + let collector = MetricsCollector::new(); + collector.connection_opened(Some("route-a")); + collector.connection_opened(Some("route-a")); + collector.connection_opened(Some("route-b")); + + assert_eq!(collector.active_connections(), 3); + assert_eq!(collector.total_connections(), 3); + + collector.connection_closed(Some("route-a")); + assert_eq!(collector.active_connections(), 2); + } + + #[test] + fn test_record_bytes() { + let collector = MetricsCollector::new(); + collector.record_bytes(100, 200, Some("route-a")); + collector.record_bytes(50, 75, Some("route-a")); + collector.record_bytes(25, 30, None); + + let total_in = collector.total_bytes_in.load(Ordering::Relaxed); + let total_out = collector.total_bytes_out.load(Ordering::Relaxed); + assert_eq!(total_in, 175); + assert_eq!(total_out, 305); + + // Route-specific bytes + let route_in = collector.route_bytes_in.get("route-a").unwrap(); + assert_eq!(route_in.load(Ordering::Relaxed), 150); + } +} diff --git a/rust/crates/rustproxy-metrics/src/lib.rs b/rust/crates/rustproxy-metrics/src/lib.rs new file mode 100644 index 0000000..874a64e --- /dev/null +++ b/rust/crates/rustproxy-metrics/src/lib.rs @@ -0,0 +1,11 @@ +//! # rustproxy-metrics +//! +//! Metrics and throughput tracking for RustProxy. + +pub mod throughput; +pub mod collector; +pub mod log_dedup; + +pub use throughput::*; +pub use collector::*; +pub use log_dedup::*; diff --git a/rust/crates/rustproxy-metrics/src/log_dedup.rs b/rust/crates/rustproxy-metrics/src/log_dedup.rs new file mode 100644 index 0000000..1e43620 --- /dev/null +++ b/rust/crates/rustproxy-metrics/src/log_dedup.rs @@ -0,0 +1,219 @@ +use dashmap::DashMap; +use std::sync::Arc; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::time::{Duration, Instant}; +use tracing::info; + +/// An aggregated event during the deduplication window. +struct AggregatedEvent { + category: String, + first_message: String, + count: AtomicU64, + first_seen: Instant, + #[allow(dead_code)] + last_seen: Instant, +} + +/// Log deduplicator that batches similar events over a time window. +/// +/// Events are grouped by a composite key of `category:key`. Within each +/// deduplication window (`flush_interval`) identical events are counted +/// instead of being emitted individually. When the window expires (or the +/// batch reaches `max_batch_size`) a single summary line is written via +/// `tracing::info!`. +pub struct LogDeduplicator { + events: DashMap, + flush_interval: Duration, + max_batch_size: u64, + #[allow(dead_code)] + rapid_threshold: u64, // events/sec that triggers immediate flush +} + +impl LogDeduplicator { + pub fn new() -> Self { + Self { + events: DashMap::new(), + flush_interval: Duration::from_secs(5), + max_batch_size: 100, + rapid_threshold: 50, + } + } + + /// Log an event, deduplicating by `category` + `key`. + /// + /// If the batch for this composite key reaches `max_batch_size` the + /// accumulated events are flushed immediately. + pub fn log(&self, category: &str, key: &str, message: &str) { + let map_key = format!("{}:{}", category, key); + let now = Instant::now(); + + let entry = self.events.entry(map_key).or_insert_with(|| AggregatedEvent { + category: category.to_string(), + first_message: message.to_string(), + count: AtomicU64::new(0), + first_seen: now, + last_seen: now, + }); + + let count = entry.count.fetch_add(1, Ordering::Relaxed) + 1; + + // Check if we should flush (batch size exceeded) + if count >= self.max_batch_size { + drop(entry); + self.flush(); + } + } + + /// Flush all accumulated events, emitting summary log lines. + pub fn flush(&self) { + // Collect and remove all events + self.events.retain(|_key, event| { + let count = event.count.load(Ordering::Relaxed); + if count > 0 { + let elapsed = event.first_seen.elapsed(); + if count == 1 { + info!("[{}] {}", event.category, event.first_message); + } else { + info!( + "[SUMMARY] {} {} events in {:.1}s: {}", + count, + event.category, + elapsed.as_secs_f64(), + event.first_message + ); + } + } + false // remove all entries after flushing + }); + } + + /// Start a background flush task that periodically drains accumulated + /// events. The task runs until the supplied `CancellationToken` is + /// cancelled, at which point it performs one final flush before exiting. + pub fn start_flush_task(self: &Arc, cancel: tokio_util::sync::CancellationToken) { + let dedup = Arc::clone(self); + let interval = self.flush_interval; + tokio::spawn(async move { + loop { + tokio::select! { + _ = cancel.cancelled() => { + dedup.flush(); + break; + } + _ = tokio::time::sleep(interval) => { + dedup.flush(); + } + } + } + }); + } +} + +impl Default for LogDeduplicator { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_single_event_emitted_as_is() { + let dedup = LogDeduplicator::new(); + dedup.log("conn", "open", "connection opened from 1.2.3.4"); + // One event should exist + assert_eq!(dedup.events.len(), 1); + let entry = dedup.events.get("conn:open").unwrap(); + assert_eq!(entry.count.load(Ordering::Relaxed), 1); + assert_eq!(entry.first_message, "connection opened from 1.2.3.4"); + drop(entry); + dedup.flush(); + // After flush, map should be empty + assert_eq!(dedup.events.len(), 0); + } + + #[test] + fn test_duplicate_events_aggregated() { + let dedup = LogDeduplicator::new(); + for _ in 0..10 { + dedup.log("conn", "timeout", "connection timed out"); + } + assert_eq!(dedup.events.len(), 1); + let entry = dedup.events.get("conn:timeout").unwrap(); + assert_eq!(entry.count.load(Ordering::Relaxed), 10); + drop(entry); + dedup.flush(); + assert_eq!(dedup.events.len(), 0); + } + + #[test] + fn test_different_keys_separate() { + let dedup = LogDeduplicator::new(); + dedup.log("conn", "open", "opened"); + dedup.log("conn", "close", "closed"); + dedup.log("tls", "handshake", "TLS handshake"); + assert_eq!(dedup.events.len(), 3); + dedup.flush(); + assert_eq!(dedup.events.len(), 0); + } + + #[test] + fn test_flush_clears_events() { + let dedup = LogDeduplicator::new(); + dedup.log("a", "b", "msg1"); + dedup.log("a", "b", "msg2"); + dedup.flush(); + assert_eq!(dedup.events.len(), 0); + // Logging after flush creates a new entry + dedup.log("a", "b", "msg3"); + assert_eq!(dedup.events.len(), 1); + let entry = dedup.events.get("a:b").unwrap(); + assert_eq!(entry.count.load(Ordering::Relaxed), 1); + assert_eq!(entry.first_message, "msg3"); + } + + #[test] + fn test_max_batch_triggers_flush() { + let dedup = LogDeduplicator::new(); + // max_batch_size defaults to 100 + for i in 0..100 { + dedup.log("flood", "key", &format!("event {}", i)); + } + // After hitting max_batch_size the events map should have been flushed + assert_eq!(dedup.events.len(), 0); + } + + #[test] + fn test_default_trait() { + let dedup = LogDeduplicator::default(); + assert_eq!(dedup.flush_interval, Duration::from_secs(5)); + assert_eq!(dedup.max_batch_size, 100); + } + + #[tokio::test] + async fn test_background_flush_task() { + let dedup = Arc::new(LogDeduplicator { + events: DashMap::new(), + flush_interval: Duration::from_millis(50), + max_batch_size: 100, + rapid_threshold: 50, + }); + + let cancel = tokio_util::sync::CancellationToken::new(); + dedup.start_flush_task(cancel.clone()); + + // Log some events + dedup.log("bg", "test", "background flush test"); + assert_eq!(dedup.events.len(), 1); + + // Wait for the background task to flush + tokio::time::sleep(Duration::from_millis(100)).await; + assert_eq!(dedup.events.len(), 0); + + // Cancel the task + cancel.cancel(); + tokio::time::sleep(Duration::from_millis(20)).await; + } +} diff --git a/rust/crates/rustproxy-metrics/src/throughput.rs b/rust/crates/rustproxy-metrics/src/throughput.rs new file mode 100644 index 0000000..e73833f --- /dev/null +++ b/rust/crates/rustproxy-metrics/src/throughput.rs @@ -0,0 +1,173 @@ +use std::sync::atomic::{AtomicU64, Ordering}; +use std::time::{Instant, SystemTime, UNIX_EPOCH}; + +/// A single throughput sample. +#[derive(Debug, Clone, Copy)] +pub struct ThroughputSample { + pub timestamp_ms: u64, + pub bytes_in: u64, + pub bytes_out: u64, +} + +/// Circular buffer for 1Hz throughput sampling. +/// Matches smartproxy's ThroughputTracker. +pub struct ThroughputTracker { + /// Circular buffer of samples + samples: Vec, + /// Current write index + write_index: usize, + /// Number of valid samples + count: usize, + /// Maximum number of samples to retain + capacity: usize, + /// Accumulated bytes since last sample + pending_bytes_in: AtomicU64, + pending_bytes_out: AtomicU64, + /// When the tracker was created + created_at: Instant, +} + +impl ThroughputTracker { + /// Create a new tracker with the given capacity (seconds of retention). + pub fn new(retention_seconds: usize) -> Self { + Self { + samples: Vec::with_capacity(retention_seconds), + write_index: 0, + count: 0, + capacity: retention_seconds, + pending_bytes_in: AtomicU64::new(0), + pending_bytes_out: AtomicU64::new(0), + created_at: Instant::now(), + } + } + + /// Record bytes (called from data flow callbacks). + pub fn record_bytes(&self, bytes_in: u64, bytes_out: u64) { + self.pending_bytes_in.fetch_add(bytes_in, Ordering::Relaxed); + self.pending_bytes_out.fetch_add(bytes_out, Ordering::Relaxed); + } + + /// Take a sample (called at 1Hz). + pub fn sample(&mut self) { + let bytes_in = self.pending_bytes_in.swap(0, Ordering::Relaxed); + let bytes_out = self.pending_bytes_out.swap(0, Ordering::Relaxed); + let timestamp_ms = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_millis() as u64; + + let sample = ThroughputSample { + timestamp_ms, + bytes_in, + bytes_out, + }; + + if self.samples.len() < self.capacity { + self.samples.push(sample); + } else { + self.samples[self.write_index] = sample; + } + self.write_index = (self.write_index + 1) % self.capacity; + self.count = (self.count + 1).min(self.capacity); + } + + /// Get throughput over the last N seconds. + pub fn throughput(&self, window_seconds: usize) -> (u64, u64) { + let window = window_seconds.min(self.count); + if window == 0 { + return (0, 0); + } + + let mut total_in = 0u64; + let mut total_out = 0u64; + + for i in 0..window { + let idx = if self.write_index >= i + 1 { + self.write_index - i - 1 + } else { + self.capacity - (i + 1 - self.write_index) + }; + if idx < self.samples.len() { + total_in += self.samples[idx].bytes_in; + total_out += self.samples[idx].bytes_out; + } + } + + (total_in / window as u64, total_out / window as u64) + } + + /// Get instant throughput (last 1 second). + pub fn instant(&self) -> (u64, u64) { + self.throughput(1) + } + + /// Get recent throughput (last 10 seconds). + pub fn recent(&self) -> (u64, u64) { + self.throughput(10) + } + + /// How long this tracker has been alive. + pub fn uptime(&self) -> std::time::Duration { + self.created_at.elapsed() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_empty_throughput() { + let tracker = ThroughputTracker::new(60); + let (bytes_in, bytes_out) = tracker.throughput(10); + assert_eq!(bytes_in, 0); + assert_eq!(bytes_out, 0); + } + + #[test] + fn test_single_sample() { + let mut tracker = ThroughputTracker::new(60); + tracker.record_bytes(1000, 2000); + tracker.sample(); + let (bytes_in, bytes_out) = tracker.instant(); + assert_eq!(bytes_in, 1000); + assert_eq!(bytes_out, 2000); + } + + #[test] + fn test_circular_buffer_wrap() { + let mut tracker = ThroughputTracker::new(3); // Small capacity + for i in 0..5 { + tracker.record_bytes(i * 100, i * 200); + tracker.sample(); + } + // Should still work after wrapping + let (bytes_in, bytes_out) = tracker.throughput(3); + assert!(bytes_in > 0); + assert!(bytes_out > 0); + } + + #[test] + fn test_window_averaging() { + let mut tracker = ThroughputTracker::new(60); + // Record 3 samples of different sizes + tracker.record_bytes(100, 200); + tracker.sample(); + tracker.record_bytes(200, 400); + tracker.sample(); + tracker.record_bytes(300, 600); + tracker.sample(); + + // Average over 3 samples: (100+200+300)/3 = 200, (200+400+600)/3 = 400 + let (avg_in, avg_out) = tracker.throughput(3); + assert_eq!(avg_in, 200); + assert_eq!(avg_out, 400); + } + + #[test] + fn test_uptime_positive() { + let tracker = ThroughputTracker::new(60); + std::thread::sleep(std::time::Duration::from_millis(10)); + assert!(tracker.uptime().as_millis() >= 10); + } +} diff --git a/rust/crates/rustproxy-nftables/Cargo.toml b/rust/crates/rustproxy-nftables/Cargo.toml new file mode 100644 index 0000000..7632d52 --- /dev/null +++ b/rust/crates/rustproxy-nftables/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "rustproxy-nftables" +version.workspace = true +edition.workspace = true +license.workspace = true +authors.workspace = true +description = "NFTables kernel-level forwarding for RustProxy" + +[dependencies] +rustproxy-config = { workspace = true } +tokio = { workspace = true } +tracing = { workspace = true } +thiserror = { workspace = true } +anyhow = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } +libc = { workspace = true } diff --git a/rust/crates/rustproxy-nftables/src/lib.rs b/rust/crates/rustproxy-nftables/src/lib.rs new file mode 100644 index 0000000..0e5211a --- /dev/null +++ b/rust/crates/rustproxy-nftables/src/lib.rs @@ -0,0 +1,10 @@ +//! # rustproxy-nftables +//! +//! NFTables kernel-level forwarding for RustProxy. +//! Generates and manages nft CLI rules for DNAT/SNAT. + +pub mod nft_manager; +pub mod rule_builder; + +pub use nft_manager::*; +pub use rule_builder::*; diff --git a/rust/crates/rustproxy-nftables/src/nft_manager.rs b/rust/crates/rustproxy-nftables/src/nft_manager.rs new file mode 100644 index 0000000..7751048 --- /dev/null +++ b/rust/crates/rustproxy-nftables/src/nft_manager.rs @@ -0,0 +1,238 @@ +use thiserror::Error; +use std::collections::HashMap; +use tracing::{debug, info, warn}; + +#[derive(Debug, Error)] +pub enum NftError { + #[error("nft command failed: {0}")] + CommandFailed(String), + #[error("IO error: {0}")] + Io(#[from] std::io::Error), + #[error("Not running as root")] + NotRoot, +} + +/// Manager for nftables rules. +/// +/// Executes `nft` CLI commands to manage kernel-level packet forwarding. +/// Requires root privileges; operations are skipped gracefully if not root. +pub struct NftManager { + table_name: String, + /// Active rules indexed by route ID + active_rules: HashMap>, + /// Whether the table has been initialized + table_initialized: bool, +} + +impl NftManager { + pub fn new(table_name: Option) -> Self { + Self { + table_name: table_name.unwrap_or_else(|| "rustproxy".to_string()), + active_rules: HashMap::new(), + table_initialized: false, + } + } + + /// Check if we are running as root. + fn is_root() -> bool { + unsafe { libc::geteuid() == 0 } + } + + /// Execute a single nft command via the CLI. + async fn exec_nft(command: &str) -> Result { + // The command starts with "nft ", strip it to get the args + let args = if command.starts_with("nft ") { + &command[4..] + } else { + command + }; + + let output = tokio::process::Command::new("nft") + .args(args.split_whitespace()) + .output() + .await + .map_err(NftError::Io)?; + + if output.status.success() { + Ok(String::from_utf8_lossy(&output.stdout).to_string()) + } else { + let stderr = String::from_utf8_lossy(&output.stderr); + Err(NftError::CommandFailed(format!( + "Command '{}' failed: {}", + command, stderr + ))) + } + } + + /// Ensure the nftables table and chains are set up. + async fn ensure_table(&mut self) -> Result<(), NftError> { + if self.table_initialized { + return Ok(()); + } + + let setup_commands = crate::rule_builder::build_table_setup(&self.table_name); + for cmd in &setup_commands { + Self::exec_nft(cmd).await?; + } + + self.table_initialized = true; + info!("NFTables table '{}' initialized", self.table_name); + Ok(()) + } + + /// Apply rules for a route. + /// + /// Executes the nft commands via the CLI. If not running as root, + /// the rules are stored locally but not applied to the kernel. + pub async fn apply_rules(&mut self, route_id: &str, rules: Vec) -> Result<(), NftError> { + if !Self::is_root() { + warn!("Not running as root, nftables rules will not be applied to kernel"); + self.active_rules.insert(route_id.to_string(), rules); + return Ok(()); + } + + self.ensure_table().await?; + + for cmd in &rules { + Self::exec_nft(cmd).await?; + debug!("Applied nft rule: {}", cmd); + } + + info!("Applied {} nftables rules for route '{}'", rules.len(), route_id); + self.active_rules.insert(route_id.to_string(), rules); + Ok(()) + } + + /// Remove rules for a route. + /// + /// Currently removes the route from tracking. To fully remove specific + /// rules would require handle-based tracking; for now, cleanup() removes + /// the entire table. + pub async fn remove_rules(&mut self, route_id: &str) -> Result<(), NftError> { + if let Some(rules) = self.active_rules.remove(route_id) { + info!("Removed {} tracked nft rules for route '{}'", rules.len(), route_id); + } + Ok(()) + } + + /// Clean up all managed rules by deleting the entire nftables table. + pub async fn cleanup(&mut self) -> Result<(), NftError> { + if !Self::is_root() { + warn!("Not running as root, skipping nftables cleanup"); + self.active_rules.clear(); + self.table_initialized = false; + return Ok(()); + } + + if self.table_initialized { + let cleanup_commands = crate::rule_builder::build_table_cleanup(&self.table_name); + for cmd in &cleanup_commands { + match Self::exec_nft(cmd).await { + Ok(_) => debug!("Cleanup: {}", cmd), + Err(e) => warn!("Cleanup command failed (may be ok): {}", e), + } + } + info!("NFTables table '{}' cleaned up", self.table_name); + } + + self.active_rules.clear(); + self.table_initialized = false; + Ok(()) + } + + /// Get the table name. + pub fn table_name(&self) -> &str { + &self.table_name + } + + /// Whether the table has been initialized in the kernel. + pub fn is_initialized(&self) -> bool { + self.table_initialized + } + + /// Get the number of active route rule sets. + pub fn active_route_count(&self) -> usize { + self.active_rules.len() + } + + /// Get the status of all active rules. + pub fn status(&self) -> HashMap { + let mut status = HashMap::new(); + for (route_id, rules) in &self.active_rules { + status.insert( + route_id.clone(), + serde_json::json!({ + "ruleCount": rules.len(), + "rules": rules, + }), + ); + } + status + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_new_default_table_name() { + let mgr = NftManager::new(None); + assert_eq!(mgr.table_name(), "rustproxy"); + assert!(!mgr.is_initialized()); + } + + #[test] + fn test_new_custom_table_name() { + let mgr = NftManager::new(Some("custom".to_string())); + assert_eq!(mgr.table_name(), "custom"); + } + + #[tokio::test] + async fn test_apply_rules_non_root() { + let mut mgr = NftManager::new(None); + // When not root, rules are stored but not applied to kernel + let rules = vec!["nft add rule ip rustproxy prerouting tcp dport 443 dnat to 10.0.0.1:8443".to_string()]; + mgr.apply_rules("route-1", rules).await.unwrap(); + assert_eq!(mgr.active_route_count(), 1); + + let status = mgr.status(); + assert!(status.contains_key("route-1")); + assert_eq!(status["route-1"]["ruleCount"], 1); + } + + #[tokio::test] + async fn test_remove_rules() { + let mut mgr = NftManager::new(None); + let rules = vec!["nft add rule test".to_string()]; + mgr.apply_rules("route-1", rules).await.unwrap(); + assert_eq!(mgr.active_route_count(), 1); + + mgr.remove_rules("route-1").await.unwrap(); + assert_eq!(mgr.active_route_count(), 0); + } + + #[tokio::test] + async fn test_cleanup_non_root() { + let mut mgr = NftManager::new(None); + let rules = vec!["nft add rule test".to_string()]; + mgr.apply_rules("route-1", rules).await.unwrap(); + mgr.apply_rules("route-2", vec!["nft add rule test2".to_string()]).await.unwrap(); + + mgr.cleanup().await.unwrap(); + assert_eq!(mgr.active_route_count(), 0); + assert!(!mgr.is_initialized()); + } + + #[tokio::test] + async fn test_status_multiple_routes() { + let mut mgr = NftManager::new(None); + mgr.apply_rules("web", vec!["rule1".to_string(), "rule2".to_string()]).await.unwrap(); + mgr.apply_rules("api", vec!["rule3".to_string()]).await.unwrap(); + + let status = mgr.status(); + assert_eq!(status.len(), 2); + assert_eq!(status["web"]["ruleCount"], 2); + assert_eq!(status["api"]["ruleCount"], 1); + } +} diff --git a/rust/crates/rustproxy-nftables/src/rule_builder.rs b/rust/crates/rustproxy-nftables/src/rule_builder.rs new file mode 100644 index 0000000..da25a25 --- /dev/null +++ b/rust/crates/rustproxy-nftables/src/rule_builder.rs @@ -0,0 +1,123 @@ +use rustproxy_config::{NfTablesOptions, NfTablesProtocol}; + +/// Build nftables DNAT rule for port forwarding. +pub fn build_dnat_rule( + table_name: &str, + chain_name: &str, + source_port: u16, + target_host: &str, + target_port: u16, + options: &NfTablesOptions, +) -> Vec { + let protocol = match options.protocol.as_ref().unwrap_or(&NfTablesProtocol::Tcp) { + NfTablesProtocol::Tcp => "tcp", + NfTablesProtocol::Udp => "udp", + NfTablesProtocol::All => "tcp", // TODO: handle "all" + }; + + let mut rules = Vec::new(); + + // DNAT rule + rules.push(format!( + "nft add rule ip {} {} {} dport {} dnat to {}:{}", + table_name, chain_name, protocol, source_port, target_host, target_port, + )); + + // SNAT rule if preserving source IP is not enabled + if !options.preserve_source_ip.unwrap_or(false) { + rules.push(format!( + "nft add rule ip {} postrouting {} dport {} masquerade", + table_name, protocol, target_port, + )); + } + + // Rate limiting + if let Some(max_rate) = &options.max_rate { + rules.push(format!( + "nft add rule ip {} {} {} dport {} limit rate {} accept", + table_name, chain_name, protocol, source_port, max_rate, + )); + } + + rules +} + +/// Build the initial table and chain setup commands. +pub fn build_table_setup(table_name: &str) -> Vec { + vec![ + format!("nft add table ip {}", table_name), + format!("nft add chain ip {} prerouting {{ type nat hook prerouting priority 0 \\; }}", table_name), + format!("nft add chain ip {} postrouting {{ type nat hook postrouting priority 100 \\; }}", table_name), + ] +} + +/// Build cleanup commands to remove the table. +pub fn build_table_cleanup(table_name: &str) -> Vec { + vec![format!("nft delete table ip {}", table_name)] +} + +#[cfg(test)] +mod tests { + use super::*; + + fn make_options() -> NfTablesOptions { + NfTablesOptions { + preserve_source_ip: None, + protocol: None, + max_rate: None, + priority: None, + table_name: None, + use_ip_sets: None, + use_advanced_nat: None, + } + } + + #[test] + fn test_basic_dnat_rule() { + let options = make_options(); + let rules = build_dnat_rule("rustproxy", "prerouting", 443, "10.0.0.1", 8443, &options); + assert!(rules.len() >= 1); + assert!(rules[0].contains("dnat to 10.0.0.1:8443")); + assert!(rules[0].contains("dport 443")); + } + + #[test] + fn test_preserve_source_ip() { + let mut options = make_options(); + options.preserve_source_ip = Some(true); + let rules = build_dnat_rule("rustproxy", "prerouting", 443, "10.0.0.1", 8443, &options); + // When preserving source IP, no masquerade rule + assert!(rules.iter().all(|r| !r.contains("masquerade"))); + } + + #[test] + fn test_without_preserve_source_ip() { + let options = make_options(); + let rules = build_dnat_rule("rustproxy", "prerouting", 443, "10.0.0.1", 8443, &options); + assert!(rules.iter().any(|r| r.contains("masquerade"))); + } + + #[test] + fn test_rate_limited_rule() { + let mut options = make_options(); + options.max_rate = Some("100/second".to_string()); + let rules = build_dnat_rule("rustproxy", "prerouting", 80, "10.0.0.1", 8080, &options); + assert!(rules.iter().any(|r| r.contains("limit rate 100/second"))); + } + + #[test] + fn test_table_setup_commands() { + let commands = build_table_setup("rustproxy"); + assert_eq!(commands.len(), 3); + assert!(commands[0].contains("add table ip rustproxy")); + assert!(commands[1].contains("prerouting")); + assert!(commands[2].contains("postrouting")); + } + + #[test] + fn test_table_cleanup() { + let commands = build_table_cleanup("rustproxy"); + assert_eq!(commands.len(), 1); + assert!(commands[0].contains("delete table ip rustproxy")); + } +} diff --git a/rust/crates/rustproxy-passthrough/Cargo.toml b/rust/crates/rustproxy-passthrough/Cargo.toml new file mode 100644 index 0000000..d43e92c --- /dev/null +++ b/rust/crates/rustproxy-passthrough/Cargo.toml @@ -0,0 +1,25 @@ +[package] +name = "rustproxy-passthrough" +version.workspace = true +edition.workspace = true +license.workspace = true +authors.workspace = true +description = "Raw TCP/SNI passthrough engine for RustProxy" + +[dependencies] +rustproxy-config = { workspace = true } +rustproxy-routing = { workspace = true } +rustproxy-metrics = { workspace = true } +tokio = { workspace = true } +tracing = { workspace = true } +thiserror = { workspace = true } +anyhow = { workspace = true } +dashmap = { workspace = true } +arc-swap = { workspace = true } +rustproxy-http = { workspace = true } +rustls = { workspace = true } +tokio-rustls = { workspace = true } +rustls-pemfile = { workspace = true } +tokio-util = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } diff --git a/rust/crates/rustproxy-passthrough/src/connection_record.rs b/rust/crates/rustproxy-passthrough/src/connection_record.rs new file mode 100644 index 0000000..3913565 --- /dev/null +++ b/rust/crates/rustproxy-passthrough/src/connection_record.rs @@ -0,0 +1,155 @@ +use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; +use std::time::{Duration, Instant}; + +/// Per-connection tracking record with atomics for lock-free updates. +/// +/// Each field uses atomics so that the forwarding tasks can update +/// bytes_received / bytes_sent / last_activity without holding any lock, +/// while the zombie scanner reads them concurrently. +pub struct ConnectionRecord { + /// Unique connection ID assigned by the ConnectionTracker. + pub id: u64, + /// Wall-clock instant when this connection was created. + pub created_at: Instant, + /// Milliseconds since `created_at` when the last activity occurred. + /// Updated atomically by the forwarding loops. + pub last_activity: AtomicU64, + /// Total bytes received from the client (inbound). + pub bytes_received: AtomicU64, + /// Total bytes sent to the client (outbound / from backend). + pub bytes_sent: AtomicU64, + /// True once the client side of the connection has closed. + pub client_closed: AtomicBool, + /// True once the backend side of the connection has closed. + pub backend_closed: AtomicBool, + /// Whether this connection uses TLS (affects zombie thresholds). + pub is_tls: AtomicBool, + /// Whether this connection has keep-alive semantics. + pub has_keep_alive: AtomicBool, +} + +impl ConnectionRecord { + /// Create a new connection record with the given ID. + /// All counters start at zero, all flags start as false. + pub fn new(id: u64) -> Self { + Self { + id, + created_at: Instant::now(), + last_activity: AtomicU64::new(0), + bytes_received: AtomicU64::new(0), + bytes_sent: AtomicU64::new(0), + client_closed: AtomicBool::new(false), + backend_closed: AtomicBool::new(false), + is_tls: AtomicBool::new(false), + has_keep_alive: AtomicBool::new(false), + } + } + + /// Update `last_activity` to reflect the current elapsed time. + pub fn touch(&self) { + let elapsed_ms = self.created_at.elapsed().as_millis() as u64; + self.last_activity.store(elapsed_ms, Ordering::Relaxed); + } + + /// Record `n` bytes received from the client (inbound). + pub fn record_bytes_in(&self, n: u64) { + self.bytes_received.fetch_add(n, Ordering::Relaxed); + self.touch(); + } + + /// Record `n` bytes sent to the client (outbound / from backend). + pub fn record_bytes_out(&self, n: u64) { + self.bytes_sent.fetch_add(n, Ordering::Relaxed); + self.touch(); + } + + /// How long since the last activity on this connection. + pub fn idle_duration(&self) -> Duration { + let last_ms = self.last_activity.load(Ordering::Relaxed); + let age_ms = self.created_at.elapsed().as_millis() as u64; + Duration::from_millis(age_ms.saturating_sub(last_ms)) + } + + /// Total age of this connection (time since creation). + pub fn age(&self) -> Duration { + self.created_at.elapsed() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::thread; + + #[test] + fn test_new_record() { + let record = ConnectionRecord::new(42); + assert_eq!(record.id, 42); + assert_eq!(record.bytes_received.load(Ordering::Relaxed), 0); + assert_eq!(record.bytes_sent.load(Ordering::Relaxed), 0); + assert!(!record.client_closed.load(Ordering::Relaxed)); + assert!(!record.backend_closed.load(Ordering::Relaxed)); + assert!(!record.is_tls.load(Ordering::Relaxed)); + assert!(!record.has_keep_alive.load(Ordering::Relaxed)); + } + + #[test] + fn test_record_bytes() { + let record = ConnectionRecord::new(1); + record.record_bytes_in(100); + record.record_bytes_in(200); + assert_eq!(record.bytes_received.load(Ordering::Relaxed), 300); + + record.record_bytes_out(50); + record.record_bytes_out(75); + assert_eq!(record.bytes_sent.load(Ordering::Relaxed), 125); + } + + #[test] + fn test_touch_updates_activity() { + let record = ConnectionRecord::new(1); + assert_eq!(record.last_activity.load(Ordering::Relaxed), 0); + + // Sleep briefly so elapsed time is nonzero + thread::sleep(Duration::from_millis(10)); + record.touch(); + + let activity = record.last_activity.load(Ordering::Relaxed); + assert!(activity >= 10, "last_activity should be at least 10ms, got {}", activity); + } + + #[test] + fn test_idle_duration() { + let record = ConnectionRecord::new(1); + // Initially idle_duration ~ age since last_activity is 0 + thread::sleep(Duration::from_millis(20)); + let idle = record.idle_duration(); + assert!(idle >= Duration::from_millis(20)); + + // After touch, idle should be near zero + record.touch(); + let idle = record.idle_duration(); + assert!(idle < Duration::from_millis(10)); + } + + #[test] + fn test_age() { + let record = ConnectionRecord::new(1); + thread::sleep(Duration::from_millis(20)); + let age = record.age(); + assert!(age >= Duration::from_millis(20)); + } + + #[test] + fn test_flags() { + let record = ConnectionRecord::new(1); + record.client_closed.store(true, Ordering::Relaxed); + record.is_tls.store(true, Ordering::Relaxed); + record.has_keep_alive.store(true, Ordering::Relaxed); + + assert!(record.client_closed.load(Ordering::Relaxed)); + assert!(!record.backend_closed.load(Ordering::Relaxed)); + assert!(record.is_tls.load(Ordering::Relaxed)); + assert!(record.has_keep_alive.load(Ordering::Relaxed)); + } +} diff --git a/rust/crates/rustproxy-passthrough/src/connection_tracker.rs b/rust/crates/rustproxy-passthrough/src/connection_tracker.rs new file mode 100644 index 0000000..04d2f5b --- /dev/null +++ b/rust/crates/rustproxy-passthrough/src/connection_tracker.rs @@ -0,0 +1,402 @@ +use dashmap::DashMap; +use std::collections::VecDeque; +use std::net::IpAddr; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::Arc; +use std::time::{Duration, Instant}; +use tokio_util::sync::CancellationToken; +use tracing::{debug, warn}; + +use super::connection_record::ConnectionRecord; + +/// Thresholds for zombie detection (non-TLS connections). +const HALF_ZOMBIE_TIMEOUT_PLAIN: Duration = Duration::from_secs(30); +/// Thresholds for zombie detection (TLS connections). +const HALF_ZOMBIE_TIMEOUT_TLS: Duration = Duration::from_secs(300); +/// Stuck connection timeout (non-TLS): received data but never sent any. +const STUCK_TIMEOUT_PLAIN: Duration = Duration::from_secs(60); +/// Stuck connection timeout (TLS): received data but never sent any. +const STUCK_TIMEOUT_TLS: Duration = Duration::from_secs(300); + +/// Tracks active connections per IP and enforces per-IP limits and rate limiting. +/// Also maintains per-connection records for zombie detection. +pub struct ConnectionTracker { + /// Active connection counts per IP + active: DashMap, + /// Connection timestamps per IP for rate limiting + timestamps: DashMap>, + /// Maximum concurrent connections per IP (None = unlimited) + max_per_ip: Option, + /// Maximum new connections per minute per IP (None = unlimited) + rate_limit_per_minute: Option, + /// Per-connection tracking records for zombie detection + connections: DashMap>, + /// Monotonically increasing connection ID counter + next_id: AtomicU64, +} + +impl ConnectionTracker { + pub fn new(max_per_ip: Option, rate_limit_per_minute: Option) -> Self { + Self { + active: DashMap::new(), + timestamps: DashMap::new(), + max_per_ip, + rate_limit_per_minute, + connections: DashMap::new(), + next_id: AtomicU64::new(1), + } + } + + /// Try to accept a new connection from the given IP. + /// Returns true if allowed, false if over limit. + pub fn try_accept(&self, ip: &IpAddr) -> bool { + // Check per-IP connection limit + if let Some(max) = self.max_per_ip { + let count = self.active + .get(ip) + .map(|c| c.value().load(Ordering::Relaxed)) + .unwrap_or(0); + if count >= max { + return false; + } + } + + // Check rate limit + if let Some(rate_limit) = self.rate_limit_per_minute { + let now = Instant::now(); + let one_minute = std::time::Duration::from_secs(60); + let mut entry = self.timestamps.entry(*ip).or_default(); + let timestamps = entry.value_mut(); + + // Remove timestamps older than 1 minute + while timestamps.front().is_some_and(|t| now.duration_since(*t) >= one_minute) { + timestamps.pop_front(); + } + + if timestamps.len() as u64 >= rate_limit { + return false; + } + timestamps.push_back(now); + } + + true + } + + /// Record that a connection was opened from the given IP. + pub fn connection_opened(&self, ip: &IpAddr) { + self.active + .entry(*ip) + .or_insert_with(|| AtomicU64::new(0)) + .value() + .fetch_add(1, Ordering::Relaxed); + } + + /// Record that a connection was closed from the given IP. + pub fn connection_closed(&self, ip: &IpAddr) { + if let Some(counter) = self.active.get(ip) { + let prev = counter.value().fetch_sub(1, Ordering::Relaxed); + // Clean up zero entries + if prev <= 1 { + drop(counter); + self.active.remove(ip); + } + } + } + + /// Get the current number of active connections for an IP. + pub fn active_connections(&self, ip: &IpAddr) -> u64 { + self.active + .get(ip) + .map(|c| c.value().load(Ordering::Relaxed)) + .unwrap_or(0) + } + + /// Get the total number of tracked IPs. + pub fn tracked_ips(&self) -> usize { + self.active.len() + } + + /// Register a new connection and return its tracking record. + /// + /// The returned `Arc` should be passed to the forwarding + /// loop so it can update bytes / activity atomics in real time. + pub fn register_connection(&self, is_tls: bool) -> Arc { + let id = self.next_id.fetch_add(1, Ordering::Relaxed); + let record = Arc::new(ConnectionRecord::new(id)); + record.is_tls.store(is_tls, Ordering::Relaxed); + self.connections.insert(id, Arc::clone(&record)); + record + } + + /// Remove a connection record when the connection is fully closed. + pub fn unregister_connection(&self, id: u64) { + self.connections.remove(&id); + } + + /// Scan all tracked connections and return IDs of zombie connections. + /// + /// A connection is considered a zombie in any of these cases: + /// - **Full zombie**: both `client_closed` and `backend_closed` are true. + /// - **Half zombie**: one side closed for longer than the threshold + /// (5 min for TLS, 30s for non-TLS). + /// - **Stuck**: `bytes_received > 0` but `bytes_sent == 0` for longer + /// than the stuck threshold (5 min for TLS, 60s for non-TLS). + pub fn scan_zombies(&self) -> Vec { + let mut zombies = Vec::new(); + + for entry in self.connections.iter() { + let record = entry.value(); + let id = *entry.key(); + let is_tls = record.is_tls.load(Ordering::Relaxed); + let client_closed = record.client_closed.load(Ordering::Relaxed); + let backend_closed = record.backend_closed.load(Ordering::Relaxed); + let idle = record.idle_duration(); + let bytes_in = record.bytes_received.load(Ordering::Relaxed); + let bytes_out = record.bytes_sent.load(Ordering::Relaxed); + + // Full zombie: both sides closed + if client_closed && backend_closed { + zombies.push(id); + continue; + } + + // Half zombie: one side closed for too long + let half_timeout = if is_tls { + HALF_ZOMBIE_TIMEOUT_TLS + } else { + HALF_ZOMBIE_TIMEOUT_PLAIN + }; + + if (client_closed || backend_closed) && idle >= half_timeout { + zombies.push(id); + continue; + } + + // Stuck: received data but never sent anything for too long + let stuck_timeout = if is_tls { + STUCK_TIMEOUT_TLS + } else { + STUCK_TIMEOUT_PLAIN + }; + + if bytes_in > 0 && bytes_out == 0 && idle >= stuck_timeout { + zombies.push(id); + } + } + + zombies + } + + /// Start a background task that periodically scans for zombie connections. + /// + /// The scanner runs every 10 seconds and logs any zombies it finds. + /// It stops when the provided `CancellationToken` is cancelled. + pub fn start_zombie_scanner(self: &Arc, cancel: CancellationToken) { + let tracker = Arc::clone(self); + tokio::spawn(async move { + let interval = Duration::from_secs(10); + loop { + tokio::select! { + _ = cancel.cancelled() => { + debug!("Zombie scanner shutting down"); + break; + } + _ = tokio::time::sleep(interval) => { + let zombies = tracker.scan_zombies(); + if !zombies.is_empty() { + warn!( + "Detected {} zombie connection(s): {:?}", + zombies.len(), + zombies + ); + } + } + } + } + }); + } + + /// Get the total number of tracked connections (with records). + pub fn total_connections(&self) -> usize { + self.connections.len() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_basic_tracking() { + let tracker = ConnectionTracker::new(None, None); + let ip: IpAddr = "127.0.0.1".parse().unwrap(); + + assert!(tracker.try_accept(&ip)); + tracker.connection_opened(&ip); + assert_eq!(tracker.active_connections(&ip), 1); + + tracker.connection_opened(&ip); + assert_eq!(tracker.active_connections(&ip), 2); + + tracker.connection_closed(&ip); + assert_eq!(tracker.active_connections(&ip), 1); + + tracker.connection_closed(&ip); + assert_eq!(tracker.active_connections(&ip), 0); + } + + #[test] + fn test_per_ip_limit() { + let tracker = ConnectionTracker::new(Some(2), None); + let ip: IpAddr = "10.0.0.1".parse().unwrap(); + + assert!(tracker.try_accept(&ip)); + tracker.connection_opened(&ip); + + assert!(tracker.try_accept(&ip)); + tracker.connection_opened(&ip); + + // Third connection should be rejected + assert!(!tracker.try_accept(&ip)); + + // Different IP should still be allowed + let ip2: IpAddr = "10.0.0.2".parse().unwrap(); + assert!(tracker.try_accept(&ip2)); + } + + #[test] + fn test_rate_limit() { + let tracker = ConnectionTracker::new(None, Some(3)); + let ip: IpAddr = "10.0.0.1".parse().unwrap(); + + assert!(tracker.try_accept(&ip)); + assert!(tracker.try_accept(&ip)); + assert!(tracker.try_accept(&ip)); + // 4th attempt within the minute should be rejected + assert!(!tracker.try_accept(&ip)); + } + + #[test] + fn test_no_limits() { + let tracker = ConnectionTracker::new(None, None); + let ip: IpAddr = "10.0.0.1".parse().unwrap(); + + for _ in 0..1000 { + assert!(tracker.try_accept(&ip)); + tracker.connection_opened(&ip); + } + assert_eq!(tracker.active_connections(&ip), 1000); + } + + #[test] + fn test_tracked_ips() { + let tracker = ConnectionTracker::new(None, None); + assert_eq!(tracker.tracked_ips(), 0); + + let ip1: IpAddr = "10.0.0.1".parse().unwrap(); + let ip2: IpAddr = "10.0.0.2".parse().unwrap(); + + tracker.connection_opened(&ip1); + tracker.connection_opened(&ip2); + assert_eq!(tracker.tracked_ips(), 2); + + tracker.connection_closed(&ip1); + assert_eq!(tracker.tracked_ips(), 1); + } + + #[test] + fn test_register_unregister_connection() { + let tracker = ConnectionTracker::new(None, None); + assert_eq!(tracker.total_connections(), 0); + + let record1 = tracker.register_connection(false); + assert_eq!(tracker.total_connections(), 1); + assert!(!record1.is_tls.load(Ordering::Relaxed)); + + let record2 = tracker.register_connection(true); + assert_eq!(tracker.total_connections(), 2); + assert!(record2.is_tls.load(Ordering::Relaxed)); + + // IDs should be unique + assert_ne!(record1.id, record2.id); + + tracker.unregister_connection(record1.id); + assert_eq!(tracker.total_connections(), 1); + + tracker.unregister_connection(record2.id); + assert_eq!(tracker.total_connections(), 0); + } + + #[test] + fn test_full_zombie_detection() { + let tracker = ConnectionTracker::new(None, None); + let record = tracker.register_connection(false); + + // Not a zombie initially + assert!(tracker.scan_zombies().is_empty()); + + // Set both sides closed -> full zombie + record.client_closed.store(true, Ordering::Relaxed); + record.backend_closed.store(true, Ordering::Relaxed); + + let zombies = tracker.scan_zombies(); + assert_eq!(zombies.len(), 1); + assert_eq!(zombies[0], record.id); + } + + #[test] + fn test_half_zombie_not_triggered_immediately() { + let tracker = ConnectionTracker::new(None, None); + let record = tracker.register_connection(false); + record.touch(); // mark activity now + + // Only one side closed, but just now -> not a zombie yet + record.client_closed.store(true, Ordering::Relaxed); + assert!(tracker.scan_zombies().is_empty()); + } + + #[test] + fn test_stuck_connection_not_triggered_immediately() { + let tracker = ConnectionTracker::new(None, None); + let record = tracker.register_connection(false); + record.touch(); // mark activity now + + // Has received data but sent nothing -> but just started, not stuck yet + record.bytes_received.store(1000, Ordering::Relaxed); + assert!(tracker.scan_zombies().is_empty()); + } + + #[test] + fn test_unregister_removes_from_zombie_scan() { + let tracker = ConnectionTracker::new(None, None); + let record = tracker.register_connection(false); + let id = record.id; + + // Make it a full zombie + record.client_closed.store(true, Ordering::Relaxed); + record.backend_closed.store(true, Ordering::Relaxed); + assert_eq!(tracker.scan_zombies().len(), 1); + + // Unregister should remove it + tracker.unregister_connection(id); + assert!(tracker.scan_zombies().is_empty()); + } + + #[test] + fn test_total_connections() { + let tracker = ConnectionTracker::new(None, None); + assert_eq!(tracker.total_connections(), 0); + + let r1 = tracker.register_connection(false); + let r2 = tracker.register_connection(true); + let r3 = tracker.register_connection(false); + assert_eq!(tracker.total_connections(), 3); + + tracker.unregister_connection(r2.id); + assert_eq!(tracker.total_connections(), 2); + + tracker.unregister_connection(r1.id); + tracker.unregister_connection(r3.id); + assert_eq!(tracker.total_connections(), 0); + } +} diff --git a/rust/crates/rustproxy-passthrough/src/forwarder.rs b/rust/crates/rustproxy-passthrough/src/forwarder.rs new file mode 100644 index 0000000..8a1e713 --- /dev/null +++ b/rust/crates/rustproxy-passthrough/src/forwarder.rs @@ -0,0 +1,325 @@ +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::TcpStream; +use tokio_util::sync::CancellationToken; +use std::sync::Arc; +use std::sync::atomic::{AtomicU64, Ordering}; +use tracing::debug; + +use super::connection_record::ConnectionRecord; + +/// Statistics for a forwarded connection. +#[derive(Debug, Default)] +pub struct ForwardStats { + pub bytes_in: AtomicU64, + pub bytes_out: AtomicU64, +} + +/// Perform bidirectional TCP forwarding between client and backend. +/// +/// This is the core data path for passthrough connections. +/// Returns (bytes_from_client, bytes_from_backend) when the connection closes. +pub async fn forward_bidirectional( + mut client: TcpStream, + mut backend: TcpStream, + initial_data: Option<&[u8]>, +) -> std::io::Result<(u64, u64)> { + // Send initial data (peeked bytes) to backend + if let Some(data) = initial_data { + backend.write_all(data).await?; + } + + let (mut client_read, mut client_write) = client.split(); + let (mut backend_read, mut backend_write) = backend.split(); + + let client_to_backend = async { + let mut buf = vec![0u8; 65536]; + let mut total = initial_data.map_or(0u64, |d| d.len() as u64); + loop { + let n = client_read.read(&mut buf).await?; + if n == 0 { + break; + } + backend_write.write_all(&buf[..n]).await?; + total += n as u64; + } + backend_write.shutdown().await?; + Ok::(total) + }; + + let backend_to_client = async { + let mut buf = vec![0u8; 65536]; + let mut total = 0u64; + loop { + let n = backend_read.read(&mut buf).await?; + if n == 0 { + break; + } + client_write.write_all(&buf[..n]).await?; + total += n as u64; + } + client_write.shutdown().await?; + Ok::(total) + }; + + let (c2b, b2c) = tokio::join!(client_to_backend, backend_to_client); + + Ok((c2b.unwrap_or(0), b2c.unwrap_or(0))) +} + +/// Perform bidirectional TCP forwarding with inactivity and max lifetime timeouts. +/// +/// Returns (bytes_from_client, bytes_from_backend) when the connection closes or times out. +pub async fn forward_bidirectional_with_timeouts( + client: TcpStream, + mut backend: TcpStream, + initial_data: Option<&[u8]>, + inactivity_timeout: std::time::Duration, + max_lifetime: std::time::Duration, + cancel: CancellationToken, +) -> std::io::Result<(u64, u64)> { + // Send initial data (peeked bytes) to backend + if let Some(data) = initial_data { + backend.write_all(data).await?; + } + + let (mut client_read, mut client_write) = client.into_split(); + let (mut backend_read, mut backend_write) = backend.into_split(); + + let last_activity = Arc::new(AtomicU64::new(0)); + let start = std::time::Instant::now(); + + let la1 = Arc::clone(&last_activity); + let initial_len = initial_data.map_or(0u64, |d| d.len() as u64); + let c2b = tokio::spawn(async move { + let mut buf = vec![0u8; 65536]; + let mut total = initial_len; + loop { + let n = match client_read.read(&mut buf).await { + Ok(0) | Err(_) => break, + Ok(n) => n, + }; + if backend_write.write_all(&buf[..n]).await.is_err() { + break; + } + total += n as u64; + la1.store(start.elapsed().as_millis() as u64, Ordering::Relaxed); + } + let _ = backend_write.shutdown().await; + total + }); + + let la2 = Arc::clone(&last_activity); + let b2c = tokio::spawn(async move { + let mut buf = vec![0u8; 65536]; + let mut total = 0u64; + loop { + let n = match backend_read.read(&mut buf).await { + Ok(0) | Err(_) => break, + Ok(n) => n, + }; + if client_write.write_all(&buf[..n]).await.is_err() { + break; + } + total += n as u64; + la2.store(start.elapsed().as_millis() as u64, Ordering::Relaxed); + } + let _ = client_write.shutdown().await; + total + }); + + // Watchdog: inactivity, max lifetime, and cancellation + let la_watch = Arc::clone(&last_activity); + let c2b_handle = c2b.abort_handle(); + let b2c_handle = b2c.abort_handle(); + let watchdog = tokio::spawn(async move { + let check_interval = std::time::Duration::from_secs(5); + let mut last_seen = 0u64; + loop { + tokio::select! { + _ = cancel.cancelled() => { + debug!("Connection cancelled by shutdown"); + c2b_handle.abort(); + b2c_handle.abort(); + break; + } + _ = tokio::time::sleep(check_interval) => { + // Check max lifetime + if start.elapsed() >= max_lifetime { + debug!("Connection exceeded max lifetime, closing"); + c2b_handle.abort(); + b2c_handle.abort(); + break; + } + + // Check inactivity + let current = la_watch.load(Ordering::Relaxed); + if current == last_seen { + let elapsed_since_activity = start.elapsed().as_millis() as u64 - current; + if elapsed_since_activity >= inactivity_timeout.as_millis() as u64 { + debug!("Connection inactive for {}ms, closing", elapsed_since_activity); + c2b_handle.abort(); + b2c_handle.abort(); + break; + } + } + last_seen = current; + } + } + } + }); + + let bytes_in = c2b.await.unwrap_or(0); + let bytes_out = b2c.await.unwrap_or(0); + watchdog.abort(); + Ok((bytes_in, bytes_out)) +} + +/// Forward bidirectional with a callback for byte counting. +pub async fn forward_bidirectional_with_stats( + client: TcpStream, + backend: TcpStream, + initial_data: Option<&[u8]>, + stats: Arc, +) -> std::io::Result<()> { + let (bytes_in, bytes_out) = forward_bidirectional(client, backend, initial_data).await?; + stats.bytes_in.fetch_add(bytes_in, Ordering::Relaxed); + stats.bytes_out.fetch_add(bytes_out, Ordering::Relaxed); + Ok(()) +} + +/// Perform bidirectional TCP forwarding with inactivity / lifetime timeouts, +/// updating a `ConnectionRecord` with byte counts and activity timestamps +/// in real time for zombie detection. +/// +/// When `record` is `None`, this behaves identically to +/// `forward_bidirectional_with_timeouts`. +/// +/// The record's `client_closed` / `backend_closed` flags are set when the +/// respective copy loop terminates, giving the zombie scanner visibility +/// into half-open connections. +pub async fn forward_bidirectional_with_record( + client: TcpStream, + mut backend: TcpStream, + initial_data: Option<&[u8]>, + inactivity_timeout: std::time::Duration, + max_lifetime: std::time::Duration, + cancel: CancellationToken, + record: Option>, +) -> std::io::Result<(u64, u64)> { + // Send initial data (peeked bytes) to backend + if let Some(data) = initial_data { + backend.write_all(data).await?; + if let Some(ref r) = record { + r.record_bytes_in(data.len() as u64); + } + } + + let (mut client_read, mut client_write) = client.into_split(); + let (mut backend_read, mut backend_write) = backend.into_split(); + + let last_activity = Arc::new(AtomicU64::new(0)); + let start = std::time::Instant::now(); + + let la1 = Arc::clone(&last_activity); + let initial_len = initial_data.map_or(0u64, |d| d.len() as u64); + let rec1 = record.clone(); + let c2b = tokio::spawn(async move { + let mut buf = vec![0u8; 65536]; + let mut total = initial_len; + loop { + let n = match client_read.read(&mut buf).await { + Ok(0) | Err(_) => break, + Ok(n) => n, + }; + if backend_write.write_all(&buf[..n]).await.is_err() { + break; + } + total += n as u64; + let now_ms = start.elapsed().as_millis() as u64; + la1.store(now_ms, Ordering::Relaxed); + if let Some(ref r) = rec1 { + r.record_bytes_in(n as u64); + } + } + let _ = backend_write.shutdown().await; + // Mark client side as closed + if let Some(ref r) = rec1 { + r.client_closed.store(true, Ordering::Relaxed); + } + total + }); + + let la2 = Arc::clone(&last_activity); + let rec2 = record.clone(); + let b2c = tokio::spawn(async move { + let mut buf = vec![0u8; 65536]; + let mut total = 0u64; + loop { + let n = match backend_read.read(&mut buf).await { + Ok(0) | Err(_) => break, + Ok(n) => n, + }; + if client_write.write_all(&buf[..n]).await.is_err() { + break; + } + total += n as u64; + let now_ms = start.elapsed().as_millis() as u64; + la2.store(now_ms, Ordering::Relaxed); + if let Some(ref r) = rec2 { + r.record_bytes_out(n as u64); + } + } + let _ = client_write.shutdown().await; + // Mark backend side as closed + if let Some(ref r) = rec2 { + r.backend_closed.store(true, Ordering::Relaxed); + } + total + }); + + // Watchdog: inactivity, max lifetime, and cancellation + let la_watch = Arc::clone(&last_activity); + let c2b_handle = c2b.abort_handle(); + let b2c_handle = b2c.abort_handle(); + let watchdog = tokio::spawn(async move { + let check_interval = std::time::Duration::from_secs(5); + let mut last_seen = 0u64; + loop { + tokio::select! { + _ = cancel.cancelled() => { + debug!("Connection cancelled by shutdown"); + c2b_handle.abort(); + b2c_handle.abort(); + break; + } + _ = tokio::time::sleep(check_interval) => { + // Check max lifetime + if start.elapsed() >= max_lifetime { + debug!("Connection exceeded max lifetime, closing"); + c2b_handle.abort(); + b2c_handle.abort(); + break; + } + + // Check inactivity + let current = la_watch.load(Ordering::Relaxed); + if current == last_seen { + let elapsed_since_activity = start.elapsed().as_millis() as u64 - current; + if elapsed_since_activity >= inactivity_timeout.as_millis() as u64 { + debug!("Connection inactive for {}ms, closing", elapsed_since_activity); + c2b_handle.abort(); + b2c_handle.abort(); + break; + } + } + last_seen = current; + } + } + } + }); + + let bytes_in = c2b.await.unwrap_or(0); + let bytes_out = b2c.await.unwrap_or(0); + watchdog.abort(); + Ok((bytes_in, bytes_out)) +} diff --git a/rust/crates/rustproxy-passthrough/src/lib.rs b/rust/crates/rustproxy-passthrough/src/lib.rs new file mode 100644 index 0000000..ba46cff --- /dev/null +++ b/rust/crates/rustproxy-passthrough/src/lib.rs @@ -0,0 +1,22 @@ +//! # rustproxy-passthrough +//! +//! Raw TCP/SNI passthrough engine for RustProxy. +//! Handles TCP listening, TLS ClientHello SNI extraction, and bidirectional forwarding. + +pub mod tcp_listener; +pub mod sni_parser; +pub mod forwarder; +pub mod proxy_protocol; +pub mod tls_handler; +pub mod connection_record; +pub mod connection_tracker; +pub mod socket_relay; + +pub use tcp_listener::*; +pub use sni_parser::*; +pub use forwarder::*; +pub use proxy_protocol::*; +pub use tls_handler::*; +pub use connection_record::*; +pub use connection_tracker::*; +pub use socket_relay::*; diff --git a/rust/crates/rustproxy-passthrough/src/proxy_protocol.rs b/rust/crates/rustproxy-passthrough/src/proxy_protocol.rs new file mode 100644 index 0000000..d9bcd73 --- /dev/null +++ b/rust/crates/rustproxy-passthrough/src/proxy_protocol.rs @@ -0,0 +1,129 @@ +use std::net::SocketAddr; +use thiserror::Error; + +#[derive(Debug, Error)] +pub enum ProxyProtocolError { + #[error("Invalid PROXY protocol header")] + InvalidHeader, + #[error("Unsupported PROXY protocol version")] + UnsupportedVersion, + #[error("Parse error: {0}")] + Parse(String), +} + +/// Parsed PROXY protocol v1 header. +#[derive(Debug, Clone)] +pub struct ProxyProtocolHeader { + pub source_addr: SocketAddr, + pub dest_addr: SocketAddr, + pub protocol: ProxyProtocol, +} + +/// Protocol in PROXY header. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum ProxyProtocol { + Tcp4, + Tcp6, + Unknown, +} + +/// Parse a PROXY protocol v1 header from data. +/// +/// Format: `PROXY TCP4 \r\n` +pub fn parse_v1(data: &[u8]) -> Result<(ProxyProtocolHeader, usize), ProxyProtocolError> { + // Find the end of the header line + let line_end = data + .windows(2) + .position(|w| w == b"\r\n") + .ok_or(ProxyProtocolError::InvalidHeader)?; + + let line = std::str::from_utf8(&data[..line_end]) + .map_err(|_| ProxyProtocolError::InvalidHeader)?; + + if !line.starts_with("PROXY ") { + return Err(ProxyProtocolError::InvalidHeader); + } + + let parts: Vec<&str> = line.split(' ').collect(); + if parts.len() != 6 { + return Err(ProxyProtocolError::InvalidHeader); + } + + let protocol = match parts[1] { + "TCP4" => ProxyProtocol::Tcp4, + "TCP6" => ProxyProtocol::Tcp6, + "UNKNOWN" => ProxyProtocol::Unknown, + _ => return Err(ProxyProtocolError::UnsupportedVersion), + }; + + let src_ip: std::net::IpAddr = parts[2] + .parse() + .map_err(|_| ProxyProtocolError::Parse("Invalid source IP".to_string()))?; + let dst_ip: std::net::IpAddr = parts[3] + .parse() + .map_err(|_| ProxyProtocolError::Parse("Invalid destination IP".to_string()))?; + let src_port: u16 = parts[4] + .parse() + .map_err(|_| ProxyProtocolError::Parse("Invalid source port".to_string()))?; + let dst_port: u16 = parts[5] + .parse() + .map_err(|_| ProxyProtocolError::Parse("Invalid destination port".to_string()))?; + + let header = ProxyProtocolHeader { + source_addr: SocketAddr::new(src_ip, src_port), + dest_addr: SocketAddr::new(dst_ip, dst_port), + protocol, + }; + + // Consumed bytes = line + \r\n + Ok((header, line_end + 2)) +} + +/// Generate a PROXY protocol v1 header string. +pub fn generate_v1(source: &SocketAddr, dest: &SocketAddr) -> String { + let proto = if source.is_ipv4() { "TCP4" } else { "TCP6" }; + format!( + "PROXY {} {} {} {} {}\r\n", + proto, + source.ip(), + dest.ip(), + source.port(), + dest.port() + ) +} + +/// Check if data starts with a PROXY protocol v1 header. +pub fn is_proxy_protocol_v1(data: &[u8]) -> bool { + data.starts_with(b"PROXY ") +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_v1_tcp4() { + let header = b"PROXY TCP4 192.168.1.100 10.0.0.1 12345 443\r\n"; + let (parsed, consumed) = parse_v1(header).unwrap(); + assert_eq!(consumed, header.len()); + assert_eq!(parsed.protocol, ProxyProtocol::Tcp4); + assert_eq!(parsed.source_addr.ip().to_string(), "192.168.1.100"); + assert_eq!(parsed.source_addr.port(), 12345); + assert_eq!(parsed.dest_addr.ip().to_string(), "10.0.0.1"); + assert_eq!(parsed.dest_addr.port(), 443); + } + + #[test] + fn test_generate_v1() { + let source: SocketAddr = "192.168.1.100:12345".parse().unwrap(); + let dest: SocketAddr = "10.0.0.1:443".parse().unwrap(); + let header = generate_v1(&source, &dest); + assert_eq!(header, "PROXY TCP4 192.168.1.100 10.0.0.1 12345 443\r\n"); + } + + #[test] + fn test_is_proxy_protocol() { + assert!(is_proxy_protocol_v1(b"PROXY TCP4 ...")); + assert!(!is_proxy_protocol_v1(b"GET / HTTP/1.1")); + } +} diff --git a/rust/crates/rustproxy-passthrough/src/sni_parser.rs b/rust/crates/rustproxy-passthrough/src/sni_parser.rs new file mode 100644 index 0000000..8eeefd2 --- /dev/null +++ b/rust/crates/rustproxy-passthrough/src/sni_parser.rs @@ -0,0 +1,287 @@ +//! ClientHello SNI extraction via manual byte parsing. +//! No TLS stack needed - we just parse enough of the ClientHello to extract the SNI. + +/// Result of SNI extraction. +#[derive(Debug)] +pub enum SniResult { + /// Successfully extracted SNI hostname. + Found(String), + /// TLS ClientHello detected but no SNI extension present. + NoSni, + /// Not a TLS ClientHello (plain HTTP or other protocol). + NotTls, + /// Need more data to determine. + NeedMoreData, +} + +/// Extract the SNI hostname from a TLS ClientHello message. +/// +/// This parses just enough of the TLS record to find the SNI extension, +/// without performing any actual TLS operations. +pub fn extract_sni(data: &[u8]) -> SniResult { + // Minimum TLS record header is 5 bytes + if data.len() < 5 { + return SniResult::NeedMoreData; + } + + // Check for TLS record: content_type=22 (Handshake) + if data[0] != 0x16 { + return SniResult::NotTls; + } + + // TLS version (major.minor) - accept any + // data[1..2] = version + + // Record length + let record_len = ((data[3] as usize) << 8) | (data[4] as usize); + let _total_len = 5 + record_len; + + // We need at least the handshake header (5 TLS + 4 handshake = 9) + if data.len() < 9 { + return SniResult::NeedMoreData; + } + + // Handshake type = 1 (ClientHello) + if data[5] != 0x01 { + return SniResult::NotTls; + } + + // Handshake length (3 bytes) - informational, we parse incrementally + let _handshake_len = ((data[6] as usize) << 16) + | ((data[7] as usize) << 8) + | (data[8] as usize); + + let hello = &data[9..]; + + // ClientHello structure: + // 2 bytes: client version + // 32 bytes: random + // 1 byte: session_id length + session_id + let mut pos = 2 + 32; // skip version + random + + if pos >= hello.len() { + return SniResult::NeedMoreData; + } + + // Session ID + let session_id_len = hello[pos] as usize; + pos += 1 + session_id_len; + + if pos + 2 > hello.len() { + return SniResult::NeedMoreData; + } + + // Cipher suites + let cipher_suites_len = ((hello[pos] as usize) << 8) | (hello[pos + 1] as usize); + pos += 2 + cipher_suites_len; + + if pos + 1 > hello.len() { + return SniResult::NeedMoreData; + } + + // Compression methods + let compression_len = hello[pos] as usize; + pos += 1 + compression_len; + + if pos + 2 > hello.len() { + // No extensions + return SniResult::NoSni; + } + + // Extensions length + let extensions_len = ((hello[pos] as usize) << 8) | (hello[pos + 1] as usize); + pos += 2; + + let extensions_end = pos + extensions_len; + if extensions_end > hello.len() { + // Partial extensions, try to parse what we have + } + + // Parse extensions looking for SNI (type 0x0000) + while pos + 4 <= hello.len() && pos < extensions_end { + let ext_type = ((hello[pos] as u16) << 8) | (hello[pos + 1] as u16); + let ext_len = ((hello[pos + 2] as usize) << 8) | (hello[pos + 3] as usize); + pos += 4; + + if ext_type == 0x0000 { + // SNI extension + return parse_sni_extension(&hello[pos..(pos + ext_len).min(hello.len())], ext_len); + } + + pos += ext_len; + } + + SniResult::NoSni +} + +/// Parse the SNI extension data. +fn parse_sni_extension(data: &[u8], _ext_len: usize) -> SniResult { + if data.len() < 5 { + return SniResult::NeedMoreData; + } + + // Server name list length + let _list_len = ((data[0] as usize) << 8) | (data[1] as usize); + + // Server name type (0 = hostname) + if data[2] != 0x00 { + return SniResult::NoSni; + } + + // Hostname length + let name_len = ((data[3] as usize) << 8) | (data[4] as usize); + + if data.len() < 5 + name_len { + return SniResult::NeedMoreData; + } + + match std::str::from_utf8(&data[5..5 + name_len]) { + Ok(hostname) => SniResult::Found(hostname.to_lowercase()), + Err(_) => SniResult::NoSni, + } +} + +/// Check if the initial bytes look like a TLS ClientHello. +pub fn is_tls(data: &[u8]) -> bool { + data.len() >= 3 && data[0] == 0x16 && data[1] == 0x03 +} + +/// Check if the initial bytes look like HTTP. +pub fn is_http(data: &[u8]) -> bool { + if data.len() < 4 { + return false; + } + // Check for common HTTP methods + let starts = [ + b"GET " as &[u8], + b"POST", + b"PUT ", + b"HEAD", + b"DELE", + b"PATC", + b"OPTI", + b"CONN", + ]; + starts.iter().any(|s| data.starts_with(s)) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_not_tls() { + let http_data = b"GET / HTTP/1.1\r\n"; + assert!(matches!(extract_sni(http_data), SniResult::NotTls)); + } + + #[test] + fn test_too_short() { + assert!(matches!(extract_sni(&[0x16, 0x03]), SniResult::NeedMoreData)); + } + + #[test] + fn test_is_tls() { + assert!(is_tls(&[0x16, 0x03, 0x01])); + assert!(!is_tls(&[0x47, 0x45, 0x54])); // "GET" + } + + #[test] + fn test_is_http() { + assert!(is_http(b"GET /")); + assert!(is_http(b"POST /api")); + assert!(!is_http(&[0x16, 0x03, 0x01])); + } + + #[test] + fn test_real_client_hello() { + // A minimal TLS 1.2 ClientHello with SNI "example.com" + let client_hello: Vec = build_test_client_hello("example.com"); + match extract_sni(&client_hello) { + SniResult::Found(sni) => assert_eq!(sni, "example.com"), + other => panic!("Expected Found, got {:?}", other), + } + } + + /// Build a minimal TLS ClientHello for testing. + fn build_test_client_hello(hostname: &str) -> Vec { + let hostname_bytes = hostname.as_bytes(); + + // SNI extension + let sni_ext_data = { + let mut d = Vec::new(); + // Server name list length + let name_entry_len = 3 + hostname_bytes.len(); // type(1) + len(2) + name + d.push(((name_entry_len >> 8) & 0xFF) as u8); + d.push((name_entry_len & 0xFF) as u8); + // Host name type = 0 + d.push(0x00); + // Host name length + d.push(((hostname_bytes.len() >> 8) & 0xFF) as u8); + d.push((hostname_bytes.len() & 0xFF) as u8); + // Host name + d.extend_from_slice(hostname_bytes); + d + }; + + // Extension: type=0x0000 (SNI), length, data + let sni_extension = { + let mut e = Vec::new(); + e.push(0x00); e.push(0x00); // SNI type + e.push(((sni_ext_data.len() >> 8) & 0xFF) as u8); + e.push((sni_ext_data.len() & 0xFF) as u8); + e.extend_from_slice(&sni_ext_data); + e + }; + + // Extensions block + let extensions = { + let mut ext = Vec::new(); + ext.push(((sni_extension.len() >> 8) & 0xFF) as u8); + ext.push((sni_extension.len() & 0xFF) as u8); + ext.extend_from_slice(&sni_extension); + ext + }; + + // ClientHello body + let hello_body = { + let mut h = Vec::new(); + // Client version TLS 1.2 + h.push(0x03); h.push(0x03); + // Random (32 bytes) + h.extend_from_slice(&[0u8; 32]); + // Session ID length = 0 + h.push(0x00); + // Cipher suites: length=2, one suite + h.push(0x00); h.push(0x02); + h.push(0x00); h.push(0x2F); // TLS_RSA_WITH_AES_128_CBC_SHA + // Compression methods: length=1, null + h.push(0x01); h.push(0x00); + // Extensions + h.extend_from_slice(&extensions); + h + }; + + // Handshake: type=1 (ClientHello), length + let handshake = { + let mut hs = Vec::new(); + hs.push(0x01); // ClientHello + // 3-byte length + hs.push(((hello_body.len() >> 16) & 0xFF) as u8); + hs.push(((hello_body.len() >> 8) & 0xFF) as u8); + hs.push((hello_body.len() & 0xFF) as u8); + hs.extend_from_slice(&hello_body); + hs + }; + + // TLS record: type=0x16, version TLS 1.0, length + let mut record = Vec::new(); + record.push(0x16); // Handshake + record.push(0x03); record.push(0x01); // TLS 1.0 + record.push(((handshake.len() >> 8) & 0xFF) as u8); + record.push((handshake.len() & 0xFF) as u8); + record.extend_from_slice(&handshake); + + record + } +} diff --git a/rust/crates/rustproxy-passthrough/src/socket_relay.rs b/rust/crates/rustproxy-passthrough/src/socket_relay.rs new file mode 100644 index 0000000..671d353 --- /dev/null +++ b/rust/crates/rustproxy-passthrough/src/socket_relay.rs @@ -0,0 +1,126 @@ +//! Socket handler relay for connecting client connections to a TypeScript handler +//! via a Unix domain socket. +//! +//! Protocol: Send a JSON metadata line terminated by `\n`, then bidirectional relay. + +use tokio::net::UnixStream; +use tokio::io::{AsyncWriteExt, AsyncReadExt}; +use tokio::net::TcpStream; +use serde::Serialize; +use tracing::debug; + +#[derive(Serialize)] +#[serde(rename_all = "camelCase")] +struct RelayMetadata { + connection_id: u64, + remote_ip: String, + remote_port: u16, + local_port: u16, + sni: Option, + route_name: String, + initial_data_base64: Option, +} + +/// Relay a client connection to a TypeScript handler via Unix domain socket. +/// +/// Protocol: Send a JSON metadata line terminated by `\n`, then bidirectional relay. +pub async fn relay_to_handler( + client: TcpStream, + relay_socket_path: &str, + connection_id: u64, + remote_ip: String, + remote_port: u16, + local_port: u16, + sni: Option, + route_name: String, + initial_data: Option<&[u8]>, +) -> std::io::Result<()> { + debug!( + "Relaying connection {} to handler socket {}", + connection_id, relay_socket_path + ); + + // Connect to TypeScript handler Unix socket + let mut handler = UnixStream::connect(relay_socket_path).await?; + + // Build and send metadata header + let initial_data_base64 = initial_data.map(base64_encode); + + let metadata = RelayMetadata { + connection_id, + remote_ip, + remote_port, + local_port, + sni, + route_name, + initial_data_base64, + }; + + let metadata_json = serde_json::to_string(&metadata) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?; + + handler.write_all(metadata_json.as_bytes()).await?; + handler.write_all(b"\n").await?; + + // Bidirectional relay between client and handler + let (mut client_read, mut client_write) = client.into_split(); + let (mut handler_read, mut handler_write) = handler.into_split(); + + let c2h = tokio::spawn(async move { + let mut buf = vec![0u8; 65536]; + loop { + let n = match client_read.read(&mut buf).await { + Ok(0) | Err(_) => break, + Ok(n) => n, + }; + if handler_write.write_all(&buf[..n]).await.is_err() { + break; + } + } + let _ = handler_write.shutdown().await; + }); + + let h2c = tokio::spawn(async move { + let mut buf = vec![0u8; 65536]; + loop { + let n = match handler_read.read(&mut buf).await { + Ok(0) | Err(_) => break, + Ok(n) => n, + }; + if client_write.write_all(&buf[..n]).await.is_err() { + break; + } + } + let _ = client_write.shutdown().await; + }); + + let _ = tokio::join!(c2h, h2c); + + debug!("Relay connection {} completed", connection_id); + Ok(()) +} + +/// Simple base64 encoding without external dependency. +fn base64_encode(data: &[u8]) -> String { + const CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + let mut result = String::new(); + for chunk in data.chunks(3) { + let b0 = chunk[0] as u32; + let b1 = if chunk.len() > 1 { chunk[1] as u32 } else { 0 }; + let b2 = if chunk.len() > 2 { chunk[2] as u32 } else { 0 }; + let n = (b0 << 16) | (b1 << 8) | b2; + result.push(CHARS[((n >> 18) & 0x3F) as usize] as char); + result.push(CHARS[((n >> 12) & 0x3F) as usize] as char); + if chunk.len() > 1 { + result.push(CHARS[((n >> 6) & 0x3F) as usize] as char); + } else { + result.push('='); + } + if chunk.len() > 2 { + result.push(CHARS[(n & 0x3F) as usize] as char); + } else { + result.push('='); + } + } + result +} diff --git a/rust/crates/rustproxy-passthrough/src/tcp_listener.rs b/rust/crates/rustproxy-passthrough/src/tcp_listener.rs new file mode 100644 index 0000000..de71b53 --- /dev/null +++ b/rust/crates/rustproxy-passthrough/src/tcp_listener.rs @@ -0,0 +1,874 @@ +use std::collections::HashMap; +use std::sync::Arc; +use tokio::net::TcpListener; +use tokio_util::sync::CancellationToken; +use tracing::{info, error, debug, warn}; +use thiserror::Error; + +use rustproxy_routing::RouteManager; +use rustproxy_metrics::MetricsCollector; +use rustproxy_http::HttpProxyService; +use crate::sni_parser; +use crate::forwarder; +use crate::tls_handler; +use crate::connection_tracker::ConnectionTracker; + +#[derive(Debug, Error)] +pub enum ListenerError { + #[error("Failed to bind port {port}: {source}")] + BindFailed { port: u16, source: std::io::Error }, + #[error("Port {0} already bound")] + AlreadyBound(u16), + #[error("IO error: {0}")] + Io(#[from] std::io::Error), +} + +/// TLS configuration for a specific domain. +#[derive(Clone)] +pub struct TlsCertConfig { + pub cert_pem: String, + pub key_pem: String, +} + +/// Timeout and connection management configuration. +#[derive(Debug, Clone)] +pub struct ConnectionConfig { + /// Timeout for establishing connection to backend (ms) + pub connection_timeout_ms: u64, + /// Timeout for initial data/SNI peek (ms) + pub initial_data_timeout_ms: u64, + /// Socket inactivity timeout (ms) + pub socket_timeout_ms: u64, + /// Maximum connection lifetime (ms) + pub max_connection_lifetime_ms: u64, + /// Graceful shutdown timeout (ms) + pub graceful_shutdown_timeout_ms: u64, + /// Maximum connections per IP (None = unlimited) + pub max_connections_per_ip: Option, + /// Connection rate limit per minute per IP (None = unlimited) + pub connection_rate_limit_per_minute: Option, + /// Keep-alive treatment + pub keep_alive_treatment: Option, + /// Inactivity multiplier for keep-alive connections + pub keep_alive_inactivity_multiplier: Option, + /// Extended keep-alive lifetime (ms) for Extended treatment mode + pub extended_keep_alive_lifetime_ms: Option, + /// Whether to accept PROXY protocol + pub accept_proxy_protocol: bool, + /// Whether to send PROXY protocol + pub send_proxy_protocol: bool, +} + +impl Default for ConnectionConfig { + fn default() -> Self { + Self { + connection_timeout_ms: 30_000, + initial_data_timeout_ms: 60_000, + socket_timeout_ms: 3_600_000, + max_connection_lifetime_ms: 86_400_000, + graceful_shutdown_timeout_ms: 30_000, + max_connections_per_ip: None, + connection_rate_limit_per_minute: None, + keep_alive_treatment: None, + keep_alive_inactivity_multiplier: None, + extended_keep_alive_lifetime_ms: None, + accept_proxy_protocol: false, + send_proxy_protocol: false, + } + } +} + +/// Manages TCP listeners for all configured ports. +pub struct TcpListenerManager { + /// Active listeners indexed by port + listeners: HashMap>, + /// Shared route manager + route_manager: Arc, + /// Shared metrics collector + metrics: Arc, + /// TLS acceptors indexed by domain + tls_configs: Arc>, + /// HTTP proxy service for HTTP-level forwarding + http_proxy: Arc, + /// Connection configuration + conn_config: Arc, + /// Connection tracker for per-IP limits + conn_tracker: Arc, + /// Cancellation token for graceful shutdown + cancel_token: CancellationToken, +} + +impl TcpListenerManager { + pub fn new(route_manager: Arc) -> Self { + let metrics = Arc::new(MetricsCollector::new()); + let http_proxy = Arc::new(HttpProxyService::new( + Arc::clone(&route_manager), + Arc::clone(&metrics), + )); + let conn_config = ConnectionConfig::default(); + let conn_tracker = Arc::new(ConnectionTracker::new( + conn_config.max_connections_per_ip, + conn_config.connection_rate_limit_per_minute, + )); + Self { + listeners: HashMap::new(), + route_manager, + metrics, + tls_configs: Arc::new(HashMap::new()), + http_proxy, + conn_config: Arc::new(conn_config), + conn_tracker, + cancel_token: CancellationToken::new(), + } + } + + /// Create with a metrics collector. + pub fn with_metrics(route_manager: Arc, metrics: Arc) -> Self { + let http_proxy = Arc::new(HttpProxyService::new( + Arc::clone(&route_manager), + Arc::clone(&metrics), + )); + let conn_config = ConnectionConfig::default(); + let conn_tracker = Arc::new(ConnectionTracker::new( + conn_config.max_connections_per_ip, + conn_config.connection_rate_limit_per_minute, + )); + Self { + listeners: HashMap::new(), + route_manager, + metrics, + tls_configs: Arc::new(HashMap::new()), + http_proxy, + conn_config: Arc::new(conn_config), + conn_tracker, + cancel_token: CancellationToken::new(), + } + } + + /// Set connection configuration. + pub fn set_connection_config(&mut self, config: ConnectionConfig) { + self.conn_tracker = Arc::new(ConnectionTracker::new( + config.max_connections_per_ip, + config.connection_rate_limit_per_minute, + )); + self.conn_config = Arc::new(config); + } + + /// Set TLS certificate configurations. + pub fn set_tls_configs(&mut self, configs: HashMap) { + self.tls_configs = Arc::new(configs); + } + + /// Start listening on a port. + pub async fn add_port(&mut self, port: u16) -> Result<(), ListenerError> { + if self.listeners.contains_key(&port) { + return Err(ListenerError::AlreadyBound(port)); + } + + let addr = format!("0.0.0.0:{}", port); + let listener = TcpListener::bind(&addr).await.map_err(|e| { + ListenerError::BindFailed { port, source: e } + })?; + + info!("Listening on port {}", port); + + let route_manager = Arc::clone(&self.route_manager); + let metrics = Arc::clone(&self.metrics); + let tls_configs = Arc::clone(&self.tls_configs); + let http_proxy = Arc::clone(&self.http_proxy); + let conn_config = Arc::clone(&self.conn_config); + let conn_tracker = Arc::clone(&self.conn_tracker); + let cancel = self.cancel_token.clone(); + + let handle = tokio::spawn(async move { + Self::accept_loop( + listener, port, route_manager, metrics, tls_configs, + http_proxy, conn_config, conn_tracker, cancel, + ).await; + }); + + self.listeners.insert(port, handle); + Ok(()) + } + + /// Stop listening on a port. + pub fn remove_port(&mut self, port: u16) -> bool { + if let Some(handle) = self.listeners.remove(&port) { + handle.abort(); + info!("Stopped listening on port {}", port); + true + } else { + false + } + } + + /// Get all currently listening ports. + pub fn listening_ports(&self) -> Vec { + let mut ports: Vec = self.listeners.keys().copied().collect(); + ports.sort(); + ports + } + + /// Stop all listeners gracefully. + /// + /// Signals cancellation and waits up to `graceful_shutdown_timeout_ms` for + /// connections to drain, then aborts remaining tasks. + pub async fn graceful_stop(&mut self) { + let timeout_ms = self.conn_config.graceful_shutdown_timeout_ms; + info!("Initiating graceful shutdown (timeout: {}ms)", timeout_ms); + + // Signal all accept loops to stop accepting new connections + self.cancel_token.cancel(); + + // Wait for existing connections to drain + let timeout = std::time::Duration::from_millis(timeout_ms); + let deadline = tokio::time::Instant::now() + timeout; + + for (port, handle) in self.listeners.drain() { + let remaining = deadline.saturating_duration_since(tokio::time::Instant::now()); + if remaining.is_zero() { + handle.abort(); + warn!("Force-stopped listener on port {} (timeout exceeded)", port); + } else { + match tokio::time::timeout(remaining, handle).await { + Ok(_) => info!("Listener on port {} stopped gracefully", port), + Err(_) => { + warn!("Listener on port {} did not stop in time, aborting", port); + } + } + } + } + + // Reset cancellation token for potential restart + self.cancel_token = CancellationToken::new(); + info!("Graceful shutdown complete"); + } + + /// Stop all listeners immediately (backward compatibility). + pub fn stop_all(&mut self) { + self.cancel_token.cancel(); + for (port, handle) in self.listeners.drain() { + handle.abort(); + info!("Stopped listening on port {}", port); + } + self.cancel_token = CancellationToken::new(); + } + + /// Update the route manager (for hot-reload). + pub fn update_route_manager(&mut self, route_manager: Arc) { + self.route_manager = route_manager; + } + + /// Get a reference to the metrics collector. + pub fn metrics(&self) -> &Arc { + &self.metrics + } + + /// Accept loop for a single port. + async fn accept_loop( + listener: TcpListener, + port: u16, + route_manager: Arc, + metrics: Arc, + tls_configs: Arc>, + http_proxy: Arc, + conn_config: Arc, + conn_tracker: Arc, + cancel: CancellationToken, + ) { + loop { + tokio::select! { + _ = cancel.cancelled() => { + info!("Accept loop on port {} shutting down", port); + break; + } + result = listener.accept() => { + match result { + Ok((stream, peer_addr)) => { + let ip = peer_addr.ip(); + + // Check per-IP limits and rate limiting + if !conn_tracker.try_accept(&ip) { + debug!("Rejected connection from {} (per-IP limit or rate limit)", peer_addr); + drop(stream); + continue; + } + + conn_tracker.connection_opened(&ip); + + let rm = Arc::clone(&route_manager); + let m = Arc::clone(&metrics); + let tc = Arc::clone(&tls_configs); + let hp = Arc::clone(&http_proxy); + let cc = Arc::clone(&conn_config); + let ct = Arc::clone(&conn_tracker); + let cn = cancel.clone(); + debug!("Accepted connection from {} on port {}", peer_addr, port); + + tokio::spawn(async move { + let result = Self::handle_connection( + stream, port, peer_addr, rm, m, tc, hp, cc, cn, + ).await; + if let Err(e) = result { + debug!("Connection error from {}: {}", peer_addr, e); + } + ct.connection_closed(&ip); + }); + } + Err(e) => { + error!("Accept error on port {}: {}", port, e); + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + } + } + } + } + } + } + + /// Handle a single incoming connection. + async fn handle_connection( + mut stream: tokio::net::TcpStream, + port: u16, + peer_addr: std::net::SocketAddr, + route_manager: Arc, + metrics: Arc, + tls_configs: Arc>, + http_proxy: Arc, + conn_config: Arc, + cancel: CancellationToken, + ) -> Result<(), Box> { + use tokio::io::AsyncReadExt; + + stream.set_nodelay(true)?; + + // Handle PROXY protocol if configured + let mut effective_peer_addr = peer_addr; + if conn_config.accept_proxy_protocol { + let mut proxy_peek = vec![0u8; 256]; + let pn = match tokio::time::timeout( + std::time::Duration::from_millis(conn_config.initial_data_timeout_ms), + stream.peek(&mut proxy_peek), + ).await { + Ok(Ok(n)) => n, + Ok(Err(e)) => return Err(e.into()), + Err(_) => return Err("Initial data timeout (proxy protocol peek)".into()), + }; + + if pn > 0 && crate::proxy_protocol::is_proxy_protocol_v1(&proxy_peek[..pn]) { + match crate::proxy_protocol::parse_v1(&proxy_peek[..pn]) { + Ok((header, consumed)) => { + debug!("PROXY protocol: real client {} -> {}", header.source_addr, header.dest_addr); + effective_peer_addr = header.source_addr; + // Consume the proxy protocol header bytes + let mut discard = vec![0u8; consumed]; + stream.read_exact(&mut discard).await?; + } + Err(e) => { + debug!("Failed to parse PROXY protocol header: {}", e); + // Not a PROXY protocol header, continue normally + } + } + } + } + let peer_addr = effective_peer_addr; + + // Peek at initial bytes with timeout + let mut peek_buf = vec![0u8; 4096]; + let n = match tokio::time::timeout( + std::time::Duration::from_millis(conn_config.initial_data_timeout_ms), + stream.peek(&mut peek_buf), + ).await { + Ok(Ok(n)) => n, + Ok(Err(e)) => return Err(e.into()), + Err(_) => return Err("Initial data timeout".into()), + }; + let initial_data = &peek_buf[..n]; + + // Determine connection type and extract SNI if TLS + let is_tls = sni_parser::is_tls(initial_data); + let is_http = sni_parser::is_http(initial_data); + let domain = if is_tls { + match sni_parser::extract_sni(initial_data) { + sni_parser::SniResult::Found(sni) => Some(sni), + sni_parser::SniResult::NoSni => None, + sni_parser::SniResult::NeedMoreData => { + let mut bigger_buf = vec![0u8; 16384]; + let n = match tokio::time::timeout( + std::time::Duration::from_millis(conn_config.initial_data_timeout_ms), + stream.peek(&mut bigger_buf), + ).await { + Ok(Ok(n)) => n, + Ok(Err(e)) => return Err(e.into()), + Err(_) => return Err("SNI data timeout".into()), + }; + match sni_parser::extract_sni(&bigger_buf[..n]) { + sni_parser::SniResult::Found(sni) => Some(sni), + _ => None, + } + } + sni_parser::SniResult::NotTls => None, + } + } else { + None + }; + + // Match route + let ctx = rustproxy_routing::MatchContext { + port, + domain: domain.as_deref(), + path: None, + client_ip: Some(&peer_addr.ip().to_string()), + tls_version: None, + headers: None, + is_tls, + }; + + let route_match = route_manager.find_route(&ctx); + + let route_match = match route_match { + Some(rm) => rm, + None => { + debug!("No route matched for port {} domain {:?}", port, domain); + return Ok(()); + } + }; + + let route_id = route_match.route.id.as_deref(); + + // Check route-level IP security for passthrough connections + if let Some(ref security) = route_match.route.security { + if !rustproxy_http::request_filter::RequestFilter::check_ip_security( + security, + &peer_addr.ip(), + ) { + debug!("Connection from {} blocked by route security", peer_addr); + return Ok(()); + } + } + + // Track connection in metrics + metrics.connection_opened(route_id); + + let target = match route_match.target { + Some(t) => t, + None => { + debug!("Route matched but no target available"); + metrics.connection_closed(route_id); + return Ok(()); + } + }; + + let target_host = target.host.first().to_string(); + let target_port = target.port.resolve(port); + let tls_mode = route_match.route.tls_mode(); + + // Connection timeout for backend connections + let connect_timeout = std::time::Duration::from_millis(conn_config.connection_timeout_ms); + let base_inactivity_ms = conn_config.socket_timeout_ms; + let (inactivity_timeout, max_lifetime) = match conn_config.keep_alive_treatment.as_ref() { + Some(rustproxy_config::KeepAliveTreatment::Extended) => { + let multiplier = conn_config.keep_alive_inactivity_multiplier.unwrap_or(6.0); + let extended_lifetime = conn_config.extended_keep_alive_lifetime_ms + .unwrap_or(7 * 24 * 3600 * 1000); // 7 days default + ( + std::time::Duration::from_millis((base_inactivity_ms as f64 * multiplier) as u64), + std::time::Duration::from_millis(extended_lifetime), + ) + } + Some(rustproxy_config::KeepAliveTreatment::Immortal) => { + ( + std::time::Duration::from_millis(base_inactivity_ms), + std::time::Duration::from_secs(u64::MAX / 2), + ) + } + _ => { + // Standard + ( + std::time::Duration::from_millis(base_inactivity_ms), + std::time::Duration::from_millis(conn_config.max_connection_lifetime_ms), + ) + } + }; + + // Determine if we should send PROXY protocol to backend + let should_send_proxy = conn_config.send_proxy_protocol + || route_match.route.action.send_proxy_protocol.unwrap_or(false) + || target.send_proxy_protocol.unwrap_or(false); + + // Generate PROXY protocol header if needed + let proxy_header = if should_send_proxy { + let dest = std::net::SocketAddr::new( + target_host.parse().unwrap_or(std::net::IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED)), + target_port, + ); + Some(crate::proxy_protocol::generate_v1(&peer_addr, &dest)) + } else { + None + }; + + let result = match tls_mode { + Some(rustproxy_config::TlsMode::Passthrough) => { + // Raw TCP passthrough - connect to backend and forward + let mut backend = match tokio::time::timeout( + connect_timeout, + tokio::net::TcpStream::connect(format!("{}:{}", target_host, target_port)), + ).await { + Ok(Ok(s)) => s, + Ok(Err(e)) => return Err(e.into()), + Err(_) => return Err("Backend connection timeout".into()), + }; + backend.set_nodelay(true)?; + + // Send PROXY protocol header if configured + if let Some(ref header) = proxy_header { + use tokio::io::AsyncWriteExt; + backend.write_all(header.as_bytes()).await?; + } + + debug!( + "Passthrough: {} -> {}:{} (SNI: {:?})", + peer_addr, target_host, target_port, domain + ); + + let mut actual_buf = vec![0u8; n]; + stream.read_exact(&mut actual_buf).await?; + + let (bytes_in, bytes_out) = forwarder::forward_bidirectional_with_timeouts( + stream, backend, Some(&actual_buf), + inactivity_timeout, max_lifetime, cancel, + ).await?; + metrics.record_bytes(bytes_in, bytes_out, route_id); + Ok(()) + } + Some(rustproxy_config::TlsMode::Terminate) => { + let tls_config = Self::find_tls_config(&domain, &tls_configs)?; + + // TLS accept with timeout, applying route-level TLS settings + let route_tls = route_match.route.action.tls.as_ref(); + let acceptor = tls_handler::build_tls_acceptor_with_config( + &tls_config.cert_pem, &tls_config.key_pem, route_tls, + )?; + let tls_stream = match tokio::time::timeout( + std::time::Duration::from_millis(conn_config.initial_data_timeout_ms), + tls_handler::accept_tls(stream, &acceptor), + ).await { + Ok(Ok(s)) => s, + Ok(Err(e)) => return Err(e), + Err(_) => return Err("TLS handshake timeout".into()), + }; + + // Peek at decrypted data to determine if HTTP + let mut buf_stream = tokio::io::BufReader::new(tls_stream); + let peeked = { + use tokio::io::AsyncBufReadExt; + match buf_stream.fill_buf().await { + Ok(data) => sni_parser::is_http(data), + Err(_) => false, + } + }; + + if peeked { + debug!( + "TLS Terminate + HTTP: {} -> {}:{} (domain: {:?})", + peer_addr, target_host, target_port, domain + ); + http_proxy.handle_io(buf_stream, peer_addr, port).await; + } else { + debug!( + "TLS Terminate + TCP: {} -> {}:{} (domain: {:?})", + peer_addr, target_host, target_port, domain + ); + // Raw TCP forwarding of decrypted stream + let backend = match tokio::time::timeout( + connect_timeout, + tokio::net::TcpStream::connect(format!("{}:{}", target_host, target_port)), + ).await { + Ok(Ok(s)) => s, + Ok(Err(e)) => return Err(e.into()), + Err(_) => return Err("Backend connection timeout".into()), + }; + backend.set_nodelay(true)?; + + let (tls_read, tls_write) = tokio::io::split(buf_stream); + let (backend_read, backend_write) = tokio::io::split(backend); + + let (bytes_in, bytes_out) = Self::forward_bidirectional_split_with_timeouts( + tls_read, tls_write, backend_read, backend_write, + inactivity_timeout, max_lifetime, + ).await; + + metrics.record_bytes(bytes_in, bytes_out, route_id); + } + Ok(()) + } + Some(rustproxy_config::TlsMode::TerminateAndReencrypt) => { + let route_tls = route_match.route.action.tls.as_ref(); + Self::handle_tls_terminate_reencrypt( + stream, n, &domain, &target_host, target_port, + peer_addr, &tls_configs, &metrics, route_id, &conn_config, route_tls, + ).await + } + None => { + if is_http { + // Plain HTTP - use HTTP proxy for request-level routing + debug!("HTTP proxy: {} on port {}", peer_addr, port); + http_proxy.handle_connection(stream, peer_addr, port).await; + Ok(()) + } else { + // Plain TCP forwarding (non-HTTP) + let mut backend = match tokio::time::timeout( + connect_timeout, + tokio::net::TcpStream::connect(format!("{}:{}", target_host, target_port)), + ).await { + Ok(Ok(s)) => s, + Ok(Err(e)) => return Err(e.into()), + Err(_) => return Err("Backend connection timeout".into()), + }; + backend.set_nodelay(true)?; + + // Send PROXY protocol header if configured + if let Some(ref header) = proxy_header { + use tokio::io::AsyncWriteExt; + backend.write_all(header.as_bytes()).await?; + } + + debug!( + "Forward: {} -> {}:{}", + peer_addr, target_host, target_port + ); + + let mut actual_buf = vec![0u8; n]; + stream.read_exact(&mut actual_buf).await?; + + let (bytes_in, bytes_out) = forwarder::forward_bidirectional_with_timeouts( + stream, backend, Some(&actual_buf), + inactivity_timeout, max_lifetime, cancel, + ).await?; + metrics.record_bytes(bytes_in, bytes_out, route_id); + Ok(()) + } + } + }; + + metrics.connection_closed(route_id); + result + } + + /// Handle TLS terminate-and-reencrypt: accept TLS from client, connect TLS to backend. + async fn handle_tls_terminate_reencrypt( + stream: tokio::net::TcpStream, + _peek_len: usize, + domain: &Option, + target_host: &str, + target_port: u16, + peer_addr: std::net::SocketAddr, + tls_configs: &HashMap, + metrics: &MetricsCollector, + route_id: Option<&str>, + conn_config: &ConnectionConfig, + route_tls: Option<&rustproxy_config::RouteTls>, + ) -> Result<(), Box> { + let tls_config = Self::find_tls_config(domain, tls_configs)?; + let acceptor = tls_handler::build_tls_acceptor_with_config( + &tls_config.cert_pem, &tls_config.key_pem, route_tls, + )?; + + // Accept TLS from client with timeout + let client_tls = match tokio::time::timeout( + std::time::Duration::from_millis(conn_config.initial_data_timeout_ms), + tls_handler::accept_tls(stream, &acceptor), + ).await { + Ok(Ok(s)) => s, + Ok(Err(e)) => return Err(e), + Err(_) => return Err("TLS handshake timeout".into()), + }; + + debug!( + "TLS Terminate+Reencrypt: {} -> {}:{} (domain: {:?})", + peer_addr, target_host, target_port, domain + ); + + // Connect to backend over TLS with timeout + let backend_tls = match tokio::time::timeout( + std::time::Duration::from_millis(conn_config.connection_timeout_ms), + tls_handler::connect_tls(target_host, target_port), + ).await { + Ok(Ok(s)) => s, + Ok(Err(e)) => return Err(e), + Err(_) => return Err("Backend TLS connection timeout".into()), + }; + + // Forward between two TLS streams + let (client_read, client_write) = tokio::io::split(client_tls); + let (backend_read, backend_write) = tokio::io::split(backend_tls); + + let base_inactivity_ms = conn_config.socket_timeout_ms; + let (inactivity_timeout, max_lifetime) = match conn_config.keep_alive_treatment.as_ref() { + Some(rustproxy_config::KeepAliveTreatment::Extended) => { + let multiplier = conn_config.keep_alive_inactivity_multiplier.unwrap_or(6.0); + let extended_lifetime = conn_config.extended_keep_alive_lifetime_ms + .unwrap_or(7 * 24 * 3600 * 1000); // 7 days default + ( + std::time::Duration::from_millis((base_inactivity_ms as f64 * multiplier) as u64), + std::time::Duration::from_millis(extended_lifetime), + ) + } + Some(rustproxy_config::KeepAliveTreatment::Immortal) => { + ( + std::time::Duration::from_millis(base_inactivity_ms), + std::time::Duration::from_secs(u64::MAX / 2), + ) + } + _ => { + // Standard + ( + std::time::Duration::from_millis(base_inactivity_ms), + std::time::Duration::from_millis(conn_config.max_connection_lifetime_ms), + ) + } + }; + + let (bytes_in, bytes_out) = Self::forward_bidirectional_split_with_timeouts( + client_read, client_write, backend_read, backend_write, + inactivity_timeout, max_lifetime, + ).await; + + metrics.record_bytes(bytes_in, bytes_out, route_id); + Ok(()) + } + + /// Find the TLS config for a given domain. + fn find_tls_config<'a>( + domain: &Option, + tls_configs: &'a HashMap, + ) -> Result<&'a TlsCertConfig, Box> { + if let Some(domain) = domain { + // Try exact match + if let Some(config) = tls_configs.get(domain) { + return Ok(config); + } + // Try wildcard + if let Some(dot_pos) = domain.find('.') { + let wildcard = format!("*.{}", &domain[dot_pos + 1..]); + if let Some(config) = tls_configs.get(&wildcard) { + return Ok(config); + } + } + } + // Try default/fallback cert + if let Some(config) = tls_configs.get("*") { + return Ok(config); + } + // Try first available cert + if let Some((_key, config)) = tls_configs.iter().next() { + return Ok(config); + } + Err("No TLS certificate available for this domain".into()) + } + + /// Forward bidirectional between two split streams with inactivity and lifetime timeouts. + async fn forward_bidirectional_split_with_timeouts( + mut client_read: R1, + mut client_write: W1, + mut backend_read: R2, + mut backend_write: W2, + inactivity_timeout: std::time::Duration, + max_lifetime: std::time::Duration, + ) -> (u64, u64) + where + R1: tokio::io::AsyncRead + Unpin + Send + 'static, + W1: tokio::io::AsyncWrite + Unpin + Send + 'static, + R2: tokio::io::AsyncRead + Unpin + Send + 'static, + W2: tokio::io::AsyncWrite + Unpin + Send + 'static, + { + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + use std::sync::Arc; + use std::sync::atomic::{AtomicU64, Ordering}; + + let last_activity = Arc::new(AtomicU64::new(0)); + let start = std::time::Instant::now(); + + let la1 = Arc::clone(&last_activity); + let c2b = tokio::spawn(async move { + let mut buf = vec![0u8; 65536]; + let mut total = 0u64; + loop { + let n = match client_read.read(&mut buf).await { + Ok(0) | Err(_) => break, + Ok(n) => n, + }; + if backend_write.write_all(&buf[..n]).await.is_err() { + break; + } + total += n as u64; + la1.store( + start.elapsed().as_millis() as u64, + Ordering::Relaxed, + ); + } + let _ = backend_write.shutdown().await; + total + }); + + let la2 = Arc::clone(&last_activity); + let b2c = tokio::spawn(async move { + let mut buf = vec![0u8; 65536]; + let mut total = 0u64; + loop { + let n = match backend_read.read(&mut buf).await { + Ok(0) | Err(_) => break, + Ok(n) => n, + }; + if client_write.write_all(&buf[..n]).await.is_err() { + break; + } + total += n as u64; + la2.store( + start.elapsed().as_millis() as u64, + Ordering::Relaxed, + ); + } + let _ = client_write.shutdown().await; + total + }); + + // Watchdog task: check for inactivity and max lifetime + let la_watch = Arc::clone(&last_activity); + let c2b_handle = c2b.abort_handle(); + let b2c_handle = b2c.abort_handle(); + let watchdog = tokio::spawn(async move { + let check_interval = std::time::Duration::from_secs(5); + let mut last_seen = 0u64; + loop { + tokio::time::sleep(check_interval).await; + + // Check max lifetime + if start.elapsed() >= max_lifetime { + debug!("Connection exceeded max lifetime, closing"); + c2b_handle.abort(); + b2c_handle.abort(); + break; + } + + // Check inactivity + let current = la_watch.load(Ordering::Relaxed); + if current == last_seen { + // No activity since last check + let elapsed_since_activity = start.elapsed().as_millis() as u64 - current; + if elapsed_since_activity >= inactivity_timeout.as_millis() as u64 { + debug!("Connection inactive for {}ms, closing", elapsed_since_activity); + c2b_handle.abort(); + b2c_handle.abort(); + break; + } + } + last_seen = current; + } + }); + + let bytes_in = c2b.await.unwrap_or(0); + let bytes_out = b2c.await.unwrap_or(0); + watchdog.abort(); + (bytes_in, bytes_out) + } +} diff --git a/rust/crates/rustproxy-passthrough/src/tls_handler.rs b/rust/crates/rustproxy-passthrough/src/tls_handler.rs new file mode 100644 index 0000000..5abddb5 --- /dev/null +++ b/rust/crates/rustproxy-passthrough/src/tls_handler.rs @@ -0,0 +1,190 @@ +use std::io::BufReader; +use std::sync::Arc; + +use rustls::pki_types::{CertificateDer, PrivateKeyDer}; +use rustls::ServerConfig; +use tokio::net::TcpStream; +use tokio_rustls::{TlsAcceptor, TlsConnector, server::TlsStream as ServerTlsStream}; +use tracing::debug; + +/// Ensure the default crypto provider is installed. +fn ensure_crypto_provider() { + let _ = rustls::crypto::ring::default_provider().install_default(); +} + +/// Build a TLS acceptor from PEM-encoded cert and key data. +pub fn build_tls_acceptor(cert_pem: &str, key_pem: &str) -> Result> { + build_tls_acceptor_with_config(cert_pem, key_pem, None) +} + +/// Build a TLS acceptor with optional RouteTls configuration for version/cipher tuning. +pub fn build_tls_acceptor_with_config( + cert_pem: &str, + key_pem: &str, + tls_config: Option<&rustproxy_config::RouteTls>, +) -> Result> { + ensure_crypto_provider(); + let certs = load_certs(cert_pem)?; + let key = load_private_key(key_pem)?; + + let mut config = if let Some(route_tls) = tls_config { + // Apply TLS version restrictions + let versions = resolve_tls_versions(route_tls.versions.as_deref()); + let builder = ServerConfig::builder_with_protocol_versions(&versions); + builder + .with_no_client_auth() + .with_single_cert(certs, key)? + } else { + ServerConfig::builder() + .with_no_client_auth() + .with_single_cert(certs, key)? + }; + + // Apply session timeout if configured + if let Some(route_tls) = tls_config { + if let Some(timeout_secs) = route_tls.session_timeout { + config.session_storage = rustls::server::ServerSessionMemoryCache::new( + 256, // max sessions + ); + debug!("TLS session timeout configured: {}s", timeout_secs); + } + } + + Ok(TlsAcceptor::from(Arc::new(config))) +} + +/// Resolve TLS version strings to rustls SupportedProtocolVersion. +fn resolve_tls_versions(versions: Option<&[String]>) -> Vec<&'static rustls::SupportedProtocolVersion> { + let versions = match versions { + Some(v) if !v.is_empty() => v, + _ => return vec![&rustls::version::TLS12, &rustls::version::TLS13], + }; + + let mut result = Vec::new(); + for v in versions { + match v.as_str() { + "TLSv1.2" | "TLS1.2" | "1.2" | "TLSv12" => { + if !result.contains(&&rustls::version::TLS12) { + result.push(&rustls::version::TLS12); + } + } + "TLSv1.3" | "TLS1.3" | "1.3" | "TLSv13" => { + if !result.contains(&&rustls::version::TLS13) { + result.push(&rustls::version::TLS13); + } + } + other => { + debug!("Unknown TLS version '{}', ignoring", other); + } + } + } + + if result.is_empty() { + // Fallback to both if no valid versions specified + vec![&rustls::version::TLS12, &rustls::version::TLS13] + } else { + result + } +} + +/// Accept a TLS connection from a client stream. +pub async fn accept_tls( + stream: TcpStream, + acceptor: &TlsAcceptor, +) -> Result, Box> { + let tls_stream = acceptor.accept(stream).await?; + debug!("TLS handshake completed"); + Ok(tls_stream) +} + +/// Connect to a backend with TLS (for terminate-and-reencrypt mode). +pub async fn connect_tls( + host: &str, + port: u16, +) -> Result, Box> { + ensure_crypto_provider(); + let config = rustls::ClientConfig::builder() + .dangerous() + .with_custom_certificate_verifier(Arc::new(InsecureVerifier)) + .with_no_client_auth(); + + let connector = TlsConnector::from(Arc::new(config)); + + let stream = TcpStream::connect(format!("{}:{}", host, port)).await?; + stream.set_nodelay(true)?; + + let server_name = rustls::pki_types::ServerName::try_from(host.to_string())?; + let tls_stream = connector.connect(server_name, stream).await?; + debug!("Backend TLS connection established to {}:{}", host, port); + Ok(tls_stream) +} + +/// Load certificates from PEM string. +fn load_certs(pem: &str) -> Result>, Box> { + let mut reader = BufReader::new(pem.as_bytes()); + let certs: Vec> = rustls_pemfile::certs(&mut reader) + .collect::, _>>()?; + if certs.is_empty() { + return Err("No certificates found in PEM data".into()); + } + Ok(certs) +} + +/// Load private key from PEM string. +fn load_private_key(pem: &str) -> Result, Box> { + let mut reader = BufReader::new(pem.as_bytes()); + // Try PKCS8 first, then RSA, then EC + let key = rustls_pemfile::private_key(&mut reader)? + .ok_or("No private key found in PEM data")?; + Ok(key) +} + +/// Insecure certificate verifier for backend connections (terminate-and-reencrypt). +/// In internal networks, backends may use self-signed certs. +#[derive(Debug)] +struct InsecureVerifier; + +impl rustls::client::danger::ServerCertVerifier for InsecureVerifier { + fn verify_server_cert( + &self, + _end_entity: &CertificateDer<'_>, + _intermediates: &[CertificateDer<'_>], + _server_name: &rustls::pki_types::ServerName<'_>, + _ocsp_response: &[u8], + _now: rustls::pki_types::UnixTime, + ) -> Result { + Ok(rustls::client::danger::ServerCertVerified::assertion()) + } + + fn verify_tls12_signature( + &self, + _message: &[u8], + _cert: &CertificateDer<'_>, + _dss: &rustls::DigitallySignedStruct, + ) -> Result { + Ok(rustls::client::danger::HandshakeSignatureValid::assertion()) + } + + fn verify_tls13_signature( + &self, + _message: &[u8], + _cert: &CertificateDer<'_>, + _dss: &rustls::DigitallySignedStruct, + ) -> Result { + Ok(rustls::client::danger::HandshakeSignatureValid::assertion()) + } + + fn supported_verify_schemes(&self) -> Vec { + vec![ + rustls::SignatureScheme::RSA_PKCS1_SHA256, + rustls::SignatureScheme::RSA_PKCS1_SHA384, + rustls::SignatureScheme::RSA_PKCS1_SHA512, + rustls::SignatureScheme::ECDSA_NISTP256_SHA256, + rustls::SignatureScheme::ECDSA_NISTP384_SHA384, + rustls::SignatureScheme::ED25519, + rustls::SignatureScheme::RSA_PSS_SHA256, + rustls::SignatureScheme::RSA_PSS_SHA384, + rustls::SignatureScheme::RSA_PSS_SHA512, + ] + } +} diff --git a/rust/crates/rustproxy-routing/Cargo.toml b/rust/crates/rustproxy-routing/Cargo.toml new file mode 100644 index 0000000..6696b6a --- /dev/null +++ b/rust/crates/rustproxy-routing/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "rustproxy-routing" +version.workspace = true +edition.workspace = true +license.workspace = true +authors.workspace = true +description = "Route matching engine for RustProxy" + +[dependencies] +rustproxy-config = { workspace = true } +glob-match = { workspace = true } +ipnet = { workspace = true } +regex = { workspace = true } +tracing = { workspace = true } +thiserror = { workspace = true } +arc-swap = { workspace = true } diff --git a/rust/crates/rustproxy-routing/src/lib.rs b/rust/crates/rustproxy-routing/src/lib.rs new file mode 100644 index 0000000..972461e --- /dev/null +++ b/rust/crates/rustproxy-routing/src/lib.rs @@ -0,0 +1,9 @@ +//! # rustproxy-routing +//! +//! Route matching engine for RustProxy. +//! Provides domain/path/IP/header matchers and a port-indexed RouteManager. + +pub mod route_manager; +pub mod matchers; + +pub use route_manager::*; diff --git a/rust/crates/rustproxy-routing/src/matchers/domain.rs b/rust/crates/rustproxy-routing/src/matchers/domain.rs new file mode 100644 index 0000000..b7529d1 --- /dev/null +++ b/rust/crates/rustproxy-routing/src/matchers/domain.rs @@ -0,0 +1,86 @@ +/// Match a domain against a pattern supporting wildcards. +/// +/// Supported patterns: +/// - `*` matches any domain +/// - `*.example.com` matches any subdomain of example.com +/// - `example.com` exact match +/// - `**.example.com` matches any depth of subdomain +pub fn domain_matches(pattern: &str, domain: &str) -> bool { + let pattern = pattern.trim().to_lowercase(); + let domain = domain.trim().to_lowercase(); + + if pattern == "*" { + return true; + } + + if pattern == domain { + return true; + } + + // Wildcard patterns + if pattern.starts_with("*.") { + let suffix = &pattern[2..]; // e.g., "example.com" + // Match exact parent or any single-level subdomain + if domain == suffix { + return true; + } + if domain.ends_with(&format!(".{}", suffix)) { + // Check it's a single level subdomain for `*.` + let prefix = &domain[..domain.len() - suffix.len() - 1]; + return !prefix.contains('.'); + } + return false; + } + + if pattern.starts_with("**.") { + let suffix = &pattern[3..]; + // Match exact parent or any depth of subdomain + return domain == suffix || domain.ends_with(&format!(".{}", suffix)); + } + + // Use glob-match for more complex patterns + glob_match::glob_match(&pattern, &domain) +} + +/// Check if a domain matches any of the given patterns. +pub fn domain_matches_any(patterns: &[&str], domain: &str) -> bool { + patterns.iter().any(|p| domain_matches(p, domain)) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_exact_match() { + assert!(domain_matches("example.com", "example.com")); + assert!(!domain_matches("example.com", "other.com")); + } + + #[test] + fn test_wildcard_all() { + assert!(domain_matches("*", "anything.com")); + assert!(domain_matches("*", "sub.domain.example.com")); + } + + #[test] + fn test_wildcard_subdomain() { + assert!(domain_matches("*.example.com", "www.example.com")); + assert!(domain_matches("*.example.com", "api.example.com")); + assert!(domain_matches("*.example.com", "example.com")); + assert!(!domain_matches("*.example.com", "deep.sub.example.com")); + } + + #[test] + fn test_double_wildcard() { + assert!(domain_matches("**.example.com", "www.example.com")); + assert!(domain_matches("**.example.com", "deep.sub.example.com")); + assert!(domain_matches("**.example.com", "example.com")); + } + + #[test] + fn test_case_insensitive() { + assert!(domain_matches("Example.COM", "example.com")); + assert!(domain_matches("*.EXAMPLE.com", "WWW.example.COM")); + } +} diff --git a/rust/crates/rustproxy-routing/src/matchers/header.rs b/rust/crates/rustproxy-routing/src/matchers/header.rs new file mode 100644 index 0000000..a2cf656 --- /dev/null +++ b/rust/crates/rustproxy-routing/src/matchers/header.rs @@ -0,0 +1,98 @@ +use std::collections::HashMap; +use regex::Regex; + +/// Match HTTP headers against a set of patterns. +/// +/// Pattern values can be: +/// - Exact string: `"application/json"` +/// - Regex (surrounded by /): `"/^text\/.*/"` +pub fn headers_match( + patterns: &HashMap, + headers: &HashMap, +) -> bool { + for (key, pattern) in patterns { + let key_lower = key.to_lowercase(); + + // Find the header (case-insensitive) + let header_value = headers + .iter() + .find(|(k, _)| k.to_lowercase() == key_lower) + .map(|(_, v)| v.as_str()); + + let header_value = match header_value { + Some(v) => v, + None => return false, // Required header not present + }; + + // Check if pattern is a regex (surrounded by /) + if pattern.starts_with('/') && pattern.ends_with('/') && pattern.len() > 2 { + let regex_str = &pattern[1..pattern.len() - 1]; + match Regex::new(regex_str) { + Ok(re) => { + if !re.is_match(header_value) { + return false; + } + } + Err(_) => { + // Invalid regex, fall back to exact match + if header_value != pattern { + return false; + } + } + } + } else { + // Exact match + if header_value != pattern { + return false; + } + } + } + + true +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_exact_header_match() { + let patterns: HashMap = { + let mut m = HashMap::new(); + m.insert("Content-Type".to_string(), "application/json".to_string()); + m + }; + let headers: HashMap = { + let mut m = HashMap::new(); + m.insert("content-type".to_string(), "application/json".to_string()); + m + }; + assert!(headers_match(&patterns, &headers)); + } + + #[test] + fn test_regex_header_match() { + let patterns: HashMap = { + let mut m = HashMap::new(); + m.insert("Content-Type".to_string(), "/^text\\/.*/".to_string()); + m + }; + let headers: HashMap = { + let mut m = HashMap::new(); + m.insert("content-type".to_string(), "text/html".to_string()); + m + }; + assert!(headers_match(&patterns, &headers)); + } + + #[test] + fn test_missing_header() { + let patterns: HashMap = { + let mut m = HashMap::new(); + m.insert("X-Custom".to_string(), "value".to_string()); + m + }; + let headers: HashMap = HashMap::new(); + assert!(!headers_match(&patterns, &headers)); + } +} diff --git a/rust/crates/rustproxy-routing/src/matchers/ip.rs b/rust/crates/rustproxy-routing/src/matchers/ip.rs new file mode 100644 index 0000000..d92f84b --- /dev/null +++ b/rust/crates/rustproxy-routing/src/matchers/ip.rs @@ -0,0 +1,126 @@ +use std::net::IpAddr; +use std::str::FromStr; +use ipnet::IpNet; + +/// Match an IP address against a pattern. +/// +/// Supported patterns: +/// - `*` matches any IP +/// - `192.168.1.0/24` CIDR range +/// - `192.168.1.100` exact match +/// - `192.168.1.*` wildcard (converted to CIDR) +/// - `::ffff:192.168.1.100` IPv6-mapped IPv4 +pub fn ip_matches(pattern: &str, ip: &str) -> bool { + let pattern = pattern.trim(); + + if pattern == "*" { + return true; + } + + // Normalize IPv4-mapped IPv6 + let normalized_ip = normalize_ip_str(ip); + + // Try CIDR match + if pattern.contains('/') { + if let Ok(net) = IpNet::from_str(pattern) { + if let Ok(addr) = IpAddr::from_str(&normalized_ip) { + return net.contains(&addr); + } + } + return false; + } + + // Handle wildcard patterns like 192.168.1.* + if pattern.contains('*') { + let pattern_cidr = wildcard_to_cidr(pattern); + if let Some(cidr) = pattern_cidr { + if let Ok(net) = IpNet::from_str(&cidr) { + if let Ok(addr) = IpAddr::from_str(&normalized_ip) { + return net.contains(&addr); + } + } + } + return false; + } + + // Exact match + let normalized_pattern = normalize_ip_str(pattern); + normalized_ip == normalized_pattern +} + +/// Check if an IP matches any of the given patterns. +pub fn ip_matches_any(patterns: &[String], ip: &str) -> bool { + patterns.iter().any(|p| ip_matches(p, ip)) +} + +/// Normalize IPv4-mapped IPv6 addresses. +fn normalize_ip_str(ip: &str) -> String { + let ip = ip.trim(); + if ip.starts_with("::ffff:") { + return ip[7..].to_string(); + } + ip.to_string() +} + +/// Convert a wildcard IP pattern to CIDR notation. +/// e.g., "192.168.1.*" -> "192.168.1.0/24" +fn wildcard_to_cidr(pattern: &str) -> Option { + let parts: Vec<&str> = pattern.split('.').collect(); + if parts.len() != 4 { + return None; + } + + let mut octets = [0u8; 4]; + let mut prefix_len = 0; + + for (i, part) in parts.iter().enumerate() { + if *part == "*" { + break; + } + if let Ok(n) = part.parse::() { + octets[i] = n; + prefix_len += 8; + } else { + return None; + } + } + + Some(format!("{}.{}.{}.{}/{}", octets[0], octets[1], octets[2], octets[3], prefix_len)) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_wildcard_all() { + assert!(ip_matches("*", "192.168.1.100")); + assert!(ip_matches("*", "::1")); + } + + #[test] + fn test_exact_match() { + assert!(ip_matches("192.168.1.100", "192.168.1.100")); + assert!(!ip_matches("192.168.1.100", "192.168.1.101")); + } + + #[test] + fn test_cidr() { + assert!(ip_matches("192.168.1.0/24", "192.168.1.100")); + assert!(ip_matches("192.168.1.0/24", "192.168.1.1")); + assert!(!ip_matches("192.168.1.0/24", "192.168.2.1")); + } + + #[test] + fn test_wildcard_pattern() { + assert!(ip_matches("192.168.1.*", "192.168.1.100")); + assert!(ip_matches("192.168.1.*", "192.168.1.1")); + assert!(!ip_matches("192.168.1.*", "192.168.2.1")); + } + + #[test] + fn test_ipv6_mapped() { + assert!(ip_matches("192.168.1.100", "::ffff:192.168.1.100")); + assert!(ip_matches("192.168.1.0/24", "::ffff:192.168.1.50")); + } +} diff --git a/rust/crates/rustproxy-routing/src/matchers/mod.rs b/rust/crates/rustproxy-routing/src/matchers/mod.rs new file mode 100644 index 0000000..938594e --- /dev/null +++ b/rust/crates/rustproxy-routing/src/matchers/mod.rs @@ -0,0 +1,9 @@ +pub mod domain; +pub mod path; +pub mod ip; +pub mod header; + +pub use domain::*; +pub use path::*; +pub use ip::*; +pub use header::*; diff --git a/rust/crates/rustproxy-routing/src/matchers/path.rs b/rust/crates/rustproxy-routing/src/matchers/path.rs new file mode 100644 index 0000000..d305206 --- /dev/null +++ b/rust/crates/rustproxy-routing/src/matchers/path.rs @@ -0,0 +1,65 @@ +/// Match a URL path against a pattern supporting wildcards. +/// +/// Supported patterns: +/// - `/api/*` matches `/api/anything` (single level) +/// - `/api/**` matches `/api/any/depth/here` +/// - `/exact/path` exact match +/// - `/prefix*` prefix match +pub fn path_matches(pattern: &str, path: &str) -> bool { + // Exact match + if pattern == path { + return true; + } + + // Double-star: match any depth + if pattern.ends_with("/**") { + let prefix = &pattern[..pattern.len() - 3]; + return path == prefix || path.starts_with(&format!("{}/", prefix)); + } + + // Single-star at end: match single path segment + if pattern.ends_with("/*") { + let prefix = &pattern[..pattern.len() - 2]; + if path == prefix { + return true; + } + if path.starts_with(&format!("{}/", prefix)) { + let rest = &path[prefix.len() + 1..]; + // Single level means no more slashes + return !rest.contains('/'); + } + return false; + } + + // Star anywhere: use glob matching + if pattern.contains('*') { + return glob_match::glob_match(pattern, path); + } + + false +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_exact_path() { + assert!(path_matches("/api/users", "/api/users")); + assert!(!path_matches("/api/users", "/api/posts")); + } + + #[test] + fn test_single_wildcard() { + assert!(path_matches("/api/*", "/api/users")); + assert!(path_matches("/api/*", "/api/posts")); + assert!(!path_matches("/api/*", "/api/users/123")); + } + + #[test] + fn test_double_wildcard() { + assert!(path_matches("/api/**", "/api/users")); + assert!(path_matches("/api/**", "/api/users/123")); + assert!(path_matches("/api/**", "/api/users/123/posts")); + } +} diff --git a/rust/crates/rustproxy-routing/src/route_manager.rs b/rust/crates/rustproxy-routing/src/route_manager.rs new file mode 100644 index 0000000..5f67f0e --- /dev/null +++ b/rust/crates/rustproxy-routing/src/route_manager.rs @@ -0,0 +1,545 @@ +use std::collections::HashMap; + +use rustproxy_config::{RouteConfig, RouteTarget, TlsMode}; +use crate::matchers; + +/// Context for route matching (subset of connection info). +pub struct MatchContext<'a> { + pub port: u16, + pub domain: Option<&'a str>, + pub path: Option<&'a str>, + pub client_ip: Option<&'a str>, + pub tls_version: Option<&'a str>, + pub headers: Option<&'a HashMap>, + pub is_tls: bool, +} + +/// Result of a route match. +pub struct RouteMatchResult<'a> { + pub route: &'a RouteConfig, + pub target: Option<&'a RouteTarget>, +} + +/// Port-indexed route lookup with priority-based matching. +/// This is the core routing engine. +pub struct RouteManager { + /// Routes indexed by port for O(1) port lookup. + port_index: HashMap>, + /// All routes, sorted by priority (highest first). + routes: Vec, +} + +impl RouteManager { + /// Create a new RouteManager from a list of routes. + pub fn new(routes: Vec) -> Self { + let mut manager = Self { + port_index: HashMap::new(), + routes: Vec::new(), + }; + + // Filter enabled routes and sort by priority + let mut enabled_routes: Vec = routes + .into_iter() + .filter(|r| r.is_enabled()) + .collect(); + enabled_routes.sort_by(|a, b| b.effective_priority().cmp(&a.effective_priority())); + + // Build port index + for (idx, route) in enabled_routes.iter().enumerate() { + for port in route.listening_ports() { + manager.port_index + .entry(port) + .or_default() + .push(idx); + } + } + + manager.routes = enabled_routes; + manager + } + + /// Find the best matching route for the given context. + pub fn find_route<'a>(&'a self, ctx: &MatchContext<'_>) -> Option> { + // Get routes for this port + let route_indices = self.port_index.get(&ctx.port)?; + + for &idx in route_indices { + let route = &self.routes[idx]; + + if self.matches_route(route, ctx) { + // Find the best matching target within the route + let target = self.find_target(route, ctx); + return Some(RouteMatchResult { route, target }); + } + } + + None + } + + /// Check if a route matches the given context. + fn matches_route(&self, route: &RouteConfig, ctx: &MatchContext<'_>) -> bool { + let rm = &route.route_match; + + // Domain matching + if let Some(ref domains) = rm.domains { + if let Some(domain) = ctx.domain { + let patterns = domains.to_vec(); + if !matchers::domain_matches_any(&patterns, domain) { + return false; + } + } + // If no domain provided but route requires domain, it depends on context + // For TLS passthrough, we need SNI; for other cases we may still match + } + + // Path matching + if let Some(ref pattern) = rm.path { + if let Some(path) = ctx.path { + if !matchers::path_matches(pattern, path) { + return false; + } + } else { + // Route requires path but none provided + return false; + } + } + + // Client IP matching + if let Some(ref client_ips) = rm.client_ip { + if let Some(ip) = ctx.client_ip { + if !matchers::ip_matches_any(client_ips, ip) { + return false; + } + } else { + return false; + } + } + + // TLS version matching + if let Some(ref tls_versions) = rm.tls_version { + if let Some(version) = ctx.tls_version { + if !tls_versions.iter().any(|v| v == version) { + return false; + } + } else { + return false; + } + } + + // Header matching + if let Some(ref patterns) = rm.headers { + if let Some(headers) = ctx.headers { + if !matchers::headers_match(patterns, headers) { + return false; + } + } else { + return false; + } + } + + true + } + + /// Find the best matching target within a route. + fn find_target<'a>(&self, route: &'a RouteConfig, ctx: &MatchContext<'_>) -> Option<&'a RouteTarget> { + let targets = route.action.targets.as_ref()?; + + if targets.len() == 1 && targets[0].target_match.is_none() { + return Some(&targets[0]); + } + + // Sort candidates by priority (already in order from config) + let mut best: Option<&RouteTarget> = None; + let mut best_priority = i32::MIN; + + for target in targets { + let priority = target.priority.unwrap_or(0); + + if let Some(ref tm) = target.target_match { + if !self.matches_target(tm, ctx) { + continue; + } + } + + if priority > best_priority || best.is_none() { + best = Some(target); + best_priority = priority; + } + } + + // Fall back to first target without match criteria + best.or_else(|| { + targets.iter().find(|t| t.target_match.is_none()) + }) + } + + /// Check if a target match criteria matches the context. + fn matches_target( + &self, + tm: &rustproxy_config::TargetMatch, + ctx: &MatchContext<'_>, + ) -> bool { + // Port matching + if let Some(ref ports) = tm.ports { + if !ports.contains(&ctx.port) { + return false; + } + } + + // Path matching + if let Some(ref pattern) = tm.path { + if let Some(path) = ctx.path { + if !matchers::path_matches(pattern, path) { + return false; + } + } else { + return false; + } + } + + // Header matching + if let Some(ref patterns) = tm.headers { + if let Some(headers) = ctx.headers { + if !matchers::headers_match(patterns, headers) { + return false; + } + } else { + return false; + } + } + + true + } + + /// Get all unique listening ports. + pub fn listening_ports(&self) -> Vec { + let mut ports: Vec = self.port_index.keys().copied().collect(); + ports.sort(); + ports + } + + /// Get all routes for a specific port. + pub fn routes_for_port(&self, port: u16) -> Vec<&RouteConfig> { + self.port_index + .get(&port) + .map(|indices| indices.iter().map(|&i| &self.routes[i]).collect()) + .unwrap_or_default() + } + + /// Get the total number of enabled routes. + pub fn route_count(&self) -> usize { + self.routes.len() + } + + /// Check if any route on the given port requires SNI. + pub fn port_requires_sni(&self, port: u16) -> bool { + let routes = self.routes_for_port(port); + + // If multiple passthrough routes on same port, SNI is needed + let passthrough_routes: Vec<_> = routes + .iter() + .filter(|r| { + r.tls_mode() == Some(&TlsMode::Passthrough) + }) + .collect(); + + if passthrough_routes.len() > 1 { + return true; + } + + // Single passthrough route with specific domain restriction needs SNI + if let Some(route) = passthrough_routes.first() { + if let Some(ref domains) = route.route_match.domains { + let domain_list = domains.to_vec(); + // If it's not just a wildcard, SNI is needed + if !domain_list.iter().all(|d| *d == "*") { + return true; + } + } + } + + false + } +} + +#[cfg(test)] +mod tests { + use super::*; + use rustproxy_config::*; + + fn make_route(port: u16, domain: Option<&str>, priority: i32) -> RouteConfig { + RouteConfig { + id: None, + route_match: RouteMatch { + ports: PortRange::Single(port), + domains: domain.map(|d| DomainSpec::Single(d.to_string())), + path: None, + client_ip: None, + tls_version: None, + headers: None, + }, + action: RouteAction { + action_type: RouteActionType::Forward, + targets: Some(vec![RouteTarget { + target_match: None, + host: HostSpec::Single("localhost".to_string()), + port: PortSpec::Fixed(8080), + tls: None, + websocket: None, + load_balancing: None, + send_proxy_protocol: None, + headers: None, + advanced: None, + priority: None, + }]), + tls: None, + websocket: None, + load_balancing: None, + advanced: None, + options: None, + forwarding_engine: None, + nftables: None, + send_proxy_protocol: None, + }, + headers: None, + security: None, + name: None, + description: None, + priority: Some(priority), + tags: None, + enabled: None, + } + } + + #[test] + fn test_basic_routing() { + let routes = vec![ + make_route(80, Some("example.com"), 0), + make_route(80, Some("other.com"), 0), + ]; + let manager = RouteManager::new(routes); + + let ctx = MatchContext { + port: 80, + domain: Some("example.com"), + path: None, + client_ip: None, + tls_version: None, + headers: None, + is_tls: false, + }; + + let result = manager.find_route(&ctx); + assert!(result.is_some()); + } + + #[test] + fn test_priority_ordering() { + let routes = vec![ + make_route(80, Some("*.example.com"), 0), + make_route(80, Some("api.example.com"), 10), // Higher priority + ]; + let manager = RouteManager::new(routes); + + let ctx = MatchContext { + port: 80, + domain: Some("api.example.com"), + path: None, + client_ip: None, + tls_version: None, + headers: None, + is_tls: false, + }; + + let result = manager.find_route(&ctx).unwrap(); + // Should match the higher-priority specific route + assert!(result.route.route_match.domains.as_ref() + .map(|d| d.to_vec()) + .unwrap() + .contains(&"api.example.com")); + } + + #[test] + fn test_no_match() { + let routes = vec![make_route(80, Some("example.com"), 0)]; + let manager = RouteManager::new(routes); + + let ctx = MatchContext { + port: 443, // Different port + domain: Some("example.com"), + path: None, + client_ip: None, + tls_version: None, + headers: None, + is_tls: false, + }; + + assert!(manager.find_route(&ctx).is_none()); + } + + #[test] + fn test_disabled_routes_excluded() { + let mut route = make_route(80, Some("example.com"), 0); + route.enabled = Some(false); + let manager = RouteManager::new(vec![route]); + assert_eq!(manager.route_count(), 0); + } + + #[test] + fn test_listening_ports() { + let routes = vec![ + make_route(80, Some("a.com"), 0), + make_route(443, Some("b.com"), 0), + make_route(80, Some("c.com"), 0), // duplicate port + ]; + let manager = RouteManager::new(routes); + let ports = manager.listening_ports(); + assert_eq!(ports, vec![80, 443]); + } + + #[test] + fn test_port_requires_sni_single_passthrough() { + let mut route = make_route(443, Some("example.com"), 0); + route.action.tls = Some(RouteTls { + mode: TlsMode::Passthrough, + certificate: None, + acme: None, + versions: None, + ciphers: None, + honor_cipher_order: None, + session_timeout: None, + }); + let manager = RouteManager::new(vec![route]); + // Single passthrough route with specific domain needs SNI + assert!(manager.port_requires_sni(443)); + } + + #[test] + fn test_port_requires_sni_wildcard_only() { + let mut route = make_route(443, Some("*"), 0); + route.action.tls = Some(RouteTls { + mode: TlsMode::Passthrough, + certificate: None, + acme: None, + versions: None, + ciphers: None, + honor_cipher_order: None, + session_timeout: None, + }); + let manager = RouteManager::new(vec![route]); + // Single passthrough route with wildcard doesn't need SNI + assert!(!manager.port_requires_sni(443)); + } + + #[test] + fn test_routes_for_port() { + let routes = vec![ + make_route(80, Some("a.com"), 0), + make_route(80, Some("b.com"), 0), + make_route(443, Some("c.com"), 0), + ]; + let manager = RouteManager::new(routes); + assert_eq!(manager.routes_for_port(80).len(), 2); + assert_eq!(manager.routes_for_port(443).len(), 1); + assert_eq!(manager.routes_for_port(8080).len(), 0); + } + + #[test] + fn test_wildcard_domain_matches_any() { + let routes = vec![make_route(80, Some("*"), 0)]; + let manager = RouteManager::new(routes); + + let ctx = MatchContext { + port: 80, + domain: Some("anything.example.com"), + path: None, + client_ip: None, + tls_version: None, + headers: None, + is_tls: false, + }; + + assert!(manager.find_route(&ctx).is_some()); + } + + #[test] + fn test_no_domain_route_matches_any_domain() { + let routes = vec![make_route(80, None, 0)]; + let manager = RouteManager::new(routes); + + let ctx = MatchContext { + port: 80, + domain: Some("example.com"), + path: None, + client_ip: None, + tls_version: None, + headers: None, + is_tls: false, + }; + + assert!(manager.find_route(&ctx).is_some()); + } + + #[test] + fn test_target_sub_matching() { + let mut route = make_route(80, Some("example.com"), 0); + route.action.targets = Some(vec![ + RouteTarget { + target_match: Some(rustproxy_config::TargetMatch { + ports: None, + path: Some("/api/*".to_string()), + headers: None, + method: None, + }), + host: HostSpec::Single("api-backend".to_string()), + port: PortSpec::Fixed(3000), + tls: None, + websocket: None, + load_balancing: None, + send_proxy_protocol: None, + headers: None, + advanced: None, + priority: Some(10), + }, + RouteTarget { + target_match: None, + host: HostSpec::Single("default-backend".to_string()), + port: PortSpec::Fixed(8080), + tls: None, + websocket: None, + load_balancing: None, + send_proxy_protocol: None, + headers: None, + advanced: None, + priority: None, + }, + ]); + let manager = RouteManager::new(vec![route]); + + // Should match the API target + let ctx = MatchContext { + port: 80, + domain: Some("example.com"), + path: Some("/api/users"), + client_ip: None, + tls_version: None, + headers: None, + is_tls: false, + }; + let result = manager.find_route(&ctx).unwrap(); + assert_eq!(result.target.unwrap().host.first(), "api-backend"); + + // Should fall back to default target + let ctx = MatchContext { + port: 80, + domain: Some("example.com"), + path: Some("/home"), + client_ip: None, + tls_version: None, + headers: None, + is_tls: false, + }; + let result = manager.find_route(&ctx).unwrap(); + assert_eq!(result.target.unwrap().host.first(), "default-backend"); + } +} diff --git a/rust/crates/rustproxy-security/Cargo.toml b/rust/crates/rustproxy-security/Cargo.toml new file mode 100644 index 0000000..c62ed7c --- /dev/null +++ b/rust/crates/rustproxy-security/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "rustproxy-security" +version.workspace = true +edition.workspace = true +license.workspace = true +authors.workspace = true +description = "IP filtering, rate limiting, and authentication for RustProxy" + +[dependencies] +rustproxy-config = { workspace = true } +dashmap = { workspace = true } +ipnet = { workspace = true } +jsonwebtoken = { workspace = true } +base64 = { workspace = true } +tracing = { workspace = true } +thiserror = { workspace = true } +serde = { workspace = true } diff --git a/rust/crates/rustproxy-security/src/basic_auth.rs b/rust/crates/rustproxy-security/src/basic_auth.rs new file mode 100644 index 0000000..7e963a7 --- /dev/null +++ b/rust/crates/rustproxy-security/src/basic_auth.rs @@ -0,0 +1,111 @@ +use base64::Engine; +use base64::engine::general_purpose::STANDARD as BASE64; + +/// Basic auth validator. +pub struct BasicAuthValidator { + users: Vec<(String, String)>, + realm: String, +} + +impl BasicAuthValidator { + pub fn new(users: Vec<(String, String)>, realm: Option) -> Self { + Self { + users, + realm: realm.unwrap_or_else(|| "Restricted".to_string()), + } + } + + /// Validate an Authorization header value. + /// Returns the username if valid. + pub fn validate(&self, auth_header: &str) -> Option { + let auth_header = auth_header.trim(); + if !auth_header.starts_with("Basic ") { + return None; + } + + let encoded = &auth_header[6..]; + let decoded = BASE64.decode(encoded).ok()?; + let credentials = String::from_utf8(decoded).ok()?; + + let mut parts = credentials.splitn(2, ':'); + let username = parts.next()?; + let password = parts.next()?; + + for (u, p) in &self.users { + if u == username && p == password { + return Some(username.to_string()); + } + } + + None + } + + /// Get the realm for WWW-Authenticate header. + pub fn realm(&self) -> &str { + &self.realm + } + + /// Generate the WWW-Authenticate header value. + pub fn www_authenticate(&self) -> String { + format!("Basic realm=\"{}\"", self.realm) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use base64::Engine; + + fn make_validator() -> BasicAuthValidator { + BasicAuthValidator::new( + vec![ + ("admin".to_string(), "secret".to_string()), + ("user".to_string(), "pass".to_string()), + ], + Some("TestRealm".to_string()), + ) + } + + fn encode_basic(user: &str, pass: &str) -> String { + let encoded = BASE64.encode(format!("{}:{}", user, pass)); + format!("Basic {}", encoded) + } + + #[test] + fn test_valid_credentials() { + let validator = make_validator(); + let header = encode_basic("admin", "secret"); + assert_eq!(validator.validate(&header), Some("admin".to_string())); + } + + #[test] + fn test_invalid_password() { + let validator = make_validator(); + let header = encode_basic("admin", "wrong"); + assert_eq!(validator.validate(&header), None); + } + + #[test] + fn test_not_basic_scheme() { + let validator = make_validator(); + assert_eq!(validator.validate("Bearer sometoken"), None); + } + + #[test] + fn test_malformed_base64() { + let validator = make_validator(); + assert_eq!(validator.validate("Basic !!!not-base64!!!"), None); + } + + #[test] + fn test_www_authenticate_format() { + let validator = make_validator(); + assert_eq!(validator.www_authenticate(), "Basic realm=\"TestRealm\""); + } + + #[test] + fn test_default_realm() { + let validator = BasicAuthValidator::new(vec![], None); + assert_eq!(validator.www_authenticate(), "Basic realm=\"Restricted\""); + } +} diff --git a/rust/crates/rustproxy-security/src/ip_filter.rs b/rust/crates/rustproxy-security/src/ip_filter.rs new file mode 100644 index 0000000..3d0f5dc --- /dev/null +++ b/rust/crates/rustproxy-security/src/ip_filter.rs @@ -0,0 +1,189 @@ +use ipnet::IpNet; +use std::net::IpAddr; +use std::str::FromStr; + +/// IP filter supporting CIDR ranges, wildcards, and exact matches. +pub struct IpFilter { + allow_list: Vec, + block_list: Vec, +} + +/// Represents an IP pattern for matching. +#[derive(Debug)] +enum IpPattern { + /// Exact IP match + Exact(IpAddr), + /// CIDR range match + Cidr(IpNet), + /// Wildcard (matches everything) + Wildcard, +} + +impl IpPattern { + fn parse(s: &str) -> Self { + let s = s.trim(); + if s == "*" { + return IpPattern::Wildcard; + } + if let Ok(net) = IpNet::from_str(s) { + return IpPattern::Cidr(net); + } + if let Ok(addr) = IpAddr::from_str(s) { + return IpPattern::Exact(addr); + } + // Try as CIDR by appending default prefix + if let Ok(addr) = IpAddr::from_str(s) { + return IpPattern::Exact(addr); + } + // Fallback: treat as exact, will never match an invalid string + IpPattern::Exact(IpAddr::from_str("0.0.0.0").unwrap()) + } + + fn matches(&self, ip: &IpAddr) -> bool { + match self { + IpPattern::Wildcard => true, + IpPattern::Exact(addr) => addr == ip, + IpPattern::Cidr(net) => net.contains(ip), + } + } +} + +impl IpFilter { + /// Create a new IP filter from allow and block lists. + pub fn new(allow_list: &[String], block_list: &[String]) -> Self { + Self { + allow_list: allow_list.iter().map(|s| IpPattern::parse(s)).collect(), + block_list: block_list.iter().map(|s| IpPattern::parse(s)).collect(), + } + } + + /// Check if an IP is allowed. + /// If allow_list is non-empty, IP must match at least one entry. + /// If block_list is non-empty, IP must NOT match any entry. + pub fn is_allowed(&self, ip: &IpAddr) -> bool { + // Check block list first + if !self.block_list.is_empty() { + for pattern in &self.block_list { + if pattern.matches(ip) { + return false; + } + } + } + + // If allow list is non-empty, must match at least one + if !self.allow_list.is_empty() { + return self.allow_list.iter().any(|p| p.matches(ip)); + } + + true + } + + /// Normalize IPv4-mapped IPv6 addresses (::ffff:x.x.x.x -> x.x.x.x) + pub fn normalize_ip(ip: &IpAddr) -> IpAddr { + match ip { + IpAddr::V6(v6) => { + if let Some(v4) = v6.to_ipv4_mapped() { + IpAddr::V4(v4) + } else { + *ip + } + } + _ => *ip, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_empty_lists_allow_all() { + let filter = IpFilter::new(&[], &[]); + let ip: IpAddr = "192.168.1.1".parse().unwrap(); + assert!(filter.is_allowed(&ip)); + } + + #[test] + fn test_allow_list_exact() { + let filter = IpFilter::new( + &["10.0.0.1".to_string()], + &[], + ); + let allowed: IpAddr = "10.0.0.1".parse().unwrap(); + let denied: IpAddr = "10.0.0.2".parse().unwrap(); + assert!(filter.is_allowed(&allowed)); + assert!(!filter.is_allowed(&denied)); + } + + #[test] + fn test_allow_list_cidr() { + let filter = IpFilter::new( + &["10.0.0.0/8".to_string()], + &[], + ); + let allowed: IpAddr = "10.255.255.255".parse().unwrap(); + let denied: IpAddr = "192.168.1.1".parse().unwrap(); + assert!(filter.is_allowed(&allowed)); + assert!(!filter.is_allowed(&denied)); + } + + #[test] + fn test_block_list() { + let filter = IpFilter::new( + &[], + &["192.168.1.100".to_string()], + ); + let blocked: IpAddr = "192.168.1.100".parse().unwrap(); + let allowed: IpAddr = "192.168.1.101".parse().unwrap(); + assert!(!filter.is_allowed(&blocked)); + assert!(filter.is_allowed(&allowed)); + } + + #[test] + fn test_block_trumps_allow() { + let filter = IpFilter::new( + &["10.0.0.0/8".to_string()], + &["10.0.0.5".to_string()], + ); + let blocked: IpAddr = "10.0.0.5".parse().unwrap(); + let allowed: IpAddr = "10.0.0.6".parse().unwrap(); + assert!(!filter.is_allowed(&blocked)); + assert!(filter.is_allowed(&allowed)); + } + + #[test] + fn test_wildcard_allow() { + let filter = IpFilter::new( + &["*".to_string()], + &[], + ); + let ip: IpAddr = "1.2.3.4".parse().unwrap(); + assert!(filter.is_allowed(&ip)); + } + + #[test] + fn test_wildcard_block() { + let filter = IpFilter::new( + &[], + &["*".to_string()], + ); + let ip: IpAddr = "1.2.3.4".parse().unwrap(); + assert!(!filter.is_allowed(&ip)); + } + + #[test] + fn test_normalize_ipv4_mapped_ipv6() { + let mapped: IpAddr = "::ffff:192.168.1.1".parse().unwrap(); + let normalized = IpFilter::normalize_ip(&mapped); + let expected: IpAddr = "192.168.1.1".parse().unwrap(); + assert_eq!(normalized, expected); + } + + #[test] + fn test_normalize_pure_ipv4() { + let ip: IpAddr = "10.0.0.1".parse().unwrap(); + let normalized = IpFilter::normalize_ip(&ip); + assert_eq!(normalized, ip); + } +} diff --git a/rust/crates/rustproxy-security/src/jwt_auth.rs b/rust/crates/rustproxy-security/src/jwt_auth.rs new file mode 100644 index 0000000..e7a07de --- /dev/null +++ b/rust/crates/rustproxy-security/src/jwt_auth.rs @@ -0,0 +1,174 @@ +use jsonwebtoken::{decode, DecodingKey, Validation, Algorithm}; +use serde::{Deserialize, Serialize}; + +/// JWT claims (minimal structure). +#[derive(Debug, Serialize, Deserialize)] +pub struct Claims { + pub sub: Option, + pub exp: Option, + pub iss: Option, + pub aud: Option, +} + +/// JWT auth validator. +pub struct JwtValidator { + decoding_key: DecodingKey, + validation: Validation, +} + +impl JwtValidator { + pub fn new( + secret: &str, + algorithm: Option<&str>, + issuer: Option<&str>, + audience: Option<&str>, + ) -> Self { + let algo = match algorithm { + Some("HS384") => Algorithm::HS384, + Some("HS512") => Algorithm::HS512, + Some("RS256") => Algorithm::RS256, + _ => Algorithm::HS256, + }; + + let mut validation = Validation::new(algo); + if let Some(iss) = issuer { + validation.set_issuer(&[iss]); + } + if let Some(aud) = audience { + validation.set_audience(&[aud]); + } + + Self { + decoding_key: DecodingKey::from_secret(secret.as_bytes()), + validation, + } + } + + /// Validate a JWT token string (without "Bearer " prefix). + /// Returns the claims if valid. + pub fn validate(&self, token: &str) -> Result { + decode::(token, &self.decoding_key, &self.validation) + .map(|data| data.claims) + .map_err(|e| e.to_string()) + } + + /// Extract token from Authorization header. + pub fn extract_token(auth_header: &str) -> Option<&str> { + let header = auth_header.trim(); + if header.starts_with("Bearer ") { + Some(&header[7..]) + } else { + None + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use jsonwebtoken::{encode, EncodingKey, Header}; + + fn make_token(secret: &str, claims: &Claims) -> String { + encode( + &Header::default(), + claims, + &EncodingKey::from_secret(secret.as_bytes()), + ) + .unwrap() + } + + fn future_exp() -> u64 { + use std::time::{SystemTime, UNIX_EPOCH}; + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs() + + 3600 + } + + fn past_exp() -> u64 { + use std::time::{SystemTime, UNIX_EPOCH}; + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs() + - 3600 + } + + #[test] + fn test_valid_token() { + let secret = "test-secret"; + let claims = Claims { + sub: Some("user123".to_string()), + exp: Some(future_exp()), + iss: None, + aud: None, + }; + let token = make_token(secret, &claims); + let validator = JwtValidator::new(secret, None, None, None); + let result = validator.validate(&token); + assert!(result.is_ok()); + assert_eq!(result.unwrap().sub, Some("user123".to_string())); + } + + #[test] + fn test_expired_token() { + let secret = "test-secret"; + let claims = Claims { + sub: Some("user123".to_string()), + exp: Some(past_exp()), + iss: None, + aud: None, + }; + let token = make_token(secret, &claims); + let validator = JwtValidator::new(secret, None, None, None); + assert!(validator.validate(&token).is_err()); + } + + #[test] + fn test_wrong_secret() { + let claims = Claims { + sub: Some("user123".to_string()), + exp: Some(future_exp()), + iss: None, + aud: None, + }; + let token = make_token("correct-secret", &claims); + let validator = JwtValidator::new("wrong-secret", None, None, None); + assert!(validator.validate(&token).is_err()); + } + + #[test] + fn test_issuer_validation() { + let secret = "test-secret"; + let claims = Claims { + sub: Some("user123".to_string()), + exp: Some(future_exp()), + iss: Some("my-issuer".to_string()), + aud: None, + }; + let token = make_token(secret, &claims); + + // Correct issuer + let validator = JwtValidator::new(secret, None, Some("my-issuer"), None); + assert!(validator.validate(&token).is_ok()); + + // Wrong issuer + let validator = JwtValidator::new(secret, None, Some("other-issuer"), None); + assert!(validator.validate(&token).is_err()); + } + + #[test] + fn test_extract_token_bearer() { + assert_eq!( + JwtValidator::extract_token("Bearer abc123"), + Some("abc123") + ); + } + + #[test] + fn test_extract_token_non_bearer() { + assert_eq!(JwtValidator::extract_token("Basic abc123"), None); + assert_eq!(JwtValidator::extract_token("abc123"), None); + } +} diff --git a/rust/crates/rustproxy-security/src/lib.rs b/rust/crates/rustproxy-security/src/lib.rs new file mode 100644 index 0000000..7753d9e --- /dev/null +++ b/rust/crates/rustproxy-security/src/lib.rs @@ -0,0 +1,13 @@ +//! # rustproxy-security +//! +//! IP filtering, rate limiting, and authentication for RustProxy. + +pub mod ip_filter; +pub mod rate_limiter; +pub mod basic_auth; +pub mod jwt_auth; + +pub use ip_filter::*; +pub use rate_limiter::*; +pub use basic_auth::*; +pub use jwt_auth::*; diff --git a/rust/crates/rustproxy-security/src/rate_limiter.rs b/rust/crates/rustproxy-security/src/rate_limiter.rs new file mode 100644 index 0000000..444cc80 --- /dev/null +++ b/rust/crates/rustproxy-security/src/rate_limiter.rs @@ -0,0 +1,97 @@ +use dashmap::DashMap; +use std::time::Instant; + +/// Sliding window rate limiter. +pub struct RateLimiter { + /// Map of key -> list of request timestamps + windows: DashMap>, + /// Maximum requests per window + max_requests: u64, + /// Window duration in seconds + window_seconds: u64, +} + +impl RateLimiter { + pub fn new(max_requests: u64, window_seconds: u64) -> Self { + Self { + windows: DashMap::new(), + max_requests, + window_seconds, + } + } + + /// Check if a request is allowed for the given key. + /// Returns true if allowed, false if rate limited. + pub fn check(&self, key: &str) -> bool { + let now = Instant::now(); + let window = std::time::Duration::from_secs(self.window_seconds); + + let mut entry = self.windows.entry(key.to_string()).or_default(); + let timestamps = entry.value_mut(); + + // Remove expired entries + timestamps.retain(|t| now.duration_since(*t) < window); + + if timestamps.len() as u64 >= self.max_requests { + false + } else { + timestamps.push(now); + true + } + } + + /// Clean up expired entries (call periodically). + pub fn cleanup(&self) { + let now = Instant::now(); + let window = std::time::Duration::from_secs(self.window_seconds); + + self.windows.retain(|_, timestamps| { + timestamps.retain(|t| now.duration_since(*t) < window); + !timestamps.is_empty() + }); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_allow_under_limit() { + let limiter = RateLimiter::new(5, 60); + for _ in 0..5 { + assert!(limiter.check("client-1")); + } + } + + #[test] + fn test_block_over_limit() { + let limiter = RateLimiter::new(3, 60); + assert!(limiter.check("client-1")); + assert!(limiter.check("client-1")); + assert!(limiter.check("client-1")); + assert!(!limiter.check("client-1")); // 4th request blocked + } + + #[test] + fn test_different_keys_independent() { + let limiter = RateLimiter::new(2, 60); + assert!(limiter.check("client-a")); + assert!(limiter.check("client-a")); + assert!(!limiter.check("client-a")); // blocked + // Different key should still be allowed + assert!(limiter.check("client-b")); + assert!(limiter.check("client-b")); + } + + #[test] + fn test_cleanup_removes_expired() { + let limiter = RateLimiter::new(100, 0); // 0 second window = immediately expired + limiter.check("client-1"); + // Sleep briefly to let entries expire + std::thread::sleep(std::time::Duration::from_millis(10)); + limiter.cleanup(); + // After cleanup, the key should be allowed again (entries expired) + assert!(limiter.check("client-1")); + } +} diff --git a/rust/crates/rustproxy-tls/Cargo.toml b/rust/crates/rustproxy-tls/Cargo.toml new file mode 100644 index 0000000..cb90a8e --- /dev/null +++ b/rust/crates/rustproxy-tls/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "rustproxy-tls" +version.workspace = true +edition.workspace = true +license.workspace = true +authors.workspace = true +description = "TLS certificate management for RustProxy" + +[dependencies] +rustproxy-config = { workspace = true } +tokio = { workspace = true } +rustls = { workspace = true } +instant-acme = { workspace = true } +tracing = { workspace = true } +thiserror = { workspace = true } +anyhow = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } +rcgen = { workspace = true } + +[dev-dependencies] +tempfile = { workspace = true } diff --git a/rust/crates/rustproxy-tls/src/acme.rs b/rust/crates/rustproxy-tls/src/acme.rs new file mode 100644 index 0000000..973f072 --- /dev/null +++ b/rust/crates/rustproxy-tls/src/acme.rs @@ -0,0 +1,360 @@ +//! ACME (Let's Encrypt) integration using instant-acme. +//! +//! This module handles HTTP-01 challenge creation and certificate provisioning. +//! Supports persisting ACME account credentials to disk for reuse across restarts. + +use std::path::{Path, PathBuf}; +use instant_acme::{ + Account, NewAccount, NewOrder, Identifier, ChallengeType, OrderStatus, + AccountCredentials, +}; +use rcgen::{CertificateParams, KeyPair}; +use thiserror::Error; +use tracing::{debug, info, warn}; + +#[derive(Debug, Error)] +pub enum AcmeError { + #[error("ACME account creation failed: {0}")] + AccountCreation(String), + #[error("ACME order failed: {0}")] + OrderFailed(String), + #[error("Challenge failed: {0}")] + ChallengeFailed(String), + #[error("Certificate finalization failed: {0}")] + FinalizationFailed(String), + #[error("No HTTP-01 challenge found")] + NoHttp01Challenge, + #[error("Timeout waiting for order: {0}")] + Timeout(String), + #[error("Account persistence error: {0}")] + Persistence(String), +} + +/// Pending HTTP-01 challenge that needs to be served. +pub struct PendingChallenge { + pub token: String, + pub key_authorization: String, + pub domain: String, +} + +/// ACME client wrapper around instant-acme. +pub struct AcmeClient { + use_production: bool, + email: String, + /// Optional directory where account.json is persisted. + account_dir: Option, +} + +impl AcmeClient { + pub fn new(email: String, use_production: bool) -> Self { + Self { + use_production, + email, + account_dir: None, + } + } + + /// Create a new client with account persistence at the given directory. + pub fn with_persistence(email: String, use_production: bool, account_dir: impl AsRef) -> Self { + Self { + use_production, + email, + account_dir: Some(account_dir.as_ref().to_path_buf()), + } + } + + /// Get or create an ACME account, persisting credentials if account_dir is set. + async fn get_or_create_account(&self) -> Result { + let directory_url = self.directory_url(); + + // Try to restore from persisted credentials + if let Some(ref dir) = self.account_dir { + let account_file = dir.join("account.json"); + if account_file.exists() { + match std::fs::read_to_string(&account_file) { + Ok(json) => { + match serde_json::from_str::(&json) { + Ok(credentials) => { + match Account::from_credentials(credentials).await { + Ok(account) => { + debug!("Restored ACME account from {}", account_file.display()); + return Ok(account); + } + Err(e) => { + warn!("Failed to restore ACME account, creating new: {}", e); + } + } + } + Err(e) => { + warn!("Invalid account.json, creating new account: {}", e); + } + } + } + Err(e) => { + warn!("Could not read account.json: {}", e); + } + } + } + } + + // Create a new account + let contact = format!("mailto:{}", self.email); + let (account, credentials) = Account::create( + &NewAccount { + contact: &[&contact], + terms_of_service_agreed: true, + only_return_existing: false, + }, + directory_url, + None, + ) + .await + .map_err(|e| AcmeError::AccountCreation(e.to_string()))?; + + debug!("ACME account created"); + + // Persist credentials if we have a directory + if let Some(ref dir) = self.account_dir { + if let Err(e) = std::fs::create_dir_all(dir) { + warn!("Failed to create account directory {}: {}", dir.display(), e); + } else { + let account_file = dir.join("account.json"); + match serde_json::to_string_pretty(&credentials) { + Ok(json) => { + if let Err(e) = std::fs::write(&account_file, &json) { + warn!("Failed to persist ACME account to {}: {}", account_file.display(), e); + } else { + info!("ACME account credentials persisted to {}", account_file.display()); + } + } + Err(e) => { + warn!("Failed to serialize account credentials: {}", e); + } + } + } + } + + Ok(account) + } + + /// Request a certificate for a domain using the HTTP-01 challenge. + /// + /// Returns (cert_chain_pem, private_key_pem) on success. + /// + /// The caller must serve the HTTP-01 challenge at: + /// `http:///.well-known/acme-challenge/` + /// + /// The `challenge_handler` closure is called with a `PendingChallenge` + /// and must arrange for the challenge response to be served. It should + /// return once the challenge is ready to be validated. + pub async fn provision( + &self, + domain: &str, + challenge_handler: F, + ) -> Result<(String, String), AcmeError> + where + F: FnOnce(PendingChallenge) -> Fut, + Fut: std::future::Future>, + { + info!("Starting ACME provisioning for {} via {}", domain, self.directory_url()); + + // 1. Get or create ACME account (with persistence) + let account = self.get_or_create_account().await?; + + // 2. Create order + let identifier = Identifier::Dns(domain.to_string()); + let mut order = account + .new_order(&NewOrder { + identifiers: &[identifier], + }) + .await + .map_err(|e| AcmeError::OrderFailed(e.to_string()))?; + + debug!("ACME order created"); + + // 3. Get authorizations and find HTTP-01 challenge + let authorizations = order + .authorizations() + .await + .map_err(|e| AcmeError::OrderFailed(e.to_string()))?; + + // Find the HTTP-01 challenge + let (challenge_token, challenge_url) = authorizations + .iter() + .flat_map(|auth| auth.challenges.iter()) + .find(|c| c.r#type == ChallengeType::Http01) + .map(|c| { + let key_auth = order.key_authorization(c); + ( + PendingChallenge { + token: c.token.clone(), + key_authorization: key_auth.as_str().to_string(), + domain: domain.to_string(), + }, + c.url.clone(), + ) + }) + .ok_or(AcmeError::NoHttp01Challenge)?; + + // Call the handler to set up challenge serving + challenge_handler(challenge_token).await?; + + // 4. Notify ACME server that challenge is ready + order + .set_challenge_ready(&challenge_url) + .await + .map_err(|e| AcmeError::ChallengeFailed(e.to_string()))?; + + debug!("Challenge marked as ready, waiting for validation..."); + + // 5. Poll for order to become ready + let mut attempts = 0; + let state = loop { + tokio::time::sleep(std::time::Duration::from_secs(2)).await; + let state = order + .refresh() + .await + .map_err(|e| AcmeError::OrderFailed(e.to_string()))?; + + match state.status { + OrderStatus::Ready | OrderStatus::Valid => break state.status, + OrderStatus::Invalid => { + return Err(AcmeError::ChallengeFailed( + "Order became invalid (challenge failed)".to_string(), + )); + } + _ => { + attempts += 1; + if attempts > 30 { + return Err(AcmeError::Timeout( + "Order did not become ready within 60 seconds".to_string(), + )); + } + } + } + }; + + debug!("Order ready, finalizing..."); + + // 6. Generate CSR and finalize + let key_pair = KeyPair::generate().map_err(|e| { + AcmeError::FinalizationFailed(format!("Key generation failed: {}", e)) + })?; + + let mut params = CertificateParams::new(vec![domain.to_string()]).map_err(|e| { + AcmeError::FinalizationFailed(format!("CSR params failed: {}", e)) + })?; + params.distinguished_name.push(rcgen::DnType::CommonName, domain); + + let csr = params.serialize_request(&key_pair).map_err(|e| { + AcmeError::FinalizationFailed(format!("CSR serialization failed: {}", e)) + })?; + + if state == OrderStatus::Ready { + order + .finalize(csr.der()) + .await + .map_err(|e| AcmeError::FinalizationFailed(e.to_string()))?; + } + + // 7. Wait for certificate to be issued + let mut attempts = 0; + loop { + let state = order + .refresh() + .await + .map_err(|e| AcmeError::OrderFailed(e.to_string()))?; + if state.status == OrderStatus::Valid { + break; + } + if state.status == OrderStatus::Invalid { + return Err(AcmeError::FinalizationFailed( + "Order became invalid during finalization".to_string(), + )); + } + attempts += 1; + if attempts > 15 { + return Err(AcmeError::Timeout( + "Certificate not issued within 30 seconds".to_string(), + )); + } + tokio::time::sleep(std::time::Duration::from_secs(2)).await; + } + + // 8. Download certificate + let cert_chain_pem = order + .certificate() + .await + .map_err(|e| AcmeError::FinalizationFailed(e.to_string()))? + .ok_or_else(|| { + AcmeError::FinalizationFailed("No certificate returned".to_string()) + })?; + + let private_key_pem = key_pair.serialize_pem(); + + info!("Certificate provisioned successfully for {}", domain); + + Ok((cert_chain_pem, private_key_pem)) + } + + /// Restore an ACME account from stored credentials. + pub async fn restore_account( + &self, + credentials: AccountCredentials, + ) -> Result { + Account::from_credentials(credentials) + .await + .map_err(|e| AcmeError::AccountCreation(e.to_string())) + } + + /// Get the ACME directory URL based on production/staging. + pub fn directory_url(&self) -> &str { + if self.use_production { + "https://acme-v02.api.letsencrypt.org/directory" + } else { + "https://acme-staging-v02.api.letsencrypt.org/directory" + } + } + + /// Whether this client is configured for production. + pub fn is_production(&self) -> bool { + self.use_production + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_directory_url_staging() { + let client = AcmeClient::new("test@example.com".to_string(), false); + assert!(client.directory_url().contains("staging")); + assert!(!client.is_production()); + } + + #[test] + fn test_directory_url_production() { + let client = AcmeClient::new("test@example.com".to_string(), true); + assert!(!client.directory_url().contains("staging")); + assert!(client.is_production()); + } + + #[test] + fn test_with_persistence_sets_account_dir() { + let tmp = tempfile::tempdir().unwrap(); + let client = AcmeClient::with_persistence( + "test@example.com".to_string(), + false, + tmp.path(), + ); + assert!(client.account_dir.is_some()); + assert_eq!(client.account_dir.unwrap(), tmp.path()); + } + + #[test] + fn test_without_persistence_no_account_dir() { + let client = AcmeClient::new("test@example.com".to_string(), false); + assert!(client.account_dir.is_none()); + } +} diff --git a/rust/crates/rustproxy-tls/src/cert_manager.rs b/rust/crates/rustproxy-tls/src/cert_manager.rs new file mode 100644 index 0000000..b9af9b7 --- /dev/null +++ b/rust/crates/rustproxy-tls/src/cert_manager.rs @@ -0,0 +1,183 @@ +use std::time::{SystemTime, UNIX_EPOCH}; +use thiserror::Error; +use tracing::info; + +use crate::cert_store::{CertStore, CertBundle, CertMetadata, CertSource}; +use crate::acme::AcmeClient; + +#[derive(Debug, Error)] +pub enum CertManagerError { + #[error("ACME provisioning failed for {domain}: {message}")] + AcmeFailure { domain: String, message: String }, + #[error("Certificate store error: {0}")] + Store(#[from] crate::cert_store::CertStoreError), + #[error("No ACME email configured")] + NoEmail, +} + +/// Certificate lifecycle manager. +/// Handles ACME provisioning, static cert loading, and renewal. +pub struct CertManager { + store: CertStore, + acme_email: Option, + use_production: bool, + renew_before_days: u32, +} + +impl CertManager { + pub fn new( + store: CertStore, + acme_email: Option, + use_production: bool, + renew_before_days: u32, + ) -> Self { + Self { + store, + acme_email, + use_production, + renew_before_days, + } + } + + /// Get a certificate for a domain (from cache). + pub fn get_cert(&self, domain: &str) -> Option<&CertBundle> { + self.store.get(domain) + } + + /// Create an ACME client using this manager's configuration. + /// Returns None if no ACME email is configured. + /// Account credentials are persisted in the cert store base directory. + pub fn acme_client(&self) -> Option { + self.acme_email.as_ref().map(|email| { + AcmeClient::with_persistence( + email.clone(), + self.use_production, + self.store.base_dir(), + ) + }) + } + + /// Load a static certificate into the store. + pub fn load_static( + &mut self, + domain: String, + bundle: CertBundle, + ) -> Result<(), CertManagerError> { + self.store.store(domain, bundle)?; + Ok(()) + } + + /// Check and return domains that need certificate renewal. + /// + /// A certificate needs renewal if it expires within `renew_before_days`. + /// Returns a list of domain names needing renewal. + pub fn check_renewals(&self) -> Vec { + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_secs(); + + let renewal_threshold = self.renew_before_days as u64 * 86400; + let mut needs_renewal = Vec::new(); + + for (domain, bundle) in self.store.iter() { + // Only auto-renew ACME certs + if bundle.metadata.source != CertSource::Acme { + continue; + } + + let time_until_expiry = bundle.metadata.expires_at.saturating_sub(now); + if time_until_expiry < renewal_threshold { + info!( + "Certificate for {} needs renewal (expires in {} days)", + domain, + time_until_expiry / 86400 + ); + needs_renewal.push(domain.clone()); + } + } + + needs_renewal + } + + /// Renew a certificate for a domain. + /// + /// Performs the full ACME provision+store flow. The `challenge_setup` closure + /// is called to arrange for the HTTP-01 challenge to be served. It receives + /// (token, key_authorization) and must make the challenge response available. + /// + /// Returns the new CertBundle on success. + pub async fn renew_domain( + &mut self, + domain: &str, + challenge_setup: F, + ) -> Result + where + F: FnOnce(String, String) -> Fut, + Fut: std::future::Future, + { + let acme_client = self.acme_client() + .ok_or(CertManagerError::NoEmail)?; + + info!("Renewing certificate for {}", domain); + + let domain_owned = domain.to_string(); + let result = acme_client.provision(&domain_owned, |pending| { + let token = pending.token.clone(); + let key_auth = pending.key_authorization.clone(); + async move { + challenge_setup(token, key_auth).await; + Ok(()) + } + }).await.map_err(|e| CertManagerError::AcmeFailure { + domain: domain.to_string(), + message: e.to_string(), + })?; + + let (cert_pem, key_pem) = result; + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_secs(); + + let bundle = CertBundle { + cert_pem, + key_pem, + ca_pem: None, + metadata: CertMetadata { + domain: domain.to_string(), + source: CertSource::Acme, + issued_at: now, + expires_at: now + 90 * 86400, + renewed_at: Some(now), + }, + }; + + self.store.store(domain.to_string(), bundle.clone())?; + info!("Certificate renewed and stored for {}", domain); + + Ok(bundle) + } + + /// Load all certificates from disk. + pub fn load_all(&mut self) -> Result { + let loaded = self.store.load_all()?; + info!("Loaded {} certificates from store", loaded); + Ok(loaded) + } + + /// Whether this manager has an ACME email configured. + pub fn has_acme(&self) -> bool { + self.acme_email.is_some() + } + + /// Get reference to the underlying store. + pub fn store(&self) -> &CertStore { + &self.store + } + + /// Get mutable reference to the underlying store. + pub fn store_mut(&mut self) -> &mut CertStore { + &mut self.store + } +} diff --git a/rust/crates/rustproxy-tls/src/cert_store.rs b/rust/crates/rustproxy-tls/src/cert_store.rs new file mode 100644 index 0000000..0391ed6 --- /dev/null +++ b/rust/crates/rustproxy-tls/src/cert_store.rs @@ -0,0 +1,314 @@ +use std::collections::HashMap; +use std::path::{Path, PathBuf}; +use serde::{Deserialize, Serialize}; +use thiserror::Error; + +#[derive(Debug, Error)] +pub enum CertStoreError { + #[error("Certificate not found for domain: {0}")] + NotFound(String), + #[error("IO error: {0}")] + Io(#[from] std::io::Error), + #[error("Invalid certificate: {0}")] + Invalid(String), + #[error("JSON error: {0}")] + Json(#[from] serde_json::Error), +} + +/// Certificate metadata stored alongside certs on disk. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct CertMetadata { + pub domain: String, + pub source: CertSource, + pub issued_at: u64, + pub expires_at: u64, + pub renewed_at: Option, +} + +/// How a certificate was obtained. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum CertSource { + Acme, + Static, + Custom, + SelfSigned, +} + +/// An in-memory certificate bundle. +#[derive(Debug, Clone)] +pub struct CertBundle { + pub key_pem: String, + pub cert_pem: String, + pub ca_pem: Option, + pub metadata: CertMetadata, +} + +/// Filesystem-backed certificate store. +/// +/// File layout per domain: +/// ```text +/// {base_dir}/{domain}/ +/// key.pem +/// cert.pem +/// ca.pem (optional) +/// metadata.json +/// ``` +pub struct CertStore { + base_dir: PathBuf, + /// In-memory cache of loaded certs + cache: HashMap, +} + +impl CertStore { + /// Create a new cert store at the given directory. + pub fn new(base_dir: impl AsRef) -> Self { + Self { + base_dir: base_dir.as_ref().to_path_buf(), + cache: HashMap::new(), + } + } + + /// Get a certificate by domain. + pub fn get(&self, domain: &str) -> Option<&CertBundle> { + self.cache.get(domain) + } + + /// Store a certificate to both cache and filesystem. + pub fn store(&mut self, domain: String, bundle: CertBundle) -> Result<(), CertStoreError> { + // Sanitize domain for directory name (replace wildcards) + let dir_name = domain.replace('*', "_wildcard_"); + let cert_dir = self.base_dir.join(&dir_name); + + // Create directory + std::fs::create_dir_all(&cert_dir)?; + + // Write key + std::fs::write(cert_dir.join("key.pem"), &bundle.key_pem)?; + + // Write cert + std::fs::write(cert_dir.join("cert.pem"), &bundle.cert_pem)?; + + // Write CA cert if present + if let Some(ref ca) = bundle.ca_pem { + std::fs::write(cert_dir.join("ca.pem"), ca)?; + } + + // Write metadata + let metadata_json = serde_json::to_string_pretty(&bundle.metadata)?; + std::fs::write(cert_dir.join("metadata.json"), metadata_json)?; + + // Update cache + self.cache.insert(domain, bundle); + Ok(()) + } + + /// Check if a certificate exists for a domain. + pub fn has(&self, domain: &str) -> bool { + self.cache.contains_key(domain) + } + + /// Load all certificates from the base directory. + pub fn load_all(&mut self) -> Result { + if !self.base_dir.exists() { + return Ok(0); + } + + let entries = std::fs::read_dir(&self.base_dir)?; + let mut loaded = 0; + + for entry in entries { + let entry = entry?; + let path = entry.path(); + + if !path.is_dir() { + continue; + } + + let metadata_path = path.join("metadata.json"); + let key_path = path.join("key.pem"); + let cert_path = path.join("cert.pem"); + + // All three files must exist + if !metadata_path.exists() || !key_path.exists() || !cert_path.exists() { + continue; + } + + // Load metadata + let metadata_str = std::fs::read_to_string(&metadata_path)?; + let metadata: CertMetadata = serde_json::from_str(&metadata_str)?; + + // Load key and cert + let key_pem = std::fs::read_to_string(&key_path)?; + let cert_pem = std::fs::read_to_string(&cert_path)?; + + // Load CA cert if present + let ca_path = path.join("ca.pem"); + let ca_pem = if ca_path.exists() { + Some(std::fs::read_to_string(&ca_path)?) + } else { + None + }; + + let domain = metadata.domain.clone(); + let bundle = CertBundle { + key_pem, + cert_pem, + ca_pem, + metadata, + }; + + self.cache.insert(domain, bundle); + loaded += 1; + } + + Ok(loaded) + } + + /// Get the base directory. + pub fn base_dir(&self) -> &Path { + &self.base_dir + } + + /// Get the number of cached certificates. + pub fn count(&self) -> usize { + self.cache.len() + } + + /// Iterate over all cached certificates. + pub fn iter(&self) -> impl Iterator { + self.cache.iter() + } + + /// Remove a certificate from cache and filesystem. + pub fn remove(&mut self, domain: &str) -> Result { + let removed = self.cache.remove(domain).is_some(); + if removed { + let dir_name = domain.replace('*', "_wildcard_"); + let cert_dir = self.base_dir.join(&dir_name); + if cert_dir.exists() { + std::fs::remove_dir_all(&cert_dir)?; + } + } + Ok(removed) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn make_test_bundle(domain: &str) -> CertBundle { + CertBundle { + key_pem: "-----BEGIN PRIVATE KEY-----\ntest-key\n-----END PRIVATE KEY-----\n".to_string(), + cert_pem: "-----BEGIN CERTIFICATE-----\ntest-cert\n-----END CERTIFICATE-----\n".to_string(), + ca_pem: None, + metadata: CertMetadata { + domain: domain.to_string(), + source: CertSource::Static, + issued_at: 1700000000, + expires_at: 1700000000 + 90 * 86400, + renewed_at: None, + }, + } + } + + #[test] + fn test_store_and_load_roundtrip() { + let tmp = tempfile::tempdir().unwrap(); + let mut store = CertStore::new(tmp.path()); + + let bundle = make_test_bundle("example.com"); + store.store("example.com".to_string(), bundle.clone()).unwrap(); + + // Verify files exist + let cert_dir = tmp.path().join("example.com"); + assert!(cert_dir.join("key.pem").exists()); + assert!(cert_dir.join("cert.pem").exists()); + assert!(cert_dir.join("metadata.json").exists()); + assert!(!cert_dir.join("ca.pem").exists()); // No CA cert + + // Load into a fresh store + let mut store2 = CertStore::new(tmp.path()); + let loaded = store2.load_all().unwrap(); + assert_eq!(loaded, 1); + + let loaded_bundle = store2.get("example.com").unwrap(); + assert_eq!(loaded_bundle.key_pem, bundle.key_pem); + assert_eq!(loaded_bundle.cert_pem, bundle.cert_pem); + assert_eq!(loaded_bundle.metadata.domain, "example.com"); + assert_eq!(loaded_bundle.metadata.source, CertSource::Static); + } + + #[test] + fn test_store_with_ca_cert() { + let tmp = tempfile::tempdir().unwrap(); + let mut store = CertStore::new(tmp.path()); + + let mut bundle = make_test_bundle("secure.com"); + bundle.ca_pem = Some("-----BEGIN CERTIFICATE-----\nca-cert\n-----END CERTIFICATE-----\n".to_string()); + store.store("secure.com".to_string(), bundle).unwrap(); + + let cert_dir = tmp.path().join("secure.com"); + assert!(cert_dir.join("ca.pem").exists()); + + let mut store2 = CertStore::new(tmp.path()); + store2.load_all().unwrap(); + let loaded = store2.get("secure.com").unwrap(); + assert!(loaded.ca_pem.is_some()); + } + + #[test] + fn test_load_all_multiple_certs() { + let tmp = tempfile::tempdir().unwrap(); + let mut store = CertStore::new(tmp.path()); + + store.store("a.com".to_string(), make_test_bundle("a.com")).unwrap(); + store.store("b.com".to_string(), make_test_bundle("b.com")).unwrap(); + store.store("c.com".to_string(), make_test_bundle("c.com")).unwrap(); + + let mut store2 = CertStore::new(tmp.path()); + let loaded = store2.load_all().unwrap(); + assert_eq!(loaded, 3); + assert!(store2.has("a.com")); + assert!(store2.has("b.com")); + assert!(store2.has("c.com")); + } + + #[test] + fn test_load_all_missing_directory() { + let mut store = CertStore::new("/nonexistent/path/to/certs"); + let loaded = store.load_all().unwrap(); + assert_eq!(loaded, 0); + } + + #[test] + fn test_remove_cert() { + let tmp = tempfile::tempdir().unwrap(); + let mut store = CertStore::new(tmp.path()); + + store.store("remove-me.com".to_string(), make_test_bundle("remove-me.com")).unwrap(); + assert!(store.has("remove-me.com")); + + let removed = store.remove("remove-me.com").unwrap(); + assert!(removed); + assert!(!store.has("remove-me.com")); + assert!(!tmp.path().join("remove-me.com").exists()); + } + + #[test] + fn test_wildcard_domain_storage() { + let tmp = tempfile::tempdir().unwrap(); + let mut store = CertStore::new(tmp.path()); + + store.store("*.example.com".to_string(), make_test_bundle("*.example.com")).unwrap(); + + // Directory should use sanitized name + assert!(tmp.path().join("_wildcard_.example.com").exists()); + + let mut store2 = CertStore::new(tmp.path()); + store2.load_all().unwrap(); + assert!(store2.has("*.example.com")); + } +} diff --git a/rust/crates/rustproxy-tls/src/lib.rs b/rust/crates/rustproxy-tls/src/lib.rs new file mode 100644 index 0000000..4bdaeb5 --- /dev/null +++ b/rust/crates/rustproxy-tls/src/lib.rs @@ -0,0 +1,13 @@ +//! # rustproxy-tls +//! +//! TLS certificate management for RustProxy. +//! Handles ACME (Let's Encrypt), static certificates, and dynamic SNI resolution. + +pub mod cert_store; +pub mod cert_manager; +pub mod acme; +pub mod sni_resolver; + +pub use cert_store::*; +pub use cert_manager::*; +pub use sni_resolver::*; diff --git a/rust/crates/rustproxy-tls/src/sni_resolver.rs b/rust/crates/rustproxy-tls/src/sni_resolver.rs new file mode 100644 index 0000000..a0d9994 --- /dev/null +++ b/rust/crates/rustproxy-tls/src/sni_resolver.rs @@ -0,0 +1,139 @@ +use std::collections::HashMap; +use std::sync::{Arc, RwLock}; + +use crate::cert_store::CertBundle; + +/// Dynamic SNI-based certificate resolver. +/// Used by the TLS stack to select the right certificate based on client SNI. +pub struct SniResolver { + /// Domain -> certificate bundle mapping + certs: RwLock>>, + /// Fallback certificate (used when no SNI or no match) + fallback: RwLock>>, +} + +impl SniResolver { + pub fn new() -> Self { + Self { + certs: RwLock::new(HashMap::new()), + fallback: RwLock::new(None), + } + } + + /// Register a certificate for a domain. + pub fn add_cert(&self, domain: String, bundle: CertBundle) { + let mut certs = self.certs.write().unwrap(); + certs.insert(domain, Arc::new(bundle)); + } + + /// Set the fallback certificate. + pub fn set_fallback(&self, bundle: CertBundle) { + let mut fallback = self.fallback.write().unwrap(); + *fallback = Some(Arc::new(bundle)); + } + + /// Resolve a certificate for the given SNI domain. + pub fn resolve(&self, domain: &str) -> Option> { + let certs = self.certs.read().unwrap(); + + // Try exact match + if let Some(bundle) = certs.get(domain) { + return Some(Arc::clone(bundle)); + } + + // Try wildcard match (e.g., *.example.com) + if let Some(dot_pos) = domain.find('.') { + let wildcard = format!("*.{}", &domain[dot_pos + 1..]); + if let Some(bundle) = certs.get(&wildcard) { + return Some(Arc::clone(bundle)); + } + } + + // Fallback + let fallback = self.fallback.read().unwrap(); + fallback.clone() + } + + /// Remove a certificate for a domain. + pub fn remove_cert(&self, domain: &str) { + let mut certs = self.certs.write().unwrap(); + certs.remove(domain); + } + + /// Get the number of registered certificates. + pub fn cert_count(&self) -> usize { + self.certs.read().unwrap().len() + } +} + +impl Default for SniResolver { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::cert_store::{CertBundle, CertMetadata, CertSource}; + + fn make_bundle(domain: &str) -> CertBundle { + CertBundle { + key_pem: format!("KEY-{}", domain), + cert_pem: format!("CERT-{}", domain), + ca_pem: None, + metadata: CertMetadata { + domain: domain.to_string(), + source: CertSource::Static, + issued_at: 0, + expires_at: 0, + renewed_at: None, + }, + } + } + + #[test] + fn test_exact_domain_resolve() { + let resolver = SniResolver::new(); + resolver.add_cert("example.com".to_string(), make_bundle("example.com")); + let result = resolver.resolve("example.com"); + assert!(result.is_some()); + assert_eq!(result.unwrap().cert_pem, "CERT-example.com"); + } + + #[test] + fn test_wildcard_resolve() { + let resolver = SniResolver::new(); + resolver.add_cert("*.example.com".to_string(), make_bundle("*.example.com")); + let result = resolver.resolve("sub.example.com"); + assert!(result.is_some()); + assert_eq!(result.unwrap().cert_pem, "CERT-*.example.com"); + } + + #[test] + fn test_fallback() { + let resolver = SniResolver::new(); + resolver.set_fallback(make_bundle("fallback")); + let result = resolver.resolve("unknown.com"); + assert!(result.is_some()); + assert_eq!(result.unwrap().cert_pem, "CERT-fallback"); + } + + #[test] + fn test_no_match_no_fallback() { + let resolver = SniResolver::new(); + resolver.add_cert("example.com".to_string(), make_bundle("example.com")); + let result = resolver.resolve("other.com"); + assert!(result.is_none()); + } + + #[test] + fn test_remove_cert() { + let resolver = SniResolver::new(); + resolver.add_cert("example.com".to_string(), make_bundle("example.com")); + assert_eq!(resolver.cert_count(), 1); + resolver.remove_cert("example.com"); + assert_eq!(resolver.cert_count(), 0); + assert!(resolver.resolve("example.com").is_none()); + } +} diff --git a/rust/crates/rustproxy/Cargo.toml b/rust/crates/rustproxy/Cargo.toml new file mode 100644 index 0000000..875ccd8 --- /dev/null +++ b/rust/crates/rustproxy/Cargo.toml @@ -0,0 +1,44 @@ +[package] +name = "rustproxy" +version.workspace = true +edition.workspace = true +license.workspace = true +authors.workspace = true +description = "High-performance multi-protocol proxy built on Pingora, compatible with SmartProxy configuration" + +[[bin]] +name = "rustproxy" +path = "src/main.rs" + +[lib] +name = "rustproxy" +path = "src/lib.rs" + +[dependencies] +rustproxy-config = { workspace = true } +rustproxy-routing = { workspace = true } +rustproxy-tls = { workspace = true } +rustproxy-passthrough = { workspace = true } +rustproxy-http = { workspace = true } +rustproxy-nftables = { workspace = true } +rustproxy-metrics = { workspace = true } +rustproxy-security = { workspace = true } +tokio = { workspace = true } +tracing = { workspace = true } +tracing-subscriber = { workspace = true } +clap = { workspace = true } +anyhow = { workspace = true } +arc-swap = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } +rustls = { workspace = true } +tokio-rustls = { workspace = true } +tokio-util = { workspace = true } +dashmap = { workspace = true } +hyper = { workspace = true } +hyper-util = { workspace = true } +http-body-util = { workspace = true } +bytes = { workspace = true } + +[dev-dependencies] +rcgen = { workspace = true } diff --git a/rust/crates/rustproxy/src/challenge_server.rs b/rust/crates/rustproxy/src/challenge_server.rs new file mode 100644 index 0000000..1b1a27f --- /dev/null +++ b/rust/crates/rustproxy/src/challenge_server.rs @@ -0,0 +1,177 @@ +//! HTTP-01 ACME challenge server. +//! +//! A lightweight HTTP server that serves ACME challenge responses at +//! `/.well-known/acme-challenge/`. + +use std::sync::Arc; + +use bytes::Bytes; +use dashmap::DashMap; +use http_body_util::Full; +use hyper::body::Incoming; +use hyper::{Request, Response, StatusCode}; +use hyper_util::rt::TokioIo; +use tokio::net::TcpListener; +use tokio_util::sync::CancellationToken; +use tracing::{debug, info, error}; + +/// ACME HTTP-01 challenge server. +pub struct ChallengeServer { + /// Token -> key authorization mapping + challenges: Arc>, + /// Cancellation token to stop the server + cancel: CancellationToken, + /// Server task handle + handle: Option>, +} + +impl ChallengeServer { + /// Create a new challenge server (not yet started). + pub fn new() -> Self { + Self { + challenges: Arc::new(DashMap::new()), + cancel: CancellationToken::new(), + handle: None, + } + } + + /// Register a challenge token -> key_authorization mapping. + pub fn set_challenge(&self, token: String, key_authorization: String) { + debug!("Registered ACME challenge: token={}", token); + self.challenges.insert(token, key_authorization); + } + + /// Remove a challenge token. + pub fn remove_challenge(&self, token: &str) { + self.challenges.remove(token); + } + + /// Start the challenge server on the given port. + pub async fn start(&mut self, port: u16) -> Result<(), Box> { + let addr = format!("0.0.0.0:{}", port); + let listener = TcpListener::bind(&addr).await?; + info!("ACME challenge server listening on port {}", port); + + let challenges = Arc::clone(&self.challenges); + let cancel = self.cancel.clone(); + + let handle = tokio::spawn(async move { + loop { + tokio::select! { + _ = cancel.cancelled() => { + info!("ACME challenge server stopping"); + break; + } + result = listener.accept() => { + match result { + Ok((stream, _)) => { + let challenges = Arc::clone(&challenges); + tokio::spawn(async move { + let io = TokioIo::new(stream); + let service = hyper::service::service_fn(move |req: Request| { + let challenges = Arc::clone(&challenges); + async move { + Self::handle_request(req, &challenges) + } + }); + + let conn = hyper::server::conn::http1::Builder::new() + .serve_connection(io, service); + + if let Err(e) = conn.await { + debug!("Challenge server connection error: {}", e); + } + }); + } + Err(e) => { + error!("Challenge server accept error: {}", e); + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + } + } + } + } + } + }); + + self.handle = Some(handle); + Ok(()) + } + + /// Stop the challenge server. + pub async fn stop(&mut self) { + self.cancel.cancel(); + if let Some(handle) = self.handle.take() { + let _ = tokio::time::timeout( + std::time::Duration::from_secs(5), + handle, + ).await; + } + self.challenges.clear(); + self.cancel = CancellationToken::new(); + info!("ACME challenge server stopped"); + } + + /// Handle an HTTP request for ACME challenges. + fn handle_request( + req: Request, + challenges: &DashMap, + ) -> Result>, hyper::Error> { + let path = req.uri().path(); + + if let Some(token) = path.strip_prefix("/.well-known/acme-challenge/") { + if let Some(key_auth) = challenges.get(token) { + debug!("Serving ACME challenge for token: {}", token); + return Ok(Response::builder() + .status(StatusCode::OK) + .header("content-type", "text/plain") + .body(Full::new(Bytes::from(key_auth.value().clone()))) + .unwrap()); + } + } + + Ok(Response::builder() + .status(StatusCode::NOT_FOUND) + .body(Full::new(Bytes::from("Not Found"))) + .unwrap()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_challenge_server_lifecycle() { + let mut server = ChallengeServer::new(); + + // Set a challenge before starting + server.set_challenge("test-token".to_string(), "test-key-auth".to_string()); + + // Start on a random port + server.start(19900).await.unwrap(); + + // Give server a moment to start + tokio::time::sleep(std::time::Duration::from_millis(50)).await; + + // Fetch the challenge + let client = tokio::net::TcpStream::connect("127.0.0.1:19900").await.unwrap(); + let io = TokioIo::new(client); + let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await.unwrap(); + tokio::spawn(async move { let _ = conn.await; }); + + let req = Request::get("/.well-known/acme-challenge/test-token") + .body(Full::new(Bytes::new())) + .unwrap(); + let resp = sender.send_request(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + + // Test 404 for unknown token + let req = Request::get("/.well-known/acme-challenge/unknown") + .body(Full::new(Bytes::new())) + .unwrap(); + let resp = sender.send_request(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::NOT_FOUND); + + server.stop().await; + } +} diff --git a/rust/crates/rustproxy/src/lib.rs b/rust/crates/rustproxy/src/lib.rs new file mode 100644 index 0000000..90a522c --- /dev/null +++ b/rust/crates/rustproxy/src/lib.rs @@ -0,0 +1,931 @@ +//! # RustProxy +//! +//! High-performance multi-protocol proxy built on Rust, +//! compatible with SmartProxy configuration. +//! +//! ## Quick Start +//! +//! ```rust,no_run +//! use rustproxy::RustProxy; +//! use rustproxy_config::{RustProxyOptions, create_https_passthrough_route}; +//! +//! #[tokio::main] +//! async fn main() -> anyhow::Result<()> { +//! let options = RustProxyOptions { +//! routes: vec![ +//! create_https_passthrough_route("example.com", "backend", 443), +//! ], +//! ..Default::default() +//! }; +//! +//! let mut proxy = RustProxy::new(options)?; +//! proxy.start().await?; +//! Ok(()) +//! } +//! ``` + +pub mod challenge_server; +pub mod management; + +use std::collections::HashMap; +use std::sync::Arc; +use std::time::Instant; + +use arc_swap::ArcSwap; +use anyhow::Result; +use tracing::{info, warn, debug, error}; + +// Re-export key types +pub use rustproxy_config; +pub use rustproxy_routing; +pub use rustproxy_passthrough; +pub use rustproxy_tls; +pub use rustproxy_http; +pub use rustproxy_nftables; +pub use rustproxy_metrics; +pub use rustproxy_security; + +use rustproxy_config::{RouteConfig, RustProxyOptions, TlsMode, CertificateSpec, ForwardingEngine}; +use rustproxy_routing::RouteManager; +use rustproxy_passthrough::{TcpListenerManager, TlsCertConfig, ConnectionConfig}; +use rustproxy_metrics::{MetricsCollector, Metrics, Statistics}; +use rustproxy_tls::{CertManager, CertStore, CertBundle, CertMetadata, CertSource}; +use rustproxy_nftables::{NftManager, rule_builder}; + +/// Certificate status. +#[derive(Debug, Clone)] +pub struct CertStatus { + pub domain: String, + pub source: String, + pub expires_at: u64, + pub is_valid: bool, +} + +/// The main RustProxy struct. +/// This is the primary public API matching SmartProxy's interface. +pub struct RustProxy { + options: RustProxyOptions, + route_table: ArcSwap, + listener_manager: Option, + metrics: Arc, + cert_manager: Option>>, + challenge_server: Option, + renewal_handle: Option>, + nft_manager: Option, + started: bool, + started_at: Option, + /// Path to a Unix domain socket for relaying socket-handler connections back to TypeScript. + socket_handler_relay_path: Option, +} + +impl RustProxy { + /// Create a new RustProxy instance with the given configuration. + pub fn new(mut options: RustProxyOptions) -> Result { + // Apply defaults to routes before validation + Self::apply_defaults(&mut options); + + // Validate routes + if let Err(errors) = rustproxy_config::validate_routes(&options.routes) { + for err in &errors { + warn!("Route validation error: {}", err); + } + if !errors.is_empty() { + anyhow::bail!("Route validation failed with {} errors", errors.len()); + } + } + + let route_manager = RouteManager::new(options.routes.clone()); + + // Set up certificate manager if ACME is configured + let cert_manager = Self::build_cert_manager(&options) + .map(|cm| Arc::new(tokio::sync::Mutex::new(cm))); + + Ok(Self { + options, + route_table: ArcSwap::from(Arc::new(route_manager)), + listener_manager: None, + metrics: Arc::new(MetricsCollector::new()), + cert_manager, + challenge_server: None, + renewal_handle: None, + nft_manager: None, + started: false, + started_at: None, + socket_handler_relay_path: None, + }) + } + + /// Apply default configuration to routes that lack targets or security. + fn apply_defaults(options: &mut RustProxyOptions) { + let defaults = match &options.defaults { + Some(d) => d.clone(), + None => return, + }; + + for route in &mut options.routes { + // Apply default target if route has no targets + if route.action.targets.is_none() { + if let Some(ref default_target) = defaults.target { + debug!("Applying default target {}:{} to route {:?}", + default_target.host, default_target.port, + route.name.as_deref().unwrap_or("unnamed")); + route.action.targets = Some(vec![ + rustproxy_config::RouteTarget { + target_match: None, + host: rustproxy_config::HostSpec::Single(default_target.host.clone()), + port: rustproxy_config::PortSpec::Fixed(default_target.port), + tls: None, + websocket: None, + load_balancing: None, + send_proxy_protocol: None, + headers: None, + advanced: None, + priority: None, + } + ]); + } + } + + // Apply default security if route has no security + if route.security.is_none() { + if let Some(ref default_security) = defaults.security { + let mut security = rustproxy_config::RouteSecurity { + ip_allow_list: None, + ip_block_list: None, + max_connections: default_security.max_connections, + authentication: None, + rate_limit: None, + basic_auth: None, + jwt_auth: None, + }; + + if let Some(ref allow_list) = default_security.ip_allow_list { + security.ip_allow_list = Some(allow_list.clone()); + } + if let Some(ref block_list) = default_security.ip_block_list { + security.ip_block_list = Some(block_list.clone()); + } + + // Only apply if there's something meaningful + if security.ip_allow_list.is_some() || security.ip_block_list.is_some() { + debug!("Applying default security to route {:?}", + route.name.as_deref().unwrap_or("unnamed")); + route.security = Some(security); + } + } + } + } + } + + /// Build a CertManager from options. + fn build_cert_manager(options: &RustProxyOptions) -> Option { + let acme = options.acme.as_ref()?; + if !acme.enabled.unwrap_or(false) { + return None; + } + + let store_path = acme.certificate_store + .as_deref() + .unwrap_or("./certs"); + let email = acme.email.clone() + .or_else(|| acme.account_email.clone()); + let use_production = acme.use_production.unwrap_or(false); + let renew_before_days = acme.renew_threshold_days.unwrap_or(30); + + let store = CertStore::new(store_path); + Some(CertManager::new(store, email, use_production, renew_before_days)) + } + + /// Build ConnectionConfig from RustProxyOptions. + fn build_connection_config(options: &RustProxyOptions) -> ConnectionConfig { + ConnectionConfig { + connection_timeout_ms: options.effective_connection_timeout(), + initial_data_timeout_ms: options.effective_initial_data_timeout(), + socket_timeout_ms: options.effective_socket_timeout(), + max_connection_lifetime_ms: options.effective_max_connection_lifetime(), + graceful_shutdown_timeout_ms: options.graceful_shutdown_timeout.unwrap_or(30_000), + max_connections_per_ip: options.max_connections_per_ip, + connection_rate_limit_per_minute: options.connection_rate_limit_per_minute, + keep_alive_treatment: options.keep_alive_treatment.clone(), + keep_alive_inactivity_multiplier: options.keep_alive_inactivity_multiplier, + extended_keep_alive_lifetime_ms: options.extended_keep_alive_lifetime, + accept_proxy_protocol: options.accept_proxy_protocol.unwrap_or(false), + send_proxy_protocol: options.send_proxy_protocol.unwrap_or(false), + } + } + + /// Start the proxy, binding to all configured ports. + pub async fn start(&mut self) -> Result<()> { + if self.started { + anyhow::bail!("Proxy is already started"); + } + + info!("Starting RustProxy..."); + + // Load persisted certificates + if let Some(ref cm) = self.cert_manager { + let mut cm = cm.lock().await; + match cm.load_all() { + Ok(count) => { + if count > 0 { + info!("Loaded {} persisted certificates", count); + } + } + Err(e) => warn!("Failed to load persisted certificates: {}", e), + } + } + + // Auto-provision certificates for routes with certificate: 'auto' + self.auto_provision_certificates().await; + + let route_manager = self.route_table.load(); + let ports = route_manager.listening_ports(); + + info!("Configured {} routes on {} ports", route_manager.route_count(), ports.len()); + + // Create TCP listener manager with metrics + let mut listener = TcpListenerManager::with_metrics( + Arc::clone(&*route_manager), + Arc::clone(&self.metrics), + ); + + // Apply connection config from options + let conn_config = Self::build_connection_config(&self.options); + debug!("Connection config: timeout={}ms, initial_data={}ms, socket={}ms, max_life={}ms", + conn_config.connection_timeout_ms, + conn_config.initial_data_timeout_ms, + conn_config.socket_timeout_ms, + conn_config.max_connection_lifetime_ms, + ); + listener.set_connection_config(conn_config); + + // Extract TLS configurations from routes and cert manager + let mut tls_configs = Self::extract_tls_configs(&self.options.routes); + + // Also load certs from cert manager into TLS config + if let Some(ref cm) = self.cert_manager { + let cm = cm.lock().await; + for (domain, bundle) in cm.store().iter() { + if !tls_configs.contains_key(domain) { + tls_configs.insert(domain.clone(), TlsCertConfig { + cert_pem: bundle.cert_pem.clone(), + key_pem: bundle.key_pem.clone(), + }); + } + } + } + + if !tls_configs.is_empty() { + debug!("Loaded TLS certificates for {} domains", tls_configs.len()); + listener.set_tls_configs(tls_configs); + } + + // Bind all ports + for port in &ports { + listener.add_port(*port).await?; + } + + self.listener_manager = Some(listener); + self.started = true; + self.started_at = Some(Instant::now()); + + // Apply NFTables rules for routes using nftables forwarding engine + self.apply_nftables_rules(&self.options.routes.clone()).await; + + // Start renewal timer if ACME is enabled + self.start_renewal_timer(); + + info!("RustProxy started successfully on ports: {:?}", ports); + Ok(()) + } + + /// Auto-provision certificates for routes that use certificate: 'auto'. + async fn auto_provision_certificates(&mut self) { + let cm_arc = match self.cert_manager { + Some(ref cm) => Arc::clone(cm), + None => return, + }; + + let mut domains_to_provision = Vec::new(); + + for route in &self.options.routes { + let tls_mode = route.tls_mode(); + let needs_cert = matches!( + tls_mode, + Some(TlsMode::Terminate) | Some(TlsMode::TerminateAndReencrypt) + ); + if !needs_cert { + continue; + } + + let cert_spec = route.action.tls.as_ref() + .and_then(|tls| tls.certificate.as_ref()); + + if let Some(CertificateSpec::Auto(_)) = cert_spec { + if let Some(ref domains) = route.route_match.domains { + for domain in domains.to_vec() { + let domain = domain.to_string(); + // Skip if we already have a valid cert + let cm = cm_arc.lock().await; + if cm.store().has(&domain) { + debug!("Already have cert for {}, skipping auto-provision", domain); + continue; + } + drop(cm); + domains_to_provision.push(domain); + } + } + } + } + + if domains_to_provision.is_empty() { + return; + } + + info!("Auto-provisioning certificates for {} domains", domains_to_provision.len()); + + // Start challenge server + let acme_port = self.options.acme.as_ref() + .and_then(|a| a.port) + .unwrap_or(80); + + let mut challenge_server = challenge_server::ChallengeServer::new(); + if let Err(e) = challenge_server.start(acme_port).await { + error!("Failed to start ACME challenge server on port {}: {}", acme_port, e); + return; + } + + for domain in &domains_to_provision { + info!("Provisioning certificate for {}", domain); + + let cm = cm_arc.lock().await; + let acme_client = cm.acme_client(); + drop(cm); + + if let Some(acme_client) = acme_client { + let challenge_server_ref = &challenge_server; + let result = acme_client.provision(domain, |pending| { + challenge_server_ref.set_challenge( + pending.token.clone(), + pending.key_authorization.clone(), + ); + async move { Ok(()) } + }).await; + + match result { + Ok((cert_pem, key_pem)) => { + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs(); + + let bundle = CertBundle { + cert_pem, + key_pem, + ca_pem: None, + metadata: CertMetadata { + domain: domain.clone(), + source: CertSource::Acme, + issued_at: now, + expires_at: now + 90 * 86400, // 90 days + renewed_at: None, + }, + }; + + let mut cm = cm_arc.lock().await; + if let Err(e) = cm.load_static(domain.clone(), bundle) { + error!("Failed to store certificate for {}: {}", domain, e); + } + + info!("Certificate provisioned for {}", domain); + } + Err(e) => { + error!("Failed to provision certificate for {}: {}", domain, e); + } + } + } + } + + challenge_server.stop().await; + } + + /// Start the renewal timer background task. + /// The background task checks for expiring certificates and renews them. + fn start_renewal_timer(&mut self) { + let cm_arc = match self.cert_manager { + Some(ref cm) => Arc::clone(cm), + None => return, + }; + + let auto_renew = self.options.acme.as_ref() + .and_then(|a| a.auto_renew) + .unwrap_or(true); + + if !auto_renew { + return; + } + + let check_interval_hours = self.options.acme.as_ref() + .and_then(|a| a.renew_check_interval_hours) + .unwrap_or(24); + + let acme_port = self.options.acme.as_ref() + .and_then(|a| a.port) + .unwrap_or(80); + + let interval = std::time::Duration::from_secs(check_interval_hours as u64 * 3600); + + let handle = tokio::spawn(async move { + loop { + tokio::time::sleep(interval).await; + debug!("Certificate renewal check triggered (interval: {}h)", check_interval_hours); + + // Check which domains need renewal + let domains = { + let cm = cm_arc.lock().await; + cm.check_renewals() + }; + + if domains.is_empty() { + debug!("No certificates need renewal"); + continue; + } + + info!("Renewing {} certificate(s)", domains.len()); + + // Start challenge server for renewals + let mut cs = challenge_server::ChallengeServer::new(); + if let Err(e) = cs.start(acme_port).await { + error!("Failed to start challenge server for renewal: {}", e); + continue; + } + + for domain in &domains { + let cs_ref = &cs; + let mut cm = cm_arc.lock().await; + let result = cm.renew_domain(domain, |token, key_auth| { + cs_ref.set_challenge(token, key_auth); + async {} + }).await; + + match result { + Ok(_bundle) => { + info!("Successfully renewed certificate for {}", domain); + } + Err(e) => { + error!("Failed to renew certificate for {}: {}", domain, e); + } + } + } + + cs.stop().await; + } + }); + + self.renewal_handle = Some(handle); + } + + /// Stop the proxy gracefully. + pub async fn stop(&mut self) -> Result<()> { + if !self.started { + return Ok(()); + } + + info!("Stopping RustProxy..."); + + // Stop renewal timer + if let Some(handle) = self.renewal_handle.take() { + handle.abort(); + } + + // Stop challenge server if running + if let Some(ref mut cs) = self.challenge_server { + cs.stop().await; + } + self.challenge_server = None; + + // Clean up NFTables rules + if let Some(ref mut nft) = self.nft_manager { + if let Err(e) = nft.cleanup().await { + warn!("NFTables cleanup failed: {}", e); + } + } + self.nft_manager = None; + + if let Some(ref mut listener) = self.listener_manager { + listener.graceful_stop().await; + } + self.listener_manager = None; + self.started = false; + + info!("RustProxy stopped"); + Ok(()) + } + + /// Update routes atomically (hot-reload). + pub async fn update_routes(&mut self, routes: Vec) -> Result<()> { + // Validate new routes + rustproxy_config::validate_routes(&routes) + .map_err(|errors| { + let msgs: Vec = errors.iter().map(|e| e.to_string()).collect(); + anyhow::anyhow!("Route validation failed: {}", msgs.join(", ")) + })?; + + let new_manager = RouteManager::new(routes.clone()); + let new_ports = new_manager.listening_ports(); + + info!("Updating routes: {} routes on {} ports", + new_manager.route_count(), new_ports.len()); + + // Get old ports + let old_ports: Vec = if let Some(ref listener) = self.listener_manager { + listener.listening_ports() + } else { + vec![] + }; + + // Atomically swap the route table + let new_manager = Arc::new(new_manager); + self.route_table.store(Arc::clone(&new_manager)); + + // Update listener manager + if let Some(ref mut listener) = self.listener_manager { + listener.update_route_manager(Arc::clone(&new_manager)); + + // Update TLS configs + let mut tls_configs = Self::extract_tls_configs(&routes); + if let Some(ref cm_arc) = self.cert_manager { + let cm = cm_arc.lock().await; + for (domain, bundle) in cm.store().iter() { + if !tls_configs.contains_key(domain) { + tls_configs.insert(domain.clone(), TlsCertConfig { + cert_pem: bundle.cert_pem.clone(), + key_pem: bundle.key_pem.clone(), + }); + } + } + } + listener.set_tls_configs(tls_configs); + + // Add new ports + for port in &new_ports { + if !old_ports.contains(port) { + listener.add_port(*port).await?; + } + } + + // Remove old ports no longer needed + for port in &old_ports { + if !new_ports.contains(port) { + listener.remove_port(*port); + } + } + } + + // Update NFTables rules: remove old, apply new + self.update_nftables_rules(&routes).await; + + self.options.routes = routes; + Ok(()) + } + + /// Provision a certificate for a named route. + pub async fn provision_certificate(&mut self, route_name: &str) -> Result<()> { + let cm_arc = self.cert_manager.as_ref() + .ok_or_else(|| anyhow::anyhow!("No certificate manager configured (ACME not enabled)"))?; + + // Find the route by name + let route = self.options.routes.iter() + .find(|r| r.name.as_deref() == Some(route_name)) + .ok_or_else(|| anyhow::anyhow!("Route '{}' not found", route_name))?; + + let domain = route.route_match.domains.as_ref() + .and_then(|d| d.to_vec().first().map(|s| s.to_string())) + .ok_or_else(|| anyhow::anyhow!("Route '{}' has no domain", route_name))?; + + info!("Provisioning certificate for route '{}' (domain: {})", route_name, domain); + + // Start challenge server + let acme_port = self.options.acme.as_ref() + .and_then(|a| a.port) + .unwrap_or(80); + + let mut cs = challenge_server::ChallengeServer::new(); + cs.start(acme_port).await + .map_err(|e| anyhow::anyhow!("Failed to start challenge server: {}", e))?; + + let cs_ref = &cs; + let mut cm = cm_arc.lock().await; + let result = cm.renew_domain(&domain, |token, key_auth| { + cs_ref.set_challenge(token, key_auth); + async {} + }).await; + drop(cm); + + cs.stop().await; + + let bundle = result + .map_err(|e| anyhow::anyhow!("ACME provisioning failed: {}", e))?; + + // Hot-swap into TLS configs + if let Some(ref mut listener) = self.listener_manager { + let mut tls_configs = Self::extract_tls_configs(&self.options.routes); + tls_configs.insert(domain.clone(), TlsCertConfig { + cert_pem: bundle.cert_pem.clone(), + key_pem: bundle.key_pem.clone(), + }); + let cm = cm_arc.lock().await; + for (d, b) in cm.store().iter() { + if !tls_configs.contains_key(d) { + tls_configs.insert(d.clone(), TlsCertConfig { + cert_pem: b.cert_pem.clone(), + key_pem: b.key_pem.clone(), + }); + } + } + listener.set_tls_configs(tls_configs); + } + + info!("Certificate provisioned and loaded for route '{}'", route_name); + Ok(()) + } + + /// Renew a certificate for a named route. + pub async fn renew_certificate(&mut self, route_name: &str) -> Result<()> { + // Renewal is just re-provisioning + self.provision_certificate(route_name).await + } + + /// Get the status of a certificate for a named route. + pub async fn get_certificate_status(&self, route_name: &str) -> Option { + let route = self.options.routes.iter() + .find(|r| r.name.as_deref() == Some(route_name))?; + + let domain = route.route_match.domains.as_ref() + .and_then(|d| d.to_vec().first().map(|s| s.to_string()))?; + + if let Some(ref cm_arc) = self.cert_manager { + let cm = cm_arc.lock().await; + if let Some(bundle) = cm.get_cert(&domain) { + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs(); + + return Some(CertStatus { + domain, + source: format!("{:?}", bundle.metadata.source), + expires_at: bundle.metadata.expires_at, + is_valid: bundle.metadata.expires_at > now, + }); + } + } + + None + } + + /// Get current metrics snapshot. + pub fn get_metrics(&self) -> Metrics { + self.metrics.snapshot() + } + + /// Add a listening port at runtime. + pub async fn add_listening_port(&mut self, port: u16) -> Result<()> { + if let Some(ref mut listener) = self.listener_manager { + listener.add_port(port).await?; + } + Ok(()) + } + + /// Remove a listening port at runtime. + pub async fn remove_listening_port(&mut self, port: u16) -> Result<()> { + if let Some(ref mut listener) = self.listener_manager { + listener.remove_port(port); + } + Ok(()) + } + + /// Get all currently listening ports. + pub fn get_listening_ports(&self) -> Vec { + self.listener_manager + .as_ref() + .map(|l| l.listening_ports()) + .unwrap_or_default() + } + + /// Get statistics snapshot. + pub fn get_statistics(&self) -> Statistics { + let uptime = self.started_at + .map(|t| t.elapsed().as_secs()) + .unwrap_or(0); + + Statistics { + active_connections: self.metrics.active_connections(), + total_connections: self.metrics.total_connections(), + routes_count: self.route_table.load().route_count() as u64, + listening_ports: self.get_listening_ports(), + uptime_seconds: uptime, + } + } + + /// Set the Unix domain socket path for relaying socket-handler connections to TypeScript. + pub fn set_socket_handler_relay_path(&mut self, path: Option) { + info!("Socket handler relay path set to: {:?}", path); + self.socket_handler_relay_path = path; + } + + /// Get the current socket handler relay path. + pub fn get_socket_handler_relay_path(&self) -> Option<&str> { + self.socket_handler_relay_path.as_deref() + } + + /// Load a certificate for a domain and hot-swap the TLS configuration. + pub async fn load_certificate( + &mut self, + domain: &str, + cert_pem: String, + key_pem: String, + ca_pem: Option, + ) -> Result<()> { + info!("Loading certificate for domain: {}", domain); + + // Store in cert manager if available + if let Some(ref cm_arc) = self.cert_manager { + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs(); + + let bundle = CertBundle { + cert_pem: cert_pem.clone(), + key_pem: key_pem.clone(), + ca_pem: ca_pem.clone(), + metadata: CertMetadata { + domain: domain.to_string(), + source: CertSource::Static, + issued_at: now, + expires_at: now + 90 * 86400, // assume 90 days + renewed_at: None, + }, + }; + + let mut cm = cm_arc.lock().await; + cm.load_static(domain.to_string(), bundle) + .map_err(|e| anyhow::anyhow!("Failed to store certificate: {}", e))?; + } + + // Hot-swap TLS config on the listener + if let Some(ref mut listener) = self.listener_manager { + let mut tls_configs = Self::extract_tls_configs(&self.options.routes); + + // Add the new cert + tls_configs.insert(domain.to_string(), TlsCertConfig { + cert_pem: cert_pem.clone(), + key_pem: key_pem.clone(), + }); + + // Also include all existing certs from cert manager + if let Some(ref cm_arc) = self.cert_manager { + let cm = cm_arc.lock().await; + for (d, b) in cm.store().iter() { + if !tls_configs.contains_key(d) { + tls_configs.insert(d.clone(), TlsCertConfig { + cert_pem: b.cert_pem.clone(), + key_pem: b.key_pem.clone(), + }); + } + } + } + + listener.set_tls_configs(tls_configs); + } + + info!("Certificate loaded and TLS config updated for {}", domain); + Ok(()) + } + + /// Get NFTables status. + pub async fn get_nftables_status(&self) -> Result> { + match &self.nft_manager { + Some(nft) => Ok(nft.status()), + None => Ok(HashMap::new()), + } + } + + /// Apply NFTables rules for routes using the nftables forwarding engine. + async fn apply_nftables_rules(&mut self, routes: &[RouteConfig]) { + let nft_routes: Vec<&RouteConfig> = routes.iter() + .filter(|r| r.action.forwarding_engine.as_ref() == Some(&ForwardingEngine::Nftables)) + .collect(); + + if nft_routes.is_empty() { + return; + } + + info!("Applying NFTables rules for {} routes", nft_routes.len()); + + let table_name = nft_routes.iter() + .find_map(|r| r.action.nftables.as_ref()?.table_name.clone()) + .unwrap_or_else(|| "rustproxy".to_string()); + + let mut nft = NftManager::new(Some(table_name)); + + for route in &nft_routes { + let route_id = route.id.as_deref() + .or(route.name.as_deref()) + .unwrap_or("unnamed"); + + let nft_options = match &route.action.nftables { + Some(opts) => opts.clone(), + None => rustproxy_config::NfTablesOptions { + preserve_source_ip: None, + protocol: None, + max_rate: None, + priority: None, + table_name: None, + use_ip_sets: None, + use_advanced_nat: None, + }, + }; + + let targets = match &route.action.targets { + Some(targets) => targets, + None => { + warn!("NFTables route '{}' has no targets, skipping", route_id); + continue; + } + }; + + let source_ports = route.route_match.ports.to_ports(); + for target in targets { + let target_host = target.host.first().to_string(); + let target_port_spec = &target.port; + + for &source_port in &source_ports { + let resolved_port = target_port_spec.resolve(source_port); + let rules = rule_builder::build_dnat_rule( + nft.table_name(), + "prerouting", + source_port, + &target_host, + resolved_port, + &nft_options, + ); + + let rule_id = format!("{}-{}-{}", route_id, source_port, resolved_port); + if let Err(e) = nft.apply_rules(&rule_id, rules).await { + error!("Failed to apply NFTables rules for route '{}': {}", route_id, e); + } + } + } + } + + self.nft_manager = Some(nft); + } + + /// Update NFTables rules when routes change. + async fn update_nftables_rules(&mut self, new_routes: &[RouteConfig]) { + // Clean up old rules + if let Some(ref mut nft) = self.nft_manager { + if let Err(e) = nft.cleanup().await { + warn!("NFTables cleanup during update failed: {}", e); + } + } + self.nft_manager = None; + + // Apply new rules + self.apply_nftables_rules(new_routes).await; + } + + /// Extract TLS configurations from route configs. + fn extract_tls_configs(routes: &[RouteConfig]) -> HashMap { + let mut configs = HashMap::new(); + + for route in routes { + let tls_mode = route.tls_mode(); + let needs_cert = matches!( + tls_mode, + Some(TlsMode::Terminate) | Some(TlsMode::TerminateAndReencrypt) + ); + if !needs_cert { + continue; + } + + let cert_spec = route.action.tls.as_ref() + .and_then(|tls| tls.certificate.as_ref()); + + if let Some(CertificateSpec::Static(cert_config)) = cert_spec { + if let Some(ref domains) = route.route_match.domains { + for domain in domains.to_vec() { + configs.insert(domain.to_string(), TlsCertConfig { + cert_pem: cert_config.cert.clone(), + key_pem: cert_config.key.clone(), + }); + } + } + } + } + + configs + } +} diff --git a/rust/crates/rustproxy/src/main.rs b/rust/crates/rustproxy/src/main.rs new file mode 100644 index 0000000..bf6b8a3 --- /dev/null +++ b/rust/crates/rustproxy/src/main.rs @@ -0,0 +1,90 @@ +use clap::Parser; +use tracing_subscriber::EnvFilter; +use anyhow::Result; + +use rustproxy::RustProxy; +use rustproxy::management; +use rustproxy_config::RustProxyOptions; + +/// RustProxy - High-performance multi-protocol proxy +#[derive(Parser, Debug)] +#[command(name = "rustproxy", version, about)] +struct Cli { + /// Path to JSON configuration file + #[arg(short, long, default_value = "config.json")] + config: String, + + /// Log level (trace, debug, info, warn, error) + #[arg(short, long, default_value = "info")] + log_level: String, + + /// Validate configuration without starting + #[arg(long)] + validate: bool, + + /// Run in management mode (JSON-over-stdin IPC for TypeScript wrapper) + #[arg(long)] + management: bool, +} + +#[tokio::main] +async fn main() -> Result<()> { + let cli = Cli::parse(); + + // Initialize tracing - write to stderr so stdout is reserved for management IPC + tracing_subscriber::fmt() + .with_writer(std::io::stderr) + .with_env_filter( + EnvFilter::try_from_default_env() + .unwrap_or_else(|_| EnvFilter::new(&cli.log_level)) + ) + .init(); + + // Management mode: JSON IPC over stdin/stdout + if cli.management { + tracing::info!("RustProxy starting in management mode..."); + return management::management_loop().await; + } + + tracing::info!("RustProxy starting..."); + + // Load configuration + let options = RustProxyOptions::from_file(&cli.config) + .map_err(|e| anyhow::anyhow!("Failed to load config '{}': {}", cli.config, e))?; + + tracing::info!( + "Loaded {} routes from {}", + options.routes.len(), + cli.config + ); + + // Validate-only mode + if cli.validate { + match rustproxy_config::validate_routes(&options.routes) { + Ok(()) => { + tracing::info!("Configuration is valid"); + return Ok(()); + } + Err(errors) => { + for err in &errors { + tracing::error!("Validation error: {}", err); + } + anyhow::bail!("{} validation errors found", errors.len()); + } + } + } + + // Create and start proxy + let mut proxy = RustProxy::new(options)?; + proxy.start().await?; + + // Wait for shutdown signal + tracing::info!("RustProxy is running. Press Ctrl+C to stop."); + tokio::signal::ctrl_c().await?; + + tracing::info!("Shutdown signal received"); + proxy.stop().await?; + + tracing::info!("RustProxy shutdown complete"); + Ok(()) +} diff --git a/rust/crates/rustproxy/src/management.rs b/rust/crates/rustproxy/src/management.rs new file mode 100644 index 0000000..a3c53f9 --- /dev/null +++ b/rust/crates/rustproxy/src/management.rs @@ -0,0 +1,470 @@ +use anyhow::Result; +use serde::{Deserialize, Serialize}; +use tokio::io::{AsyncBufReadExt, BufReader}; +use tracing::{info, error}; + +use crate::RustProxy; +use rustproxy_config::RustProxyOptions; + +/// A management request from the TypeScript wrapper. +#[derive(Debug, Deserialize)] +pub struct ManagementRequest { + pub id: String, + pub method: String, + #[serde(default)] + pub params: serde_json::Value, +} + +/// A management response back to the TypeScript wrapper. +#[derive(Debug, Serialize)] +pub struct ManagementResponse { + pub id: String, + pub success: bool, + #[serde(skip_serializing_if = "Option::is_none")] + pub result: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option, +} + +/// An unsolicited event from the proxy to the TypeScript wrapper. +#[derive(Debug, Serialize)] +pub struct ManagementEvent { + pub event: String, + pub data: serde_json::Value, +} + +impl ManagementResponse { + fn ok(id: String, result: serde_json::Value) -> Self { + Self { + id, + success: true, + result: Some(result), + error: None, + } + } + + fn err(id: String, message: String) -> Self { + Self { + id, + success: false, + result: None, + error: Some(message), + } + } +} + +fn send_line(line: &str) { + // Use blocking stdout write - we're writing short JSON lines + use std::io::Write; + let stdout = std::io::stdout(); + let mut handle = stdout.lock(); + let _ = handle.write_all(line.as_bytes()); + let _ = handle.write_all(b"\n"); + let _ = handle.flush(); +} + +fn send_response(response: &ManagementResponse) { + match serde_json::to_string(response) { + Ok(json) => send_line(&json), + Err(e) => error!("Failed to serialize management response: {}", e), + } +} + +fn send_event(event: &str, data: serde_json::Value) { + let evt = ManagementEvent { + event: event.to_string(), + data, + }; + match serde_json::to_string(&evt) { + Ok(json) => send_line(&json), + Err(e) => error!("Failed to serialize management event: {}", e), + } +} + +/// Run the management loop, reading JSON commands from stdin and writing responses to stdout. +pub async fn management_loop() -> Result<()> { + let stdin = BufReader::new(tokio::io::stdin()); + let mut lines = stdin.lines(); + let mut proxy: Option = None; + + send_event("ready", serde_json::json!({})); + + loop { + let line = match lines.next_line().await { + Ok(Some(line)) => line, + Ok(None) => { + // stdin closed - parent process exited + info!("Management stdin closed, shutting down"); + if let Some(ref mut p) = proxy { + let _ = p.stop().await; + } + break; + } + Err(e) => { + error!("Error reading management stdin: {}", e); + break; + } + }; + + let line = line.trim().to_string(); + if line.is_empty() { + continue; + } + + let request: ManagementRequest = match serde_json::from_str(&line) { + Ok(r) => r, + Err(e) => { + error!("Failed to parse management request: {}", e); + // Send error response without an ID + send_response(&ManagementResponse::err( + "unknown".to_string(), + format!("Failed to parse request: {}", e), + )); + continue; + } + }; + + let response = handle_request(&request, &mut proxy).await; + send_response(&response); + } + + Ok(()) +} + +async fn handle_request( + request: &ManagementRequest, + proxy: &mut Option, +) -> ManagementResponse { + let id = request.id.clone(); + + match request.method.as_str() { + "start" => handle_start(&id, &request.params, proxy).await, + "stop" => handle_stop(&id, proxy).await, + "updateRoutes" => handle_update_routes(&id, &request.params, proxy).await, + "getMetrics" => handle_get_metrics(&id, proxy), + "getStatistics" => handle_get_statistics(&id, proxy), + "provisionCertificate" => handle_provision_certificate(&id, &request.params, proxy).await, + "renewCertificate" => handle_renew_certificate(&id, &request.params, proxy).await, + "getCertificateStatus" => handle_get_certificate_status(&id, &request.params, proxy).await, + "getListeningPorts" => handle_get_listening_ports(&id, proxy), + "getNftablesStatus" => handle_get_nftables_status(&id, proxy).await, + "setSocketHandlerRelay" => handle_set_socket_handler_relay(&id, &request.params, proxy).await, + "addListeningPort" => handle_add_listening_port(&id, &request.params, proxy).await, + "removeListeningPort" => handle_remove_listening_port(&id, &request.params, proxy).await, + "loadCertificate" => handle_load_certificate(&id, &request.params, proxy).await, + _ => ManagementResponse::err(id, format!("Unknown method: {}", request.method)), + } +} + +async fn handle_start( + id: &str, + params: &serde_json::Value, + proxy: &mut Option, +) -> ManagementResponse { + if proxy.is_some() { + return ManagementResponse::err(id.to_string(), "Proxy is already running".to_string()); + } + + let config = match params.get("config") { + Some(config) => config, + None => return ManagementResponse::err(id.to_string(), "Missing 'config' parameter".to_string()), + }; + + let options: RustProxyOptions = match serde_json::from_value(config.clone()) { + Ok(o) => o, + Err(e) => return ManagementResponse::err(id.to_string(), format!("Invalid config: {}", e)), + }; + + match RustProxy::new(options) { + Ok(mut p) => { + match p.start().await { + Ok(()) => { + send_event("started", serde_json::json!({})); + *proxy = Some(p); + ManagementResponse::ok(id.to_string(), serde_json::json!({})) + } + Err(e) => { + send_event("error", serde_json::json!({"message": format!("{}", e)})); + ManagementResponse::err(id.to_string(), format!("Failed to start: {}", e)) + } + } + } + Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to create proxy: {}", e)), + } +} + +async fn handle_stop( + id: &str, + proxy: &mut Option, +) -> ManagementResponse { + match proxy.as_mut() { + Some(p) => { + match p.stop().await { + Ok(()) => { + *proxy = None; + send_event("stopped", serde_json::json!({})); + ManagementResponse::ok(id.to_string(), serde_json::json!({})) + } + Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to stop: {}", e)), + } + } + None => ManagementResponse::ok(id.to_string(), serde_json::json!({})), + } +} + +async fn handle_update_routes( + id: &str, + params: &serde_json::Value, + proxy: &mut Option, +) -> ManagementResponse { + let p = match proxy.as_mut() { + Some(p) => p, + None => return ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()), + }; + + let routes = match params.get("routes") { + Some(routes) => routes, + None => return ManagementResponse::err(id.to_string(), "Missing 'routes' parameter".to_string()), + }; + + let routes: Vec = match serde_json::from_value(routes.clone()) { + Ok(r) => r, + Err(e) => return ManagementResponse::err(id.to_string(), format!("Invalid routes: {}", e)), + }; + + match p.update_routes(routes).await { + Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})), + Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to update routes: {}", e)), + } +} + +fn handle_get_metrics( + id: &str, + proxy: &Option, +) -> ManagementResponse { + match proxy.as_ref() { + Some(p) => { + let metrics = p.get_metrics(); + match serde_json::to_value(&metrics) { + Ok(v) => ManagementResponse::ok(id.to_string(), v), + Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to serialize metrics: {}", e)), + } + } + None => ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()), + } +} + +fn handle_get_statistics( + id: &str, + proxy: &Option, +) -> ManagementResponse { + match proxy.as_ref() { + Some(p) => { + let stats = p.get_statistics(); + match serde_json::to_value(&stats) { + Ok(v) => ManagementResponse::ok(id.to_string(), v), + Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to serialize statistics: {}", e)), + } + } + None => ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()), + } +} + +async fn handle_provision_certificate( + id: &str, + params: &serde_json::Value, + proxy: &mut Option, +) -> ManagementResponse { + let p = match proxy.as_mut() { + Some(p) => p, + None => return ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()), + }; + + let route_name = match params.get("routeName").and_then(|v| v.as_str()) { + Some(name) => name.to_string(), + None => return ManagementResponse::err(id.to_string(), "Missing 'routeName' parameter".to_string()), + }; + + match p.provision_certificate(&route_name).await { + Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})), + Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to provision certificate: {}", e)), + } +} + +async fn handle_renew_certificate( + id: &str, + params: &serde_json::Value, + proxy: &mut Option, +) -> ManagementResponse { + let p = match proxy.as_mut() { + Some(p) => p, + None => return ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()), + }; + + let route_name = match params.get("routeName").and_then(|v| v.as_str()) { + Some(name) => name.to_string(), + None => return ManagementResponse::err(id.to_string(), "Missing 'routeName' parameter".to_string()), + }; + + match p.renew_certificate(&route_name).await { + Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})), + Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to renew certificate: {}", e)), + } +} + +async fn handle_get_certificate_status( + id: &str, + params: &serde_json::Value, + proxy: &Option, +) -> ManagementResponse { + let p = match proxy.as_ref() { + Some(p) => p, + None => return ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()), + }; + + let route_name = match params.get("routeName").and_then(|v| v.as_str()) { + Some(name) => name, + None => return ManagementResponse::err(id.to_string(), "Missing 'routeName' parameter".to_string()), + }; + + match p.get_certificate_status(route_name).await { + Some(status) => ManagementResponse::ok(id.to_string(), serde_json::json!({ + "domain": status.domain, + "source": status.source, + "expiresAt": status.expires_at, + "isValid": status.is_valid, + })), + None => ManagementResponse::ok(id.to_string(), serde_json::Value::Null), + } +} + +fn handle_get_listening_ports( + id: &str, + proxy: &Option, +) -> ManagementResponse { + match proxy.as_ref() { + Some(p) => { + let ports = p.get_listening_ports(); + ManagementResponse::ok(id.to_string(), serde_json::json!({ "ports": ports })) + } + None => ManagementResponse::ok(id.to_string(), serde_json::json!({ "ports": [] })), + } +} + +async fn handle_get_nftables_status( + id: &str, + proxy: &Option, +) -> ManagementResponse { + match proxy.as_ref() { + Some(p) => { + match p.get_nftables_status().await { + Ok(status) => { + match serde_json::to_value(&status) { + Ok(v) => ManagementResponse::ok(id.to_string(), v), + Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to serialize: {}", e)), + } + } + Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to get status: {}", e)), + } + } + None => ManagementResponse::ok(id.to_string(), serde_json::json!({})), + } +} + +async fn handle_set_socket_handler_relay( + id: &str, + params: &serde_json::Value, + proxy: &mut Option, +) -> ManagementResponse { + let p = match proxy.as_mut() { + Some(p) => p, + None => return ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()), + }; + + let socket_path = params.get("socketPath") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()); + + info!("setSocketHandlerRelay: socket_path={:?}", socket_path); + p.set_socket_handler_relay_path(socket_path); + + ManagementResponse::ok(id.to_string(), serde_json::json!({})) +} + +async fn handle_add_listening_port( + id: &str, + params: &serde_json::Value, + proxy: &mut Option, +) -> ManagementResponse { + let p = match proxy.as_mut() { + Some(p) => p, + None => return ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()), + }; + + let port = match params.get("port").and_then(|v| v.as_u64()) { + Some(port) => port as u16, + None => return ManagementResponse::err(id.to_string(), "Missing 'port' parameter".to_string()), + }; + + match p.add_listening_port(port).await { + Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})), + Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to add port {}: {}", port, e)), + } +} + +async fn handle_remove_listening_port( + id: &str, + params: &serde_json::Value, + proxy: &mut Option, +) -> ManagementResponse { + let p = match proxy.as_mut() { + Some(p) => p, + None => return ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()), + }; + + let port = match params.get("port").and_then(|v| v.as_u64()) { + Some(port) => port as u16, + None => return ManagementResponse::err(id.to_string(), "Missing 'port' parameter".to_string()), + }; + + match p.remove_listening_port(port).await { + Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})), + Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to remove port {}: {}", port, e)), + } +} + +async fn handle_load_certificate( + id: &str, + params: &serde_json::Value, + proxy: &mut Option, +) -> ManagementResponse { + let p = match proxy.as_mut() { + Some(p) => p, + None => return ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()), + }; + + let domain = match params.get("domain").and_then(|v| v.as_str()) { + Some(d) => d.to_string(), + None => return ManagementResponse::err(id.to_string(), "Missing 'domain' parameter".to_string()), + }; + + let cert = match params.get("cert").and_then(|v| v.as_str()) { + Some(c) => c.to_string(), + None => return ManagementResponse::err(id.to_string(), "Missing 'cert' parameter".to_string()), + }; + + let key = match params.get("key").and_then(|v| v.as_str()) { + Some(k) => k.to_string(), + None => return ManagementResponse::err(id.to_string(), "Missing 'key' parameter".to_string()), + }; + + let ca = params.get("ca").and_then(|v| v.as_str()).map(|s| s.to_string()); + + info!("loadCertificate: domain={}", domain); + + // Load cert into cert manager and hot-swap TLS config + match p.load_certificate(&domain, cert, key, ca).await { + Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})), + Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to load certificate for {}: {}", domain, e)), + } +} diff --git a/rust/crates/rustproxy/tests/common/mod.rs b/rust/crates/rustproxy/tests/common/mod.rs new file mode 100644 index 0000000..578e3f0 --- /dev/null +++ b/rust/crates/rustproxy/tests/common/mod.rs @@ -0,0 +1,402 @@ +use std::sync::atomic::{AtomicU16, Ordering}; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::TcpListener; +use tokio::task::JoinHandle; + +/// Atomic port allocator starting at 19000 to avoid collisions. +static PORT_COUNTER: AtomicU16 = AtomicU16::new(19000); + +/// Get the next available port for testing. +pub fn next_port() -> u16 { + PORT_COUNTER.fetch_add(1, Ordering::SeqCst) +} + +/// Start a simple TCP echo server that echoes back whatever it receives. +/// Returns the join handle for the server task. +pub async fn start_echo_server(port: u16) -> JoinHandle<()> { + let listener = TcpListener::bind(format!("127.0.0.1:{}", port)) + .await + .expect("Failed to bind echo server"); + + tokio::spawn(async move { + loop { + let (mut stream, _) = match listener.accept().await { + Ok(conn) => conn, + Err(_) => break, + }; + tokio::spawn(async move { + let mut buf = vec![0u8; 65536]; + loop { + let n = match stream.read(&mut buf).await { + Ok(0) | Err(_) => break, + Ok(n) => n, + }; + if stream.write_all(&buf[..n]).await.is_err() { + break; + } + } + }); + } + }) +} + +/// Start a TCP echo server that prefixes responses to identify which backend responded. +pub async fn start_prefix_echo_server(port: u16, prefix: &str) -> JoinHandle<()> { + let prefix = prefix.to_string(); + let listener = TcpListener::bind(format!("127.0.0.1:{}", port)) + .await + .expect("Failed to bind prefix echo server"); + + tokio::spawn(async move { + loop { + let (mut stream, _) = match listener.accept().await { + Ok(conn) => conn, + Err(_) => break, + }; + let pfx = prefix.clone(); + tokio::spawn(async move { + let mut buf = vec![0u8; 65536]; + loop { + let n = match stream.read(&mut buf).await { + Ok(0) | Err(_) => break, + Ok(n) => n, + }; + let mut response = pfx.as_bytes().to_vec(); + response.extend_from_slice(&buf[..n]); + if stream.write_all(&response).await.is_err() { + break; + } + } + }); + } + }) +} + +/// Start a simple HTTP server that responds with a fixed status and body. +pub async fn start_http_server(port: u16, status: u16, body: &str) -> JoinHandle<()> { + let body = body.to_string(); + let listener = TcpListener::bind(format!("127.0.0.1:{}", port)) + .await + .expect("Failed to bind HTTP server"); + + tokio::spawn(async move { + loop { + let (mut stream, _) = match listener.accept().await { + Ok(conn) => conn, + Err(_) => break, + }; + let b = body.clone(); + tokio::spawn(async move { + let mut buf = vec![0u8; 8192]; + // Read the request + let _n = stream.read(&mut buf).await.unwrap_or(0); + // Send response + let response = format!( + "HTTP/1.1 {} OK\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}", + status, + b.len(), + b, + ); + let _ = stream.write_all(response.as_bytes()).await; + let _ = stream.shutdown().await; + }); + } + }) +} + +/// Start an HTTP backend server that echoes back request details as JSON. +/// The response body contains: {"method":"GET","path":"/foo","host":"example.com","backend":""} +/// Supports keep-alive by reading HTTP requests properly. +pub async fn start_http_echo_backend(port: u16, backend_name: &str) -> JoinHandle<()> { + let name = backend_name.to_string(); + let listener = TcpListener::bind(format!("127.0.0.1:{}", port)) + .await + .unwrap_or_else(|_| panic!("Failed to bind HTTP echo backend on port {}", port)); + + tokio::spawn(async move { + loop { + let (mut stream, _) = match listener.accept().await { + Ok(conn) => conn, + Err(_) => break, + }; + let backend = name.clone(); + tokio::spawn(async move { + let mut buf = vec![0u8; 16384]; + // Read request data + let n = match stream.read(&mut buf).await { + Ok(0) | Err(_) => return, + Ok(n) => n, + }; + let req_str = String::from_utf8_lossy(&buf[..n]); + + // Parse first line: METHOD PATH HTTP/x.x + let first_line = req_str.lines().next().unwrap_or(""); + let parts: Vec<&str> = first_line.split_whitespace().collect(); + let method = parts.first().copied().unwrap_or("UNKNOWN"); + let path = parts.get(1).copied().unwrap_or("/"); + + // Extract Host header + let host = req_str.lines() + .find(|l| l.to_lowercase().starts_with("host:")) + .map(|l| l[5..].trim()) + .unwrap_or("unknown"); + + let body = format!( + r#"{{"method":"{}","path":"{}","host":"{}","backend":"{}"}}"#, + method, path, host, backend + ); + + let response = format!( + "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}", + body.len(), + body, + ); + let _ = stream.write_all(response.as_bytes()).await; + let _ = stream.shutdown().await; + }); + } + }) +} + +/// Wrap a future with a timeout, preventing tests from hanging. +pub async fn with_timeout(future: F, secs: u64) -> Result +where + F: std::future::Future, +{ + match tokio::time::timeout(std::time::Duration::from_secs(secs), future).await { + Ok(result) => Ok(result), + Err(_) => Err("Test timed out"), + } +} + +/// Wait briefly for a server to be ready by attempting TCP connections. +pub async fn wait_for_port(port: u16, timeout_ms: u64) -> bool { + let start = std::time::Instant::now(); + let timeout = std::time::Duration::from_millis(timeout_ms); + while start.elapsed() < timeout { + if tokio::net::TcpStream::connect(format!("127.0.0.1:{}", port)) + .await + .is_ok() + { + return true; + } + tokio::time::sleep(std::time::Duration::from_millis(10)).await; + } + false +} + +/// Helper to create a minimal route config for testing. +pub fn make_test_route( + port: u16, + domain: Option<&str>, + target_host: &str, + target_port: u16, +) -> rustproxy_config::RouteConfig { + rustproxy_config::RouteConfig { + id: None, + route_match: rustproxy_config::RouteMatch { + ports: rustproxy_config::PortRange::Single(port), + domains: domain.map(|d| rustproxy_config::DomainSpec::Single(d.to_string())), + path: None, + client_ip: None, + tls_version: None, + headers: None, + }, + action: rustproxy_config::RouteAction { + action_type: rustproxy_config::RouteActionType::Forward, + targets: Some(vec![rustproxy_config::RouteTarget { + target_match: None, + host: rustproxy_config::HostSpec::Single(target_host.to_string()), + port: rustproxy_config::PortSpec::Fixed(target_port), + tls: None, + websocket: None, + load_balancing: None, + send_proxy_protocol: None, + headers: None, + advanced: None, + priority: None, + }]), + tls: None, + websocket: None, + load_balancing: None, + advanced: None, + options: None, + forwarding_engine: None, + nftables: None, + send_proxy_protocol: None, + }, + headers: None, + security: None, + name: None, + description: None, + priority: None, + tags: None, + enabled: None, + } +} + +/// Start a simple WebSocket echo backend. +/// +/// Accepts WebSocket upgrade requests (HTTP Upgrade: websocket), sends 101 back, +/// then echoes all data received on the connection. +pub async fn start_ws_echo_backend(port: u16) -> JoinHandle<()> { + let listener = TcpListener::bind(format!("127.0.0.1:{}", port)) + .await + .unwrap_or_else(|_| panic!("Failed to bind WS echo backend on port {}", port)); + + tokio::spawn(async move { + loop { + let (mut stream, _) = match listener.accept().await { + Ok(conn) => conn, + Err(_) => break, + }; + tokio::spawn(async move { + // Read the HTTP upgrade request + let mut buf = vec![0u8; 4096]; + let n = match stream.read(&mut buf).await { + Ok(0) | Err(_) => return, + Ok(n) => n, + }; + + let req_str = String::from_utf8_lossy(&buf[..n]); + + // Extract Sec-WebSocket-Key for proper handshake + let ws_key = req_str.lines() + .find(|l| l.to_lowercase().starts_with("sec-websocket-key:")) + .map(|l| l.split(':').nth(1).unwrap_or("").trim().to_string()) + .unwrap_or_default(); + + // Compute Sec-WebSocket-Accept (simplified - just echo for test purposes) + // Real implementation would compute SHA-1 + base64 + let accept_response = format!( + "HTTP/1.1 101 Switching Protocols\r\n\ + Upgrade: websocket\r\n\ + Connection: Upgrade\r\n\ + Sec-WebSocket-Accept: {}\r\n\ + \r\n", + ws_key + ); + + if stream.write_all(accept_response.as_bytes()).await.is_err() { + return; + } + + // Echo all data back (raw TCP after upgrade) + let mut echo_buf = vec![0u8; 65536]; + loop { + let n = match stream.read(&mut echo_buf).await { + Ok(0) | Err(_) => break, + Ok(n) => n, + }; + if stream.write_all(&echo_buf[..n]).await.is_err() { + break; + } + } + }); + } + }) +} + +/// Generate a self-signed certificate for testing using rcgen. +/// Returns (cert_pem, key_pem). +pub fn generate_self_signed_cert(domain: &str) -> (String, String) { + use rcgen::{CertificateParams, KeyPair}; + + let mut params = CertificateParams::new(vec![domain.to_string()]).unwrap(); + params.distinguished_name.push(rcgen::DnType::CommonName, domain); + + let key_pair = KeyPair::generate().unwrap(); + let cert = params.self_signed(&key_pair).unwrap(); + + (cert.pem(), key_pair.serialize_pem()) +} + +/// Start a TLS echo server using the given cert/key. +/// Returns the join handle. +pub async fn start_tls_echo_server(port: u16, cert_pem: &str, key_pem: &str) -> JoinHandle<()> { + use std::sync::Arc; + + let acceptor = rustproxy_passthrough::build_tls_acceptor(cert_pem, key_pem) + .expect("Failed to build TLS acceptor"); + let acceptor = Arc::new(acceptor); + + let listener = TcpListener::bind(format!("127.0.0.1:{}", port)) + .await + .expect("Failed to bind TLS echo server"); + + tokio::spawn(async move { + loop { + let (stream, _) = match listener.accept().await { + Ok(conn) => conn, + Err(_) => break, + }; + let acc = acceptor.clone(); + tokio::spawn(async move { + let mut tls_stream = match acc.accept(stream).await { + Ok(s) => s, + Err(_) => return, + }; + let mut buf = vec![0u8; 65536]; + loop { + let n = match tls_stream.read(&mut buf).await { + Ok(0) | Err(_) => break, + Ok(n) => n, + }; + if tls_stream.write_all(&buf[..n]).await.is_err() { + break; + } + } + }); + } + }) +} + +/// Helper to create a TLS terminate route with static cert for testing. +pub fn make_tls_terminate_route( + port: u16, + domain: &str, + target_host: &str, + target_port: u16, + cert_pem: &str, + key_pem: &str, +) -> rustproxy_config::RouteConfig { + let mut route = make_test_route(port, Some(domain), target_host, target_port); + route.action.tls = Some(rustproxy_config::RouteTls { + mode: rustproxy_config::TlsMode::Terminate, + certificate: Some(rustproxy_config::CertificateSpec::Static( + rustproxy_config::CertificateConfig { + cert: cert_pem.to_string(), + key: key_pem.to_string(), + ca: None, + key_file: None, + cert_file: None, + }, + )), + acme: None, + versions: None, + ciphers: None, + honor_cipher_order: None, + session_timeout: None, + }); + route +} + +/// Helper to create a TLS passthrough route for testing. +pub fn make_tls_passthrough_route( + port: u16, + domain: Option<&str>, + target_host: &str, + target_port: u16, +) -> rustproxy_config::RouteConfig { + let mut route = make_test_route(port, domain, target_host, target_port); + route.action.tls = Some(rustproxy_config::RouteTls { + mode: rustproxy_config::TlsMode::Passthrough, + certificate: None, + acme: None, + versions: None, + ciphers: None, + honor_cipher_order: None, + session_timeout: None, + }); + route +} diff --git a/rust/crates/rustproxy/tests/integration_http_proxy.rs b/rust/crates/rustproxy/tests/integration_http_proxy.rs new file mode 100644 index 0000000..6651f54 --- /dev/null +++ b/rust/crates/rustproxy/tests/integration_http_proxy.rs @@ -0,0 +1,453 @@ +mod common; + +use common::*; +use rustproxy::RustProxy; +use rustproxy_config::RustProxyOptions; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; + +/// Send a raw HTTP request and return the full response as a string. +async fn send_http_request(port: u16, host: &str, method: &str, path: &str) -> String { + let mut stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", port)) + .await + .unwrap(); + + let request = format!( + "{} {} HTTP/1.1\r\nHost: {}\r\nConnection: close\r\n\r\n", + method, path, host, + ); + stream.write_all(request.as_bytes()).await.unwrap(); + + let mut response = Vec::new(); + stream.read_to_end(&mut response).await.unwrap(); + String::from_utf8_lossy(&response).to_string() +} + +/// Extract the body from a raw HTTP response string (after the \r\n\r\n). +fn extract_body(response: &str) -> &str { + response.split("\r\n\r\n").nth(1).unwrap_or("") +} + +#[tokio::test] +async fn test_http_forward_basic() { + let backend_port = next_port(); + let proxy_port = next_port(); + + let _backend = start_http_echo_backend(backend_port, "main").await; + + let options = RustProxyOptions { + routes: vec![make_test_route(proxy_port, None, "127.0.0.1", backend_port)], + ..Default::default() + }; + + let mut proxy = RustProxy::new(options).unwrap(); + proxy.start().await.unwrap(); + assert!(wait_for_port(proxy_port, 2000).await); + + let result = with_timeout(async { + let response = send_http_request(proxy_port, "anyhost.com", "GET", "/hello").await; + let body = extract_body(&response); + body.to_string() + }, 10) + .await + .unwrap(); + + assert!(result.contains(r#""method":"GET"#), "Expected GET method, got: {}", result); + assert!(result.contains(r#""path":"/hello"#), "Expected /hello path, got: {}", result); + assert!(result.contains(r#""backend":"main"#), "Expected main backend, got: {}", result); + + proxy.stop().await.unwrap(); +} + +#[tokio::test] +async fn test_http_forward_host_routing() { + let backend1_port = next_port(); + let backend2_port = next_port(); + let proxy_port = next_port(); + + let _b1 = start_http_echo_backend(backend1_port, "alpha").await; + let _b2 = start_http_echo_backend(backend2_port, "beta").await; + + let options = RustProxyOptions { + routes: vec![ + make_test_route(proxy_port, Some("alpha.example.com"), "127.0.0.1", backend1_port), + make_test_route(proxy_port, Some("beta.example.com"), "127.0.0.1", backend2_port), + ], + ..Default::default() + }; + + let mut proxy = RustProxy::new(options).unwrap(); + proxy.start().await.unwrap(); + assert!(wait_for_port(proxy_port, 2000).await); + + // Test alpha domain + let alpha_result = with_timeout(async { + let response = send_http_request(proxy_port, "alpha.example.com", "GET", "/").await; + extract_body(&response).to_string() + }, 10) + .await + .unwrap(); + + assert!(alpha_result.contains(r#""backend":"alpha"#), "Expected alpha backend, got: {}", alpha_result); + + // Test beta domain + let beta_result = with_timeout(async { + let response = send_http_request(proxy_port, "beta.example.com", "GET", "/").await; + extract_body(&response).to_string() + }, 10) + .await + .unwrap(); + + assert!(beta_result.contains(r#""backend":"beta"#), "Expected beta backend, got: {}", beta_result); + + proxy.stop().await.unwrap(); +} + +#[tokio::test] +async fn test_http_forward_path_routing() { + let backend1_port = next_port(); + let backend2_port = next_port(); + let proxy_port = next_port(); + + let _b1 = start_http_echo_backend(backend1_port, "api").await; + let _b2 = start_http_echo_backend(backend2_port, "web").await; + + let mut api_route = make_test_route(proxy_port, None, "127.0.0.1", backend1_port); + api_route.route_match.path = Some("/api/**".to_string()); + api_route.priority = Some(10); + + let web_route = make_test_route(proxy_port, None, "127.0.0.1", backend2_port); + + let options = RustProxyOptions { + routes: vec![api_route, web_route], + ..Default::default() + }; + + let mut proxy = RustProxy::new(options).unwrap(); + proxy.start().await.unwrap(); + assert!(wait_for_port(proxy_port, 2000).await); + + // Test API path + let api_result = with_timeout(async { + let response = send_http_request(proxy_port, "any.com", "GET", "/api/users").await; + extract_body(&response).to_string() + }, 10) + .await + .unwrap(); + + assert!(api_result.contains(r#""backend":"api"#), "Expected api backend, got: {}", api_result); + + // Test web path (no /api prefix) + let web_result = with_timeout(async { + let response = send_http_request(proxy_port, "any.com", "GET", "/index.html").await; + extract_body(&response).to_string() + }, 10) + .await + .unwrap(); + + assert!(web_result.contains(r#""backend":"web"#), "Expected web backend, got: {}", web_result); + + proxy.stop().await.unwrap(); +} + +#[tokio::test] +async fn test_http_forward_cors_preflight() { + let backend_port = next_port(); + let proxy_port = next_port(); + + let _backend = start_http_echo_backend(backend_port, "main").await; + + let options = RustProxyOptions { + routes: vec![make_test_route(proxy_port, None, "127.0.0.1", backend_port)], + ..Default::default() + }; + + let mut proxy = RustProxy::new(options).unwrap(); + proxy.start().await.unwrap(); + assert!(wait_for_port(proxy_port, 2000).await); + + let result = with_timeout(async { + let mut stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port)) + .await + .unwrap(); + + // Send CORS preflight request + let request = format!( + "OPTIONS /api/data HTTP/1.1\r\nHost: example.com\r\nOrigin: http://localhost:3000\r\nAccess-Control-Request-Method: POST\r\nConnection: close\r\n\r\n", + ); + stream.write_all(request.as_bytes()).await.unwrap(); + + let mut response = Vec::new(); + stream.read_to_end(&mut response).await.unwrap(); + String::from_utf8_lossy(&response).to_string() + }, 10) + .await + .unwrap(); + + // Should get 204 No Content with CORS headers + assert!(result.contains("204"), "Expected 204 status, got: {}", result); + assert!(result.to_lowercase().contains("access-control-allow-origin"), + "Expected CORS header, got: {}", result); + + proxy.stop().await.unwrap(); +} + +#[tokio::test] +async fn test_http_forward_backend_error() { + let backend_port = next_port(); + let proxy_port = next_port(); + + // Start an HTTP server that returns 500 + let _backend = start_http_server(backend_port, 500, "Internal Error").await; + + let options = RustProxyOptions { + routes: vec![make_test_route(proxy_port, None, "127.0.0.1", backend_port)], + ..Default::default() + }; + + let mut proxy = RustProxy::new(options).unwrap(); + proxy.start().await.unwrap(); + assert!(wait_for_port(proxy_port, 2000).await); + + let result = with_timeout(async { + let response = send_http_request(proxy_port, "example.com", "GET", "/fail").await; + response + }, 10) + .await + .unwrap(); + + // Proxy should relay the 500 from backend + assert!(result.contains("500"), "Expected 500 status, got: {}", result); + + proxy.stop().await.unwrap(); +} + +#[tokio::test] +async fn test_http_forward_no_route_matched() { + let proxy_port = next_port(); + + // Create a route only for a specific domain + let options = RustProxyOptions { + routes: vec![make_test_route(proxy_port, Some("known.example.com"), "127.0.0.1", 9999)], + ..Default::default() + }; + + let mut proxy = RustProxy::new(options).unwrap(); + proxy.start().await.unwrap(); + assert!(wait_for_port(proxy_port, 2000).await); + + let result = with_timeout(async { + let response = send_http_request(proxy_port, "unknown.example.com", "GET", "/").await; + response + }, 10) + .await + .unwrap(); + + // Should get 502 Bad Gateway (no route matched) + assert!(result.contains("502"), "Expected 502 status, got: {}", result); + + proxy.stop().await.unwrap(); +} + +#[tokio::test] +async fn test_http_forward_backend_unavailable() { + let proxy_port = next_port(); + let dead_port = next_port(); // No server running here + + let options = RustProxyOptions { + routes: vec![make_test_route(proxy_port, None, "127.0.0.1", dead_port)], + ..Default::default() + }; + + let mut proxy = RustProxy::new(options).unwrap(); + proxy.start().await.unwrap(); + assert!(wait_for_port(proxy_port, 2000).await); + + let result = with_timeout(async { + let response = send_http_request(proxy_port, "example.com", "GET", "/").await; + response + }, 10) + .await + .unwrap(); + + // Should get 502 Bad Gateway (backend unavailable) + assert!(result.contains("502"), "Expected 502 status, got: {}", result); + + proxy.stop().await.unwrap(); +} + +#[tokio::test] +async fn test_https_terminate_http_forward() { + let backend_port = next_port(); + let proxy_port = next_port(); + let domain = "httpproxy.example.com"; + + let (cert_pem, key_pem) = generate_self_signed_cert(domain); + let _backend = start_http_echo_backend(backend_port, "tls-backend").await; + + let options = RustProxyOptions { + routes: vec![make_tls_terminate_route( + proxy_port, domain, "127.0.0.1", backend_port, &cert_pem, &key_pem, + )], + ..Default::default() + }; + + let mut proxy = RustProxy::new(options).unwrap(); + proxy.start().await.unwrap(); + assert!(wait_for_port(proxy_port, 2000).await); + + let result = with_timeout(async { + let _ = rustls::crypto::ring::default_provider().install_default(); + let tls_config = rustls::ClientConfig::builder() + .dangerous() + .with_custom_certificate_verifier(std::sync::Arc::new(InsecureVerifier)) + .with_no_client_auth(); + let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(tls_config)); + + let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port)) + .await + .unwrap(); + let server_name = rustls::pki_types::ServerName::try_from(domain.to_string()).unwrap(); + let mut tls_stream = connector.connect(server_name, stream).await.unwrap(); + + // Send HTTP request through TLS + let request = format!( + "GET /api/data HTTP/1.1\r\nHost: {}\r\nConnection: close\r\n\r\n", + domain + ); + tls_stream.write_all(request.as_bytes()).await.unwrap(); + + let mut response = Vec::new(); + tls_stream.read_to_end(&mut response).await.unwrap(); + String::from_utf8_lossy(&response).to_string() + }, 10) + .await + .unwrap(); + + let body = extract_body(&result); + assert!(body.contains(r#""method":"GET"#), "Expected GET, got: {}", body); + assert!(body.contains(r#""path":"/api/data"#), "Expected /api/data, got: {}", body); + assert!(body.contains(r#""backend":"tls-backend"#), "Expected tls-backend, got: {}", body); + + proxy.stop().await.unwrap(); +} + +#[tokio::test] +async fn test_websocket_through_proxy() { + let backend_port = next_port(); + let proxy_port = next_port(); + + let _backend = start_ws_echo_backend(backend_port).await; + + let options = RustProxyOptions { + routes: vec![make_test_route(proxy_port, None, "127.0.0.1", backend_port)], + ..Default::default() + }; + + let mut proxy = RustProxy::new(options).unwrap(); + proxy.start().await.unwrap(); + assert!(wait_for_port(proxy_port, 2000).await); + + let result = with_timeout(async { + let mut stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port)) + .await + .unwrap(); + + // Send WebSocket upgrade request + let request = format!( + "GET /ws HTTP/1.1\r\n\ + Host: example.com\r\n\ + Upgrade: websocket\r\n\ + Connection: Upgrade\r\n\ + Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n\ + Sec-WebSocket-Version: 13\r\n\ + \r\n" + ); + stream.write_all(request.as_bytes()).await.unwrap(); + + // Read the 101 response + let mut response_buf = Vec::with_capacity(4096); + let mut temp = [0u8; 1]; + loop { + let n = stream.read(&mut temp).await.unwrap(); + if n == 0 { break; } + response_buf.push(temp[0]); + if response_buf.len() >= 4 { + let len = response_buf.len(); + if response_buf[len-4..] == *b"\r\n\r\n" { + break; + } + } + } + + let response_str = String::from_utf8_lossy(&response_buf).to_string(); + assert!(response_str.contains("101"), "Expected 101 Switching Protocols, got: {}", response_str); + assert!( + response_str.to_lowercase().contains("upgrade: websocket"), + "Expected Upgrade header, got: {}", + response_str + ); + + // After upgrade, send data and verify echo + let test_data = b"Hello WebSocket!"; + stream.write_all(test_data).await.unwrap(); + + // Read echoed data + let mut echo_buf = vec![0u8; 256]; + let n = stream.read(&mut echo_buf).await.unwrap(); + let echoed = &echo_buf[..n]; + + assert_eq!(echoed, test_data, "Expected echo of sent data"); + + "ok".to_string() + }, 10) + .await + .unwrap(); + + assert_eq!(result, "ok"); + proxy.stop().await.unwrap(); +} + +/// InsecureVerifier for test TLS client connections. +#[derive(Debug)] +struct InsecureVerifier; + +impl rustls::client::danger::ServerCertVerifier for InsecureVerifier { + fn verify_server_cert( + &self, + _end_entity: &rustls::pki_types::CertificateDer<'_>, + _intermediates: &[rustls::pki_types::CertificateDer<'_>], + _server_name: &rustls::pki_types::ServerName<'_>, + _ocsp_response: &[u8], + _now: rustls::pki_types::UnixTime, + ) -> Result { + Ok(rustls::client::danger::ServerCertVerified::assertion()) + } + + fn verify_tls12_signature( + &self, + _message: &[u8], + _cert: &rustls::pki_types::CertificateDer<'_>, + _dss: &rustls::DigitallySignedStruct, + ) -> Result { + Ok(rustls::client::danger::HandshakeSignatureValid::assertion()) + } + + fn verify_tls13_signature( + &self, + _message: &[u8], + _cert: &rustls::pki_types::CertificateDer<'_>, + _dss: &rustls::DigitallySignedStruct, + ) -> Result { + Ok(rustls::client::danger::HandshakeSignatureValid::assertion()) + } + + fn supported_verify_schemes(&self) -> Vec { + vec![ + rustls::SignatureScheme::RSA_PKCS1_SHA256, + rustls::SignatureScheme::ECDSA_NISTP256_SHA256, + rustls::SignatureScheme::ECDSA_NISTP384_SHA384, + rustls::SignatureScheme::ED25519, + rustls::SignatureScheme::RSA_PSS_SHA256, + ] + } +} diff --git a/rust/crates/rustproxy/tests/integration_proxy_lifecycle.rs b/rust/crates/rustproxy/tests/integration_proxy_lifecycle.rs new file mode 100644 index 0000000..3beee5a --- /dev/null +++ b/rust/crates/rustproxy/tests/integration_proxy_lifecycle.rs @@ -0,0 +1,250 @@ +mod common; + +use common::*; +use rustproxy::RustProxy; +use rustproxy_config::RustProxyOptions; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; + +#[tokio::test] +async fn test_start_and_stop() { + let port = next_port(); + + let options = RustProxyOptions { + routes: vec![make_test_route(port, None, "127.0.0.1", 8080)], + ..Default::default() + }; + + let mut proxy = RustProxy::new(options).unwrap(); + + // Not listening before start + assert!(!wait_for_port(port, 200).await); + + proxy.start().await.unwrap(); + assert!(wait_for_port(port, 2000).await, "Port should be listening after start"); + + proxy.stop().await.unwrap(); + + // Give the OS a moment to release the port + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + assert!(!wait_for_port(port, 200).await, "Port should not be listening after stop"); +} + +#[tokio::test] +async fn test_double_start_fails() { + let port = next_port(); + + let options = RustProxyOptions { + routes: vec![make_test_route(port, None, "127.0.0.1", 8080)], + ..Default::default() + }; + + let mut proxy = RustProxy::new(options).unwrap(); + proxy.start().await.unwrap(); + + // Second start should fail + let result = proxy.start().await; + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("already started")); + + proxy.stop().await.unwrap(); +} + +#[tokio::test] +async fn test_update_routes_hot_reload() { + let port = next_port(); + + let options = RustProxyOptions { + routes: vec![make_test_route(port, Some("old.example.com"), "127.0.0.1", 8080)], + ..Default::default() + }; + + let mut proxy = RustProxy::new(options).unwrap(); + proxy.start().await.unwrap(); + + // Update routes atomically + let new_routes = vec![ + make_test_route(port, Some("new.example.com"), "127.0.0.1", 9090), + ]; + let result = proxy.update_routes(new_routes).await; + assert!(result.is_ok()); + + proxy.stop().await.unwrap(); +} + +#[tokio::test] +async fn test_add_remove_listening_port() { + let port1 = next_port(); + let port2 = next_port(); + + let options = RustProxyOptions { + routes: vec![make_test_route(port1, None, "127.0.0.1", 8080)], + ..Default::default() + }; + + let mut proxy = RustProxy::new(options).unwrap(); + proxy.start().await.unwrap(); + assert!(wait_for_port(port1, 2000).await); + + // Add a new port + proxy.add_listening_port(port2).await.unwrap(); + assert!(wait_for_port(port2, 2000).await, "New port should be listening"); + + // Remove the port + proxy.remove_listening_port(port2).await.unwrap(); + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + assert!(!wait_for_port(port2, 200).await, "Removed port should not be listening"); + + // Original port should still be listening + assert!(wait_for_port(port1, 200).await, "Original port should still be listening"); + + proxy.stop().await.unwrap(); +} + +#[tokio::test] +async fn test_get_statistics() { + let port = next_port(); + + let options = RustProxyOptions { + routes: vec![make_test_route(port, None, "127.0.0.1", 8080)], + ..Default::default() + }; + + let mut proxy = RustProxy::new(options).unwrap(); + proxy.start().await.unwrap(); + + let stats = proxy.get_statistics(); + assert_eq!(stats.routes_count, 1); + assert!(stats.listening_ports.contains(&port)); + + proxy.stop().await.unwrap(); +} + +#[tokio::test] +async fn test_invalid_routes_rejected() { + let options = RustProxyOptions { + routes: vec![{ + let mut route = make_test_route(80, None, "127.0.0.1", 8080); + route.action.targets = None; // Invalid: forward without targets + route + }], + ..Default::default() + }; + + let result = RustProxy::new(options); + assert!(result.is_err()); +} + +#[tokio::test] +async fn test_metrics_track_connections() { + let backend_port = next_port(); + let proxy_port = next_port(); + + let _backend = start_echo_server(backend_port).await; + + let options = RustProxyOptions { + routes: vec![make_test_route(proxy_port, None, "127.0.0.1", backend_port)], + ..Default::default() + }; + + let mut proxy = RustProxy::new(options).unwrap(); + proxy.start().await.unwrap(); + assert!(wait_for_port(proxy_port, 2000).await); + + // No connections yet + let stats = proxy.get_statistics(); + assert_eq!(stats.total_connections, 0); + + // Make a connection and send data + { + let mut stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port)) + .await + .unwrap(); + stream.write_all(b"hello").await.unwrap(); + let mut buf = vec![0u8; 16]; + let _ = stream.read(&mut buf).await; + } + + // Small delay for metrics to update + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + + let stats = proxy.get_statistics(); + assert!(stats.total_connections > 0, "Expected total_connections > 0, got {}", stats.total_connections); + + proxy.stop().await.unwrap(); +} + +#[tokio::test] +async fn test_metrics_track_bytes() { + let backend_port = next_port(); + let proxy_port = next_port(); + + let _backend = start_http_echo_backend(backend_port, "metrics-test").await; + + let options = RustProxyOptions { + routes: vec![make_test_route(proxy_port, None, "127.0.0.1", backend_port)], + ..Default::default() + }; + + let mut proxy = RustProxy::new(options).unwrap(); + proxy.start().await.unwrap(); + assert!(wait_for_port(proxy_port, 2000).await); + + // Send HTTP request through proxy + { + let mut stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port)) + .await + .unwrap(); + let request = b"GET /test HTTP/1.1\r\nHost: example.com\r\nConnection: close\r\n\r\n"; + stream.write_all(request).await.unwrap(); + let mut response = Vec::new(); + stream.read_to_end(&mut response).await.unwrap(); + assert!(!response.is_empty(), "Expected non-empty response"); + } + + // Small delay for metrics to update + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + + let stats = proxy.get_statistics(); + assert!(stats.total_connections > 0, + "Expected some connections tracked, got {}", stats.total_connections); + + proxy.stop().await.unwrap(); +} + +#[tokio::test] +async fn test_hot_reload_port_changes() { + let port1 = next_port(); + let port2 = next_port(); + let backend_port = next_port(); + + let _backend = start_echo_server(backend_port).await; + + // Start with port1 + let options = RustProxyOptions { + routes: vec![make_test_route(port1, None, "127.0.0.1", backend_port)], + ..Default::default() + }; + + let mut proxy = RustProxy::new(options).unwrap(); + proxy.start().await.unwrap(); + assert!(wait_for_port(port1, 2000).await); + assert!(!wait_for_port(port2, 200).await, "port2 should not be listening yet"); + + // Update routes to use port2 instead + let new_routes = vec![ + make_test_route(port2, None, "127.0.0.1", backend_port), + ]; + proxy.update_routes(new_routes).await.unwrap(); + + // Port2 should now be listening, port1 should be closed + assert!(wait_for_port(port2, 2000).await, "port2 should be listening after reload"); + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + assert!(!wait_for_port(port1, 200).await, "port1 should be closed after reload"); + + // Verify port2 works + let ports = proxy.get_listening_ports(); + assert!(ports.contains(&port2), "Expected port2 in listening ports: {:?}", ports); + assert!(!ports.contains(&port1), "port1 should not be in listening ports: {:?}", ports); + + proxy.stop().await.unwrap(); +} diff --git a/rust/crates/rustproxy/tests/integration_tcp_passthrough.rs b/rust/crates/rustproxy/tests/integration_tcp_passthrough.rs new file mode 100644 index 0000000..5c56bfb --- /dev/null +++ b/rust/crates/rustproxy/tests/integration_tcp_passthrough.rs @@ -0,0 +1,197 @@ +mod common; + +use common::*; +use rustproxy::RustProxy; +use rustproxy_config::RustProxyOptions; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::TcpStream; + +#[tokio::test] +async fn test_tcp_forward_echo() { + let backend_port = next_port(); + let proxy_port = next_port(); + + // Start echo backend + let _backend = start_echo_server(backend_port).await; + + // Configure proxy + let options = RustProxyOptions { + routes: vec![make_test_route(proxy_port, None, "127.0.0.1", backend_port)], + ..Default::default() + }; + + let mut proxy = RustProxy::new(options).unwrap(); + proxy.start().await.unwrap(); + + // Wait for proxy to be ready + assert!(wait_for_port(proxy_port, 2000).await, "Proxy port not ready"); + + // Connect and send data + let result = with_timeout(async { + let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port)) + .await + .unwrap(); + stream.write_all(b"hello world").await.unwrap(); + + let mut buf = vec![0u8; 1024]; + let n = stream.read(&mut buf).await.unwrap(); + String::from_utf8_lossy(&buf[..n]).to_string() + }, 5) + .await + .unwrap(); + + assert_eq!(result, "hello world"); + + proxy.stop().await.unwrap(); +} + +#[tokio::test] +async fn test_tcp_forward_large_payload() { + let backend_port = next_port(); + let proxy_port = next_port(); + + let _backend = start_echo_server(backend_port).await; + + let options = RustProxyOptions { + routes: vec![make_test_route(proxy_port, None, "127.0.0.1", backend_port)], + ..Default::default() + }; + + let mut proxy = RustProxy::new(options).unwrap(); + proxy.start().await.unwrap(); + assert!(wait_for_port(proxy_port, 2000).await); + + let result = with_timeout(async { + let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port)) + .await + .unwrap(); + + // Send 1MB of data + let data = vec![b'A'; 1_000_000]; + stream.write_all(&data).await.unwrap(); + stream.shutdown().await.unwrap(); + + // Read all back + let mut received = Vec::new(); + stream.read_to_end(&mut received).await.unwrap(); + received.len() + }, 10) + .await + .unwrap(); + + assert_eq!(result, 1_000_000); + + proxy.stop().await.unwrap(); +} + +#[tokio::test] +async fn test_tcp_forward_multiple_connections() { + let backend_port = next_port(); + let proxy_port = next_port(); + + let _backend = start_echo_server(backend_port).await; + + let options = RustProxyOptions { + routes: vec![make_test_route(proxy_port, None, "127.0.0.1", backend_port)], + ..Default::default() + }; + + let mut proxy = RustProxy::new(options).unwrap(); + proxy.start().await.unwrap(); + assert!(wait_for_port(proxy_port, 2000).await); + + let result = with_timeout(async { + let mut handles = Vec::new(); + for i in 0..10 { + let port = proxy_port; + handles.push(tokio::spawn(async move { + let mut stream = TcpStream::connect(format!("127.0.0.1:{}", port)) + .await + .unwrap(); + let msg = format!("connection-{}", i); + stream.write_all(msg.as_bytes()).await.unwrap(); + + let mut buf = vec![0u8; 1024]; + let n = stream.read(&mut buf).await.unwrap(); + String::from_utf8_lossy(&buf[..n]).to_string() + })); + } + + let mut results = Vec::new(); + for handle in handles { + results.push(handle.await.unwrap()); + } + results + }, 10) + .await + .unwrap(); + + assert_eq!(result.len(), 10); + for (i, r) in result.iter().enumerate() { + assert_eq!(r, &format!("connection-{}", i)); + } + + proxy.stop().await.unwrap(); +} + +#[tokio::test] +async fn test_tcp_forward_backend_unreachable() { + let proxy_port = next_port(); + let dead_port = next_port(); // No server on this port + + let options = RustProxyOptions { + routes: vec![make_test_route(proxy_port, None, "127.0.0.1", dead_port)], + ..Default::default() + }; + + let mut proxy = RustProxy::new(options).unwrap(); + proxy.start().await.unwrap(); + assert!(wait_for_port(proxy_port, 2000).await); + + // Connection should complete (proxy accepts it) but data should not flow + let result = with_timeout(async { + let stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port)).await; + stream.is_ok() + }, 5) + .await + .unwrap(); + + assert!(result, "Should be able to connect to proxy even if backend is down"); + + proxy.stop().await.unwrap(); +} + +#[tokio::test] +async fn test_tcp_forward_bidirectional() { + let backend_port = next_port(); + let proxy_port = next_port(); + + // Start a prefix echo server to verify data flows in both directions + let _backend = start_prefix_echo_server(backend_port, "REPLY:").await; + + let options = RustProxyOptions { + routes: vec![make_test_route(proxy_port, None, "127.0.0.1", backend_port)], + ..Default::default() + }; + + let mut proxy = RustProxy::new(options).unwrap(); + proxy.start().await.unwrap(); + assert!(wait_for_port(proxy_port, 2000).await); + + let result = with_timeout(async { + let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port)) + .await + .unwrap(); + stream.write_all(b"test data").await.unwrap(); + + let mut buf = vec![0u8; 1024]; + let n = stream.read(&mut buf).await.unwrap(); + String::from_utf8_lossy(&buf[..n]).to_string() + }, 5) + .await + .unwrap(); + + assert_eq!(result, "REPLY:test data"); + + proxy.stop().await.unwrap(); +} diff --git a/rust/crates/rustproxy/tests/integration_tls_passthrough.rs b/rust/crates/rustproxy/tests/integration_tls_passthrough.rs new file mode 100644 index 0000000..1e24dac --- /dev/null +++ b/rust/crates/rustproxy/tests/integration_tls_passthrough.rs @@ -0,0 +1,247 @@ +mod common; + +use common::*; +use rustproxy::RustProxy; +use rustproxy_config::RustProxyOptions; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::TcpStream; + +/// Build a minimal TLS ClientHello with the given SNI domain. +/// This is enough for the proxy's SNI parser to extract the domain. +fn build_client_hello(domain: &str) -> Vec { + let domain_bytes = domain.as_bytes(); + let sni_length = domain_bytes.len() as u16; + + // Server Name extension (type 0x0000) + let mut sni_ext = Vec::new(); + sni_ext.extend_from_slice(&[0x00, 0x00]); // extension type: server_name + let sni_list_len = sni_length + 5; // 2 (list len) + 1 (type) + 2 (name len) + name + sni_ext.extend_from_slice(&(sni_list_len as u16).to_be_bytes()); // extension data length + sni_ext.extend_from_slice(&((sni_list_len - 2) as u16).to_be_bytes()); // server name list length + sni_ext.push(0x00); // host_name type + sni_ext.extend_from_slice(&sni_length.to_be_bytes()); + sni_ext.extend_from_slice(domain_bytes); + + let extensions_length = sni_ext.len() as u16; + + // ClientHello message + let mut client_hello = Vec::new(); + client_hello.extend_from_slice(&[0x03, 0x03]); // TLS 1.2 version + client_hello.extend_from_slice(&[0x00; 32]); // random + client_hello.push(0x00); // session_id length + client_hello.extend_from_slice(&[0x00, 0x02, 0x00, 0xff]); // cipher suites (1 suite) + client_hello.extend_from_slice(&[0x01, 0x00]); // compression methods (null) + client_hello.extend_from_slice(&extensions_length.to_be_bytes()); + client_hello.extend_from_slice(&sni_ext); + + let hello_len = client_hello.len() as u32; + + // Handshake wrapper (type 1 = ClientHello) + let mut handshake = Vec::new(); + handshake.push(0x01); // ClientHello + handshake.extend_from_slice(&hello_len.to_be_bytes()[1..4]); // 3-byte length + handshake.extend_from_slice(&client_hello); + + let hs_len = handshake.len() as u16; + + // TLS record + let mut record = Vec::new(); + record.push(0x16); // ContentType: Handshake + record.extend_from_slice(&[0x03, 0x01]); // TLS 1.0 (record version) + record.extend_from_slice(&hs_len.to_be_bytes()); + record.extend_from_slice(&handshake); + + record +} + +#[tokio::test] +async fn test_tls_passthrough_sni_routing() { + let backend1_port = next_port(); + let backend2_port = next_port(); + let proxy_port = next_port(); + + let _b1 = start_prefix_echo_server(backend1_port, "BACKEND1:").await; + let _b2 = start_prefix_echo_server(backend2_port, "BACKEND2:").await; + + let options = RustProxyOptions { + routes: vec![ + make_tls_passthrough_route(proxy_port, Some("one.example.com"), "127.0.0.1", backend1_port), + make_tls_passthrough_route(proxy_port, Some("two.example.com"), "127.0.0.1", backend2_port), + ], + ..Default::default() + }; + + let mut proxy = RustProxy::new(options).unwrap(); + proxy.start().await.unwrap(); + assert!(wait_for_port(proxy_port, 2000).await); + + // Send a fake ClientHello with SNI "one.example.com" + let result = with_timeout(async { + let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port)) + .await + .unwrap(); + let hello = build_client_hello("one.example.com"); + stream.write_all(&hello).await.unwrap(); + + let mut buf = vec![0u8; 4096]; + let n = stream.read(&mut buf).await.unwrap(); + String::from_utf8_lossy(&buf[..n]).to_string() + }, 5) + .await + .unwrap(); + + // Backend1 should have received the ClientHello and prefixed its response + assert!(result.starts_with("BACKEND1:"), "Expected BACKEND1 prefix, got: {}", result); + + // Now test routing to backend2 + let result2 = with_timeout(async { + let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port)) + .await + .unwrap(); + let hello = build_client_hello("two.example.com"); + stream.write_all(&hello).await.unwrap(); + + let mut buf = vec![0u8; 4096]; + let n = stream.read(&mut buf).await.unwrap(); + String::from_utf8_lossy(&buf[..n]).to_string() + }, 5) + .await + .unwrap(); + + assert!(result2.starts_with("BACKEND2:"), "Expected BACKEND2 prefix, got: {}", result2); + + proxy.stop().await.unwrap(); +} + +#[tokio::test] +async fn test_tls_passthrough_unknown_sni() { + let backend_port = next_port(); + let proxy_port = next_port(); + + let _backend = start_echo_server(backend_port).await; + + let options = RustProxyOptions { + routes: vec![ + make_tls_passthrough_route(proxy_port, Some("known.example.com"), "127.0.0.1", backend_port), + ], + ..Default::default() + }; + + let mut proxy = RustProxy::new(options).unwrap(); + proxy.start().await.unwrap(); + assert!(wait_for_port(proxy_port, 2000).await); + + // Send ClientHello with unknown SNI - should get no response (connection dropped) + let result = with_timeout(async { + let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port)) + .await + .unwrap(); + let hello = build_client_hello("unknown.example.com"); + stream.write_all(&hello).await.unwrap(); + + let mut buf = vec![0u8; 4096]; + // Should either get 0 bytes (closed) or an error + match stream.read(&mut buf).await { + Ok(0) => true, // Connection closed = no route matched + Ok(_) => false, // Got data = route shouldn't have matched + Err(_) => true, // Error = connection dropped + } + }, 5) + .await + .unwrap(); + + assert!(result, "Unknown SNI should result in dropped connection"); + + proxy.stop().await.unwrap(); +} + +#[tokio::test] +async fn test_tls_passthrough_wildcard_domain() { + let backend_port = next_port(); + let proxy_port = next_port(); + + let _backend = start_prefix_echo_server(backend_port, "WILDCARD:").await; + + let options = RustProxyOptions { + routes: vec![ + make_tls_passthrough_route(proxy_port, Some("*.example.com"), "127.0.0.1", backend_port), + ], + ..Default::default() + }; + + let mut proxy = RustProxy::new(options).unwrap(); + proxy.start().await.unwrap(); + assert!(wait_for_port(proxy_port, 2000).await); + + // Should match any subdomain of example.com + let result = with_timeout(async { + let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port)) + .await + .unwrap(); + let hello = build_client_hello("anything.example.com"); + stream.write_all(&hello).await.unwrap(); + + let mut buf = vec![0u8; 4096]; + let n = stream.read(&mut buf).await.unwrap(); + String::from_utf8_lossy(&buf[..n]).to_string() + }, 5) + .await + .unwrap(); + + assert!(result.starts_with("WILDCARD:"), "Expected WILDCARD prefix, got: {}", result); + + proxy.stop().await.unwrap(); +} + +#[tokio::test] +async fn test_tls_passthrough_multiple_domains() { + let b1_port = next_port(); + let b2_port = next_port(); + let b3_port = next_port(); + let proxy_port = next_port(); + + let _b1 = start_prefix_echo_server(b1_port, "B1:").await; + let _b2 = start_prefix_echo_server(b2_port, "B2:").await; + let _b3 = start_prefix_echo_server(b3_port, "B3:").await; + + let options = RustProxyOptions { + routes: vec![ + make_tls_passthrough_route(proxy_port, Some("alpha.example.com"), "127.0.0.1", b1_port), + make_tls_passthrough_route(proxy_port, Some("beta.example.com"), "127.0.0.1", b2_port), + make_tls_passthrough_route(proxy_port, Some("gamma.example.com"), "127.0.0.1", b3_port), + ], + ..Default::default() + }; + + let mut proxy = RustProxy::new(options).unwrap(); + proxy.start().await.unwrap(); + assert!(wait_for_port(proxy_port, 2000).await); + + for (domain, expected_prefix) in [ + ("alpha.example.com", "B1:"), + ("beta.example.com", "B2:"), + ("gamma.example.com", "B3:"), + ] { + let result = with_timeout(async { + let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port)) + .await + .unwrap(); + let hello = build_client_hello(domain); + stream.write_all(&hello).await.unwrap(); + + let mut buf = vec![0u8; 4096]; + let n = stream.read(&mut buf).await.unwrap(); + String::from_utf8_lossy(&buf[..n]).to_string() + }, 5) + .await + .unwrap(); + + assert!( + result.starts_with(expected_prefix), + "Domain {} should route to {}, got: {}", + domain, expected_prefix, result + ); + } + + proxy.stop().await.unwrap(); +} diff --git a/rust/crates/rustproxy/tests/integration_tls_terminate.rs b/rust/crates/rustproxy/tests/integration_tls_terminate.rs new file mode 100644 index 0000000..e01e9c6 --- /dev/null +++ b/rust/crates/rustproxy/tests/integration_tls_terminate.rs @@ -0,0 +1,324 @@ +mod common; + +use common::*; +use rustproxy::RustProxy; +use rustproxy_config::RustProxyOptions; +use std::sync::Arc; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; + +/// Create a rustls client config that trusts self-signed certs. +fn make_insecure_tls_client_config() -> Arc { + let _ = rustls::crypto::ring::default_provider().install_default(); + let config = rustls::ClientConfig::builder() + .dangerous() + .with_custom_certificate_verifier(Arc::new(InsecureVerifier)) + .with_no_client_auth(); + Arc::new(config) +} + +#[derive(Debug)] +struct InsecureVerifier; + +impl rustls::client::danger::ServerCertVerifier for InsecureVerifier { + fn verify_server_cert( + &self, + _end_entity: &rustls::pki_types::CertificateDer<'_>, + _intermediates: &[rustls::pki_types::CertificateDer<'_>], + _server_name: &rustls::pki_types::ServerName<'_>, + _ocsp_response: &[u8], + _now: rustls::pki_types::UnixTime, + ) -> Result { + Ok(rustls::client::danger::ServerCertVerified::assertion()) + } + + fn verify_tls12_signature( + &self, + _message: &[u8], + _cert: &rustls::pki_types::CertificateDer<'_>, + _dss: &rustls::DigitallySignedStruct, + ) -> Result { + Ok(rustls::client::danger::HandshakeSignatureValid::assertion()) + } + + fn verify_tls13_signature( + &self, + _message: &[u8], + _cert: &rustls::pki_types::CertificateDer<'_>, + _dss: &rustls::DigitallySignedStruct, + ) -> Result { + Ok(rustls::client::danger::HandshakeSignatureValid::assertion()) + } + + fn supported_verify_schemes(&self) -> Vec { + vec![ + rustls::SignatureScheme::RSA_PKCS1_SHA256, + rustls::SignatureScheme::ECDSA_NISTP256_SHA256, + rustls::SignatureScheme::ECDSA_NISTP384_SHA384, + rustls::SignatureScheme::ED25519, + rustls::SignatureScheme::RSA_PSS_SHA256, + ] + } +} + +#[tokio::test] +async fn test_tls_terminate_basic() { + let backend_port = next_port(); + let proxy_port = next_port(); + let domain = "test.example.com"; + + // Generate self-signed cert + let (cert_pem, key_pem) = generate_self_signed_cert(domain); + + // Start plain TCP echo backend (proxy terminates TLS, sends plain to backend) + let _backend = start_echo_server(backend_port).await; + + let options = RustProxyOptions { + routes: vec![make_tls_terminate_route( + proxy_port, domain, "127.0.0.1", backend_port, &cert_pem, &key_pem, + )], + ..Default::default() + }; + + let mut proxy = RustProxy::new(options).unwrap(); + proxy.start().await.unwrap(); + assert!(wait_for_port(proxy_port, 2000).await); + + // Connect with TLS client + let result = with_timeout(async { + let tls_config = make_insecure_tls_client_config(); + let connector = tokio_rustls::TlsConnector::from(tls_config); + + let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port)) + .await + .unwrap(); + + let server_name = rustls::pki_types::ServerName::try_from(domain.to_string()).unwrap(); + let mut tls_stream = connector.connect(server_name, stream).await.unwrap(); + + tls_stream.write_all(b"hello TLS").await.unwrap(); + + let mut buf = vec![0u8; 1024]; + let n = tls_stream.read(&mut buf).await.unwrap(); + String::from_utf8_lossy(&buf[..n]).to_string() + }, 10) + .await + .unwrap(); + + assert_eq!(result, "hello TLS"); + + proxy.stop().await.unwrap(); +} + +#[tokio::test] +async fn test_tls_terminate_and_reencrypt() { + let backend_port = next_port(); + let proxy_port = next_port(); + let domain = "reencrypt.example.com"; + let backend_domain = "backend.internal"; + + // Generate certs + let (proxy_cert, proxy_key) = generate_self_signed_cert(domain); + let (backend_cert, backend_key) = generate_self_signed_cert(backend_domain); + + // Start TLS echo backend + let _backend = start_tls_echo_server(backend_port, &backend_cert, &backend_key).await; + + // Create terminate-and-reencrypt route + let mut route = make_tls_terminate_route( + proxy_port, domain, "127.0.0.1", backend_port, &proxy_cert, &proxy_key, + ); + route.action.tls.as_mut().unwrap().mode = rustproxy_config::TlsMode::TerminateAndReencrypt; + + let options = RustProxyOptions { + routes: vec![route], + ..Default::default() + }; + + let mut proxy = RustProxy::new(options).unwrap(); + proxy.start().await.unwrap(); + assert!(wait_for_port(proxy_port, 2000).await); + + let result = with_timeout(async { + let tls_config = make_insecure_tls_client_config(); + let connector = tokio_rustls::TlsConnector::from(tls_config); + + let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port)) + .await + .unwrap(); + + let server_name = rustls::pki_types::ServerName::try_from(domain.to_string()).unwrap(); + let mut tls_stream = connector.connect(server_name, stream).await.unwrap(); + + tls_stream.write_all(b"hello reencrypt").await.unwrap(); + + let mut buf = vec![0u8; 1024]; + let n = tls_stream.read(&mut buf).await.unwrap(); + String::from_utf8_lossy(&buf[..n]).to_string() + }, 10) + .await + .unwrap(); + + assert_eq!(result, "hello reencrypt"); + + proxy.stop().await.unwrap(); +} + +#[tokio::test] +async fn test_tls_terminate_sni_cert_selection() { + let backend1_port = next_port(); + let backend2_port = next_port(); + let proxy_port = next_port(); + + let (cert1, key1) = generate_self_signed_cert("alpha.example.com"); + let (cert2, key2) = generate_self_signed_cert("beta.example.com"); + + let _b1 = start_prefix_echo_server(backend1_port, "ALPHA:").await; + let _b2 = start_prefix_echo_server(backend2_port, "BETA:").await; + + let options = RustProxyOptions { + routes: vec![ + make_tls_terminate_route(proxy_port, "alpha.example.com", "127.0.0.1", backend1_port, &cert1, &key1), + make_tls_terminate_route(proxy_port, "beta.example.com", "127.0.0.1", backend2_port, &cert2, &key2), + ], + ..Default::default() + }; + + let mut proxy = RustProxy::new(options).unwrap(); + proxy.start().await.unwrap(); + assert!(wait_for_port(proxy_port, 2000).await); + + // Test alpha domain + let result = with_timeout(async { + let tls_config = make_insecure_tls_client_config(); + let connector = tokio_rustls::TlsConnector::from(tls_config); + + let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port)) + .await + .unwrap(); + + let server_name = rustls::pki_types::ServerName::try_from("alpha.example.com".to_string()).unwrap(); + let mut tls_stream = connector.connect(server_name, stream).await.unwrap(); + + tls_stream.write_all(b"test").await.unwrap(); + + let mut buf = vec![0u8; 1024]; + let n = tls_stream.read(&mut buf).await.unwrap(); + String::from_utf8_lossy(&buf[..n]).to_string() + }, 10) + .await + .unwrap(); + + assert!(result.starts_with("ALPHA:"), "Expected ALPHA prefix, got: {}", result); + + proxy.stop().await.unwrap(); +} + +#[tokio::test] +async fn test_tls_terminate_large_payload() { + let backend_port = next_port(); + let proxy_port = next_port(); + let domain = "large.example.com"; + + let (cert_pem, key_pem) = generate_self_signed_cert(domain); + let _backend = start_echo_server(backend_port).await; + + let options = RustProxyOptions { + routes: vec![make_tls_terminate_route( + proxy_port, domain, "127.0.0.1", backend_port, &cert_pem, &key_pem, + )], + ..Default::default() + }; + + let mut proxy = RustProxy::new(options).unwrap(); + proxy.start().await.unwrap(); + assert!(wait_for_port(proxy_port, 2000).await); + + let result = with_timeout(async { + let tls_config = make_insecure_tls_client_config(); + let connector = tokio_rustls::TlsConnector::from(tls_config); + + let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port)) + .await + .unwrap(); + + let server_name = rustls::pki_types::ServerName::try_from(domain.to_string()).unwrap(); + let mut tls_stream = connector.connect(server_name, stream).await.unwrap(); + + // Send 1MB of data + let data = vec![b'X'; 1_000_000]; + tls_stream.write_all(&data).await.unwrap(); + tls_stream.shutdown().await.unwrap(); + + let mut received = Vec::new(); + tls_stream.read_to_end(&mut received).await.unwrap(); + received.len() + }, 15) + .await + .unwrap(); + + assert_eq!(result, 1_000_000); + + proxy.stop().await.unwrap(); +} + +#[tokio::test] +async fn test_tls_terminate_concurrent() { + let backend_port = next_port(); + let proxy_port = next_port(); + let domain = "concurrent.example.com"; + + let (cert_pem, key_pem) = generate_self_signed_cert(domain); + let _backend = start_echo_server(backend_port).await; + + let options = RustProxyOptions { + routes: vec![make_tls_terminate_route( + proxy_port, domain, "127.0.0.1", backend_port, &cert_pem, &key_pem, + )], + ..Default::default() + }; + + let mut proxy = RustProxy::new(options).unwrap(); + proxy.start().await.unwrap(); + assert!(wait_for_port(proxy_port, 2000).await); + + let result = with_timeout(async { + let mut handles = Vec::new(); + for i in 0..10 { + let port = proxy_port; + let dom = domain.to_string(); + handles.push(tokio::spawn(async move { + let tls_config = make_insecure_tls_client_config(); + let connector = tokio_rustls::TlsConnector::from(tls_config); + + let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", port)) + .await + .unwrap(); + + let server_name = rustls::pki_types::ServerName::try_from(dom).unwrap(); + let mut tls_stream = connector.connect(server_name, stream).await.unwrap(); + + let msg = format!("conn-{}", i); + tls_stream.write_all(msg.as_bytes()).await.unwrap(); + + let mut buf = vec![0u8; 1024]; + let n = tls_stream.read(&mut buf).await.unwrap(); + String::from_utf8_lossy(&buf[..n]).to_string() + })); + } + + let mut results = Vec::new(); + for handle in handles { + results.push(handle.await.unwrap()); + } + results + }, 15) + .await + .unwrap(); + + assert_eq!(result.len(), 10); + for (i, r) in result.iter().enumerate() { + assert_eq!(r, &format!("conn-{}", i)); + } + + proxy.stop().await.unwrap(); +} diff --git a/test/test.acme-route-creation.ts b/test/test.acme-route-creation.ts deleted file mode 100644 index 1e86b4f..0000000 --- a/test/test.acme-route-creation.ts +++ /dev/null @@ -1,218 +0,0 @@ -import { tap, expect } from '@git.zone/tstest/tapbundle'; -import { SmartProxy } from '../ts/index.js'; -import * as plugins from '../ts/plugins.js'; - -/** - * Test that verifies ACME challenge routes are properly created - */ -tap.test('should create ACME challenge route', async (tools) => { - tools.timeout(5000); - - // Create a challenge route manually to test its structure - const challengeRoute = { - name: 'acme-challenge', - priority: 1000, - match: { - ports: 18080, - path: '/.well-known/acme-challenge/*' - }, - action: { - type: 'socket-handler' as const, - socketHandler: (socket: any, context: any) => { - socket.once('data', (data: Buffer) => { - const request = data.toString(); - const lines = request.split('\r\n'); - const [method, path] = lines[0].split(' '); - const token = path?.split('/').pop() || ''; - - const response = [ - 'HTTP/1.1 200 OK', - 'Content-Type: text/plain', - `Content-Length: ${token.length}`, - 'Connection: close', - '', - token - ].join('\r\n'); - - socket.write(response); - socket.end(); - }); - } - } - }; - - // Test that the challenge route has the correct structure - expect(challengeRoute).toBeDefined(); - expect(challengeRoute.match.path).toEqual('/.well-known/acme-challenge/*'); - expect(challengeRoute.match.ports).toEqual(18080); - expect(challengeRoute.action.type).toEqual('socket-handler'); - expect(challengeRoute.priority).toEqual(1000); - - // Create a proxy with the challenge route - const settings = { - routes: [ - { - name: 'secure-route', - match: { - ports: [18443], - domains: 'test.local' - }, - action: { - type: 'forward' as const, - targets: [{ host: 'localhost', port: 8080 }] - } - }, - challengeRoute - ] - }; - - const proxy = new SmartProxy(settings); - - // Mock NFTables manager - (proxy as any).nftablesManager = { - ensureNFTablesSetup: async () => {}, - stop: async () => {} - }; - - // Mock certificate manager to prevent real ACME initialization - (proxy as any).createCertificateManager = async function() { - return { - setUpdateRoutesCallback: () => {}, - setHttpProxy: () => {}, - setGlobalAcmeDefaults: () => {}, - setAcmeStateManager: () => {}, - initialize: async () => {}, - provisionAllCertificates: async () => {}, - stop: async () => {}, - getAcmeOptions: () => ({}), - getState: () => ({ challengeRouteActive: false }) - }; - }; - - await proxy.start(); - - // Verify the challenge route is in the proxy's routes - const proxyRoutes = proxy.routeManager.getRoutes(); - const foundChallengeRoute = proxyRoutes.find((r: any) => r.name === 'acme-challenge'); - - expect(foundChallengeRoute).toBeDefined(); - expect(foundChallengeRoute?.match.path).toEqual('/.well-known/acme-challenge/*'); - - await proxy.stop(); -}); - -tap.test('should handle HTTP request parsing correctly', async (tools) => { - tools.timeout(5000); - - let handlerCalled = false; - let receivedContext: any; - let parsedRequest: any = {}; - - const settings = { - routes: [ - { - name: 'test-static', - match: { - ports: [18090], - path: '/test/*' - }, - action: { - type: 'socket-handler' as const, - socketHandler: (socket, context) => { - handlerCalled = true; - receivedContext = context; - - // Parse HTTP request from socket - socket.once('data', (data) => { - const request = data.toString(); - const lines = request.split('\r\n'); - const [method, path, protocol] = lines[0].split(' '); - - // Parse headers - const headers: any = {}; - for (let i = 1; i < lines.length; i++) { - if (lines[i] === '') break; - const [key, value] = lines[i].split(': '); - if (key && value) { - headers[key.toLowerCase()] = value; - } - } - - // Store parsed request data - parsedRequest = { method, path, headers }; - - // Send HTTP response - const response = [ - 'HTTP/1.1 200 OK', - 'Content-Type: text/plain', - 'Content-Length: 2', - 'Connection: close', - '', - 'OK' - ].join('\r\n'); - - socket.write(response); - socket.end(); - }); - } - } - } - ] - }; - - const proxy = new SmartProxy(settings); - - // Mock NFTables manager - (proxy as any).nftablesManager = { - ensureNFTablesSetup: async () => {}, - stop: async () => {} - }; - - await proxy.start(); - - // Create a simple HTTP request - const client = new plugins.net.Socket(); - - await new Promise((resolve, reject) => { - client.connect(18090, 'localhost', () => { - // Send HTTP request - const request = [ - 'GET /test/example HTTP/1.1', - 'Host: localhost:18090', - 'User-Agent: test-client', - '', - '' - ].join('\r\n'); - - client.write(request); - - // Wait for response - client.on('data', (data) => { - const response = data.toString(); - expect(response).toContain('HTTP/1.1 200'); - expect(response).toContain('OK'); - client.end(); - resolve(); - }); - }); - - client.on('error', reject); - }); - - // Verify handler was called - expect(handlerCalled).toBeTrue(); - expect(receivedContext).toBeDefined(); - - // The context passed to socket handlers is IRouteContext, not HTTP request data - expect(receivedContext.port).toEqual(18090); - expect(receivedContext.routeName).toEqual('test-static'); - - // Verify the parsed HTTP request data - expect(parsedRequest.path).toEqual('/test/example'); - expect(parsedRequest.method).toEqual('GET'); - expect(parsedRequest.headers.host).toEqual('localhost:18090'); - - await proxy.stop(); -}); - -export default tap.start(); \ No newline at end of file diff --git a/test/test.acme-state-manager.node.ts b/test/test.acme-state-manager.node.ts deleted file mode 100644 index d0969c8..0000000 --- a/test/test.acme-state-manager.node.ts +++ /dev/null @@ -1,188 +0,0 @@ -import { expect, tap } from '@git.zone/tstest/tapbundle'; -import { AcmeStateManager } from '../ts/proxies/smart-proxy/acme-state-manager.js'; -import type { IRouteConfig } from '../ts/proxies/smart-proxy/models/route-types.js'; - -tap.test('AcmeStateManager should track challenge routes correctly', async (tools) => { - const stateManager = new AcmeStateManager(); - - const challengeRoute: IRouteConfig = { - name: 'acme-challenge', - priority: 1000, - match: { - ports: 80, - path: '/.well-known/acme-challenge/*' - }, - action: { - type: 'socket-handler', - socketHandler: async (socket, context) => { - // Mock handler that would write the challenge response - socket.end('challenge response'); - } - } - }; - - // Initially no challenge routes - expect(stateManager.isChallengeRouteActive()).toBeFalse(); - expect(stateManager.getActiveChallengeRoutes()).toEqual([]); - - // Add challenge route - stateManager.addChallengeRoute(challengeRoute); - expect(stateManager.isChallengeRouteActive()).toBeTrue(); - expect(stateManager.getActiveChallengeRoutes()).toHaveProperty("length", 1); - expect(stateManager.getPrimaryChallengeRoute()).toEqual(challengeRoute); - - // Remove challenge route - stateManager.removeChallengeRoute('acme-challenge'); - expect(stateManager.isChallengeRouteActive()).toBeFalse(); - expect(stateManager.getActiveChallengeRoutes()).toEqual([]); - expect(stateManager.getPrimaryChallengeRoute()).toBeNull(); -}); - -tap.test('AcmeStateManager should track port allocations', async (tools) => { - const stateManager = new AcmeStateManager(); - - const challengeRoute1: IRouteConfig = { - name: 'acme-challenge-1', - priority: 1000, - match: { - ports: 80, - path: '/.well-known/acme-challenge/*' - }, - action: { - type: 'socket-handler' - } - }; - - const challengeRoute2: IRouteConfig = { - name: 'acme-challenge-2', - priority: 900, - match: { - ports: [80, 8080], - path: '/.well-known/acme-challenge/*' - }, - action: { - type: 'socket-handler' - } - }; - - // Add first route - stateManager.addChallengeRoute(challengeRoute1); - expect(stateManager.isPortAllocatedForAcme(80)).toBeTrue(); - expect(stateManager.isPortAllocatedForAcme(8080)).toBeFalse(); - expect(stateManager.getAcmePorts()).toEqual([80]); - - // Add second route - stateManager.addChallengeRoute(challengeRoute2); - expect(stateManager.isPortAllocatedForAcme(80)).toBeTrue(); - expect(stateManager.isPortAllocatedForAcme(8080)).toBeTrue(); - expect(stateManager.getAcmePorts()).toContain(80); - expect(stateManager.getAcmePorts()).toContain(8080); - - // Remove first route - port 80 should still be allocated - stateManager.removeChallengeRoute('acme-challenge-1'); - expect(stateManager.isPortAllocatedForAcme(80)).toBeTrue(); - expect(stateManager.isPortAllocatedForAcme(8080)).toBeTrue(); - - // Remove second route - all ports should be deallocated - stateManager.removeChallengeRoute('acme-challenge-2'); - expect(stateManager.isPortAllocatedForAcme(80)).toBeFalse(); - expect(stateManager.isPortAllocatedForAcme(8080)).toBeFalse(); - expect(stateManager.getAcmePorts()).toEqual([]); -}); - -tap.test('AcmeStateManager should select primary route by priority', async (tools) => { - const stateManager = new AcmeStateManager(); - - const lowPriorityRoute: IRouteConfig = { - name: 'low-priority', - priority: 100, - match: { - ports: 80 - }, - action: { - type: 'socket-handler' - } - }; - - const highPriorityRoute: IRouteConfig = { - name: 'high-priority', - priority: 2000, - match: { - ports: 80 - }, - action: { - type: 'socket-handler' - } - }; - - const defaultPriorityRoute: IRouteConfig = { - name: 'default-priority', - // No priority specified - should default to 0 - match: { - ports: 80 - }, - action: { - type: 'socket-handler' - } - }; - - // Add low priority first - stateManager.addChallengeRoute(lowPriorityRoute); - expect(stateManager.getPrimaryChallengeRoute()?.name).toEqual('low-priority'); - - // Add high priority - should become primary - stateManager.addChallengeRoute(highPriorityRoute); - expect(stateManager.getPrimaryChallengeRoute()?.name).toEqual('high-priority'); - - // Add default priority - primary should remain high priority - stateManager.addChallengeRoute(defaultPriorityRoute); - expect(stateManager.getPrimaryChallengeRoute()?.name).toEqual('high-priority'); - - // Remove high priority - primary should fall back to low priority - stateManager.removeChallengeRoute('high-priority'); - expect(stateManager.getPrimaryChallengeRoute()?.name).toEqual('low-priority'); -}); - -tap.test('AcmeStateManager should handle clear operation', async (tools) => { - const stateManager = new AcmeStateManager(); - - const challengeRoute1: IRouteConfig = { - name: 'route-1', - match: { - ports: [80, 443] - }, - action: { - type: 'socket-handler' - } - }; - - const challengeRoute2: IRouteConfig = { - name: 'route-2', - match: { - ports: 8080 - }, - action: { - type: 'socket-handler' - } - }; - - // Add routes - stateManager.addChallengeRoute(challengeRoute1); - stateManager.addChallengeRoute(challengeRoute2); - - // Verify state before clear - expect(stateManager.isChallengeRouteActive()).toBeTrue(); - expect(stateManager.getActiveChallengeRoutes()).toHaveProperty("length", 2); - expect(stateManager.getAcmePorts()).toHaveProperty("length", 3); - - // Clear all state - stateManager.clear(); - - // Verify state after clear - expect(stateManager.isChallengeRouteActive()).toBeFalse(); - expect(stateManager.getActiveChallengeRoutes()).toEqual([]); - expect(stateManager.getAcmePorts()).toEqual([]); - expect(stateManager.getPrimaryChallengeRoute()).toBeNull(); -}); - -export default tap.start(); \ No newline at end of file diff --git a/test/test.acme-timing-simple.ts b/test/test.acme-timing-simple.ts deleted file mode 100644 index 39b738b..0000000 --- a/test/test.acme-timing-simple.ts +++ /dev/null @@ -1,122 +0,0 @@ -import { tap, expect } from '@git.zone/tstest/tapbundle'; -import { SmartProxy } from '../ts/index.js'; - -// Test that certificate provisioning is deferred until after ports are listening -tap.test('should defer certificate provisioning until ports are ready', async (tapTest) => { - // Track when operations happen - let portsListening = false; - let certProvisioningStarted = false; - let operationOrder: string[] = []; - - // Create proxy with certificate route but without real ACME - const proxy = new SmartProxy({ - routes: [{ - name: 'test-route', - match: { - ports: 8443, - domains: ['test.local'] - }, - action: { - type: 'forward', - targets: [{ host: 'localhost', port: 8181 }], - tls: { - mode: 'terminate', - certificate: 'auto', - acme: { - email: 'test@local.dev', - useProduction: false - } - } - } - }] - }); - - // Override the certificate manager creation to avoid real ACME - const originalCreateCertManager = proxy['createCertificateManager']; - proxy['createCertificateManager'] = async function(...args: any[]) { - console.log('Creating mock cert manager'); - operationOrder.push('create-cert-manager'); - const mockCertManager = { - certStore: null, - smartAcme: null, - httpProxy: null, - renewalTimer: null, - pendingChallenges: new Map(), - challengeRoute: null, - certStatus: new Map(), - globalAcmeDefaults: null, - updateRoutesCallback: undefined, - challengeRouteActive: false, - isProvisioning: false, - acmeStateManager: null, - initialize: async () => { - operationOrder.push('cert-manager-init'); - console.log('Mock cert manager initialized'); - }, - provisionAllCertificates: async () => { - operationOrder.push('cert-provisioning'); - certProvisioningStarted = true; - // Check that ports are listening when provisioning starts - if (!portsListening) { - throw new Error('Certificate provisioning started before ports ready!'); - } - console.log('Mock certificate provisioning (ports are ready)'); - }, - stop: async () => {}, - setHttpProxy: () => {}, - setGlobalAcmeDefaults: () => {}, - setAcmeStateManager: () => {}, - setUpdateRoutesCallback: () => {}, - getAcmeOptions: () => ({}), - getState: () => ({ challengeRouteActive: false }), - getCertStatus: () => new Map(), - checkAndRenewCertificates: async () => {}, - addChallengeRoute: async () => {}, - removeChallengeRoute: async () => {}, - getCertificate: async () => null, - isValidCertificate: () => false, - waitForProvisioning: async () => {} - } as any; - - // Call initialize immediately as the real createCertificateManager does - await mockCertManager.initialize(); - - return mockCertManager; - }; - - // Track port manager operations - const originalAddPorts = proxy['portManager'].addPorts; - proxy['portManager'].addPorts = async function(ports: number[]) { - operationOrder.push('ports-starting'); - const result = await originalAddPorts.call(this, ports); - operationOrder.push('ports-ready'); - portsListening = true; - console.log('Ports are now listening'); - return result; - }; - - // Start the proxy - await proxy.start(); - - // Log the operation order for debugging - console.log('Operation order:', operationOrder); - - // Verify operations happened in the correct order - expect(operationOrder).toContain('create-cert-manager'); - expect(operationOrder).toContain('cert-manager-init'); - expect(operationOrder).toContain('ports-starting'); - expect(operationOrder).toContain('ports-ready'); - expect(operationOrder).toContain('cert-provisioning'); - - // Verify ports were ready before certificate provisioning - const portsReadyIndex = operationOrder.indexOf('ports-ready'); - const certProvisioningIndex = operationOrder.indexOf('cert-provisioning'); - - expect(portsReadyIndex).toBeLessThan(certProvisioningIndex); - expect(certProvisioningStarted).toEqual(true); - expect(portsListening).toEqual(true); - - await proxy.stop(); -}); - -export default tap.start(); \ No newline at end of file diff --git a/test/test.acme-timing.ts b/test/test.acme-timing.ts deleted file mode 100644 index 33677d7..0000000 --- a/test/test.acme-timing.ts +++ /dev/null @@ -1,204 +0,0 @@ -import { tap, expect } from '@git.zone/tstest/tapbundle'; -import { SmartProxy } from '../ts/index.js'; -import * as net from 'net'; - -// Test that certificate provisioning waits for ports to be ready -tap.test('should defer certificate provisioning until after ports are listening', async (tapTest) => { - // Track the order of operations - const operationLog: string[] = []; - - // Create a mock server to verify ports are listening - let port80Listening = false; - - // Try to use port 8080 instead of 80 to avoid permission issues in testing - const acmePort = 8080; - - // Create proxy with ACME certificate requirement - const proxy = new SmartProxy({ - useHttpProxy: [acmePort], - httpProxyPort: 8845, // Use different port to avoid conflicts - acme: { - email: 'test@test.local', - useProduction: false, - port: acmePort - }, - routes: [{ - name: 'test-acme-route', - match: { - ports: 8443, - domains: ['test.local'] - }, - action: { - type: 'forward', - targets: [{ host: 'localhost', port: 8181 }], - tls: { - mode: 'terminate', - certificate: 'auto', - acme: { - email: 'test@test.local', - useProduction: false - } - } - } - }] - }); - - // Mock some internal methods to track operation order - const originalAddPorts = proxy['portManager'].addPorts; - proxy['portManager'].addPorts = async function(ports: number[]) { - operationLog.push('Starting port listeners'); - const result = await originalAddPorts.call(this, ports); - operationLog.push('Port listeners started'); - port80Listening = true; - return result; - }; - - // Track that we created a certificate manager and SmartProxy will call provisionAllCertificates - let certManagerCreated = false; - - // Override createCertificateManager to set up our tracking - const originalCreateCertManager = (proxy as any).createCertificateManager; - (proxy as any).certManagerCreated = false; - - // Mock certificate manager to avoid real ACME initialization - (proxy as any).createCertificateManager = async function() { - operationLog.push('Creating certificate manager'); - const mockCertManager = { - setUpdateRoutesCallback: () => {}, - setHttpProxy: () => {}, - setGlobalAcmeDefaults: () => {}, - setAcmeStateManager: () => {}, - initialize: async () => { - operationLog.push('Certificate manager initialized'); - }, - provisionAllCertificates: async () => { - operationLog.push('Starting certificate provisioning'); - if (!port80Listening) { - operationLog.push('ERROR: Certificate provisioning started before ports ready'); - } - operationLog.push('Certificate provisioning completed'); - }, - stop: async () => {}, - getAcmeOptions: () => ({ email: 'test@test.local', useProduction: false }), - getState: () => ({ challengeRouteActive: false }) - }; - certManagerCreated = true; - (proxy as any).certManager = mockCertManager; - return mockCertManager; - }; - - // Start the proxy - await proxy.start(); - - // Verify the order of operations - expect(operationLog).toContain('Starting port listeners'); - expect(operationLog).toContain('Port listeners started'); - expect(operationLog).toContain('Starting certificate provisioning'); - - // Ensure port listeners started before certificate provisioning - const portStartIndex = operationLog.indexOf('Port listeners started'); - const certStartIndex = operationLog.indexOf('Starting certificate provisioning'); - - expect(portStartIndex).toBeLessThan(certStartIndex); - expect(operationLog).not.toContain('ERROR: Certificate provisioning started before ports ready'); - - await proxy.stop(); -}); - -// Test that ACME challenge route is available when certificate is requested -tap.test('should have ACME challenge route ready before certificate provisioning', async (tapTest) => { - let challengeRouteActive = false; - let certificateProvisioningStarted = false; - - const proxy = new SmartProxy({ - useHttpProxy: [8080], - httpProxyPort: 8846, // Use different port to avoid conflicts - acme: { - email: 'test@test.local', - useProduction: false, - port: 8080 - }, - routes: [{ - name: 'test-route', - match: { - ports: 8443, - domains: ['test.example.com'] - }, - action: { - type: 'forward', - targets: [{ host: 'localhost', port: 8181 }], - tls: { - mode: 'terminate', - certificate: 'auto' - } - } - }] - }); - - // Mock the certificate manager to track operations - const originalInitialize = proxy['certManager'] ? - proxy['certManager'].initialize : null; - - if (proxy['certManager']) { - const certManager = proxy['certManager']; - - // Track when challenge route is added - const originalAddChallenge = certManager['addChallengeRoute']; - certManager['addChallengeRoute'] = async function() { - await originalAddChallenge.call(this); - challengeRouteActive = true; - }; - - // Track when certificate provisioning starts - const originalProvisionAcme = certManager['provisionAcmeCertificate']; - certManager['provisionAcmeCertificate'] = async function(...args: any[]) { - certificateProvisioningStarted = true; - // Verify challenge route is active - expect(challengeRouteActive).toEqual(true); - // Don't actually provision in test - return; - }; - } - - // Mock certificate manager to avoid real ACME initialization - (proxy as any).createCertificateManager = async function() { - const mockCertManager = { - setUpdateRoutesCallback: () => {}, - setHttpProxy: () => {}, - setGlobalAcmeDefaults: () => {}, - setAcmeStateManager: () => {}, - initialize: async () => { - challengeRouteActive = true; - }, - provisionAllCertificates: async () => { - certificateProvisioningStarted = true; - expect(challengeRouteActive).toEqual(true); - }, - stop: async () => {}, - getAcmeOptions: () => ({ email: 'test@test.local', useProduction: false }), - getState: () => ({ challengeRouteActive: false }), - addChallengeRoute: async () => { - challengeRouteActive = true; - }, - provisionAcmeCertificate: async () => { - certificateProvisioningStarted = true; - expect(challengeRouteActive).toEqual(true); - } - }; - // Call initialize like the real createCertificateManager does - await mockCertManager.initialize(); - return mockCertManager; - }; - - await proxy.start(); - - // Give it a moment to complete initialization - await new Promise(resolve => setTimeout(resolve, 100)); - - // Verify challenge route was added before any certificate provisioning - expect(challengeRouteActive).toEqual(true); - - await proxy.stop(); -}); - -export default tap.start(); \ No newline at end of file diff --git a/test/test.certificate-acme-update.ts b/test/test.certificate-acme-update.ts deleted file mode 100644 index 29c9e42..0000000 --- a/test/test.certificate-acme-update.ts +++ /dev/null @@ -1,77 +0,0 @@ -import { expect, tap } from '@git.zone/tstest/tapbundle'; -import * as plugins from '../ts/plugins.js'; -import * as smartproxy from '../ts/index.js'; - -// This test verifies that SmartProxy correctly uses the updated SmartAcme v8.0.0 API -// with the optional wildcard parameter - -tap.test('SmartCertManager should call getCertificateForDomain with wildcard option', async () => { - console.log('Testing SmartCertManager with SmartAcme v8.0.0 API...'); - - // Create a mock route with ACME certificate configuration - const mockRoute: smartproxy.IRouteConfig = { - match: { - domains: ['test.example.com'], - ports: 443 - }, - action: { - type: 'forward', - targets: [{ - host: 'localhost', - port: 8080 - }], - tls: { - mode: 'terminate', - certificate: 'auto', - acme: { - email: 'test@example.com', - useProduction: false - } - } - }, - name: 'test-route' - }; - - // Create a certificate manager - const certManager = new smartproxy.SmartCertManager( - [mockRoute], - './test-certs', - { - email: 'test@example.com', - useProduction: false - } - ); - - // Since we can't actually test ACME in a unit test, we'll just verify the logic - // The actual test would be that it builds and runs without errors - - // Test the wildcard logic for different domain types and challenge handlers - const testCases = [ - { domain: 'example.com', hasDnsChallenge: true, shouldIncludeWildcard: true }, - { domain: 'example.com', hasDnsChallenge: false, shouldIncludeWildcard: false }, - { domain: 'sub.example.com', hasDnsChallenge: true, shouldIncludeWildcard: true }, - { domain: 'sub.example.com', hasDnsChallenge: false, shouldIncludeWildcard: false }, - { domain: '*.example.com', hasDnsChallenge: true, shouldIncludeWildcard: false }, - { domain: '*.example.com', hasDnsChallenge: false, shouldIncludeWildcard: false }, - { domain: 'test', hasDnsChallenge: true, shouldIncludeWildcard: false }, // single label domain - { domain: 'test', hasDnsChallenge: false, shouldIncludeWildcard: false }, - { domain: 'my.sub.example.com', hasDnsChallenge: true, shouldIncludeWildcard: true }, - { domain: 'my.sub.example.com', hasDnsChallenge: false, shouldIncludeWildcard: false } - ]; - - for (const testCase of testCases) { - const shouldIncludeWildcard = !testCase.domain.startsWith('*.') && - testCase.domain.includes('.') && - testCase.domain.split('.').length >= 2 && - testCase.hasDnsChallenge; - - console.log(`Domain: ${testCase.domain}, DNS-01: ${testCase.hasDnsChallenge}, Should include wildcard: ${shouldIncludeWildcard}`); - expect(shouldIncludeWildcard).toEqual(testCase.shouldIncludeWildcard); - } - - console.log('All wildcard logic tests passed!'); -}); - -tap.start({ - throwOnError: true -}); \ No newline at end of file diff --git a/test/test.certificate-provision.ts b/test/test.certificate-provision.ts deleted file mode 100644 index f962b68..0000000 --- a/test/test.certificate-provision.ts +++ /dev/null @@ -1,423 +0,0 @@ -import { expect, tap } from '@git.zone/tstest/tapbundle'; -import { SmartProxy } from '../ts/index.js'; -import type { TSmartProxyCertProvisionObject } from '../ts/index.js'; -import * as fs from 'fs'; -import * as path from 'path'; -import { fileURLToPath } from 'url'; - -const __filename = fileURLToPath(import.meta.url); -const __dirname = path.dirname(__filename); - -let testProxy: SmartProxy; - -// Load test certificates from helpers -const testCert = fs.readFileSync(path.join(__dirname, 'helpers/test-cert.pem'), 'utf8'); -const testKey = fs.readFileSync(path.join(__dirname, 'helpers/test-key.pem'), 'utf8'); - -// Helper to create a fully mocked certificate manager that doesn't contact ACME servers -function createMockCertManager(options: { - onProvisionAll?: () => void; - onGetCertForDomain?: (domain: string) => void; -} = {}) { - return { - setUpdateRoutesCallback: function(callback: any) { - this.updateRoutesCallback = callback; - }, - updateRoutesCallback: null as any, - setHttpProxy: function() {}, - setGlobalAcmeDefaults: function() {}, - setAcmeStateManager: function() {}, - setRoutes: function(routes: any) {}, - initialize: async function() {}, - provisionAllCertificates: async function() { - if (options.onProvisionAll) { - options.onProvisionAll(); - } - }, - stop: async function() {}, - getAcmeOptions: function() { - return { email: 'test@example.com', useProduction: false }; - }, - getState: function() { - return { challengeRouteActive: false }; - }, - smartAcme: { - getCertificateForDomain: async (domain: string) => { - if (options.onGetCertForDomain) { - options.onGetCertForDomain(domain); - } - throw new Error('Mocked ACME - not calling real servers'); - } - } - }; -} - -tap.test('SmartProxy should support custom certificate provision function', async () => { - // Create test certificate object matching ICert interface - const testCertObject = { - id: 'test-cert-1', - domainName: 'test.example.com', - created: Date.now(), - validUntil: Date.now() + 90 * 24 * 60 * 60 * 1000, // 90 days - privateKey: testKey, - publicKey: testCert, - csr: '' - }; - - // Custom certificate store for testing - const customCerts = new Map(); - customCerts.set('test.example.com', testCertObject); - - // Create proxy with custom certificate provision - testProxy = new SmartProxy({ - certProvisionFunction: async (domain: string): Promise => { - console.log(`Custom cert provision called for domain: ${domain}`); - - // Return custom cert for known domains - if (customCerts.has(domain)) { - console.log(`Returning custom certificate for ${domain}`); - return customCerts.get(domain)!; - } - - // Fallback to Let's Encrypt for other domains - console.log(`Falling back to Let's Encrypt for ${domain}`); - return 'http01'; - }, - certProvisionFallbackToAcme: true, - acme: { - email: 'test@example.com', - useProduction: false - }, - routes: [ - { - name: 'test-route', - match: { - ports: [443], - domains: ['test.example.com'] - }, - action: { - type: 'forward', - targets: [{ - host: 'localhost', - port: 8080 - }], - tls: { - mode: 'terminate', - certificate: 'auto' - } - } - } - ] - }); - - expect(testProxy).toBeInstanceOf(SmartProxy); -}); - -tap.test('Custom certificate provision function should be called', async () => { - let provisionCalled = false; - const provisionedDomains: string[] = []; - - const testProxy2 = new SmartProxy({ - certProvisionFunction: async (domain: string): Promise => { - provisionCalled = true; - provisionedDomains.push(domain); - - // Return a test certificate matching ICert interface - return { - id: `test-cert-${domain}`, - domainName: domain, - created: Date.now(), - validUntil: Date.now() + 90 * 24 * 60 * 60 * 1000, - privateKey: testKey, - publicKey: testCert, - csr: '' - }; - }, - acme: { - email: 'test@example.com', - useProduction: false, - port: 9080 - }, - routes: [ - { - name: 'custom-cert-route', - match: { - ports: [9443], - domains: ['custom.example.com'] - }, - action: { - type: 'forward', - targets: [{ - host: 'localhost', - port: 8080 - }], - tls: { - mode: 'terminate', - certificate: 'auto' - } - } - } - ] - }); - - // Fully mock the certificate manager to avoid ACME server contact - let certManagerCalled = false; - (testProxy2 as any).createCertificateManager = async function() { - const mockCertManager = createMockCertManager({ - onProvisionAll: () => { - certManagerCalled = true; - // Simulate calling the provision function - testProxy2.settings.certProvisionFunction?.('custom.example.com'); - } - }); - - // Set callback as in real implementation - mockCertManager.setUpdateRoutesCallback(async (routes: any) => { - await this.updateRoutes(routes); - }); - - return mockCertManager; - }; - - // Start the proxy (this will trigger certificate provisioning) - await testProxy2.start(); - - expect(certManagerCalled).toBeTrue(); - expect(provisionCalled).toBeTrue(); - expect(provisionedDomains).toContain('custom.example.com'); - - await testProxy2.stop(); -}); - -tap.test('Should fallback to ACME when custom provision fails', async () => { - const failedDomains: string[] = []; - let acmeAttempted = false; - - const testProxy3 = new SmartProxy({ - certProvisionFunction: async (domain: string): Promise => { - failedDomains.push(domain); - throw new Error('Custom provision failed for testing'); - }, - certProvisionFallbackToAcme: true, - acme: { - email: 'test@example.com', - useProduction: false, - port: 9080 - }, - routes: [ - { - name: 'fallback-route', - match: { - ports: [9444], - domains: ['fallback.example.com'] - }, - action: { - type: 'forward', - targets: [{ - host: 'localhost', - port: 8080 - }], - tls: { - mode: 'terminate', - certificate: 'auto' - } - } - } - ] - }); - - // Fully mock the certificate manager to avoid ACME server contact - (testProxy3 as any).createCertificateManager = async function() { - const mockCertManager = createMockCertManager({ - onProvisionAll: async () => { - // Simulate the provision logic: first try custom function, then ACME - try { - await testProxy3.settings.certProvisionFunction?.('fallback.example.com'); - } catch (e) { - // Custom provision failed, try ACME - acmeAttempted = true; - } - } - }); - - // Set callback as in real implementation - mockCertManager.setUpdateRoutesCallback(async (routes: any) => { - await this.updateRoutes(routes); - }); - - return mockCertManager; - }; - - // Start the proxy - await testProxy3.start(); - - // Custom provision should have failed - expect(failedDomains).toContain('fallback.example.com'); - - // ACME should have been attempted as fallback - expect(acmeAttempted).toBeTrue(); - - await testProxy3.stop(); -}); - -tap.test('Should not fallback when certProvisionFallbackToAcme is false', async () => { - let errorThrown = false; - let errorMessage = ''; - - const testProxy4 = new SmartProxy({ - certProvisionFunction: async (_domain: string): Promise => { - throw new Error('Custom provision failed for testing'); - }, - certProvisionFallbackToAcme: false, - acme: { - email: 'test@example.com', - useProduction: false, - port: 9082 - }, - routes: [ - { - name: 'no-fallback-route', - match: { - ports: [9449], - domains: ['no-fallback.example.com'] - }, - action: { - type: 'forward', - targets: [{ - host: 'localhost', - port: 8080 - }], - tls: { - mode: 'terminate', - certificate: 'auto' - } - } - } - ] - }); - - // Fully mock the certificate manager to avoid ACME server contact - (testProxy4 as any).createCertificateManager = async function() { - const mockCertManager = createMockCertManager({ - onProvisionAll: async () => { - // Simulate the provision logic with no fallback - try { - await testProxy4.settings.certProvisionFunction?.('no-fallback.example.com'); - } catch (e: any) { - errorThrown = true; - errorMessage = e.message; - // With certProvisionFallbackToAcme=false, the error should propagate - if (!testProxy4.settings.certProvisionFallbackToAcme) { - throw e; - } - } - } - }); - - // Set callback as in real implementation - mockCertManager.setUpdateRoutesCallback(async (routes: any) => { - await this.updateRoutes(routes); - }); - - return mockCertManager; - }; - - try { - await testProxy4.start(); - } catch (e) { - // Expected to fail - } - - expect(errorThrown).toBeTrue(); - expect(errorMessage).toInclude('Custom provision failed for testing'); - - await testProxy4.stop(); -}); - -tap.test('Should return http01 for unknown domains', async () => { - let returnedHttp01 = false; - let acmeAttempted = false; - - const testProxy5 = new SmartProxy({ - certProvisionFunction: async (domain: string): Promise => { - if (domain === 'known.example.com') { - return { - id: `test-cert-${domain}`, - domainName: domain, - created: Date.now(), - validUntil: Date.now() + 90 * 24 * 60 * 60 * 1000, - privateKey: testKey, - publicKey: testCert, - csr: '' - }; - } - returnedHttp01 = true; - return 'http01'; - }, - acme: { - email: 'test@example.com', - useProduction: false, - port: 9081 - }, - routes: [ - { - name: 'unknown-domain-route', - match: { - ports: [9446], - domains: ['unknown.example.com'] - }, - action: { - type: 'forward', - targets: [{ - host: 'localhost', - port: 8080 - }], - tls: { - mode: 'terminate', - certificate: 'auto' - } - } - } - ] - }); - - // Fully mock the certificate manager to avoid ACME server contact - (testProxy5 as any).createCertificateManager = async function() { - const mockCertManager = createMockCertManager({ - onProvisionAll: async () => { - // Simulate the provision logic: call provision function first - const result = await testProxy5.settings.certProvisionFunction?.('unknown.example.com'); - if (result === 'http01') { - // http01 means use ACME - acmeAttempted = true; - } - } - }); - - // Set callback as in real implementation - mockCertManager.setUpdateRoutesCallback(async (routes: any) => { - await this.updateRoutes(routes); - }); - - return mockCertManager; - }; - - await testProxy5.start(); - - // Should have returned http01 for unknown domain - expect(returnedHttp01).toBeTrue(); - - // ACME should have been attempted - expect(acmeAttempted).toBeTrue(); - - await testProxy5.stop(); -}); - -tap.test('cleanup', async () => { - // Clean up any test proxies - if (testProxy) { - await testProxy.stop(); - } -}); - -export default tap.start(); diff --git a/test/test.certificate-provisioning.ts b/test/test.certificate-provisioning.ts deleted file mode 100644 index 6cab058..0000000 --- a/test/test.certificate-provisioning.ts +++ /dev/null @@ -1,241 +0,0 @@ -import { SmartProxy } from '../ts/proxies/smart-proxy/index.js'; -import { expect, tap } from '@git.zone/tstest/tapbundle'; - -const testProxy = new SmartProxy({ - routes: [{ - name: 'test-route', - match: { ports: 9443, domains: 'test.local' }, - action: { - type: 'forward', - targets: [{ host: 'localhost', port: 8080 }], - tls: { - mode: 'terminate', - certificate: 'auto', - acme: { - email: 'test@test.local', - useProduction: false - } - } - } - }], - acme: { - port: 9080 // Use high port for ACME challenges - } -}); - -tap.test('should provision certificate automatically', async () => { - // Mock certificate manager to avoid real ACME initialization - const mockCertStatus = { - domain: 'test-route', - status: 'valid' as const, - source: 'acme' as const, - expiryDate: new Date(Date.now() + 90 * 24 * 60 * 60 * 1000), - issueDate: new Date() - }; - - (testProxy as any).createCertificateManager = async function() { - return { - setUpdateRoutesCallback: () => {}, - setHttpProxy: () => {}, - setGlobalAcmeDefaults: () => {}, - setAcmeStateManager: () => {}, - initialize: async () => {}, - provisionAllCertificates: async () => {}, - stop: async () => {}, - getAcmeOptions: () => ({ email: 'test@test.local', useProduction: false }), - getState: () => ({ challengeRouteActive: false }), - getCertificateStatus: () => mockCertStatus - }; - }; - - (testProxy as any).getCertificateStatus = () => mockCertStatus; - - await testProxy.start(); - - const status = testProxy.getCertificateStatus('test-route'); - expect(status).toBeDefined(); - expect(status.status).toEqual('valid'); - expect(status.source).toEqual('acme'); - - await testProxy.stop(); -}); - -tap.test('should handle static certificates', async () => { - const proxy = new SmartProxy({ - routes: [{ - name: 'static-route', - match: { ports: 9444, domains: 'static.example.com' }, - action: { - type: 'forward', - targets: [{ host: 'localhost', port: 8080 }], - tls: { - mode: 'terminate', - certificate: { - cert: '-----BEGIN CERTIFICATE-----\nMIIC...\n-----END CERTIFICATE-----', - key: '-----BEGIN PRIVATE KEY-----\nMIIE...\n-----END PRIVATE KEY-----' - } - } - } - }] - }); - - await proxy.start(); - - const status = proxy.getCertificateStatus('static-route'); - expect(status).toBeDefined(); - expect(status.status).toEqual('valid'); - expect(status.source).toEqual('static'); - - await proxy.stop(); -}); - -tap.test('should handle ACME challenge routes', async () => { - const proxy = new SmartProxy({ - routes: [{ - name: 'auto-cert-route', - match: { ports: 9445, domains: 'acme.local' }, - action: { - type: 'forward', - targets: [{ host: 'localhost', port: 8080 }], - tls: { - mode: 'terminate', - certificate: 'auto', - acme: { - email: 'acme@test.local', - useProduction: false, - challengePort: 9081 - } - } - } - }, { - name: 'port-9081-route', - match: { ports: 9081, domains: 'acme.local' }, - action: { - type: 'forward', - targets: [{ host: 'localhost', port: 8080 }] - } - }], - acme: { - port: 9081 // Use high port for ACME challenges - } - }); - - // Mock certificate manager to avoid real ACME initialization - (proxy as any).createCertificateManager = async function() { - return { - setUpdateRoutesCallback: () => {}, - setHttpProxy: () => {}, - setGlobalAcmeDefaults: () => {}, - setAcmeStateManager: () => {}, - initialize: async () => {}, - provisionAllCertificates: async () => {}, - stop: async () => {}, - getAcmeOptions: () => ({ email: 'acme@test.local', useProduction: false }), - getState: () => ({ challengeRouteActive: false }) - }; - }; - - await proxy.start(); - - // Verify the proxy is configured with routes including the necessary port - const routes = proxy.settings.routes; - - // Check that we have a route listening on the ACME challenge port - const acmeChallengePort = 9081; - const routesOnChallengePort = routes.filter((r: any) => { - const ports = Array.isArray(r.match.ports) ? r.match.ports : [r.match.ports]; - return ports.includes(acmeChallengePort); - }); - - expect(routesOnChallengePort.length).toBeGreaterThan(0); - expect(routesOnChallengePort[0].name).toEqual('port-9081-route'); - - // Verify the main route has ACME configuration - const mainRoute = routes.find((r: any) => r.name === 'auto-cert-route'); - expect(mainRoute).toBeDefined(); - expect(mainRoute?.action.tls?.certificate).toEqual('auto'); - expect(mainRoute?.action.tls?.acme?.email).toEqual('acme@test.local'); - expect(mainRoute?.action.tls?.acme?.challengePort).toEqual(9081); - - await proxy.stop(); -}); - -tap.test('should renew certificates', async () => { - const proxy = new SmartProxy({ - routes: [{ - name: 'renew-route', - match: { ports: 9446, domains: 'renew.local' }, - action: { - type: 'forward', - targets: [{ host: 'localhost', port: 8080 }], - tls: { - mode: 'terminate', - certificate: 'auto', - acme: { - email: 'renew@test.local', - useProduction: false, - renewBeforeDays: 30 - } - } - } - }], - acme: { - port: 9082 // Use high port for ACME challenges - } - }); - - // Mock certificate manager with renewal capability - let renewCalled = false; - const mockCertStatus = { - domain: 'renew-route', - status: 'valid' as const, - source: 'acme' as const, - expiryDate: new Date(Date.now() + 90 * 24 * 60 * 60 * 1000), - issueDate: new Date() - }; - - (proxy as any).certManager = { - renewCertificate: async (routeName: string) => { - renewCalled = true; - expect(routeName).toEqual('renew-route'); - }, - getCertificateStatus: () => mockCertStatus, - setUpdateRoutesCallback: () => {}, - setHttpProxy: () => {}, - setGlobalAcmeDefaults: () => {}, - setAcmeStateManager: () => {}, - initialize: async () => {}, - provisionAllCertificates: async () => {}, - stop: async () => {}, - getAcmeOptions: () => ({ email: 'renew@test.local', useProduction: false }), - getState: () => ({ challengeRouteActive: false }) - }; - - (proxy as any).createCertificateManager = async function() { - return this.certManager; - }; - - (proxy as any).getCertificateStatus = function(routeName: string) { - return this.certManager.getCertificateStatus(routeName); - }; - - (proxy as any).renewCertificate = async function(routeName: string) { - if (this.certManager) { - await this.certManager.renewCertificate(routeName); - } - }; - - await proxy.start(); - - // Force renewal - await proxy.renewCertificate('renew-route'); - expect(renewCalled).toBeTrue(); - - const status = proxy.getCertificateStatus('renew-route'); - expect(status).toBeDefined(); - expect(status.status).toEqual('valid'); - - await proxy.stop(); -}); - -export default tap.start(); \ No newline at end of file diff --git a/test/test.cleanup-queue-bug.node.ts b/test/test.cleanup-queue-bug.node.ts deleted file mode 100644 index 544d026..0000000 --- a/test/test.cleanup-queue-bug.node.ts +++ /dev/null @@ -1,146 +0,0 @@ -import { expect, tap } from '@git.zone/tstest/tapbundle'; -import { SmartProxy } from '../ts/index.js'; - -tap.test('cleanup queue bug - verify queue processing handles more than batch size', async () => { - console.log('\n=== Cleanup Queue Bug Test ==='); - console.log('Purpose: Verify that the cleanup queue correctly processes all connections'); - console.log('even when there are more than the batch size (100)'); - - // Create proxy - const proxy = new SmartProxy({ - routes: [{ - name: 'test-route', - match: { ports: 8588 }, - action: { - type: 'forward', - targets: [{ host: 'localhost', port: 9996 }] - } - }], - enableDetailedLogging: false, - }); - - await proxy.start(); - console.log('โœ“ Proxy started on port 8588'); - - // Access connection manager - const cm = (proxy as any).connectionManager; - - // Create mock connection records - console.log('\n--- Creating 150 mock connections ---'); - const mockConnections: any[] = []; - - for (let i = 0; i < 150; i++) { - // Create mock socket objects with necessary methods - const mockIncoming = { - destroyed: true, - writable: false, - remoteAddress: '127.0.0.1', - removeAllListeners: () => {}, - destroy: () => {}, - end: () => {}, - on: () => {}, - once: () => {}, - emit: () => {}, - pause: () => {}, - resume: () => {} - }; - - const mockOutgoing = { - destroyed: true, - writable: false, - removeAllListeners: () => {}, - destroy: () => {}, - end: () => {}, - on: () => {}, - once: () => {}, - emit: () => {} - }; - - const mockRecord = { - id: `mock-${i}`, - incoming: mockIncoming, - outgoing: mockOutgoing, - connectionClosed: false, - incomingStartTime: Date.now(), - lastActivity: Date.now(), - remoteIP: '127.0.0.1', - remotePort: 10000 + i, - localPort: 8588, - bytesReceived: 100, - bytesSent: 100, - incomingTerminationReason: null, - cleanupTimer: null - }; - - // Add to connection records - cm.connectionRecords.set(mockRecord.id, mockRecord); - mockConnections.push(mockRecord); - } - - console.log(`Created ${cm.getConnectionCount()} mock connections`); - expect(cm.getConnectionCount()).toEqual(150); - - // Queue all connections for cleanup - console.log('\n--- Queueing all connections for cleanup ---'); - - // The cleanup queue processes immediately when it reaches batch size (100) - // So after queueing 150, the first 100 will be processed immediately - for (const conn of mockConnections) { - cm.initiateCleanupOnce(conn, 'test_cleanup'); - } - - // After queueing 150, the first 100 should have been processed immediately - // leaving 50 in the queue - console.log(`Cleanup queue size after queueing: ${cm.cleanupQueue.size}`); - console.log(`Active connections after initial batch: ${cm.getConnectionCount()}`); - - // The first 100 should have been cleaned up immediately - expect(cm.cleanupQueue.size).toEqual(50); - expect(cm.getConnectionCount()).toEqual(50); - - // Wait for remaining cleanup to complete - console.log('\n--- Waiting for remaining cleanup batches to process ---'); - - // The remaining 50 connections should be cleaned up in the next batch - let waitTime = 0; - let lastCount = cm.getConnectionCount(); - - while (cm.getConnectionCount() > 0 || cm.cleanupQueue.size > 0) { - await new Promise(resolve => setTimeout(resolve, 100)); - waitTime += 100; - - const currentCount = cm.getConnectionCount(); - if (currentCount !== lastCount) { - console.log(`Active connections: ${currentCount}, Queue size: ${cm.cleanupQueue.size}`); - lastCount = currentCount; - } - - if (waitTime > 5000) { - console.log('Timeout waiting for cleanup to complete'); - break; - } - } - console.log(`All cleanup completed in ${waitTime}ms`); - - // Check final state - const finalCount = cm.getConnectionCount(); - console.log(`\nFinal connection count: ${finalCount}`); - console.log(`Final cleanup queue size: ${cm.cleanupQueue.size}`); - - // All connections should be cleaned up - expect(finalCount).toEqual(0); - expect(cm.cleanupQueue.size).toEqual(0); - - // Verify termination stats - all 150 should have been terminated - const stats = cm.getTerminationStats(); - console.log('Termination stats:', stats); - expect(stats.incoming.test_cleanup).toEqual(150); - - // Cleanup - console.log('\n--- Stopping proxy ---'); - await proxy.stop(); - - console.log('\nโœ“ Test complete: Cleanup queue now correctly processes all connections'); -}); - -export default tap.start(); \ No newline at end of file diff --git a/test/test.connect-disconnect-cleanup.node.ts b/test/test.connect-disconnect-cleanup.node.ts deleted file mode 100644 index a4e81de..0000000 --- a/test/test.connect-disconnect-cleanup.node.ts +++ /dev/null @@ -1,240 +0,0 @@ -import { tap, expect } from '@git.zone/tstest/tapbundle'; -import * as net from 'net'; -import * as plugins from '../ts/plugins.js'; - -// Import SmartProxy and configurations -import { SmartProxy } from '../ts/index.js'; - -tap.test('should handle clients that connect and immediately disconnect without sending data', async () => { - console.log('\n=== Testing Connect-Disconnect Cleanup ==='); - - // Create a SmartProxy instance - const proxy = new SmartProxy({ - enableDetailedLogging: false, - initialDataTimeout: 5000, // 5 second timeout for initial data - routes: [{ - name: 'test-route', - match: { ports: 8560 }, - action: { - type: 'forward', - targets: [{ - host: 'localhost', - port: 9999 // Non-existent port - }] - } - }] - }); - - // Start the proxy - await proxy.start(); - console.log('โœ“ Proxy started on port 8560'); - - // Helper to get active connection count - const getActiveConnections = () => { - const connectionManager = (proxy as any).connectionManager; - return connectionManager ? connectionManager.getConnectionCount() : 0; - }; - - const initialCount = getActiveConnections(); - console.log(`Initial connection count: ${initialCount}`); - - // Test 1: Connect and immediately disconnect without sending data - console.log('\n--- Test 1: Immediate disconnect ---'); - const connectionCounts: number[] = []; - - for (let i = 0; i < 10; i++) { - const client = new net.Socket(); - - // Connect and immediately destroy - client.connect(8560, 'localhost', () => { - // Connected - immediately destroy without sending data - client.destroy(); - }); - - // Wait a tiny bit - await new Promise(resolve => setTimeout(resolve, 10)); - - const count = getActiveConnections(); - connectionCounts.push(count); - if ((i + 1) % 5 === 0) { - console.log(`After ${i + 1} connect/disconnect cycles: ${count} active connections`); - } - } - - // Wait a bit for cleanup - await new Promise(resolve => setTimeout(resolve, 500)); - - const afterImmediateDisconnect = getActiveConnections(); - console.log(`After immediate disconnect test: ${afterImmediateDisconnect} active connections`); - - // Test 2: Connect, wait a bit, then disconnect without sending data - console.log('\n--- Test 2: Delayed disconnect ---'); - - for (let i = 0; i < 5; i++) { - const client = new net.Socket(); - - client.on('error', () => { - // Ignore errors - }); - - client.connect(8560, 'localhost', () => { - // Wait 100ms then disconnect without sending data - setTimeout(() => { - if (!client.destroyed) { - client.destroy(); - } - }, 100); - }); - } - - // Check count immediately - const duringDelayed = getActiveConnections(); - console.log(`During delayed disconnect test: ${duringDelayed} active connections`); - - // Wait for cleanup - await new Promise(resolve => setTimeout(resolve, 1000)); - - const afterDelayedDisconnect = getActiveConnections(); - console.log(`After delayed disconnect test: ${afterDelayedDisconnect} active connections`); - - // Test 3: Mix of immediate and delayed disconnects - console.log('\n--- Test 3: Mixed disconnect patterns ---'); - - const promises = []; - for (let i = 0; i < 20; i++) { - promises.push(new Promise((resolve) => { - const client = new net.Socket(); - - client.on('error', () => { - resolve(); - }); - - client.on('close', () => { - resolve(); - }); - - client.connect(8560, 'localhost', () => { - if (i % 2 === 0) { - // Half disconnect immediately - client.destroy(); - } else { - // Half wait 50ms - setTimeout(() => { - if (!client.destroyed) { - client.destroy(); - } - }, 50); - } - }); - - // Failsafe timeout - setTimeout(() => resolve(), 200); - })); - } - - // Wait for all to complete - await Promise.all(promises); - - const duringMixed = getActiveConnections(); - console.log(`During mixed test: ${duringMixed} active connections`); - - // Final cleanup wait - await new Promise(resolve => setTimeout(resolve, 1000)); - - const finalCount = getActiveConnections(); - console.log(`\nFinal connection count: ${finalCount}`); - - // Stop the proxy - await proxy.stop(); - console.log('โœ“ Proxy stopped'); - - // Verify all connections were cleaned up - expect(finalCount).toEqual(initialCount); - expect(afterImmediateDisconnect).toEqual(initialCount); - expect(afterDelayedDisconnect).toEqual(initialCount); - - // Check that connections didn't accumulate during the test - const maxCount = Math.max(...connectionCounts); - console.log(`\nMax connection count during immediate disconnect test: ${maxCount}`); - expect(maxCount).toBeLessThan(3); // Should stay very low - - console.log('\nโœ… PASS: Connect-disconnect cleanup working correctly!'); -}); - -tap.test('should handle clients that error during connection', async () => { - console.log('\n=== Testing Connection Error Cleanup ==='); - - const proxy = new SmartProxy({ - enableDetailedLogging: false, - routes: [{ - name: 'test-route', - match: { ports: 8561 }, - action: { - type: 'forward', - targets: [{ - host: 'localhost', - port: 9999 - }] - } - }] - }); - - await proxy.start(); - console.log('โœ“ Proxy started on port 8561'); - - const getActiveConnections = () => { - const connectionManager = (proxy as any).connectionManager; - return connectionManager ? connectionManager.getConnectionCount() : 0; - }; - - const initialCount = getActiveConnections(); - console.log(`Initial connection count: ${initialCount}`); - - // Create connections that will error - const promises = []; - for (let i = 0; i < 10; i++) { - promises.push(new Promise((resolve) => { - const client = new net.Socket(); - - client.on('error', () => { - resolve(); - }); - - client.on('close', () => { - resolve(); - }); - - // Connect to proxy - client.connect(8561, 'localhost', () => { - // Force an error by writing invalid data then destroying - try { - client.write(Buffer.alloc(1024 * 1024)); // Large write - client.destroy(); - } catch (e) { - // Ignore - } - }); - - // Timeout - setTimeout(() => resolve(), 500); - })); - } - - await Promise.all(promises); - console.log('โœ“ All error connections completed'); - - // Wait for cleanup - await new Promise(resolve => setTimeout(resolve, 500)); - - const finalCount = getActiveConnections(); - console.log(`Final connection count: ${finalCount}`); - - await proxy.stop(); - console.log('โœ“ Proxy stopped'); - - expect(finalCount).toEqual(initialCount); - - console.log('\nโœ… PASS: Connection error cleanup working correctly!'); -}); - -export default tap.start(); \ No newline at end of file diff --git a/test/test.connection-cleanup-comprehensive.node.ts b/test/test.connection-cleanup-comprehensive.node.ts deleted file mode 100644 index 7d723b5..0000000 --- a/test/test.connection-cleanup-comprehensive.node.ts +++ /dev/null @@ -1,277 +0,0 @@ -import { tap, expect } from '@git.zone/tstest/tapbundle'; -import * as net from 'net'; -import * as plugins from '../ts/plugins.js'; - -// Import SmartProxy and configurations -import { SmartProxy } from '../ts/index.js'; - -tap.test('comprehensive connection cleanup test - all scenarios', async () => { - console.log('\n=== Comprehensive Connection Cleanup Test ==='); - - // Create a SmartProxy instance - const proxy = new SmartProxy({ - enableDetailedLogging: false, - initialDataTimeout: 2000, - socketTimeout: 5000, - routes: [ - { - name: 'non-tls-route', - match: { ports: 8570 }, - action: { - type: 'forward', - targets: [{ - host: 'localhost', - port: 9999 // Non-existent port - }] - } - }, - { - name: 'tls-route', - match: { ports: 8571 }, - action: { - type: 'forward', - targets: [{ - host: 'localhost', - port: 9999 // Non-existent port - }], - tls: { - mode: 'passthrough' - } - } - } - ] - }); - - // Start the proxy - await proxy.start(); - console.log('โœ“ Proxy started on ports 8570 (non-TLS) and 8571 (TLS)'); - - // Helper to get active connection count - const getActiveConnections = () => { - const connectionManager = (proxy as any).connectionManager; - return connectionManager ? connectionManager.getConnectionCount() : 0; - }; - - const initialCount = getActiveConnections(); - console.log(`Initial connection count: ${initialCount}`); - - // Test 1: Rapid ECONNREFUSED retries (from original issue) - console.log('\n--- Test 1: Rapid ECONNREFUSED retries ---'); - for (let i = 0; i < 10; i++) { - await new Promise((resolve) => { - const client = new net.Socket(); - - client.on('error', () => { - client.destroy(); - resolve(); - }); - - client.on('close', () => { - resolve(); - }); - - client.connect(8570, 'localhost', () => { - // Send data to trigger routing - client.write('GET / HTTP/1.1\r\nHost: test.com\r\n\r\n'); - }); - - setTimeout(() => { - if (!client.destroyed) { - client.destroy(); - } - resolve(); - }, 100); - }); - - if ((i + 1) % 5 === 0) { - const count = getActiveConnections(); - console.log(`After ${i + 1} ECONNREFUSED retries: ${count} active connections`); - } - } - - // Test 2: Connect without sending data (immediate disconnect) - console.log('\n--- Test 2: Connect without sending data ---'); - for (let i = 0; i < 10; i++) { - const client = new net.Socket(); - - client.on('error', () => { - // Ignore - }); - - // Connect to non-TLS port and immediately disconnect - client.connect(8570, 'localhost', () => { - client.destroy(); - }); - - await new Promise(resolve => setTimeout(resolve, 10)); - } - - const afterNoData = getActiveConnections(); - console.log(`After connect-without-data test: ${afterNoData} active connections`); - - // Test 3: TLS connections that disconnect before handshake - console.log('\n--- Test 3: TLS early disconnect ---'); - for (let i = 0; i < 10; i++) { - const client = new net.Socket(); - - client.on('error', () => { - // Ignore - }); - - // Connect to TLS port but disconnect before sending handshake - client.connect(8571, 'localhost', () => { - // Wait 50ms then disconnect (before initial data timeout) - setTimeout(() => { - client.destroy(); - }, 50); - }); - - await new Promise(resolve => setTimeout(resolve, 100)); - } - - const afterTlsEarly = getActiveConnections(); - console.log(`After TLS early disconnect test: ${afterTlsEarly} active connections`); - - // Test 4: Mixed pattern - simulating real-world chaos - console.log('\n--- Test 4: Mixed chaos pattern ---'); - const promises = []; - - for (let i = 0; i < 30; i++) { - promises.push(new Promise((resolve) => { - const client = new net.Socket(); - const port = i % 2 === 0 ? 8570 : 8571; - - client.on('error', () => { - resolve(); - }); - - client.on('close', () => { - resolve(); - }); - - client.connect(port, 'localhost', () => { - const scenario = i % 5; - - switch (scenario) { - case 0: - // Immediate disconnect - client.destroy(); - break; - case 1: - // Send data then disconnect - client.write('GET / HTTP/1.1\r\nHost: test.com\r\n\r\n'); - setTimeout(() => client.destroy(), 20); - break; - case 2: - // Disconnect after delay - setTimeout(() => client.destroy(), 100); - break; - case 3: - // Send partial TLS handshake - if (port === 8571) { - client.write(Buffer.from([0x16, 0x03, 0x01])); // Partial TLS - } - setTimeout(() => client.destroy(), 50); - break; - case 4: - // Just let it timeout - break; - } - }); - - // Failsafe - setTimeout(() => { - if (!client.destroyed) { - client.destroy(); - } - resolve(); - }, 500); - })); - - // Small delay between connections - if (i % 5 === 0) { - await new Promise(resolve => setTimeout(resolve, 10)); - } - } - - await Promise.all(promises); - console.log('โœ“ Chaos test completed'); - - // Wait for any cleanup - await new Promise(resolve => setTimeout(resolve, 1000)); - - const afterChaos = getActiveConnections(); - console.log(`After chaos test: ${afterChaos} active connections`); - - // Test 5: NFTables route (should cleanup properly) - console.log('\n--- Test 5: NFTables route cleanup ---'); - const nftProxy = new SmartProxy({ - enableDetailedLogging: false, - routes: [{ - name: 'nftables-route', - match: { ports: 8572 }, - action: { - type: 'forward', - forwardingEngine: 'nftables', - targets: [{ - host: 'localhost', - port: 9999 - }] - } - }] - }); - - await nftProxy.start(); - - const getNftConnections = () => { - const connectionManager = (nftProxy as any).connectionManager; - return connectionManager ? connectionManager.getConnectionCount() : 0; - }; - - // Create NFTables connections - for (let i = 0; i < 5; i++) { - const client = new net.Socket(); - - client.on('error', () => { - // Ignore - }); - - client.connect(8572, 'localhost', () => { - setTimeout(() => client.destroy(), 50); - }); - - await new Promise(resolve => setTimeout(resolve, 100)); - } - - await new Promise(resolve => setTimeout(resolve, 500)); - - const nftFinal = getNftConnections(); - console.log(`NFTables connections after test: ${nftFinal}`); - - await nftProxy.stop(); - - // Final check on main proxy - const finalCount = getActiveConnections(); - console.log(`\nFinal connection count: ${finalCount}`); - - // Stop the proxy - await proxy.stop(); - console.log('โœ“ Proxy stopped'); - - // Verify all connections were cleaned up - expect(finalCount).toEqual(initialCount); - expect(afterNoData).toEqual(initialCount); - expect(afterTlsEarly).toEqual(initialCount); - expect(afterChaos).toEqual(initialCount); - expect(nftFinal).toEqual(0); - - console.log('\nโœ… PASS: Comprehensive connection cleanup test passed!'); - console.log('All connection scenarios properly cleaned up:'); - console.log('- ECONNREFUSED rapid retries'); - console.log('- Connect without sending data'); - console.log('- TLS early disconnect'); - console.log('- Mixed chaos patterns'); - console.log('- NFTables connections'); -}); - -export default tap.start(); \ No newline at end of file diff --git a/test/test.connection-limits.node.ts b/test/test.connection-limits.node.ts deleted file mode 100644 index c316d52..0000000 --- a/test/test.connection-limits.node.ts +++ /dev/null @@ -1,304 +0,0 @@ -import { expect, tap } from '@git.zone/tstest/tapbundle'; -import * as net from 'net'; -import { SmartProxy } from '../ts/proxies/smart-proxy/index.js'; -import { HttpProxy } from '../ts/proxies/http-proxy/index.js'; - -let testServer: net.Server; -let smartProxy: SmartProxy; -let httpProxy: HttpProxy; -const TEST_SERVER_PORT = 5100; -const PROXY_PORT = 5101; -const HTTP_PROXY_PORT = 5102; - -// Track all created servers and connections for cleanup -const allServers: net.Server[] = []; -const allProxies: (SmartProxy | HttpProxy)[] = []; -const activeConnections: net.Socket[] = []; - -// Helper: Creates a test TCP server -function createTestServer(port: number): Promise { - return new Promise((resolve) => { - const server = net.createServer((socket) => { - socket.on('data', (data) => { - socket.write(`Echo: ${data.toString()}`); - }); - socket.on('error', () => {}); - }); - server.listen(port, 'localhost', () => { - console.log(`[Test Server] Listening on localhost:${port}`); - allServers.push(server); - resolve(server); - }); - }); -} - -// Helper: Creates multiple concurrent connections -// If waitForData is true, waits for the connection to be fully established (can receive data) -async function createConcurrentConnections( - port: number, - count: number, - waitForData: boolean = false -): Promise { - const connections: net.Socket[] = []; - const promises: Promise[] = []; - - for (let i = 0; i < count; i++) { - promises.push( - new Promise((resolve, reject) => { - const client = new net.Socket(); - const timeout = setTimeout(() => { - client.destroy(); - reject(new Error(`Connection ${i} timeout`)); - }, 5000); - - client.connect(port, 'localhost', () => { - if (!waitForData) { - clearTimeout(timeout); - activeConnections.push(client); - connections.push(client); - resolve(client); - } - // If waitForData, we wait for the close event to see if connection was rejected - }); - - if (waitForData) { - // Wait a bit to see if connection gets closed by server - client.once('close', () => { - clearTimeout(timeout); - reject(new Error('Connection closed by server')); - }); - - // If we can write and get a response, connection is truly established - setTimeout(() => { - if (!client.destroyed) { - clearTimeout(timeout); - activeConnections.push(client); - connections.push(client); - resolve(client); - } - }, 100); - } - - client.on('error', (err) => { - clearTimeout(timeout); - reject(err); - }); - }) - ); - } - - await Promise.all(promises); - return connections; -} - -// Helper: Clean up connections -function cleanupConnections(connections: net.Socket[]): void { - connections.forEach(conn => { - if (!conn.destroyed) { - conn.destroy(); - } - }); -} - -tap.test('Setup test environment', async () => { - testServer = await createTestServer(TEST_SERVER_PORT); - - // Create SmartProxy with low connection limits for testing - smartProxy = new SmartProxy({ - routes: [{ - name: 'test-route', - match: { - ports: PROXY_PORT - }, - action: { - type: 'forward', - targets: [{ - host: 'localhost', - port: TEST_SERVER_PORT - }] - }, - security: { - maxConnections: 5 // Low limit for testing - } - }], - maxConnectionsPerIP: 3, // Low per-IP limit - connectionRateLimitPerMinute: 10, // Low rate limit - defaults: { - security: { - maxConnections: 10 // Low global limit - } - } - }); - - await smartProxy.start(); - allProxies.push(smartProxy); -}); - -tap.test('Per-IP connection limits', async () => { - // Test that we can create up to the per-IP limit - const connections1 = await createConcurrentConnections(PROXY_PORT, 3); - expect(connections1.length).toEqual(3); - - // Allow server-side processing to complete - await new Promise(resolve => setTimeout(resolve, 50)); - - // Try to create one more connection - should fail - // Use waitForData=true to detect if server closes the connection after accepting it - try { - await createConcurrentConnections(PROXY_PORT, 1, true); - // If we get here, the 4th connection was truly established - throw new Error('Should not allow more than 3 connections per IP'); - } catch (err) { - console.log(`Per-IP limit error received: ${err.message}`); - // Connection should be rejected - either reset, refused, or closed by server - const isRejected = err.message.includes('ECONNRESET') || - err.message.includes('ECONNREFUSED') || - err.message.includes('closed'); - expect(isRejected).toBeTrue(); - } - - // Clean up first set of connections - cleanupConnections(connections1); - await new Promise(resolve => setTimeout(resolve, 100)); - - // Should be able to create new connections after cleanup - const connections2 = await createConcurrentConnections(PROXY_PORT, 2); - expect(connections2.length).toEqual(2); - - cleanupConnections(connections2); -}); - -tap.test('Route-level connection limits', async () => { - // Create multiple connections up to route limit - const connections = await createConcurrentConnections(PROXY_PORT, 5); - expect(connections.length).toEqual(5); - - // Try to exceed route limit - try { - await createConcurrentConnections(PROXY_PORT, 1); - throw new Error('Should not allow more than 5 connections for this route'); - } catch (err) { - // Connection should be rejected - either reset or refused - console.log('Connection limit error:', err.message); - const isRejected = err.message.includes('ECONNRESET') || - err.message.includes('ECONNREFUSED') || - err.message.includes('closed') || - err.message.includes('5 connections'); - expect(isRejected).toBeTrue(); - } - - cleanupConnections(connections); -}); - -tap.test('Connection rate limiting', async () => { - // Create connections rapidly - const connections: net.Socket[] = []; - - // Create 10 connections rapidly (at rate limit) - for (let i = 0; i < 10; i++) { - try { - const conn = await createConcurrentConnections(PROXY_PORT, 1); - connections.push(...conn); - // Small delay to avoid per-IP limit - if (connections.length >= 3) { - cleanupConnections(connections.splice(0, 3)); - await new Promise(resolve => setTimeout(resolve, 50)); - } - } catch (err) { - // Expected to fail at some point due to rate limit - expect(i).toBeGreaterThan(0); - break; - } - } - - cleanupConnections(connections); -}); - -tap.test('HttpProxy per-IP validation', async () => { - // Skip complex HttpProxy integration test - focus on SmartProxy connection limits - // The HttpProxy has its own per-IP validation that's tested separately - // This test would require TLS certificates and more complex setup - console.log('Skipping HttpProxy per-IP validation - tested separately'); -}); - -tap.test('IP tracking cleanup', async (tools) => { - // Wait for any previous test cleanup to complete - await tools.delayFor(300); - - // Create and close connections - const connections: net.Socket[] = []; - - for (let i = 0; i < 2; i++) { - try { - const conn = await createConcurrentConnections(PROXY_PORT, 1); - connections.push(...conn); - } catch { - // Ignore rejections - } - } - - // Close all connections - cleanupConnections(connections); - - // Wait for cleanup to process - await tools.delayFor(500); - - // Verify that IP tracking has been cleaned up - const securityManager = (smartProxy as any).securityManager; - const ipCount = securityManager.getConnectionCountByIP('::ffff:127.0.0.1'); - - // Should have no connections tracked for this IP after cleanup - // Note: Due to asynchronous cleanup, we allow for some variance - expect(ipCount).toBeLessThanOrEqual(1); -}); - -tap.test('Cleanup queue race condition handling', async () => { - // Wait for previous test cleanup - await new Promise(resolve => setTimeout(resolve, 300)); - - // Create connections sequentially to avoid hitting per-IP limit - const allConnections: net.Socket[] = []; - for (let i = 0; i < 2; i++) { - try { - const conn = await createConcurrentConnections(PROXY_PORT, 1); - allConnections.push(...conn); - } catch { - // Ignore connection rejections - } - } - - // Close all connections rapidly - allConnections.forEach(conn => conn.destroy()); - - // Give cleanup queue time to process - await new Promise(resolve => setTimeout(resolve, 500)); - - // Verify all connections were cleaned up - const connectionManager = (smartProxy as any).connectionManager; - const remainingConnections = connectionManager.getConnectionCount(); - - // Allow for some variance due to async cleanup - expect(remainingConnections).toBeLessThanOrEqual(1); -}); - -tap.test('Cleanup and shutdown', async () => { - // Clean up any remaining connections - cleanupConnections(activeConnections); - activeConnections.length = 0; - - // Stop all proxies - for (const proxy of allProxies) { - await proxy.stop(); - } - allProxies.length = 0; - - // Close all test servers - for (const server of allServers) { - await new Promise((resolve) => { - server.close(() => resolve()); - }); - } - allServers.length = 0; -}); - -export default tap.start(); \ No newline at end of file diff --git a/test/test.fix-verification.ts b/test/test.fix-verification.ts deleted file mode 100644 index edc54fe..0000000 --- a/test/test.fix-verification.ts +++ /dev/null @@ -1,83 +0,0 @@ -import { tap, expect } from '@git.zone/tstest/tapbundle'; -import { SmartProxy } from '../ts/index.js'; - -tap.test('should verify certificate manager callback is preserved on updateRoutes', async () => { - // Create proxy with initial cert routes - const proxy = new SmartProxy({ - routes: [{ - name: 'cert-route', - match: { ports: [18443], domains: ['test.local'] }, - action: { - type: 'forward', - targets: [{ host: 'localhost', port: 3000 }], - tls: { - mode: 'terminate', - certificate: 'auto', - acme: { email: 'test@local.test' } - } - } - }], - acme: { email: 'test@local.test', port: 18080 } - }); - - // Track callback preservation - let initialCallbackSet = false; - let updateCallbackSet = false; - - // Mock certificate manager creation - (proxy as any).createCertificateManager = async function(...args: any[]) { - const certManager = { - updateRoutesCallback: null as any, - setUpdateRoutesCallback: function(callback: any) { - this.updateRoutesCallback = callback; - if (!initialCallbackSet) { - initialCallbackSet = true; - } else { - updateCallbackSet = true; - } - }, - setHttpProxy: () => {}, - setGlobalAcmeDefaults: () => {}, - setAcmeStateManager: () => {}, - setRoutes: (routes: any) => {}, - initialize: async () => {}, - provisionAllCertificates: async () => {}, - stop: async () => {}, - getAcmeOptions: () => ({ email: 'test@local.test' }), - getState: () => ({ challengeRouteActive: false }) - }; - - // Set callback as in real implementation - certManager.setUpdateRoutesCallback(async (routes) => { - await this.updateRoutes(routes); - }); - - return certManager; - }; - - await proxy.start(); - expect(initialCallbackSet).toEqual(true); - - // Update routes - this should preserve the callback - await proxy.updateRoutes([{ - name: 'updated-route', - match: { ports: [18444], domains: ['test2.local'] }, - action: { - type: 'forward', - targets: [{ host: 'localhost', port: 3001 }], - tls: { - mode: 'terminate', - certificate: 'auto', - acme: { email: 'test@local.test' } - } - } - }]); - - expect(updateCallbackSet).toEqual(true); - - await proxy.stop(); - - console.log('Fix verified: Certificate manager callback is preserved on updateRoutes'); -}); - -export default tap.start(); \ No newline at end of file diff --git a/test/test.http-fix-unit.ts b/test/test.http-fix-unit.ts deleted file mode 100644 index 0b2de7a..0000000 --- a/test/test.http-fix-unit.ts +++ /dev/null @@ -1,183 +0,0 @@ -import { tap, expect } from '@git.zone/tstest/tapbundle'; -import * as net from 'net'; - -// Unit test for the HTTP forwarding fix -tap.test('should forward non-TLS connections on HttpProxy ports', async (tapTest) => { - // Test configuration - const testPort = 8080; - const httpProxyPort = 8844; - - // Track forwarding logic - let forwardedToHttpProxy = false; - let setupDirectConnection = false; - - // Create mock settings - const mockSettings = { - useHttpProxy: [testPort], - httpProxyPort: httpProxyPort, - routes: [{ - name: 'test-route', - match: { ports: testPort }, - action: { - type: 'forward', - targets: [{ host: 'localhost', port: 8181 }] - } - }] - }; - - // Create mock connection record - const mockRecord = { - id: 'test-connection', - localPort: testPort, - remoteIP: '127.0.0.1', - isTLS: false - }; - - // Mock HttpProxyBridge - const mockHttpProxyBridge = { - getHttpProxy: () => ({ available: true }), - forwardToHttpProxy: async () => { - forwardedToHttpProxy = true; - } - }; - - // Test the logic from handleForwardAction - const route = mockSettings.routes[0]; - const action = route.action as any; - - // Simulate the fixed logic - if (!action.tls) { - // No TLS settings - check if this port should use HttpProxy - const isHttpProxyPort = mockSettings.useHttpProxy?.includes(mockRecord.localPort); - - if (isHttpProxyPort && mockHttpProxyBridge.getHttpProxy()) { - // Forward non-TLS connections to HttpProxy if configured - console.log(`Using HttpProxy for non-TLS connection on port ${mockRecord.localPort}`); - await mockHttpProxyBridge.forwardToHttpProxy(); - } else { - // Basic forwarding - console.log(`Using basic forwarding`); - setupDirectConnection = true; - } - } - - // Verify the fix works correctly - expect(forwardedToHttpProxy).toEqual(true); - expect(setupDirectConnection).toEqual(false); - - console.log('Test passed: Non-TLS connections on HttpProxy ports are forwarded correctly'); -}); - -// Test that non-HttpProxy ports still use direct connection -tap.test('should use direct connection for non-HttpProxy ports', async (tapTest) => { - let forwardedToHttpProxy = false; - let setupDirectConnection = false; - - const mockSettings = { - useHttpProxy: [80, 443], // Different ports - httpProxyPort: 8844, - routes: [{ - name: 'test-route', - match: { ports: 8080 }, // Not in useHttpProxy - action: { - type: 'forward', - targets: [{ host: 'localhost', port: 8181 }] - } - }] - }; - - const mockRecord = { - id: 'test-connection-2', - localPort: 8080, // Not in useHttpProxy - remoteIP: '127.0.0.1', - isTLS: false - }; - - const mockHttpProxyBridge = { - getHttpProxy: () => ({ available: true }), - forwardToHttpProxy: async () => { - forwardedToHttpProxy = true; - } - }; - - const route = mockSettings.routes[0]; - const action = route.action as any; - - // Test the logic - if (!action.tls) { - const isHttpProxyPort = mockSettings.useHttpProxy?.includes(mockRecord.localPort); - - if (isHttpProxyPort && mockHttpProxyBridge.getHttpProxy()) { - console.log(`Using HttpProxy for non-TLS connection on port ${mockRecord.localPort}`); - await mockHttpProxyBridge.forwardToHttpProxy(); - } else { - console.log(`Using basic forwarding for port ${mockRecord.localPort}`); - setupDirectConnection = true; - } - } - - // Verify port 8080 uses direct connection when not in useHttpProxy - expect(forwardedToHttpProxy).toEqual(false); - expect(setupDirectConnection).toEqual(true); - - console.log('Test passed: Non-HttpProxy ports use direct connection'); -}); - -// Test HTTP-01 ACME challenge scenario -tap.test('should handle ACME HTTP-01 challenges on port 80 with HttpProxy', async (tapTest) => { - let forwardedToHttpProxy = false; - - const mockSettings = { - useHttpProxy: [80], // Port 80 configured for HttpProxy - httpProxyPort: 8844, - acme: { - port: 80, - email: 'test@example.com' - }, - routes: [{ - name: 'acme-challenge', - match: { - ports: 80, - paths: ['/.well-known/acme-challenge/*'] - }, - action: { - type: 'forward', - targets: [{ host: 'localhost', port: 8080 }] - } - }] - }; - - const mockRecord = { - id: 'acme-connection', - localPort: 80, - remoteIP: '127.0.0.1', - isTLS: false - }; - - const mockHttpProxyBridge = { - getHttpProxy: () => ({ available: true }), - forwardToHttpProxy: async () => { - forwardedToHttpProxy = true; - } - }; - - const route = mockSettings.routes[0]; - const action = route.action as any; - - // Test the fix for ACME HTTP-01 challenges - if (!action.tls) { - const isHttpProxyPort = mockSettings.useHttpProxy?.includes(mockRecord.localPort); - - if (isHttpProxyPort && mockHttpProxyBridge.getHttpProxy()) { - console.log(`Using HttpProxy for ACME challenge on port ${mockRecord.localPort}`); - await mockHttpProxyBridge.forwardToHttpProxy(); - } - } - - // Verify HTTP-01 challenges on port 80 go through HttpProxy - expect(forwardedToHttpProxy).toEqual(true); - - console.log('Test passed: ACME HTTP-01 challenges on port 80 use HttpProxy'); -}); - -export default tap.start(); \ No newline at end of file diff --git a/test/test.http-fix-verification.ts b/test/test.http-fix-verification.ts deleted file mode 100644 index 57206fc..0000000 --- a/test/test.http-fix-verification.ts +++ /dev/null @@ -1,256 +0,0 @@ -import { tap, expect } from '@git.zone/tstest/tapbundle'; -import { RouteConnectionHandler } from '../ts/proxies/smart-proxy/route-connection-handler.js'; -import type { ISmartProxyOptions } from '../ts/proxies/smart-proxy/models/interfaces.js'; -import * as net from 'net'; - -// Direct test of the fix in RouteConnectionHandler -tap.test('should detect and forward non-TLS connections on useHttpProxy ports', async (tapTest) => { - // Create mock objects - const mockSettings: ISmartProxyOptions = { - useHttpProxy: [8080], - httpProxyPort: 8844, - routes: [{ - name: 'test-route', - match: { ports: 8080 }, - action: { - type: 'forward', - targets: [{ host: 'localhost', port: 8181 }] - } - }] - }; - - let httpProxyForwardCalled = false; - let directConnectionCalled = false; - - // Create mocks for dependencies - const mockHttpProxyBridge = { - getHttpProxy: () => ({ available: true }), - forwardToHttpProxy: async (...args: any[]) => { - console.log('Mock: forwardToHttpProxy called'); - httpProxyForwardCalled = true; - } - }; - - // Mock connection manager - const mockConnectionManager = { - createConnection: (socket: any) => ({ - id: 'test-connection', - localPort: 8080, - remoteIP: '127.0.0.1', - isTLS: false - }), - generateConnectionId: () => 'test-connection-id', - initiateCleanupOnce: () => {}, - cleanupConnection: () => {}, - getConnectionCount: () => 1, - trackConnectionByRoute: (routeId: string, connectionId: string) => {}, - handleError: (type: string, record: any) => { - return (error: Error) => { - console.log(`Mock: Error handled for ${type}: ${error.message}`); - }; - } - }; - - // Mock route manager that returns a matching route - const mockRouteManager = { - findMatchingRoute: (criteria: any) => ({ - route: mockSettings.routes[0] - }), - getRoutes: () => mockSettings.routes, - getRoutesForPort: (port: number) => mockSettings.routes.filter(r => { - const ports = Array.isArray(r.match.ports) ? r.match.ports : [r.match.ports]; - return ports.some(p => { - if (typeof p === 'number') { - return p === port; - } else if (p && typeof p === 'object' && 'from' in p && 'to' in p) { - return port >= p.from && port <= p.to; - } - return false; - }); - }) - }; - - // Mock security manager - const mockSecurityManager = { - validateAndTrackIP: () => ({ allowed: true }) - }; - - // Create a mock SmartProxy instance with necessary properties - const mockSmartProxy = { - settings: mockSettings, - connectionManager: mockConnectionManager, - securityManager: mockSecurityManager, - httpProxyBridge: mockHttpProxyBridge, - routeManager: mockRouteManager - } as any; - - // Create route connection handler instance - const handler = new RouteConnectionHandler(mockSmartProxy); - - // Override setupDirectConnection to track if it's called - handler['setupDirectConnection'] = (...args: any[]) => { - console.log('Mock: setupDirectConnection called'); - directConnectionCalled = true; - }; - - // Test: Create a mock socket representing non-TLS connection on port 8080 - const mockSocket = { - localPort: 8080, - remoteAddress: '127.0.0.1', - on: function(event: string, handler: Function) { return this; }, - once: function(event: string, handler: Function) { - // Capture the data handler - if (event === 'data') { - this._dataHandler = handler; - } - return this; - }, - end: () => {}, - destroy: () => {}, - pause: () => {}, - resume: () => {}, - removeListener: function() { return this; }, - emit: () => {}, - setNoDelay: () => {}, - setKeepAlive: () => {}, - _dataHandler: null as any - } as any; - - // Simulate the handler processing the connection - handler.handleConnection(mockSocket); - - // Simulate receiving non-TLS data - if (mockSocket._dataHandler) { - mockSocket._dataHandler(Buffer.from('GET / HTTP/1.1\r\nHost: test.local\r\n\r\n')); - } - - // Give it a moment to process - await new Promise(resolve => setTimeout(resolve, 100)); - - // Verify that the connection was forwarded to HttpProxy, not direct connection - expect(httpProxyForwardCalled).toEqual(true); - expect(directConnectionCalled).toEqual(false); -}); - -// Test that verifies TLS connections still work normally -tap.test('should handle TLS connections normally', async (tapTest) => { - const mockSettings: ISmartProxyOptions = { - useHttpProxy: [443], - httpProxyPort: 8844, - routes: [{ - name: 'tls-route', - match: { ports: 443 }, - action: { - type: 'forward', - targets: [{ host: 'localhost', port: 8443 }], - tls: { mode: 'terminate' } - } - }] - }; - - let httpProxyForwardCalled = false; - - const mockHttpProxyBridge = { - getHttpProxy: () => ({ available: true }), - forwardToHttpProxy: async (...args: any[]) => { - httpProxyForwardCalled = true; - } - }; - - const mockConnectionManager = { - createConnection: (socket: any) => ({ - id: 'test-tls-connection', - localPort: 443, - remoteIP: '127.0.0.1', - isTLS: true, - tlsHandshakeComplete: false - }), - generateConnectionId: () => 'test-tls-connection-id', - initiateCleanupOnce: () => {}, - cleanupConnection: () => {}, - getConnectionCount: () => 1, - trackConnectionByRoute: (routeId: string, connectionId: string) => {}, - handleError: (type: string, record: any) => { - return (error: Error) => { - console.log(`Mock: Error handled for ${type}: ${error.message}`); - }; - } - }; - - const mockTlsManager = { - isTlsHandshake: (chunk: Buffer) => true, - isClientHello: (chunk: Buffer) => true, - extractSNI: (chunk: Buffer) => 'test.local' - }; - - const mockRouteManager = { - findMatchingRoute: (criteria: any) => ({ - route: mockSettings.routes[0] - }), - getRoutes: () => mockSettings.routes, - getRoutesForPort: (port: number) => mockSettings.routes.filter(r => { - const ports = Array.isArray(r.match.ports) ? r.match.ports : [r.match.ports]; - return ports.some(p => { - if (typeof p === 'number') { - return p === port; - } else if (p && typeof p === 'object' && 'from' in p && 'to' in p) { - return port >= p.from && port <= p.to; - } - return false; - }); - }) - }; - - const mockSecurityManager = { - validateAndTrackIP: () => ({ allowed: true }) - }; - - // Create a mock SmartProxy instance with necessary properties - const mockSmartProxy = { - settings: mockSettings, - connectionManager: mockConnectionManager, - securityManager: mockSecurityManager, - tlsManager: mockTlsManager, - httpProxyBridge: mockHttpProxyBridge, - routeManager: mockRouteManager - } as any; - - const handler = new RouteConnectionHandler(mockSmartProxy); - - const mockSocket = { - localPort: 443, - remoteAddress: '127.0.0.1', - on: function(event: string, handler: Function) { return this; }, - once: function(event: string, handler: Function) { - // Capture the data handler - if (event === 'data') { - this._dataHandler = handler; - } - return this; - }, - end: () => {}, - destroy: () => {}, - pause: () => {}, - resume: () => {}, - removeListener: function() { return this; }, - emit: () => {}, - setNoDelay: () => {}, - setKeepAlive: () => {}, - _dataHandler: null as any - } as any; - - handler.handleConnection(mockSocket); - - // Simulate TLS handshake - if (mockSocket._dataHandler) { - const tlsHandshake = Buffer.from([0x16, 0x03, 0x01, 0x00, 0x05]); - mockSocket._dataHandler(tlsHandshake); - } - - await new Promise(resolve => setTimeout(resolve, 100)); - - // TLS connections with 'terminate' mode should go to HttpProxy - expect(httpProxyForwardCalled).toEqual(true); -}); - -export default tap.start(); \ No newline at end of file diff --git a/test/test.http-forwarding-fix.ts b/test/test.http-forwarding-fix.ts deleted file mode 100644 index d473e75..0000000 --- a/test/test.http-forwarding-fix.ts +++ /dev/null @@ -1,189 +0,0 @@ -import { tap, expect } from '@git.zone/tstest/tapbundle'; -import { SmartProxy } from '../ts/index.js'; -import * as net from 'net'; - -// Test that verifies HTTP connections on ports configured in useHttpProxy are properly forwarded -tap.test('should detect and forward non-TLS connections on HttpProxy ports', async (tapTest) => { - // Track whether the connection was forwarded to HttpProxy - let forwardedToHttpProxy = false; - let connectionPath = ''; - - // Create a SmartProxy instance first - const proxy = new SmartProxy({ - useHttpProxy: [8081], // Use different port to avoid conflicts - httpProxyPort: 8847, // Use different port to avoid conflicts - routes: [{ - name: 'test-http-forward', - match: { ports: 8081 }, - action: { - type: 'forward', - targets: [{ host: 'localhost', port: 8181 }] - } - }] - }); - - // Add detailed logging to the existing proxy instance - proxy.settings.enableDetailedLogging = true; - - // Override the HttpProxy initialization to avoid actual HttpProxy setup - proxy['httpProxyBridge'].initialize = async () => { - console.log('Mock: HttpProxyBridge initialized'); - }; - proxy['httpProxyBridge'].start = async () => { - console.log('Mock: HttpProxyBridge started'); - }; - proxy['httpProxyBridge'].stop = async () => { - console.log('Mock: HttpProxyBridge stopped'); - return Promise.resolve(); // Ensure it returns a resolved promise - }; - - await proxy.start(); - - // Mock the HttpProxy forwarding AFTER start to ensure it's not overridden - const originalForward = (proxy as any).httpProxyBridge.forwardToHttpProxy; - (proxy as any).httpProxyBridge.forwardToHttpProxy = async function(...args: any[]) { - forwardedToHttpProxy = true; - connectionPath = 'httpproxy'; - console.log('Mock: Connection forwarded to HttpProxy with args:', args[0], 'on port:', args[2]?.localPort); - // Properly close the connection for the test - const socket = args[1]; - socket.end(); - socket.destroy(); - }; - - // Mock getHttpProxy to indicate HttpProxy is available - (proxy as any).httpProxyBridge.getHttpProxy = () => ({ available: true }); - - // Make a connection to port 8080 - const client = new net.Socket(); - - await new Promise((resolve, reject) => { - client.connect(8081, 'localhost', () => { - console.log('Client connected to proxy on port 8081'); - // Send a non-TLS HTTP request - client.write('GET / HTTP/1.1\r\nHost: test.local\r\n\r\n'); - // Add a small delay to ensure data is sent - setTimeout(() => resolve(), 50); - }); - - client.on('error', reject); - }); - - // Give it a moment to process - await new Promise(resolve => setTimeout(resolve, 100)); - - // Verify the connection was forwarded to HttpProxy - expect(forwardedToHttpProxy).toEqual(true); - expect(connectionPath).toEqual('httpproxy'); - - client.destroy(); - - // Restore original method before stopping - (proxy as any).httpProxyBridge.forwardToHttpProxy = originalForward; - - console.log('About to stop proxy...'); - await proxy.stop(); - console.log('Proxy stopped'); - - // Wait a bit to ensure port is released - await new Promise(resolve => setTimeout(resolve, 100)); -}); - -// Test that verifies the fix detects non-TLS connections -tap.test('should properly detect non-TLS connections on HttpProxy ports', async (tapTest) => { - const targetPort = 8182; - let receivedConnection = false; - - // Create a target server that never receives the connection (because it goes to HttpProxy) - const targetServer = net.createServer((socket) => { - receivedConnection = true; - socket.end(); - }); - - await new Promise((resolve) => { - targetServer.listen(targetPort, () => { - console.log(`Target server listening on port ${targetPort}`); - resolve(); - }); - }); - - // Mock HttpProxyBridge to track forwarding - let httpProxyForwardCalled = false; - - const proxy = new SmartProxy({ - useHttpProxy: [8082], // Use different port to avoid conflicts - httpProxyPort: 8848, // Use different port to avoid conflicts - routes: [{ - name: 'test-route', - match: { - ports: 8082 - }, - action: { - type: 'forward', - targets: [{ host: 'localhost', port: targetPort }] - } - }] - }); - - // Override the forwardToHttpProxy method to track calls - const originalForward = proxy['httpProxyBridge'].forwardToHttpProxy; - proxy['httpProxyBridge'].forwardToHttpProxy = async function(...args: any[]) { - httpProxyForwardCalled = true; - console.log('HttpProxy forward called with connectionId:', args[0]); - // Properly close the connection - const socket = args[1]; - socket.end(); - socket.destroy(); - }; - - // Mock HttpProxyBridge methods - proxy['httpProxyBridge'].initialize = async () => { - console.log('Mock: HttpProxyBridge initialized'); - }; - proxy['httpProxyBridge'].start = async () => { - console.log('Mock: HttpProxyBridge started'); - }; - proxy['httpProxyBridge'].stop = async () => { - console.log('Mock: HttpProxyBridge stopped'); - return Promise.resolve(); // Ensure it returns a resolved promise - }; - - // Mock getHttpProxy to return a truthy value - proxy['httpProxyBridge'].getHttpProxy = () => ({} as any); - - await proxy.start(); - - // Make a non-TLS connection - const client = new net.Socket(); - - await new Promise((resolve, reject) => { - client.connect(8082, 'localhost', () => { - console.log('Connected to proxy'); - client.write('GET / HTTP/1.1\r\nHost: test.local\r\n\r\n'); - // Add a small delay to ensure data is sent - setTimeout(() => resolve(), 50); - }); - - client.on('error', () => resolve()); // Ignore errors since we're ending the connection - }); - - await new Promise(resolve => setTimeout(resolve, 100)); - - // Verify that HttpProxy was called, not direct connection - expect(httpProxyForwardCalled).toEqual(true); - expect(receivedConnection).toEqual(false); // Target should not receive direct connection - - client.destroy(); - await proxy.stop(); - await new Promise((resolve) => { - targetServer.close(() => resolve()); - }); - - // Wait a bit to ensure port is released - await new Promise(resolve => setTimeout(resolve, 100)); - - // Restore original method - proxy['httpProxyBridge'].forwardToHttpProxy = originalForward; -}); - -export default tap.start(); \ No newline at end of file diff --git a/test/test.http-port8080-simple.ts b/test/test.http-port8080-simple.ts deleted file mode 100644 index 90be620..0000000 --- a/test/test.http-port8080-simple.ts +++ /dev/null @@ -1,246 +0,0 @@ -import { tap, expect } from '@git.zone/tstest/tapbundle'; -import { SmartProxy } from '../ts/index.js'; -import * as plugins from '../ts/plugins.js'; -import * as net from 'net'; -import * as http from 'http'; - -/** - * This test verifies our improved port binding intelligence for ACME challenges. - * It specifically tests: - * 1. Using port 8080 instead of 80 for ACME HTTP challenges - * 2. Correctly handling shared port bindings between regular routes and challenge routes - * 3. Avoiding port conflicts when updating routes - */ - -tap.test('should handle ACME challenges on port 8080 with improved port binding intelligence', async (tapTest) => { - // Create a simple echo server to act as our target - const targetPort = 9001; - let receivedData = ''; - - const targetServer = net.createServer((socket) => { - console.log('Target server received connection'); - - socket.on('data', (data) => { - receivedData += data.toString(); - console.log('Target server received data:', data.toString().split('\n')[0]); - - // Send a simple HTTP response - const response = 'HTTP/1.1 200 OK\r\nContent-Type: text/plain\r\nContent-Length: 13\r\n\r\nHello, World!'; - socket.write(response); - }); - }); - - await new Promise((resolve) => { - targetServer.listen(targetPort, () => { - console.log(`Target server listening on port ${targetPort}`); - resolve(); - }); - }); - - // In this test we will NOT create a mock ACME server on the same port - // as SmartProxy will use, instead we'll let SmartProxy handle it - const acmeServerPort = 9009; - const acmeRequests: string[] = []; - let acmeServer: http.Server | null = null; - - // We'll assume the ACME port is available for SmartProxy - let acmePortAvailable = true; - - // Create SmartProxy with ACME configured to use port 8080 - console.log('Creating SmartProxy with ACME port 8080...'); - const tempCertDir = './temp-certs'; - - try { - await plugins.smartfile.fs.ensureDir(tempCertDir); - } catch (error) { - // Directory may already exist, that's ok - } - - const proxy = new SmartProxy({ - enableDetailedLogging: true, - routes: [ - { - name: 'test-route', - match: { - ports: [9003], - domains: ['test.example.com'] - }, - action: { - type: 'forward', - targets: [{ host: 'localhost', port: targetPort }], - tls: { - mode: 'terminate', - certificate: 'auto' // Use ACME for certificate - } - } - }, - // Also add a route for port 8080 to test port sharing - { - name: 'http-route', - match: { - ports: [9009], - domains: ['test.example.com'] - }, - action: { - type: 'forward', - targets: [{ host: 'localhost', port: targetPort }] - } - } - ], - acme: { - email: 'test@example.com', - useProduction: false, - port: 9009, // Use 9009 instead of default 80 - certificateStore: tempCertDir - } - }); - - // Mock the certificate manager to avoid actual ACME operations - console.log('Mocking certificate manager...'); - const createCertManager = (proxy as any).createCertificateManager; - (proxy as any).createCertificateManager = async function(...args: any[]) { - // Create a completely mocked certificate manager that doesn't use ACME at all - return { - initialize: async () => {}, - getCertPair: async () => { - return { - publicKey: 'MOCK CERTIFICATE', - privateKey: 'MOCK PRIVATE KEY' - }; - }, - getAcmeOptions: () => { - return { - port: 9009 - }; - }, - getState: () => { - return { - initializing: false, - ready: true, - port: 9009 - }; - }, - provisionAllCertificates: async () => { - console.log('Mock: Provisioning certificates'); - return []; - }, - stop: async () => {}, - setRoutes: (routes: any) => {}, - smartAcme: { - getCertificateForDomain: async () => { - // Return a mock certificate - return { - publicKey: 'MOCK CERTIFICATE', - privateKey: 'MOCK PRIVATE KEY', - validUntil: Date.now() + 90 * 24 * 60 * 60 * 1000, - created: Date.now() - }; - }, - start: async () => {}, - stop: async () => {} - } - }; - }; - - // Track port binding attempts to verify intelligence - const portBindAttempts: number[] = []; - const originalAddPort = (proxy as any).portManager.addPort; - (proxy as any).portManager.addPort = async function(port: number) { - portBindAttempts.push(port); - return originalAddPort.call(this, port); - }; - - try { - console.log('Starting SmartProxy...'); - await proxy.start(); - - console.log('Port binding attempts:', portBindAttempts); - - // Check that we tried to bind to port 9009 - // Should attempt to bind to port 9009 - expect(portBindAttempts.includes(9009)).toEqual(true); - // Should attempt to bind to port 9003 - expect(portBindAttempts.includes(9003)).toEqual(true); - - // Get actual bound ports - const boundPorts = proxy.getListeningPorts(); - console.log('Actually bound ports:', boundPorts); - - // If port 9009 was available, we should be bound to it - if (acmePortAvailable) { - // Should be bound to port 9009 if available - expect(boundPorts.includes(9009)).toEqual(true); - } - - // Should be bound to port 9003 - expect(boundPorts.includes(9003)).toEqual(true); - - // Test adding a new route on port 8080 - console.log('Testing route update with port reuse...'); - - // Reset tracking - portBindAttempts.length = 0; - - // Add a new route on port 8080 - const newRoutes = [ - ...proxy.settings.routes, - { - name: 'additional-route', - match: { - ports: [9009], - path: '/additional' - }, - action: { - type: 'forward' as const, - targets: [{ host: 'localhost', port: targetPort }] - } - } - ]; - - // Update routes - this should NOT try to rebind port 8080 - await proxy.updateRoutes(newRoutes); - - console.log('Port binding attempts after update:', portBindAttempts); - - // We should not try to rebind port 9009 since it's already bound - // Should not attempt to rebind port 9009 - expect(portBindAttempts.includes(9009)).toEqual(false); - - // We should still be listening on both ports - const portsAfterUpdate = proxy.getListeningPorts(); - console.log('Bound ports after update:', portsAfterUpdate); - - if (acmePortAvailable) { - // Should still be bound to port 9009 - expect(portsAfterUpdate.includes(9009)).toEqual(true); - } - // Should still be bound to port 9003 - expect(portsAfterUpdate.includes(9003)).toEqual(true); - - // The test is successful at this point - we've verified the port binding intelligence - console.log('Port binding intelligence verified successfully!'); - // We'll skip the actual connection test to avoid timeouts - } finally { - // Clean up - console.log('Cleaning up...'); - await proxy.stop(); - - if (targetServer) { - await new Promise((resolve) => { - targetServer.close(() => resolve()); - }); - } - - // No acmeServer to close in this test - - // Clean up temp directory - try { - // Remove temp directory - await plugins.smartfile.fs.remove(tempCertDir); - } catch (error) { - console.error('Failed to remove temp directory:', error); - } - } -}); - -export default tap.start(); \ No newline at end of file diff --git a/test/test.http-proxy-security-limits.node.ts b/test/test.http-proxy-security-limits.node.ts deleted file mode 100644 index 2a28f30..0000000 --- a/test/test.http-proxy-security-limits.node.ts +++ /dev/null @@ -1,114 +0,0 @@ -import { expect, tap } from '@git.zone/tstest/tapbundle'; -import { SecurityManager } from '../ts/proxies/http-proxy/security-manager.js'; -import { createLogger } from '../ts/proxies/http-proxy/models/types.js'; - -let securityManager: SecurityManager; -const logger = createLogger('error'); // Quiet logger for tests - -tap.test('Setup HttpProxy SecurityManager', async () => { - securityManager = new SecurityManager(logger, [], 3, 10); // Low limits for testing -}); - -tap.test('HttpProxy IP connection tracking', async () => { - const testIP = '10.0.0.1'; - - // Track connections - securityManager.trackConnectionByIP(testIP, 'http-conn1'); - securityManager.trackConnectionByIP(testIP, 'http-conn2'); - - expect(securityManager.getConnectionCountByIP(testIP)).toEqual(2); - - // Validate IP should pass - let result = securityManager.validateIP(testIP); - expect(result.allowed).toBeTrue(); - - // Add one more to reach limit - securityManager.trackConnectionByIP(testIP, 'http-conn3'); - - // Should now reject new connections - result = securityManager.validateIP(testIP); - expect(result.allowed).toBeFalse(); - expect(result.reason).toInclude('Maximum connections per IP (3) exceeded'); - - // Remove a connection - securityManager.removeConnectionByIP(testIP, 'http-conn1'); - - // Should allow connections again - result = securityManager.validateIP(testIP); - expect(result.allowed).toBeTrue(); - - // Clean up - securityManager.removeConnectionByIP(testIP, 'http-conn2'); - securityManager.removeConnectionByIP(testIP, 'http-conn3'); -}); - -tap.test('HttpProxy connection rate limiting', async () => { - const testIP = '10.0.0.2'; - - // Make 10 connection attempts rapidly (at rate limit) - // Note: We don't track connections here as we're testing rate limiting, not per-IP limiting - for (let i = 0; i < 10; i++) { - const result = securityManager.validateIP(testIP); - expect(result.allowed).toBeTrue(); - } - - // 11th connection should be rate limited - const result = securityManager.validateIP(testIP); - expect(result.allowed).toBeFalse(); - expect(result.reason).toInclude('Connection rate limit (10/min) exceeded'); -}); - -tap.test('HttpProxy CLIENT_IP header handling', async () => { - // This tests the scenario where SmartProxy forwards the real client IP - const realClientIP = '203.0.113.1'; - const proxyIP = '127.0.0.1'; - - // Simulate SmartProxy tracking the real client IP - securityManager.trackConnectionByIP(realClientIP, 'forwarded-conn1'); - securityManager.trackConnectionByIP(realClientIP, 'forwarded-conn2'); - securityManager.trackConnectionByIP(realClientIP, 'forwarded-conn3'); - - // Real client IP should be at limit - let result = securityManager.validateIP(realClientIP); - expect(result.allowed).toBeFalse(); - - // But proxy IP should still be allowed - result = securityManager.validateIP(proxyIP); - expect(result.allowed).toBeTrue(); - - // Clean up - securityManager.removeConnectionByIP(realClientIP, 'forwarded-conn1'); - securityManager.removeConnectionByIP(realClientIP, 'forwarded-conn2'); - securityManager.removeConnectionByIP(realClientIP, 'forwarded-conn3'); -}); - -tap.test('HttpProxy automatic cleanup', async (tools) => { - const testIP = '10.0.0.3'; - - // Create and immediately remove connections - for (let i = 0; i < 5; i++) { - securityManager.trackConnectionByIP(testIP, `cleanup-conn${i}`); - securityManager.removeConnectionByIP(testIP, `cleanup-conn${i}`); - } - - // Add rate limit entries - for (let i = 0; i < 5; i++) { - securityManager.validateIP(testIP); - } - - // Wait a bit (cleanup runs every 60 seconds in production) - // For testing, we'll just verify the cleanup logic works - await tools.delayFor(100); - - // Manually trigger cleanup (in production this happens automatically) - (securityManager as any).performIpCleanup(); - - // IP should be cleaned up - expect(securityManager.getConnectionCountByIP(testIP)).toEqual(0); -}); - -tap.test('Cleanup HttpProxy SecurityManager', async () => { - securityManager.clearIPTracking(); -}); - -export default tap.start(); \ No newline at end of file diff --git a/test/test.httpproxy.function-targets.ts b/test/test.httpproxy.function-targets.ts deleted file mode 100644 index 926e489..0000000 --- a/test/test.httpproxy.function-targets.ts +++ /dev/null @@ -1,405 +0,0 @@ -import { expect, tap } from '@git.zone/tstest/tapbundle'; -import * as plugins from '../ts/plugins.js'; -import { HttpProxy } from '../ts/proxies/http-proxy/index.js'; -import type { IRouteConfig } from '../ts/proxies/smart-proxy/models/route-types.js'; -import type { IRouteContext } from '../ts/core/models/route-context.js'; - -// Declare variables for tests -let httpProxy: HttpProxy; -let testServer: plugins.http.Server; -let testServerHttp2: plugins.http2.Http2Server; -let serverPort: number; -let serverPortHttp2: number; - -// Setup test environment -tap.test('setup HttpProxy function-based targets test environment', async (tools) => { - // Set a reasonable timeout for the test - tools.timeout(30000); // 30 seconds - // Create simple HTTP server to respond to requests - testServer = plugins.http.createServer((req, res) => { - res.writeHead(200, { 'Content-Type': 'application/json' }); - res.end(JSON.stringify({ - url: req.url, - headers: req.headers, - method: req.method, - message: 'HTTP/1.1 Response' - })); - }); - - // Create simple HTTP/2 server to respond to requests - testServerHttp2 = plugins.http2.createServer(); - testServerHttp2.on('stream', (stream, headers) => { - stream.respond({ - 'content-type': 'application/json', - ':status': 200 - }); - stream.end(JSON.stringify({ - path: headers[':path'], - headers, - method: headers[':method'], - message: 'HTTP/2 Response' - })); - }); - - // Handle HTTP/2 errors - testServerHttp2.on('error', (err) => { - console.error('HTTP/2 server error:', err); - }); - - // Start the servers - await new Promise(resolve => { - testServer.listen(0, () => { - const address = testServer.address() as { port: number }; - serverPort = address.port; - resolve(); - }); - }); - - await new Promise(resolve => { - testServerHttp2.listen(0, () => { - const address = testServerHttp2.address() as { port: number }; - serverPortHttp2 = address.port; - resolve(); - }); - }); - - // Create HttpProxy instance - httpProxy = new HttpProxy({ - port: 0, // Use dynamic port - logLevel: 'info', // Use info level to see more logs - // Disable ACME to avoid trying to bind to port 80 - acme: { - enabled: false - } - }); - - await httpProxy.start(); - - // Log the actual port being used - const actualPort = httpProxy.getListeningPort(); - console.log(`HttpProxy actual listening port: ${actualPort}`); -}); - -// Test static host/port routes -tap.test('should support static host/port routes', async () => { - // Get proxy port first - const proxyPort = httpProxy.getListeningPort(); - - const routes: IRouteConfig[] = [ - { - name: 'static-route', - priority: 100, - match: { - domains: 'example.com', - ports: proxyPort - }, - action: { - type: 'forward', - targets: [{ - host: 'localhost', - port: serverPort - }] - } - } - ]; - - await httpProxy.updateRouteConfigs(routes); - - // Make request to proxy - const response = await makeRequest({ - hostname: 'localhost', - port: proxyPort, - path: '/test', - method: 'GET', - headers: { - 'Host': 'example.com' - } - }); - - expect(response.statusCode).toEqual(200); - const body = JSON.parse(response.body); - expect(body.url).toEqual('/test'); - expect(body.headers.host).toEqual(`localhost:${serverPort}`); -}); - -// Test function-based host -tap.test('should support function-based host', async () => { - const proxyPort = httpProxy.getListeningPort(); - const routes: IRouteConfig[] = [ - { - name: 'function-host-route', - priority: 100, - match: { - domains: 'function.example.com', - ports: proxyPort - }, - action: { - type: 'forward', - targets: [{ - host: (context: IRouteContext) => { - // Return localhost always in this test - return 'localhost'; - }, - port: serverPort - }] - } - } - ]; - - await httpProxy.updateRouteConfigs(routes); - - // Make request to proxy - const response = await makeRequest({ - hostname: 'localhost', - port: proxyPort, - path: '/function-host', - method: 'GET', - headers: { - 'Host': 'function.example.com' - } - }); - - expect(response.statusCode).toEqual(200); - const body = JSON.parse(response.body); - expect(body.url).toEqual('/function-host'); - expect(body.headers.host).toEqual(`localhost:${serverPort}`); -}); - -// Test function-based port -tap.test('should support function-based port', async () => { - const proxyPort = httpProxy.getListeningPort(); - const routes: IRouteConfig[] = [ - { - name: 'function-port-route', - priority: 100, - match: { - domains: 'function-port.example.com', - ports: proxyPort - }, - action: { - type: 'forward', - targets: [{ - host: 'localhost', - port: (context: IRouteContext) => { - // Return test server port - return serverPort; - } - }] - } - } - ]; - - await httpProxy.updateRouteConfigs(routes); - - // Make request to proxy - const response = await makeRequest({ - hostname: 'localhost', - port: proxyPort, - path: '/function-port', - method: 'GET', - headers: { - 'Host': 'function-port.example.com' - } - }); - - expect(response.statusCode).toEqual(200); - const body = JSON.parse(response.body); - expect(body.url).toEqual('/function-port'); - expect(body.headers.host).toEqual(`localhost:${serverPort}`); -}); - -// Test function-based host AND port -tap.test('should support function-based host AND port', async () => { - const proxyPort = httpProxy.getListeningPort(); - const routes: IRouteConfig[] = [ - { - name: 'function-both-route', - priority: 100, - match: { - domains: 'function-both.example.com', - ports: proxyPort - }, - action: { - type: 'forward', - targets: [{ - host: (context: IRouteContext) => { - return 'localhost'; - }, - port: (context: IRouteContext) => { - return serverPort; - } - }] - } - } - ]; - - await httpProxy.updateRouteConfigs(routes); - - // Make request to proxy - const response = await makeRequest({ - hostname: 'localhost', - port: proxyPort, - path: '/function-both', - method: 'GET', - headers: { - 'Host': 'function-both.example.com' - } - }); - - expect(response.statusCode).toEqual(200); - const body = JSON.parse(response.body); - expect(body.url).toEqual('/function-both'); - expect(body.headers.host).toEqual(`localhost:${serverPort}`); -}); - -// Test context-based routing with path -tap.test('should support context-based routing with path', async () => { - const proxyPort = httpProxy.getListeningPort(); - const routes: IRouteConfig[] = [ - { - name: 'context-path-route', - priority: 100, - match: { - domains: 'context.example.com', - ports: proxyPort - }, - action: { - type: 'forward', - targets: [{ - host: (context: IRouteContext) => { - // Use path to determine host - if (context.path?.startsWith('/api')) { - return 'localhost'; - } else { - return '127.0.0.1'; // Another way to reference localhost - } - }, - port: serverPort - }] - } - } - ]; - - await httpProxy.updateRouteConfigs(routes); - - // Make request to proxy with /api path - const apiResponse = await makeRequest({ - hostname: 'localhost', - port: proxyPort, - path: '/api/test', - method: 'GET', - headers: { - 'Host': 'context.example.com' - } - }); - - expect(apiResponse.statusCode).toEqual(200); - const apiBody = JSON.parse(apiResponse.body); - expect(apiBody.url).toEqual('/api/test'); - - // Make request to proxy with non-api path - const nonApiResponse = await makeRequest({ - hostname: 'localhost', - port: proxyPort, - path: '/web/test', - method: 'GET', - headers: { - 'Host': 'context.example.com' - } - }); - - expect(nonApiResponse.statusCode).toEqual(200); - const nonApiBody = JSON.parse(nonApiResponse.body); - expect(nonApiBody.url).toEqual('/web/test'); -}); - -// Cleanup test environment -tap.test('cleanup HttpProxy function-based targets test environment', async () => { - // Skip cleanup if setup failed - if (!httpProxy && !testServer && !testServerHttp2) { - console.log('Skipping cleanup - setup failed'); - return; - } - - // Stop test servers first - if (testServer) { - await new Promise((resolve, reject) => { - testServer.close((err) => { - if (err) { - console.error('Error closing test server:', err); - reject(err); - } else { - console.log('Test server closed successfully'); - resolve(); - } - }); - }); - } - - if (testServerHttp2) { - await new Promise((resolve, reject) => { - testServerHttp2.close((err) => { - if (err) { - console.error('Error closing HTTP/2 test server:', err); - reject(err); - } else { - console.log('HTTP/2 test server closed successfully'); - resolve(); - } - }); - }); - } - - // Stop HttpProxy last - if (httpProxy) { - console.log('Stopping HttpProxy...'); - await httpProxy.stop(); - console.log('HttpProxy stopped successfully'); - } - - // Force exit after a short delay to ensure cleanup - const cleanupTimeout = setTimeout(() => { - console.log('Cleanup completed, exiting'); - }, 100); - - // Don't keep the process alive just for this timeout - if (cleanupTimeout.unref) { - cleanupTimeout.unref(); - } -}); - -// Helper function to make HTTPS requests with self-signed certificate support -async function makeRequest(options: plugins.http.RequestOptions): Promise<{ statusCode: number, headers: plugins.http.IncomingHttpHeaders, body: string }> { - return new Promise((resolve, reject) => { - // Use HTTPS with rejectUnauthorized: false to accept self-signed certificates - const req = plugins.https.request({ - ...options, - rejectUnauthorized: false, // Accept self-signed certificates - }, (res) => { - let body = ''; - res.on('data', (chunk) => { - body += chunk; - }); - res.on('end', () => { - resolve({ - statusCode: res.statusCode || 0, - headers: res.headers, - body - }); - }); - }); - - req.on('error', (err) => { - console.error(`Request error: ${err.message}`); - reject(err); - }); - - req.end(); - }); -} - -// Start the tests -tap.start().then(() => { - // Ensure process exits after tests complete - process.exit(0); -}); \ No newline at end of file diff --git a/test/test.httpproxy.ts b/test/test.httpproxy.ts deleted file mode 100644 index fc1fe4a..0000000 --- a/test/test.httpproxy.ts +++ /dev/null @@ -1,596 +0,0 @@ -import { expect, tap } from '@git.zone/tstest/tapbundle'; -import * as smartproxy from '../ts/index.js'; -import { loadTestCertificates } from './helpers/certificates.js'; -import * as https from 'https'; -import * as http from 'http'; -import { WebSocket, WebSocketServer } from 'ws'; - -let testProxy: smartproxy.HttpProxy; -let testServer: http.Server; -let wsServer: WebSocketServer; -let testCertificates: { privateKey: string; publicKey: string }; - -// Helper function to make HTTPS requests -async function makeHttpsRequest( - options: https.RequestOptions, -): Promise<{ statusCode: number; headers: http.IncomingHttpHeaders; body: string }> { - console.log('[TEST] Making HTTPS request:', { - hostname: options.hostname, - port: options.port, - path: options.path, - method: options.method, - headers: options.headers, - }); - return new Promise((resolve, reject) => { - const req = https.request(options, (res) => { - console.log('[TEST] Received HTTPS response:', { - statusCode: res.statusCode, - headers: res.headers, - }); - let data = ''; - res.on('data', (chunk) => (data += chunk)); - res.on('end', () => { - console.log('[TEST] Response completed:', { data }); - // Ensure the socket is destroyed to prevent hanging connections - res.socket?.destroy(); - resolve({ - statusCode: res.statusCode!, - headers: res.headers, - body: data, - }); - }); - }); - req.on('error', (error) => { - console.error('[TEST] Request error:', error); - reject(error); - }); - req.end(); - }); -} - -// Setup test environment -tap.test('setup test environment', async () => { - // Load and validate certificates - console.log('[TEST] Loading and validating certificates'); - testCertificates = loadTestCertificates(); - console.log('[TEST] Certificates loaded and validated'); - - // Create a test HTTP server - testServer = http.createServer((req, res) => { - console.log('[TEST SERVER] Received HTTP request:', { - url: req.url, - method: req.method, - headers: req.headers, - }); - res.writeHead(200, { 'Content-Type': 'text/plain' }); - res.end('Hello from test server!'); - }); - - // Handle WebSocket upgrade requests - testServer.on('upgrade', (request, socket, head) => { - console.log('[TEST SERVER] Received WebSocket upgrade request:', { - url: request.url, - method: request.method, - headers: { - host: request.headers.host, - upgrade: request.headers.upgrade, - connection: request.headers.connection, - 'sec-websocket-key': request.headers['sec-websocket-key'], - 'sec-websocket-version': request.headers['sec-websocket-version'], - 'sec-websocket-protocol': request.headers['sec-websocket-protocol'], - }, - }); - - if (request.headers.upgrade?.toLowerCase() !== 'websocket') { - console.log('[TEST SERVER] Not a WebSocket upgrade request'); - socket.destroy(); - return; - } - - console.log('[TEST SERVER] Handling WebSocket upgrade'); - wsServer.handleUpgrade(request, socket, head, (ws) => { - console.log('[TEST SERVER] WebSocket connection upgraded'); - wsServer.emit('connection', ws, request); - }); - }); - - // Create a WebSocket server (for the test HTTP server) - console.log('[TEST SERVER] Creating WebSocket server'); - wsServer = new WebSocketServer({ - noServer: true, - perMessageDeflate: false, - clientTracking: true, - handleProtocols: () => 'echo-protocol', - }); - - wsServer.on('connection', (ws, request) => { - console.log('[TEST SERVER] WebSocket connection established:', { - url: request.url, - headers: { - host: request.headers.host, - upgrade: request.headers.upgrade, - connection: request.headers.connection, - 'sec-websocket-key': request.headers['sec-websocket-key'], - 'sec-websocket-version': request.headers['sec-websocket-version'], - 'sec-websocket-protocol': request.headers['sec-websocket-protocol'], - }, - }); - - // Set up connection timeout - const connectionTimeout = setTimeout(() => { - console.error('[TEST SERVER] WebSocket connection timed out'); - ws.terminate(); - }, 5000); - - // Clear timeout when connection is properly closed - const clearConnectionTimeout = () => { - clearTimeout(connectionTimeout); - }; - - ws.on('message', (message) => { - const msg = message.toString(); - console.log('[TEST SERVER] Received WebSocket message:', msg); - try { - const response = `Echo: ${msg}`; - console.log('[TEST SERVER] Sending WebSocket response:', response); - ws.send(response); - // Clear timeout on successful message exchange - clearConnectionTimeout(); - } catch (error) { - console.error('[TEST SERVER] Error sending WebSocket message:', error); - } - }); - - ws.on('error', (error) => { - console.error('[TEST SERVER] WebSocket error:', error); - clearConnectionTimeout(); - }); - - ws.on('close', (code, reason) => { - console.log('[TEST SERVER] WebSocket connection closed:', { - code, - reason: reason.toString(), - wasClean: code === 1000 || code === 1001, - }); - clearConnectionTimeout(); - }); - - ws.on('ping', (data) => { - try { - console.log('[TEST SERVER] Received ping, sending pong'); - ws.pong(data); - } catch (error) { - console.error('[TEST SERVER] Error sending pong:', error); - } - }); - - ws.on('pong', (data) => { - console.log('[TEST SERVER] Received pong'); - }); - }); - - wsServer.on('error', (error) => { - console.error('Test server: WebSocket server error:', error); - }); - - wsServer.on('headers', (headers) => { - console.log('Test server: WebSocket headers:', headers); - }); - - wsServer.on('close', () => { - console.log('Test server: WebSocket server closed'); - }); - - await new Promise((resolve) => testServer.listen(3100, resolve)); - console.log('Test server listening on port 3100'); -}); - -tap.test('should create proxy instance', async () => { - // Test with the original minimal options (only port) - testProxy = new smartproxy.HttpProxy({ - port: 3001, - }); - expect(testProxy).toEqual(testProxy); // Instance equality check -}); - -tap.test('should create proxy instance with extended options', async () => { - // Test with extended options to verify backward compatibility - testProxy = new smartproxy.HttpProxy({ - port: 3001, - maxConnections: 5000, - keepAliveTimeout: 120000, - headersTimeout: 60000, - logLevel: 'info', - cors: { - allowOrigin: '*', - allowMethods: 'GET, POST, OPTIONS', - allowHeaders: 'Content-Type', - maxAge: 3600 - } - }); - expect(testProxy).toEqual(testProxy); // Instance equality check - expect(testProxy.options.port).toEqual(3001); -}); - -tap.test('should start the proxy server', async () => { - // Create a new proxy instance - testProxy = new smartproxy.HttpProxy({ - port: 3001, - maxConnections: 5000, - backendProtocol: 'http1', - acme: { - enabled: false // Disable ACME for testing - } - }); - - // Configure routes for the proxy - await testProxy.updateRouteConfigs([ - { - match: { - ports: [3001], - domains: ['push.rocks', 'localhost'] - }, - action: { - type: 'forward', - targets: [{ - host: 'localhost', - port: 3100 - }], - tls: { - mode: 'terminate' - }, - websocket: { - enabled: true, - subprotocols: ['echo-protocol'] - } - } - } - ]); - - // Start the proxy - await testProxy.start(); - - // Verify the proxy is listening on the correct port - expect(testProxy.getListeningPort()).toEqual(3001); -}); - -tap.test('should route HTTPS requests based on host header', async () => { - // IMPORTANT: Connect to localhost (where the proxy is listening) but use the Host header "push.rocks" - const response = await makeHttpsRequest({ - hostname: 'localhost', // changed from 'push.rocks' to 'localhost' - port: 3001, - path: '/', - method: 'GET', - headers: { - host: 'push.rocks', // virtual host for routing - }, - rejectUnauthorized: false, - }); - - expect(response.statusCode).toEqual(200); - expect(response.body).toEqual('Hello from test server!'); -}); - -tap.test('should handle unknown host headers', async () => { - // Connect to localhost but use an unknown host header. - const response = await makeHttpsRequest({ - hostname: 'localhost', // connecting to localhost - port: 3001, - path: '/', - method: 'GET', - headers: { - host: 'unknown.host', // this should not match any proxy config - }, - rejectUnauthorized: false, - }); - - // Expect a 404 response with the appropriate error message. - expect(response.statusCode).toEqual(404); -}); - -tap.test('should support WebSocket connections', async () => { - // Create a WebSocket client - console.log('[TEST] Testing WebSocket connection'); - - console.log('[TEST] Creating WebSocket to wss://localhost:3001/ with host header: push.rocks'); - const ws = new WebSocket('wss://localhost:3001/', { - protocol: 'echo-protocol', - rejectUnauthorized: false, - headers: { - host: 'push.rocks' - } - }); - - const connectionTimeout = setTimeout(() => { - console.error('[TEST] WebSocket connection timeout'); - ws.terminate(); - }, 5000); - - const timeouts: NodeJS.Timeout[] = [connectionTimeout]; - - try { - // Wait for connection with timeout - await Promise.race([ - new Promise((resolve, reject) => { - ws.on('open', () => { - console.log('[TEST] WebSocket connected'); - clearTimeout(connectionTimeout); - resolve(); - }); - ws.on('error', (err) => { - console.error('[TEST] WebSocket connection error:', err); - clearTimeout(connectionTimeout); - reject(err); - }); - }), - new Promise((_, reject) => { - const timeout = setTimeout(() => reject(new Error('Connection timeout')), 3000); - timeouts.push(timeout); - }) - ]); - - // Send a message and receive echo with timeout - await Promise.race([ - new Promise((resolve, reject) => { - const testMessage = 'Hello WebSocket!'; - let messageReceived = false; - - ws.on('message', (data) => { - messageReceived = true; - const message = data.toString(); - console.log('[TEST] Received WebSocket message:', message); - expect(message).toEqual(`Echo: ${testMessage}`); - resolve(); - }); - - ws.on('error', (err) => { - console.error('[TEST] WebSocket message error:', err); - reject(err); - }); - - console.log('[TEST] Sending WebSocket message:', testMessage); - ws.send(testMessage); - - // Add additional debug logging - const debugTimeout = setTimeout(() => { - if (!messageReceived) { - console.log('[TEST] No message received after 2 seconds'); - } - }, 2000); - timeouts.push(debugTimeout); - }), - new Promise((_, reject) => { - const timeout = setTimeout(() => reject(new Error('Message timeout')), 3000); - timeouts.push(timeout); - }) - ]); - - // Close the connection properly - await Promise.race([ - new Promise((resolve) => { - ws.on('close', () => { - console.log('[TEST] WebSocket closed'); - resolve(); - }); - ws.close(); - }), - new Promise((resolve) => { - const timeout = setTimeout(() => { - console.log('[TEST] Force closing WebSocket'); - ws.terminate(); - resolve(); - }, 2000); - timeouts.push(timeout); - }) - ]); - } catch (error) { - console.error('[TEST] WebSocket test error:', error); - try { - ws.terminate(); - } catch (terminateError) { - console.error('[TEST] Error during terminate:', terminateError); - } - // Skip if WebSocket fails for now - console.log('[TEST] WebSocket test failed, continuing with other tests'); - } finally { - // Clean up all timeouts - timeouts.forEach(timeout => clearTimeout(timeout)); - } -}); - -tap.test('should handle custom headers', async () => { - await testProxy.addDefaultHeaders({ - 'X-Proxy-Header': 'test-value', - }); - - const response = await makeHttpsRequest({ - hostname: 'localhost', // changed to 'localhost' - port: 3001, - path: '/', - method: 'GET', - headers: { - host: 'push.rocks', // still routing to push.rocks - }, - rejectUnauthorized: false, - }); - - expect(response.headers['x-proxy-header']).toEqual('test-value'); -}); - -tap.test('should handle CORS preflight requests', async () => { - // Test OPTIONS request (CORS preflight) - const response = await makeHttpsRequest({ - hostname: 'localhost', - port: 3001, - path: '/', - method: 'OPTIONS', - headers: { - host: 'push.rocks', - origin: 'https://example.com', - 'access-control-request-method': 'POST', - 'access-control-request-headers': 'content-type' - }, - rejectUnauthorized: false, - }); - - // Should get appropriate CORS headers - expect(response.statusCode).toBeLessThan(300); // 200 or 204 - expect(response.headers['access-control-allow-origin']).toEqual('*'); - expect(response.headers['access-control-allow-methods']).toContain('GET'); - expect(response.headers['access-control-allow-methods']).toContain('POST'); -}); - -tap.test('should track connections and metrics', async () => { - // Get metrics from the proxy - const metrics = testProxy.getMetrics(); - - // Verify metrics structure and some values - expect(metrics).toHaveProperty('activeConnections'); - expect(metrics).toHaveProperty('totalRequests'); - expect(metrics).toHaveProperty('failedRequests'); - expect(metrics).toHaveProperty('uptime'); - expect(metrics).toHaveProperty('memoryUsage'); - expect(metrics).toHaveProperty('activeWebSockets'); - - // Should have served at least some requests from previous tests - expect(metrics.totalRequests).toBeGreaterThan(0); - expect(metrics.uptime).toBeGreaterThan(0); -}); - -tap.test('should update capacity settings', async () => { - // Update proxy capacity settings - testProxy.updateCapacity(2000, 60000, 25); - - // Verify settings were updated - expect(testProxy.options.maxConnections).toEqual(2000); - expect(testProxy.options.keepAliveTimeout).toEqual(60000); - expect(testProxy.options.connectionPoolSize).toEqual(25); -}); - -tap.test('should handle certificate requests', async () => { - // Test certificate request (this won't actually issue a cert in test mode) - const result = await testProxy.requestCertificate('test.example.com'); - - // In test mode with ACME disabled, this should return false - expect(result).toEqual(false); -}); - -tap.test('should update certificates directly', async () => { - // Test certificate update - const testCert = '-----BEGIN CERTIFICATE-----\nMIIB...test...'; - const testKey = '-----BEGIN PRIVATE KEY-----\nMIIE...test...'; - - // This should not throw - expect(() => { - testProxy.updateCertificate('test.example.com', testCert, testKey); - }).not.toThrow(); -}); - -tap.test('cleanup', async () => { - console.log('[TEST] Starting cleanup'); - - try { - // 1. Close WebSocket clients if server exists - if (wsServer && wsServer.clients) { - console.log(`[TEST] Terminating ${wsServer.clients.size} WebSocket clients`); - wsServer.clients.forEach((client) => { - try { - client.terminate(); - } catch (err) { - console.error('[TEST] Error terminating client:', err); - } - }); - } - - // 2. Close WebSocket server with timeout - if (wsServer) { - console.log('[TEST] Closing WebSocket server'); - await Promise.race([ - new Promise((resolve, reject) => { - wsServer.close((err) => { - if (err) { - console.error('[TEST] Error closing WebSocket server:', err); - reject(err); - } else { - console.log('[TEST] WebSocket server closed'); - resolve(); - } - }); - }).catch((err) => { - console.error('[TEST] Caught error closing WebSocket server:', err); - }), - new Promise((resolve) => { - setTimeout(() => { - console.log('[TEST] WebSocket server close timeout'); - resolve(); - }, 1000); - }) - ]); - } - - // 3. Close test server with timeout - if (testServer) { - console.log('[TEST] Closing test server'); - // First close all connections - testServer.closeAllConnections(); - - await Promise.race([ - new Promise((resolve, reject) => { - testServer.close((err) => { - if (err) { - console.error('[TEST] Error closing test server:', err); - reject(err); - } else { - console.log('[TEST] Test server closed'); - resolve(); - } - }); - }).catch((err) => { - console.error('[TEST] Caught error closing test server:', err); - }), - new Promise((resolve) => { - setTimeout(() => { - console.log('[TEST] Test server close timeout'); - resolve(); - }, 1000); - }) - ]); - } - - // 4. Stop the proxy with timeout - if (testProxy) { - console.log('[TEST] Stopping proxy'); - await Promise.race([ - testProxy.stop() - .then(() => { - console.log('[TEST] Proxy stopped successfully'); - }) - .catch((error) => { - console.error('[TEST] Error stopping proxy:', error); - }), - new Promise((resolve) => { - setTimeout(() => { - console.log('[TEST] Proxy stop timeout'); - resolve(); - }, 2000); - }) - ]); - } - } catch (error) { - console.error('[TEST] Error during cleanup:', error); - } - - console.log('[TEST] Cleanup complete'); - - // Add debugging to see what might be keeping the process alive - if (process.env.DEBUG_HANDLES) { - console.log('[TEST] Active handles:', (process as any)._getActiveHandles?.().length); - console.log('[TEST] Active requests:', (process as any)._getActiveRequests?.().length); - } -}); - -// Exit handler removed to prevent interference with test cleanup - -// Teardown test removed - let tap handle proper cleanup - -export default tap.start(); \ No newline at end of file diff --git a/test/test.keepalive-support.node.ts b/test/test.keepalive-support.node.ts deleted file mode 100644 index 5c30401..0000000 --- a/test/test.keepalive-support.node.ts +++ /dev/null @@ -1,250 +0,0 @@ -import { expect, tap } from '@git.zone/tstest/tapbundle'; -import * as net from 'net'; -import { SmartProxy } from '../ts/index.js'; -import * as plugins from '../ts/plugins.js'; - -tap.test('keepalive support - verify keepalive connections are properly handled', async (tools) => { - console.log('\n=== KeepAlive Support Test ==='); - console.log('Purpose: Verify that keepalive connections are not prematurely cleaned up'); - - // Create a simple echo backend - const echoBackend = net.createServer((socket) => { - socket.on('data', (data) => { - // Echo back received data - try { - socket.write(data); - } catch (err) { - // Ignore write errors during shutdown - } - }); - - socket.on('error', (err: NodeJS.ErrnoException) => { - // Ignore errors from backend sockets - console.log(`Backend socket error (expected during cleanup): ${err.code}`); - }); - }); - - await new Promise((resolve) => { - echoBackend.listen(9998, () => { - console.log('โœ“ Echo backend started on port 9998'); - resolve(); - }); - }); - - // Test 1: Standard keepalive treatment - console.log('\n--- Test 1: Standard KeepAlive Treatment ---'); - - const proxy1 = new SmartProxy({ - routes: [{ - name: 'keepalive-route', - match: { ports: 8590 }, - action: { - type: 'forward', - targets: [{ host: 'localhost', port: 9998 }] - } - }], - keepAlive: true, - keepAliveTreatment: 'standard', - inactivityTimeout: 5000, // 5 seconds for faster testing - enableDetailedLogging: false, - }); - - await proxy1.start(); - console.log('โœ“ Proxy with standard keepalive started on port 8590'); - - // Create a keepalive connection - const client1 = net.connect(8590, 'localhost'); - - // Add error handler to prevent unhandled errors - client1.on('error', (err: NodeJS.ErrnoException) => { - console.log(`Client1 error (expected during cleanup): ${err.code}`); - }); - - await new Promise((resolve) => { - client1.on('connect', () => { - console.log('Client connected'); - client1.setKeepAlive(true, 1000); - resolve(); - }); - }); - - // Send initial data - client1.write('Hello keepalive\n'); - - // Wait for echo - await new Promise((resolve) => { - client1.once('data', (data) => { - console.log(`Received echo: ${data.toString().trim()}`); - resolve(); - }); - }); - - // Check connection is marked as keepalive - const cm1 = (proxy1 as any).connectionManager; - const connections1 = cm1.getConnections(); - let keepAliveCount = 0; - - for (const [id, record] of connections1) { - if (record.hasKeepAlive) { - keepAliveCount++; - console.log(`KeepAlive connection ${id}: hasKeepAlive=${record.hasKeepAlive}`); - } - } - - expect(keepAliveCount).toEqual(1); - - // Wait to ensure it's not cleaned up prematurely - await plugins.smartdelay.delayFor(6000); - - const afterWaitCount1 = cm1.getConnectionCount(); - console.log(`Connections after 6s wait: ${afterWaitCount1}`); - expect(afterWaitCount1).toEqual(1); // Should still be connected - - // Send more data to keep it alive - client1.write('Still alive\n'); - - // Clean up test 1 - client1.destroy(); - await proxy1.stop(); - await plugins.smartdelay.delayFor(500); // Wait for port to be released - - // Test 2: Extended keepalive treatment - console.log('\n--- Test 2: Extended KeepAlive Treatment ---'); - - const proxy2 = new SmartProxy({ - routes: [{ - name: 'keepalive-extended', - match: { ports: 8591 }, - action: { - type: 'forward', - targets: [{ host: 'localhost', port: 9998 }] - } - }], - keepAlive: true, - keepAliveTreatment: 'extended', - keepAliveInactivityMultiplier: 6, - inactivityTimeout: 2000, // 2 seconds base, 12 seconds with multiplier - enableDetailedLogging: false, - }); - - await proxy2.start(); - console.log('โœ“ Proxy with extended keepalive started on port 8591'); - - const client2 = net.connect(8591, 'localhost'); - - // Add error handler to prevent unhandled errors - client2.on('error', (err: NodeJS.ErrnoException) => { - console.log(`Client2 error (expected during cleanup): ${err.code}`); - }); - - await new Promise((resolve) => { - client2.on('connect', () => { - console.log('Client connected with extended timeout'); - client2.setKeepAlive(true, 1000); - resolve(); - }); - }); - - // Send initial data - client2.write('Extended keepalive\n'); - - // Check connection - const cm2 = (proxy2 as any).connectionManager; - await plugins.smartdelay.delayFor(1000); - - const connections2 = cm2.getConnections(); - for (const [id, record] of connections2) { - console.log(`Extended connection ${id}: hasKeepAlive=${record.hasKeepAlive}, treatment=extended`); - } - - // Wait 3 seconds (would timeout with standard treatment) - await plugins.smartdelay.delayFor(3000); - - const midWaitCount = cm2.getConnectionCount(); - console.log(`Connections after 3s (base timeout exceeded): ${midWaitCount}`); - expect(midWaitCount).toEqual(1); // Should still be connected due to extended treatment - - // Clean up test 2 - client2.destroy(); - await proxy2.stop(); - await plugins.smartdelay.delayFor(500); // Wait for port to be released - - // Test 3: Immortal keepalive treatment - console.log('\n--- Test 3: Immortal KeepAlive Treatment ---'); - - const proxy3 = new SmartProxy({ - routes: [{ - name: 'keepalive-immortal', - match: { ports: 8592 }, - action: { - type: 'forward', - targets: [{ host: 'localhost', port: 9998 }] - } - }], - keepAlive: true, - keepAliveTreatment: 'immortal', - inactivityTimeout: 1000, // 1 second - should be ignored for immortal - enableDetailedLogging: false, - }); - - await proxy3.start(); - console.log('โœ“ Proxy with immortal keepalive started on port 8592'); - - const client3 = net.connect(8592, 'localhost'); - - // Add error handler to prevent unhandled errors - client3.on('error', (err: NodeJS.ErrnoException) => { - console.log(`Client3 error (expected during cleanup): ${err.code}`); - }); - - await new Promise((resolve) => { - client3.on('connect', () => { - console.log('Client connected with immortal treatment'); - client3.setKeepAlive(true, 1000); - resolve(); - }); - }); - - // Send initial data - client3.write('Immortal connection\n'); - - // Wait well beyond normal timeout - await plugins.smartdelay.delayFor(5000); - - const cm3 = (proxy3 as any).connectionManager; - const immortalCount = cm3.getConnectionCount(); - console.log(`Immortal connections after 5s inactivity: ${immortalCount}`); - expect(immortalCount).toEqual(1); // Should never timeout - - // Verify zombie detection doesn't affect immortal connections - console.log('\n--- Verifying zombie detection respects keepalive ---'); - - // Manually trigger inactivity check - cm3.performOptimizedInactivityCheck(); - - await plugins.smartdelay.delayFor(1000); - - const afterCheckCount = cm3.getConnectionCount(); - console.log(`Connections after manual inactivity check: ${afterCheckCount}`); - expect(afterCheckCount).toEqual(1); // Should still be alive - - // Clean up - client3.destroy(); - await proxy3.stop(); - - // Close backend and wait for it to fully close - await new Promise((resolve) => { - echoBackend.close(() => { - console.log('Echo backend closed'); - resolve(); - }); - }); - - console.log('\nโœ“ All keepalive tests passed:'); - console.log(' - Standard treatment works correctly'); - console.log(' - Extended treatment applies multiplier'); - console.log(' - Immortal treatment never times out'); - console.log(' - Zombie detection respects keepalive settings'); -}); - -export default tap.start(); \ No newline at end of file diff --git a/test/test.memory-leak-check.node.ts b/test/test.memory-leak-check.node.ts deleted file mode 100644 index 0131e1f..0000000 --- a/test/test.memory-leak-check.node.ts +++ /dev/null @@ -1,151 +0,0 @@ -import { tap, expect } from '@git.zone/tstest/tapbundle'; -import { SmartProxy, createHttpRoute } from '../ts/index.js'; -import * as http from 'http'; - -tap.test('should not have memory leaks in long-running operations', async (tools) => { - // Get initial memory usage - const getMemoryUsage = () => { - if (global.gc) { - global.gc(); - } - const usage = process.memoryUsage(); - return { - heapUsed: Math.round(usage.heapUsed / 1024 / 1024), // MB - external: Math.round(usage.external / 1024 / 1024), // MB - rss: Math.round(usage.rss / 1024 / 1024) // MB - }; - }; - - // Create a target server - const targetServer = http.createServer((req, res) => { - res.writeHead(200, { 'Content-Type': 'text/plain' }); - res.end('OK'); - }); - await new Promise((resolve) => targetServer.listen(3100, resolve)); - - // Create the proxy - use non-privileged port - const routes = [ - createHttpRoute(['test1.local', 'test2.local', 'test3.local'], { host: 'localhost', port: 3100 }), - ]; - // Update route to use port 8080 - routes[0].match.ports = 8080; - - const proxy = new SmartProxy({ - routes: routes - }); - await proxy.start(); - - console.log('Starting memory leak test...'); - const initialMemory = getMemoryUsage(); - console.log('Initial memory:', initialMemory); - - // Function to make requests - const makeRequest = (domain: string): Promise => { - return new Promise((resolve, reject) => { - const req = http.request({ - hostname: 'localhost', - port: 8080, - path: '/', - method: 'GET', - headers: { - 'Host': domain - } - }, (res) => { - res.on('data', () => {}); - res.on('end', resolve); - }); - req.on('error', reject); - req.end(); - }); - }; - - // Test 1: Many requests to the same routes - console.log('Test 1: Making 1000 requests to same routes...'); - for (let i = 0; i < 1000; i++) { - await makeRequest(`test${(i % 3) + 1}.local`); - if (i % 100 === 0) { - console.log(` Progress: ${i}/1000`); - } - } - - const afterSameRoutesMemory = getMemoryUsage(); - console.log('Memory after same routes:', afterSameRoutesMemory); - - // Test 2: Many requests to different routes (tests routeContextCache) - console.log('Test 2: Making 1000 requests to different routes...'); - for (let i = 0; i < 1000; i++) { - // Create unique domain to test cache growth - await makeRequest(`test${i}.local`); - if (i % 100 === 0) { - console.log(` Progress: ${i}/1000`); - } - } - - const afterDifferentRoutesMemory = getMemoryUsage(); - console.log('Memory after different routes:', afterDifferentRoutesMemory); - - // Test 3: Check metrics collector memory - console.log('Test 3: Checking metrics collector...'); - const metrics = proxy.getMetrics(); - console.log(`Active connections: ${metrics.connections.active()}`); - console.log(`Total connections: ${metrics.connections.total()}`); - console.log(`RPS: ${metrics.requests.perSecond()}`); - - // Test 4: Many rapid connections (tests requestTimestamps array) - console.log('Test 4: Making 500 rapid requests...'); - const rapidRequests = []; - for (let i = 0; i < 500; i++) { - rapidRequests.push(makeRequest('test1.local')); - if (i % 50 === 0) { - // Wait a bit to let some complete - await Promise.all(rapidRequests); - rapidRequests.length = 0; - // Add delay to allow connections to close - await new Promise(resolve => setTimeout(resolve, 100)); - console.log(` Progress: ${i}/500`); - } - } - await Promise.all(rapidRequests); - - const afterRapidMemory = getMemoryUsage(); - console.log('Memory after rapid requests:', afterRapidMemory); - - // Force garbage collection and check final memory - await new Promise(resolve => setTimeout(resolve, 1000)); - const finalMemory = getMemoryUsage(); - console.log('Final memory:', finalMemory); - - // Memory leak checks - const memoryGrowth = finalMemory.heapUsed - initialMemory.heapUsed; - console.log(`Total memory growth: ${memoryGrowth} MB`); - - // Check for excessive memory growth - // Allow some growth but not excessive (e.g., more than 50MB for this test) - expect(memoryGrowth).toBeLessThan(50); - - // Check specific potential leaks - // 1. Route context cache should not grow unbounded - const routeHandler = proxy.routeConnectionHandler as any; - if (routeHandler.routeContextCache) { - console.log(`Route context cache size: ${routeHandler.routeContextCache.size}`); - // Should not have 1000 entries from different routes test - expect(routeHandler.routeContextCache.size).toBeLessThan(100); - } - - // 2. Metrics collector should clean up old timestamps - const metricsCollector = (proxy as any).metricsCollector; - if (metricsCollector && metricsCollector.requestTimestamps) { - console.log(`Request timestamps array length: ${metricsCollector.requestTimestamps.length}`); - // Should clean up old timestamps periodically - expect(metricsCollector.requestTimestamps.length).toBeLessThanOrEqual(10000); - } - - // Cleanup - await proxy.stop(); - await new Promise((resolve) => targetServer.close(() => resolve())); - - console.log('Memory leak test completed successfully'); -}); - -// Run with: node --expose-gc test.memory-leak-check.node.ts -export default tap.start(); \ No newline at end of file diff --git a/test/test.memory-leak-simple.ts b/test/test.memory-leak-simple.ts deleted file mode 100644 index 05c2f52..0000000 --- a/test/test.memory-leak-simple.ts +++ /dev/null @@ -1,59 +0,0 @@ -import { tap, expect } from '@git.zone/tstest/tapbundle'; -import { SmartProxy, createHttpRoute } from '../ts/index.js'; -import * as http from 'http'; - -tap.test('memory leak fixes verification', async () => { - // Test 1: MetricsCollector requestTimestamps cleanup - console.log('\n=== Test 1: MetricsCollector requestTimestamps cleanup ==='); - const proxy = new SmartProxy({ - routes: [ - createHttpRoute('test.local', { host: 'localhost', port: 3200 }, { - match: { - ports: 8081, - domains: 'test.local' - } - }), - ] - }); - - await proxy.start(); - - const metricsCollector = (proxy as any).metricsCollector; - - // Check initial state - console.log('Initial timestamps:', metricsCollector.requestTimestamps.length); - - // Simulate many requests to test cleanup - for (let i = 0; i < 6000; i++) { - metricsCollector.recordRequest(); - } - - // Should be cleaned up to MAX_TIMESTAMPS (5000) - console.log('After 6000 requests:', metricsCollector.requestTimestamps.length); - expect(metricsCollector.requestTimestamps.length).toBeLessThanOrEqual(5000); - - await proxy.stop(); - - // Test 2: Verify intervals are cleaned up - console.log('\n=== Test 2: Verify cleanup methods exist ==='); - - // Check RequestHandler has destroy method - const { RequestHandler } = await import('../ts/proxies/http-proxy/request-handler.js'); - const requestHandler = new RequestHandler({ port: 8080 }, null as any); - expect(typeof requestHandler.destroy).toEqual('function'); - console.log('โœ“ RequestHandler has destroy method'); - - // Check FunctionCache has destroy method - const { FunctionCache } = await import('../ts/proxies/http-proxy/function-cache.js'); - const functionCache = new FunctionCache({ debug: () => {}, info: () => {} } as any); - expect(typeof functionCache.destroy).toEqual('function'); - console.log('โœ“ FunctionCache has destroy method'); - - // Cleanup - requestHandler.destroy(); - functionCache.destroy(); - - console.log('\nโœ… All memory leak fixes verified!'); -}); - -export default tap.start(); \ No newline at end of file diff --git a/test/test.memory-leak-unit.ts b/test/test.memory-leak-unit.ts deleted file mode 100644 index 1629fe7..0000000 --- a/test/test.memory-leak-unit.ts +++ /dev/null @@ -1,131 +0,0 @@ -import { tap, expect } from '@git.zone/tstest/tapbundle'; - -tap.test('memory leak fixes - unit tests', async () => { - console.log('\n=== Testing MetricsCollector memory management ==='); - - // Import and test MetricsCollector directly - const { MetricsCollector } = await import('../ts/proxies/smart-proxy/metrics-collector.js'); - - // Create a mock SmartProxy with minimal required properties - const mockProxy = { - connectionManager: { - getConnectionCount: () => 0, - getConnections: () => new Map(), - getTerminationStats: () => ({ incoming: {} }) - }, - routeConnectionHandler: { - newConnectionSubject: { - subscribe: () => ({ unsubscribe: () => {} }) - } - }, - settings: {} - }; - - const collector = new MetricsCollector(mockProxy as any); - collector.start(); - - // Test timestamp cleanup - console.log('Testing requestTimestamps cleanup...'); - - // Add 6000 timestamps - for (let i = 0; i < 6000; i++) { - collector.recordRequest(`conn-${i}`, 'test-route', '127.0.0.1'); - } - - // Access private property for testing - let timestamps = (collector as any).requestTimestamps; - console.log(`Timestamps after 6000 requests: ${timestamps.length}`); - - // Force one more request to trigger cleanup - collector.recordRequest('conn-final', 'test-route', '127.0.0.1'); - timestamps = (collector as any).requestTimestamps; - console.log(`Timestamps after cleanup trigger: ${timestamps.length}`); - - // Now check the RPS window - all timestamps are within 1 minute so they won't be cleaned - const now = Date.now(); - const oldestTimestamp = Math.min(...timestamps); - const windowAge = now - oldestTimestamp; - console.log(`Window age: ${windowAge}ms (should be < 60000ms for all to be kept)`); - - // Since all timestamps are recent (within RPS window), they won't be cleaned by window - // But the array size should still be limited - console.log(`MAX_TIMESTAMPS: ${(collector as any).MAX_TIMESTAMPS}`); - - // The issue is our rapid-fire test - all timestamps are within the window - // Let's test with older timestamps - console.log('\nTesting with mixed old/new timestamps...'); - (collector as any).requestTimestamps = []; - - // Add some old timestamps (older than window) - const oldTime = now - 70000; // 70 seconds ago - for (let i = 0; i < 3000; i++) { - (collector as any).requestTimestamps.push(oldTime); - } - - // Add new timestamps to exceed limit - for (let i = 0; i < 3000; i++) { - collector.recordRequest(`conn-new-${i}`, 'test-route', '127.0.0.1'); - } - - timestamps = (collector as any).requestTimestamps; - console.log(`After mixed timestamps: ${timestamps.length} (old ones should be cleaned)`); - - // Old timestamps should be cleaned when we exceed MAX_TIMESTAMPS - expect(timestamps.length).toBeLessThanOrEqual(5000); - - // Stop the collector - collector.stop(); - - console.log('\n=== Testing FunctionCache cleanup ==='); - - const { FunctionCache } = await import('../ts/proxies/http-proxy/function-cache.js'); - - const mockLogger = { - debug: () => {}, - info: () => {}, - warn: () => {}, - error: () => {} - }; - - const cache = new FunctionCache(mockLogger as any); - - // Check that cleanup interval was set - expect((cache as any).cleanupInterval).toBeTruthy(); - - // Test destroy method - cache.destroy(); - - // Cleanup interval should be cleared - expect((cache as any).cleanupInterval).toBeNull(); - - console.log('โœ“ FunctionCache properly cleans up interval'); - - console.log('\n=== Testing RequestHandler cleanup ==='); - - const { RequestHandler } = await import('../ts/proxies/http-proxy/request-handler.js'); - - const mockConnectionPool = { - getConnection: () => null, - releaseConnection: () => {} - }; - - const handler = new RequestHandler( - { port: 8080, logLevel: 'error' }, - mockConnectionPool as any - ); - - // Check that cleanup interval was set - expect((handler as any).rateLimitCleanupInterval).toBeTruthy(); - - // Test destroy method - handler.destroy(); - - // Cleanup interval should be cleared - expect((handler as any).rateLimitCleanupInterval).toBeNull(); - - console.log('โœ“ RequestHandler properly cleans up interval'); - - console.log('\nโœ… All memory leak fixes verified!'); -}); - -export default tap.start(); \ No newline at end of file diff --git a/test/test.metrics-collector.ts b/test/test.metrics-collector.ts deleted file mode 100644 index 7f4bc25..0000000 --- a/test/test.metrics-collector.ts +++ /dev/null @@ -1,280 +0,0 @@ -import { expect, tap } from '@git.zone/tstest/tapbundle'; -import { SmartProxy } from '../ts/index.js'; -import * as net from 'net'; -import * as plugins from '../ts/plugins.js'; - -tap.test('MetricsCollector provides accurate metrics', async (tools) => { - console.log('\n=== MetricsCollector Test ==='); - - // Create a simple echo server for testing - const echoServer = net.createServer((socket) => { - socket.on('data', (data) => { - socket.write(data); - }); - socket.on('error', () => {}); // Ignore errors - }); - - await new Promise((resolve) => { - echoServer.listen(9995, () => { - console.log('โœ“ Echo server started on port 9995'); - resolve(); - }); - }); - - // Create SmartProxy with test routes - const proxy = new SmartProxy({ - routes: [ - { - name: 'test-route-1', - match: { ports: 8700 }, - action: { - type: 'forward', - targets: [{ host: 'localhost', port: 9995 }] - } - }, - { - name: 'test-route-2', - match: { ports: 8701 }, - action: { - type: 'forward', - targets: [{ host: 'localhost', port: 9995 }] - } - } - ], - enableDetailedLogging: true, - }); - - await proxy.start(); - console.log('โœ“ Proxy started on ports 8700 and 8701'); - - // Get metrics interface - const metrics = proxy.getMetrics(); - - // Test 1: Initial state - console.log('\n--- Test 1: Initial State ---'); - expect(metrics.connections.active()).toEqual(0); - expect(metrics.connections.total()).toEqual(0); - expect(metrics.requests.perSecond()).toEqual(0); - expect(metrics.connections.byRoute().size).toEqual(0); - expect(metrics.connections.byIP().size).toEqual(0); - - const throughput = metrics.throughput.instant(); - expect(throughput.in).toEqual(0); - expect(throughput.out).toEqual(0); - console.log('โœ“ Initial metrics are all zero'); - - // Test 2: Create connections and verify metrics - console.log('\n--- Test 2: Active Connections ---'); - const clients: net.Socket[] = []; - - // Create 3 connections to route 1 - for (let i = 0; i < 3; i++) { - const client = net.connect(8700, 'localhost'); - clients.push(client); - await new Promise((resolve) => { - client.on('connect', resolve); - client.on('error', () => resolve()); - }); - } - - // Create 2 connections to route 2 - for (let i = 0; i < 2; i++) { - const client = net.connect(8701, 'localhost'); - clients.push(client); - await new Promise((resolve) => { - client.on('connect', resolve); - client.on('error', () => resolve()); - }); - } - - // Wait for connections to be fully established and routed - await plugins.smartdelay.delayFor(300); - - // Verify connection counts - expect(metrics.connections.active()).toEqual(5); - expect(metrics.connections.total()).toEqual(5); - console.log(`โœ“ Active connections: ${metrics.connections.active()}`); - console.log(`โœ“ Total connections: ${metrics.connections.total()}`); - - // Test 3: Connections by route - console.log('\n--- Test 3: Connections by Route ---'); - const routeConnections = metrics.connections.byRoute(); - console.log('Route connections:', Array.from(routeConnections.entries())); - - // Check if we have the expected counts - let route1Count = 0; - let route2Count = 0; - for (const [routeName, count] of routeConnections) { - if (routeName === 'test-route-1') route1Count = count; - if (routeName === 'test-route-2') route2Count = count; - } - - expect(route1Count).toEqual(3); - expect(route2Count).toEqual(2); - console.log('โœ“ Route test-route-1 has 3 connections'); - console.log('โœ“ Route test-route-2 has 2 connections'); - - // Test 4: Connections by IP - console.log('\n--- Test 4: Connections by IP ---'); - const ipConnections = metrics.connections.byIP(); - // All connections are from localhost (127.0.0.1 or ::1) - let totalIPConnections = 0; - for (const [ip, count] of ipConnections) { - console.log(` IP ${ip}: ${count} connections`); - totalIPConnections += count; - } - expect(totalIPConnections).toEqual(5); - console.log('โœ“ Total connections by IP matches active connections'); - - // Test 5: RPS calculation - console.log('\n--- Test 5: Requests Per Second ---'); - const rps = metrics.requests.perSecond(); - console.log(` Current RPS: ${rps.toFixed(2)}`); - // We created 5 connections, so RPS should be > 0 - expect(rps).toBeGreaterThan(0); - console.log('โœ“ RPS is greater than 0'); - - // Test 6: Throughput - console.log('\n--- Test 6: Throughput ---'); - // Send some data through connections - for (const client of clients) { - if (!client.destroyed) { - client.write('Hello metrics!\n'); - } - } - - // Wait for data to be transmitted and for sampling to occur - await plugins.smartdelay.delayFor(1100); // Wait for at least one sampling interval - - const throughputAfter = metrics.throughput.instant(); - console.log(` Bytes in: ${throughputAfter.in}`); - console.log(` Bytes out: ${throughputAfter.out}`); - // Throughput might still be 0 if no samples were taken, so just check it's defined - expect(throughputAfter.in).toBeDefined(); - expect(throughputAfter.out).toBeDefined(); - console.log('โœ“ Throughput shows bytes transferred'); - - // Test 7: Close some connections - console.log('\n--- Test 7: Connection Cleanup ---'); - // Close first 2 clients - clients[0].destroy(); - clients[1].destroy(); - - await plugins.smartdelay.delayFor(100); - - expect(metrics.connections.active()).toEqual(3); - // Note: total() includes active connections + terminated connections from stats - // The terminated connections might not be counted immediately - const totalConns = metrics.connections.total(); - expect(totalConns).toBeGreaterThanOrEqual(3); // At least the active connections - console.log(`โœ“ Active connections reduced to ${metrics.connections.active()}`); - console.log(`โœ“ Total connections: ${totalConns}`); - - // Test 8: Helper methods - console.log('\n--- Test 8: Helper Methods ---'); - - // Test getTopIPs - const topIPs = metrics.connections.topIPs(5); - expect(topIPs.length).toBeGreaterThan(0); - console.log('โœ“ getTopIPs returns IP list'); - - // Test throughput rate - const throughputRate = metrics.throughput.recent(); - console.log(` Throughput rate: ${throughputRate.in} bytes/sec in, ${throughputRate.out} bytes/sec out`); - console.log('โœ“ Throughput rates calculated'); - - // Cleanup - console.log('\n--- Cleanup ---'); - for (const client of clients) { - if (!client.destroyed) { - client.destroy(); - } - } - - await proxy.stop(); - echoServer.close(); - - console.log('\nโœ“ All MetricsCollector tests passed'); -}); - -// Test with mock data for unit testing -tap.test('MetricsCollector unit test with mock data', async () => { - console.log('\n=== MetricsCollector Unit Test ==='); - - // Create a mock SmartProxy with mock ConnectionManager - const mockConnections = new Map([ - ['conn1', { - remoteIP: '192.168.1.1', - routeName: 'api', - bytesReceived: 1000, - bytesSent: 500, - incomingStartTime: Date.now() - 5000 - }], - ['conn2', { - remoteIP: '192.168.1.1', - routeName: 'web', - bytesReceived: 2000, - bytesSent: 1500, - incomingStartTime: Date.now() - 10000 - }], - ['conn3', { - remoteIP: '192.168.1.2', - routeName: 'api', - bytesReceived: 500, - bytesSent: 250, - incomingStartTime: Date.now() - 3000 - }] - ]); - - const mockSmartProxy = { - connectionManager: { - getConnectionCount: () => mockConnections.size, - getConnections: () => mockConnections, - getTerminationStats: () => ({ - incoming: { normal: 10, timeout: 2, error: 1 } - }) - } - }; - - // Import MetricsCollector directly - const { MetricsCollector } = await import('../ts/proxies/smart-proxy/metrics-collector.js'); - const metrics = new MetricsCollector(mockSmartProxy as any); - - // Test metrics calculation - console.log('\n--- Testing with Mock Data ---'); - - expect(metrics.connections.active()).toEqual(3); - console.log(`โœ“ Active connections: ${metrics.connections.active()}`); - - expect(metrics.connections.total()).toEqual(16); // 3 active + 13 terminated - console.log(`โœ“ Total connections: ${metrics.connections.total()}`); - - const routeConns = metrics.connections.byRoute(); - expect(routeConns.get('api')).toEqual(2); - expect(routeConns.get('web')).toEqual(1); - console.log('โœ“ Connections by route calculated correctly'); - - const ipConns = metrics.connections.byIP(); - expect(ipConns.get('192.168.1.1')).toEqual(2); - expect(ipConns.get('192.168.1.2')).toEqual(1); - console.log('โœ“ Connections by IP calculated correctly'); - - // Throughput tracker returns rates, not totals - just verify it returns something - const throughput = metrics.throughput.instant(); - expect(throughput.in).toBeDefined(); - expect(throughput.out).toBeDefined(); - console.log(`โœ“ Throughput rates calculated: ${throughput.in} bytes/sec in, ${throughput.out} bytes/sec out`); - - // Test RPS tracking - metrics.recordRequest('test-1', 'test-route', '192.168.1.1'); - metrics.recordRequest('test-2', 'test-route', '192.168.1.1'); - metrics.recordRequest('test-3', 'test-route', '192.168.1.2'); - - const rps = metrics.requests.perSecond(); - expect(rps).toBeGreaterThan(0); - console.log(`โœ“ RPS tracking works: ${rps.toFixed(2)} req/sec`); - - console.log('\nโœ“ All unit tests passed'); -}); - -export default tap.start(); \ No newline at end of file diff --git a/test/test.nftables-manager.ts b/test/test.nftables-manager.ts deleted file mode 100644 index a876660..0000000 --- a/test/test.nftables-manager.ts +++ /dev/null @@ -1,188 +0,0 @@ -import { expect, tap } from '@git.zone/tstest/tapbundle'; -import { NFTablesManager } from '../ts/proxies/smart-proxy/nftables-manager.js'; -import type { IRouteConfig } from '../ts/proxies/smart-proxy/models/route-types.js'; -import type { ISmartProxyOptions } from '../ts/proxies/smart-proxy/models/interfaces.js'; -import * as child_process from 'child_process'; -import { promisify } from 'util'; - -const exec = promisify(child_process.exec); - -// Check if we have root privileges -async function checkRootPrivileges(): Promise { - try { - const { stdout } = await exec('id -u'); - return stdout.trim() === '0'; - } catch (err) { - return false; - } -} - -// Skip tests if not root -const isRoot = await checkRootPrivileges(); -if (!isRoot) { - console.log(''); - console.log('========================================'); - console.log('NFTablesManager tests require root privileges'); - console.log('Skipping NFTablesManager tests'); - console.log('========================================'); - console.log(''); - // Skip tests when not running as root - tests are marked with tap.skip.test -} - -/** - * Tests for the NFTablesManager class - */ - -// Sample route configurations for testing -const sampleRoute: IRouteConfig = { - name: 'test-nftables-route', - match: { - ports: 8080, - domains: 'test.example.com' - }, - action: { - type: 'forward', - targets: [{ - host: 'localhost', - port: 8000 - }], - forwardingEngine: 'nftables', - nftables: { - protocol: 'tcp', - preserveSourceIP: true, - useIPSets: true - } - } -}; - -// Sample SmartProxy options -const sampleOptions: ISmartProxyOptions = { - routes: [sampleRoute], - enableDetailedLogging: true -}; - -// Instance of NFTablesManager for testing -let manager: NFTablesManager; - -// Skip these tests by default since they require root privileges to run NFTables commands -// When running as root, change this to false -const SKIP_TESTS = true; - -tap.skip.test('NFTablesManager setup test', async () => { - // Test will be skipped if not running as root due to tap.skip.test - - // Create a SmartProxy instance first - const { SmartProxy } = await import('../ts/proxies/smart-proxy/smart-proxy.js'); - const proxy = new SmartProxy(sampleOptions); - - // Create a new instance of NFTablesManager - manager = new NFTablesManager(proxy); - - // Verify the instance was created successfully - expect(manager).toBeTruthy(); -}); - -tap.skip.test('NFTablesManager route provisioning test', async () => { - // Test will be skipped if not running as root due to tap.skip.test - - // Provision the sample route - const result = await manager.provisionRoute(sampleRoute); - - // Verify the route was provisioned successfully - expect(result).toEqual(true); - - // Verify the route is listed as provisioned - expect(manager.isRouteProvisioned(sampleRoute)).toEqual(true); -}); - -tap.skip.test('NFTablesManager status test', async () => { - // Test will be skipped if not running as root due to tap.skip.test - - // Get the status of the managed rules - const status = await manager.getStatus(); - - // Verify status includes our route - const keys = Object.keys(status); - expect(keys.length).toBeGreaterThan(0); - - // Check the status of the first rule - const firstStatus = status[keys[0]]; - expect(firstStatus.active).toEqual(true); - expect(firstStatus.ruleCount.added).toBeGreaterThan(0); -}); - -tap.skip.test('NFTablesManager route updating test', async () => { - // Test will be skipped if not running as root due to tap.skip.test - - // Create an updated version of the sample route - const updatedRoute: IRouteConfig = { - ...sampleRoute, - action: { - ...sampleRoute.action, - targets: [{ - host: 'localhost', - port: 9000 // Different port - }], - nftables: { - ...sampleRoute.action.nftables, - protocol: 'all' // Different protocol - } - } - }; - - // Update the route - const result = await manager.updateRoute(sampleRoute, updatedRoute); - - // Verify the route was updated successfully - expect(result).toEqual(true); - - // Verify the old route is no longer provisioned - expect(manager.isRouteProvisioned(sampleRoute)).toEqual(false); - - // Verify the new route is provisioned - expect(manager.isRouteProvisioned(updatedRoute)).toEqual(true); -}); - -tap.skip.test('NFTablesManager route deprovisioning test', async () => { - // Test will be skipped if not running as root due to tap.skip.test - - // Create an updated version of the sample route from the previous test - const updatedRoute: IRouteConfig = { - ...sampleRoute, - action: { - ...sampleRoute.action, - targets: [{ - host: 'localhost', - port: 9000 // Different port from original test - }], - nftables: { - ...sampleRoute.action.nftables, - protocol: 'all' // Different protocol from original test - } - } - }; - - // Deprovision the route - const result = await manager.deprovisionRoute(updatedRoute); - - // Verify the route was deprovisioned successfully - expect(result).toEqual(true); - - // Verify the route is no longer provisioned - expect(manager.isRouteProvisioned(updatedRoute)).toEqual(false); -}); - -tap.skip.test('NFTablesManager cleanup test', async () => { - // Test will be skipped if not running as root due to tap.skip.test - - // Stop all NFTables rules - await manager.stop(); - - // Get the status of the managed rules - const status = await manager.getStatus(); - - // Verify there are no active rules - expect(Object.keys(status).length).toEqual(0); -}); - -export default tap.start(); \ No newline at end of file diff --git a/test/test.nftables-status.ts b/test/test.nftables-status.ts deleted file mode 100644 index d7883c6..0000000 --- a/test/test.nftables-status.ts +++ /dev/null @@ -1,166 +0,0 @@ -import { SmartProxy } from '../ts/proxies/smart-proxy/index.js'; -import { NFTablesManager } from '../ts/proxies/smart-proxy/nftables-manager.js'; -import { createNfTablesRoute } from '../ts/proxies/smart-proxy/utils/route-helpers.js'; -import { expect, tap } from '@git.zone/tstest/tapbundle'; -import * as child_process from 'child_process'; -import { promisify } from 'util'; - -const exec = promisify(child_process.exec); - -// Check if we have root privileges -async function checkRootPrivileges(): Promise { - try { - const { stdout } = await exec('id -u'); - return stdout.trim() === '0'; - } catch (err) { - return false; - } -} - -// Skip tests if not root -const isRoot = await checkRootPrivileges(); -if (!isRoot) { - console.log(''); - console.log('========================================'); - console.log('NFTables status tests require root privileges'); - console.log('Skipping NFTables status tests'); - console.log('========================================'); - console.log(''); -} - -// Define the test function based on root privileges -const testFn = isRoot ? tap.test : tap.skip.test; - -testFn('NFTablesManager status functionality', async () => { - const { SmartProxy } = await import('../ts/proxies/smart-proxy/smart-proxy.js'); - const proxy = new SmartProxy({ routes: [] }); - const nftablesManager = new NFTablesManager(proxy); - - // Create test routes - const testRoutes = [ - createNfTablesRoute('test-route-1', { host: 'localhost', port: 8080 }, { ports: 9080 }), - createNfTablesRoute('test-route-2', { host: 'localhost', port: 8081 }, { ports: 9081 }), - createNfTablesRoute('test-route-3', { host: 'localhost', port: 8082 }, { - ports: 9082, - ipAllowList: ['127.0.0.1', '192.168.1.0/24'] - }) - ]; - - // Get initial status (should be empty) - let status = await nftablesManager.getStatus(); - expect(Object.keys(status).length).toEqual(0); - - // Provision routes - for (const route of testRoutes) { - await nftablesManager.provisionRoute(route); - } - - // Get status after provisioning - status = await nftablesManager.getStatus(); - expect(Object.keys(status).length).toEqual(3); - - // Check status structure - for (const routeStatus of Object.values(status)) { - expect(routeStatus).toHaveProperty('active'); - expect(routeStatus).toHaveProperty('ruleCount'); - expect(routeStatus).toHaveProperty('lastUpdate'); - expect(routeStatus.active).toBeTrue(); - } - - // Deprovision one route - await nftablesManager.deprovisionRoute(testRoutes[0]); - - // Check status after deprovisioning - status = await nftablesManager.getStatus(); - expect(Object.keys(status).length).toEqual(2); - - // Cleanup remaining routes - await nftablesManager.stop(); - - // Final status should be empty - status = await nftablesManager.getStatus(); - expect(Object.keys(status).length).toEqual(0); -}); - -testFn('SmartProxy getNfTablesStatus functionality', async () => { - const smartProxy = new SmartProxy({ - routes: [ - createNfTablesRoute('proxy-test-1', { host: 'localhost', port: 3000 }, { ports: 3001 }), - createNfTablesRoute('proxy-test-2', { host: 'localhost', port: 3002 }, { ports: 3003 }), - // Include a non-NFTables route to ensure it's not included in the status - { - name: 'non-nftables-route', - match: { ports: 3004 }, - action: { - type: 'forward', - targets: [{ host: 'localhost', port: 3005 }] - } - } - ] - }); - - // Start the proxy - await smartProxy.start(); - - // Get NFTables status - const status = await smartProxy.getNfTablesStatus(); - - // Should only have 2 NFTables routes - const statusKeys = Object.keys(status); - expect(statusKeys.length).toEqual(2); - - // Check that both NFTables routes are in the status - const routeIds = statusKeys.sort(); - expect(routeIds).toContain('proxy-test-1:3001'); - expect(routeIds).toContain('proxy-test-2:3003'); - - // Verify status structure - for (const [routeId, routeStatus] of Object.entries(status)) { - expect(routeStatus).toHaveProperty('active', true); - expect(routeStatus).toHaveProperty('ruleCount'); - expect(routeStatus.ruleCount).toHaveProperty('total'); - expect(routeStatus.ruleCount.total).toBeGreaterThan(0); - } - - // Stop the proxy - await smartProxy.stop(); - - // After stopping, status should be empty - const finalStatus = await smartProxy.getNfTablesStatus(); - expect(Object.keys(finalStatus).length).toEqual(0); -}); - -testFn('NFTables route update status tracking', async () => { - const smartProxy = new SmartProxy({ - routes: [ - createNfTablesRoute('update-test', { host: 'localhost', port: 4000 }, { ports: 4001 }) - ] - }); - - await smartProxy.start(); - - // Get initial status - let status = await smartProxy.getNfTablesStatus(); - expect(Object.keys(status).length).toEqual(1); - const initialUpdate = status['update-test:4001'].lastUpdate; - - // Wait a moment - await new Promise(resolve => setTimeout(resolve, 10)); - - // Update the route - await smartProxy.updateRoutes([ - createNfTablesRoute('update-test', { host: 'localhost', port: 4002 }, { ports: 4001 }) - ]); - - // Get status after update - status = await smartProxy.getNfTablesStatus(); - expect(Object.keys(status).length).toEqual(1); - const updatedTime = status['update-test:4001'].lastUpdate; - - // The update time should be different - expect(updatedTime.getTime()).toBeGreaterThan(initialUpdate.getTime()); - - await smartProxy.stop(); -}); - -export default tap.start(); \ No newline at end of file diff --git a/test/test.port80-management.node.ts b/test/test.port80-management.node.ts deleted file mode 100644 index ba413e6..0000000 --- a/test/test.port80-management.node.ts +++ /dev/null @@ -1,281 +0,0 @@ -import { expect, tap } from '@git.zone/tstest/tapbundle'; -import { SmartProxy } from '../ts/index.js'; - -/** - * Test that verifies port 80 is not double-registered when both - * user routes and ACME challenges use the same port - */ -tap.test('should not double-register port 80 when user route and ACME use same port', async (tools) => { - tools.timeout(5000); - - let port80AddCount = 0; - const activePorts = new Set(); - - const settings = { - port: 9901, - routes: [ - { - name: 'user-route', - match: { - ports: [80] - }, - action: { - type: 'forward' as const, - targets: [{ host: 'localhost', port: 3000 }] - } - }, - { - name: 'secure-route', - match: { - ports: [443] - }, - action: { - type: 'forward' as const, - targets: [{ host: 'localhost', port: 3001 }], - tls: { - mode: 'terminate' as const, - certificate: 'auto' as const - } - } - } - ], - acme: { - email: 'test@test.com', - port: 80 // ACME on same port as user route - } - }; - - const proxy = new SmartProxy(settings); - - // Mock the port manager to track port additions - const mockPortManager = { - addPort: async (port: number) => { - if (activePorts.has(port)) { - return; // Simulate deduplication - } - activePorts.add(port); - if (port === 80) { - port80AddCount++; - } - }, - addPorts: async (ports: number[]) => { - for (const port of ports) { - await mockPortManager.addPort(port); - } - }, - updatePorts: async (requiredPorts: Set) => { - for (const port of requiredPorts) { - await mockPortManager.addPort(port); - } - }, - setShuttingDown: () => {}, - closeAll: async () => { activePorts.clear(); }, - stop: async () => { await mockPortManager.closeAll(); } - }; - - // Inject mock - (proxy as any).portManager = mockPortManager; - - // Mock certificate manager to prevent ACME calls - (proxy as any).createCertificateManager = async function(routes: any[], certDir: string, acmeOptions: any, initialState?: any) { - const mockCertManager = { - setUpdateRoutesCallback: function(callback: any) { /* noop */ }, - setHttpProxy: function() {}, - setGlobalAcmeDefaults: function() {}, - setAcmeStateManager: function() {}, - initialize: async function() { - // Simulate ACME route addition - const challengeRoute = { - name: 'acme-challenge', - priority: 1000, - match: { - ports: acmeOptions?.port || 80, - path: '/.well-known/acme-challenge/*' - }, - action: { - type: 'static' - } - }; - // This would trigger route update in real implementation - }, - provisionAllCertificates: async function() { - // Mock implementation to satisfy the call in SmartProxy.start() - // Add the ACME challenge port here too in case initialize was skipped - const challengePort = acmeOptions?.port || 80; - await mockPortManager.addPort(challengePort); - console.log(`Added ACME challenge port from provisionAllCertificates: ${challengePort}`); - }, - getAcmeOptions: () => acmeOptions, - getState: () => ({ challengeRouteActive: false }), - stop: async () => {} - }; - return mockCertManager; - }; - - // Mock NFTables - (proxy as any).nftablesManager = { - ensureNFTablesSetup: async () => {}, - stop: async () => {} - }; - - // Mock admin server - (proxy as any).startAdminServer = async function() { - (this as any).servers.set(this.settings.port, { - port: this.settings.port, - close: async () => {} - }); - }; - - await proxy.start(); - - // Verify that port 80 was added only once - expect(port80AddCount).toEqual(1); - - await proxy.stop(); -}); - -/** - * Test that verifies ACME can use a different port than user routes - */ -tap.test('should handle ACME on different port than user routes', async (tools) => { - tools.timeout(5000); - - const portAddHistory: number[] = []; - const activePorts = new Set(); - - const settings = { - port: 9902, - routes: [ - { - name: 'user-route', - match: { - ports: [80] - }, - action: { - type: 'forward' as const, - targets: [{ host: 'localhost', port: 3000 }] - } - }, - { - name: 'secure-route', - match: { - ports: [443] - }, - action: { - type: 'forward' as const, - targets: [{ host: 'localhost', port: 3001 }], - tls: { - mode: 'terminate' as const, - certificate: 'auto' as const - } - } - } - ], - acme: { - email: 'test@test.com', - port: 8080 // ACME on different port than user routes - } - }; - - const proxy = new SmartProxy(settings); - - // Mock the port manager - const mockPortManager = { - addPort: async (port: number) => { - console.log(`Attempting to add port: ${port}`); - if (!activePorts.has(port)) { - activePorts.add(port); - portAddHistory.push(port); - console.log(`Port ${port} added to history`); - } else { - console.log(`Port ${port} already active, not adding to history`); - } - }, - addPorts: async (ports: number[]) => { - for (const port of ports) { - await mockPortManager.addPort(port); - } - }, - updatePorts: async (requiredPorts: Set) => { - for (const port of requiredPorts) { - await mockPortManager.addPort(port); - } - }, - setShuttingDown: () => {}, - closeAll: async () => { activePorts.clear(); }, - stop: async () => { await mockPortManager.closeAll(); } - }; - - // Inject mocks - (proxy as any).portManager = mockPortManager; - - // Mock certificate manager - (proxy as any).createCertificateManager = async function(routes: any[], certDir: string, acmeOptions: any, initialState?: any) { - const mockCertManager = { - setUpdateRoutesCallback: function(callback: any) { /* noop */ }, - setHttpProxy: function() {}, - setGlobalAcmeDefaults: function() {}, - setAcmeStateManager: function() {}, - initialize: async function() { - // Simulate ACME route addition on different port - const challengePort = acmeOptions?.port || 80; - const challengeRoute = { - name: 'acme-challenge', - priority: 1000, - match: { - ports: challengePort, - path: '/.well-known/acme-challenge/*' - }, - action: { - type: 'static' - } - }; - - // Add the ACME port to our port tracking - await mockPortManager.addPort(challengePort); - - // For debugging - console.log(`Added ACME challenge port: ${challengePort}`); - }, - provisionAllCertificates: async function() { - // Mock implementation to satisfy the call in SmartProxy.start() - // Add the ACME challenge port here too in case initialize was skipped - const challengePort = acmeOptions?.port || 80; - await mockPortManager.addPort(challengePort); - console.log(`Added ACME challenge port from provisionAllCertificates: ${challengePort}`); - }, - getAcmeOptions: () => acmeOptions, - getState: () => ({ challengeRouteActive: false }), - stop: async () => {} - }; - return mockCertManager; - }; - - // Mock NFTables - (proxy as any).nftablesManager = { - ensureNFTablesSetup: async () => {}, - stop: async () => {} - }; - - // Mock admin server - (proxy as any).startAdminServer = async function() { - (this as any).servers.set(this.settings.port, { - port: this.settings.port, - close: async () => {} - }); - }; - - await proxy.start(); - - // Log the port history for debugging - console.log('Port add history:', portAddHistory); - - // Verify that all expected ports were added - expect(portAddHistory.includes(80)).toBeTrue(); // User route - expect(portAddHistory.includes(443)).toBeTrue(); // TLS route - expect(portAddHistory.includes(8080)).toBeTrue(); // ACME challenge on different port - - await proxy.stop(); -}); - -export default tap.start(); \ No newline at end of file diff --git a/test/test.proxy-chain-cleanup.node.ts b/test/test.proxy-chain-cleanup.node.ts deleted file mode 100644 index 7c57027..0000000 --- a/test/test.proxy-chain-cleanup.node.ts +++ /dev/null @@ -1,182 +0,0 @@ -import { expect, tap } from '@git.zone/tstest/tapbundle'; -import * as plugins from '../ts/plugins.js'; -import { SmartProxy } from '../ts/index.js'; - -let outerProxy: SmartProxy; -let innerProxy: SmartProxy; - -tap.test('setup two smartproxies in a chain configuration', async () => { - // Setup inner proxy (backend proxy) - innerProxy = new SmartProxy({ - routes: [ - { - name: 'inner-backend', - match: { - ports: 8002 - }, - action: { - type: 'forward', - targets: [{ - host: 'httpbin.org', - port: 443 - }] - } - } - ], - defaults: { - target: { - host: 'httpbin.org', - port: 443 - } - }, - acceptProxyProtocol: true, - sendProxyProtocol: false, - enableDetailedLogging: true, - inactivityTimeout: 10000 // Shorter timeout for testing - }); - await innerProxy.start(); - - // Setup outer proxy (frontend proxy) - outerProxy = new SmartProxy({ - routes: [ - { - name: 'outer-frontend', - match: { - ports: 8001 - }, - action: { - type: 'forward', - targets: [{ - host: 'localhost', - port: 8002 - }], - sendProxyProtocol: true - } - } - ], - defaults: { - target: { - host: 'localhost', - port: 8002 - } - }, - sendProxyProtocol: true, - enableDetailedLogging: true, - inactivityTimeout: 10000 // Shorter timeout for testing - }); - await outerProxy.start(); -}); - -tap.test('should properly cleanup connections in proxy chain', async (tools) => { - const testDuration = 30000; // 30 seconds - const connectionInterval = 500; // Create new connection every 500ms - const connectionDuration = 2000; // Each connection lasts 2 seconds - - let connectionsCreated = 0; - let connectionsCompleted = 0; - - // Function to create a test connection - const createTestConnection = async () => { - connectionsCreated++; - const connectionId = connectionsCreated; - - try { - const socket = plugins.net.connect({ - port: 8001, - host: 'localhost' - }); - - await new Promise((resolve, reject) => { - socket.on('connect', () => { - console.log(`Connection ${connectionId} established`); - - // Send TLS Client Hello for httpbin.org - const clientHello = Buffer.from([ - 0x16, 0x03, 0x01, 0x00, 0xc8, // TLS handshake header - 0x01, 0x00, 0x00, 0xc4, // Client Hello - 0x03, 0x03, // TLS 1.2 - ...Array(32).fill(0), // Random bytes - 0x00, // Session ID length - 0x00, 0x02, 0x13, 0x01, // Cipher suites - 0x01, 0x00, // Compression methods - 0x00, 0x97, // Extensions length - 0x00, 0x00, 0x00, 0x0f, 0x00, 0x0d, // SNI extension - 0x00, 0x00, 0x0a, 0x68, 0x74, 0x74, 0x70, 0x62, 0x69, 0x6e, 0x2e, 0x6f, 0x72, 0x67 // "httpbin.org" - ]); - - socket.write(clientHello); - - // Keep connection alive for specified duration - setTimeout(() => { - socket.destroy(); - connectionsCompleted++; - console.log(`Connection ${connectionId} closed (completed: ${connectionsCompleted}/${connectionsCreated})`); - resolve(); - }, connectionDuration); - }); - - socket.on('error', (err) => { - console.log(`Connection ${connectionId} error: ${err.message}`); - connectionsCompleted++; - reject(err); - }); - }); - } catch (err) { - console.log(`Failed to create connection ${connectionId}: ${err.message}`); - connectionsCompleted++; - } - }; - - // Start creating connections - const startTime = Date.now(); - const connectionTimer = setInterval(() => { - if (Date.now() - startTime < testDuration) { - createTestConnection().catch(() => {}); - } else { - clearInterval(connectionTimer); - } - }, connectionInterval); - - // Monitor connection counts - const monitorInterval = setInterval(() => { - const outerConnections = (outerProxy as any).connectionManager.getConnectionCount(); - const innerConnections = (innerProxy as any).connectionManager.getConnectionCount(); - - console.log(`Active connections - Outer: ${outerConnections}, Inner: ${innerConnections}, Created: ${connectionsCreated}, Completed: ${connectionsCompleted}`); - }, 2000); - - // Wait for test duration + cleanup time - await tools.delayFor(testDuration + 10000); - - clearInterval(connectionTimer); - clearInterval(monitorInterval); - - // Wait for all connections to complete - while (connectionsCompleted < connectionsCreated) { - await tools.delayFor(100); - } - - // Give some time for cleanup - await tools.delayFor(5000); - - // Check final connection counts - const finalOuterConnections = (outerProxy as any).connectionManager.getConnectionCount(); - const finalInnerConnections = (innerProxy as any).connectionManager.getConnectionCount(); - - console.log(`\nFinal connection counts:`); - console.log(`Outer proxy: ${finalOuterConnections}`); - console.log(`Inner proxy: ${finalInnerConnections}`); - console.log(`Total created: ${connectionsCreated}`); - console.log(`Total completed: ${connectionsCompleted}`); - - // Both proxies should have cleaned up all connections - expect(finalOuterConnections).toEqual(0); - expect(finalInnerConnections).toEqual(0); -}); - -tap.test('cleanup proxies', async () => { - await outerProxy.stop(); - await innerProxy.stop(); -}); - -export default tap.start(); \ No newline at end of file diff --git a/test/test.proxy-chain-simple.node.ts b/test/test.proxy-chain-simple.node.ts deleted file mode 100644 index b89024f..0000000 --- a/test/test.proxy-chain-simple.node.ts +++ /dev/null @@ -1,193 +0,0 @@ -import { tap, expect } from '@git.zone/tstest/tapbundle'; -import * as net from 'net'; -import * as plugins from '../ts/plugins.js'; - -// Import SmartProxy and configurations -import { SmartProxy } from '../ts/index.js'; - -tap.test('simple proxy chain test - identify connection accumulation', async () => { - console.log('\n=== Simple Proxy Chain Test ==='); - console.log('Setup: Client โ†’ SmartProxy1 (8590) โ†’ SmartProxy2 (8591) โ†’ Backend (down)'); - - // Create backend server that accepts and immediately closes connections - const backend = net.createServer((socket) => { - console.log('Backend: Connection received, closing immediately'); - socket.destroy(); - }); - - await new Promise((resolve) => { - backend.listen(9998, () => { - console.log('โœ“ Backend server started on port 9998 (closes connections immediately)'); - resolve(); - }); - }); - - // Create SmartProxy2 (downstream) - const proxy2 = new SmartProxy({ - enableDetailedLogging: true, - socketTimeout: 5000, - routes: [{ - name: 'to-backend', - match: { ports: 8591 }, - action: { - type: 'forward', - targets: [{ - host: 'localhost', - port: 9998 // Backend that closes immediately - }] - } - }] - }); - - // Create SmartProxy1 (upstream) - const proxy1 = new SmartProxy({ - enableDetailedLogging: true, - socketTimeout: 5000, - routes: [{ - name: 'to-proxy2', - match: { ports: 8590 }, - action: { - type: 'forward', - targets: [{ - host: 'localhost', - port: 8591 // Forward to proxy2 - }] - } - }] - }); - - await proxy2.start(); - console.log('โœ“ SmartProxy2 started on port 8591'); - - await proxy1.start(); - console.log('โœ“ SmartProxy1 started on port 8590'); - - // Helper to get connection counts - const getConnectionCounts = () => { - const conn1 = (proxy1 as any).connectionManager; - const conn2 = (proxy2 as any).connectionManager; - return { - proxy1: conn1 ? conn1.getConnectionCount() : 0, - proxy2: conn2 ? conn2.getConnectionCount() : 0 - }; - }; - - console.log('\n--- Making 5 sequential connections ---'); - - for (let i = 0; i < 5; i++) { - console.log(`\n=== Connection ${i + 1} ===`); - - const counts = getConnectionCounts(); - console.log(`Before: Proxy1=${counts.proxy1}, Proxy2=${counts.proxy2}`); - - await new Promise((resolve) => { - const client = new net.Socket(); - let dataReceived = false; - - client.on('data', (data) => { - console.log(`Client received data: ${data.toString()}`); - dataReceived = true; - }); - - client.on('error', (err: NodeJS.ErrnoException) => { - console.log(`Client error: ${err.code}`); - resolve(); - }); - - client.on('close', () => { - console.log(`Client closed (data received: ${dataReceived})`); - resolve(); - }); - - client.connect(8590, 'localhost', () => { - console.log('Client connected to Proxy1'); - // Send HTTP request - client.write('GET / HTTP/1.1\r\nHost: test.com\r\n\r\n'); - }); - - // Timeout - setTimeout(() => { - if (!client.destroyed) { - console.log('Client timeout, destroying'); - client.destroy(); - } - resolve(); - }, 2000); - }); - - // Wait a bit and check counts - await new Promise(resolve => setTimeout(resolve, 500)); - - const afterCounts = getConnectionCounts(); - console.log(`After: Proxy1=${afterCounts.proxy1}, Proxy2=${afterCounts.proxy2}`); - - if (afterCounts.proxy1 > 0 || afterCounts.proxy2 > 0) { - console.log('โš ๏ธ WARNING: Connections not cleaned up!'); - } - } - - console.log('\n--- Test with backend completely down ---'); - - // Stop backend - backend.close(); - await new Promise(resolve => setTimeout(resolve, 100)); - console.log('โœ“ Backend stopped'); - - // Make more connections with backend down - for (let i = 0; i < 3; i++) { - console.log(`\n=== Connection ${i + 6} (backend down) ===`); - - const counts = getConnectionCounts(); - console.log(`Before: Proxy1=${counts.proxy1}, Proxy2=${counts.proxy2}`); - - await new Promise((resolve) => { - const client = new net.Socket(); - - client.on('error', () => { - resolve(); - }); - - client.on('close', () => { - resolve(); - }); - - client.connect(8590, 'localhost', () => { - client.write('GET / HTTP/1.1\r\nHost: test.com\r\n\r\n'); - }); - - setTimeout(() => { - if (!client.destroyed) { - client.destroy(); - } - resolve(); - }, 1000); - }); - - await new Promise(resolve => setTimeout(resolve, 500)); - - const afterCounts = getConnectionCounts(); - console.log(`After: Proxy1=${afterCounts.proxy1}, Proxy2=${afterCounts.proxy2}`); - } - - // Final check - console.log('\n--- Final Check ---'); - await new Promise(resolve => setTimeout(resolve, 1000)); - - const finalCounts = getConnectionCounts(); - console.log(`Final counts: Proxy1=${finalCounts.proxy1}, Proxy2=${finalCounts.proxy2}`); - - await proxy1.stop(); - await proxy2.stop(); - - // Verify - if (finalCounts.proxy1 > 0 || finalCounts.proxy2 > 0) { - console.log('\nโŒ FAIL: Connections accumulated!'); - } else { - console.log('\nโœ… PASS: No connection accumulation'); - } - - expect(finalCounts.proxy1).toEqual(0); - expect(finalCounts.proxy2).toEqual(0); -}); - -export default tap.start(); \ No newline at end of file diff --git a/test/test.proxy-chaining-accumulation.node.ts b/test/test.proxy-chaining-accumulation.node.ts deleted file mode 100644 index a345d0c..0000000 --- a/test/test.proxy-chaining-accumulation.node.ts +++ /dev/null @@ -1,364 +0,0 @@ -import { tap, expect } from '@git.zone/tstest/tapbundle'; -import * as net from 'net'; -import * as plugins from '../ts/plugins.js'; - -// Import SmartProxy and configurations -import { SmartProxy } from '../ts/index.js'; - -tap.test('should handle proxy chaining without connection accumulation', async () => { - console.log('\n=== Testing Proxy Chaining Connection Accumulation ==='); - console.log('Setup: Client โ†’ SmartProxy1 โ†’ SmartProxy2 โ†’ Backend (down)'); - - // Create SmartProxy2 (downstream proxy) - const proxy2 = new SmartProxy({ - enableDetailedLogging: false, - socketTimeout: 5000, - routes: [{ - name: 'backend-route', - match: { ports: 8581 }, - action: { - type: 'forward', - targets: [{ - host: 'localhost', - port: 9999 // Non-existent backend - }] - } - }] - }); - - // Create SmartProxy1 (upstream proxy) - const proxy1 = new SmartProxy({ - enableDetailedLogging: false, - socketTimeout: 5000, - routes: [{ - name: 'chain-route', - match: { ports: 8580 }, - action: { - type: 'forward', - targets: [{ - host: 'localhost', - port: 8581 // Forward to proxy2 - }] - } - }] - }); - - // Start both proxies - await proxy2.start(); - console.log('โœ“ SmartProxy2 started on port 8581'); - - await proxy1.start(); - console.log('โœ“ SmartProxy1 started on port 8580'); - - // Helper to get connection counts - const getConnectionCounts = () => { - const conn1 = (proxy1 as any).connectionManager; - const conn2 = (proxy2 as any).connectionManager; - return { - proxy1: conn1 ? conn1.getConnectionCount() : 0, - proxy2: conn2 ? conn2.getConnectionCount() : 0 - }; - }; - - const initialCounts = getConnectionCounts(); - console.log(`\nInitial connection counts - Proxy1: ${initialCounts.proxy1}, Proxy2: ${initialCounts.proxy2}`); - - // Test 1: Single connection attempt - console.log('\n--- Test 1: Single connection through chain ---'); - - await new Promise((resolve) => { - const client = new net.Socket(); - - client.on('error', (err: NodeJS.ErrnoException) => { - console.log(`Client received error: ${err.code}`); - resolve(); - }); - - client.on('close', () => { - console.log('Client connection closed'); - resolve(); - }); - - client.connect(8580, 'localhost', () => { - console.log('Client connected to Proxy1'); - // Send data to trigger routing - client.write('GET / HTTP/1.1\r\nHost: test.com\r\n\r\n'); - }); - - // Timeout - setTimeout(() => { - if (!client.destroyed) { - client.destroy(); - } - resolve(); - }, 1000); - }); - - // Check connections after single attempt - await new Promise(resolve => setTimeout(resolve, 500)); - let counts = getConnectionCounts(); - console.log(`After single connection - Proxy1: ${counts.proxy1}, Proxy2: ${counts.proxy2}`); - - // Test 2: Multiple simultaneous connections - console.log('\n--- Test 2: Multiple simultaneous connections ---'); - - const promises = []; - for (let i = 0; i < 10; i++) { - promises.push(new Promise((resolve) => { - const client = new net.Socket(); - - client.on('error', () => { - resolve(); - }); - - client.on('close', () => { - resolve(); - }); - - client.connect(8580, 'localhost', () => { - // Send data - client.write(`GET /test${i} HTTP/1.1\r\nHost: test.com\r\n\r\n`); - }); - - // Timeout - setTimeout(() => { - if (!client.destroyed) { - client.destroy(); - } - resolve(); - }, 500); - })); - } - - await Promise.all(promises); - console.log('โœ“ All simultaneous connections completed'); - - // Check connections - counts = getConnectionCounts(); - console.log(`After simultaneous connections - Proxy1: ${counts.proxy1}, Proxy2: ${counts.proxy2}`); - - // Test 3: Rapid serial connections (simulating retries) - console.log('\n--- Test 3: Rapid serial connections (retries) ---'); - - for (let i = 0; i < 20; i++) { - await new Promise((resolve) => { - const client = new net.Socket(); - - client.on('error', () => { - resolve(); - }); - - client.on('close', () => { - resolve(); - }); - - client.connect(8580, 'localhost', () => { - client.write('GET / HTTP/1.1\r\nHost: test.com\r\n\r\n'); - // Quick disconnect to simulate retry behavior - setTimeout(() => client.destroy(), 50); - }); - - // Timeout - setTimeout(() => { - if (!client.destroyed) { - client.destroy(); - } - resolve(); - }, 200); - }); - - if ((i + 1) % 5 === 0) { - counts = getConnectionCounts(); - console.log(`After ${i + 1} retries - Proxy1: ${counts.proxy1}, Proxy2: ${counts.proxy2}`); - } - - // Small delay between retries - await new Promise(resolve => setTimeout(resolve, 50)); - } - - // Test 4: Long-lived connection attempt - console.log('\n--- Test 4: Long-lived connection attempt ---'); - - await new Promise((resolve) => { - const client = new net.Socket(); - - client.on('error', () => { - resolve(); - }); - - client.on('close', () => { - console.log('Long-lived client closed'); - resolve(); - }); - - client.connect(8580, 'localhost', () => { - console.log('Long-lived client connected'); - // Send data periodically - const interval = setInterval(() => { - if (!client.destroyed && client.writable) { - client.write('PING\r\n'); - } else { - clearInterval(interval); - } - }, 100); - - // Close after 2 seconds - setTimeout(() => { - clearInterval(interval); - client.destroy(); - }, 2000); - }); - - // Timeout - setTimeout(() => { - if (!client.destroyed) { - client.destroy(); - } - resolve(); - }, 3000); - }); - - // Final check - await new Promise(resolve => setTimeout(resolve, 1000)); - - const finalCounts = getConnectionCounts(); - console.log(`\nFinal connection counts - Proxy1: ${finalCounts.proxy1}, Proxy2: ${finalCounts.proxy2}`); - - // Monitor for a bit to see if connections are cleaned up - console.log('\nMonitoring connection cleanup...'); - for (let i = 0; i < 3; i++) { - await new Promise(resolve => setTimeout(resolve, 500)); - counts = getConnectionCounts(); - console.log(`After ${(i + 1) * 0.5}s - Proxy1: ${counts.proxy1}, Proxy2: ${counts.proxy2}`); - } - - // Stop proxies - await proxy1.stop(); - console.log('\nโœ“ SmartProxy1 stopped'); - - await proxy2.stop(); - console.log('โœ“ SmartProxy2 stopped'); - - // Analysis - console.log('\n=== Analysis ==='); - if (finalCounts.proxy1 > 0 || finalCounts.proxy2 > 0) { - console.log('โŒ FAIL: Connections accumulated!'); - console.log(`Proxy1 leaked ${finalCounts.proxy1} connections`); - console.log(`Proxy2 leaked ${finalCounts.proxy2} connections`); - } else { - console.log('โœ… PASS: No connection accumulation detected'); - } - - // Verify - expect(finalCounts.proxy1).toEqual(0); - expect(finalCounts.proxy2).toEqual(0); -}); - -tap.test('should handle proxy chain with HTTP traffic', async () => { - console.log('\n=== Testing Proxy Chain with HTTP Traffic ==='); - - // Create SmartProxy2 with HTTP handling - const proxy2 = new SmartProxy({ - useHttpProxy: [8583], // Enable HTTP proxy handling - httpProxyPort: 8584, - enableDetailedLogging: false, - routes: [{ - name: 'http-backend', - match: { ports: 8583 }, - action: { - type: 'forward', - targets: [{ - host: 'localhost', - port: 9999 // Non-existent backend - }] - } - }] - }); - - // Create SmartProxy1 with HTTP handling - const proxy1 = new SmartProxy({ - useHttpProxy: [8582], // Enable HTTP proxy handling - httpProxyPort: 8585, - enableDetailedLogging: false, - routes: [{ - name: 'http-chain', - match: { ports: 8582 }, - action: { - type: 'forward', - targets: [{ - host: 'localhost', - port: 8583 // Forward to proxy2 - }] - } - }] - }); - - await proxy2.start(); - console.log('โœ“ SmartProxy2 (HTTP) started on port 8583'); - - await proxy1.start(); - console.log('โœ“ SmartProxy1 (HTTP) started on port 8582'); - - // Helper to get connection counts - const getConnectionCounts = () => { - const conn1 = (proxy1 as any).connectionManager; - const conn2 = (proxy2 as any).connectionManager; - return { - proxy1: conn1 ? conn1.getConnectionCount() : 0, - proxy2: conn2 ? conn2.getConnectionCount() : 0 - }; - }; - - console.log('\nSending HTTP requests through chain...'); - - // Make HTTP requests - for (let i = 0; i < 5; i++) { - await new Promise((resolve) => { - const client = new net.Socket(); - let responseData = ''; - - client.on('data', (data) => { - responseData += data.toString(); - // Check if we got a complete HTTP response - if (responseData.includes('\r\n\r\n')) { - console.log(`Response ${i + 1}: ${responseData.split('\r\n')[0]}`); - client.destroy(); - } - }); - - client.on('error', () => { - resolve(); - }); - - client.on('close', () => { - resolve(); - }); - - client.connect(8582, 'localhost', () => { - client.write(`GET /test${i} HTTP/1.1\r\nHost: test.com\r\nConnection: close\r\n\r\n`); - }); - - setTimeout(() => { - if (!client.destroyed) { - client.destroy(); - } - resolve(); - }, 1000); - }); - - await new Promise(resolve => setTimeout(resolve, 100)); - } - - await new Promise(resolve => setTimeout(resolve, 1000)); - - const finalCounts = getConnectionCounts(); - console.log(`\nFinal HTTP proxy counts - Proxy1: ${finalCounts.proxy1}, Proxy2: ${finalCounts.proxy2}`); - - await proxy1.stop(); - await proxy2.stop(); - - expect(finalCounts.proxy1).toEqual(0); - expect(finalCounts.proxy2).toEqual(0); -}); - -export default tap.start(); \ No newline at end of file diff --git a/test/test.rapid-retry-cleanup.node.ts b/test/test.rapid-retry-cleanup.node.ts deleted file mode 100644 index 60c064e..0000000 --- a/test/test.rapid-retry-cleanup.node.ts +++ /dev/null @@ -1,199 +0,0 @@ -import { tap, expect } from '@git.zone/tstest/tapbundle'; -import * as net from 'net'; -import * as plugins from '../ts/plugins.js'; - -// Import SmartProxy and configurations -import { SmartProxy } from '../ts/index.js'; - -tap.test('should handle rapid connection retries without leaking connections', async () => { - console.log('\n=== Testing Rapid Connection Retry Cleanup ==='); - - // Create a SmartProxy instance - const proxy = new SmartProxy({ - enableDetailedLogging: false, - maxConnectionLifetime: 10000, - socketTimeout: 5000, - routes: [{ - name: 'test-route', - match: { ports: 8550 }, - action: { - type: 'forward', - targets: [{ - host: 'localhost', - port: 9999 // Non-existent port to force connection failures - }] - } - }] - }); - - // Start the proxy - await proxy.start(); - console.log('โœ“ Proxy started on port 8550'); - - // Helper to get active connection count - const getActiveConnections = () => { - const connectionManager = (proxy as any).connectionManager; - return connectionManager ? connectionManager.getConnectionCount() : 0; - }; - - // Track connection counts - const connectionCounts: number[] = []; - const initialCount = getActiveConnections(); - console.log(`Initial connection count: ${initialCount}`); - - // Simulate rapid retries - const retryCount = 20; - const retryDelay = 50; // 50ms between retries - let successfulConnections = 0; - let failedConnections = 0; - - console.log(`\nSimulating ${retryCount} rapid connection attempts...`); - - for (let i = 0; i < retryCount; i++) { - await new Promise((resolve) => { - const client = new net.Socket(); - - client.on('error', () => { - failedConnections++; - client.destroy(); - resolve(); - }); - - client.on('close', () => { - resolve(); - }); - - client.connect(8550, 'localhost', () => { - // Send some data to trigger routing - client.write('GET / HTTP/1.1\r\nHost: test.com\r\n\r\n'); - successfulConnections++; - }); - - // Force close after a short time - setTimeout(() => { - if (!client.destroyed) { - client.destroy(); - } - }, 100); - }); - - // Small delay between retries - await new Promise(resolve => setTimeout(resolve, retryDelay)); - - // Check connection count after each attempt - const currentCount = getActiveConnections(); - connectionCounts.push(currentCount); - - if ((i + 1) % 5 === 0) { - console.log(`After ${i + 1} attempts: ${currentCount} active connections`); - } - } - - console.log(`\nConnection attempts complete:`); - console.log(`- Successful: ${successfulConnections}`); - console.log(`- Failed: ${failedConnections}`); - - // Wait a bit for any pending cleanups - console.log('\nWaiting for cleanup...'); - await new Promise(resolve => setTimeout(resolve, 1000)); - - // Check final connection count - const finalCount = getActiveConnections(); - console.log(`\nFinal connection count: ${finalCount}`); - - // Analyze connection count trend - const maxCount = Math.max(...connectionCounts); - const avgCount = connectionCounts.reduce((a, b) => a + b, 0) / connectionCounts.length; - - console.log(`\nConnection count statistics:`); - console.log(`- Maximum: ${maxCount}`); - console.log(`- Average: ${avgCount.toFixed(2)}`); - console.log(`- Initial: ${initialCount}`); - console.log(`- Final: ${finalCount}`); - - // Stop the proxy - await proxy.stop(); - console.log('\nโœ“ Proxy stopped'); - - // Verify results - expect(finalCount).toEqual(initialCount); - expect(maxCount).toBeLessThan(10); // Should not accumulate many connections - - console.log('\nโœ… PASS: Connection cleanup working correctly under rapid retries!'); -}); - -tap.test('should handle routing failures without leaking connections', async () => { - console.log('\n=== Testing Routing Failure Cleanup ==='); - - // Create a SmartProxy instance with no routes - const proxy = new SmartProxy({ - enableDetailedLogging: false, - maxConnectionLifetime: 10000, - socketTimeout: 5000, - routes: [] // No routes - all connections will fail routing - }); - - // Start the proxy - await proxy.start(); - console.log('โœ“ Proxy started on port 8551 with no routes'); - - // Helper to get active connection count - const getActiveConnections = () => { - const connectionManager = (proxy as any).connectionManager; - return connectionManager ? connectionManager.getConnectionCount() : 0; - }; - - const initialCount = getActiveConnections(); - console.log(`Initial connection count: ${initialCount}`); - - // Create multiple connections that will fail routing - const connectionPromises = []; - for (let i = 0; i < 10; i++) { - connectionPromises.push(new Promise((resolve) => { - const client = new net.Socket(); - - client.on('error', () => { - client.destroy(); - resolve(); - }); - - client.on('close', () => { - resolve(); - }); - - client.connect(8551, 'localhost', () => { - // Send data to trigger routing (which will fail) - client.write('GET / HTTP/1.1\r\nHost: test.com\r\n\r\n'); - }); - - // Force close after a short time - setTimeout(() => { - if (!client.destroyed) { - client.destroy(); - } - resolve(); - }, 500); - })); - } - - // Wait for all connections to complete - await Promise.all(connectionPromises); - console.log('โœ“ All connection attempts completed'); - - // Wait for cleanup - await new Promise(resolve => setTimeout(resolve, 500)); - - const finalCount = getActiveConnections(); - console.log(`Final connection count: ${finalCount}`); - - // Stop the proxy - await proxy.stop(); - console.log('โœ“ Proxy stopped'); - - // Verify no connections leaked - expect(finalCount).toEqual(initialCount); - - console.log('\nโœ… PASS: Routing failures cleaned up correctly!'); -}); - -export default tap.start(); \ No newline at end of file diff --git a/test/test.route-callback-simple.ts b/test/test.route-callback-simple.ts deleted file mode 100644 index e23879f..0000000 --- a/test/test.route-callback-simple.ts +++ /dev/null @@ -1,117 +0,0 @@ -import { tap, expect } from '@git.zone/tstest/tapbundle'; -import { SmartProxy } from '../ts/index.js'; - -tap.test('should set update routes callback on certificate manager', async () => { - // Create a simple proxy with a route requiring certificates - const proxy = new SmartProxy({ - acme: { - email: 'test@local.dev', - useProduction: false, - port: 8080 // Use non-privileged port for ACME challenges globally - }, - routes: [{ - name: 'test-route', - match: { - ports: [8443], - domains: ['test.local'] - }, - action: { - type: 'forward', - targets: [{ host: 'localhost', port: 3000 }], - tls: { - mode: 'terminate', - certificate: 'auto', - acme: { - email: 'test@local.dev', - useProduction: false - } - } - } - }] - }); - - // Track callback setting - let callbackSet = false; - - // Override createCertificateManager to track callback setting - (proxy as any).createCertificateManager = async function( - routes: any, - certStore: string, - acmeOptions?: any, - initialState?: any - ) { - // Create a mock certificate manager - const mockCertManager = { - setUpdateRoutesCallback: function(callback: any) { - callbackSet = true; - }, - setHttpProxy: function(proxy: any) {}, - setGlobalAcmeDefaults: function(defaults: any) {}, - setAcmeStateManager: function(manager: any) {}, - setRoutes: function(routes: any) {}, - initialize: async function() {}, - provisionAllCertificates: async function() {}, - stop: async function() {}, - getAcmeOptions: function() { return acmeOptions || {}; }, - getState: function() { return initialState || { challengeRouteActive: false }; } - }; - - // Mimic the real createCertificateManager behavior - // Always set up the route update callback for ACME challenges - mockCertManager.setUpdateRoutesCallback(async (routes) => { - await this.updateRoutes(routes); - }); - - // Connect with HttpProxy if available (mimic real behavior) - if ((this as any).httpProxyBridge.getHttpProxy()) { - mockCertManager.setHttpProxy((this as any).httpProxyBridge.getHttpProxy()); - } - - // Set the ACME state manager - mockCertManager.setAcmeStateManager((this as any).acmeStateManager); - - // Pass down the global ACME config if available - if ((this as any).settings.acme) { - mockCertManager.setGlobalAcmeDefaults((this as any).settings.acme); - } - - await mockCertManager.initialize(); - return mockCertManager; - }; - - await proxy.start(); - - // The callback should have been set during initialization - expect(callbackSet).toEqual(true); - - // Reset tracking - callbackSet = false; - - // Update routes - this should recreate the certificate manager - await proxy.updateRoutes([{ - name: 'new-route', - match: { - ports: [8444], - domains: ['new.local'] - }, - action: { - type: 'forward', - targets: [{ host: 'localhost', port: 3001 }], - tls: { - mode: 'terminate', - certificate: 'auto', - acme: { - email: 'test@local.dev', - useProduction: false - } - } - } - }]); - - // The callback should have been set again after update - expect(callbackSet).toEqual(true); - - await proxy.stop(); -}); - -export default tap.start(); \ No newline at end of file diff --git a/test/test.route-update-callback.node.ts b/test/test.route-update-callback.node.ts deleted file mode 100644 index c5f4680..0000000 --- a/test/test.route-update-callback.node.ts +++ /dev/null @@ -1,343 +0,0 @@ -import * as plugins from '../ts/plugins.js'; -import { SmartProxy } from '../ts/index.js'; -import { tap, expect } from '@git.zone/tstest/tapbundle'; - -let testProxy: SmartProxy; - -// Create test routes using high ports to avoid permission issues -const createRoute = (id: number, domain: string, port: number = 8443) => ({ - name: `test-route-${id}`, - match: { - ports: [port], - domains: [domain] - }, - action: { - type: 'forward' as const, - targets: [{ - host: 'localhost', - port: 3000 + id - }], - tls: { - mode: 'terminate' as const, - certificate: 'auto' as const, - acme: { - email: 'test@testdomain.test', - useProduction: false - } - } - } -}); - -tap.test('should create SmartProxy instance', async () => { - testProxy = new SmartProxy({ - routes: [createRoute(1, 'test1.testdomain.test', 8443)], - acme: { - email: 'test@testdomain.test', - useProduction: false, - port: 8080 - } - }); - expect(testProxy).toBeInstanceOf(SmartProxy); -}); - -tap.test('should preserve route update callback after updateRoutes', async () => { - // Mock the certificate manager to avoid actual ACME initialization - const originalInitializeCertManager = (testProxy as any).initializeCertificateManager; - let certManagerInitialized = false; - - (testProxy as any).initializeCertificateManager = async function() { - certManagerInitialized = true; - // Create a minimal mock certificate manager - const mockCertManager = { - setUpdateRoutesCallback: function(callback: any) { - this.updateRoutesCallback = callback; - }, - updateRoutesCallback: null, - setHttpProxy: function() {}, - setGlobalAcmeDefaults: function() {}, - setAcmeStateManager: function() {}, - setRoutes: function(routes: any) {}, - initialize: async function() { - // This is where the callback is actually set in the real implementation - return Promise.resolve(); - }, - provisionAllCertificates: async function() { - return Promise.resolve(); - }, - stop: async function() {}, - getAcmeOptions: function() { - return { email: 'test@testdomain.test' }; - }, - getState: function() { - return { challengeRouteActive: false }; - } - }; - - (this as any).certManager = mockCertManager; - - // Simulate the real behavior where setUpdateRoutesCallback is called - mockCertManager.setUpdateRoutesCallback(async (routes: any) => { - await this.updateRoutes(routes); - }); - }; - - // Start the proxy (with mocked cert manager) - await testProxy.start(); - expect(certManagerInitialized).toEqual(true); - - // Get initial certificate manager reference - const initialCertManager = (testProxy as any).certManager; - expect(initialCertManager).toBeTruthy(); - expect(initialCertManager.updateRoutesCallback).toBeTruthy(); - - // Store the initial callback reference - const initialCallback = initialCertManager.updateRoutesCallback; - - // Update routes - this should recreate the cert manager with callback - const newRoutes = [ - createRoute(1, 'test1.testdomain.test', 8443), - createRoute(2, 'test2.testdomain.test', 8444) - ]; - - // Mock the updateRoutes to simulate the real implementation - testProxy.updateRoutes = async function(routes) { - // Update settings - this.settings.routes = routes; - - // Simulate what happens in the real code - recreate cert manager via createCertificateManager - if ((this as any).certManager) { - await (this as any).certManager.stop(); - - // Simulate createCertificateManager which creates a new cert manager - const newMockCertManager = { - setUpdateRoutesCallback: function(callback: any) { - this.updateRoutesCallback = callback; - }, - updateRoutesCallback: null, - setHttpProxy: function() {}, - setGlobalAcmeDefaults: function() {}, - setAcmeStateManager: function() {}, - setRoutes: function(routes: any) {}, - initialize: async function() {}, - provisionAllCertificates: async function() {}, - stop: async function() {}, - getAcmeOptions: function() { - return { email: 'test@testdomain.test' }; - }, - getState: function() { - return { challengeRouteActive: false }; - } - }; - - // Set the callback as done in createCertificateManager - newMockCertManager.setUpdateRoutesCallback(async (routes: any) => { - await this.updateRoutes(routes); - }); - - (this as any).certManager = newMockCertManager; - await (this as any).certManager.initialize(); - } - }; - - await testProxy.updateRoutes(newRoutes); - - // Get new certificate manager reference - const newCertManager = (testProxy as any).certManager; - expect(newCertManager).toBeTruthy(); - expect(newCertManager).not.toEqual(initialCertManager); // Should be a new instance - expect(newCertManager.updateRoutesCallback).toBeTruthy(); // Callback should be set - - // Test that the callback works - const testChallengeRoute = { - name: 'acme-challenge', - match: { - ports: [8080], - path: '/.well-known/acme-challenge/*' - }, - action: { - type: 'static' as const, - content: 'challenge-token' - } - }; - - // This should not throw "No route update callback set" error - let callbackWorked = false; - try { - // If callback is set, this should work - if (newCertManager.updateRoutesCallback) { - await newCertManager.updateRoutesCallback([...newRoutes, testChallengeRoute]); - callbackWorked = true; - } - } catch (error) { - throw new Error(`Route update callback failed: ${error.message}`); - } - - expect(callbackWorked).toEqual(true); - console.log('Route update callback successfully preserved and invoked'); -}); - -tap.test('should handle multiple sequential route updates', async () => { - // Continue with the mocked proxy from previous test - let updateCount = 0; - - // Perform multiple route updates - for (let i = 1; i <= 3; i++) { - const routes = []; - for (let j = 1; j <= i; j++) { - routes.push(createRoute(j, `test${j}.testdomain.test`, 8440 + j)); - } - - await testProxy.updateRoutes(routes); - updateCount++; - - // Verify cert manager is properly set up each time - const certManager = (testProxy as any).certManager; - expect(certManager).toBeTruthy(); - expect(certManager.updateRoutesCallback).toBeTruthy(); - - console.log(`Route update ${i} callback is properly set`); - } - - expect(updateCount).toEqual(3); -}); - -tap.test('should handle route updates when cert manager is not initialized', async () => { - // Create proxy without routes that need certificates - const proxyWithoutCerts = new SmartProxy({ - routes: [{ - name: 'no-cert-route', - match: { - ports: [9080] - }, - action: { - type: 'forward' as const, - targets: [{ - host: 'localhost', - port: 3000 - }] - } - }] - }); - - // Mock initializeCertificateManager to avoid ACME issues - (proxyWithoutCerts as any).initializeCertificateManager = async function() { - // Only create cert manager if routes need it - const autoRoutes = this.settings.routes.filter((r: any) => - r.action.tls?.certificate === 'auto' - ); - - if (autoRoutes.length === 0) { - console.log('No routes require certificate management'); - return; - } - - // Create mock cert manager - const mockCertManager = { - setUpdateRoutesCallback: function(callback: any) { - this.updateRoutesCallback = callback; - }, - updateRoutesCallback: null, - setHttpProxy: function() {}, - setRoutes: function(routes: any) {}, - initialize: async function() {}, - provisionAllCertificates: async function() {}, - stop: async function() {}, - getAcmeOptions: function() { - return { email: 'test@testdomain.test' }; - }, - getState: function() { - return { challengeRouteActive: false }; - } - }; - - (this as any).certManager = mockCertManager; - - // Set the callback - mockCertManager.setUpdateRoutesCallback(async (routes: any) => { - await this.updateRoutes(routes); - }); - }; - - await proxyWithoutCerts.start(); - - // This should not have a cert manager - const certManager = (proxyWithoutCerts as any).certManager; - expect(certManager).toBeFalsy(); - - // Update with routes that need certificates - await proxyWithoutCerts.updateRoutes([createRoute(1, 'cert-needed.testdomain.test', 9443)]); - - // In the real implementation, cert manager is not created by updateRoutes if it doesn't exist - // This is the expected behavior - cert manager is only created during start() or re-created if already exists - const newCertManager = (proxyWithoutCerts as any).certManager; - expect(newCertManager).toBeFalsy(); // Should still be null - - await proxyWithoutCerts.stop(); -}); - -tap.test('should clean up properly', async () => { - await testProxy.stop(); -}); - -tap.test('real code integration test - verify fix is applied', async () => { - // This test will start with routes that need certificates to test the fix - const realProxy = new SmartProxy({ - routes: [createRoute(1, 'test.example.com', 9999)], - acme: { - email: 'test@example.com', - useProduction: false, - port: 18080 - } - }); - - // Mock the certificate manager creation to track callback setting - let callbackSet = false; - (realProxy as any).createCertificateManager = async function(routes: any[], certDir: string, acmeOptions: any, initialState?: any) { - const mockCertManager = { - setUpdateRoutesCallback: function(callback: any) { - callbackSet = true; - this.updateRoutesCallback = callback; - }, - updateRoutesCallback: null as any, - setHttpProxy: function() {}, - setGlobalAcmeDefaults: function() {}, - setAcmeStateManager: function() {}, - setRoutes: function(routes: any) {}, - initialize: async function() {}, - provisionAllCertificates: async function() {}, - stop: async function() {}, - getAcmeOptions: function() { - return acmeOptions || { email: 'test@example.com', useProduction: false }; - }, - getState: function() { - return initialState || { challengeRouteActive: false }; - } - }; - - // Always set up the route update callback for ACME challenges - mockCertManager.setUpdateRoutesCallback(async (routes) => { - await this.updateRoutes(routes); - }); - - return mockCertManager; - }; - - await realProxy.start(); - - // The callback should have been set during initialization - expect(callbackSet).toEqual(true); - callbackSet = false; // Reset for update test - - // Update routes - this should recreate cert manager with callback preserved - const newRoute = createRoute(2, 'test2.example.com', 9999); - await realProxy.updateRoutes([createRoute(1, 'test.example.com', 9999), newRoute]); - - // The callback should have been set again during update - expect(callbackSet).toEqual(true); - - await realProxy.stop(); - - console.log('Real code integration test passed - fix is correctly applied!'); -}); - -export default tap.start(); \ No newline at end of file diff --git a/test/test.smartacme-integration.ts b/test/test.smartacme-integration.ts deleted file mode 100644 index e720557..0000000 --- a/test/test.smartacme-integration.ts +++ /dev/null @@ -1,54 +0,0 @@ -import * as plugins from '../ts/plugins.js'; -import { tap, expect } from '@git.zone/tstest/tapbundle'; -import { SmartCertManager } from '../ts/proxies/smart-proxy/certificate-manager.js'; -import type { IRouteConfig } from '../ts/proxies/smart-proxy/models/route-types.js'; - -let certManager: SmartCertManager; - -tap.test('should create a SmartCertManager instance', async () => { - const routes: IRouteConfig[] = [ - { - name: 'test-acme-route', - match: { - domains: ['test.example.com'], - ports: [] - }, - action: { - type: 'forward', - targets: [{ - host: 'localhost', - port: 3000 - }], - tls: { - mode: 'terminate', - certificate: 'auto', - acme: { - email: 'test@example.com' - } - } - } - } - ]; - - certManager = new SmartCertManager(routes, './test-certs', { - email: 'test@example.com', - useProduction: false - }); - - // Just verify it creates without error - expect(certManager).toBeInstanceOf(SmartCertManager); -}); - -tap.test('should verify SmartAcme handlers are accessible', async () => { - // Test that we can access SmartAcme handlers - const http01Handler = new plugins.smartacme.handlers.Http01MemoryHandler(); - expect(http01Handler).toBeDefined(); -}); - -tap.test('should verify SmartAcme cert managers are accessible', async () => { - // Test that we can access SmartAcme cert managers - const memoryCertManager = new plugins.smartacme.certmanagers.MemoryCertManager(); - expect(memoryCertManager).toBeDefined(); -}); - -export default tap.start(); \ No newline at end of file diff --git a/test/test.stuck-connection-cleanup.node.ts b/test/test.stuck-connection-cleanup.node.ts deleted file mode 100644 index 2bd0c6e..0000000 --- a/test/test.stuck-connection-cleanup.node.ts +++ /dev/null @@ -1,144 +0,0 @@ -import { expect, tap } from '@git.zone/tstest/tapbundle'; -import * as net from 'net'; -import { SmartProxy } from '../ts/index.js'; -import * as plugins from '../ts/plugins.js'; - -tap.test('stuck connection cleanup - verify connections to hanging backends are cleaned up', async (tools) => { - console.log('\n=== Stuck Connection Cleanup Test ==='); - console.log('Purpose: Verify that connections to backends that accept but never respond are cleaned up'); - - // Create a hanging backend that accepts connections but never responds - let backendConnections = 0; - const hangingBackend = net.createServer((socket) => { - backendConnections++; - console.log(`Hanging backend: Connection ${backendConnections} received`); - // Accept the connection but never send any data back - // This simulates a hung backend service - }); - - await new Promise((resolve) => { - hangingBackend.listen(9997, () => { - console.log('โœ“ Hanging backend started on port 9997'); - resolve(); - }); - }); - - // Create proxy that forwards to hanging backend - const proxy = new SmartProxy({ - routes: [{ - name: 'to-hanging-backend', - match: { ports: 8589 }, - action: { - type: 'forward', - targets: [{ host: 'localhost', port: 9997 }] - } - }], - keepAlive: true, - enableDetailedLogging: false, - inactivityTimeout: 5000, // 5 second inactivity check interval for faster testing - }); - - await proxy.start(); - console.log('โœ“ Proxy started on port 8589'); - - // Create connections that will get stuck - console.log('\n--- Creating connections to hanging backend ---'); - const clients: net.Socket[] = []; - - for (let i = 0; i < 5; i++) { - const client = net.connect(8589, 'localhost'); - clients.push(client); - - await new Promise((resolve) => { - client.on('connect', () => { - console.log(`Client ${i} connected`); - // Send data that will never get a response - client.write(`GET / HTTP/1.1\r\nHost: localhost\r\n\r\n`); - resolve(); - }); - - client.on('error', (err) => { - console.log(`Client ${i} error: ${err.message}`); - resolve(); - }); - }); - } - - // Wait a moment for connections to establish - await plugins.smartdelay.delayFor(1000); - - // Check initial connection count - const initialCount = (proxy as any).connectionManager.getConnectionCount(); - console.log(`\nInitial connection count: ${initialCount}`); - expect(initialCount).toEqual(5); - - // Get connection details - const connections = (proxy as any).connectionManager.getConnections(); - let stuckCount = 0; - - for (const [id, record] of connections) { - if (record.bytesReceived > 0 && record.bytesSent === 0) { - stuckCount++; - console.log(`Stuck connection ${id}: received=${record.bytesReceived}, sent=${record.bytesSent}`); - } - } - - console.log(`Stuck connections found: ${stuckCount}`); - expect(stuckCount).toEqual(5); - - // Wait for inactivity check to run (it checks every 30s by default, but we set it to 5s) - console.log('\n--- Waiting for stuck connection detection (65 seconds) ---'); - console.log('Note: Stuck connections are cleaned up after 60 seconds with no response'); - - // Speed up time by manually triggering inactivity check after simulating time passage - // First, age the connections by updating their timestamps - const now = Date.now(); - for (const [id, record] of connections) { - // Simulate that these connections are 61 seconds old - record.incomingStartTime = now - 61000; - record.lastActivity = now - 61000; - } - - // Manually trigger inactivity check - console.log('Manually triggering inactivity check...'); - (proxy as any).connectionManager.performOptimizedInactivityCheck(); - - // Wait for cleanup to complete - await plugins.smartdelay.delayFor(1000); - - // Check connection count after cleanup - const afterCleanupCount = (proxy as any).connectionManager.getConnectionCount(); - console.log(`\nConnection count after cleanup: ${afterCleanupCount}`); - - // Verify termination stats - const stats = (proxy as any).connectionManager.getTerminationStats(); - console.log('\nTermination stats:', stats); - - // All connections should be cleaned up as "stuck_no_response" - expect(afterCleanupCount).toEqual(0); - - // The termination reason might be under incoming or general stats - const stuckCleanups = (stats.incoming.stuck_no_response || 0) + - (stats.outgoing?.stuck_no_response || 0); - console.log(`Stuck cleanups detected: ${stuckCleanups}`); - expect(stuckCleanups).toBeGreaterThan(0); - - // Verify clients were disconnected - let closedClients = 0; - for (const client of clients) { - if (client.destroyed) { - closedClients++; - } - } - console.log(`Closed clients: ${closedClients}/5`); - expect(closedClients).toEqual(5); - - // Cleanup - console.log('\n--- Cleanup ---'); - await proxy.stop(); - hangingBackend.close(); - - console.log('โœ“ Test complete: Stuck connections are properly detected and cleaned up'); -}); - -export default tap.start(); \ No newline at end of file diff --git a/test/test.websocket-keepalive.node.ts b/test/test.websocket-keepalive.node.ts deleted file mode 100644 index 8fac0a7..0000000 --- a/test/test.websocket-keepalive.node.ts +++ /dev/null @@ -1,157 +0,0 @@ -import { tap, expect } from '@git.zone/tstest/tapbundle'; -import { SmartProxy } from '../ts/index.js'; -import * as net from 'net'; - -tap.test('websocket keep-alive settings for SNI passthrough', async (tools) => { - // Test 1: Verify grace periods for TLS connections - console.log('\n=== Test 1: Grace periods for encrypted connections ==='); - - const proxy = new SmartProxy({ - keepAliveTreatment: 'extended', - keepAliveInactivityMultiplier: 10, - inactivityTimeout: 60000, // 1 minute for testing - routes: [ - { - name: 'test-passthrough', - match: { ports: 8443, domains: 'test.local' }, - action: { - type: 'forward', - targets: [{ host: 'localhost', port: 9443 }], - tls: { mode: 'passthrough' } - } - } - ] - }); - - // Override route port - proxy.settings.routes[0].match.ports = 8443; - - await proxy.start(); - - // Access connection manager - const connectionManager = proxy.connectionManager; - - // Test 2: Verify longer grace periods are applied - console.log('\n=== Test 2: Checking grace period configuration ==='); - - // Create a mock connection record - const mockRecord = { - id: 'test-conn-1', - remoteIP: '127.0.0.1', - incomingStartTime: Date.now() - 120000, // 2 minutes old - isTLS: true, - incoming: { destroyed: false } as any, - outgoing: { destroyed: true } as any, // Half-zombie state - connectionClosed: false, - hasKeepAlive: true, - lastActivity: Date.now() - 60000 - }; - - // The grace period should be 5 minutes for TLS connections - const gracePeriod = mockRecord.isTLS ? 300000 : 30000; - console.log(`Grace period for TLS connection: ${gracePeriod}ms (${gracePeriod / 1000} seconds)`); - expect(gracePeriod).toEqual(300000); // 5 minutes - - // Test 3: Verify keep-alive treatment - console.log('\n=== Test 3: Keep-alive treatment configuration ==='); - - const settings = proxy.settings; - console.log(`Keep-alive treatment: ${settings.keepAliveTreatment}`); - console.log(`Keep-alive multiplier: ${settings.keepAliveInactivityMultiplier}`); - console.log(`Base inactivity timeout: ${settings.inactivityTimeout}ms`); - - // Calculate effective timeout - const effectiveTimeout = settings.inactivityTimeout! * (settings.keepAliveInactivityMultiplier || 6); - console.log(`Effective timeout for keep-alive connections: ${effectiveTimeout}ms (${effectiveTimeout / 1000} seconds)`); - - expect(settings.keepAliveTreatment).toEqual('extended'); - expect(effectiveTimeout).toEqual(600000); // 10 minutes with our test config - - // Test 4: Verify SNI passthrough doesn't get WebSocket heartbeat - console.log('\n=== Test 4: SNI passthrough handling ==='); - - // Check route configuration - const route = proxy.settings.routes[0]; - expect(route.action.tls?.mode).toEqual('passthrough'); - - // In passthrough mode, WebSocket-specific handling should be skipped - // The connection should be treated as a raw TCP connection - console.log('โœ“ SNI passthrough routes bypass WebSocket heartbeat checks'); - - await proxy.stop(); - - console.log('\nโœ… WebSocket keep-alive configuration test completed!'); -}); - -// Test actual long-lived connection behavior -tap.test('long-lived connection survival test', async (tools) => { - tools.timeout(60000); // This test waits 55 seconds - console.log('\n=== Testing long-lived connection survival ==='); - - // Create a simple echo server - const echoServer = net.createServer((socket) => { - console.log('Echo server: client connected'); - socket.on('data', (data) => { - socket.write(data); // Echo back - }); - }); - - await new Promise((resolve) => echoServer.listen(9444, resolve)); - - // Create proxy with immortal keep-alive - const proxy = new SmartProxy({ - keepAliveTreatment: 'immortal', // Never timeout - routes: [ - { - name: 'echo-passthrough', - match: { ports: 8444 }, - action: { - type: 'forward', - targets: [{ host: 'localhost', port: 9444 }] - } - } - ] - }); - - // Override route port - proxy.settings.routes[0].match.ports = 8444; - - await proxy.start(); - - // Create a client connection - const client = new net.Socket(); - await new Promise((resolve, reject) => { - client.connect(8444, 'localhost', () => { - console.log('Client connected to proxy'); - resolve(); - }); - client.on('error', reject); - }); - - // Keep connection alive with periodic data - let pingCount = 0; - const pingInterval = setInterval(() => { - if (client.writable) { - client.write(`ping ${++pingCount}\n`); - console.log(`Sent ping ${pingCount}`); - } - }, 20000); // Every 20 seconds - - // Wait 55 seconds to verify connection survives past old 30s timeout - await new Promise(resolve => setTimeout(resolve, 55000)); - - // Check if connection is still alive - const isAlive = client.writable && !client.destroyed; - console.log(`Connection alive after 55 seconds: ${isAlive}`); - expect(isAlive).toBeTrue(); - - // Clean up - clearInterval(pingInterval); - client.destroy(); - await proxy.stop(); - await new Promise((resolve) => echoServer.close(() => resolve())); - - console.log('โœ… Long-lived connection survived past 30-second timeout!'); -}); - -export default tap.start(); \ No newline at end of file diff --git a/test/test.wrapped-socket.ts b/test/test.wrapped-socket.ts index bc8669e..689c0d5 100644 --- a/test/test.wrapped-socket.ts +++ b/test/test.wrapped-socket.ts @@ -312,61 +312,4 @@ tap.test('WrappedSocket - should handle encoding and address methods', async () server.close(); }); -tap.test('WrappedSocket - should work with ConnectionManager', async () => { - // This test verifies that WrappedSocket can be used seamlessly with ConnectionManager - const { ConnectionManager } = await import('../ts/proxies/smart-proxy/connection-manager.js'); - - // Create minimal settings - const settings = { - routes: [], - defaults: { - security: { - maxConnections: 100 - } - } - }; - - // Create a mock SmartProxy instance - const mockSmartProxy = { - settings, - securityManager: { - trackConnectionByIP: () => {}, - untrackConnectionByIP: () => {}, - removeConnectionByIP: () => {} - } - } as any; - - const connectionManager = new ConnectionManager(mockSmartProxy); - - // Create a simple test server - const server = net.createServer(); - await new Promise((resolve) => { - server.listen(0, 'localhost', () => resolve()); - }); - - const serverPort = (server.address() as net.AddressInfo).port; - - // Create a client connection - const clientSocket = net.connect(serverPort, 'localhost'); - - // Wait for connection to establish - await new Promise((resolve) => { - clientSocket.once('connect', () => resolve()); - }); - - // Wrap with proxy info - const wrappedSocket = new WrappedSocket(clientSocket, '203.0.113.45', 65432); - - // Create connection using wrapped socket - const record = connectionManager.createConnection(wrappedSocket); - - expect(record).toBeTruthy(); - expect(record!.remoteIP).toEqual('203.0.113.45'); // Should use the real client IP - expect(record!.localPort).toEqual(clientSocket.localPort); - - // Clean up - connectionManager.cleanupConnection(record!, 'test-complete'); - server.close(); -}); - export default tap.start(); \ No newline at end of file diff --git a/test/test.zombie-connection-cleanup.node.ts b/test/test.zombie-connection-cleanup.node.ts deleted file mode 100644 index 31a6d08..0000000 --- a/test/test.zombie-connection-cleanup.node.ts +++ /dev/null @@ -1,304 +0,0 @@ -import { tap, expect } from '@git.zone/tstest/tapbundle'; -import * as net from 'net'; -import * as plugins from '../ts/plugins.js'; - -// Import SmartProxy -import { SmartProxy } from '../ts/index.js'; - -// Import types through type-only imports -import type { ConnectionManager } from '../ts/proxies/smart-proxy/connection-manager.js'; -import type { IConnectionRecord } from '../ts/proxies/smart-proxy/models/interfaces.js'; - -tap.test('zombie connection cleanup - verify inactivity check detects and cleans destroyed sockets', async () => { - console.log('\n=== Zombie Connection Cleanup Test ==='); - console.log('Purpose: Verify that connections with destroyed sockets are detected and cleaned up'); - console.log('Setup: Client โ†’ OuterProxy (8590) โ†’ InnerProxy (8591) โ†’ Backend (9998)'); - - // Create backend server that can be controlled - let acceptConnections = true; - let destroyImmediately = false; - const backendConnections: net.Socket[] = []; - - const backend = net.createServer((socket) => { - console.log('Backend: Connection received'); - backendConnections.push(socket); - - if (destroyImmediately) { - console.log('Backend: Destroying connection immediately'); - socket.destroy(); - } else { - socket.on('data', (data) => { - console.log('Backend: Received data, echoing back'); - socket.write(data); - }); - } - }); - - await new Promise((resolve) => { - backend.listen(9998, () => { - console.log('โœ“ Backend server started on port 9998'); - resolve(); - }); - }); - - // Create InnerProxy with faster inactivity check for testing - const innerProxy = new SmartProxy({ - enableDetailedLogging: true, - inactivityTimeout: 5000, // 5 seconds for faster testing - inactivityCheckInterval: 1000, // Check every second - routes: [{ - name: 'to-backend', - match: { ports: 8591 }, - action: { - type: 'forward', - targets: [{ - host: 'localhost', - port: 9998 - }] - } - }] - }); - - // Create OuterProxy with faster inactivity check - const outerProxy = new SmartProxy({ - enableDetailedLogging: true, - inactivityTimeout: 5000, // 5 seconds for faster testing - inactivityCheckInterval: 1000, // Check every second - routes: [{ - name: 'to-inner', - match: { ports: 8590 }, - action: { - type: 'forward', - targets: [{ - host: 'localhost', - port: 8591 - }] - } - }] - }); - - await innerProxy.start(); - console.log('โœ“ InnerProxy started on port 8591'); - - await outerProxy.start(); - console.log('โœ“ OuterProxy started on port 8590'); - - // Helper to get connection details - const getConnectionDetails = () => { - const outerConnMgr = (outerProxy as any).connectionManager as ConnectionManager; - const innerConnMgr = (innerProxy as any).connectionManager as ConnectionManager; - - const outerRecords = Array.from((outerConnMgr as any).connectionRecords.values()) as IConnectionRecord[]; - const innerRecords = Array.from((innerConnMgr as any).connectionRecords.values()) as IConnectionRecord[]; - - return { - outer: { - count: outerConnMgr.getConnectionCount(), - records: outerRecords, - zombies: outerRecords.filter(r => - !r.connectionClosed && - r.incoming?.destroyed && - (r.outgoing?.destroyed ?? true) - ), - halfZombies: outerRecords.filter(r => - !r.connectionClosed && - (r.incoming?.destroyed || r.outgoing?.destroyed) && - !(r.incoming?.destroyed && (r.outgoing?.destroyed ?? true)) - ) - }, - inner: { - count: innerConnMgr.getConnectionCount(), - records: innerRecords, - zombies: innerRecords.filter(r => - !r.connectionClosed && - r.incoming?.destroyed && - (r.outgoing?.destroyed ?? true) - ), - halfZombies: innerRecords.filter(r => - !r.connectionClosed && - (r.incoming?.destroyed || r.outgoing?.destroyed) && - !(r.incoming?.destroyed && (r.outgoing?.destroyed ?? true)) - ) - } - }; - }; - - console.log('\n--- Test 1: Create zombie by destroying sockets without events ---'); - - // Create a connection and forcefully destroy sockets to create zombies - const client1 = new net.Socket(); - await new Promise((resolve) => { - client1.connect(8590, 'localhost', () => { - console.log('Client1 connected to OuterProxy'); - client1.write('GET / HTTP/1.1\r\nHost: test.com\r\n\r\n'); - - // Wait for connection to be established through the chain - setTimeout(() => { - console.log('Forcefully destroying backend connections to create zombies'); - - // Get connection details before destruction - const beforeDetails = getConnectionDetails(); - console.log(`Before destruction: Outer=${beforeDetails.outer.count}, Inner=${beforeDetails.inner.count}`); - - // Destroy all backend connections without proper close events - backendConnections.forEach(conn => { - if (!conn.destroyed) { - // Remove all listeners to prevent proper cleanup - conn.removeAllListeners(); - conn.destroy(); - } - }); - - // Also destroy the client socket abruptly - client1.removeAllListeners(); - client1.destroy(); - - resolve(); - }, 500); - }); - }); - - // Check immediately after destruction - await new Promise(resolve => setTimeout(resolve, 100)); - let details = getConnectionDetails(); - console.log(`\nAfter destruction:`); - console.log(` Outer: ${details.outer.count} connections, ${details.outer.zombies.length} zombies, ${details.outer.halfZombies.length} half-zombies`); - console.log(` Inner: ${details.inner.count} connections, ${details.inner.zombies.length} zombies, ${details.inner.halfZombies.length} half-zombies`); - - // Wait for inactivity check to run (should detect zombies) - console.log('\nWaiting for inactivity check to detect zombies...'); - await new Promise(resolve => setTimeout(resolve, 2000)); - - details = getConnectionDetails(); - console.log(`\nAfter first inactivity check:`); - console.log(` Outer: ${details.outer.count} connections, ${details.outer.zombies.length} zombies, ${details.outer.halfZombies.length} half-zombies`); - console.log(` Inner: ${details.inner.count} connections, ${details.inner.zombies.length} zombies, ${details.inner.halfZombies.length} half-zombies`); - - console.log('\n--- Test 2: Create half-zombie by destroying only one socket ---'); - - // Clear backend connections array - backendConnections.length = 0; - - const client2 = new net.Socket(); - await new Promise((resolve) => { - client2.connect(8590, 'localhost', () => { - console.log('Client2 connected to OuterProxy'); - client2.write('GET / HTTP/1.1\r\nHost: test.com\r\n\r\n'); - - setTimeout(() => { - console.log('Creating half-zombie by destroying only outgoing socket on outer proxy'); - - // Access the connection records directly - const outerConnMgr = (outerProxy as any).connectionManager as ConnectionManager; - const outerRecords = Array.from((outerConnMgr as any).connectionRecords.values()) as IConnectionRecord[]; - - // Find the active connection and destroy only its outgoing socket - const activeRecord = outerRecords.find(r => !r.connectionClosed && r.outgoing && !r.outgoing.destroyed); - if (activeRecord && activeRecord.outgoing) { - console.log('Found active connection, destroying outgoing socket'); - activeRecord.outgoing.removeAllListeners(); - activeRecord.outgoing.destroy(); - } - - resolve(); - }, 500); - }); - }); - - // Check half-zombie state - await new Promise(resolve => setTimeout(resolve, 100)); - details = getConnectionDetails(); - console.log(`\nAfter creating half-zombie:`); - console.log(` Outer: ${details.outer.count} connections, ${details.outer.zombies.length} zombies, ${details.outer.halfZombies.length} half-zombies`); - console.log(` Inner: ${details.inner.count} connections, ${details.inner.zombies.length} zombies, ${details.inner.halfZombies.length} half-zombies`); - - // Wait for 30-second grace period (simulated by multiple checks) - console.log('\nWaiting for half-zombie grace period (30 seconds simulated)...'); - - // Manually age the connection to trigger half-zombie cleanup - const outerConnMgr = (outerProxy as any).connectionManager as ConnectionManager; - const records = Array.from((outerConnMgr as any).connectionRecords.values()) as IConnectionRecord[]; - records.forEach(record => { - if (!record.connectionClosed) { - // Age the connection by 35 seconds - record.incomingStartTime -= 35000; - } - }); - - // Trigger inactivity check - await new Promise(resolve => setTimeout(resolve, 2000)); - - details = getConnectionDetails(); - console.log(`\nAfter half-zombie cleanup:`); - console.log(` Outer: ${details.outer.count} connections, ${details.outer.zombies.length} zombies, ${details.outer.halfZombies.length} half-zombies`); - console.log(` Inner: ${details.inner.count} connections, ${details.inner.zombies.length} zombies, ${details.inner.halfZombies.length} half-zombies`); - - // Clean up client2 properly - if (!client2.destroyed) { - client2.destroy(); - } - - console.log('\n--- Test 3: Rapid zombie creation under load ---'); - - // Create multiple connections rapidly and destroy them - const rapidClients: net.Socket[] = []; - - for (let i = 0; i < 5; i++) { - const client = new net.Socket(); - rapidClients.push(client); - - client.connect(8590, 'localhost', () => { - console.log(`Rapid client ${i} connected`); - client.write('GET / HTTP/1.1\r\nHost: test.com\r\n\r\n'); - - // Destroy after random delay - setTimeout(() => { - client.removeAllListeners(); - client.destroy(); - }, Math.random() * 500); - }); - - // Small delay between connections - await new Promise(resolve => setTimeout(resolve, 50)); - } - - // Wait a bit - await new Promise(resolve => setTimeout(resolve, 1000)); - - details = getConnectionDetails(); - console.log(`\nAfter rapid connections:`); - console.log(` Outer: ${details.outer.count} connections, ${details.outer.zombies.length} zombies, ${details.outer.halfZombies.length} half-zombies`); - console.log(` Inner: ${details.inner.count} connections, ${details.inner.zombies.length} zombies, ${details.inner.halfZombies.length} half-zombies`); - - // Wait for cleanup - console.log('\nWaiting for final cleanup...'); - await new Promise(resolve => setTimeout(resolve, 3000)); - - details = getConnectionDetails(); - console.log(`\nFinal state:`); - console.log(` Outer: ${details.outer.count} connections, ${details.outer.zombies.length} zombies, ${details.outer.halfZombies.length} half-zombies`); - console.log(` Inner: ${details.inner.count} connections, ${details.inner.zombies.length} zombies, ${details.inner.halfZombies.length} half-zombies`); - - // Cleanup - await outerProxy.stop(); - await innerProxy.stop(); - backend.close(); - - // Verify all connections are cleaned up - console.log('\n--- Verification ---'); - - if (details.outer.count === 0 && details.inner.count === 0) { - console.log('โœ… PASS: All zombie connections were cleaned up'); - } else { - console.log('โŒ FAIL: Some connections remain'); - } - - expect(details.outer.count).toEqual(0); - expect(details.inner.count).toEqual(0); - expect(details.outer.zombies.length).toEqual(0); - expect(details.inner.zombies.length).toEqual(0); - expect(details.outer.halfZombies.length).toEqual(0); - expect(details.inner.halfZombies.length).toEqual(0); -}); - -export default tap.start(); \ No newline at end of file diff --git a/ts/00_commitinfo_data.ts b/ts/00_commitinfo_data.ts index 91ee1d9..8d7afe4 100644 --- a/ts/00_commitinfo_data.ts +++ b/ts/00_commitinfo_data.ts @@ -3,6 +3,6 @@ */ export const commitinfo = { name: '@push.rocks/smartproxy', - version: '22.4.2', + version: '22.5.0', description: 'A powerful proxy package with unified route-based configuration for high traffic management. Features include SSL/TLS support, flexible routing patterns, WebSocket handling, advanced security options, and automatic ACME certificate management.' } diff --git a/ts/index.ts b/ts/index.ts index 15f3b06..3a60d3b 100644 --- a/ts/index.ts +++ b/ts/index.ts @@ -5,15 +5,10 @@ // NFTables proxy exports export * from './proxies/nftables-proxy/index.js'; -// Export HttpProxy elements -export { HttpProxy, CertificateManager, ConnectionPool, RequestHandler, WebSocketHandler } from './proxies/http-proxy/index.js'; -export type { IMetricsTracker, MetricsTracker } from './proxies/http-proxy/index.js'; -export type { IHttpProxyOptions, ICertificateEntry, ILogger } from './proxies/http-proxy/models/types.js'; -export { SharedRouteManager as HttpProxyRouteManager } from './core/routing/route-manager.js'; - -// Export SmartProxy elements selectively to avoid RouteManager ambiguity -export { SmartProxy, ConnectionManager, SecurityManager, TimeoutManager, TlsManager, HttpProxyBridge, RouteConnectionHandler, SmartCertManager } from './proxies/smart-proxy/index.js'; +// Export SmartProxy elements +export { SmartProxy } from './proxies/smart-proxy/index.js'; export { SharedRouteManager as RouteManager } from './core/routing/route-manager.js'; + // Export smart-proxy models export type { ISmartProxyOptions, IConnectionRecord, IRouteConfig, IRouteMatch, IRouteAction, IRouteTls, IRouteContext } from './proxies/smart-proxy/models/index.js'; export type { TSmartProxyCertProvisionObject } from './proxies/smart-proxy/models/interfaces.js'; @@ -22,8 +17,6 @@ export * from './proxies/smart-proxy/utils/index.js'; // Original: export * from './smartproxy/classes.pp.snihandler.js' // Now we export from the new module export { SniHandler } from './tls/sni/sni-handler.js'; -// Original: export * from './smartproxy/classes.pp.interfaces.js' -// Now we export from the new module (selectively to avoid conflicts) // Core types and utilities export * from './core/models/common-types.js'; @@ -32,8 +25,7 @@ export * from './core/models/common-types.js'; export type { IAcmeOptions } from './proxies/smart-proxy/models/interfaces.js'; // Modular exports for new architecture -// Certificate module has been removed - use SmartCertManager instead export * as tls from './tls/index.js'; export * as routing from './routing/index.js'; export * as detection from './detection/index.js'; -export * as protocols from './protocols/index.js'; \ No newline at end of file +export * as protocols from './protocols/index.js'; diff --git a/ts/proxies/http-proxy/connection-pool.ts b/ts/proxies/http-proxy/connection-pool.ts deleted file mode 100644 index 9aa62c0..0000000 --- a/ts/proxies/http-proxy/connection-pool.ts +++ /dev/null @@ -1,228 +0,0 @@ -import * as plugins from '../../plugins.js'; -import { type IHttpProxyOptions, type IConnectionEntry, type ILogger, createLogger } from './models/types.js'; -import { cleanupSocket } from '../../core/utils/socket-utils.js'; - -/** - * Manages a pool of backend connections for efficient reuse - */ -export class ConnectionPool { - private connectionPool: Map> = new Map(); - private roundRobinPositions: Map = new Map(); - private logger: ILogger; - - constructor(private options: IHttpProxyOptions) { - this.logger = createLogger(options.logLevel || 'info'); - } - - /** - * Get a connection from the pool or create a new one - */ - public getConnection(host: string, port: number): Promise { - return new Promise((resolve, reject) => { - const poolKey = `${host}:${port}`; - const connectionList = this.connectionPool.get(poolKey) || []; - - // Look for an idle connection - const idleConnectionIndex = connectionList.findIndex(c => c.isIdle); - - if (idleConnectionIndex >= 0) { - // Get existing connection from pool - const connection = connectionList[idleConnectionIndex]; - connection.isIdle = false; - connection.lastUsed = Date.now(); - this.logger.debug(`Reusing connection from pool for ${poolKey}`); - - // Update the pool - this.connectionPool.set(poolKey, connectionList); - - resolve(connection.socket); - return; - } - - // No idle connection available, create a new one if pool isn't full - const poolSize = this.options.connectionPoolSize || 50; - if (connectionList.length < poolSize) { - this.logger.debug(`Creating new connection to ${host}:${port}`); - - try { - const socket = plugins.net.connect({ - host, - port, - keepAlive: true, - keepAliveInitialDelay: 30000 // 30 seconds - }); - - socket.once('connect', () => { - // Add to connection pool - const connection = { - socket, - lastUsed: Date.now(), - isIdle: false - }; - - connectionList.push(connection); - this.connectionPool.set(poolKey, connectionList); - - // Setup cleanup when the connection is closed - socket.once('close', () => { - const idx = connectionList.findIndex(c => c.socket === socket); - if (idx >= 0) { - connectionList.splice(idx, 1); - this.connectionPool.set(poolKey, connectionList); - this.logger.debug(`Removed closed connection from pool for ${poolKey}`); - } - }); - - resolve(socket); - }); - - socket.once('error', (err) => { - this.logger.error(`Error creating connection to ${host}:${port}`, err); - reject(err); - }); - } catch (err) { - this.logger.error(`Failed to create connection to ${host}:${port}`, err); - reject(err); - } - } else { - // Pool is full, wait for an idle connection or reject - this.logger.warn(`Connection pool for ${poolKey} is full (${connectionList.length})`); - reject(new Error(`Connection pool for ${poolKey} is full`)); - } - }); - } - - /** - * Return a connection to the pool for reuse - */ - public returnConnection(socket: plugins.net.Socket, host: string, port: number): void { - const poolKey = `${host}:${port}`; - const connectionList = this.connectionPool.get(poolKey) || []; - - // Find this connection in the pool - const connectionIndex = connectionList.findIndex(c => c.socket === socket); - - if (connectionIndex >= 0) { - // Mark as idle and update last used time - connectionList[connectionIndex].isIdle = true; - connectionList[connectionIndex].lastUsed = Date.now(); - - this.logger.debug(`Returned connection to pool for ${poolKey}`); - } else { - this.logger.warn(`Attempted to return unknown connection to pool for ${poolKey}`); - } - } - - /** - * Cleanup the connection pool by removing idle connections - * or reducing pool size if it exceeds the configured maximum - */ - public cleanupConnectionPool(): void { - const now = Date.now(); - const idleTimeout = this.options.keepAliveTimeout || 120000; // 2 minutes default - - for (const [host, connections] of this.connectionPool.entries()) { - // Sort by last used time (oldest first) - connections.sort((a, b) => a.lastUsed - b.lastUsed); - - // Remove idle connections older than the idle timeout - let removed = 0; - while (connections.length > 0) { - const connection = connections[0]; - - // Remove if idle and exceeds timeout, or if pool is too large - if ((connection.isIdle && now - connection.lastUsed > idleTimeout) || - connections.length > (this.options.connectionPoolSize || 50)) { - - cleanupSocket(connection.socket, `pool-${host}-idle`, { immediate: true }).catch(() => {}); - - connections.shift(); // Remove from pool - removed++; - } else { - break; // Stop removing if we've reached active or recent connections - } - } - - if (removed > 0) { - this.logger.debug(`Removed ${removed} idle connections from pool for ${host}, ${connections.length} remaining`); - } - - // Update the pool with the remaining connections - if (connections.length === 0) { - this.connectionPool.delete(host); - } else { - this.connectionPool.set(host, connections); - } - } - } - - /** - * Close all connections in the pool - */ - public closeAllConnections(): void { - for (const [host, connections] of this.connectionPool.entries()) { - this.logger.debug(`Closing ${connections.length} connections to ${host}`); - - for (const connection of connections) { - cleanupSocket(connection.socket, `pool-${host}-close`, { immediate: true }).catch(() => {}); - } - } - - this.connectionPool.clear(); - this.roundRobinPositions.clear(); - } - - /** - * Get load balancing target using round-robin - */ - public getNextTarget(targets: string[], port: number): { host: string, port: number } { - const targetKey = targets.join(','); - - // Initialize position if not exists - if (!this.roundRobinPositions.has(targetKey)) { - this.roundRobinPositions.set(targetKey, 0); - } - - // Get current position and increment for next time - const currentPosition = this.roundRobinPositions.get(targetKey)!; - const nextPosition = (currentPosition + 1) % targets.length; - this.roundRobinPositions.set(targetKey, nextPosition); - - // Return the selected target - return { - host: targets[currentPosition], - port - }; - } - - /** - * Gets the connection pool status - */ - public getPoolStatus(): Record { - return Object.fromEntries( - Array.from(this.connectionPool.entries()).map(([host, connections]) => [ - host, - { - total: connections.length, - idle: connections.filter(c => c.isIdle).length - } - ]) - ); - } - - /** - * Setup a periodic cleanup task - */ - public setupPeriodicCleanup(interval: number = 60000): NodeJS.Timeout { - const timer = setInterval(() => { - this.cleanupConnectionPool(); - }, interval); - - // Don't prevent process exit - if (timer.unref) { - timer.unref(); - } - - return timer; - } -} \ No newline at end of file diff --git a/ts/proxies/http-proxy/context-creator.ts b/ts/proxies/http-proxy/context-creator.ts deleted file mode 100644 index 9b1a8a5..0000000 --- a/ts/proxies/http-proxy/context-creator.ts +++ /dev/null @@ -1,145 +0,0 @@ -import * as plugins from '../../plugins.js'; -import '../../core/models/socket-augmentation.js'; -import type { IRouteContext, IHttpRouteContext, IHttp2RouteContext } from '../../core/models/route-context.js'; - -/** - * Context creator for NetworkProxy - * Creates route contexts for matching and function evaluation - */ -export class ContextCreator { - /** - * Create a route context from HTTP request information - */ - public createHttpRouteContext(req: any, options: { - tlsVersion?: string; - connectionId: string; - clientIp: string; - serverIp: string; - }): IHttpRouteContext { - // Parse headers - const headers: Record = {}; - for (const [key, value] of Object.entries(req.headers)) { - if (typeof value === 'string') { - headers[key.toLowerCase()] = value; - } else if (Array.isArray(value) && value.length > 0) { - headers[key.toLowerCase()] = value[0]; - } - } - - // Parse domain from Host header - const domain = headers['host']?.split(':')[0] || ''; - - // Parse URL - const url = new URL(`http://${domain}${req.url || '/'}`); - - return { - // Connection basics - port: req.socket.localPort || 0, - domain, - clientIp: options.clientIp, - serverIp: options.serverIp, - - // HTTP specifics - path: url.pathname, - query: url.search ? url.search.substring(1) : '', - headers, - - // TLS information - isTls: !!req.socket.encrypted, - tlsVersion: options.tlsVersion, - - // Request objects - req, - - // Metadata - timestamp: Date.now(), - connectionId: options.connectionId - }; - } - - /** - * Create a route context from HTTP/2 stream and headers - */ - public createHttp2RouteContext( - stream: plugins.http2.ServerHttp2Stream, - headers: plugins.http2.IncomingHttpHeaders, - options: { - connectionId: string; - clientIp: string; - serverIp: string; - } - ): IHttp2RouteContext { - // Parse headers, excluding HTTP/2 pseudo-headers - const processedHeaders: Record = {}; - for (const [key, value] of Object.entries(headers)) { - if (!key.startsWith(':') && typeof value === 'string') { - processedHeaders[key.toLowerCase()] = value; - } - } - - // Get domain from :authority pseudo-header - const authority = headers[':authority'] as string || ''; - const domain = authority.split(':')[0]; - - // Get path from :path pseudo-header - const path = headers[':path'] as string || '/'; - - // Parse the path to extract query string - const pathParts = path.split('?'); - const pathname = pathParts[0]; - const query = pathParts.length > 1 ? pathParts[1] : ''; - - // Get the socket from the session - const socket = (stream.session as any)?.socket; - - return { - // Connection basics - port: socket?.localPort || 0, - domain, - clientIp: options.clientIp, - serverIp: options.serverIp, - - // HTTP specifics - path: pathname, - query, - headers: processedHeaders, - - // HTTP/2 specific properties - method: headers[':method'] as string, - stream, - - // TLS information - HTTP/2 is always on TLS in browsers - isTls: true, - tlsVersion: socket?.getTLSVersion?.() || 'TLSv1.3', - - // Metadata - timestamp: Date.now(), - connectionId: options.connectionId - }; - } - - /** - * Create a basic route context from socket information - */ - public createSocketRouteContext(socket: plugins.net.Socket, options: { - domain?: string; - tlsVersion?: string; - connectionId: string; - }): IRouteContext { - return { - // Connection basics - port: socket.localPort || 0, - domain: options.domain, - clientIp: socket.remoteAddress?.replace('::ffff:', '') || '0.0.0.0', - serverIp: socket.localAddress?.replace('::ffff:', '') || '0.0.0.0', - - // TLS information - isTls: options.tlsVersion !== undefined, - tlsVersion: options.tlsVersion, - - // Metadata - timestamp: Date.now(), - connectionId: options.connectionId - }; - } -} \ No newline at end of file diff --git a/ts/proxies/http-proxy/default-certificates.ts b/ts/proxies/http-proxy/default-certificates.ts deleted file mode 100644 index 3fe9812..0000000 --- a/ts/proxies/http-proxy/default-certificates.ts +++ /dev/null @@ -1,150 +0,0 @@ -import * as plugins from '../../plugins.js'; -import * as fs from 'fs'; -import * as path from 'path'; -import { fileURLToPath } from 'url'; -import { AsyncFileSystem } from '../../core/utils/fs-utils.js'; -import type { ILogger, ICertificateEntry } from './models/types.js'; - -/** - * Interface for default certificate data - */ -export interface IDefaultCertificates { - key: string; - cert: string; -} - -/** - * Provides default SSL certificates for HttpProxy. - * This is a minimal replacement for the deprecated CertificateManager. - * - * For production certificate management, use SmartCertManager instead. - */ -export class DefaultCertificateProvider { - private defaultCertificates: IDefaultCertificates | null = null; - private certificateCache: Map = new Map(); - private initialized = false; - - constructor(private logger?: ILogger) {} - - /** - * Load default certificates asynchronously (preferred) - */ - public async loadDefaultCertificatesAsync(): Promise { - if (this.defaultCertificates) { - return this.defaultCertificates; - } - - const __dirname = path.dirname(fileURLToPath(import.meta.url)); - const certPath = path.join(__dirname, '..', '..', '..', 'assets', 'certs'); - - try { - const [key, cert] = await Promise.all([ - AsyncFileSystem.readFile(path.join(certPath, 'key.pem')), - AsyncFileSystem.readFile(path.join(certPath, 'cert.pem')) - ]); - - this.defaultCertificates = { key, cert }; - this.logger?.info?.('Loaded default certificates from filesystem'); - this.initialized = true; - return this.defaultCertificates; - } catch (error) { - this.logger?.warn?.(`Failed to load default certificates: ${error}`); - this.defaultCertificates = this.generateFallbackCertificate(); - this.initialized = true; - return this.defaultCertificates; - } - } - - /** - * Load default certificates synchronously (for backward compatibility) - * @deprecated Use loadDefaultCertificatesAsync instead - */ - public loadDefaultCertificatesSync(): IDefaultCertificates { - if (this.defaultCertificates) { - return this.defaultCertificates; - } - - const __dirname = path.dirname(fileURLToPath(import.meta.url)); - const certPath = path.join(__dirname, '..', '..', '..', 'assets', 'certs'); - - try { - this.defaultCertificates = { - key: fs.readFileSync(path.join(certPath, 'key.pem'), 'utf8'), - cert: fs.readFileSync(path.join(certPath, 'cert.pem'), 'utf8') - }; - this.logger?.info?.('Loaded default certificates from filesystem (sync)'); - } catch (error) { - this.logger?.warn?.(`Failed to load default certificates: ${error}`); - this.defaultCertificates = this.generateFallbackCertificate(); - } - - this.initialized = true; - return this.defaultCertificates; - } - - /** - * Gets the default certificates (loads synchronously if not already loaded) - */ - public getDefaultCertificates(): IDefaultCertificates { - if (!this.defaultCertificates) { - return this.loadDefaultCertificatesSync(); - } - return this.defaultCertificates; - } - - /** - * Updates a certificate in the cache - */ - public updateCertificate(domain: string, cert: string, key: string): void { - this.certificateCache.set(domain, { - cert, - key, - expires: new Date(Date.now() + 90 * 24 * 60 * 60 * 1000) // 90 days - }); - - this.logger?.info?.(`Certificate updated for ${domain}`); - } - - /** - * Gets a cached certificate - */ - public getCachedCertificate(domain: string): ICertificateEntry | null { - return this.certificateCache.get(domain) || null; - } - - /** - * Gets statistics for metrics - */ - public getStats(): { cachedCertificates: number; defaultCertEnabled: boolean } { - return { - cachedCertificates: this.certificateCache.size, - defaultCertEnabled: this.defaultCertificates !== null - }; - } - - /** - * Generate a fallback self-signed certificate placeholder - * Note: This is just a placeholder - real apps should provide proper certificates - */ - private generateFallbackCertificate(): IDefaultCertificates { - this.logger?.warn?.('Using fallback self-signed certificate placeholder'); - - // Minimal self-signed certificate for fallback only - // In production, proper certificates should be provided via SmartCertManager - const selfSignedCert = `-----BEGIN CERTIFICATE----- -MIIBkTCB+wIJAKHHIgIIA0/cMA0GCSqGSIb3DQEBBQUAMA0xCzAJBgNVBAYTAlVT -MB4XDTE0MDEwMTAwMDAwMFoXDTI0MDEwMTAwMDAwMFowDTELMAkGA1UEBhMCVVMw -gZ8wDQYJKoZIhvcNAQEBBQADgY0AMIGJAoGBAMRiH0VwnOH3jCV7c6JFZWYrvuqy ------END CERTIFICATE-----`; - - const selfSignedKey = `-----BEGIN PRIVATE KEY----- -MIICdgIBADANBgkqhkiG9w0BAQEFAASCAmAwggJcAgEAAoGBAMRiH0VwnOH3jCV7 -c6JFZWYrvuqyALCLXj0pcr1iqNdHjegNXnkl5zjdaUjq4edNOKl7M1AlFiYjG2xk ------END PRIVATE KEY-----`; - - return { - key: selfSignedKey, - cert: selfSignedCert - }; - } -} diff --git a/ts/proxies/http-proxy/function-cache.ts b/ts/proxies/http-proxy/function-cache.ts deleted file mode 100644 index 59fb632..0000000 --- a/ts/proxies/http-proxy/function-cache.ts +++ /dev/null @@ -1,279 +0,0 @@ -import type { IRouteContext } from '../../core/models/route-context.js'; -import type { ILogger } from './models/types.js'; - -/** - * Interface for cached function result - */ -interface ICachedResult { - value: T; - expiry: number; - hash: string; -} - -/** - * Function cache for NetworkProxy function-based targets - * - * This cache improves performance for function-based targets by storing - * the results of function evaluations and reusing them for similar contexts. - */ -export class FunctionCache { - // Cache storage - private hostCache: Map> = new Map(); - private portCache: Map> = new Map(); - - // Maximum number of entries to store in each cache - private maxCacheSize: number; - - // Default TTL for cache entries in milliseconds (default: 5 seconds) - private defaultTtl: number; - - // Logger - private logger: ILogger; - - // Cleanup interval timer - private cleanupInterval: NodeJS.Timeout | null = null; - - /** - * Creates a new function cache - * - * @param logger Logger for debug output - * @param options Cache options - */ - constructor( - logger: ILogger, - options: { - maxCacheSize?: number; - defaultTtl?: number; - } = {} - ) { - this.logger = logger; - this.maxCacheSize = options.maxCacheSize || 1000; - this.defaultTtl = options.defaultTtl || 5000; // 5 seconds default - - // Start the cache cleanup timer - this.cleanupInterval = setInterval(() => this.cleanupCache(), 30000); // Cleanup every 30 seconds - - // Make sure the interval doesn't keep the process alive - if (this.cleanupInterval.unref) { - this.cleanupInterval.unref(); - } - } - - /** - * Compute a hash for a context object - * This is used to identify similar contexts for caching - * - * @param context The route context to hash - * @param functionId Identifier for the function (usually route name or ID) - * @returns A string hash - */ - private computeContextHash(context: IRouteContext, functionId: string): string { - // Extract relevant properties for the hash - const hashBase = { - functionId, - port: context.port, - domain: context.domain, - clientIp: context.clientIp, - path: context.path, - query: context.query, - isTls: context.isTls, - tlsVersion: context.tlsVersion - }; - - // Generate a hash string - return JSON.stringify(hashBase); - } - - /** - * Get cached host result for a function and context - * - * @param context Route context - * @param functionId Identifier for the function - * @returns Cached host value or undefined if not found - */ - public getCachedHost(context: IRouteContext, functionId: string): string | string[] | undefined { - const hash = this.computeContextHash(context, functionId); - const cached = this.hostCache.get(hash); - - // Return if no cached value or expired - if (!cached || cached.expiry < Date.now()) { - if (cached) { - // If expired, remove from cache - this.hostCache.delete(hash); - this.logger.debug(`Cache miss (expired) for host function: ${functionId}`); - } else { - this.logger.debug(`Cache miss for host function: ${functionId}`); - } - return undefined; - } - - this.logger.debug(`Cache hit for host function: ${functionId}`); - return cached.value; - } - - /** - * Get cached port result for a function and context - * - * @param context Route context - * @param functionId Identifier for the function - * @returns Cached port value or undefined if not found - */ - public getCachedPort(context: IRouteContext, functionId: string): number | undefined { - const hash = this.computeContextHash(context, functionId); - const cached = this.portCache.get(hash); - - // Return if no cached value or expired - if (!cached || cached.expiry < Date.now()) { - if (cached) { - // If expired, remove from cache - this.portCache.delete(hash); - this.logger.debug(`Cache miss (expired) for port function: ${functionId}`); - } else { - this.logger.debug(`Cache miss for port function: ${functionId}`); - } - return undefined; - } - - this.logger.debug(`Cache hit for port function: ${functionId}`); - return cached.value; - } - - /** - * Store a host function result in the cache - * - * @param context Route context - * @param functionId Identifier for the function - * @param value Host value to cache - * @param ttl Optional TTL in milliseconds - */ - public cacheHost( - context: IRouteContext, - functionId: string, - value: string | string[], - ttl?: number - ): void { - const hash = this.computeContextHash(context, functionId); - const expiry = Date.now() + (ttl || this.defaultTtl); - - // Check if we need to prune the cache before adding - if (this.hostCache.size >= this.maxCacheSize) { - this.pruneOldestEntries(this.hostCache); - } - - // Store the result - this.hostCache.set(hash, { value, expiry, hash }); - this.logger.debug(`Cached host function result for: ${functionId}`); - } - - /** - * Store a port function result in the cache - * - * @param context Route context - * @param functionId Identifier for the function - * @param value Port value to cache - * @param ttl Optional TTL in milliseconds - */ - public cachePort( - context: IRouteContext, - functionId: string, - value: number, - ttl?: number - ): void { - const hash = this.computeContextHash(context, functionId); - const expiry = Date.now() + (ttl || this.defaultTtl); - - // Check if we need to prune the cache before adding - if (this.portCache.size >= this.maxCacheSize) { - this.pruneOldestEntries(this.portCache); - } - - // Store the result - this.portCache.set(hash, { value, expiry, hash }); - this.logger.debug(`Cached port function result for: ${functionId}`); - } - - /** - * Remove expired entries from the cache - */ - private cleanupCache(): void { - const now = Date.now(); - let expiredCount = 0; - - // Clean up host cache - for (const [hash, cached] of this.hostCache.entries()) { - if (cached.expiry < now) { - this.hostCache.delete(hash); - expiredCount++; - } - } - - // Clean up port cache - for (const [hash, cached] of this.portCache.entries()) { - if (cached.expiry < now) { - this.portCache.delete(hash); - expiredCount++; - } - } - - if (expiredCount > 0) { - this.logger.debug(`Cleaned up ${expiredCount} expired cache entries`); - } - } - - /** - * Prune oldest entries from a cache map - * Used when the cache exceeds the maximum size - * - * @param cache The cache map to prune - */ - private pruneOldestEntries(cache: Map>): void { - // Find the oldest entries - const now = Date.now(); - const itemsToRemove = Math.floor(this.maxCacheSize * 0.2); // Remove 20% of the cache - - // Convert to array for sorting - const entries = Array.from(cache.entries()); - - // Sort by expiry (oldest first) - entries.sort((a, b) => a[1].expiry - b[1].expiry); - - // Remove oldest entries - const toRemove = entries.slice(0, itemsToRemove); - for (const [hash] of toRemove) { - cache.delete(hash); - } - - this.logger.debug(`Pruned ${toRemove.length} oldest cache entries`); - } - - /** - * Get current cache stats - */ - public getStats(): { hostCacheSize: number; portCacheSize: number } { - return { - hostCacheSize: this.hostCache.size, - portCacheSize: this.portCache.size - }; - } - - /** - * Clear all cached entries - */ - public clearCache(): void { - this.hostCache.clear(); - this.portCache.clear(); - this.logger.info('Function cache cleared'); - } - - /** - * Destroy the cache and cleanup resources - */ - public destroy(): void { - if (this.cleanupInterval) { - clearInterval(this.cleanupInterval); - this.cleanupInterval = null; - } - this.clearCache(); - this.logger.debug('Function cache destroyed'); - } -} \ No newline at end of file diff --git a/ts/proxies/http-proxy/handlers/index.ts b/ts/proxies/http-proxy/handlers/index.ts deleted file mode 100644 index 586da59..0000000 --- a/ts/proxies/http-proxy/handlers/index.ts +++ /dev/null @@ -1,5 +0,0 @@ -/** - * HTTP handlers for various route types - */ - -// Empty - all handlers have been removed \ No newline at end of file diff --git a/ts/proxies/http-proxy/http-proxy.ts b/ts/proxies/http-proxy/http-proxy.ts deleted file mode 100644 index 11e80b3..0000000 --- a/ts/proxies/http-proxy/http-proxy.ts +++ /dev/null @@ -1,669 +0,0 @@ -import * as plugins from '../../plugins.js'; -import { - createLogger, -} from './models/types.js'; -import { SharedRouteManager as RouteManager } from '../../core/routing/route-manager.js'; -import type { - IHttpProxyOptions, - ILogger -} from './models/types.js'; -import type { IRouteConfig } from '../smart-proxy/models/route-types.js'; -import type { IRouteContext, IHttpRouteContext } from '../../core/models/route-context.js'; -import { createBaseRouteContext } from '../../core/models/route-context.js'; -import { DefaultCertificateProvider } from './default-certificates.js'; -import { ConnectionPool } from './connection-pool.js'; -import { RequestHandler, type IMetricsTracker } from './request-handler.js'; -import { WebSocketHandler } from './websocket-handler.js'; -import { HttpRouter } from '../../routing/router/index.js'; -import { cleanupSocket } from '../../core/utils/socket-utils.js'; -import { FunctionCache } from './function-cache.js'; -import { SecurityManager } from './security-manager.js'; -import { connectionLogDeduplicator } from '../../core/utils/log-deduplicator.js'; - -/** - * HttpProxy provides a reverse proxy with TLS termination, WebSocket support, - * automatic certificate management, and high-performance connection pooling. - * Handles all HTTP/HTTPS traffic including redirects, ACME challenges, and static routes. - */ -export class HttpProxy implements IMetricsTracker { - // Provide a minimal JSON representation to avoid circular references during deep equality checks - public toJSON(): any { - return {}; - } - // Configuration - public options: IHttpProxyOptions; - public routes: IRouteConfig[] = []; - - // Server instances (HTTP/2 with HTTP/1 fallback) - public httpsServer: plugins.http2.Http2SecureServer; - - // Core components - private defaultCertProvider: DefaultCertificateProvider; - private connectionPool: ConnectionPool; - private requestHandler: RequestHandler; - private webSocketHandler: WebSocketHandler; - private router = new HttpRouter(); // Unified HTTP router - private routeManager: RouteManager; - private functionCache: FunctionCache; - private securityManager: SecurityManager; - - // State tracking - public socketMap = new plugins.lik.ObjectMap(); - public activeContexts: Set = new Set(); - public connectedClients: number = 0; - public startTime: number = 0; - public requestsServed: number = 0; - public failedRequests: number = 0; - - // Tracking for SmartProxy integration - private portProxyConnections: number = 0; - private tlsTerminatedConnections: number = 0; - - // Timers - private metricsInterval: NodeJS.Timeout; - private connectionPoolCleanupInterval: NodeJS.Timeout; - - // Logger - private logger: ILogger; - - /** - * Creates a new HttpProxy instance - */ - constructor(optionsArg: IHttpProxyOptions) { - // Set default options - this.options = { - port: optionsArg.port, - maxConnections: optionsArg.maxConnections || 10000, - keepAliveTimeout: optionsArg.keepAliveTimeout || 120000, // 2 minutes - headersTimeout: optionsArg.headersTimeout || 60000, // 1 minute - logLevel: optionsArg.logLevel || 'info', - cors: optionsArg.cors || { - allowOrigin: '*', - allowMethods: 'GET, POST, PUT, DELETE, OPTIONS', - allowHeaders: 'Content-Type, Authorization', - maxAge: 86400 - }, - // Defaults for SmartProxy integration - connectionPoolSize: optionsArg.connectionPoolSize || 50, - portProxyIntegration: optionsArg.portProxyIntegration || false, - // Backend protocol (http1 or http2) - backendProtocol: optionsArg.backendProtocol || 'http1', - // Default ACME options - acme: { - enabled: optionsArg.acme?.enabled || false, - port: optionsArg.acme?.port || 80, - accountEmail: optionsArg.acme?.accountEmail || 'admin@example.com', - useProduction: optionsArg.acme?.useProduction || false, // Default to staging for safety - renewThresholdDays: optionsArg.acme?.renewThresholdDays || 30, - autoRenew: optionsArg.acme?.autoRenew !== false, // Default to true - certificateStore: optionsArg.acme?.certificateStore || './certs', - skipConfiguredCerts: optionsArg.acme?.skipConfiguredCerts || false - } - }; - - // Initialize logger - this.logger = createLogger(this.options.logLevel); - - // Initialize route manager - this.routeManager = new RouteManager({ - logger: this.logger, - enableDetailedLogging: this.options.logLevel === 'debug', - routes: [] - }); - - // Initialize function cache - this.functionCache = new FunctionCache(this.logger, { - maxCacheSize: this.options.functionCacheSize || 1000, - defaultTtl: this.options.functionCacheTtl || 5000 - }); - - // Initialize security manager - this.securityManager = new SecurityManager( - this.logger, - [], - this.options.maxConnectionsPerIP || 100, - this.options.connectionRateLimitPerMinute || 300 - ); - - // Initialize other components - this.defaultCertProvider = new DefaultCertificateProvider(this.logger); - this.connectionPool = new ConnectionPool(this.options); - this.requestHandler = new RequestHandler( - this.options, - this.connectionPool, - this.routeManager, - this.functionCache, - this.router - ); - this.webSocketHandler = new WebSocketHandler( - this.options, - this.connectionPool, - this.routes // Pass current routes to WebSocketHandler - ); - - // Connect request handler to this metrics tracker - this.requestHandler.setMetricsTracker(this); - - // Initialize with any provided routes - if (this.options.routes && this.options.routes.length > 0) { - this.updateRouteConfigs(this.options.routes); - } - } - - /** - * Implements IMetricsTracker interface to increment request counters - */ - public incrementRequestsServed(): void { - this.requestsServed++; - } - - /** - * Implements IMetricsTracker interface to increment failed request counters - */ - public incrementFailedRequests(): void { - this.failedRequests++; - } - - /** - * Returns the port number this HttpProxy is listening on - * Useful for SmartProxy to determine where to forward connections - */ - public getListeningPort(): number { - // If the server is running, get the actual listening port - if (this.httpsServer && this.httpsServer.address()) { - const address = this.httpsServer.address(); - if (address && typeof address === 'object' && 'port' in address) { - return address.port; - } - } - // Fallback to configured port - return this.options.port; - } - - /** - * Updates the server capacity settings - * @param maxConnections Maximum number of simultaneous connections - * @param keepAliveTimeout Keep-alive timeout in milliseconds - * @param connectionPoolSize Size of the connection pool per backend - */ - public updateCapacity(maxConnections?: number, keepAliveTimeout?: number, connectionPoolSize?: number): void { - if (maxConnections !== undefined) { - this.options.maxConnections = maxConnections; - this.logger.info(`Updated max connections to ${maxConnections}`); - } - - if (keepAliveTimeout !== undefined) { - this.options.keepAliveTimeout = keepAliveTimeout; - - if (this.httpsServer) { - // HTTP/2 servers have setTimeout method for timeout management - this.httpsServer.setTimeout(keepAliveTimeout); - this.logger.info(`Updated server timeout to ${keepAliveTimeout}ms`); - } - } - - if (connectionPoolSize !== undefined) { - this.options.connectionPoolSize = connectionPoolSize; - this.logger.info(`Updated connection pool size to ${connectionPoolSize}`); - - // Clean up excess connections in the pool - this.connectionPool.cleanupConnectionPool(); - } - } - - /** - * Returns current server metrics - * Useful for SmartProxy to determine which HttpProxy to use for load balancing - */ - public getMetrics(): any { - return { - activeConnections: this.connectedClients, - totalRequests: this.requestsServed, - failedRequests: this.failedRequests, - portProxyConnections: this.portProxyConnections, - tlsTerminatedConnections: this.tlsTerminatedConnections, - connectionPoolSize: this.connectionPool.getPoolStatus(), - uptime: Math.floor((Date.now() - this.startTime) / 1000), - memoryUsage: process.memoryUsage(), - activeWebSockets: this.webSocketHandler.getConnectionInfo().activeConnections, - functionCache: this.functionCache.getStats() - }; - } - - /** - * Starts the proxy server - */ - public async start(): Promise { - this.startTime = Date.now(); - - // Create HTTP/2 server with HTTP/1 fallback - const defaultCerts = this.defaultCertProvider.getDefaultCertificates(); - this.httpsServer = plugins.http2.createSecureServer( - { - key: defaultCerts.key, - cert: defaultCerts.cert, - allowHTTP1: true, - ALPNProtocols: ['h2', 'http/1.1'] - } - ); - - // Track raw TCP connections for metrics and limits - this.setupConnectionTracking(); - - // Handle incoming HTTP/2 streams - this.httpsServer.on('stream', (stream: plugins.http2.ServerHttp2Stream, headers: plugins.http2.IncomingHttpHeaders) => { - this.requestHandler.handleHttp2(stream, headers); - }); - // Handle HTTP/1.x fallback requests - this.httpsServer.on('request', (req: plugins.http.IncomingMessage, res: plugins.http.ServerResponse) => { - this.requestHandler.handleRequest(req, res); - }); - - // Setup WebSocket support on HTTP/1 fallback - this.webSocketHandler.initialize(this.httpsServer as any); - // Start metrics logging - this.setupMetricsCollection(); - // Start periodic connection pool cleanup - this.connectionPoolCleanupInterval = this.connectionPool.setupPeriodicCleanup(); - - // Start the server - return new Promise((resolve) => { - this.httpsServer.listen(this.options.port, () => { - this.logger.info(`HttpProxy started on port ${this.options.port}`); - resolve(); - }); - }); - } - - /** - * Check if an address is a loopback address (IPv4 or IPv6) - */ - private isLoopback(addr?: string): boolean { - if (!addr) return false; - // Check for IPv6 loopback - if (addr === '::1') return true; - // Handle IPv6-mapped IPv4 addresses - if (addr.startsWith('::ffff:')) { - addr = addr.substring(7); - } - // Check for IPv4 loopback range (127.0.0.0/8) - return addr.startsWith('127.'); - } - - /** - * Sets up tracking of TCP connections - */ - private setupConnectionTracking(): void { - this.httpsServer.on('connection', (connection: plugins.net.Socket) => { - let remoteIP = connection.remoteAddress || ''; - const connectionId = Math.random().toString(36).substring(2, 15); - const isFromSmartProxy = this.options.portProxyIntegration && this.isLoopback(connection.remoteAddress); - - // For SmartProxy connections, wait for CLIENT_IP header - if (isFromSmartProxy) { - const MAX_PREFACE = 256; // bytes - prevent DoS - const HEADER_TIMEOUT_MS = 2000; // timeout for header parsing (increased for slow networks) - let headerTimer: NodeJS.Timeout | undefined; - let buffered = Buffer.alloc(0); - - const onData = (chunk: Buffer) => { - buffered = Buffer.concat([buffered, chunk]); - - // Prevent unbounded growth - if (buffered.length > MAX_PREFACE) { - connection.removeListener('data', onData); - if (headerTimer) clearTimeout(headerTimer); - this.logger.warn('Header preface too large, closing connection'); - connection.destroy(); - return; - } - - const idx = buffered.indexOf('\r\n'); - if (idx !== -1) { - const headerLine = buffered.slice(0, idx).toString('utf8'); - if (headerLine.startsWith('CLIENT_IP:')) { - remoteIP = headerLine.substring(10).trim(); - this.logger.debug(`Extracted client IP from SmartProxy: ${remoteIP}`); - } - - // Clean up listener and timer - connection.removeListener('data', onData); - if (headerTimer) clearTimeout(headerTimer); - - // Put remaining data back onto the stream - const remaining = buffered.slice(idx + 2); - if (remaining.length > 0) { - connection.unshift(remaining); - } - - // Store the real IP on the connection - connection._realRemoteIP = remoteIP; - - // Validate the real IP - const ipValidation = this.securityManager.validateIP(remoteIP); - if (!ipValidation.allowed) { - connectionLogDeduplicator.log( - 'ip-rejected', - 'warn', - `HttpProxy connection rejected (via SmartProxy)`, - { remoteIP, reason: ipValidation.reason, component: 'http-proxy' }, - remoteIP - ); - connection.destroy(); - return; - } - - // Track connection by real IP - this.securityManager.trackConnectionByIP(remoteIP, connectionId); - } - }; - - // Set timeout for header parsing - headerTimer = setTimeout(() => { - connection.removeListener('data', onData); - this.logger.warn('Header parsing timeout, closing connection'); - connection.destroy(); - }, HEADER_TIMEOUT_MS); - - // Unref the timer so it doesn't keep the process alive - if (headerTimer.unref) headerTimer.unref(); - - // Use prependListener to get data first - connection.prependListener('data', onData); - } else { - // Direct connection - validate immediately - const ipValidation = this.securityManager.validateIP(remoteIP); - if (!ipValidation.allowed) { - connectionLogDeduplicator.log( - 'ip-rejected', - 'warn', - `HttpProxy connection rejected`, - { remoteIP, reason: ipValidation.reason, component: 'http-proxy' }, - remoteIP - ); - connection.destroy(); - return; - } - - // Track connection by IP - this.securityManager.trackConnectionByIP(remoteIP, connectionId); - } - - // Then check global max connections - if (this.socketMap.getArray().length >= this.options.maxConnections) { - connectionLogDeduplicator.log( - 'connection-rejected', - 'warn', - 'HttpProxy max connections reached', - { - reason: 'global-limit', - currentConnections: this.socketMap.getArray().length, - maxConnections: this.options.maxConnections, - component: 'http-proxy' - }, - 'http-proxy-global-limit' - ); - connection.destroy(); - return; - } - - // Add connection to tracking with metadata - connection._connectionId = connectionId; - connection._remoteIP = remoteIP; - this.socketMap.add(connection); - this.connectedClients = this.socketMap.getArray().length; - - // Check for connection from SmartProxy by inspecting the source port - const localPort = connection.localPort || 0; - const remotePort = connection.remotePort || 0; - - // If this connection is from a SmartProxy - if (isFromSmartProxy) { - this.portProxyConnections++; - this.logger.debug(`New connection from SmartProxy for client ${remoteIP} (local: ${localPort}, remote: ${remotePort})`); - } else { - this.logger.debug(`New direct connection from ${remoteIP} (local: ${localPort}, remote: ${remotePort})`); - } - - // Setup connection cleanup handlers - const cleanupConnection = () => { - if (this.socketMap.checkForObject(connection)) { - this.socketMap.remove(connection); - this.connectedClients = this.socketMap.getArray().length; - - // Remove IP tracking - const connId = connection._connectionId; - const connIP = connection._realRemoteIP || connection._remoteIP; - if (connId && connIP) { - this.securityManager.removeConnectionByIP(connIP, connId); - } - - // If this was a SmartProxy connection, decrement the counter - if (this.options.portProxyIntegration && connection.remoteAddress?.includes('127.0.0.1')) { - this.portProxyConnections--; - } - - this.logger.debug(`Connection closed from ${connIP || 'unknown'}. ${this.connectedClients} connections remaining`); - } - }; - - connection.on('close', cleanupConnection); - connection.on('error', (err) => { - this.logger.debug('Connection error', err); - cleanupConnection(); - }); - connection.on('end', cleanupConnection); - }); - - // Track TLS handshake completions - this.httpsServer.on('secureConnection', (tlsSocket) => { - this.tlsTerminatedConnections++; - this.logger.debug('TLS handshake completed, connection secured'); - }); - } - - /** - * Sets up metrics collection - */ - private setupMetricsCollection(): void { - this.metricsInterval = setInterval(() => { - const uptime = Math.floor((Date.now() - this.startTime) / 1000); - const metrics = { - uptime, - activeConnections: this.connectedClients, - totalRequests: this.requestsServed, - failedRequests: this.failedRequests, - portProxyConnections: this.portProxyConnections, - tlsTerminatedConnections: this.tlsTerminatedConnections, - activeWebSockets: this.webSocketHandler.getConnectionInfo().activeConnections, - memoryUsage: process.memoryUsage(), - activeContexts: Array.from(this.activeContexts), - connectionPool: this.connectionPool.getPoolStatus() - }; - - this.logger.debug('Proxy metrics', metrics); - }, 60000); // Log metrics every minute - - // Don't keep process alive just for metrics - if (this.metricsInterval.unref) { - this.metricsInterval.unref(); - } - } - - /** - * Updates the route configurations - this is the primary method for configuring HttpProxy - * @param routes The new route configurations to use - */ - public async updateRouteConfigs(routes: IRouteConfig[]): Promise { - this.logger.info(`Updating route configurations (${routes.length} routes)`); - - // Update routes in RouteManager, modern router, WebSocketHandler, and SecurityManager - this.routeManager.updateRoutes(routes); - this.router.setRoutes(routes); - this.webSocketHandler.setRoutes(routes); - this.requestHandler.securityManager.setRoutes(routes); - this.routes = routes; - - // Collect all domains and certificates for configuration - const currentHostnames = new Set(); - const certificateUpdates = new Map(); - - // Process each route to extract domain and certificate information - for (const route of routes) { - // Skip non-forward routes or routes without domains - if (route.action.type !== 'forward' || !route.match.domains) { - continue; - } - - // Get domains from route - const domains = Array.isArray(route.match.domains) - ? route.match.domains - : [route.match.domains]; - - // Process each domain - for (const domain of domains) { - // Skip wildcard domains for direct host configuration - if (domain.includes('*')) { - continue; - } - - currentHostnames.add(domain); - - // Check if we have a static certificate for this domain - if (route.action.tls?.certificate && route.action.tls.certificate !== 'auto') { - certificateUpdates.set(domain, { - cert: route.action.tls.certificate.cert, - key: route.action.tls.certificate.key - }); - } - } - } - - // Update certificate cache with any static certificates - for (const [domain, certData] of certificateUpdates.entries()) { - try { - this.defaultCertProvider.updateCertificate( - domain, - certData.cert, - certData.key - ); - - this.activeContexts.add(domain); - } catch (error) { - this.logger.error(`Failed to add SSL context for ${domain}`, error); - } - } - - // Clean up removed contexts - for (const hostname of this.activeContexts) { - if (!currentHostnames.has(hostname)) { - this.logger.info(`Hostname ${hostname} removed from configuration`); - this.activeContexts.delete(hostname); - } - } - - // Update the router with new routes - this.router.setRoutes(routes); - - // Update WebSocket handler with new routes - this.webSocketHandler.setRoutes(routes); - - this.logger.info(`Route configuration updated with ${routes.length} routes`); - } - - // Legacy methods have been removed. - // Please use updateRouteConfigs() directly with modern route-based configuration. - - /** - * Adds default headers to be included in all responses - */ - public async addDefaultHeaders(headersArg: { [key: string]: string }): Promise { - this.logger.info('Adding default headers', headersArg); - this.requestHandler.setDefaultHeaders(headersArg); - } - - /** - * Stops the proxy server - */ - public async stop(): Promise { - this.logger.info('Stopping HttpProxy server'); - - // Clear intervals - if (this.metricsInterval) { - clearInterval(this.metricsInterval); - } - - if (this.connectionPoolCleanupInterval) { - clearInterval(this.connectionPoolCleanupInterval); - } - - // Stop WebSocket handler - this.webSocketHandler.shutdown(); - - // Destroy request handler (cleans up intervals and caches) - if (this.requestHandler && typeof this.requestHandler.destroy === 'function') { - this.requestHandler.destroy(); - } - - // Close all tracked sockets - const socketCleanupPromises = this.socketMap.getArray().map(socket => - cleanupSocket(socket, 'http-proxy-stop', { immediate: true }) - ); - await Promise.all(socketCleanupPromises); - - // Close all connection pool connections - this.connectionPool.closeAllConnections(); - - // Certificate management cleanup is handled by SmartCertManager - - // Flush any pending deduplicated logs - connectionLogDeduplicator.flushAll(); - - // Close the HTTPS server - return new Promise((resolve) => { - this.httpsServer.close(() => { - this.logger.info('HttpProxy server stopped successfully'); - resolve(); - }); - }); - } - - /** - * Requests a new certificate for a domain - * This can be used to manually trigger certificate issuance - * @param domain The domain to request a certificate for - * @returns A promise that resolves when the request is submitted (not when the certificate is issued) - */ - public async requestCertificate(domain: string): Promise { - this.logger.warn('requestCertificate is deprecated - use SmartCertManager instead'); - return false; - } - - /** - * Update certificate for a domain - * - * This method allows direct updates of certificates from external sources - * like Port80Handler or custom certificate providers. - * - * @param domain The domain to update certificate for - * @param certificate The new certificate (public key) - * @param privateKey The new private key - * @param expiryDate Optional expiry date - */ - public updateCertificate( - domain: string, - certificate: string, - privateKey: string, - expiryDate?: Date - ): void { - this.logger.info(`Updating certificate for ${domain}`); - this.defaultCertProvider.updateCertificate(domain, certificate, privateKey); - } - - /** - * Gets all route configurations currently in use - */ - public getRouteConfigs(): IRouteConfig[] { - return this.routeManager.getRoutes(); - } -} \ No newline at end of file diff --git a/ts/proxies/http-proxy/http-request-handler.ts b/ts/proxies/http-proxy/http-request-handler.ts deleted file mode 100644 index 840212d..0000000 --- a/ts/proxies/http-proxy/http-request-handler.ts +++ /dev/null @@ -1,331 +0,0 @@ -import * as plugins from '../../plugins.js'; -import '../../core/models/socket-augmentation.js'; -import type { IHttpRouteContext, IRouteContext } from '../../core/models/route-context.js'; -import type { ILogger } from './models/types.js'; -import type { IMetricsTracker } from './request-handler.js'; -import type { IRouteConfig } from '../smart-proxy/models/route-types.js'; -import { TemplateUtils } from '../../core/utils/template-utils.js'; - -/** - * HTTP Request Handler Helper - handles requests with specific destinations - * This is a helper class for the main RequestHandler - */ -export class HttpRequestHandler { - /** - * Handle HTTP request with a specific destination - */ - public static async handleHttpRequestWithDestination( - req: plugins.http.IncomingMessage, - res: plugins.http.ServerResponse, - destination: { host: string, port: number }, - routeContext: IHttpRouteContext, - startTime: number, - logger: ILogger, - metricsTracker?: IMetricsTracker | null, - route?: IRouteConfig - ): Promise { - try { - // Apply URL rewriting if route config is provided - if (route) { - HttpRequestHandler.applyUrlRewriting(req, route, routeContext, logger); - HttpRequestHandler.applyRouteHeaderModifications(route, req, res, logger); - } - - // Create options for the proxy request - const options: plugins.http.RequestOptions = { - hostname: destination.host, - port: destination.port, - path: req.url, - method: req.method, - headers: { ...req.headers } - }; - - // Optionally rewrite host header to match target - if (options.headers && 'host' in options.headers) { - // Only apply if host header rewrite is enabled or not explicitly disabled - const shouldRewriteHost = route?.action.options?.rewriteHostHeader !== false; - if (shouldRewriteHost) { - // Safely cast to OutgoingHttpHeaders to access host property - (options.headers as plugins.http.OutgoingHttpHeaders).host = `${destination.host}:${destination.port}`; - } - } - - logger.debug( - `Proxying request to ${destination.host}:${destination.port}${req.url}`, - { method: req.method } - ); - - // Create proxy request - const proxyReq = plugins.http.request(options, (proxyRes) => { - // Copy status code - res.statusCode = proxyRes.statusCode || 500; - - // Copy headers from proxy response to client response - for (const [key, value] of Object.entries(proxyRes.headers)) { - if (value !== undefined) { - res.setHeader(key, value); - } - } - - // Apply response header modifications if route config is provided - if (route && route.headers?.response) { - HttpRequestHandler.applyResponseHeaderModifications(route, res, logger, routeContext); - } - - // Pipe proxy response to client response - proxyRes.pipe(res); - - // Increment served requests counter when the response finishes - res.on('finish', () => { - if (metricsTracker) { - metricsTracker.incrementRequestsServed(); - } - - // Log the completed request - const duration = Date.now() - startTime; - logger.debug( - `Request completed in ${duration}ms: ${req.method} ${req.url} ${res.statusCode}`, - { duration, statusCode: res.statusCode } - ); - }); - }); - - // Handle proxy request errors - proxyReq.on('error', (error) => { - const duration = Date.now() - startTime; - logger.error( - `Proxy error for ${req.method} ${req.url}: ${error.message}`, - { duration, error: error.message } - ); - - // Increment failed requests counter - if (metricsTracker) { - metricsTracker.incrementFailedRequests(); - } - - // Check if headers have already been sent - if (!res.headersSent) { - res.statusCode = 502; - res.end(`Bad Gateway: ${error.message}`); - } else { - // If headers already sent, just close the connection - res.end(); - } - }); - - // Pipe request body to proxy request and handle client-side errors - req.pipe(proxyReq); - - // Handle client disconnection - req.on('error', (error) => { - logger.debug(`Client connection error: ${error.message}`); - proxyReq.destroy(); - - // Increment failed requests counter on client errors - if (metricsTracker) { - metricsTracker.incrementFailedRequests(); - } - }); - - // Handle response errors - res.on('error', (error) => { - logger.debug(`Response error: ${error.message}`); - proxyReq.destroy(); - - // Increment failed requests counter on response errors - if (metricsTracker) { - metricsTracker.incrementFailedRequests(); - } - }); - } catch (error) { - // Handle any unexpected errors - logger.error( - `Unexpected error handling request: ${error.message}`, - { error: error.stack } - ); - - // Increment failed requests counter - if (metricsTracker) { - metricsTracker.incrementFailedRequests(); - } - - if (!res.headersSent) { - res.statusCode = 500; - res.end('Internal Server Error'); - } else { - res.end(); - } - } - } - - /** - * Apply URL rewriting based on route configuration - * Implements Phase 5.2: URL rewriting using route context - * - * @param req The request with the URL to rewrite - * @param route The route configuration containing rewrite rules - * @param routeContext Context for template variable resolution - * @param logger Logger for debugging information - * @returns True if URL was rewritten, false otherwise - */ - private static applyUrlRewriting( - req: plugins.http.IncomingMessage, - route: IRouteConfig, - routeContext: IHttpRouteContext, - logger: ILogger - ): boolean { - // Check if route has URL rewriting configuration - if (!route.action.advanced?.urlRewrite) { - return false; - } - - const rewriteConfig = route.action.advanced.urlRewrite; - - // Store original URL for logging - const originalUrl = req.url; - - if (rewriteConfig.pattern && rewriteConfig.target) { - try { - // Create a RegExp from the pattern with optional flags - const regex = new RegExp(rewriteConfig.pattern, rewriteConfig.flags || ''); - - // Apply rewriting with template variable resolution - let target = rewriteConfig.target; - - // Replace template variables in target with values from context - target = TemplateUtils.resolveTemplateVariables(target, routeContext); - - // If onlyRewritePath is set, split URL into path and query parts - if (rewriteConfig.onlyRewritePath && req.url) { - const [path, query] = req.url.split('?'); - const rewrittenPath = path.replace(regex, target); - req.url = query ? `${rewrittenPath}?${query}` : rewrittenPath; - } else { - // Perform the replacement on the entire URL - req.url = req.url?.replace(regex, target); - } - - logger.debug(`URL rewritten: ${originalUrl} -> ${req.url}`); - return true; - } catch (err) { - logger.error(`Error in URL rewriting: ${err}`); - return false; - } - } - - return false; - } - - /** - * Apply header modifications from route configuration to request headers - * Implements Phase 5.1: Route-based header manipulation for requests - */ - private static applyRouteHeaderModifications( - route: IRouteConfig, - req: plugins.http.IncomingMessage, - res: plugins.http.ServerResponse, - logger: ILogger - ): void { - // Check if route has header modifications - if (!route.headers) { - return; - } - - // Apply request header modifications (these will be sent to the backend) - if (route.headers.request && req.headers) { - // Create routing context for template resolution - const routeContext: IRouteContext = { - domain: req.headers.host as string || '', - path: req.url || '', - clientIp: req.socket.remoteAddress?.replace('::ffff:', '') || '', - serverIp: req.socket.localAddress?.replace('::ffff:', '') || '', - port: parseInt(req.socket.localPort?.toString() || '0', 10), - isTls: !!req.socket.encrypted, - headers: req.headers as Record, - timestamp: Date.now(), - connectionId: `${Date.now()}-${Math.floor(Math.random() * 10000)}`, - }; - - for (const [key, value] of Object.entries(route.headers.request)) { - // Skip if header already exists and we're not overriding - if (req.headers[key.toLowerCase()] && !value.startsWith('!')) { - continue; - } - - // Handle special delete directive (!delete) - if (value === '!delete') { - delete req.headers[key.toLowerCase()]; - logger.debug(`Deleted request header: ${key}`); - continue; - } - - // Handle forced override (!value) - let finalValue: string; - if (value.startsWith('!')) { - // Keep the ! but resolve any templates in the rest - const templateValue = value.substring(1); - finalValue = '!' + TemplateUtils.resolveTemplateVariables(templateValue, routeContext); - } else { - // Resolve templates in the entire value - finalValue = TemplateUtils.resolveTemplateVariables(value, routeContext); - } - - // Set the header - req.headers[key.toLowerCase()] = finalValue; - logger.debug(`Modified request header: ${key}=${finalValue}`); - } - } - } - - /** - * Apply header modifications from route configuration to response headers - * Implements Phase 5.1: Route-based header manipulation for responses - */ - private static applyResponseHeaderModifications( - route: IRouteConfig, - res: plugins.http.ServerResponse, - logger: ILogger, - routeContext?: IRouteContext - ): void { - // Check if route has response header modifications - if (!route.headers?.response) { - return; - } - - // Apply response header modifications - for (const [key, value] of Object.entries(route.headers.response)) { - // Skip if header already exists and we're not overriding - if (res.hasHeader(key) && !value.startsWith('!')) { - continue; - } - - // Handle special delete directive (!delete) - if (value === '!delete') { - res.removeHeader(key); - logger.debug(`Deleted response header: ${key}`); - continue; - } - - // Handle forced override (!value) - let finalValue: string; - if (value.startsWith('!') && value !== '!delete') { - // Keep the ! but resolve any templates in the rest - const templateValue = value.substring(1); - finalValue = routeContext - ? '!' + TemplateUtils.resolveTemplateVariables(templateValue, routeContext) - : '!' + templateValue; - } else { - // Resolve templates in the entire value - finalValue = routeContext - ? TemplateUtils.resolveTemplateVariables(value, routeContext) - : value; - } - - // Set the header - res.setHeader(key, finalValue); - logger.debug(`Modified response header: ${key}=${finalValue}`); - } - } - - // Template resolution is now handled by the TemplateUtils class -} \ No newline at end of file diff --git a/ts/proxies/http-proxy/http2-request-handler.ts b/ts/proxies/http-proxy/http2-request-handler.ts deleted file mode 100644 index b40a2be..0000000 --- a/ts/proxies/http-proxy/http2-request-handler.ts +++ /dev/null @@ -1,255 +0,0 @@ -import * as plugins from '../../plugins.js'; -import type { IHttpRouteContext } from '../../core/models/route-context.js'; -import type { ILogger } from './models/types.js'; -import type { IMetricsTracker } from './request-handler.js'; - -/** - * HTTP/2 Request Handler Helper - handles HTTP/2 streams with specific destinations - * This is a helper class for the main RequestHandler - */ -export class Http2RequestHandler { - /** - * Handle HTTP/2 stream with direct HTTP/2 backend - */ - public static async handleHttp2WithHttp2Destination( - stream: plugins.http2.ServerHttp2Stream, - headers: plugins.http2.IncomingHttpHeaders, - destination: { host: string, port: number }, - routeContext: IHttpRouteContext, - sessions: Map, - logger: ILogger, - metricsTracker?: IMetricsTracker | null - ): Promise { - const key = `${destination.host}:${destination.port}`; - - // Get or create a client HTTP/2 session - let session = sessions.get(key); - if (!session || session.closed || (session as any).destroyed) { - try { - // Connect to the backend HTTP/2 server - session = plugins.http2.connect(`http://${destination.host}:${destination.port}`); - sessions.set(key, session); - - // Handle session errors and cleanup - session.on('error', (err) => { - logger.error(`HTTP/2 session error to ${key}: ${err.message}`); - sessions.delete(key); - }); - - session.on('close', () => { - logger.debug(`HTTP/2 session closed to ${key}`); - sessions.delete(key); - }); - } catch (err) { - logger.error(`Failed to establish HTTP/2 session to ${key}: ${err.message}`); - stream.respond({ ':status': 502 }); - stream.end('Bad Gateway: Failed to establish connection to backend'); - if (metricsTracker) metricsTracker.incrementFailedRequests(); - return; - } - } - - try { - // Build headers for backend HTTP/2 request - const h2Headers: Record = { - ':method': headers[':method'], - ':path': headers[':path'], - ':authority': `${destination.host}:${destination.port}` - }; - - // Copy other headers, excluding pseudo-headers - for (const [key, value] of Object.entries(headers)) { - if (!key.startsWith(':') && typeof value === 'string') { - h2Headers[key] = value; - } - } - - logger.debug( - `Proxying HTTP/2 request to ${destination.host}:${destination.port}${headers[':path']}`, - { method: headers[':method'] } - ); - - // Create HTTP/2 request stream to the backend - const h2Stream = session.request(h2Headers); - - // Pipe client stream to backend stream - stream.pipe(h2Stream); - - // Handle responses from the backend - h2Stream.on('response', (responseHeaders) => { - // Map status and headers to client response - const resp: Record = { - ':status': responseHeaders[':status'] as number - }; - - // Copy non-pseudo headers - for (const [key, value] of Object.entries(responseHeaders)) { - if (!key.startsWith(':') && value !== undefined) { - resp[key] = value; - } - } - - // Send headers to client - stream.respond(resp); - - // Pipe backend response to client - h2Stream.pipe(stream); - - // Track successful requests - stream.on('end', () => { - if (metricsTracker) metricsTracker.incrementRequestsServed(); - logger.debug( - `HTTP/2 request completed: ${headers[':method']} ${headers[':path']} ${responseHeaders[':status']}`, - { method: headers[':method'], status: responseHeaders[':status'] } - ); - }); - }); - - // Handle backend errors - h2Stream.on('error', (err) => { - logger.error(`HTTP/2 stream error: ${err.message}`); - - // Only send error response if headers haven't been sent - if (!stream.headersSent) { - stream.respond({ ':status': 502 }); - stream.end(`Bad Gateway: ${err.message}`); - } else { - stream.end(); - } - - if (metricsTracker) metricsTracker.incrementFailedRequests(); - }); - - // Handle client stream errors - stream.on('error', (err) => { - logger.debug(`Client HTTP/2 stream error: ${err.message}`); - h2Stream.destroy(); - if (metricsTracker) metricsTracker.incrementFailedRequests(); - }); - - } catch (err: any) { - logger.error(`Error handling HTTP/2 request: ${err.message}`); - - // Only send error response if headers haven't been sent - if (!stream.headersSent) { - stream.respond({ ':status': 500 }); - stream.end('Internal Server Error'); - } else { - stream.end(); - } - - if (metricsTracker) metricsTracker.incrementFailedRequests(); - } - } - - /** - * Handle HTTP/2 stream with HTTP/1 backend - */ - public static async handleHttp2WithHttp1Destination( - stream: plugins.http2.ServerHttp2Stream, - headers: plugins.http2.IncomingHttpHeaders, - destination: { host: string, port: number }, - routeContext: IHttpRouteContext, - logger: ILogger, - metricsTracker?: IMetricsTracker | null - ): Promise { - try { - // Build headers for HTTP/1 proxy request, excluding HTTP/2 pseudo-headers - const outboundHeaders: Record = {}; - for (const [key, value] of Object.entries(headers)) { - if (typeof key === 'string' && typeof value === 'string' && !key.startsWith(':')) { - outboundHeaders[key] = value; - } - } - - // Always rewrite host header to match target - outboundHeaders.host = `${destination.host}:${destination.port}`; - - logger.debug( - `Proxying HTTP/2 request to HTTP/1 backend ${destination.host}:${destination.port}${headers[':path']}`, - { method: headers[':method'] } - ); - - // Create HTTP/1 proxy request - const proxyReq = plugins.http.request( - { - hostname: destination.host, - port: destination.port, - path: headers[':path'] as string, - method: headers[':method'] as string, - headers: outboundHeaders - }, - (proxyRes) => { - // Map status and headers back to HTTP/2 - const responseHeaders: Record = { - ':status': proxyRes.statusCode || 500 - }; - - // Copy headers from HTTP/1 response to HTTP/2 response - for (const [key, value] of Object.entries(proxyRes.headers)) { - if (value !== undefined) { - responseHeaders[key] = value as string | string[]; - } - } - - // Send headers to client - stream.respond(responseHeaders); - - // Pipe HTTP/1 response to HTTP/2 stream - proxyRes.pipe(stream); - - // Clean up when client disconnects - stream.on('close', () => proxyReq.destroy()); - stream.on('error', () => proxyReq.destroy()); - - // Track successful requests - stream.on('end', () => { - if (metricsTracker) metricsTracker.incrementRequestsServed(); - logger.debug( - `HTTP/2 to HTTP/1 request completed: ${headers[':method']} ${headers[':path']} ${proxyRes.statusCode}`, - { method: headers[':method'], status: proxyRes.statusCode } - ); - }); - } - ); - - // Handle proxy request errors - proxyReq.on('error', (err) => { - logger.error(`HTTP/1 proxy error: ${err.message}`); - - // Only send error response if headers haven't been sent - if (!stream.headersSent) { - stream.respond({ ':status': 502 }); - stream.end(`Bad Gateway: ${err.message}`); - } else { - stream.end(); - } - - if (metricsTracker) metricsTracker.incrementFailedRequests(); - }); - - // Pipe client stream to proxy request - stream.pipe(proxyReq); - - // Handle client stream errors - stream.on('error', (err) => { - logger.debug(`Client HTTP/2 stream error: ${err.message}`); - proxyReq.destroy(); - if (metricsTracker) metricsTracker.incrementFailedRequests(); - }); - - } catch (err: any) { - logger.error(`Error handling HTTP/2 to HTTP/1 request: ${err.message}`); - - // Only send error response if headers haven't been sent - if (!stream.headersSent) { - stream.respond({ ':status': 500 }); - stream.end('Internal Server Error'); - } else { - stream.end(); - } - - if (metricsTracker) metricsTracker.incrementFailedRequests(); - } - } -} \ No newline at end of file diff --git a/ts/proxies/http-proxy/index.ts b/ts/proxies/http-proxy/index.ts deleted file mode 100644 index 5d4f51c..0000000 --- a/ts/proxies/http-proxy/index.ts +++ /dev/null @@ -1,18 +0,0 @@ -/** - * HttpProxy implementation - */ -// Re-export models -export * from './models/index.js'; - -// Export HttpProxy and supporting classes -export { HttpProxy } from './http-proxy.js'; -export { DefaultCertificateProvider } from './default-certificates.js'; -export { ConnectionPool } from './connection-pool.js'; -export { RequestHandler } from './request-handler.js'; -export type { IMetricsTracker, MetricsTracker } from './request-handler.js'; -export { WebSocketHandler } from './websocket-handler.js'; - -/** - * @deprecated Use DefaultCertificateProvider instead. This alias is for backward compatibility. - */ -export { DefaultCertificateProvider as CertificateManager } from './default-certificates.js'; diff --git a/ts/proxies/http-proxy/models/http-types.ts b/ts/proxies/http-proxy/models/http-types.ts deleted file mode 100644 index 8dffd50..0000000 --- a/ts/proxies/http-proxy/models/http-types.ts +++ /dev/null @@ -1,148 +0,0 @@ -import * as plugins from '../../../plugins.js'; -// Import from protocols for consistent status codes -import { HttpStatus as ProtocolHttpStatus, getStatusText as getProtocolStatusText } from '../../../protocols/http/index.js'; - -/** - * HTTP-specific event types - */ -export enum HttpEvents { - REQUEST_RECEIVED = 'request-received', - REQUEST_FORWARDED = 'request-forwarded', - REQUEST_HANDLED = 'request-handled', - REQUEST_ERROR = 'request-error', -} - - -// Re-export for backward compatibility with subset of commonly used codes -export const HttpStatus = { - OK: ProtocolHttpStatus.OK, - MOVED_PERMANENTLY: ProtocolHttpStatus.MOVED_PERMANENTLY, - FOUND: ProtocolHttpStatus.FOUND, - TEMPORARY_REDIRECT: ProtocolHttpStatus.TEMPORARY_REDIRECT, - PERMANENT_REDIRECT: ProtocolHttpStatus.PERMANENT_REDIRECT, - BAD_REQUEST: ProtocolHttpStatus.BAD_REQUEST, - UNAUTHORIZED: ProtocolHttpStatus.UNAUTHORIZED, - FORBIDDEN: ProtocolHttpStatus.FORBIDDEN, - NOT_FOUND: ProtocolHttpStatus.NOT_FOUND, - METHOD_NOT_ALLOWED: ProtocolHttpStatus.METHOD_NOT_ALLOWED, - REQUEST_TIMEOUT: ProtocolHttpStatus.REQUEST_TIMEOUT, - TOO_MANY_REQUESTS: ProtocolHttpStatus.TOO_MANY_REQUESTS, - INTERNAL_SERVER_ERROR: ProtocolHttpStatus.INTERNAL_SERVER_ERROR, - NOT_IMPLEMENTED: ProtocolHttpStatus.NOT_IMPLEMENTED, - BAD_GATEWAY: ProtocolHttpStatus.BAD_GATEWAY, - SERVICE_UNAVAILABLE: ProtocolHttpStatus.SERVICE_UNAVAILABLE, - GATEWAY_TIMEOUT: ProtocolHttpStatus.GATEWAY_TIMEOUT, -} as const; - -/** - * Base error class for HTTP-related errors - */ -export class HttpError extends Error { - constructor(message: string, public readonly statusCode: number = HttpStatus.INTERNAL_SERVER_ERROR) { - super(message); - this.name = 'HttpError'; - } -} - -/** - * Error related to certificate operations - */ -export class CertificateError extends HttpError { - constructor( - message: string, - public readonly domain: string, - public readonly isRenewal: boolean = false - ) { - super(`${message} for domain ${domain}${isRenewal ? ' (renewal)' : ''}`, HttpStatus.INTERNAL_SERVER_ERROR); - this.name = 'CertificateError'; - } -} - -/** - * Error related to server operations - */ -export class ServerError extends HttpError { - constructor(message: string, public readonly code?: string, statusCode: number = HttpStatus.INTERNAL_SERVER_ERROR) { - super(message, statusCode); - this.name = 'ServerError'; - } -} - -/** - * Error for bad requests - */ -export class BadRequestError extends HttpError { - constructor(message: string) { - super(message, HttpStatus.BAD_REQUEST); - this.name = 'BadRequestError'; - } -} - -/** - * Error for not found resources - */ -export class NotFoundError extends HttpError { - constructor(message: string = 'Resource not found') { - super(message, HttpStatus.NOT_FOUND); - this.name = 'NotFoundError'; - } -} - -/** - * Redirect configuration for HTTP requests - */ -export interface IRedirectConfig { - source: string; // Source path or pattern - destination: string; // Destination URL - type: number; // Redirect status code - preserveQuery?: boolean; // Whether to preserve query parameters -} - -/** - * HTTP router configuration - */ -export interface IRouterConfig { - routes: Array<{ - path: string; - method?: string; - handler: (req: plugins.http.IncomingMessage, res: plugins.http.ServerResponse) => void | Promise; - }>; - notFoundHandler?: (req: plugins.http.IncomingMessage, res: plugins.http.ServerResponse) => void; - errorHandler?: (error: Error, req: plugins.http.IncomingMessage, res: plugins.http.ServerResponse) => void; -} - -/** - * HTTP request method types - */ -export type HttpMethod = 'GET' | 'POST' | 'PUT' | 'DELETE' | 'PATCH' | 'HEAD' | 'OPTIONS' | 'CONNECT' | 'TRACE'; - - -/** - * Helper function to get HTTP status text - */ -export function getStatusText(status: number): string { - return getProtocolStatusText(status as ProtocolHttpStatus); -} - -// Legacy interfaces for backward compatibility -export interface IDomainOptions { - domainName: string; - sslRedirect: boolean; - acmeMaintenance: boolean; - forward?: { ip: string; port: number }; - acmeForward?: { ip: string; port: number }; -} - -export interface IDomainCertificate { - options: IDomainOptions; - certObtained: boolean; - obtainingInProgress: boolean; - certificate?: string; - privateKey?: string; - expiryDate?: Date; - lastRenewalAttempt?: Date; -} - -// Backward compatibility exports -export { HttpError as Port80HandlerError }; -export { CertificateError as CertError }; \ No newline at end of file diff --git a/ts/proxies/http-proxy/models/index.ts b/ts/proxies/http-proxy/models/index.ts deleted file mode 100644 index 8b6a627..0000000 --- a/ts/proxies/http-proxy/models/index.ts +++ /dev/null @@ -1,5 +0,0 @@ -/** - * HttpProxy models - */ -export * from './types.js'; -export * from './http-types.js'; diff --git a/ts/proxies/http-proxy/models/types.ts b/ts/proxies/http-proxy/models/types.ts deleted file mode 100644 index dd05c2c..0000000 --- a/ts/proxies/http-proxy/models/types.ts +++ /dev/null @@ -1,125 +0,0 @@ -import * as plugins from '../../../plugins.js'; -// Certificate types removed - define IAcmeOptions locally -export interface IAcmeOptions { - enabled: boolean; - email?: string; - accountEmail?: string; - port?: number; - certificateStore?: string; - environment?: 'production' | 'staging'; - useProduction?: boolean; - renewThresholdDays?: number; - autoRenew?: boolean; - skipConfiguredCerts?: boolean; -} -import type { IRouteConfig } from '../../smart-proxy/models/route-types.js'; - -/** - * Configuration options for HttpProxy - */ -export interface IHttpProxyOptions { - port: number; - maxConnections?: number; - keepAliveTimeout?: number; - headersTimeout?: number; - logLevel?: 'error' | 'warn' | 'info' | 'debug'; - cors?: { - allowOrigin?: string; - allowMethods?: string; - allowHeaders?: string; - maxAge?: number; - }; - - // Settings for SmartProxy integration - connectionPoolSize?: number; // Maximum connections to maintain in the pool to each backend - portProxyIntegration?: boolean; // Flag to indicate this proxy is used by SmartProxy - // Protocol to use when proxying to backends: HTTP/1.x or HTTP/2 - backendProtocol?: 'http1' | 'http2'; - - // Function cache options - functionCacheSize?: number; // Maximum number of cached function results (default: 1000) - functionCacheTtl?: number; // Time to live for cached function results in ms (default: 5000) - - // ACME certificate management options - acme?: IAcmeOptions; - - // Direct route configurations - routes?: IRouteConfig[]; - - // Rate limiting and security - maxConnectionsPerIP?: number; // Maximum simultaneous connections from a single IP - connectionRateLimitPerMinute?: number; // Max new connections per minute from a single IP -} - -/** - * Interface for a certificate entry in the cache - */ -export interface ICertificateEntry { - key: string; - cert: string; - expires?: Date; -} - - - -/** - * Interface for connection tracking in the pool - */ -export interface IConnectionEntry { - socket: plugins.net.Socket; - lastUsed: number; - isIdle: boolean; -} - -/** - * WebSocket with heartbeat interface - */ -export interface IWebSocketWithHeartbeat extends plugins.wsDefault { - lastPong: number; - isAlive: boolean; -} - -/** - * Logger interface for consistent logging across components - */ -export interface ILogger { - debug(message: string, data?: any): void; - info(message: string, data?: any): void; - warn(message: string, data?: any): void; - error(message: string, data?: any): void; -} - -/** - * Creates a logger based on the specified log level - */ -export function createLogger(logLevel: string = 'info'): ILogger { - const logLevels = { - error: 0, - warn: 1, - info: 2, - debug: 3 - }; - - return { - debug: (message: string, data?: any) => { - if (logLevels[logLevel] >= logLevels.debug) { - console.log(`[DEBUG] ${message}`, data || ''); - } - }, - info: (message: string, data?: any) => { - if (logLevels[logLevel] >= logLevels.info) { - console.log(`[INFO] ${message}`, data || ''); - } - }, - warn: (message: string, data?: any) => { - if (logLevels[logLevel] >= logLevels.warn) { - console.warn(`[WARN] ${message}`, data || ''); - } - }, - error: (message: string, data?: any) => { - if (logLevels[logLevel] >= logLevels.error) { - console.error(`[ERROR] ${message}`, data || ''); - } - } - }; -} \ No newline at end of file diff --git a/ts/proxies/http-proxy/request-handler.ts b/ts/proxies/http-proxy/request-handler.ts deleted file mode 100644 index e102f5c..0000000 --- a/ts/proxies/http-proxy/request-handler.ts +++ /dev/null @@ -1,878 +0,0 @@ -import * as plugins from '../../plugins.js'; -import '../../core/models/socket-augmentation.js'; -import { - type IHttpProxyOptions, - type ILogger, - createLogger, -} from './models/types.js'; -import { SharedRouteManager as RouteManager } from '../../core/routing/route-manager.js'; -import { ConnectionPool } from './connection-pool.js'; -import { ContextCreator } from './context-creator.js'; -import { HttpRequestHandler } from './http-request-handler.js'; -import { Http2RequestHandler } from './http2-request-handler.js'; -import type { IRouteConfig, IRouteTarget } from '../smart-proxy/models/route-types.js'; -import type { IRouteContext, IHttpRouteContext } from '../../core/models/route-context.js'; -import { toBaseContext } from '../../core/models/route-context.js'; -import { TemplateUtils } from '../../core/utils/template-utils.js'; -import { SecurityManager } from './security-manager.js'; - -/** - * Interface for tracking metrics - */ -export interface IMetricsTracker { - incrementRequestsServed(): void; - incrementFailedRequests(): void; -} - -// Backward compatibility -export type MetricsTracker = IMetricsTracker; - -/** - * Handles HTTP request processing and proxying - */ -export class RequestHandler { - private defaultHeaders: { [key: string]: string } = {}; - private logger: ILogger; - private metricsTracker: IMetricsTracker | null = null; - // HTTP/2 client sessions for backend proxying - private h2Sessions: Map = new Map(); - - // Context creator for route contexts - private contextCreator: ContextCreator = new ContextCreator(); - - // Security manager for IP filtering, rate limiting, etc. - public securityManager: SecurityManager; - - // Rate limit cleanup interval - private rateLimitCleanupInterval: NodeJS.Timeout | null = null; - - constructor( - private options: IHttpProxyOptions, - private connectionPool: ConnectionPool, - private routeManager?: RouteManager, - private functionCache?: any, // FunctionCache - using any to avoid circular dependency - private router?: any // HttpRouter - using any to avoid circular dependency - ) { - this.logger = createLogger(options.logLevel || 'info'); - this.securityManager = new SecurityManager(this.logger); - - // Schedule rate limit cleanup every minute - this.rateLimitCleanupInterval = setInterval(() => { - this.securityManager.cleanupExpiredRateLimits(); - }, 60000); - - // Make sure the interval doesn't keep the process alive - if (this.rateLimitCleanupInterval.unref) { - this.rateLimitCleanupInterval.unref(); - } - } - - /** - * Set the route manager instance - */ - public setRouteManager(routeManager: RouteManager): void { - this.routeManager = routeManager; - } - - /** - * Set the metrics tracker instance - */ - public setMetricsTracker(tracker: IMetricsTracker): void { - this.metricsTracker = tracker; - } - - /** - * Set default headers to be included in all responses - */ - public setDefaultHeaders(headers: { [key: string]: string }): void { - this.defaultHeaders = { - ...this.defaultHeaders, - ...headers - }; - this.logger.info('Updated default response headers'); - } - - /** - * Get all default headers - */ - public getDefaultHeaders(): { [key: string]: string } { - return { ...this.defaultHeaders }; - } - - /** - * Select the appropriate target from the targets array based on sub-matching criteria - */ - private selectTarget( - targets: IRouteTarget[], - context: { - port: number; - path?: string; - headers?: Record; - method?: string; - } - ): IRouteTarget | null { - // Sort targets by priority (higher first) - const sortedTargets = [...targets].sort((a, b) => (b.priority || 0) - (a.priority || 0)); - - // Find the first matching target - for (const target of sortedTargets) { - if (!target.match) { - // No match criteria means this is a default/fallback target - return target; - } - - // Check port match - if (target.match.ports && !target.match.ports.includes(context.port)) { - continue; - } - - // Check path match (supports wildcards) - if (target.match.path && context.path) { - const pathPattern = target.match.path.replace(/\*/g, '.*'); - const pathRegex = new RegExp(`^${pathPattern}$`); - if (!pathRegex.test(context.path)) { - continue; - } - } - - // Check method match - if (target.match.method && context.method && !target.match.method.includes(context.method)) { - continue; - } - - // Check headers match - if (target.match.headers && context.headers) { - let headersMatch = true; - for (const [key, pattern] of Object.entries(target.match.headers)) { - const headerValue = context.headers[key.toLowerCase()]; - if (!headerValue) { - headersMatch = false; - break; - } - - if (pattern instanceof RegExp) { - if (!pattern.test(headerValue)) { - headersMatch = false; - break; - } - } else if (headerValue !== pattern) { - headersMatch = false; - break; - } - } - if (!headersMatch) { - continue; - } - } - - // All criteria matched - return target; - } - - // No matching target found, return the first target without match criteria (default) - return sortedTargets.find(t => !t.match) || null; - } - - /** - * Apply CORS headers to response if configured - * Implements Phase 5.5: Context-aware CORS handling - * - * @param res The server response to apply headers to - * @param req The incoming request - * @param route Optional route config with CORS settings - */ - private applyCorsHeaders( - res: plugins.http.ServerResponse, - req: plugins.http.IncomingMessage, - route?: IRouteConfig - ): void { - // Use route-specific CORS config if available, otherwise use global config - let corsConfig: any = null; - - // Route CORS config takes precedence if enabled - if (route?.headers?.cors?.enabled) { - corsConfig = route.headers.cors; - this.logger.debug(`Using route-specific CORS config for ${route.name || 'unnamed route'}`); - } - // Fall back to global CORS config if available - else if (this.options.cors) { - corsConfig = this.options.cors; - this.logger.debug('Using global CORS config'); - } - - // If no CORS config available, skip - if (!corsConfig) { - return; - } - - // Get origin from request - const origin = req.headers.origin; - - // Apply Allow-Origin (with dynamic validation if needed) - if (corsConfig.allowOrigin) { - // Handle multiple origins in array format - if (Array.isArray(corsConfig.allowOrigin)) { - if (origin && corsConfig.allowOrigin.includes(origin)) { - // Match found, set specific origin - res.setHeader('Access-Control-Allow-Origin', origin); - res.setHeader('Vary', 'Origin'); // Important for caching - } else if (corsConfig.allowOrigin.includes('*')) { - // Wildcard match - res.setHeader('Access-Control-Allow-Origin', '*'); - } - } - // Handle single origin or wildcard - else if (corsConfig.allowOrigin === '*') { - res.setHeader('Access-Control-Allow-Origin', '*'); - } - // Match single origin against request - else if (origin && corsConfig.allowOrigin === origin) { - res.setHeader('Access-Control-Allow-Origin', origin); - res.setHeader('Vary', 'Origin'); - } - // Use template variables if present - else if (origin && corsConfig.allowOrigin.includes('{')) { - const resolvedOrigin = TemplateUtils.resolveTemplateVariables( - corsConfig.allowOrigin, - { domain: req.headers.host } as any - ); - if (resolvedOrigin === origin || resolvedOrigin === '*') { - res.setHeader('Access-Control-Allow-Origin', origin); - res.setHeader('Vary', 'Origin'); - } - } - } - - // Apply other CORS headers - if (corsConfig.allowMethods) { - res.setHeader('Access-Control-Allow-Methods', corsConfig.allowMethods); - } - - if (corsConfig.allowHeaders) { - res.setHeader('Access-Control-Allow-Headers', corsConfig.allowHeaders); - } - - if (corsConfig.allowCredentials) { - res.setHeader('Access-Control-Allow-Credentials', 'true'); - } - - if (corsConfig.exposeHeaders) { - res.setHeader('Access-Control-Expose-Headers', corsConfig.exposeHeaders); - } - - if (corsConfig.maxAge) { - res.setHeader('Access-Control-Max-Age', corsConfig.maxAge.toString()); - } - - // Handle CORS preflight requests if enabled (default: true) - if (req.method === 'OPTIONS' && corsConfig.preflight !== false) { - res.statusCode = 204; // No content - res.end(); - return; - } - } - - // First implementation of applyRouteHeaderModifications moved to the second implementation below - - /** - * Apply default headers to response - */ - private applyDefaultHeaders(res: plugins.http.ServerResponse): void { - // Apply default headers - for (const [key, value] of Object.entries(this.defaultHeaders)) { - if (!res.hasHeader(key)) { - res.setHeader(key, value); - } - } - - // Add server identifier if not already set - if (!res.hasHeader('Server')) { - res.setHeader('Server', 'NetworkProxy'); - } - } - - /** - * Apply URL rewriting based on route configuration - * Implements Phase 5.2: URL rewriting using route context - * - * @param req The request with the URL to rewrite - * @param route The route configuration containing rewrite rules - * @param routeContext Context for template variable resolution - * @returns True if URL was rewritten, false otherwise - */ - private applyUrlRewriting( - req: plugins.http.IncomingMessage, - route: IRouteConfig, - routeContext: IHttpRouteContext - ): boolean { - // Check if route has URL rewriting configuration - if (!route.action.advanced?.urlRewrite) { - return false; - } - - const rewriteConfig = route.action.advanced.urlRewrite; - - // Store original URL for logging - const originalUrl = req.url; - - if (rewriteConfig.pattern && rewriteConfig.target) { - try { - // Create a RegExp from the pattern - const regex = new RegExp(rewriteConfig.pattern, rewriteConfig.flags || ''); - - // Apply rewriting with template variable resolution - let target = rewriteConfig.target; - - // Replace template variables in target with values from context - target = TemplateUtils.resolveTemplateVariables(target, routeContext); - - // If onlyRewritePath is set, split URL into path and query parts - if (rewriteConfig.onlyRewritePath && req.url) { - const [path, query] = req.url.split('?'); - const rewrittenPath = path.replace(regex, target); - req.url = query ? `${rewrittenPath}?${query}` : rewrittenPath; - } else { - // Perform the replacement on the entire URL - req.url = req.url?.replace(regex, target); - } - - this.logger.debug(`URL rewritten: ${originalUrl} -> ${req.url}`); - return true; - } catch (err) { - this.logger.error(`Error in URL rewriting: ${err}`); - return false; - } - } - - return false; - } - - /** - * Apply header modifications from route configuration - * Implements Phase 5.1: Route-based header manipulation - */ - private applyRouteHeaderModifications( - route: IRouteConfig, - req: plugins.http.IncomingMessage, - res: plugins.http.ServerResponse - ): void { - // Check if route has header modifications - if (!route.headers) { - return; - } - - // Apply request header modifications (these will be sent to the backend) - if (route.headers.request && req.headers) { - for (const [key, value] of Object.entries(route.headers.request)) { - // Skip if header already exists and we're not overriding - if (req.headers[key.toLowerCase()] && !value.startsWith('!')) { - continue; - } - - // Handle special delete directive (!delete) - if (value === '!delete') { - delete req.headers[key.toLowerCase()]; - this.logger.debug(`Deleted request header: ${key}`); - continue; - } - - // Handle forced override (!value) - let finalValue: string; - if (value.startsWith('!') && value !== '!delete') { - // Keep the ! but resolve any templates in the rest - const templateValue = value.substring(1); - finalValue = '!' + TemplateUtils.resolveTemplateVariables(templateValue, {} as IRouteContext); - } else { - // Resolve templates in the entire value - finalValue = TemplateUtils.resolveTemplateVariables(value, {} as IRouteContext); - } - - // Set the header - req.headers[key.toLowerCase()] = finalValue; - this.logger.debug(`Modified request header: ${key}=${finalValue}`); - } - } - - // Apply response header modifications (these will be stored for later use) - if (route.headers.response) { - for (const [key, value] of Object.entries(route.headers.response)) { - // Skip if header already exists and we're not overriding - if (res.hasHeader(key) && !value.startsWith('!')) { - continue; - } - - // Handle special delete directive (!delete) - if (value === '!delete') { - res.removeHeader(key); - this.logger.debug(`Deleted response header: ${key}`); - continue; - } - - // Handle forced override (!value) - let finalValue: string; - if (value.startsWith('!') && value !== '!delete') { - // Keep the ! but resolve any templates in the rest - const templateValue = value.substring(1); - finalValue = '!' + TemplateUtils.resolveTemplateVariables(templateValue, {} as IRouteContext); - } else { - // Resolve templates in the entire value - finalValue = TemplateUtils.resolveTemplateVariables(value, {} as IRouteContext); - } - - // Set the header - res.setHeader(key, finalValue); - this.logger.debug(`Modified response header: ${key}=${finalValue}`); - } - } - } - - /** - * Handle an HTTP request - */ - public async handleRequest( - req: plugins.http.IncomingMessage, - res: plugins.http.ServerResponse - ): Promise { - // Record start time for logging - const startTime = Date.now(); - - // Get route before applying CORS (we might need its settings) - // Try to find a matching route using RouteManager - let matchingRoute: IRouteConfig | null = null; - if (this.routeManager) { - try { - // Create a connection ID for this request - const connectionId = `http-${Date.now()}-${Math.floor(Math.random() * 10000)}`; - - // Create route context for function-based targets - const routeContext = this.contextCreator.createHttpRouteContext(req, { - connectionId, - clientIp: req.socket.remoteAddress?.replace('::ffff:', '') || '0.0.0.0', - serverIp: req.socket.localAddress?.replace('::ffff:', '') || '0.0.0.0', - tlsVersion: req.socket.getTLSVersion?.() || undefined - }); - - const matchResult = this.routeManager.findMatchingRoute(toBaseContext(routeContext)); - matchingRoute = matchResult?.route || null; - } catch (err) { - this.logger.error('Error finding matching route', err); - } - } - - // Apply CORS headers with route-specific settings if available - this.applyCorsHeaders(res, req, matchingRoute); - - // If this is an OPTIONS request, the response has already been ended in applyCorsHeaders - // so we should return early to avoid trying to set more headers - if (req.method === 'OPTIONS') { - // Increment metrics for OPTIONS requests too - if (this.metricsTracker) { - this.metricsTracker.incrementRequestsServed(); - } - return; - } - - // Apply default headers - this.applyDefaultHeaders(res); - - // We already have the connection ID and routeContext from CORS handling - const connectionId = `http-${Date.now()}-${Math.floor(Math.random() * 10000)}`; - - // Create route context for function-based targets (if we don't already have one) - const routeContext = this.contextCreator.createHttpRouteContext(req, { - connectionId, - clientIp: req.socket.remoteAddress?.replace('::ffff:', '') || '0.0.0.0', - serverIp: req.socket.localAddress?.replace('::ffff:', '') || '0.0.0.0', - tlsVersion: req.socket.getTLSVersion?.() || undefined - }); - - // Check security restrictions if we have a matching route - if (matchingRoute) { - // Check IP filtering and rate limiting - if (!this.securityManager.isAllowed(matchingRoute, routeContext)) { - this.logger.warn(`Access denied for ${routeContext.clientIp} to ${matchingRoute.name || 'unnamed'}`); - res.statusCode = 403; - res.end('Forbidden: Access denied by security policy'); - if (this.metricsTracker) this.metricsTracker.incrementFailedRequests(); - return; - } - - // Check basic auth - if (matchingRoute.security?.basicAuth?.enabled) { - const authHeader = req.headers.authorization; - if (!authHeader || !authHeader.startsWith('Basic ')) { - // No auth header provided - send 401 with WWW-Authenticate header - res.statusCode = 401; - const realm = matchingRoute.security.basicAuth.realm || 'Protected Area'; - res.setHeader('WWW-Authenticate', `Basic realm="${realm}", charset="UTF-8"`); - res.end('Authentication Required'); - if (this.metricsTracker) this.metricsTracker.incrementFailedRequests(); - return; - } - - // Verify credentials - try { - const credentials = Buffer.from(authHeader.substring(6), 'base64').toString('utf-8'); - const [username, password] = credentials.split(':'); - - if (!this.securityManager.checkBasicAuth(matchingRoute, username, password)) { - res.statusCode = 401; - const realm = matchingRoute.security.basicAuth.realm || 'Protected Area'; - res.setHeader('WWW-Authenticate', `Basic realm="${realm}", charset="UTF-8"`); - res.end('Invalid Credentials'); - if (this.metricsTracker) this.metricsTracker.incrementFailedRequests(); - return; - } - } catch (err) { - this.logger.error(`Error verifying basic auth: ${err}`); - res.statusCode = 401; - res.end('Authentication Error'); - if (this.metricsTracker) this.metricsTracker.incrementFailedRequests(); - return; - } - } - - // Check JWT auth - if (matchingRoute.security?.jwtAuth?.enabled) { - const authHeader = req.headers.authorization; - if (!authHeader || !authHeader.startsWith('Bearer ')) { - // No auth header provided - send 401 - res.statusCode = 401; - res.end('Authentication Required: JWT token missing'); - if (this.metricsTracker) this.metricsTracker.incrementFailedRequests(); - return; - } - - // Verify token - const token = authHeader.substring(7); - if (!this.securityManager.verifyJwtToken(matchingRoute, token)) { - res.statusCode = 401; - res.end('Invalid or Expired JWT'); - if (this.metricsTracker) this.metricsTracker.incrementFailedRequests(); - return; - } - } - } - - // If we found a matching route with forward action, select appropriate target - if (matchingRoute && matchingRoute.action.type === 'forward' && matchingRoute.action.targets && matchingRoute.action.targets.length > 0) { - this.logger.debug(`Found matching route: ${matchingRoute.name || 'unnamed'}`); - - // Select the appropriate target from the targets array - const selectedTarget = this.selectTarget(matchingRoute.action.targets, { - port: routeContext.port, - path: routeContext.path, - headers: routeContext.headers, - method: routeContext.method - }); - - if (!selectedTarget) { - this.logger.error(`No matching target found for route ${matchingRoute.name}`); - req.socket.end(); - return; - } - - // Extract target information, resolving functions if needed - let targetHost: string | string[]; - let targetPort: number; - - try { - // Check function cache for host and resolve or use cached value - if (typeof selectedTarget.host === 'function') { - // Generate a function ID for caching (use route name or ID if available) - const functionId = `host-${matchingRoute.id || matchingRoute.name || 'unnamed'}`; - - // Check if we have a cached result - if (this.functionCache) { - const cachedHost = this.functionCache.getCachedHost(routeContext, functionId); - if (cachedHost !== undefined) { - targetHost = cachedHost; - this.logger.debug(`Using cached host value for ${functionId}`); - } else { - // Resolve the function and cache the result - const resolvedHost = selectedTarget.host(toBaseContext(routeContext)); - targetHost = resolvedHost; - - // Cache the result - this.functionCache.cacheHost(routeContext, functionId, resolvedHost); - this.logger.debug(`Resolved and cached function-based host to: ${Array.isArray(resolvedHost) ? resolvedHost.join(', ') : resolvedHost}`); - } - } else { - // No cache available, just resolve - const resolvedHost = selectedTarget.host(routeContext); - targetHost = resolvedHost; - this.logger.debug(`Resolved function-based host to: ${Array.isArray(resolvedHost) ? resolvedHost.join(', ') : resolvedHost}`); - } - } else { - targetHost = selectedTarget.host; - } - - // Check function cache for port and resolve or use cached value - if (typeof selectedTarget.port === 'function') { - // Generate a function ID for caching - const functionId = `port-${matchingRoute.id || matchingRoute.name || 'unnamed'}`; - - // Check if we have a cached result - if (this.functionCache) { - const cachedPort = this.functionCache.getCachedPort(routeContext, functionId); - if (cachedPort !== undefined) { - targetPort = cachedPort; - this.logger.debug(`Using cached port value for ${functionId}`); - } else { - // Resolve the function and cache the result - const resolvedPort = selectedTarget.port(toBaseContext(routeContext)); - targetPort = resolvedPort; - - // Cache the result - this.functionCache.cachePort(routeContext, functionId, resolvedPort); - this.logger.debug(`Resolved and cached function-based port to: ${resolvedPort}`); - } - } else { - // No cache available, just resolve - const resolvedPort = selectedTarget.port(routeContext); - targetPort = resolvedPort; - this.logger.debug(`Resolved function-based port to: ${resolvedPort}`); - } - } else { - targetPort = selectedTarget.port === 'preserve' ? routeContext.port : selectedTarget.port as number; - } - - // Select a single host if an array was provided - const selectedHost = Array.isArray(targetHost) - ? targetHost[Math.floor(Math.random() * targetHost.length)] - : targetHost; - - // Create a destination for the connection pool - const destination = { - host: selectedHost, - port: targetPort - }; - - // Apply URL rewriting if configured - this.applyUrlRewriting(req, matchingRoute, routeContext); - - // Apply header modifications if configured - this.applyRouteHeaderModifications(matchingRoute, req, res); - - // Continue with handling using the resolved destination - HttpRequestHandler.handleHttpRequestWithDestination( - req, - res, - destination, - routeContext, - startTime, - this.logger, - this.metricsTracker, - matchingRoute // Pass the route config for additional processing - ); - return; - } catch (err) { - this.logger.error(`Error evaluating function-based target: ${err}`); - res.statusCode = 500; - res.end('Internal Server Error: Failed to evaluate target functions'); - if (this.metricsTracker) this.metricsTracker.incrementFailedRequests(); - return; - } - } - - // If no route was found, return 404 - this.logger.warn(`No route configuration for host: ${req.headers.host}`); - res.statusCode = 404; - res.end('Not Found: No route configuration for this host'); - if (this.metricsTracker) this.metricsTracker.incrementFailedRequests(); - } - - /** - * Handle HTTP/2 stream requests with function-based target support - */ - public async handleHttp2(stream: plugins.http2.ServerHttp2Stream, headers: plugins.http2.IncomingHttpHeaders): Promise { - const startTime = Date.now(); - - // Create a connection ID for this HTTP/2 stream - const connectionId = `http2-${Date.now()}-${Math.floor(Math.random() * 10000)}`; - - // Get client IP and server IP from the socket - const socket = (stream.session as any)?.socket; - const clientIp = socket?.remoteAddress?.replace('::ffff:', '') || '0.0.0.0'; - const serverIp = socket?.localAddress?.replace('::ffff:', '') || '0.0.0.0'; - - // Create route context for function-based targets - const routeContext = this.contextCreator.createHttp2RouteContext(stream, headers, { - connectionId, - clientIp, - serverIp - }); - - // Try to find a matching route using RouteManager - let matchingRoute: IRouteConfig | null = null; - if (this.routeManager) { - try { - const matchResult = this.routeManager.findMatchingRoute(toBaseContext(routeContext)); - matchingRoute = matchResult?.route || null; - } catch (err) { - this.logger.error('Error finding matching route for HTTP/2 request', err); - } - } - - // If we found a matching route with forward action, select appropriate target - if (matchingRoute && matchingRoute.action.type === 'forward' && matchingRoute.action.targets && matchingRoute.action.targets.length > 0) { - this.logger.debug(`Found matching route for HTTP/2 request: ${matchingRoute.name || 'unnamed'}`); - - // Select the appropriate target from the targets array - const selectedTarget = this.selectTarget(matchingRoute.action.targets, { - port: routeContext.port, - path: routeContext.path, - headers: routeContext.headers, - method: routeContext.method - }); - - if (!selectedTarget) { - this.logger.error(`No matching target found for route ${matchingRoute.name}`); - stream.respond({ ':status': 502 }); - stream.end(); - return; - } - - // Extract target information, resolving functions if needed - let targetHost: string | string[]; - let targetPort: number; - - try { - // Check function cache for host and resolve or use cached value - if (typeof selectedTarget.host === 'function') { - // Generate a function ID for caching (use route name or ID if available) - const functionId = `host-http2-${matchingRoute.id || matchingRoute.name || 'unnamed'}`; - - // Check if we have a cached result - if (this.functionCache) { - const cachedHost = this.functionCache.getCachedHost(routeContext, functionId); - if (cachedHost !== undefined) { - targetHost = cachedHost; - this.logger.debug(`Using cached host value for HTTP/2: ${functionId}`); - } else { - // Resolve the function and cache the result - const resolvedHost = selectedTarget.host(toBaseContext(routeContext)); - targetHost = resolvedHost; - - // Cache the result - this.functionCache.cacheHost(routeContext, functionId, resolvedHost); - this.logger.debug(`Resolved and cached HTTP/2 function-based host to: ${Array.isArray(resolvedHost) ? resolvedHost.join(', ') : resolvedHost}`); - } - } else { - // No cache available, just resolve - const resolvedHost = selectedTarget.host(routeContext); - targetHost = resolvedHost; - this.logger.debug(`Resolved HTTP/2 function-based host to: ${Array.isArray(resolvedHost) ? resolvedHost.join(', ') : resolvedHost}`); - } - } else { - targetHost = selectedTarget.host; - } - - // Check function cache for port and resolve or use cached value - if (typeof selectedTarget.port === 'function') { - // Generate a function ID for caching - const functionId = `port-http2-${matchingRoute.id || matchingRoute.name || 'unnamed'}`; - - // Check if we have a cached result - if (this.functionCache) { - const cachedPort = this.functionCache.getCachedPort(routeContext, functionId); - if (cachedPort !== undefined) { - targetPort = cachedPort; - this.logger.debug(`Using cached port value for HTTP/2: ${functionId}`); - } else { - // Resolve the function and cache the result - const resolvedPort = selectedTarget.port(toBaseContext(routeContext)); - targetPort = resolvedPort; - - // Cache the result - this.functionCache.cachePort(routeContext, functionId, resolvedPort); - this.logger.debug(`Resolved and cached HTTP/2 function-based port to: ${resolvedPort}`); - } - } else { - // No cache available, just resolve - const resolvedPort = selectedTarget.port(routeContext); - targetPort = resolvedPort; - this.logger.debug(`Resolved HTTP/2 function-based port to: ${resolvedPort}`); - } - } else { - targetPort = selectedTarget.port === 'preserve' ? routeContext.port : selectedTarget.port as number; - } - - // Select a single host if an array was provided - const selectedHost = Array.isArray(targetHost) - ? targetHost[Math.floor(Math.random() * targetHost.length)] - : targetHost; - - // Create a destination for forwarding - const destination = { - host: selectedHost, - port: targetPort - }; - - // Handle HTTP/2 stream based on backend protocol - const backendProtocol = matchingRoute.action.options?.backendProtocol || this.options.backendProtocol; - - if (backendProtocol === 'http2') { - // Forward to HTTP/2 backend - return Http2RequestHandler.handleHttp2WithHttp2Destination( - stream, - headers, - destination, - routeContext, - this.h2Sessions, - this.logger, - this.metricsTracker - ); - } else { - // Forward to HTTP/1.1 backend - return Http2RequestHandler.handleHttp2WithHttp1Destination( - stream, - headers, - destination, - routeContext, - this.logger, - this.metricsTracker - ); - } - } catch (err) { - this.logger.error(`Error evaluating function-based target for HTTP/2: ${err}`); - stream.respond({ ':status': 500 }); - stream.end('Internal Server Error: Failed to evaluate target functions'); - if (this.metricsTracker) this.metricsTracker.incrementFailedRequests(); - return; - } - } - - // Fall back to legacy routing if no matching route found - const method = headers[':method'] || 'GET'; - const path = headers[':path'] || '/'; - - // No route was found - stream.respond({ ':status': 404 }); - stream.end('Not Found: No route configuration for this request'); - if (this.metricsTracker) this.metricsTracker.incrementFailedRequests(); - } - - /** - * Cleanup resources and stop intervals - */ - public destroy(): void { - if (this.rateLimitCleanupInterval) { - clearInterval(this.rateLimitCleanupInterval); - this.rateLimitCleanupInterval = null; - } - - // Close all HTTP/2 sessions - for (const [key, session] of this.h2Sessions) { - session.close(); - } - this.h2Sessions.clear(); - - // Clear function cache if it has a destroy method - if (this.functionCache && typeof this.functionCache.destroy === 'function') { - this.functionCache.destroy(); - } - - this.logger.debug('RequestHandler destroyed'); - } -} \ No newline at end of file diff --git a/ts/proxies/http-proxy/security-manager.ts b/ts/proxies/http-proxy/security-manager.ts deleted file mode 100644 index 858e1f4..0000000 --- a/ts/proxies/http-proxy/security-manager.ts +++ /dev/null @@ -1,413 +0,0 @@ -import type { ILogger } from './models/types.js'; -import type { IRouteConfig } from '../smart-proxy/models/route-types.js'; -import type { IRouteContext } from '../../core/models/route-context.js'; -import { - isIPAuthorized, - normalizeIP, - parseBasicAuthHeader, - cleanupExpiredRateLimits, - type IRateLimitInfo -} from '../../core/utils/security-utils.js'; - -/** - * Manages security features for the HttpProxy - * Implements IP filtering, rate limiting, and authentication. - * Uses shared utilities from security-utils.ts. - */ -export class SecurityManager { - // Cache IP filtering results to avoid constant regex matching - private ipFilterCache: Map> = new Map(); - - // Store rate limits per route and key - private rateLimits: Map> = new Map(); - - // Connection tracking by IP - private connectionsByIP: Map> = new Map(); - private connectionRateByIP: Map = new Map(); - - constructor( - private logger: ILogger, - private routes: IRouteConfig[] = [], - private maxConnectionsPerIP: number = 100, - private connectionRateLimitPerMinute: number = 300 - ) { - // Start periodic cleanup for connection tracking - this.startPeriodicIpCleanup(); - } - - /** - * Update the routes configuration - */ - public setRoutes(routes: IRouteConfig[]): void { - this.routes = routes; - // Reset caches when routes change - this.ipFilterCache.clear(); - } - - /** - * Check if a client is allowed to access a specific route - * - * @param route The route to check access for - * @param context The route context with client information - * @returns True if access is allowed, false otherwise - */ - public isAllowed(route: IRouteConfig, context: IRouteContext): boolean { - if (!route.security) { - return true; // No security restrictions - } - - // --- IP filtering --- - if (!this.isIpAllowed(route, context.clientIp)) { - this.logger.debug(`IP ${context.clientIp} is blocked for route ${route.name || 'unnamed'}`); - return false; - } - - // --- Rate limiting --- - if (route.security.rateLimit?.enabled && !this.isWithinRateLimit(route, context)) { - this.logger.debug(`Rate limit exceeded for route ${route.name || 'unnamed'}`); - return false; - } - - // --- Basic Auth (handled at HTTP level) --- - // Basic auth is not checked here as it requires HTTP headers - // and is handled in the RequestHandler - - return true; - } - - /** - * Check if an IP is allowed based on route security settings - */ - private isIpAllowed(route: IRouteConfig, clientIp: string): boolean { - if (!route.security) { - return true; // No security restrictions - } - - const routeId = route.name || 'unnamed'; - - // Check cache first - if (!this.ipFilterCache.has(routeId)) { - this.ipFilterCache.set(routeId, new Map()); - } - - const routeCache = this.ipFilterCache.get(routeId)!; - if (routeCache.has(clientIp)) { - return routeCache.get(clientIp)!; - } - - // Use shared utility for IP authorization - const allowed = isIPAuthorized( - clientIp, - route.security.ipAllowList, - route.security.ipBlockList - ); - - // Cache the result - routeCache.set(clientIp, allowed); - - return allowed; - } - - /** - * Check if request is within rate limit - */ - private isWithinRateLimit(route: IRouteConfig, context: IRouteContext): boolean { - if (!route.security?.rateLimit?.enabled) { - return true; - } - - const rateLimit = route.security.rateLimit; - const routeId = route.name || 'unnamed'; - - // Determine rate limit key (by IP, path, or header) - let key = context.clientIp; // Default to IP - - if (rateLimit.keyBy === 'path' && context.path) { - key = `${context.clientIp}:${context.path}`; - } else if (rateLimit.keyBy === 'header' && rateLimit.headerName && context.headers) { - const headerValue = context.headers[rateLimit.headerName.toLowerCase()]; - if (headerValue) { - key = `${context.clientIp}:${headerValue}`; - } - } - - // Get or create rate limit tracking for this route - if (!this.rateLimits.has(routeId)) { - this.rateLimits.set(routeId, new Map()); - } - - const routeLimits = this.rateLimits.get(routeId)!; - const now = Date.now(); - - // Get or create rate limit tracking for this key - let limit = routeLimits.get(key); - if (!limit || limit.expiry < now) { - // Create new rate limit or reset expired one - limit = { - count: 1, - expiry: now + (rateLimit.window * 1000) - }; - routeLimits.set(key, limit); - return true; - } - - // Increment the counter - limit.count++; - - // Check if rate limit is exceeded - return limit.count <= rateLimit.maxRequests; - } - - /** - * Clean up expired rate limits - * Should be called periodically to prevent memory leaks - */ - public cleanupExpiredRateLimits(): void { - cleanupExpiredRateLimits(this.rateLimits, { - info: this.logger.info.bind(this.logger), - warn: this.logger.warn.bind(this.logger), - error: this.logger.error.bind(this.logger), - debug: this.logger.debug?.bind(this.logger) - }); - } - - /** - * Check basic auth credentials - * - * @param route The route to check auth for - * @param username The provided username - * @param password The provided password - * @returns True if credentials are valid, false otherwise - */ - public checkBasicAuth(route: IRouteConfig, username: string, password: string): boolean { - if (!route.security?.basicAuth?.enabled) { - return true; - } - - const basicAuth = route.security.basicAuth; - - // Check credentials against configured users - for (const user of basicAuth.users) { - if (user.username === username && user.password === password) { - return true; - } - } - - return false; - } - - /** - * Verify a JWT token - * - * @param route The route to verify the token for - * @param token The JWT token to verify - * @returns True if the token is valid, false otherwise - */ - public verifyJwtToken(route: IRouteConfig, token: string): boolean { - if (!route.security?.jwtAuth?.enabled) { - return true; - } - - try { - const jwtAuth = route.security.jwtAuth; - - // Verify structure - const parts = token.split('.'); - if (parts.length !== 3) { - return false; - } - - // Decode payload - const payload = JSON.parse(Buffer.from(parts[1], 'base64').toString()); - - // Check expiration - if (payload.exp && payload.exp < Math.floor(Date.now() / 1000)) { - return false; - } - - // Check issuer - if (jwtAuth.issuer && payload.iss !== jwtAuth.issuer) { - return false; - } - - // Check audience - if (jwtAuth.audience && payload.aud !== jwtAuth.audience) { - return false; - } - - // Note: In a real implementation, you'd also verify the signature - // using the secret and algorithm specified in jwtAuth - - return true; - } catch (err) { - this.logger.error(`Error verifying JWT: ${err}`); - return false; - } - } - - /** - * Get connections count by IP (checks normalized variants) - */ - public getConnectionCountByIP(ip: string): number { - // Check all normalized variants of the IP - const variants = normalizeIP(ip); - for (const variant of variants) { - const connections = this.connectionsByIP.get(variant); - if (connections) { - return connections.size; - } - } - return 0; - } - - /** - * Check and update connection rate for an IP - * @returns true if within rate limit, false if exceeding limit - */ - public checkConnectionRate(ip: string): boolean { - const now = Date.now(); - const minute = 60 * 1000; - - // Find existing rate tracking (check normalized variants) - const variants = normalizeIP(ip); - let existingKey: string | null = null; - for (const variant of variants) { - if (this.connectionRateByIP.has(variant)) { - existingKey = variant; - break; - } - } - - const key = existingKey || ip; - - if (!this.connectionRateByIP.has(key)) { - this.connectionRateByIP.set(key, [now]); - return true; - } - - // Get timestamps and filter out entries older than 1 minute - const timestamps = this.connectionRateByIP.get(key)!.filter((time) => now - time < minute); - timestamps.push(now); - this.connectionRateByIP.set(key, timestamps); - - // Check if rate exceeds limit - return timestamps.length <= this.connectionRateLimitPerMinute; - } - - /** - * Track connection by IP - */ - public trackConnectionByIP(ip: string, connectionId: string): void { - // Check if any variant already exists - const variants = normalizeIP(ip); - let existingKey: string | null = null; - - for (const variant of variants) { - if (this.connectionsByIP.has(variant)) { - existingKey = variant; - break; - } - } - - const key = existingKey || ip; - if (!this.connectionsByIP.has(key)) { - this.connectionsByIP.set(key, new Set()); - } - this.connectionsByIP.get(key)!.add(connectionId); - } - - /** - * Remove connection tracking for an IP - */ - public removeConnectionByIP(ip: string, connectionId: string): void { - // Check all variants to find where the connection is tracked - const variants = normalizeIP(ip); - - for (const variant of variants) { - if (this.connectionsByIP.has(variant)) { - const connections = this.connectionsByIP.get(variant)!; - connections.delete(connectionId); - if (connections.size === 0) { - this.connectionsByIP.delete(variant); - } - break; - } - } - } - - /** - * Check if IP should be allowed considering connection rate and max connections - * @returns Object with result and reason - */ - public validateIP(ip: string): { allowed: boolean; reason?: string } { - // Check connection count limit - if (this.getConnectionCountByIP(ip) >= this.maxConnectionsPerIP) { - return { - allowed: false, - reason: `Maximum connections per IP (${this.maxConnectionsPerIP}) exceeded` - }; - } - - // Check connection rate limit - if (!this.checkConnectionRate(ip)) { - return { - allowed: false, - reason: `Connection rate limit (${this.connectionRateLimitPerMinute}/min) exceeded` - }; - } - - return { allowed: true }; - } - - /** - * Clears all IP tracking data (for shutdown) - */ - public clearIPTracking(): void { - this.connectionsByIP.clear(); - this.connectionRateByIP.clear(); - } - - /** - * Start periodic cleanup of IP tracking data - */ - private startPeriodicIpCleanup(): void { - // Clean up IP tracking data every minute - setInterval(() => { - this.performIpCleanup(); - }, 60000).unref(); - } - - /** - * Perform cleanup of expired IP data - */ - private performIpCleanup(): void { - const now = Date.now(); - const minute = 60 * 1000; - let cleanedRateLimits = 0; - let cleanedIPs = 0; - - // Clean up expired rate limit timestamps - for (const [ip, timestamps] of this.connectionRateByIP.entries()) { - const validTimestamps = timestamps.filter((time) => now - time < minute); - - if (validTimestamps.length === 0) { - this.connectionRateByIP.delete(ip); - cleanedRateLimits++; - } else if (validTimestamps.length < timestamps.length) { - this.connectionRateByIP.set(ip, validTimestamps); - } - } - - // Clean up IPs with no active connections - for (const [ip, connections] of this.connectionsByIP.entries()) { - if (connections.size === 0) { - this.connectionsByIP.delete(ip); - cleanedIPs++; - } - } - - if (cleanedRateLimits > 0 || cleanedIPs > 0) { - this.logger.debug(`IP cleanup: removed ${cleanedIPs} IPs and ${cleanedRateLimits} rate limits`); - } - } -} \ No newline at end of file diff --git a/ts/proxies/http-proxy/websocket-handler.ts b/ts/proxies/http-proxy/websocket-handler.ts deleted file mode 100644 index 9f496ad..0000000 --- a/ts/proxies/http-proxy/websocket-handler.ts +++ /dev/null @@ -1,581 +0,0 @@ -import * as plugins from '../../plugins.js'; -import '../../core/models/socket-augmentation.js'; -import { type IHttpProxyOptions, type IWebSocketWithHeartbeat, type ILogger, createLogger } from './models/types.js'; -import { ConnectionPool } from './connection-pool.js'; -import { HttpRouter } from '../../routing/router/index.js'; -import type { IRouteConfig, IRouteTarget } from '../smart-proxy/models/route-types.js'; -import type { IRouteContext } from '../../core/models/route-context.js'; -import { toBaseContext } from '../../core/models/route-context.js'; -import { ContextCreator } from './context-creator.js'; -import { SecurityManager } from './security-manager.js'; -import { TemplateUtils } from '../../core/utils/template-utils.js'; -import { getMessageSize, toBuffer } from '../../core/utils/websocket-utils.js'; - -/** - * Handles WebSocket connections and proxying - */ -export class WebSocketHandler { - private heartbeatInterval: NodeJS.Timeout | null = null; - private wsServer: plugins.ws.WebSocketServer | null = null; - private logger: ILogger; - private contextCreator: ContextCreator = new ContextCreator(); - private router: HttpRouter | null = null; - private securityManager: SecurityManager; - - constructor( - private options: IHttpProxyOptions, - private connectionPool: ConnectionPool, - private routes: IRouteConfig[] = [] - ) { - this.logger = createLogger(options.logLevel || 'info'); - this.securityManager = new SecurityManager(this.logger, routes); - - // Initialize router if we have routes - if (routes.length > 0) { - this.router = new HttpRouter(routes, this.logger); - } - } - - /** - * Set the route configurations - */ - public setRoutes(routes: IRouteConfig[]): void { - this.routes = routes; - - // Initialize or update the route router - if (!this.router) { - this.router = new HttpRouter(routes, this.logger); - } else { - this.router.setRoutes(routes); - } - - // Update the security manager - this.securityManager.setRoutes(routes); - } - - /** - * Select the appropriate target from the targets array based on sub-matching criteria - */ - private selectTarget( - targets: IRouteTarget[], - context: { - port: number; - path?: string; - headers?: Record; - method?: string; - } - ): IRouteTarget | null { - // Sort targets by priority (higher first) - const sortedTargets = [...targets].sort((a, b) => (b.priority || 0) - (a.priority || 0)); - - // Find the first matching target - for (const target of sortedTargets) { - if (!target.match) { - // No match criteria means this is a default/fallback target - return target; - } - - // Check port match - if (target.match.ports && !target.match.ports.includes(context.port)) { - continue; - } - - // Check path match (supports wildcards) - if (target.match.path && context.path) { - const pathPattern = target.match.path.replace(/\*/g, '.*'); - const pathRegex = new RegExp(`^${pathPattern}$`); - if (!pathRegex.test(context.path)) { - continue; - } - } - - // Check method match - if (target.match.method && context.method && !target.match.method.includes(context.method)) { - continue; - } - - // Check headers match - if (target.match.headers && context.headers) { - let headersMatch = true; - for (const [key, pattern] of Object.entries(target.match.headers)) { - const headerValue = context.headers[key.toLowerCase()]; - if (!headerValue) { - headersMatch = false; - break; - } - - if (pattern instanceof RegExp) { - if (!pattern.test(headerValue)) { - headersMatch = false; - break; - } - } else if (headerValue !== pattern) { - headersMatch = false; - break; - } - } - if (!headersMatch) { - continue; - } - } - - // All criteria matched - return target; - } - - // No matching target found, return the first target without match criteria (default) - return sortedTargets.find(t => !t.match) || null; - } - - /** - * Initialize WebSocket server on an existing HTTPS server - */ - public initialize(server: plugins.https.Server): void { - // Create WebSocket server - this.wsServer = new plugins.ws.WebSocketServer({ - server: server, - clientTracking: true - }); - - // Handle WebSocket connections - this.wsServer.on('connection', (wsIncoming: IWebSocketWithHeartbeat, req: plugins.http.IncomingMessage) => { - this.handleWebSocketConnection(wsIncoming, req); - }); - - // Start the heartbeat interval - this.startHeartbeat(); - - this.logger.info('WebSocket handler initialized'); - } - - /** - * Start the heartbeat interval to check for inactive WebSocket connections - */ - private startHeartbeat(): void { - // Clean up existing interval if any - if (this.heartbeatInterval) { - clearInterval(this.heartbeatInterval); - } - - // Set up the heartbeat interval (check every 30 seconds) - this.heartbeatInterval = setInterval(() => { - if (!this.wsServer || this.wsServer.clients.size === 0) { - return; // Skip if no active connections - } - - this.logger.debug(`WebSocket heartbeat check for ${this.wsServer.clients.size} clients`); - - this.wsServer.clients.forEach((ws: plugins.wsDefault) => { - const wsWithHeartbeat = ws as IWebSocketWithHeartbeat; - - if (wsWithHeartbeat.isAlive === false) { - this.logger.debug('Terminating inactive WebSocket connection'); - return wsWithHeartbeat.terminate(); - } - - wsWithHeartbeat.isAlive = false; - wsWithHeartbeat.ping(); - }); - }, 30000); - - // Make sure the interval doesn't keep the process alive - if (this.heartbeatInterval.unref) { - this.heartbeatInterval.unref(); - } - } - - /** - * Handle a new WebSocket connection - */ - private handleWebSocketConnection(wsIncoming: IWebSocketWithHeartbeat, req: plugins.http.IncomingMessage): void { - this.logger.debug(`WebSocket connection initiated from ${req.headers.host}`); - - try { - // Initialize heartbeat tracking - wsIncoming.isAlive = true; - wsIncoming.lastPong = Date.now(); - - // Handle pong messages to track liveness - wsIncoming.on('pong', () => { - wsIncoming.isAlive = true; - wsIncoming.lastPong = Date.now(); - }); - - // Create a context for routing - const connectionId = `ws-${Date.now()}-${Math.floor(Math.random() * 10000)}`; - const routeContext = this.contextCreator.createHttpRouteContext(req, { - connectionId, - clientIp: req.socket.remoteAddress?.replace('::ffff:', '') || '0.0.0.0', - serverIp: req.socket.localAddress?.replace('::ffff:', '') || '0.0.0.0', - tlsVersion: req.socket.getTLSVersion?.() || undefined - }); - - // Try modern router first if available - let route: IRouteConfig | undefined; - if (this.router) { - route = this.router.routeReq(req); - } - - // Define destination variables - let destination: { host: string; port: number }; - - // If we found a route with the modern router, use it - if (route && route.action.type === 'forward' && route.action.targets && route.action.targets.length > 0) { - this.logger.debug(`Found matching WebSocket route: ${route.name || 'unnamed'}`); - - // Select the appropriate target from the targets array - const selectedTarget = this.selectTarget(route.action.targets, { - port: routeContext.port, - path: routeContext.path, - headers: routeContext.headers, - method: routeContext.method - }); - - if (!selectedTarget) { - this.logger.error(`No matching target found for route ${route.name}`); - wsIncoming.close(1003, 'No matching target'); - return; - } - - // Check if WebSockets are enabled for this route - if (route.action.websocket?.enabled === false) { - this.logger.debug(`WebSockets are disabled for route: ${route.name || 'unnamed'}`); - wsIncoming.close(1003, 'WebSockets not supported for this route'); - return; - } - - // Check security restrictions if configured to authenticate WebSocket requests - if (route.action.websocket?.authenticateRequest !== false && route.security) { - if (!this.securityManager.isAllowed(route, toBaseContext(routeContext))) { - this.logger.warn(`WebSocket connection denied by security policy for ${routeContext.clientIp}`); - wsIncoming.close(1008, 'Access denied by security policy'); - return; - } - - // Check origin restrictions if configured - const origin = req.headers.origin; - if (origin && route.action.websocket?.allowedOrigins && route.action.websocket.allowedOrigins.length > 0) { - const isAllowed = route.action.websocket.allowedOrigins.some(allowedOrigin => { - // Handle wildcards and template variables - if (allowedOrigin.includes('*') || allowedOrigin.includes('{')) { - const pattern = allowedOrigin.replace(/\*/g, '.*'); - const resolvedPattern = TemplateUtils.resolveTemplateVariables(pattern, routeContext); - const regex = new RegExp(`^${resolvedPattern}$`); - return regex.test(origin); - } - return allowedOrigin === origin; - }); - - if (!isAllowed) { - this.logger.warn(`WebSocket origin ${origin} not allowed for route: ${route.name || 'unnamed'}`); - wsIncoming.close(1008, 'Origin not allowed'); - return; - } - } - } - - // Extract target information, resolving functions if needed - let targetHost: string | string[]; - let targetPort: number; - - try { - // Resolve host if it's a function - if (typeof selectedTarget.host === 'function') { - const resolvedHost = selectedTarget.host(toBaseContext(routeContext)); - targetHost = resolvedHost; - this.logger.debug(`Resolved function-based host for WebSocket: ${Array.isArray(resolvedHost) ? resolvedHost.join(', ') : resolvedHost}`); - } else { - targetHost = selectedTarget.host; - } - - // Resolve port if it's a function - if (typeof selectedTarget.port === 'function') { - targetPort = selectedTarget.port(toBaseContext(routeContext)); - this.logger.debug(`Resolved function-based port for WebSocket: ${targetPort}`); - } else { - targetPort = selectedTarget.port === 'preserve' ? routeContext.port : selectedTarget.port as number; - } - - // Select a single host if an array was provided - const selectedHost = Array.isArray(targetHost) - ? targetHost[Math.floor(Math.random() * targetHost.length)] - : targetHost; - - // Create a destination for the WebSocket connection - destination = { - host: selectedHost, - port: targetPort - }; - - this.logger.debug(`WebSocket destination resolved: ${selectedHost}:${targetPort}`); - } catch (err) { - this.logger.error(`Error evaluating function-based target for WebSocket: ${err}`); - wsIncoming.close(1011, 'Internal server error'); - return; - } - } else { - // No route found - this.logger.warn(`No route configuration for WebSocket host: ${req.headers.host}`); - wsIncoming.close(1008, 'No route configuration for this host'); - return; - } - - // Build target URL with potential path rewriting - // Determine protocol based on the target's configuration - // For WebSocket connections, we use ws for HTTP backends and wss for HTTPS backends - const isTargetSecure = destination.port === 443; - const protocol = isTargetSecure ? 'wss' : 'ws'; - let targetPath = req.url || '/'; - - // Apply path rewriting if configured - if (route?.action.websocket?.rewritePath) { - const originalPath = targetPath; - targetPath = TemplateUtils.resolveTemplateVariables( - route.action.websocket.rewritePath, - {...routeContext, path: targetPath} - ); - this.logger.debug(`WebSocket path rewritten: ${originalPath} -> ${targetPath}`); - } - - const targetUrl = `${protocol}://${destination.host}:${destination.port}${targetPath}`; - - this.logger.debug(`WebSocket connection from ${req.socket.remoteAddress} to ${targetUrl}`); - - // Create headers for outgoing WebSocket connection - const headers: { [key: string]: string } = {}; - - // Copy relevant headers from incoming request - for (const [key, value] of Object.entries(req.headers)) { - if (value && typeof value === 'string' && - key.toLowerCase() !== 'connection' && - key.toLowerCase() !== 'upgrade' && - key.toLowerCase() !== 'sec-websocket-key' && - key.toLowerCase() !== 'sec-websocket-version') { - headers[key] = value; - } - } - - // Always rewrite host header for WebSockets for consistency - headers['host'] = `${destination.host}:${destination.port}`; - - // Add custom headers from route configuration - if (route?.action.websocket?.customHeaders) { - for (const [key, value] of Object.entries(route.action.websocket.customHeaders)) { - // Skip if header already exists and we're not overriding - if (headers[key.toLowerCase()] && !value.startsWith('!')) { - continue; - } - - // Handle special delete directive (!delete) - if (value === '!delete') { - delete headers[key.toLowerCase()]; - continue; - } - - // Handle forced override (!value) - let finalValue: string; - if (value.startsWith('!') && value !== '!delete') { - // Keep the ! but resolve any templates in the rest - const templateValue = value.substring(1); - finalValue = '!' + TemplateUtils.resolveTemplateVariables(templateValue, routeContext); - } else { - // Resolve templates in the entire value - finalValue = TemplateUtils.resolveTemplateVariables(value, routeContext); - } - - // Set the header - headers[key.toLowerCase()] = finalValue; - } - } - - // Create WebSocket connection options - const wsOptions: any = { - headers: headers, - followRedirects: true - }; - - // Add subprotocols if configured - if (route?.action.websocket?.subprotocols && route.action.websocket.subprotocols.length > 0) { - wsOptions.protocols = route.action.websocket.subprotocols; - } else if (req.headers['sec-websocket-protocol']) { - // Pass through client requested protocols - wsOptions.protocols = req.headers['sec-websocket-protocol'].split(',').map(p => p.trim()); - } - - // Create outgoing WebSocket connection - this.logger.debug(`Creating WebSocket connection to ${targetUrl} with options:`, { - headers: wsOptions.headers, - protocols: wsOptions.protocols - }); - const wsOutgoing = new plugins.wsDefault(targetUrl, wsOptions); - this.logger.debug(`WebSocket instance created, waiting for connection...`); - - // Handle connection errors - wsOutgoing.on('error', (err) => { - this.logger.error(`WebSocket target connection error: ${err.message}`); - if (wsIncoming.readyState === wsIncoming.OPEN) { - wsIncoming.close(1011, 'Internal server error'); - } - }); - - // Handle outgoing connection open - wsOutgoing.on('open', () => { - this.logger.debug(`WebSocket target connection opened to ${targetUrl}`); - // Set up custom ping interval if configured - let pingInterval: NodeJS.Timeout | null = null; - if (route?.action.websocket?.pingInterval && route.action.websocket.pingInterval > 0) { - pingInterval = setInterval(() => { - if (wsIncoming.readyState === wsIncoming.OPEN) { - wsIncoming.ping(); - this.logger.debug(`Sent WebSocket ping to client for route: ${route.name || 'unnamed'}`); - } - }, route.action.websocket.pingInterval); - - // Don't keep process alive just for pings - if (pingInterval.unref) pingInterval.unref(); - } - - // Set up custom ping timeout if configured - let pingTimeout: NodeJS.Timeout | null = null; - const pingTimeoutMs = route?.action.websocket?.pingTimeout || 60000; // Default 60s - - // Define timeout function for cleaner code - const resetPingTimeout = () => { - if (pingTimeout) clearTimeout(pingTimeout); - pingTimeout = setTimeout(() => { - this.logger.debug(`WebSocket ping timeout for client connection on route: ${route?.name || 'unnamed'}`); - wsIncoming.terminate(); - }, pingTimeoutMs); - - // Don't keep process alive just for timeouts - if (pingTimeout.unref) pingTimeout.unref(); - }; - - // Reset timeout on pong - wsIncoming.on('pong', () => { - wsIncoming.isAlive = true; - wsIncoming.lastPong = Date.now(); - resetPingTimeout(); - }); - - // Initial ping timeout - resetPingTimeout(); - - // Handle potential message size limits - const maxSize = route?.action.websocket?.maxPayloadSize || 0; - - // Forward incoming messages to outgoing connection - wsIncoming.on('message', (data, isBinary) => { - this.logger.debug(`WebSocket forwarding message from client to target: ${data.toString()}`); - if (wsOutgoing.readyState === wsOutgoing.OPEN) { - // Check message size if limit is set - const messageSize = getMessageSize(data); - if (maxSize > 0 && messageSize > maxSize) { - this.logger.warn(`WebSocket message exceeds max size (${messageSize} > ${maxSize})`); - wsIncoming.close(1009, 'Message too big'); - return; - } - - wsOutgoing.send(data, { binary: isBinary }); - } else { - this.logger.warn(`WebSocket target connection not open (state: ${wsOutgoing.readyState})`); - } - }); - - // Forward outgoing messages to incoming connection - wsOutgoing.on('message', (data, isBinary) => { - this.logger.debug(`WebSocket forwarding message from target to client: ${data.toString()}`); - if (wsIncoming.readyState === wsIncoming.OPEN) { - wsIncoming.send(data, { binary: isBinary }); - } else { - this.logger.warn(`WebSocket client connection not open (state: ${wsIncoming.readyState})`); - } - }); - - // Handle closing of connections - wsIncoming.on('close', (code, reason) => { - this.logger.debug(`WebSocket client connection closed: ${code} ${reason}`); - if (wsOutgoing.readyState === wsOutgoing.OPEN) { - // Ensure code is a valid WebSocket close code number - const validCode = typeof code === 'number' && code >= 1000 && code <= 4999 ? code : 1000; - try { - const reasonString = reason ? toBuffer(reason).toString() : ''; - wsOutgoing.close(validCode, reasonString); - } catch (err) { - this.logger.error('Error closing wsOutgoing:', err); - wsOutgoing.close(validCode); - } - } - - // Clean up timers - if (pingInterval) clearInterval(pingInterval); - if (pingTimeout) clearTimeout(pingTimeout); - }); - - wsOutgoing.on('close', (code, reason) => { - this.logger.debug(`WebSocket target connection closed: ${code} ${reason}`); - if (wsIncoming.readyState === wsIncoming.OPEN) { - // Ensure code is a valid WebSocket close code number - const validCode = typeof code === 'number' && code >= 1000 && code <= 4999 ? code : 1000; - try { - const reasonString = reason ? toBuffer(reason).toString() : ''; - wsIncoming.close(validCode, reasonString); - } catch (err) { - this.logger.error('Error closing wsIncoming:', err); - wsIncoming.close(validCode); - } - } - - // Clean up timers - if (pingInterval) clearInterval(pingInterval); - if (pingTimeout) clearTimeout(pingTimeout); - }); - - this.logger.debug(`WebSocket connection established: ${req.headers.host} -> ${destination.host}:${destination.port}`); - }); - - } catch (error) { - this.logger.error(`Error handling WebSocket connection: ${error.message}`); - if (wsIncoming.readyState === wsIncoming.OPEN) { - wsIncoming.close(1011, 'Internal server error'); - } - } - } - - /** - * Get information about active WebSocket connections - */ - public getConnectionInfo(): { activeConnections: number } { - return { - activeConnections: this.wsServer ? this.wsServer.clients.size : 0 - }; - } - - /** - * Shutdown the WebSocket handler - */ - public shutdown(): void { - // Stop heartbeat interval - if (this.heartbeatInterval) { - clearInterval(this.heartbeatInterval); - this.heartbeatInterval = null; - } - - // Close all WebSocket connections - if (this.wsServer) { - this.logger.info(`Closing ${this.wsServer.clients.size} WebSocket connections`); - - for (const client of this.wsServer.clients) { - try { - client.terminate(); - } catch (error) { - this.logger.error('Error terminating WebSocket client', error); - } - } - - // Close the server - this.wsServer.close(); - this.wsServer = null; - } - } -} \ No newline at end of file diff --git a/ts/proxies/index.ts b/ts/proxies/index.ts index f28dfde..3bdfae9 100644 --- a/ts/proxies/index.ts +++ b/ts/proxies/index.ts @@ -2,16 +2,8 @@ * Proxy implementations module */ -// Export HttpProxy with selective imports to avoid conflicts -export { HttpProxy, CertificateManager, ConnectionPool, RequestHandler, WebSocketHandler } from './http-proxy/index.js'; -export type { IMetricsTracker, MetricsTracker } from './http-proxy/index.js'; -// Export http-proxy models except IAcmeOptions -export type { IHttpProxyOptions, ICertificateEntry, ILogger } from './http-proxy/models/types.js'; -// RouteManager has been unified - use SharedRouteManager from core/routing -export { SharedRouteManager as HttpProxyRouteManager } from '../core/routing/route-manager.js'; - // Export SmartProxy with selective imports to avoid conflicts -export { SmartProxy, ConnectionManager, SecurityManager, TimeoutManager, TlsManager, HttpProxyBridge, RouteConnectionHandler } from './smart-proxy/index.js'; +export { SmartProxy } from './smart-proxy/index.js'; export { SharedRouteManager as SmartProxyRouteManager } from '../core/routing/route-manager.js'; export * from './smart-proxy/utils/index.js'; // Export smart-proxy models except IAcmeOptions diff --git a/ts/proxies/smart-proxy/acme-state-manager.ts b/ts/proxies/smart-proxy/acme-state-manager.ts deleted file mode 100644 index 3afd149..0000000 --- a/ts/proxies/smart-proxy/acme-state-manager.ts +++ /dev/null @@ -1,112 +0,0 @@ -import type { IRouteConfig } from './models/route-types.js'; - -/** - * Global state store for ACME operations - * Tracks active challenge routes and port allocations - */ -export class AcmeStateManager { - private activeChallengeRoutes: Map = new Map(); - private acmePortAllocations: Set = new Set(); - private primaryChallengeRoute: IRouteConfig | null = null; - - /** - * Check if a challenge route is active - */ - public isChallengeRouteActive(): boolean { - return this.activeChallengeRoutes.size > 0; - } - - /** - * Register a challenge route as active - */ - public addChallengeRoute(route: IRouteConfig): void { - this.activeChallengeRoutes.set(route.name, route); - - // Track the primary challenge route - if (!this.primaryChallengeRoute || route.priority > (this.primaryChallengeRoute.priority || 0)) { - this.primaryChallengeRoute = route; - } - - // Track port allocations - const ports = Array.isArray(route.match.ports) ? route.match.ports : [route.match.ports]; - ports.forEach(port => this.acmePortAllocations.add(port)); - } - - /** - * Remove a challenge route - */ - public removeChallengeRoute(routeName: string): void { - const route = this.activeChallengeRoutes.get(routeName); - if (!route) return; - - this.activeChallengeRoutes.delete(routeName); - - // Update primary challenge route if needed - if (this.primaryChallengeRoute?.name === routeName) { - this.primaryChallengeRoute = null; - // Find new primary route with highest priority - let highestPriority = -1; - for (const [_, activeRoute] of this.activeChallengeRoutes) { - const priority = activeRoute.priority || 0; - if (priority > highestPriority) { - highestPriority = priority; - this.primaryChallengeRoute = activeRoute; - } - } - } - - // Update port allocations - only remove if no other routes use this port - const ports = Array.isArray(route.match.ports) ? route.match.ports : [route.match.ports]; - ports.forEach(port => { - let portStillUsed = false; - for (const [_, activeRoute] of this.activeChallengeRoutes) { - const activePorts = Array.isArray(activeRoute.match.ports) ? - activeRoute.match.ports : [activeRoute.match.ports]; - if (activePorts.includes(port)) { - portStillUsed = true; - break; - } - } - if (!portStillUsed) { - this.acmePortAllocations.delete(port); - } - }); - } - - /** - * Get all active challenge routes - */ - public getActiveChallengeRoutes(): IRouteConfig[] { - return Array.from(this.activeChallengeRoutes.values()); - } - - /** - * Get the primary challenge route - */ - public getPrimaryChallengeRoute(): IRouteConfig | null { - return this.primaryChallengeRoute; - } - - /** - * Check if a port is allocated for ACME - */ - public isPortAllocatedForAcme(port: number): boolean { - return this.acmePortAllocations.has(port); - } - - /** - * Get all ACME ports - */ - public getAcmePorts(): number[] { - return Array.from(this.acmePortAllocations); - } - - /** - * Clear all state (for shutdown or reset) - */ - public clear(): void { - this.activeChallengeRoutes.clear(); - this.acmePortAllocations.clear(); - this.primaryChallengeRoute = null; - } -} \ No newline at end of file diff --git a/ts/proxies/smart-proxy/cert-store.ts b/ts/proxies/smart-proxy/cert-store.ts deleted file mode 100644 index 36641b3..0000000 --- a/ts/proxies/smart-proxy/cert-store.ts +++ /dev/null @@ -1,92 +0,0 @@ -import * as plugins from '../../plugins.js'; -import { AsyncFileSystem } from '../../core/utils/fs-utils.js'; -import type { ICertificateData } from './certificate-manager.js'; - -export class CertStore { - constructor(private certDir: string) {} - - public async initialize(): Promise { - await AsyncFileSystem.ensureDir(this.certDir); - } - - public async getCertificate(routeName: string): Promise { - const certPath = this.getCertPath(routeName); - const metaPath = `${certPath}/meta.json`; - - if (!await AsyncFileSystem.exists(metaPath)) { - return null; - } - - try { - const meta = await AsyncFileSystem.readJSON(metaPath); - - const [cert, key] = await Promise.all([ - AsyncFileSystem.readFile(`${certPath}/cert.pem`), - AsyncFileSystem.readFile(`${certPath}/key.pem`) - ]); - - let ca: string | undefined; - const caPath = `${certPath}/ca.pem`; - if (await AsyncFileSystem.exists(caPath)) { - ca = await AsyncFileSystem.readFile(caPath); - } - - return { - cert, - key, - ca, - expiryDate: new Date(meta.expiryDate), - issueDate: new Date(meta.issueDate) - }; - } catch (error) { - console.error(`Failed to load certificate for ${routeName}: ${error}`); - return null; - } - } - - public async saveCertificate( - routeName: string, - certData: ICertificateData - ): Promise { - const certPath = this.getCertPath(routeName); - await AsyncFileSystem.ensureDir(certPath); - - // Save certificate files in parallel - const savePromises = [ - AsyncFileSystem.writeFile(`${certPath}/cert.pem`, certData.cert), - AsyncFileSystem.writeFile(`${certPath}/key.pem`, certData.key) - ]; - - if (certData.ca) { - savePromises.push( - AsyncFileSystem.writeFile(`${certPath}/ca.pem`, certData.ca) - ); - } - - // Save metadata - const meta = { - expiryDate: certData.expiryDate.toISOString(), - issueDate: certData.issueDate.toISOString(), - savedAt: new Date().toISOString() - }; - - savePromises.push( - AsyncFileSystem.writeJSON(`${certPath}/meta.json`, meta) - ); - - await Promise.all(savePromises); - } - - public async deleteCertificate(routeName: string): Promise { - const certPath = this.getCertPath(routeName); - if (await AsyncFileSystem.isDirectory(certPath)) { - await AsyncFileSystem.removeDir(certPath); - } - } - - private getCertPath(routeName: string): string { - // Sanitize route name for filesystem - const safeName = routeName.replace(/[^a-zA-Z0-9-_]/g, '_'); - return `${this.certDir}/${safeName}`; - } -} \ No newline at end of file diff --git a/ts/proxies/smart-proxy/certificate-manager.ts b/ts/proxies/smart-proxy/certificate-manager.ts deleted file mode 100644 index e6c287c..0000000 --- a/ts/proxies/smart-proxy/certificate-manager.ts +++ /dev/null @@ -1,895 +0,0 @@ -import * as plugins from '../../plugins.js'; -import { HttpProxy } from '../http-proxy/index.js'; -import type { IRouteConfig, IRouteTls } from './models/route-types.js'; -import type { IAcmeOptions } from './models/interfaces.js'; -import { CertStore } from './cert-store.js'; -import type { AcmeStateManager } from './acme-state-manager.js'; -import { logger } from '../../core/utils/logger.js'; -import { SocketHandlers } from './utils/route-helpers.js'; - -export interface ICertStatus { - domain: string; - status: 'valid' | 'pending' | 'expired' | 'error'; - expiryDate?: Date; - issueDate?: Date; - source: 'static' | 'acme' | 'custom'; - error?: string; -} - -export interface ICertificateData { - cert: string; - key: string; - ca?: string; - expiryDate: Date; - issueDate: Date; - source?: 'static' | 'acme' | 'custom'; -} - -export class SmartCertManager { - private certStore: CertStore; - private smartAcme: plugins.smartacme.SmartAcme | null = null; - private httpProxy: HttpProxy | null = null; - private renewalTimer: NodeJS.Timeout | null = null; - private pendingChallenges: Map = new Map(); - private challengeRoute: IRouteConfig | null = null; - - // Track certificate status by route name - private certStatus: Map = new Map(); - - // Global ACME defaults from top-level configuration - private globalAcmeDefaults: IAcmeOptions | null = null; - - // Callback to update SmartProxy routes for challenges - private updateRoutesCallback?: (routes: IRouteConfig[]) => Promise; - - // Flag to track if challenge route is currently active - private challengeRouteActive: boolean = false; - - // Flag to track if provisioning is in progress - private isProvisioning: boolean = false; - - // ACME state manager reference - private acmeStateManager: AcmeStateManager | null = null; - - // Custom certificate provision function - private certProvisionFunction?: (domain: string) => Promise; - - // Whether to fallback to ACME if custom provision fails - private certProvisionFallbackToAcme: boolean = true; - - constructor( - private routes: IRouteConfig[], - private certDir: string = './certs', - private acmeOptions?: { - email?: string; - useProduction?: boolean; - port?: number; - }, - private initialState?: { - challengeRouteActive?: boolean; - } - ) { - this.certStore = new CertStore(certDir); - - // Apply initial state if provided - if (initialState) { - this.challengeRouteActive = initialState.challengeRouteActive || false; - } - } - - public setHttpProxy(httpProxy: HttpProxy): void { - this.httpProxy = httpProxy; - } - - - /** - * Set the ACME state manager - */ - public setAcmeStateManager(stateManager: AcmeStateManager): void { - this.acmeStateManager = stateManager; - } - - /** - * Set global ACME defaults from top-level configuration - */ - public setGlobalAcmeDefaults(defaults: IAcmeOptions): void { - this.globalAcmeDefaults = defaults; - } - - /** - * Set custom certificate provision function - */ - public setCertProvisionFunction(fn: (domain: string) => Promise): void { - this.certProvisionFunction = fn; - } - - /** - * Set whether to fallback to ACME if custom provision fails - */ - public setCertProvisionFallbackToAcme(fallback: boolean): void { - this.certProvisionFallbackToAcme = fallback; - } - - /** - * Update the routes array to keep it in sync with SmartProxy - * This prevents stale route data when adding/removing challenge routes - */ - public setRoutes(routes: IRouteConfig[]): void { - this.routes = routes; - } - - /** - * Set callback for updating routes (used for challenge routes) - */ - public setUpdateRoutesCallback(callback: (routes: IRouteConfig[]) => Promise): void { - this.updateRoutesCallback = callback; - try { - logger.log('debug', 'Route update callback set successfully', { component: 'certificate-manager' }); - } catch (error) { - // Silently handle logging errors - console.log('[DEBUG] Route update callback set successfully'); - } - } - - /** - * Initialize certificate manager and provision certificates for all routes - */ - public async initialize(): Promise { - // Create certificate directory if it doesn't exist - await this.certStore.initialize(); - - // Initialize SmartAcme if we have any ACME routes - const hasAcmeRoutes = this.routes.some(r => - r.action.tls?.certificate === 'auto' - ); - - if (hasAcmeRoutes && this.acmeOptions?.email) { - // Create HTTP-01 challenge handler - const http01Handler = new plugins.smartacme.handlers.Http01MemoryHandler(); - - // Set up challenge handler integration with our routing - this.setupChallengeHandler(http01Handler); - - // Create SmartAcme instance with built-in MemoryCertManager and HTTP-01 handler - this.smartAcme = new plugins.smartacme.SmartAcme({ - accountEmail: this.acmeOptions.email, - environment: this.acmeOptions.useProduction ? 'production' : 'integration', - certManager: new plugins.smartacme.certmanagers.MemoryCertManager(), - challengeHandlers: [http01Handler] - }); - - await this.smartAcme.start(); - - // Add challenge route once at initialization if not already active - if (!this.challengeRouteActive) { - logger.log('info', 'Adding ACME challenge route during initialization', { component: 'certificate-manager' }); - await this.addChallengeRoute(); - } else { - logger.log('info', 'Challenge route already active from previous instance', { component: 'certificate-manager' }); - } - } - - // Skip automatic certificate provisioning during initialization - // This will be called later after ports are listening - logger.log('info', 'Certificate manager initialized. Deferring certificate provisioning until after ports are listening.', { component: 'certificate-manager' }); - - // Start renewal timer - this.startRenewalTimer(); - } - - /** - * Provision certificates for all routes that need them - */ - public async provisionAllCertificates(): Promise { - const certRoutes = this.routes.filter(r => - r.action.tls?.mode === 'terminate' || - r.action.tls?.mode === 'terminate-and-reencrypt' - ); - - // Set provisioning flag to prevent concurrent operations - this.isProvisioning = true; - - try { - for (const route of certRoutes) { - try { - await this.provisionCertificate(route, true); // Allow concurrent since we're managing it here - } catch (error) { - logger.log('error', `Failed to provision certificate for route ${route.name}`, { routeName: route.name, error, component: 'certificate-manager' }); - } - } - } finally { - this.isProvisioning = false; - } - } - - /** - * Provision certificate for a single route - */ - public async provisionCertificate(route: IRouteConfig, allowConcurrent: boolean = false): Promise { - const tls = route.action.tls; - if (!tls || (tls.mode !== 'terminate' && tls.mode !== 'terminate-and-reencrypt')) { - return; - } - - // Check if provisioning is already in progress (prevent concurrent provisioning) - if (!allowConcurrent && this.isProvisioning) { - logger.log('info', `Certificate provisioning already in progress, skipping ${route.name}`, { routeName: route.name, component: 'certificate-manager' }); - return; - } - - const domains = this.extractDomainsFromRoute(route); - if (domains.length === 0) { - logger.log('warn', `Route ${route.name} has TLS termination but no domains`, { routeName: route.name, component: 'certificate-manager' }); - return; - } - - const primaryDomain = domains[0]; - - if (tls.certificate === 'auto') { - // ACME certificate - await this.provisionAcmeCertificate(route, domains); - } else if (typeof tls.certificate === 'object') { - // Static certificate - await this.provisionStaticCertificate(route, primaryDomain, tls.certificate); - } - } - - /** - * Provision ACME certificate - */ - private async provisionAcmeCertificate( - route: IRouteConfig, - domains: string[] - ): Promise { - const primaryDomain = domains[0]; - const routeName = route.name || primaryDomain; - - // Check if we already have a valid certificate - const existingCert = await this.certStore.getCertificate(routeName); - if (existingCert && this.isCertificateValid(existingCert)) { - logger.log('info', `Using existing valid certificate for ${primaryDomain}`, { domain: primaryDomain, component: 'certificate-manager' }); - await this.applyCertificate(primaryDomain, existingCert); - this.updateCertStatus(routeName, 'valid', existingCert.source || 'acme', existingCert); - return; - } - - // Check for custom provision function first - if (this.certProvisionFunction) { - try { - logger.log('info', `Attempting custom certificate provision for ${primaryDomain}`, { domain: primaryDomain, component: 'certificate-manager' }); - const result = await this.certProvisionFunction(primaryDomain); - - if (result === 'http01') { - logger.log('info', `Custom function returned 'http01', falling back to Let's Encrypt for ${primaryDomain}`, { domain: primaryDomain, component: 'certificate-manager' }); - // Continue with existing ACME logic below - } else { - // Use custom certificate - const customCert = result as plugins.tsclass.network.ICert; - - // Convert to internal certificate format - const certData: ICertificateData = { - cert: customCert.publicKey, - key: customCert.privateKey, - ca: '', - issueDate: new Date(), - expiryDate: this.extractExpiryDate(customCert.publicKey), - source: 'custom' - }; - - // Store and apply certificate - await this.certStore.saveCertificate(routeName, certData); - await this.applyCertificate(primaryDomain, certData); - this.updateCertStatus(routeName, 'valid', 'custom', certData); - - logger.log('info', `Custom certificate applied for ${primaryDomain}`, { - domain: primaryDomain, - expiryDate: certData.expiryDate, - component: 'certificate-manager' - }); - return; - } - } catch (error) { - logger.log('error', `Custom cert provision failed for ${primaryDomain}: ${error.message}`, { - domain: primaryDomain, - error: error.message, - component: 'certificate-manager' - }); - // Check if we should fallback to ACME - if (!this.certProvisionFallbackToAcme) { - throw error; - } - logger.log('info', `Falling back to Let's Encrypt for ${primaryDomain}`, { domain: primaryDomain, component: 'certificate-manager' }); - } - } - - if (!this.smartAcme) { - throw new Error( - 'SmartAcme not initialized. This usually means no ACME email was provided. ' + - 'Please ensure you have configured ACME with an email address either:\n' + - '1. In the top-level "acme" configuration\n' + - '2. In the route\'s "tls.acme" configuration' - ); - } - - // Apply renewal threshold from global defaults or route config - const renewThreshold = route.action.tls?.acme?.renewBeforeDays || - this.globalAcmeDefaults?.renewThresholdDays || - 30; - - logger.log('info', `Requesting ACME certificate for ${domains.join(', ')} (renew ${renewThreshold} days before expiry)`, { domains: domains.join(', '), renewThreshold, component: 'certificate-manager' }); - this.updateCertStatus(routeName, 'pending', 'acme'); - - try { - // Challenge route should already be active from initialization - // No need to add it for each certificate - - // Determine if we should request a wildcard certificate - // Only request wildcards if: - // 1. The primary domain is not already a wildcard - // 2. The domain has multiple parts (can have subdomains) - // 3. We have DNS-01 challenge support (required for wildcards) - const hasDnsChallenge = (this.smartAcme as any).challengeHandlers?.some((handler: any) => - handler.getSupportedTypes && handler.getSupportedTypes().includes('dns-01') - ); - - const shouldIncludeWildcard = !primaryDomain.startsWith('*.') && - primaryDomain.includes('.') && - primaryDomain.split('.').length >= 2 && - hasDnsChallenge; - - if (shouldIncludeWildcard) { - logger.log('info', `Requesting wildcard certificate for ${primaryDomain} (DNS-01 available)`, { domain: primaryDomain, challengeType: 'DNS-01', component: 'certificate-manager' }); - } - - // Use smartacme to get certificate with optional wildcard - const cert = await this.smartAcme.getCertificateForDomain( - primaryDomain, - shouldIncludeWildcard ? { includeWildcard: true } : undefined - ); - - // SmartAcme's Cert object has these properties: - // - publicKey: The certificate PEM string - // - privateKey: The private key PEM string - // - csr: Certificate signing request - // - validUntil: Timestamp in milliseconds - // - domainName: The domain name - const certData: ICertificateData = { - cert: cert.publicKey, - key: cert.privateKey, - ca: cert.publicKey, // Use same as cert for now - expiryDate: new Date(cert.validUntil), - issueDate: new Date(cert.created), - source: 'acme' - }; - - await this.certStore.saveCertificate(routeName, certData); - await this.applyCertificate(primaryDomain, certData); - this.updateCertStatus(routeName, 'valid', 'acme', certData); - - logger.log('info', `Successfully provisioned ACME certificate for ${primaryDomain}`, { domain: primaryDomain, component: 'certificate-manager' }); - } catch (error) { - logger.log('error', `Failed to provision ACME certificate for ${primaryDomain}: ${error.message}`, { domain: primaryDomain, error: error.message, component: 'certificate-manager' }); - this.updateCertStatus(routeName, 'error', 'acme', undefined, error.message); - throw error; - } - } - - /** - * Provision static certificate - */ - private async provisionStaticCertificate( - route: IRouteConfig, - domain: string, - certConfig: { key: string; cert: string; keyFile?: string; certFile?: string } - ): Promise { - const routeName = route.name || domain; - - try { - let key: string = certConfig.key; - let cert: string = certConfig.cert; - - // Load from files if paths are provided - const smartFileFactory = plugins.smartfile.SmartFileFactory.nodeFs(); - if (certConfig.keyFile) { - const keyFile = await smartFileFactory.fromFilePath(certConfig.keyFile); - key = keyFile.contents.toString(); - } - if (certConfig.certFile) { - const certFile = await smartFileFactory.fromFilePath(certConfig.certFile); - cert = certFile.contents.toString(); - } - - // Parse certificate to get dates - const expiryDate = this.extractExpiryDate(cert); - const issueDate = new Date(); // Current date as issue date - - const certData: ICertificateData = { - cert, - key, - expiryDate, - issueDate, - source: 'static' - }; - - // Save to store for consistency - await this.certStore.saveCertificate(routeName, certData); - await this.applyCertificate(domain, certData); - this.updateCertStatus(routeName, 'valid', 'static', certData); - - logger.log('info', `Successfully loaded static certificate for ${domain}`, { domain, component: 'certificate-manager' }); - } catch (error) { - logger.log('error', `Failed to provision static certificate for ${domain}: ${error.message}`, { domain, error: error.message, component: 'certificate-manager' }); - this.updateCertStatus(routeName, 'error', 'static', undefined, error.message); - throw error; - } - } - - /** - * Apply certificate to HttpProxy - */ - private async applyCertificate(domain: string, certData: ICertificateData): Promise { - if (!this.httpProxy) { - logger.log('warn', `HttpProxy not set, cannot apply certificate for domain ${domain}`, { domain, component: 'certificate-manager' }); - return; - } - - // Apply certificate to HttpProxy - this.httpProxy.updateCertificate(domain, certData.cert, certData.key); - - // Also apply for wildcard if it's a subdomain - if (domain.includes('.') && !domain.startsWith('*.')) { - const parts = domain.split('.'); - if (parts.length >= 2) { - const wildcardDomain = `*.${parts.slice(-2).join('.')}`; - this.httpProxy.updateCertificate(wildcardDomain, certData.cert, certData.key); - } - } - } - - /** - * Extract domains from route configuration - */ - private extractDomainsFromRoute(route: IRouteConfig): string[] { - if (!route.match.domains) { - return []; - } - - const domains = Array.isArray(route.match.domains) - ? route.match.domains - : [route.match.domains]; - - // Filter out wildcards and patterns - return domains.filter(d => - !d.includes('*') && - !d.includes('{') && - d.includes('.') - ); - } - - /** - * Check if certificate is valid - */ - private isCertificateValid(cert: ICertificateData): boolean { - const now = new Date(); - - // Use renewal threshold from global defaults or fallback to 30 days - const renewThresholdDays = this.globalAcmeDefaults?.renewThresholdDays || 30; - const expiryThreshold = new Date(now.getTime() + renewThresholdDays * 24 * 60 * 60 * 1000); - - return cert.expiryDate > expiryThreshold; - } - - /** - * Extract expiry date from a PEM certificate - */ - private extractExpiryDate(_certPem: string): Date { - // For now, we'll default to 90 days for custom certificates - // In production, you might want to use a proper X.509 parser - // or require the custom cert provider to include expiry info - logger.log('info', 'Using default 90-day expiry for custom certificate', { - component: 'certificate-manager' - }); - return new Date(Date.now() + 90 * 24 * 60 * 60 * 1000); - } - - - /** - * Add challenge route to SmartProxy - * - * This method adds a special route for ACME HTTP-01 challenges, which typically uses port 80. - * Since we may already be listening on port 80 for regular routes, we need to be - * careful about how we add this route to avoid binding conflicts. - */ - private async addChallengeRoute(): Promise { - // Check with state manager first - avoid duplication - if (this.acmeStateManager && this.acmeStateManager.isChallengeRouteActive()) { - try { - logger.log('info', 'Challenge route already active in global state, skipping', { component: 'certificate-manager' }); - } catch (error) { - // Silently handle logging errors - console.log('[INFO] Challenge route already active in global state, skipping'); - } - this.challengeRouteActive = true; - return; - } - - if (this.challengeRouteActive) { - try { - logger.log('info', 'Challenge route already active locally, skipping', { component: 'certificate-manager' }); - } catch (error) { - // Silently handle logging errors - console.log('[INFO] Challenge route already active locally, skipping'); - } - return; - } - - if (!this.updateRoutesCallback) { - throw new Error('No route update callback set'); - } - - if (!this.challengeRoute) { - throw new Error('Challenge route not initialized'); - } - - // Get the challenge port - const challengePort = this.globalAcmeDefaults?.port || 80; - - // Check if any existing routes are already using this port - // This helps us determine if we need to create a new binding or can reuse existing one - const portInUseByRoutes = this.routes.some(route => { - const routePorts = Array.isArray(route.match.ports) ? route.match.ports : [route.match.ports]; - return routePorts.some(p => { - // Handle both number and port range objects - if (typeof p === 'number') { - return p === challengePort; - } else if (typeof p === 'object' && 'from' in p && 'to' in p) { - // Port range case - check if challengePort is in range - return challengePort >= p.from && challengePort <= p.to; - } - return false; - }); - }); - - try { - // Log whether port is already in use by other routes - if (portInUseByRoutes) { - try { - logger.log('info', `Port ${challengePort} is already used by another route, merging ACME challenge route`, { - port: challengePort, - component: 'certificate-manager' - }); - } catch (error) { - // Silently handle logging errors - console.log(`[INFO] Port ${challengePort} is already used by another route, merging ACME challenge route`); - } - } else { - try { - logger.log('info', `Adding new ACME challenge route on port ${challengePort}`, { - port: challengePort, - component: 'certificate-manager' - }); - } catch (error) { - // Silently handle logging errors - console.log(`[INFO] Adding new ACME challenge route on port ${challengePort}`); - } - } - - // Add the challenge route to the existing routes - const challengeRoute = this.challengeRoute; - const updatedRoutes = [...this.routes, challengeRoute]; - - // With the re-ordering of start(), port binding should already be done - // This updateRoutes call should just add the route without binding again - await this.updateRoutesCallback(updatedRoutes); - // Keep local routes in sync after updating - this.routes = updatedRoutes; - this.challengeRouteActive = true; - - // Register with state manager - if (this.acmeStateManager) { - this.acmeStateManager.addChallengeRoute(challengeRoute); - } - - try { - logger.log('info', 'ACME challenge route successfully added', { component: 'certificate-manager' }); - } catch (error) { - // Silently handle logging errors - console.log('[INFO] ACME challenge route successfully added'); - } - } catch (error) { - // Enhanced error handling based on error type - if ((error as any).code === 'EADDRINUSE') { - try { - logger.log('warn', `Challenge port ${challengePort} is unavailable - it's already in use by another process. Consider configuring a different ACME port.`, { - port: challengePort, - error: (error as Error).message, - component: 'certificate-manager' - }); - } catch (logError) { - // Silently handle logging errors - console.log(`[WARN] Challenge port ${challengePort} is unavailable - it's already in use by another process. Consider configuring a different ACME port.`); - } - - // Provide a more informative and actionable error message - throw new Error( - `ACME HTTP-01 challenge port ${challengePort} is already in use by another process. ` + - `Please configure a different port using the acme.port setting (e.g., 8080).` - ); - } else if (error.message && error.message.includes('EADDRINUSE')) { - // Some Node.js versions embed the error code in the message rather than the code property - try { - logger.log('warn', `Port ${challengePort} conflict detected: ${error.message}`, { - port: challengePort, - component: 'certificate-manager' - }); - } catch (logError) { - // Silently handle logging errors - console.log(`[WARN] Port ${challengePort} conflict detected: ${error.message}`); - } - - // More detailed error message with suggestions - throw new Error( - `ACME HTTP challenge port ${challengePort} conflict detected. ` + - `To resolve this issue, try one of these approaches:\n` + - `1. Configure a different port in ACME settings (acme.port)\n` + - `2. Add a regular route that uses port ${challengePort} before initializing the certificate manager\n` + - `3. Stop any other services that might be using port ${challengePort}` - ); - } - - // Log and rethrow other types of errors - try { - logger.log('error', `Failed to add challenge route: ${(error as Error).message}`, { - error: (error as Error).message, - component: 'certificate-manager' - }); - } catch (logError) { - // Silently handle logging errors - console.log(`[ERROR] Failed to add challenge route: ${(error as Error).message}`); - } - throw error; - } - } - - /** - * Remove challenge route from SmartProxy - */ - private async removeChallengeRoute(): Promise { - if (!this.challengeRouteActive) { - try { - logger.log('info', 'Challenge route not active, skipping removal', { component: 'certificate-manager' }); - } catch (error) { - // Silently handle logging errors - console.log('[INFO] Challenge route not active, skipping removal'); - } - return; - } - - if (!this.updateRoutesCallback) { - return; - } - - try { - const filteredRoutes = this.routes.filter(r => r.name !== 'acme-challenge'); - await this.updateRoutesCallback(filteredRoutes); - // Keep local routes in sync after updating - this.routes = filteredRoutes; - this.challengeRouteActive = false; - - // Remove from state manager - if (this.acmeStateManager) { - this.acmeStateManager.removeChallengeRoute('acme-challenge'); - } - - try { - logger.log('info', 'ACME challenge route successfully removed', { component: 'certificate-manager' }); - } catch (error) { - // Silently handle logging errors - console.log('[INFO] ACME challenge route successfully removed'); - } - } catch (error) { - try { - logger.log('error', `Failed to remove challenge route: ${error.message}`, { error: error.message, component: 'certificate-manager' }); - } catch (logError) { - // Silently handle logging errors - console.log(`[ERROR] Failed to remove challenge route: ${error.message}`); - } - // Reset the flag even on error to avoid getting stuck - this.challengeRouteActive = false; - throw error; - } - } - - /** - * Start renewal timer - */ - private startRenewalTimer(): void { - // Check for renewals every 12 hours - this.renewalTimer = setInterval(() => { - this.checkAndRenewCertificates(); - }, 12 * 60 * 60 * 1000); - - // Unref the timer so it doesn't keep the process alive - if (this.renewalTimer.unref) { - this.renewalTimer.unref(); - } - - // Also do an immediate check - this.checkAndRenewCertificates(); - } - - /** - * Check and renew certificates that are expiring - */ - private async checkAndRenewCertificates(): Promise { - for (const route of this.routes) { - if (route.action.tls?.certificate === 'auto') { - const routeName = route.name || this.extractDomainsFromRoute(route)[0]; - const cert = await this.certStore.getCertificate(routeName); - - if (cert && !this.isCertificateValid(cert)) { - logger.log('info', `Certificate for ${routeName} needs renewal`, { routeName, component: 'certificate-manager' }); - try { - await this.provisionCertificate(route); - } catch (error) { - logger.log('error', `Failed to renew certificate for ${routeName}: ${error.message}`, { routeName, error: error.message, component: 'certificate-manager' }); - } - } - } - } - } - - /** - * Update certificate status - */ - private updateCertStatus( - routeName: string, - status: ICertStatus['status'], - source: ICertStatus['source'], - certData?: ICertificateData, - error?: string - ): void { - this.certStatus.set(routeName, { - domain: routeName, - status, - source, - expiryDate: certData?.expiryDate, - issueDate: certData?.issueDate, - error - }); - } - - /** - * Get certificate status for a route - */ - public getCertificateStatus(routeName: string): ICertStatus | undefined { - return this.certStatus.get(routeName); - } - - /** - * Force renewal of a certificate - */ - public async renewCertificate(routeName: string): Promise { - const route = this.routes.find(r => r.name === routeName); - if (!route) { - throw new Error(`Route ${routeName} not found`); - } - - // Remove existing certificate to force renewal - await this.certStore.deleteCertificate(routeName); - await this.provisionCertificate(route); - } - - /** - * Setup challenge handler integration with SmartProxy routing - */ - private setupChallengeHandler(http01Handler: plugins.smartacme.handlers.Http01MemoryHandler): void { - // Use challenge port from global config or default to 80 - const challengePort = this.globalAcmeDefaults?.port || 80; - - // Create a challenge route that delegates to SmartAcme's HTTP-01 handler - const challengeRoute: IRouteConfig = { - name: 'acme-challenge', - priority: 1000, // High priority - match: { - ports: challengePort, - path: '/.well-known/acme-challenge/*' - }, - action: { - type: 'socket-handler', - socketHandler: SocketHandlers.httpServer((req, res) => { - // Extract the token from the path - const token = req.url?.split('/').pop(); - if (!token) { - res.status(404); - res.send('Not found'); - return; - } - - // Create mock request/response objects for SmartAcme - let responseData: any = null; - const mockReq = { - url: req.url, - method: req.method, - headers: req.headers - }; - - const mockRes = { - statusCode: 200, - setHeader: (name: string, value: string) => {}, - end: (data: any) => { - responseData = data; - } - }; - - // Use SmartAcme's handler - const handleAcme = () => { - http01Handler.handleRequest(mockReq as any, mockRes as any, () => { - // Not handled by ACME - res.status(404); - res.send('Not found'); - }); - - // Give it a moment to process, then send response - setTimeout(() => { - if (responseData) { - res.header('Content-Type', 'text/plain'); - res.send(String(responseData)); - } else { - res.status(404); - res.send('Not found'); - } - }, 100); - }; - - handleAcme(); - }) - } - }; - - // Store the challenge route to add it when needed - this.challengeRoute = challengeRoute; - } - - /** - * Stop certificate manager - */ - public async stop(): Promise { - if (this.renewalTimer) { - clearInterval(this.renewalTimer); - this.renewalTimer = null; - } - - // Always remove challenge route on shutdown - if (this.challengeRoute) { - logger.log('info', 'Removing ACME challenge route during shutdown', { component: 'certificate-manager' }); - await this.removeChallengeRoute(); - } - - if (this.smartAcme) { - await this.smartAcme.stop(); - } - - // Clear any pending challenges - if (this.pendingChallenges.size > 0) { - this.pendingChallenges.clear(); - } - } - - /** - * Get ACME options (for recreating after route updates) - */ - public getAcmeOptions(): { email?: string; useProduction?: boolean; port?: number } | undefined { - return this.acmeOptions; - } - - /** - * Get certificate manager state - */ - public getState(): { challengeRouteActive: boolean } { - return { - challengeRouteActive: this.challengeRouteActive - }; - } -} - diff --git a/ts/proxies/smart-proxy/connection-manager.ts b/ts/proxies/smart-proxy/connection-manager.ts deleted file mode 100644 index f633cad..0000000 --- a/ts/proxies/smart-proxy/connection-manager.ts +++ /dev/null @@ -1,809 +0,0 @@ -import * as plugins from '../../plugins.js'; -import type { IConnectionRecord } from './models/interfaces.js'; -import { logger } from '../../core/utils/logger.js'; -import { connectionLogDeduplicator } from '../../core/utils/log-deduplicator.js'; -import { LifecycleComponent } from '../../core/utils/lifecycle-component.js'; -import { cleanupSocket } from '../../core/utils/socket-utils.js'; -import { WrappedSocket } from '../../core/models/wrapped-socket.js'; -import { ProtocolDetector } from '../../detection/index.js'; -import type { SmartProxy } from './smart-proxy.js'; - -/** - * Manages connection lifecycle, tracking, and cleanup with performance optimizations - */ -export class ConnectionManager extends LifecycleComponent { - private connectionRecords: Map = new Map(); - private terminationStats: { - incoming: Record; - outgoing: Record; - } = { incoming: {}, outgoing: {} }; - - // Performance optimization: Track connections needing inactivity check - private nextInactivityCheck: Map = new Map(); - - // Connection limits - private readonly maxConnections: number; - private readonly cleanupBatchSize: number = 100; - - // Cleanup queue for batched processing - private cleanupQueue: Set = new Set(); - private cleanupTimer: NodeJS.Timeout | null = null; - private isProcessingCleanup: boolean = false; - - // Route-level connection tracking - private connectionsByRoute: Map> = new Map(); - - constructor( - private smartProxy: SmartProxy - ) { - super(); - - // Set reasonable defaults for connection limits - this.maxConnections = smartProxy.settings.defaults?.security?.maxConnections || 10000; - - // Start inactivity check timer if not disabled - if (!smartProxy.settings.disableInactivityCheck) { - this.startInactivityCheckTimer(); - } - } - - /** - * Generate a unique connection ID - */ - public generateConnectionId(): string { - return Math.random().toString(36).substring(2, 15) + - Math.random().toString(36).substring(2, 15); - } - - /** - * Create and track a new connection - * Accepts either a regular net.Socket or a WrappedSocket for transparent PROXY protocol support - * - * @param socket - The socket for the connection - * @param options - Optional configuration - * @param options.connectionId - Pre-generated connection ID (for atomic IP tracking) - * @param options.skipIpTracking - Skip IP tracking (if already done atomically) - */ - public createConnection( - socket: plugins.net.Socket | WrappedSocket, - options?: { connectionId?: string; skipIpTracking?: boolean } - ): IConnectionRecord | null { - // Enforce connection limit - if (this.connectionRecords.size >= this.maxConnections) { - // Use deduplicated logging for connection limit - connectionLogDeduplicator.log( - 'connection-rejected', - 'warn', - 'Global connection limit reached', - { - reason: 'global-limit', - currentConnections: this.connectionRecords.size, - maxConnections: this.maxConnections, - component: 'connection-manager' - }, - 'global-limit' - ); - socket.destroy(); - return null; - } - - const connectionId = options?.connectionId || this.generateConnectionId(); - const remoteIP = socket.remoteAddress || ''; - const remotePort = socket.remotePort || 0; - const localPort = socket.localPort || 0; - const now = Date.now(); - - const record: IConnectionRecord = { - id: connectionId, - incoming: socket, - outgoing: null, - incomingStartTime: now, - lastActivity: now, - connectionClosed: false, - pendingData: [], - pendingDataSize: 0, - bytesReceived: 0, - bytesSent: 0, - remoteIP, - remotePort, - localPort, - isTLS: false, - tlsHandshakeComplete: false, - hasReceivedInitialData: false, - hasKeepAlive: false, - incomingTerminationReason: null, - outgoingTerminationReason: null, - usingNetworkProxy: false, - isBrowserConnection: false, - domainSwitches: 0 - }; - - this.trackConnection(connectionId, record, options?.skipIpTracking); - return record; - } - - /** - * Track an existing connection - * @param connectionId - The connection ID - * @param record - The connection record - * @param skipIpTracking - Skip IP tracking if already done atomically - */ - public trackConnection(connectionId: string, record: IConnectionRecord, skipIpTracking?: boolean): void { - this.connectionRecords.set(connectionId, record); - if (!skipIpTracking) { - this.smartProxy.securityManager.trackConnectionByIP(record.remoteIP, connectionId); - } - - // Schedule inactivity check - if (!this.smartProxy.settings.disableInactivityCheck) { - this.scheduleInactivityCheck(connectionId, record); - } - } - - /** - * Schedule next inactivity check for a connection - */ - private scheduleInactivityCheck(connectionId: string, record: IConnectionRecord): void { - let timeout = this.smartProxy.settings.inactivityTimeout!; - - if (record.hasKeepAlive) { - if (this.smartProxy.settings.keepAliveTreatment === 'immortal') { - // Don't schedule check for immortal connections - return; - } else if (this.smartProxy.settings.keepAliveTreatment === 'extended') { - const multiplier = this.smartProxy.settings.keepAliveInactivityMultiplier || 6; - timeout = timeout * multiplier; - } - } - - const checkTime = Date.now() + timeout; - this.nextInactivityCheck.set(connectionId, checkTime); - } - - /** - * Start the inactivity check timer - */ - private startInactivityCheckTimer(): void { - // Check more frequently (every 10 seconds) to catch zombies and stuck connections faster - this.setInterval(() => { - this.performOptimizedInactivityCheck(); - }, 10000); - // Note: LifecycleComponent's setInterval already calls unref() - } - - /** - * Get a connection by ID - */ - public getConnection(connectionId: string): IConnectionRecord | undefined { - return this.connectionRecords.get(connectionId); - } - - /** - * Get all active connections - */ - public getConnections(): Map { - return this.connectionRecords; - } - - /** - * Get count of active connections - */ - public getConnectionCount(): number { - return this.connectionRecords.size; - } - - /** - * Track connection by route - */ - public trackConnectionByRoute(routeId: string, connectionId: string): void { - if (!this.connectionsByRoute.has(routeId)) { - this.connectionsByRoute.set(routeId, new Set()); - } - this.connectionsByRoute.get(routeId)!.add(connectionId); - } - - /** - * Remove connection tracking for a route - */ - public removeConnectionByRoute(routeId: string, connectionId: string): void { - if (this.connectionsByRoute.has(routeId)) { - const connections = this.connectionsByRoute.get(routeId)!; - connections.delete(connectionId); - if (connections.size === 0) { - this.connectionsByRoute.delete(routeId); - } - } - } - - /** - * Get connection count by route - */ - public getConnectionCountByRoute(routeId: string): number { - return this.connectionsByRoute.get(routeId)?.size || 0; - } - - /** - * Initiates cleanup once for a connection - */ - public initiateCleanupOnce(record: IConnectionRecord, reason: string = 'normal'): void { - // Use deduplicated logging for cleanup events - connectionLogDeduplicator.log( - 'connection-cleanup', - 'info', - `Connection cleanup: ${reason}`, - { - connectionId: record.id, - remoteIP: record.remoteIP, - reason, - component: 'connection-manager' - }, - reason - ); - - if (record.incomingTerminationReason == null) { - record.incomingTerminationReason = reason; - this.incrementTerminationStat('incoming', reason); - } - - // Add to cleanup queue for batched processing - this.queueCleanup(record.id); - } - - /** - * Queue a connection for cleanup - */ - private queueCleanup(connectionId: string): void { - // Check if connection is already being processed - const record = this.connectionRecords.get(connectionId); - if (!record || record.connectionClosed) { - // Already cleaned up or doesn't exist, skip - return; - } - - this.cleanupQueue.add(connectionId); - - // Process immediately if queue is getting large and not already processing - if (this.cleanupQueue.size >= this.cleanupBatchSize && !this.isProcessingCleanup) { - this.processCleanupQueue(); - } else if (!this.cleanupTimer && !this.isProcessingCleanup) { - // Otherwise, schedule batch processing - this.cleanupTimer = this.setTimeout(() => { - this.processCleanupQueue(); - }, 100); - } - } - - /** - * Process the cleanup queue in batches - */ - private processCleanupQueue(): void { - // Prevent concurrent processing - if (this.isProcessingCleanup) { - return; - } - - this.isProcessingCleanup = true; - - if (this.cleanupTimer) { - this.clearTimeout(this.cleanupTimer); - this.cleanupTimer = null; - } - - try { - // Take a snapshot of items to process - const toCleanup = Array.from(this.cleanupQueue).slice(0, this.cleanupBatchSize); - - // Remove only the items we're processing from the queue - for (const connectionId of toCleanup) { - this.cleanupQueue.delete(connectionId); - const record = this.connectionRecords.get(connectionId); - if (record) { - this.cleanupConnection(record, record.incomingTerminationReason || 'normal'); - } - } - } finally { - // Always reset the processing flag - this.isProcessingCleanup = false; - - // Check if more items were added while we were processing - if (this.cleanupQueue.size > 0) { - this.cleanupTimer = this.setTimeout(() => { - this.processCleanupQueue(); - }, 10); - } - } - } - - /** - * Clean up a connection record - */ - public cleanupConnection(record: IConnectionRecord, reason: string = 'normal'): void { - if (!record.connectionClosed) { - record.connectionClosed = true; - - // Remove from inactivity check - this.nextInactivityCheck.delete(record.id); - - // Track connection termination - this.smartProxy.securityManager.removeConnectionByIP(record.remoteIP, record.id); - - // Remove from route tracking - if (record.routeId) { - this.removeConnectionByRoute(record.routeId, record.id); - } - - // Remove from metrics tracking - if (this.smartProxy.metricsCollector) { - this.smartProxy.metricsCollector.removeConnection(record.id); - } - - // Clean up protocol detection fragments - const context = ProtocolDetector.createConnectionContext({ - sourceIp: record.remoteIP, - sourcePort: record.incoming?.remotePort || 0, - destIp: record.incoming?.localAddress || '', - destPort: record.localPort, - socketId: record.id - }); - - // Clean up any pending detection fragments for this connection - ProtocolDetector.cleanupConnection(context); - - if (record.cleanupTimer) { - clearTimeout(record.cleanupTimer); - record.cleanupTimer = undefined; - } - - // Calculate metrics once - const duration = Date.now() - record.incomingStartTime; - const logData = { - connectionId: record.id, - remoteIP: record.remoteIP, - localPort: record.localPort, - reason, - duration: plugins.prettyMs(duration), - bytes: { in: record.bytesReceived, out: record.bytesSent }, - tls: record.isTLS, - keepAlive: record.hasKeepAlive, - usingNetworkProxy: record.usingNetworkProxy, - domainSwitches: record.domainSwitches || 0, - component: 'connection-manager' - }; - - // Remove all data handlers to make sure we clean up properly - if (record.incoming) { - try { - record.incoming.removeAllListeners('data'); - record.renegotiationHandler = undefined; - } catch (err) { - logger.log('error', `Error removing data handlers: ${err}`, { - connectionId: record.id, - error: err, - component: 'connection-manager' - }); - } - } - - // Handle socket cleanup - check if sockets are still active - const cleanupPromises: Promise[] = []; - - if (record.incoming) { - // Extract underlying socket if it's a WrappedSocket - const incomingSocket = record.incoming instanceof WrappedSocket ? record.incoming.socket : record.incoming; - if (!record.incoming.writable || record.incoming.destroyed) { - // Socket is not active, clean up immediately - cleanupPromises.push(cleanupSocket(incomingSocket, `${record.id}-incoming`, { immediate: true })); - } else { - // Socket is still active, allow graceful cleanup - cleanupPromises.push(cleanupSocket(incomingSocket, `${record.id}-incoming`, { allowDrain: true, gracePeriod: 5000 })); - } - } - - if (record.outgoing) { - // Extract underlying socket if it's a WrappedSocket - const outgoingSocket = record.outgoing instanceof WrappedSocket ? record.outgoing.socket : record.outgoing; - if (!record.outgoing.writable || record.outgoing.destroyed) { - // Socket is not active, clean up immediately - cleanupPromises.push(cleanupSocket(outgoingSocket, `${record.id}-outgoing`, { immediate: true })); - } else { - // Socket is still active, allow graceful cleanup - cleanupPromises.push(cleanupSocket(outgoingSocket, `${record.id}-outgoing`, { allowDrain: true, gracePeriod: 5000 })); - } - } - - // Wait for cleanup to complete - Promise.all(cleanupPromises).catch(err => { - logger.log('error', `Error during socket cleanup: ${err}`, { - connectionId: record.id, - error: err, - component: 'connection-manager' - }); - }); - - // Clear pendingData to avoid memory leaks - record.pendingData = []; - record.pendingDataSize = 0; - - // Remove the record from the tracking map - this.connectionRecords.delete(record.id); - - // Use deduplicated logging for connection termination - if (this.smartProxy.settings.enableDetailedLogging) { - // For detailed logging, include more info but still deduplicate by IP+reason - connectionLogDeduplicator.log( - 'connection-terminated', - 'info', - `Connection terminated: ${record.remoteIP}:${record.localPort}`, - { - ...logData, - duration_ms: duration, - bytesIn: record.bytesReceived, - bytesOut: record.bytesSent - }, - `${record.remoteIP}-${reason}` - ); - } else { - // For normal logging, deduplicate by termination reason - connectionLogDeduplicator.log( - 'connection-terminated', - 'info', - `Connection terminated`, - { - remoteIP: record.remoteIP, - reason, - activeConnections: this.connectionRecords.size, - component: 'connection-manager' - }, - reason // Group by termination reason - ); - } - } - } - - - /** - * Creates a generic error handler for incoming or outgoing sockets - */ - public handleError(side: 'incoming' | 'outgoing', record: IConnectionRecord) { - return (err: Error) => { - const code = (err as any).code; - let reason = 'error'; - - const now = Date.now(); - const connectionDuration = now - record.incomingStartTime; - const lastActivityAge = now - record.lastActivity; - - // Update activity tracking - if (side === 'incoming') { - record.lastActivity = now; - this.scheduleInactivityCheck(record.id, record); - } - - const errorData = { - connectionId: record.id, - side, - remoteIP: record.remoteIP, - error: err.message, - duration: plugins.prettyMs(connectionDuration), - lastActivity: plugins.prettyMs(lastActivityAge), - component: 'connection-manager' - }; - - switch (code) { - case 'ECONNRESET': - reason = 'econnreset'; - logger.log('warn', `ECONNRESET on ${side}: ${record.remoteIP}`, errorData); - break; - case 'ETIMEDOUT': - reason = 'etimedout'; - logger.log('warn', `ETIMEDOUT on ${side}: ${record.remoteIP}`, errorData); - break; - default: - logger.log('error', `Error on ${side}: ${record.remoteIP} - ${err.message}`, errorData); - } - - if (side === 'incoming' && record.incomingTerminationReason == null) { - record.incomingTerminationReason = reason; - this.incrementTerminationStat('incoming', reason); - } else if (side === 'outgoing' && record.outgoingTerminationReason == null) { - record.outgoingTerminationReason = reason; - this.incrementTerminationStat('outgoing', reason); - } - - this.initiateCleanupOnce(record, reason); - }; - } - - /** - * Creates a generic close handler for incoming or outgoing sockets - */ - public handleClose(side: 'incoming' | 'outgoing', record: IConnectionRecord) { - return () => { - if (this.smartProxy.settings.enableDetailedLogging) { - logger.log('info', `Connection closed on ${side} side`, { - connectionId: record.id, - side, - remoteIP: record.remoteIP, - component: 'connection-manager' - }); - } - - if (side === 'incoming' && record.incomingTerminationReason == null) { - record.incomingTerminationReason = 'normal'; - this.incrementTerminationStat('incoming', 'normal'); - } else if (side === 'outgoing' && record.outgoingTerminationReason == null) { - record.outgoingTerminationReason = 'normal'; - this.incrementTerminationStat('outgoing', 'normal'); - record.outgoingClosedTime = Date.now(); - } - - this.initiateCleanupOnce(record, 'closed_' + side); - }; - } - - /** - * Increment termination statistics - */ - public incrementTerminationStat(side: 'incoming' | 'outgoing', reason: string): void { - this.terminationStats[side][reason] = (this.terminationStats[side][reason] || 0) + 1; - } - - /** - * Get termination statistics - */ - public getTerminationStats(): { incoming: Record; outgoing: Record } { - return this.terminationStats; - } - - /** - * Optimized inactivity check - only checks connections that are due - */ - private performOptimizedInactivityCheck(): void { - const now = Date.now(); - const connectionsToCheck: string[] = []; - - // Find connections that need checking - for (const [connectionId, checkTime] of this.nextInactivityCheck) { - if (checkTime <= now) { - connectionsToCheck.push(connectionId); - } - } - - // Also check ALL connections for zombie state (destroyed sockets but not cleaned up) - // This is critical for proxy chains where sockets can be destroyed without events - for (const [connectionId, record] of this.connectionRecords) { - if (!record.connectionClosed) { - const incomingDestroyed = record.incoming?.destroyed || false; - const outgoingDestroyed = record.outgoing?.destroyed || false; - - // Check for zombie connections: both sockets destroyed but connection not cleaned up - if (incomingDestroyed && outgoingDestroyed) { - logger.log('warn', `Zombie connection detected: ${connectionId} - both sockets destroyed but not cleaned up`, { - connectionId, - remoteIP: record.remoteIP, - age: plugins.prettyMs(now - record.incomingStartTime), - component: 'connection-manager' - }); - - // Clean up immediately - this.cleanupConnection(record, 'zombie_cleanup'); - continue; - } - - // Check for half-zombie: one socket destroyed - if (incomingDestroyed || outgoingDestroyed) { - const age = now - record.incomingStartTime; - // Use longer grace period for encrypted connections (5 minutes vs 30 seconds) - const gracePeriod = record.isTLS ? 300000 : 30000; - - // Also ensure connection is old enough to avoid premature cleanup - if (age > gracePeriod && age > 10000) { - logger.log('warn', `Half-zombie connection detected: ${connectionId} - ${incomingDestroyed ? 'incoming' : 'outgoing'} destroyed`, { - connectionId, - remoteIP: record.remoteIP, - age: plugins.prettyMs(age), - incomingDestroyed, - outgoingDestroyed, - isTLS: record.isTLS, - gracePeriod: plugins.prettyMs(gracePeriod), - component: 'connection-manager' - }); - - // Clean up - this.cleanupConnection(record, 'half_zombie_cleanup'); - } - } - - // Check for stuck connections: no data sent back to client - if (!record.connectionClosed && record.outgoing && record.bytesReceived > 0 && record.bytesSent === 0) { - const age = now - record.incomingStartTime; - // Use longer grace period for encrypted connections (5 minutes vs 60 seconds) - const stuckThreshold = record.isTLS ? 300000 : 60000; - - // If connection is older than threshold and no data sent back, likely stuck - if (age > stuckThreshold) { - logger.log('warn', `Stuck connection detected: ${connectionId} - received ${record.bytesReceived} bytes but sent 0 bytes`, { - connectionId, - remoteIP: record.remoteIP, - age: plugins.prettyMs(age), - bytesReceived: record.bytesReceived, - targetHost: record.targetHost, - targetPort: record.targetPort, - isTLS: record.isTLS, - threshold: plugins.prettyMs(stuckThreshold), - component: 'connection-manager' - }); - - // Set termination reason and increment stats - if (record.incomingTerminationReason == null) { - record.incomingTerminationReason = 'stuck_no_response'; - this.incrementTerminationStat('incoming', 'stuck_no_response'); - } - - // Clean up - this.cleanupConnection(record, 'stuck_no_response'); - } - } - } - } - - // Process only connections that need checking - for (const connectionId of connectionsToCheck) { - const record = this.connectionRecords.get(connectionId); - if (!record || record.connectionClosed) { - this.nextInactivityCheck.delete(connectionId); - continue; - } - - const inactivityTime = now - record.lastActivity; - - // Use extended timeout for extended-treatment keep-alive connections - let effectiveTimeout = this.smartProxy.settings.inactivityTimeout!; - if (record.hasKeepAlive && this.smartProxy.settings.keepAliveTreatment === 'extended') { - const multiplier = this.smartProxy.settings.keepAliveInactivityMultiplier || 6; - effectiveTimeout = effectiveTimeout * multiplier; - } - - if (inactivityTime > effectiveTimeout) { - // For keep-alive connections, issue a warning first - if (record.hasKeepAlive && !record.inactivityWarningIssued) { - logger.log('warn', `Keep-alive connection inactive: ${record.remoteIP}`, { - connectionId, - remoteIP: record.remoteIP, - inactiveFor: plugins.prettyMs(inactivityTime), - component: 'connection-manager' - }); - - record.inactivityWarningIssued = true; - - // Reschedule check for 10 minutes later - this.nextInactivityCheck.set(connectionId, now + 600000); - - // Try to stimulate activity with a probe packet - if (record.outgoing && !record.outgoing.destroyed) { - try { - record.outgoing.write(Buffer.alloc(0)); - } catch (err) { - logger.log('error', `Error sending probe packet: ${err}`, { - connectionId, - error: err, - component: 'connection-manager' - }); - } - } - } else { - // Close the connection - logger.log('warn', `Closing inactive connection: ${record.remoteIP}`, { - connectionId, - remoteIP: record.remoteIP, - inactiveFor: plugins.prettyMs(inactivityTime), - hasKeepAlive: record.hasKeepAlive, - component: 'connection-manager' - }); - this.cleanupConnection(record, 'inactivity'); - } - } else { - // Reschedule next check - this.scheduleInactivityCheck(connectionId, record); - } - - // Parity check: if outgoing socket closed and incoming remains active - // Increased from 2 minutes to 30 minutes for long-lived connections - if ( - record.outgoingClosedTime && - !record.incoming.destroyed && - !record.connectionClosed && - now - record.outgoingClosedTime > 1800000 // 30 minutes - ) { - // Only close if no data activity for 10 minutes - if (now - record.lastActivity > 600000) { - logger.log('warn', `Parity check failed after extended timeout: ${record.remoteIP}`, { - connectionId, - remoteIP: record.remoteIP, - timeElapsed: plugins.prettyMs(now - record.outgoingClosedTime), - inactiveFor: plugins.prettyMs(now - record.lastActivity), - component: 'connection-manager' - }); - this.cleanupConnection(record, 'parity_check'); - } - } - } - } - - /** - * Legacy method for backward compatibility - */ - public performInactivityCheck(): void { - this.performOptimizedInactivityCheck(); - } - - /** - * Clear all connections (for shutdown) - */ - public async clearConnections(): Promise { - // Delegate to LifecycleComponent's cleanup - await this.cleanup(); - } - - /** - * Override LifecycleComponent's onCleanup method - */ - protected async onCleanup(): Promise { - - // Process connections in batches to avoid blocking - const connections = Array.from(this.connectionRecords.values()); - const batchSize = 100; - let index = 0; - - const processBatch = () => { - const batch = connections.slice(index, index + batchSize); - - for (const record of batch) { - try { - if (record.cleanupTimer) { - clearTimeout(record.cleanupTimer); - record.cleanupTimer = undefined; - } - - // Immediate destruction using socket-utils - const shutdownPromises: Promise[] = []; - - if (record.incoming) { - const incomingSocket = record.incoming instanceof WrappedSocket ? record.incoming.socket : record.incoming; - shutdownPromises.push(cleanupSocket(incomingSocket, `${record.id}-incoming-shutdown`, { immediate: true })); - } - - if (record.outgoing) { - const outgoingSocket = record.outgoing instanceof WrappedSocket ? record.outgoing.socket : record.outgoing; - shutdownPromises.push(cleanupSocket(outgoingSocket, `${record.id}-outgoing-shutdown`, { immediate: true })); - } - - // Don't wait for shutdown cleanup in this batch processing - Promise.all(shutdownPromises).catch(() => {}); - } catch (err) { - logger.log('error', `Error during connection cleanup: ${err}`, { - connectionId: record.id, - error: err, - component: 'connection-manager' - }); - } - } - - index += batchSize; - - // Continue with next batch if needed - if (index < connections.length) { - setImmediate(processBatch); - } else { - // Clear all maps - this.connectionRecords.clear(); - this.nextInactivityCheck.clear(); - this.cleanupQueue.clear(); - this.terminationStats = { incoming: {}, outgoing: {} }; - } - }; - - // Start batch processing - setImmediate(processBatch); - } -} \ No newline at end of file diff --git a/ts/proxies/smart-proxy/http-proxy-bridge.ts b/ts/proxies/smart-proxy/http-proxy-bridge.ts deleted file mode 100644 index 420b0a5..0000000 --- a/ts/proxies/smart-proxy/http-proxy-bridge.ts +++ /dev/null @@ -1,213 +0,0 @@ -import * as plugins from '../../plugins.js'; -import { HttpProxy } from '../http-proxy/index.js'; -import { setupBidirectionalForwarding } from '../../core/utils/socket-utils.js'; -import type { IConnectionRecord } from './models/interfaces.js'; -import type { IRouteConfig } from './models/route-types.js'; -import { WrappedSocket } from '../../core/models/wrapped-socket.js'; -import type { SmartProxy } from './smart-proxy.js'; - -export class HttpProxyBridge { - private httpProxy: HttpProxy | null = null; - - constructor(private smartProxy: SmartProxy) {} - - /** - * Get the HttpProxy instance - */ - public getHttpProxy(): HttpProxy | null { - return this.httpProxy; - } - - /** - * Initialize HttpProxy instance - */ - public async initialize(): Promise { - if (!this.httpProxy && this.smartProxy.settings.useHttpProxy && this.smartProxy.settings.useHttpProxy.length > 0) { - const httpProxyOptions: any = { - port: this.smartProxy.settings.httpProxyPort!, - portProxyIntegration: true, - logLevel: this.smartProxy.settings.enableDetailedLogging ? 'debug' : 'info' - }; - - this.httpProxy = new HttpProxy(httpProxyOptions); - console.log(`Initialized HttpProxy on port ${this.smartProxy.settings.httpProxyPort}`); - - // Apply route configurations to HttpProxy - await this.syncRoutesToHttpProxy(this.smartProxy.settings.routes || []); - } - } - - /** - * Sync routes to HttpProxy - */ - public async syncRoutesToHttpProxy(routes: IRouteConfig[]): Promise { - if (!this.httpProxy) return; - - // Convert routes to HttpProxy format - const httpProxyConfigs = routes - .filter(route => { - // Check if this route matches any of the specified network proxy ports - const routePorts = Array.isArray(route.match.ports) - ? route.match.ports - : [route.match.ports]; - - return routePorts.some(port => - this.smartProxy.settings.useHttpProxy?.includes(port) - ); - }) - .map(route => this.routeToHttpProxyConfig(route)); - - // Apply configurations to HttpProxy - await this.httpProxy.updateRouteConfigs(httpProxyConfigs); - } - - /** - * Convert route to HttpProxy configuration - */ - private routeToHttpProxyConfig(route: IRouteConfig): any { - // Convert route to HttpProxy domain config format - let domain = '*'; - if (route.match.domains) { - if (Array.isArray(route.match.domains)) { - domain = route.match.domains[0] || '*'; - } else { - domain = route.match.domains; - } - } - - return { - ...route, // Keep the original route structure - match: { - ...route.match, - domains: domain // Ensure domains is always set for HttpProxy - } - }; - } - - /** - * Check if connection should use HttpProxy - */ - public shouldUseHttpProxy(connection: IConnectionRecord, routeMatch: any): boolean { - // Only use HttpProxy for TLS termination - return ( - routeMatch.route.action.tls?.mode === 'terminate' || - routeMatch.route.action.tls?.mode === 'terminate-and-reencrypt' - ) && this.httpProxy !== null; - } - - /** - * Forward connection to HttpProxy - */ - public async forwardToHttpProxy( - connectionId: string, - socket: plugins.net.Socket | WrappedSocket, - record: IConnectionRecord, - initialChunk: Buffer, - httpProxyPort: number, - cleanupCallback: (reason: string) => void - ): Promise { - if (!this.httpProxy) { - throw new Error('HttpProxy not initialized'); - } - - // Check if client socket is already destroyed before proceeding - const underlyingSocket = socket instanceof WrappedSocket ? socket.socket : socket; - if (underlyingSocket.destroyed) { - console.log(`[${connectionId}] Client socket already destroyed, skipping HttpProxy forwarding`); - cleanupCallback('client_disconnected_before_proxy'); - return; - } - - const proxySocket = new plugins.net.Socket(); - - // Handle client disconnect during proxy connection setup - const clientDisconnectHandler = () => { - console.log(`[${connectionId}] Client disconnected during HttpProxy connection setup`); - proxySocket.destroy(); - cleanupCallback('client_disconnected_during_setup'); - }; - underlyingSocket.once('close', clientDisconnectHandler); - - try { - await new Promise((resolve, reject) => { - proxySocket.connect(httpProxyPort, 'localhost', () => { - console.log(`[${connectionId}] Connected to HttpProxy for termination`); - resolve(); - }); - - proxySocket.on('error', reject); - }); - } finally { - // Remove the disconnect handler after connection attempt - underlyingSocket.removeListener('close', clientDisconnectHandler); - } - - // Double-check client socket is still connected after async operation - if (underlyingSocket.destroyed) { - console.log(`[${connectionId}] Client disconnected while connecting to HttpProxy`); - proxySocket.destroy(); - cleanupCallback('client_disconnected_after_proxy_connect'); - return; - } - - // Send client IP information header first (custom protocol) - // Format: "CLIENT_IP:\r\n" - const clientIPHeader = Buffer.from(`CLIENT_IP:${record.remoteIP}\r\n`); - proxySocket.write(clientIPHeader); - - // Send initial chunk if present - if (initialChunk) { - // Count the initial chunk bytes - record.bytesReceived += initialChunk.length; - if (this.smartProxy.metricsCollector) { - this.smartProxy.metricsCollector.recordBytes(record.id, initialChunk.length, 0); - } - proxySocket.write(initialChunk); - } - - // Use centralized bidirectional forwarding (underlyingSocket already extracted above) - setupBidirectionalForwarding(underlyingSocket, proxySocket, { - onClientData: (chunk) => { - // Update stats - this is the ONLY place bytes are counted for HttpProxy connections - if (record) { - record.bytesReceived += chunk.length; - if (this.smartProxy.metricsCollector) { - this.smartProxy.metricsCollector.recordBytes(record.id, chunk.length, 0); - } - } - }, - onServerData: (chunk) => { - // Update stats - this is the ONLY place bytes are counted for HttpProxy connections - if (record) { - record.bytesSent += chunk.length; - if (this.smartProxy.metricsCollector) { - this.smartProxy.metricsCollector.recordBytes(record.id, 0, chunk.length); - } - } - }, - onCleanup: (reason) => { - cleanupCallback(reason); - }, - enableHalfOpen: false // Close both when one closes (required for proxy chains) - }); - } - - /** - * Start HttpProxy - */ - public async start(): Promise { - if (this.httpProxy) { - await this.httpProxy.start(); - } - } - - /** - * Stop HttpProxy - */ - public async stop(): Promise { - if (this.httpProxy) { - await this.httpProxy.stop(); - this.httpProxy = null; - } - } -} \ No newline at end of file diff --git a/ts/proxies/smart-proxy/index.ts b/ts/proxies/smart-proxy/index.ts index 9742b13..900bd81 100644 --- a/ts/proxies/smart-proxy/index.ts +++ b/ts/proxies/smart-proxy/index.ts @@ -1,7 +1,7 @@ /** * SmartProxy implementation * - * Version 14.0.0: Unified Route-Based Configuration API + * Version 23.0.0: Rust-backed proxy engine */ // Re-export models export * from './models/index.js'; @@ -9,21 +9,14 @@ export * from './models/index.js'; // Export the main SmartProxy class export { SmartProxy } from './smart-proxy.js'; -// Export core supporting classes -export { ConnectionManager } from './connection-manager.js'; -export { SecurityManager } from './security-manager.js'; -export { TimeoutManager } from './timeout-manager.js'; -export { TlsManager } from './tls-manager.js'; -export { HttpProxyBridge } from './http-proxy-bridge.js'; +// Export Rust bridge and helpers +export { RustProxyBridge } from './rust-proxy-bridge.js'; +export { RoutePreprocessor } from './route-preprocessor.js'; +export { SocketHandlerServer } from './socket-handler-server.js'; +export { RustMetricsAdapter } from './rust-metrics-adapter.js'; // Export route-based components export { SharedRouteManager as RouteManager } from '../../core/routing/route-manager.js'; -export { RouteConnectionHandler } from './route-connection-handler.js'; -export { NFTablesManager } from './nftables-manager.js'; -export { RouteOrchestrator } from './route-orchestrator.js'; - -// Export certificate management -export { SmartCertManager } from './certificate-manager.js'; // Export all helper functions from the utils directory export * from './utils/index.js'; diff --git a/ts/proxies/smart-proxy/metrics-collector.ts b/ts/proxies/smart-proxy/metrics-collector.ts deleted file mode 100644 index 7292585..0000000 --- a/ts/proxies/smart-proxy/metrics-collector.ts +++ /dev/null @@ -1,453 +0,0 @@ -import * as plugins from '../../plugins.js'; -import type { SmartProxy } from './smart-proxy.js'; -import type { - IMetrics, - IThroughputData, - IThroughputHistoryPoint, - IByteTracker -} from './models/metrics-types.js'; -import { ThroughputTracker } from './throughput-tracker.js'; -import { logger } from '../../core/utils/logger.js'; - -/** - * Collects and provides metrics for SmartProxy with clean API - */ -export class MetricsCollector implements IMetrics { - // Throughput tracking - private throughputTracker: ThroughputTracker; - private routeThroughputTrackers = new Map(); - private ipThroughputTrackers = new Map(); - - // Request tracking - private requestTimestamps: number[] = []; - private totalRequests: number = 0; - - // Connection byte tracking for per-route/IP metrics - private connectionByteTrackers = new Map(); - - // Subscriptions - private samplingInterval?: NodeJS.Timeout; - private connectionSubscription?: plugins.smartrx.rxjs.Subscription; - - // Configuration - private readonly sampleIntervalMs: number; - private readonly retentionSeconds: number; - - // Track connection durations for percentile calculations - private connectionDurations: number[] = []; - private bytesInArray: number[] = []; - private bytesOutArray: number[] = []; - - constructor( - private smartProxy: SmartProxy, - config?: { - sampleIntervalMs?: number; - retentionSeconds?: number; - } - ) { - this.sampleIntervalMs = config?.sampleIntervalMs || 1000; - this.retentionSeconds = config?.retentionSeconds || 3600; - this.throughputTracker = new ThroughputTracker(this.retentionSeconds); - } - - // Connection metrics implementation - public connections = { - active: (): number => { - return this.smartProxy.connectionManager.getConnectionCount(); - }, - - total: (): number => { - const stats = this.smartProxy.connectionManager.getTerminationStats(); - let total = this.smartProxy.connectionManager.getConnectionCount(); - - for (const reason in stats.incoming) { - total += stats.incoming[reason]; - } - - return total; - }, - - byRoute: (): Map => { - const routeCounts = new Map(); - const connections = this.smartProxy.connectionManager.getConnections(); - - for (const [_, record] of connections) { - const routeName = (record as any).routeName || - record.routeConfig?.name || - 'unknown'; - - const current = routeCounts.get(routeName) || 0; - routeCounts.set(routeName, current + 1); - } - - return routeCounts; - }, - - byIP: (): Map => { - const ipCounts = new Map(); - - for (const [_, record] of this.smartProxy.connectionManager.getConnections()) { - const ip = record.remoteIP; - const current = ipCounts.get(ip) || 0; - ipCounts.set(ip, current + 1); - } - - return ipCounts; - }, - - topIPs: (limit: number = 10): Array<{ ip: string; count: number }> => { - const ipCounts = this.connections.byIP(); - return Array.from(ipCounts.entries()) - .sort((a, b) => b[1] - a[1]) - .slice(0, limit) - .map(([ip, count]) => ({ ip, count })); - } - }; - - // Throughput metrics implementation - public throughput = { - instant: (): IThroughputData => { - return this.throughputTracker.getRate(1); - }, - - recent: (): IThroughputData => { - return this.throughputTracker.getRate(10); - }, - - average: (): IThroughputData => { - return this.throughputTracker.getRate(60); - }, - - custom: (seconds: number): IThroughputData => { - return this.throughputTracker.getRate(seconds); - }, - - history: (seconds: number): Array => { - return this.throughputTracker.getHistory(seconds); - }, - - byRoute: (windowSeconds: number = 1): Map => { - const routeThroughput = new Map(); - - // Get throughput from each route's dedicated tracker - for (const [route, tracker] of this.routeThroughputTrackers) { - const rate = tracker.getRate(windowSeconds); - if (rate.in > 0 || rate.out > 0) { - routeThroughput.set(route, rate); - } - } - - return routeThroughput; - }, - - byIP: (windowSeconds: number = 1): Map => { - const ipThroughput = new Map(); - - // Get throughput from each IP's dedicated tracker - for (const [ip, tracker] of this.ipThroughputTrackers) { - const rate = tracker.getRate(windowSeconds); - if (rate.in > 0 || rate.out > 0) { - ipThroughput.set(ip, rate); - } - } - - return ipThroughput; - } - }; - - // Request metrics implementation - public requests = { - perSecond: (): number => { - const now = Date.now(); - const oneSecondAgo = now - 1000; - - // Clean old timestamps - this.requestTimestamps = this.requestTimestamps.filter(ts => ts > now - 60000); - - // Count requests in last second - const recentRequests = this.requestTimestamps.filter(ts => ts > oneSecondAgo); - return recentRequests.length; - }, - - perMinute: (): number => { - const now = Date.now(); - const oneMinuteAgo = now - 60000; - - // Count requests in last minute - const recentRequests = this.requestTimestamps.filter(ts => ts > oneMinuteAgo); - return recentRequests.length; - }, - - total: (): number => { - return this.totalRequests; - } - }; - - // Totals implementation - public totals = { - bytesIn: (): number => { - let total = 0; - - // Sum from all active connections - for (const [_, record] of this.smartProxy.connectionManager.getConnections()) { - total += record.bytesReceived; - } - - // TODO: Add historical data from terminated connections - - return total; - }, - - bytesOut: (): number => { - let total = 0; - - // Sum from all active connections - for (const [_, record] of this.smartProxy.connectionManager.getConnections()) { - total += record.bytesSent; - } - - // TODO: Add historical data from terminated connections - - return total; - }, - - connections: (): number => { - return this.connections.total(); - } - }; - - // Helper to calculate percentiles from an array - private calculatePercentile(arr: number[], percentile: number): number { - if (arr.length === 0) return 0; - const sorted = [...arr].sort((a, b) => a - b); - const index = Math.floor((sorted.length - 1) * percentile); - return sorted[index]; - } - - // Percentiles implementation - public percentiles = { - connectionDuration: (): { p50: number; p95: number; p99: number } => { - return { - p50: this.calculatePercentile(this.connectionDurations, 0.5), - p95: this.calculatePercentile(this.connectionDurations, 0.95), - p99: this.calculatePercentile(this.connectionDurations, 0.99) - }; - }, - - bytesTransferred: (): { - in: { p50: number; p95: number; p99: number }; - out: { p50: number; p95: number; p99: number }; - } => { - return { - in: { - p50: this.calculatePercentile(this.bytesInArray, 0.5), - p95: this.calculatePercentile(this.bytesInArray, 0.95), - p99: this.calculatePercentile(this.bytesInArray, 0.99) - }, - out: { - p50: this.calculatePercentile(this.bytesOutArray, 0.5), - p95: this.calculatePercentile(this.bytesOutArray, 0.95), - p99: this.calculatePercentile(this.bytesOutArray, 0.99) - } - }; - } - }; - - /** - * Record a new request - */ - public recordRequest(connectionId: string, routeName: string, remoteIP: string): void { - const now = Date.now(); - this.requestTimestamps.push(now); - this.totalRequests++; - - // Initialize byte tracker for this connection - this.connectionByteTrackers.set(connectionId, { - connectionId, - routeName, - remoteIP, - bytesIn: 0, - bytesOut: 0, - startTime: now, - lastUpdate: now - }); - - // Cleanup old request timestamps - if (this.requestTimestamps.length > 5000) { - // First try to clean up old timestamps (older than 1 minute) - const cutoff = now - 60000; - this.requestTimestamps = this.requestTimestamps.filter(ts => ts > cutoff); - - // If still too many, enforce hard cap of 5000 most recent - if (this.requestTimestamps.length > 5000) { - this.requestTimestamps = this.requestTimestamps.slice(-5000); - } - } - } - - /** - * Record bytes transferred for a connection - */ - public recordBytes(connectionId: string, bytesIn: number, bytesOut: number): void { - // Update global throughput tracker - this.throughputTracker.recordBytes(bytesIn, bytesOut); - - // Update connection-specific tracker - const tracker = this.connectionByteTrackers.get(connectionId); - if (tracker) { - tracker.bytesIn += bytesIn; - tracker.bytesOut += bytesOut; - tracker.lastUpdate = Date.now(); - - // Update per-route throughput tracker - let routeTracker = this.routeThroughputTrackers.get(tracker.routeName); - if (!routeTracker) { - routeTracker = new ThroughputTracker(this.retentionSeconds); - this.routeThroughputTrackers.set(tracker.routeName, routeTracker); - } - routeTracker.recordBytes(bytesIn, bytesOut); - - // Update per-IP throughput tracker - let ipTracker = this.ipThroughputTrackers.get(tracker.remoteIP); - if (!ipTracker) { - ipTracker = new ThroughputTracker(this.retentionSeconds); - this.ipThroughputTrackers.set(tracker.remoteIP, ipTracker); - } - ipTracker.recordBytes(bytesIn, bytesOut); - } - } - - /** - * Clean up tracking for a closed connection - */ - public removeConnection(connectionId: string): void { - const tracker = this.connectionByteTrackers.get(connectionId); - if (tracker) { - // Calculate connection duration - const duration = Date.now() - tracker.startTime; - - // Add to arrays for percentile calculations (bounded to prevent memory growth) - const MAX_SAMPLES = 5000; - - this.connectionDurations.push(duration); - if (this.connectionDurations.length > MAX_SAMPLES) { - this.connectionDurations.shift(); - } - - this.bytesInArray.push(tracker.bytesIn); - if (this.bytesInArray.length > MAX_SAMPLES) { - this.bytesInArray.shift(); - } - - this.bytesOutArray.push(tracker.bytesOut); - if (this.bytesOutArray.length > MAX_SAMPLES) { - this.bytesOutArray.shift(); - } - } - - this.connectionByteTrackers.delete(connectionId); - } - - /** - * Start the metrics collector - */ - public start(): void { - if (!this.smartProxy.routeConnectionHandler) { - throw new Error('MetricsCollector: RouteConnectionHandler not available'); - } - - // Start periodic sampling - this.samplingInterval = setInterval(() => { - // Sample global throughput - this.throughputTracker.takeSample(); - - // Sample per-route throughput - for (const [_, tracker] of this.routeThroughputTrackers) { - tracker.takeSample(); - } - - // Sample per-IP throughput - for (const [_, tracker] of this.ipThroughputTrackers) { - tracker.takeSample(); - } - - // Clean up old connection trackers (connections closed more than 5 minutes ago) - const cutoff = Date.now() - 300000; - for (const [id, tracker] of this.connectionByteTrackers) { - if (tracker.lastUpdate < cutoff) { - this.connectionByteTrackers.delete(id); - } - } - - // Clean up unused route trackers - const activeRoutes = new Set(Array.from(this.connectionByteTrackers.values()).map(t => t.routeName)); - for (const [route, _] of this.routeThroughputTrackers) { - if (!activeRoutes.has(route)) { - this.routeThroughputTrackers.delete(route); - } - } - - // Clean up unused IP trackers - const activeIPs = new Set(Array.from(this.connectionByteTrackers.values()).map(t => t.remoteIP)); - for (const [ip, _] of this.ipThroughputTrackers) { - if (!activeIPs.has(ip)) { - this.ipThroughputTrackers.delete(ip); - } - } - }, this.sampleIntervalMs); - - // Unref the interval so it doesn't keep the process alive - if (this.samplingInterval.unref) { - this.samplingInterval.unref(); - } - - // Subscribe to new connections - this.connectionSubscription = this.smartProxy.routeConnectionHandler.newConnectionSubject.subscribe({ - next: (record) => { - const routeName = record.routeConfig?.name || 'unknown'; - this.recordRequest(record.id, routeName, record.remoteIP); - - if (this.smartProxy.settings?.enableDetailedLogging) { - logger.log('debug', `MetricsCollector: New connection recorded`, { - connectionId: record.id, - remoteIP: record.remoteIP, - routeName, - component: 'metrics' - }); - } - }, - error: (err) => { - logger.log('error', `MetricsCollector: Error in connection subscription`, { - error: err.message, - component: 'metrics' - }); - } - }); - - logger.log('debug', 'MetricsCollector started', { component: 'metrics' }); - } - - /** - * Stop the metrics collector - */ - public stop(): void { - if (this.samplingInterval) { - clearInterval(this.samplingInterval); - this.samplingInterval = undefined; - } - - if (this.connectionSubscription) { - this.connectionSubscription.unsubscribe(); - this.connectionSubscription = undefined; - } - - logger.log('debug', 'MetricsCollector stopped', { component: 'metrics' }); - } - - /** - * Alias for stop() for compatibility - */ - public destroy(): void { - this.stop(); - } -} \ No newline at end of file diff --git a/ts/proxies/smart-proxy/models/interfaces.ts b/ts/proxies/smart-proxy/models/interfaces.ts index fd8b2d7..0ae7335 100644 --- a/ts/proxies/smart-proxy/models/interfaces.ts +++ b/ts/proxies/smart-proxy/models/interfaces.ts @@ -99,10 +99,6 @@ export interface ISmartProxyOptions { keepAliveInactivityMultiplier?: number; // Multiplier for inactivity timeout for keep-alive connections extendedKeepAliveLifetime?: number; // Extended lifetime for keep-alive connections (ms) - // HttpProxy integration - useHttpProxy?: number[]; // Array of ports to forward to HttpProxy - httpProxyPort?: number; // Port where HttpProxy is listening (default: 8443) - // Metrics configuration metrics?: { enabled?: boolean; @@ -139,6 +135,12 @@ export interface ISmartProxyOptions { * Default: true */ certProvisionFallbackToAcme?: boolean; + + /** + * Path to the RustProxy binary. If not set, the binary is located + * automatically via env var, platform package, local build, or PATH. + */ + rustBinaryPath?: string; } /** diff --git a/ts/proxies/smart-proxy/nftables-manager.ts b/ts/proxies/smart-proxy/nftables-manager.ts deleted file mode 100644 index 192bcc0..0000000 --- a/ts/proxies/smart-proxy/nftables-manager.ts +++ /dev/null @@ -1,271 +0,0 @@ -import * as plugins from '../../plugins.js'; -import { NfTablesProxy } from '../nftables-proxy/nftables-proxy.js'; -import type { - NfTableProxyOptions, - PortRange, - NfTablesStatus -} from '../nftables-proxy/models/interfaces.js'; -import type { - IRouteConfig, - TPortRange, - INfTablesOptions -} from './models/route-types.js'; -import type { SmartProxy } from './smart-proxy.js'; - -/** - * Manages NFTables rules based on SmartProxy route configurations - * - * This class bridges the gap between SmartProxy routes and the NFTablesProxy, - * allowing high-performance kernel-level packet forwarding for routes that - * specify NFTables as their forwarding engine. - */ -export class NFTablesManager { - private rulesMap: Map = new Map(); - - /** - * Creates a new NFTablesManager - * - * @param smartProxy The SmartProxy instance - */ - constructor(private smartProxy: SmartProxy) {} - - /** - * Provision NFTables rules for a route - * - * @param route The route configuration - * @returns A promise that resolves to true if successful, false otherwise - */ - public async provisionRoute(route: IRouteConfig): Promise { - // Generate a unique ID for this route - const routeId = this.generateRouteId(route); - - // Skip if route doesn't use NFTables - if (route.action.forwardingEngine !== 'nftables') { - return true; - } - - // Create NFTables options from route configuration - const nftOptions = this.createNfTablesOptions(route); - - // Create and start an NFTablesProxy instance - const proxy = new NfTablesProxy(nftOptions); - - try { - await proxy.start(); - this.rulesMap.set(routeId, proxy); - return true; - } catch (err) { - console.error(`Failed to provision NFTables rules for route ${route.name || 'unnamed'}: ${err.message}`); - return false; - } - } - - /** - * Remove NFTables rules for a route - * - * @param route The route configuration - * @returns A promise that resolves to true if successful, false otherwise - */ - public async deprovisionRoute(route: IRouteConfig): Promise { - const routeId = this.generateRouteId(route); - - const proxy = this.rulesMap.get(routeId); - if (!proxy) { - return true; // Nothing to remove - } - - try { - await proxy.stop(); - this.rulesMap.delete(routeId); - return true; - } catch (err) { - console.error(`Failed to deprovision NFTables rules for route ${route.name || 'unnamed'}: ${err.message}`); - return false; - } - } - - /** - * Update NFTables rules when route changes - * - * @param oldRoute The previous route configuration - * @param newRoute The new route configuration - * @returns A promise that resolves to true if successful, false otherwise - */ - public async updateRoute(oldRoute: IRouteConfig, newRoute: IRouteConfig): Promise { - // Remove old rules and add new ones - await this.deprovisionRoute(oldRoute); - return this.provisionRoute(newRoute); - } - - /** - * Generate a unique ID for a route - * - * @param route The route configuration - * @returns A unique ID string - */ - private generateRouteId(route: IRouteConfig): string { - // Generate a unique ID based on route properties - // Include the route name, match criteria, and a timestamp - const matchStr = JSON.stringify({ - ports: route.match.ports, - domains: route.match.domains - }); - - return `${route.name || 'unnamed'}-${matchStr}-${route.id || Date.now().toString()}`; - } - - /** - * Create NFTablesProxy options from a route configuration - * - * @param route The route configuration - * @returns NFTableProxyOptions object - */ - private createNfTablesOptions(route: IRouteConfig): NfTableProxyOptions { - const { action } = route; - - // Ensure we have targets - if (!action.targets || action.targets.length === 0) { - throw new Error('Route must have targets to use NFTables forwarding'); - } - - // NFTables can only handle a single target, so we use the first target without match criteria - // or the first target if all have match criteria - const defaultTarget = action.targets.find(t => !t.match) || action.targets[0]; - - // Convert port specifications - const fromPorts = this.expandPortRange(route.match.ports); - - // Determine target port - let toPorts: number | PortRange | Array; - - if (defaultTarget.port === 'preserve') { - // 'preserve' means use the same ports as the source - toPorts = fromPorts; - } else if (typeof defaultTarget.port === 'function') { - // For function-based ports, we can't determine at setup time - // Use the "preserve" approach and let NFTables handle it - toPorts = fromPorts; - } else { - toPorts = defaultTarget.port; - } - - // Determine target host - let toHost: string; - if (typeof defaultTarget.host === 'function') { - // Can't determine at setup time, use localhost as a placeholder - // and rely on run-time handling - toHost = 'localhost'; - } else if (Array.isArray(defaultTarget.host)) { - // Use first host for now - NFTables will do simple round-robin - toHost = defaultTarget.host[0]; - } else { - toHost = defaultTarget.host; - } - - // Create options - const options: NfTableProxyOptions = { - fromPort: fromPorts, - toPort: toPorts, - toHost: toHost, - protocol: action.nftables?.protocol || 'tcp', - preserveSourceIP: action.nftables?.preserveSourceIP !== undefined ? - action.nftables.preserveSourceIP : - this.smartProxy.settings.preserveSourceIP, - useIPSets: action.nftables?.useIPSets !== false, - useAdvancedNAT: action.nftables?.useAdvancedNAT, - enableLogging: this.smartProxy.settings.enableDetailedLogging, - deleteOnExit: true, - tableName: action.nftables?.tableName || 'smartproxy' - }; - - // Add security-related options - if (route.security?.ipAllowList?.length) { - options.ipAllowList = route.security.ipAllowList; - } - - if (route.security?.ipBlockList?.length) { - options.ipBlockList = route.security.ipBlockList; - } - - // Add QoS options - if (action.nftables?.maxRate || action.nftables?.priority) { - options.qos = { - enabled: true, - maxRate: action.nftables.maxRate, - priority: action.nftables.priority - }; - } - - return options; - } - - /** - * Expand port range specifications - * - * @param ports The port range specification - * @returns Expanded port range - */ - private expandPortRange(ports: TPortRange): number | PortRange | Array { - // Process different port specifications - if (typeof ports === 'number') { - return ports; - } else if (Array.isArray(ports)) { - const result: Array = []; - - for (const item of ports) { - if (typeof item === 'number') { - result.push(item); - } else if ('from' in item && 'to' in item) { - result.push({ from: item.from, to: item.to }); - } - } - - return result; - } else if (typeof ports === 'object' && ports !== null && 'from' in ports && 'to' in ports) { - return { from: (ports as any).from, to: (ports as any).to }; - } - - // Fallback to port 80 if something went wrong - console.warn('Invalid port range specification, using port 80 as fallback'); - return 80; - } - - /** - * Get status of all managed rules - * - * @returns A promise that resolves to a record of NFTables status objects - */ - public async getStatus(): Promise> { - const result: Record = {}; - - for (const [routeId, proxy] of this.rulesMap.entries()) { - result[routeId] = await proxy.getStatus(); - } - - return result; - } - - /** - * Check if a route is currently provisioned - * - * @param route The route configuration - * @returns True if the route is provisioned, false otherwise - */ - public isRouteProvisioned(route: IRouteConfig): boolean { - const routeId = this.generateRouteId(route); - return this.rulesMap.has(routeId); - } - - /** - * Stop all NFTables rules - * - * @returns A promise that resolves when all rules have been stopped - */ - public async stop(): Promise { - // Stop all NFTables proxies - const stopPromises = Array.from(this.rulesMap.values()).map(proxy => proxy.stop()); - await Promise.all(stopPromises); - - this.rulesMap.clear(); - } -} \ No newline at end of file diff --git a/ts/proxies/smart-proxy/port-manager.ts b/ts/proxies/smart-proxy/port-manager.ts deleted file mode 100644 index 0e40c62..0000000 --- a/ts/proxies/smart-proxy/port-manager.ts +++ /dev/null @@ -1,358 +0,0 @@ -import * as plugins from '../../plugins.js'; -import { logger } from '../../core/utils/logger.js'; -import { cleanupSocket } from '../../core/utils/socket-utils.js'; -import type { SmartProxy } from './smart-proxy.js'; - -/** - * PortManager handles the dynamic creation and removal of port listeners - * - * This class provides methods to add and remove listening ports at runtime, - * allowing SmartProxy to adapt to configuration changes without requiring - * a full restart. - * - * It includes a reference counting system to track how many routes are using - * each port, so ports can be automatically released when they are no longer needed. - */ -export class PortManager { - private servers: Map = new Map(); - private isShuttingDown: boolean = false; - // Track how many routes are using each port - private portRefCounts: Map = new Map(); - - /** - * Create a new PortManager - * - * @param smartProxy The SmartProxy instance - */ - constructor( - private smartProxy: SmartProxy - ) {} - - /** - * Start listening on a specific port - * - * @param port The port number to listen on - * @returns Promise that resolves when the server is listening or rejects on error - */ - public async addPort(port: number): Promise { - // Check if we're already listening on this port - if (this.servers.has(port)) { - // Port is already bound, just increment the reference count - this.incrementPortRefCount(port); - try { - logger.log('debug', `PortManager: Port ${port} is already bound by SmartProxy, reusing binding`, { - port, - component: 'port-manager' - }); - } catch (e) { - console.log(`[DEBUG] PortManager: Port ${port} is already bound by SmartProxy, reusing binding`); - } - return; - } - - // Initialize reference count for new port - this.portRefCounts.set(port, 1); - - // Create a server for this port - const server = plugins.net.createServer((socket) => { - // Check if shutting down - if (this.isShuttingDown) { - cleanupSocket(socket, 'port-manager-shutdown', { immediate: true }); - return; - } - - // Delegate to route connection handler - this.smartProxy.routeConnectionHandler.handleConnection(socket); - }).on('error', (err: Error) => { - try { - logger.log('error', `Server Error on port ${port}: ${err.message}`, { - port, - error: err.message, - component: 'port-manager' - }); - } catch (e) { - console.error(`[ERROR] Server Error on port ${port}: ${err.message}`); - } - }); - - // Start listening on the port - return new Promise((resolve, reject) => { - server.listen(port, () => { - const isHttpProxyPort = this.smartProxy.settings.useHttpProxy?.includes(port); - try { - logger.log('info', `SmartProxy -> OK: Now listening on port ${port}${ - isHttpProxyPort ? ' (HttpProxy forwarding enabled)' : '' - }`, { - port, - isHttpProxyPort: !!isHttpProxyPort, - component: 'port-manager' - }); - } catch (e) { - console.log(`[INFO] SmartProxy -> OK: Now listening on port ${port}${ - isHttpProxyPort ? ' (HttpProxy forwarding enabled)' : '' - }`); - } - - // Store the server reference - this.servers.set(port, server); - resolve(); - }).on('error', (err) => { - // Check if this is an external conflict - const { isConflict, isExternal } = this.isPortConflict(err); - - if (isConflict && !isExternal) { - // This is an internal conflict (port already bound by SmartProxy) - // This shouldn't normally happen because we check servers.has(port) above - logger.log('warn', `Port ${port} binding conflict: already in use by SmartProxy`, { - port, - component: 'port-manager' - }); - // Still increment reference count to maintain tracking - this.incrementPortRefCount(port); - resolve(); - return; - } - - // Log the error and propagate it - logger.log('error', `Failed to listen on port ${port}: ${err.message}`, { - port, - error: err.message, - code: (err as any).code, - component: 'port-manager' - }); - - // Clean up reference count since binding failed - this.portRefCounts.delete(port); - - reject(err); - }); - }); - } - - /** - * Stop listening on a specific port - * - * @param port The port to stop listening on - * @returns Promise that resolves when the server is closed - */ - public async removePort(port: number): Promise { - // Decrement the reference count first - const newRefCount = this.decrementPortRefCount(port); - - // If there are still references to this port, keep it open - if (newRefCount > 0) { - logger.log('debug', `PortManager: Port ${port} still has ${newRefCount} references, keeping open`, { - port, - refCount: newRefCount, - component: 'port-manager' - }); - return; - } - - // Get the server for this port - const server = this.servers.get(port); - if (!server) { - logger.log('warn', `PortManager: Not listening on port ${port}`, { - port, - component: 'port-manager' - }); - // Ensure reference count is reset - this.portRefCounts.delete(port); - return; - } - - // Close the server - return new Promise((resolve) => { - server.close((err) => { - if (err) { - logger.log('error', `Error closing server on port ${port}: ${err.message}`, { - port, - error: err.message, - component: 'port-manager' - }); - } else { - logger.log('info', `SmartProxy -> Stopped listening on port ${port}`, { - port, - component: 'port-manager' - }); - } - - // Remove the server reference and clean up reference counting - this.servers.delete(port); - this.portRefCounts.delete(port); - resolve(); - }); - }); - } - - /** - * Add multiple ports at once - * - * @param ports Array of ports to add - * @returns Promise that resolves when all servers are listening - */ - public async addPorts(ports: number[]): Promise { - const uniquePorts = [...new Set(ports)]; - await Promise.all(uniquePorts.map(port => this.addPort(port))); - } - - /** - * Remove multiple ports at once - * - * @param ports Array of ports to remove - * @returns Promise that resolves when all servers are closed - */ - public async removePorts(ports: number[]): Promise { - const uniquePorts = [...new Set(ports)]; - await Promise.all(uniquePorts.map(port => this.removePort(port))); - } - - /** - * Update listening ports to match the provided list - * - * This will add any ports that aren't currently listening, - * and remove any ports that are no longer needed. - * - * @param ports Array of ports that should be listening - * @returns Promise that resolves when all operations are complete - */ - public async updatePorts(ports: number[]): Promise { - const targetPorts = new Set(ports); - const currentPorts = new Set(this.servers.keys()); - - // Find ports to add and remove - const portsToAdd = ports.filter(port => !currentPorts.has(port)); - const portsToRemove = Array.from(currentPorts).filter(port => !targetPorts.has(port)); - - // Log the changes - if (portsToAdd.length > 0) { - console.log(`PortManager: Adding new listeners for ports: ${portsToAdd.join(', ')}`); - } - - if (portsToRemove.length > 0) { - console.log(`PortManager: Removing listeners for ports: ${portsToRemove.join(', ')}`); - } - - // Add and remove ports - await this.removePorts(portsToRemove); - await this.addPorts(portsToAdd); - } - - /** - * Get all ports that are currently listening - * - * @returns Array of port numbers - */ - public getListeningPorts(): number[] { - return Array.from(this.servers.keys()); - } - - /** - * Mark the port manager as shutting down - */ - public setShuttingDown(isShuttingDown: boolean): void { - this.isShuttingDown = isShuttingDown; - } - - /** - * Close all listening servers - * - * @returns Promise that resolves when all servers are closed - */ - public async closeAll(): Promise { - const allPorts = Array.from(this.servers.keys()); - await this.removePorts(allPorts); - } - - /** - * Get all server instances (for testing or debugging) - */ - public getServers(): Map { - return new Map(this.servers); - } - - /** - * Check if a port is bound by this SmartProxy instance - * - * @param port The port number to check - * @returns True if the port is currently bound by SmartProxy - */ - public isPortBoundBySmartProxy(port: number): boolean { - return this.servers.has(port); - } - - /** - * Get the current reference count for a port - * - * @param port The port number to check - * @returns The number of routes using this port, 0 if none - */ - public getPortRefCount(port: number): number { - return this.portRefCounts.get(port) || 0; - } - - /** - * Increment the reference count for a port - * - * @param port The port number to increment - * @returns The new reference count - */ - public incrementPortRefCount(port: number): number { - const currentCount = this.portRefCounts.get(port) || 0; - const newCount = currentCount + 1; - this.portRefCounts.set(port, newCount); - - logger.log('debug', `Port ${port} reference count increased to ${newCount}`, { - port, - refCount: newCount, - component: 'port-manager' - }); - - return newCount; - } - - /** - * Decrement the reference count for a port - * - * @param port The port number to decrement - * @returns The new reference count - */ - public decrementPortRefCount(port: number): number { - const currentCount = this.portRefCounts.get(port) || 0; - - if (currentCount <= 0) { - logger.log('warn', `Attempted to decrement reference count for port ${port} below zero`, { - port, - component: 'port-manager' - }); - return 0; - } - - const newCount = currentCount - 1; - this.portRefCounts.set(port, newCount); - - logger.log('debug', `Port ${port} reference count decreased to ${newCount}`, { - port, - refCount: newCount, - component: 'port-manager' - }); - - return newCount; - } - - /** - * Determine if a port binding error is due to an external or internal conflict - * - * @param error The error object from a failed port binding - * @returns Object indicating if this is a conflict and if it's external - */ - private isPortConflict(error: any): { isConflict: boolean; isExternal: boolean } { - if (error.code !== 'EADDRINUSE') { - return { isConflict: false, isExternal: false }; - } - - // Check if we already have this port - const isBoundInternally = this.servers.has(Number(error.port)); - return { isConflict: true, isExternal: !isBoundInternally }; - } -} \ No newline at end of file diff --git a/ts/proxies/smart-proxy/route-connection-handler.ts b/ts/proxies/smart-proxy/route-connection-handler.ts deleted file mode 100644 index 5f41274..0000000 --- a/ts/proxies/smart-proxy/route-connection-handler.ts +++ /dev/null @@ -1,1712 +0,0 @@ -import * as plugins from '../../plugins.js'; -import type { IConnectionRecord, ISmartProxyOptions } from './models/interfaces.js'; -import { logger } from '../../core/utils/logger.js'; -import { connectionLogDeduplicator } from '../../core/utils/log-deduplicator.js'; -// Route checking functions have been removed -import type { IRouteConfig, IRouteAction, IRouteTarget } from './models/route-types.js'; -import type { IRouteContext } from '../../core/models/route-context.js'; -import { cleanupSocket, setupSocketHandlers, createSocketWithErrorHandler, setupBidirectionalForwarding } from '../../core/utils/socket-utils.js'; -import { WrappedSocket } from '../../core/models/wrapped-socket.js'; -import { getUnderlyingSocket } from '../../core/models/socket-types.js'; -import { ProxyProtocolParser } from '../../core/utils/proxy-protocol.js'; -import type { SmartProxy } from './smart-proxy.js'; -import { ProtocolDetector } from '../../detection/index.js'; - -/** - * Handles new connection processing and setup logic with support for route-based configuration - */ -export class RouteConnectionHandler { - // Note: Route context caching was considered but not implemented - // as route contexts are lightweight and should be created fresh - // for each connection to ensure accurate context data - - // RxJS Subject for new connections - public newConnectionSubject = new plugins.smartrx.rxjs.Subject(); - - constructor( - private smartProxy: SmartProxy - ) {} - - - /** - * Create a route context object for port and host mapping functions - */ - private createRouteContext(options: { - connectionId: string; - port: number; - domain?: string; - clientIp: string; - serverIp: string; - isTls: boolean; - tlsVersion?: string; - routeName?: string; - routeId?: string; - path?: string; - query?: string; - headers?: Record; - }): IRouteContext { - return { - // Connection information - port: options.port, - domain: options.domain, - clientIp: options.clientIp, - serverIp: options.serverIp, - path: options.path, - query: options.query, - headers: options.headers, - - // TLS information - isTls: options.isTls, - tlsVersion: options.tlsVersion, - - // Route information - routeName: options.routeName, - routeId: options.routeId, - - // Additional properties - timestamp: Date.now(), - connectionId: options.connectionId, - }; - } - - /** - * Determines if SNI is required for routing decisions on this port. - * - * SNI is REQUIRED when: - * - Multiple routes exist on this port (need SNI to pick correct route) - * - Route has dynamic target function (needs ctx.domain) - * - Route has specific domain restriction (strict validation) - * - * SNI is NOT required when: - * - TLS termination mode (HttpProxy handles session resumption) - * - Single route with static target and no domain restriction (or wildcard) - */ - private calculateSniRequirement(port: number): boolean { - const routesOnPort = this.smartProxy.routeManager.getRoutesForPort(port); - - // No routes = no SNI requirement (will fail routing anyway) - if (routesOnPort.length === 0) return false; - - // Check if any route terminates TLS - if so, SNI not required - // (HttpProxy handles session resumption internally) - const hasTermination = routesOnPort.some(route => - route.action.tls?.mode === 'terminate' || - route.action.tls?.mode === 'terminate-and-reencrypt' - ); - if (hasTermination) return false; - - // Multiple routes = need SNI to pick the correct route - if (routesOnPort.length > 1) return true; - - // Single route - check if it needs SNI for validation or routing - const route = routesOnPort[0]; - - // Dynamic host selection requires SNI (function receives ctx.domain) - const hasDynamicTarget = route.action.targets?.some(t => typeof t.host === 'function'); - if (hasDynamicTarget) return true; - - // Specific domain restriction requires SNI for strict validation - const hasSpecificDomain = route.match.domains && !this.isWildcardOnly(route.match.domains); - if (hasSpecificDomain) return true; - - // Single route, static target(s), no domain restriction = SNI not required - return false; - } - - /** - * Check if domains config is wildcard-only (matches everything) - */ - private isWildcardOnly(domains: string | string[]): boolean { - const domainList = Array.isArray(domains) ? domains : [domains]; - return domainList.length === 1 && domainList[0] === '*'; - } - - /** - * Handle a new incoming connection - */ - public handleConnection(socket: plugins.net.Socket): void { - const remoteIP = socket.remoteAddress || ''; - const localPort = socket.localPort || 0; - - // Always wrap the socket to prepare for potential PROXY protocol - const wrappedSocket = new WrappedSocket(socket); - - // If this is from a trusted proxy, log it - if (this.smartProxy.settings.proxyIPs?.includes(remoteIP)) { - logger.log('debug', `Connection from trusted proxy ${remoteIP}, PROXY protocol parsing will be enabled`, { - remoteIP, - component: 'route-handler' - }); - } - - // Generate connection ID first for atomic IP validation and tracking - const connectionId = this.smartProxy.connectionManager.generateConnectionId(); - const clientIP = wrappedSocket.remoteAddress || ''; - - // Atomically validate IP and track the connection to prevent race conditions - // This ensures concurrent connections from the same IP are properly limited - const ipValidation = this.smartProxy.securityManager.validateAndTrackIP(clientIP, connectionId); - if (!ipValidation.allowed) { - connectionLogDeduplicator.log( - 'ip-rejected', - 'warn', - `Connection rejected from ${clientIP}`, - { remoteIP: clientIP, reason: ipValidation.reason, component: 'route-handler' }, - clientIP - ); - cleanupSocket(wrappedSocket.socket, `rejected-${ipValidation.reason}`, { immediate: true }); - return; - } - - // Create a new connection record with the wrapped socket - // Skip IP tracking since we already did it atomically above - const record = this.smartProxy.connectionManager.createConnection(wrappedSocket, { - connectionId, - skipIpTracking: true - }); - if (!record) { - // Connection was rejected due to global limit - clean up the IP tracking we did - this.smartProxy.securityManager.removeConnectionByIP(clientIP, connectionId); - return; - } - - // Emit new connection event - this.newConnectionSubject.next(record); - // Note: connectionId was already generated above for atomic IP tracking - - // Apply socket optimizations (apply to underlying socket) - const underlyingSocket = wrappedSocket.socket; - underlyingSocket.setNoDelay(this.smartProxy.settings.noDelay); - - // Apply keep-alive settings if enabled - if (this.smartProxy.settings.keepAlive) { - underlyingSocket.setKeepAlive(true, this.smartProxy.settings.keepAliveInitialDelay); - record.hasKeepAlive = true; - - // Apply enhanced TCP keep-alive options if enabled - if (this.smartProxy.settings.enableKeepAliveProbes) { - try { - // These are platform-specific and may not be available - if ('setKeepAliveProbes' in underlyingSocket) { - (underlyingSocket as any).setKeepAliveProbes(10); - } - if ('setKeepAliveInterval' in underlyingSocket) { - (underlyingSocket as any).setKeepAliveInterval(1000); - } - } catch (err) { - // Ignore errors - these are optional enhancements - if (this.smartProxy.settings.enableDetailedLogging) { - logger.log('warn', `Enhanced TCP keep-alive settings not supported`, { connectionId, error: err, component: 'route-handler' }); - } - } - } - } - - if (this.smartProxy.settings.enableDetailedLogging) { - logger.log('info', - `New connection from ${remoteIP} on port ${localPort}. ` + - `Keep-Alive: ${record.hasKeepAlive ? 'Enabled' : 'Disabled'}. ` + - `Active connections: ${this.smartProxy.connectionManager.getConnectionCount()}`, - { - connectionId, - remoteIP, - localPort, - keepAlive: record.hasKeepAlive ? 'Enabled' : 'Disabled', - activeConnections: this.smartProxy.connectionManager.getConnectionCount(), - component: 'route-handler' - } - ); - } else { - logger.log('info', - `New connection from ${remoteIP} on port ${localPort}. Active connections: ${this.smartProxy.connectionManager.getConnectionCount()}`, - { - remoteIP, - localPort, - activeConnections: this.smartProxy.connectionManager.getConnectionCount(), - component: 'route-handler' - } - ); - } - - // Handle the connection - wait for initial data to determine if it's TLS - this.handleInitialData(wrappedSocket, record); - } - - /** - * Handle initial data from a connection to determine routing - */ - private handleInitialData(socket: plugins.net.Socket | WrappedSocket, record: IConnectionRecord): void { - const connectionId = record.id; - const localPort = record.localPort; - let initialDataReceived = false; - - // Check if any routes on this port require TLS handling - const allRoutes = this.smartProxy.routeManager.getRoutes(); - const needsTlsHandling = allRoutes.some(route => { - // Check if route matches this port - const matchesPort = this.smartProxy.routeManager.getRoutesForPort(localPort).includes(route); - - return matchesPort && - route.action.type === 'forward' && - route.action.tls && - (route.action.tls.mode === 'terminate' || - route.action.tls.mode === 'passthrough'); - }); - - // Smart SNI requirement calculation - // Determines if we need SNI for routing decisions on this port - const needsSniForRouting = this.calculateSniRequirement(localPort); - const allowSessionTicket = !needsSniForRouting; - - // If no routes require TLS handling and it's not port 443, route immediately - if (!needsTlsHandling && localPort !== 443) { - // Extract underlying socket for socket-utils functions - const underlyingSocket = getUnderlyingSocket(socket); - // Set up proper socket handlers for immediate routing - setupSocketHandlers( - underlyingSocket, - (reason) => { - // Always cleanup when incoming socket closes - // This prevents connection accumulation in proxy chains - logger.log('debug', `Connection ${connectionId} closed during immediate routing: ${reason}`, { - connectionId, - remoteIP: record.remoteIP, - reason, - hasOutgoing: !!record.outgoing, - outgoingState: record.outgoing?.readyState, - component: 'route-handler' - }); - - // If there's a pending or established outgoing connection, destroy it - if (record.outgoing && !record.outgoing.destroyed) { - logger.log('debug', `Destroying outgoing connection for ${connectionId}`, { - connectionId, - outgoingState: record.outgoing.readyState, - component: 'route-handler' - }); - record.outgoing.destroy(); - } - - // Always cleanup the connection record - this.smartProxy.connectionManager.cleanupConnection(record, reason); - }, - undefined, // Use default timeout handler - 'immediate-route-client' - ); - - // Route immediately for non-TLS connections - this.routeConnection(socket, record, '', undefined); - return; - } - - // Otherwise, wait for initial data to check if it's TLS - // Set an initial timeout for handshake data - let initialTimeout: NodeJS.Timeout | null = setTimeout(() => { - if (!initialDataReceived) { - logger.log('warn', `No initial data received from ${record.remoteIP} after ${this.smartProxy.settings.initialDataTimeout}ms for connection ${connectionId}`, { - connectionId, - timeout: this.smartProxy.settings.initialDataTimeout, - remoteIP: record.remoteIP, - component: 'route-handler' - }); - - // Add a grace period - setTimeout(() => { - if (!initialDataReceived) { - logger.log('warn', `Final initial data timeout after grace period for connection ${connectionId}`, { - connectionId, - component: 'route-handler' - }); - if (record.incomingTerminationReason === null) { - record.incomingTerminationReason = 'initial_timeout'; - this.smartProxy.connectionManager.incrementTerminationStat('incoming', 'initial_timeout'); - } - socket.end(); - this.smartProxy.connectionManager.cleanupConnection(record, 'initial_timeout'); - } - }, 30000); - } - }, this.smartProxy.settings.initialDataTimeout!); - - // Make sure timeout doesn't keep the process alive - if (initialTimeout.unref) { - initialTimeout.unref(); - } - - // Set up error handler - socket.on('error', this.smartProxy.connectionManager.handleError('incoming', record)); - - // Add close/end handlers to catch immediate disconnections - socket.once('close', () => { - if (!initialDataReceived) { - logger.log('warn', `Connection ${connectionId} closed before sending initial data`, { - connectionId, - remoteIP: record.remoteIP, - component: 'route-handler' - }); - if (initialTimeout) { - clearTimeout(initialTimeout); - initialTimeout = null; - } - this.smartProxy.connectionManager.cleanupConnection(record, 'closed_before_data'); - } - }); - - socket.once('end', () => { - if (!initialDataReceived) { - logger.log('debug', `Connection ${connectionId} ended before sending initial data`, { - connectionId, - remoteIP: record.remoteIP, - component: 'route-handler' - }); - if (initialTimeout) { - clearTimeout(initialTimeout); - initialTimeout = null; - } - // Don't cleanup on 'end' - wait for 'close' - } - }); - - // Handler for processing initial data (after potential PROXY protocol) - const processInitialData = async (chunk: Buffer) => { - // Create connection context for protocol detection - const context = ProtocolDetector.createConnectionContext({ - sourceIp: record.remoteIP, - sourcePort: socket.remotePort || 0, - destIp: socket.localAddress || '', - destPort: socket.localPort || 0, - socketId: record.id - }); - - const detectionResult = await ProtocolDetector.detectWithContext( - chunk, - context, - { extractFullHeaders: false } // Only extract essential info for routing - ); - - // Block non-TLS connections on port 443 - if (localPort === 443 && detectionResult.protocol !== 'tls') { - logger.log('warn', `Non-TLS connection ${record.id} detected on port 443. Terminating connection - only TLS traffic is allowed on standard HTTPS port.`, { - connectionId: record.id, - detectedProtocol: detectionResult.protocol, - message: 'Terminating connection - only TLS traffic is allowed on standard HTTPS port.', - component: 'route-handler' - }); - if (record.incomingTerminationReason === null) { - record.incomingTerminationReason = 'non_tls_blocked'; - this.smartProxy.connectionManager.incrementTerminationStat('incoming', 'non_tls_blocked'); - } - socket.end(); - this.smartProxy.connectionManager.cleanupConnection(record, 'non_tls_blocked'); - return; - } - - // Extract domain and protocol info - let serverName = ''; - if (detectionResult.protocol === 'tls') { - record.isTLS = true; - serverName = detectionResult.connectionInfo.domain || ''; - - // Lock the connection to the negotiated SNI - record.lockedDomain = serverName; - - // Check if we should reject connections without SNI - if (!serverName && allowSessionTicket === false) { - logger.log('warn', `No SNI detected in TLS ClientHello for connection ${record.id}; sending TLS alert`, { - connectionId: record.id, - component: 'route-handler' - }); - if (record.incomingTerminationReason === null) { - record.incomingTerminationReason = 'session_ticket_blocked_no_sni'; - this.smartProxy.connectionManager.incrementTerminationStat( - 'incoming', - 'session_ticket_blocked_no_sni' - ); - } - const alert = Buffer.from([0x15, 0x03, 0x03, 0x00, 0x02, 0x01, 0x70]); - try { - // Count the alert bytes being sent - record.bytesSent += alert.length; - if (this.smartProxy.metricsCollector) { - this.smartProxy.metricsCollector.recordBytes(record.id, 0, alert.length); - } - - socket.cork(); - socket.write(alert); - socket.uncork(); - socket.end(); - } catch { - socket.end(); - } - this.smartProxy.connectionManager.cleanupConnection(record, 'session_ticket_blocked_no_sni'); - return; - } - - if (this.smartProxy.settings.enableDetailedLogging) { - logger.log('info', `TLS connection with SNI`, { - connectionId: record.id, - serverName: serverName || '(empty)', - component: 'route-handler' - }); - } - } else if (detectionResult.protocol === 'http') { - // For HTTP, extract domain from Host header - serverName = detectionResult.connectionInfo.domain || ''; - - // Store HTTP-specific info for later use - record.httpInfo = { - method: detectionResult.connectionInfo.method, - path: detectionResult.connectionInfo.path, - headers: detectionResult.connectionInfo.headers - }; - - if (this.smartProxy.settings.enableDetailedLogging) { - logger.log('info', `HTTP connection detected`, { - connectionId: record.id, - domain: serverName || '(no host header)', - method: detectionResult.connectionInfo.method, - path: detectionResult.connectionInfo.path, - component: 'route-handler' - }); - } - } - - // Find the appropriate route for this connection - this.routeConnection(socket, record, serverName, chunk, detectionResult); - }; - - // First data handler to capture initial TLS handshake or PROXY protocol - socket.once('data', async (chunk: Buffer) => { - // Clear the initial timeout since we've received data - if (initialTimeout) { - clearTimeout(initialTimeout); - initialTimeout = null; - } - - initialDataReceived = true; - record.hasReceivedInitialData = true; - - // Check if this is from a trusted proxy and might have PROXY protocol - if (this.smartProxy.settings.proxyIPs?.includes(socket.remoteAddress || '') && this.smartProxy.settings.acceptProxyProtocol !== false) { - // Check if this starts with PROXY protocol - if (chunk.toString('ascii', 0, Math.min(6, chunk.length)).startsWith('PROXY ')) { - try { - const parseResult = ProxyProtocolParser.parse(chunk); - - if (parseResult.proxyInfo) { - // Update the wrapped socket with real client info (if it's a WrappedSocket) - if (socket instanceof WrappedSocket) { - socket.setProxyInfo(parseResult.proxyInfo.sourceIP, parseResult.proxyInfo.sourcePort); - } - - // Update connection record with real client info - record.remoteIP = parseResult.proxyInfo.sourceIP; - record.remotePort = parseResult.proxyInfo.sourcePort; - - logger.log('info', `PROXY protocol parsed successfully`, { - connectionId, - realClientIP: parseResult.proxyInfo.sourceIP, - realClientPort: parseResult.proxyInfo.sourcePort, - proxyIP: socket.remoteAddress, - component: 'route-handler' - }); - - // Process remaining data if any - if (parseResult.remainingData.length > 0) { - processInitialData(parseResult.remainingData); - } else { - // Wait for more data - socket.once('data', processInitialData); - } - return; - } - } catch (error) { - logger.log('error', `Failed to parse PROXY protocol from trusted proxy`, { - connectionId, - error: error.message, - proxyIP: socket.remoteAddress, - component: 'route-handler' - }); - // Continue processing as normal data - } - } - } - - // Process as normal data (no PROXY protocol) - processInitialData(chunk); - }); - } - - /** - * Route the connection based on match criteria - */ - private routeConnection( - socket: plugins.net.Socket | WrappedSocket, - record: IConnectionRecord, - serverName: string, - initialChunk?: Buffer, - detectionResult?: any // Using any temporarily to avoid circular dependency issues - ): void { - const connectionId = record.id; - const localPort = record.localPort; - const remoteIP = record.remoteIP; - - // Check if this is an HTTP proxy port - const isHttpProxyPort = this.smartProxy.settings.useHttpProxy?.includes(localPort); - - // For HTTP proxy ports without TLS, skip domain check since domain info comes from HTTP headers - const skipDomainCheck = isHttpProxyPort && !record.isTLS; - - // Create route context for matching - const routeContext: IRouteContext = { - port: localPort, - domain: skipDomainCheck ? undefined : serverName, // Skip domain if HTTP proxy without TLS - clientIp: remoteIP, - serverIp: socket.localAddress || '', - path: undefined, // We don't have path info at this point - isTls: record.isTLS, - tlsVersion: undefined, // We don't extract TLS version yet - timestamp: Date.now(), - connectionId: record.id - }; - - // Find matching route - const routeMatch = this.smartProxy.routeManager.findMatchingRoute(routeContext); - - if (!routeMatch) { - logger.log('warn', `No route found for ${serverName || 'connection'} on port ${localPort} (connection: ${connectionId})`, { - connectionId, - serverName: serverName || 'connection', - localPort, - component: 'route-handler' - }); - - // No matching route, use default/fallback handling - logger.log('info', `Using default route handling for connection ${connectionId}`, { - connectionId, - component: 'route-handler' - }); - - // Check default security settings - const defaultSecuritySettings = this.smartProxy.settings.defaults?.security; - if (defaultSecuritySettings) { - if (defaultSecuritySettings.ipAllowList && defaultSecuritySettings.ipAllowList.length > 0) { - const isAllowed = this.smartProxy.securityManager.isIPAuthorized( - remoteIP, - defaultSecuritySettings.ipAllowList, - defaultSecuritySettings.ipBlockList || [] - ); - - if (!isAllowed) { - logger.log('warn', `IP ${remoteIP} not in default allowed list for connection ${connectionId}`, { - connectionId, - remoteIP, - component: 'route-handler' - }); - socket.end(); - this.smartProxy.connectionManager.cleanupConnection(record, 'ip_blocked'); - return; - } - } - } - - // Setup direct connection with default settings - if (this.smartProxy.settings.defaults?.target) { - // Use defaults from configuration - const targetHost = this.smartProxy.settings.defaults.target.host; - const targetPort = this.smartProxy.settings.defaults.target.port; - - return this.setupDirectConnection( - socket, - record, - serverName, - initialChunk, - undefined, - targetHost, - targetPort - ); - } else { - // No default target available, terminate the connection - logger.log('warn', `No default target configured for connection ${connectionId}. Closing connection`, { - connectionId, - component: 'route-handler' - }); - socket.end(); - this.smartProxy.connectionManager.cleanupConnection(record, 'no_default_target'); - return; - } - } - - // A matching route was found - const route = routeMatch.route; - - if (this.smartProxy.settings.enableDetailedLogging) { - logger.log('info', `Route matched`, { - connectionId, - routeName: route.name || 'unnamed', - serverName: serverName || 'connection', - localPort, - component: 'route-handler' - }); - } - - // Apply route-specific security checks - if (route.security) { - // Check IP allow/block lists - if (route.security.ipAllowList || route.security.ipBlockList) { - const isIPAllowed = this.smartProxy.securityManager.isIPAuthorized( - remoteIP, - route.security.ipAllowList || [], - route.security.ipBlockList || [] - ); - - if (!isIPAllowed) { - // Deduplicated logging for route IP blocks - connectionLogDeduplicator.log( - 'ip-rejected', - 'warn', - `IP blocked by route security`, - { - connectionId, - remoteIP, - routeName: route.name || 'unnamed', - reason: 'route-ip-blocked', - component: 'route-handler' - }, - remoteIP - ); - socket.end(); - this.smartProxy.connectionManager.cleanupConnection(record, 'route_ip_blocked'); - return; - } - } - - // Check max connections per route - if (route.security.maxConnections !== undefined) { - const routeId = route.id || route.name || 'unnamed'; - const currentConnections = this.smartProxy.connectionManager.getConnectionCountByRoute(routeId); - - if (currentConnections >= route.security.maxConnections) { - // Deduplicated logging for route connection limits - connectionLogDeduplicator.log( - 'connection-rejected', - 'warn', - `Route connection limit reached`, - { - connectionId, - routeName: route.name, - currentConnections, - maxConnections: route.security.maxConnections, - reason: 'route-limit', - component: 'route-handler' - }, - `route-limit-${route.name}` - ); - socket.end(); - this.smartProxy.connectionManager.cleanupConnection(record, 'route_connection_limit'); - return; - } - } - - // Check authentication requirements - if (route.security.authentication || route.security.basicAuth || route.security.jwtAuth) { - // Authentication checks would typically happen at the HTTP layer - // For non-HTTP connections or passthrough, we can't enforce authentication - if (route.action.type === 'forward' && route.action.tls?.mode !== 'terminate') { - logger.log('warn', `Route ${route.name} has authentication configured but it cannot be enforced for non-terminated connections`, { - connectionId, - routeName: route.name, - tlsMode: route.action.tls?.mode || 'none', - component: 'route-handler' - }); - } - } - } - - // Handle the route based on its action type - switch (route.action.type) { - case 'forward': - return this.handleForwardAction(socket, record, route, initialChunk, detectionResult); - - case 'socket-handler': - logger.log('info', `Handling socket-handler action for route ${route.name}`, { - connectionId, - routeName: route.name, - component: 'route-handler' - }); - this.handleSocketHandlerAction(socket, record, route, initialChunk); - return; - - default: - logger.log('error', `Unknown action type '${(route.action as any).type}' for connection ${connectionId}`, { - connectionId, - actionType: (route.action as any).type, - component: 'route-handler' - }); - socket.end(); - this.smartProxy.connectionManager.cleanupConnection(record, 'unknown_action'); - } - } - - /** - * Select the appropriate target from the targets array based on sub-matching criteria - */ - private selectTarget( - targets: IRouteTarget[], - context: { - port: number; - path?: string; - headers?: Record; - method?: string; - } - ): IRouteTarget | null { - // Sort targets by priority (higher first) - const sortedTargets = [...targets].sort((a, b) => (b.priority || 0) - (a.priority || 0)); - - // Find the first matching target - for (const target of sortedTargets) { - if (!target.match) { - // No match criteria means this is a default/fallback target - return target; - } - - // Check port match - if (target.match.ports && !target.match.ports.includes(context.port)) { - continue; - } - - // Check path match (supports wildcards) - if (target.match.path && context.path) { - const pathPattern = target.match.path.replace(/\*/g, '.*'); - const pathRegex = new RegExp(`^${pathPattern}$`); - if (!pathRegex.test(context.path)) { - continue; - } - } - - // Check method match - if (target.match.method && context.method && !target.match.method.includes(context.method)) { - continue; - } - - // Check headers match - if (target.match.headers && context.headers) { - let headersMatch = true; - for (const [key, pattern] of Object.entries(target.match.headers)) { - const headerValue = context.headers[key.toLowerCase()]; - if (!headerValue) { - headersMatch = false; - break; - } - - if (pattern instanceof RegExp) { - if (!pattern.test(headerValue)) { - headersMatch = false; - break; - } - } else if (headerValue !== pattern) { - headersMatch = false; - break; - } - } - if (!headersMatch) { - continue; - } - } - - // All criteria matched - return target; - } - - // No matching target found, return the first target without match criteria (default) - return sortedTargets.find(t => !t.match) || null; - } - - /** - * Handle a forward action for a route - */ - private handleForwardAction( - socket: plugins.net.Socket | WrappedSocket, - record: IConnectionRecord, - route: IRouteConfig, - initialChunk?: Buffer, - detectionResult?: any // Using any temporarily to avoid circular dependency issues - ): void { - const connectionId = record.id; - const action = route.action as IRouteAction; - - // Store the route config in the connection record for metrics and other uses - record.routeConfig = route; - record.routeId = route.id || route.name || 'unnamed'; - - // Track connection by route - this.smartProxy.connectionManager.trackConnectionByRoute(record.routeId, record.id); - - // Check if this route uses NFTables for forwarding - if (action.forwardingEngine === 'nftables') { - // NFTables handles packet forwarding at the kernel level - // The application should NOT interfere with these connections - - // Log the connection for monitoring purposes - if (this.smartProxy.settings.enableDetailedLogging) { - logger.log('info', `NFTables forwarding (kernel-level)`, { - connectionId: record.id, - source: `${record.remoteIP}:${socket.remotePort}`, - destination: `${socket.localAddress}:${record.localPort}`, - routeName: route.name || 'unnamed', - domain: record.lockedDomain || 'n/a', - component: 'route-handler' - }); - } else { - logger.log('info', `NFTables forwarding`, { - connectionId: record.id, - remoteIP: record.remoteIP, - localPort: record.localPort, - routeName: route.name || 'unnamed', - component: 'route-handler' - }); - } - - // Additional NFTables-specific logging if configured - if (action.nftables) { - const nftConfig = action.nftables; - if (this.smartProxy.settings.enableDetailedLogging) { - logger.log('info', `NFTables config`, { - connectionId: record.id, - protocol: nftConfig.protocol || 'tcp', - preserveSourceIP: nftConfig.preserveSourceIP || false, - priority: nftConfig.priority || 'default', - maxRate: nftConfig.maxRate || 'unlimited', - component: 'route-handler' - }); - } - } - - // For NFTables routes, we should still track the connection but not interfere - // Mark the connection as using network proxy so it's cleaned up properly - record.usingNetworkProxy = true; - - // We don't close the socket - just let it remain open - // The kernel-level NFTables rules will handle the actual forwarding - - // Set up cleanup when the socket eventually closes - socket.once('close', () => { - this.smartProxy.connectionManager.cleanupConnection(record, 'nftables_closed'); - }); - - return; - } - - // Select the appropriate target from the targets array - if (!action.targets || action.targets.length === 0) { - logger.log('error', `Forward action missing targets configuration for connection ${connectionId}`, { - connectionId, - component: 'route-handler' - }); - socket.end(); - this.smartProxy.connectionManager.cleanupConnection(record, 'missing_targets'); - return; - } - - // Create context for target selection - const targetSelectionContext = { - port: record.localPort, - path: record.httpInfo?.path, - headers: record.httpInfo?.headers, - method: record.httpInfo?.method - }; - - const selectedTarget = this.selectTarget(action.targets, targetSelectionContext); - if (!selectedTarget) { - logger.log('error', `No matching target found for connection ${connectionId}`, { - connectionId, - port: targetSelectionContext.port, - component: 'route-handler' - }); - socket.end(); - this.smartProxy.connectionManager.cleanupConnection(record, 'no_matching_target'); - return; - } - - // Create the routing context for this connection - const routeContext = this.createRouteContext({ - connectionId: record.id, - port: record.localPort, - domain: record.lockedDomain, - clientIp: record.remoteIP, - serverIp: socket.localAddress || '', - isTls: record.isTLS || false, - tlsVersion: record.tlsVersion, - routeName: route.name, - routeId: route.id, - }); - - // Note: Route contexts are not cached to ensure fresh data for each connection - - // Determine host using function or static value - let targetHost: string | string[]; - if (typeof selectedTarget.host === 'function') { - try { - targetHost = selectedTarget.host(routeContext); - if (this.smartProxy.settings.enableDetailedLogging) { - logger.log('info', `Dynamic host resolved to ${Array.isArray(targetHost) ? targetHost.join(', ') : targetHost} for connection ${connectionId}`, { - connectionId, - targetHost: Array.isArray(targetHost) ? targetHost.join(', ') : targetHost, - component: 'route-handler' - }); - } - } catch (err) { - logger.log('error', `Error in host mapping function for connection ${connectionId}: ${err}`, { - connectionId, - error: err, - component: 'route-handler' - }); - socket.end(); - this.smartProxy.connectionManager.cleanupConnection(record, 'host_mapping_error'); - return; - } - } else { - targetHost = selectedTarget.host; - } - - // If an array of hosts, select one randomly for load balancing - const selectedHost = Array.isArray(targetHost) - ? targetHost[Math.floor(Math.random() * targetHost.length)] - : targetHost; - - // Determine port using function or static value - let targetPort: number; - if (typeof selectedTarget.port === 'function') { - try { - targetPort = selectedTarget.port(routeContext); - if (this.smartProxy.settings.enableDetailedLogging) { - logger.log('info', `Dynamic port mapping from ${record.localPort} to ${targetPort} for connection ${connectionId}`, { - connectionId, - sourcePort: record.localPort, - targetPort, - component: 'route-handler' - }); - } - // Store the resolved target port in the context for potential future use - routeContext.targetPort = targetPort; - } catch (err) { - logger.log('error', `Error in port mapping function for connection ${connectionId}: ${err}`, { - connectionId, - error: err, - component: 'route-handler' - }); - socket.end(); - this.smartProxy.connectionManager.cleanupConnection(record, 'port_mapping_error'); - return; - } - } else if (selectedTarget.port === 'preserve') { - // Use incoming port if port is 'preserve' - targetPort = record.localPort; - } else { - // Use static port from configuration - targetPort = selectedTarget.port; - } - - // Store the resolved host in the context - routeContext.targetHost = selectedHost; - - // Get effective settings (target overrides route-level settings) - const effectiveTls = selectedTarget.tls || action.tls; - const effectiveWebsocket = selectedTarget.websocket || action.websocket; - const effectiveSendProxyProtocol = selectedTarget.sendProxyProtocol !== undefined - ? selectedTarget.sendProxyProtocol - : action.sendProxyProtocol; - - // Determine if this needs TLS handling - if (effectiveTls) { - switch (effectiveTls.mode) { - case 'passthrough': - // For TLS passthrough, just forward directly - if (this.smartProxy.settings.enableDetailedLogging) { - logger.log('info', `Using TLS passthrough to ${selectedHost}:${targetPort} for connection ${connectionId}`, { - connectionId, - targetHost: selectedHost, - targetPort, - component: 'route-handler' - }); - } - - return this.setupDirectConnection( - socket, - record, - record.lockedDomain, - initialChunk, - undefined, - selectedHost, - targetPort - ); - - case 'terminate': - case 'terminate-and-reencrypt': - // For TLS termination, use HttpProxy - if (this.smartProxy.httpProxyBridge.getHttpProxy()) { - if (this.smartProxy.settings.enableDetailedLogging) { - logger.log('info', `Using HttpProxy for TLS termination to ${Array.isArray(selectedTarget.host) ? selectedTarget.host.join(', ') : selectedTarget.host} for connection ${connectionId}`, { - connectionId, - targetHost: selectedTarget.host, - component: 'route-handler' - }); - } - - // If we have an initial chunk with TLS data, start processing it - if (initialChunk && record.isTLS) { - this.smartProxy.httpProxyBridge.forwardToHttpProxy( - connectionId, - socket, - record, - initialChunk, - this.smartProxy.settings.httpProxyPort || 8443, - (reason) => this.smartProxy.connectionManager.cleanupConnection(record, reason) - ); - return; - } - - // This shouldn't normally happen - we should have TLS data at this point - logger.log('error', `TLS termination route without TLS data for connection ${connectionId}`, { - connectionId, - component: 'route-handler' - }); - socket.end(); - this.smartProxy.connectionManager.cleanupConnection(record, 'tls_error'); - return; - } else { - logger.log('error', `HttpProxy not available for TLS termination for connection ${connectionId}`, { - connectionId, - component: 'route-handler' - }); - socket.end(); - this.smartProxy.connectionManager.cleanupConnection(record, 'no_http_proxy'); - return; - } - } - } else { - // No TLS settings - check if this port should use HttpProxy - const isHttpProxyPort = this.smartProxy.settings.useHttpProxy?.includes(record.localPort); - - // Debug logging - if (this.smartProxy.settings.enableDetailedLogging) { - logger.log('debug', `Checking HttpProxy forwarding: port=${record.localPort}, useHttpProxy=${JSON.stringify(this.smartProxy.settings.useHttpProxy)}, isHttpProxyPort=${isHttpProxyPort}, hasHttpProxy=${!!this.smartProxy.httpProxyBridge.getHttpProxy()}`, { - connectionId, - localPort: record.localPort, - useHttpProxy: this.smartProxy.settings.useHttpProxy, - isHttpProxyPort, - hasHttpProxy: !!this.smartProxy.httpProxyBridge.getHttpProxy(), - component: 'route-handler' - }); - } - - if (isHttpProxyPort && this.smartProxy.httpProxyBridge.getHttpProxy()) { - // Forward non-TLS connections to HttpProxy if configured - if (this.smartProxy.settings.enableDetailedLogging) { - logger.log('info', `Using HttpProxy for non-TLS connection ${connectionId} on port ${record.localPort}`, { - connectionId, - port: record.localPort, - component: 'route-handler' - }); - } - - this.smartProxy.httpProxyBridge.forwardToHttpProxy( - connectionId, - socket, - record, - initialChunk, - this.smartProxy.settings.httpProxyPort || 8443, - (reason) => this.smartProxy.connectionManager.cleanupConnection(record, reason) - ); - return; - } else { - // Basic forwarding - if (this.smartProxy.settings.enableDetailedLogging) { - logger.log('info', `Using basic forwarding to ${Array.isArray(selectedTarget.host) ? selectedTarget.host.join(', ') : selectedTarget.host}:${selectedTarget.port} for connection ${connectionId}`, { - connectionId, - targetHost: selectedTarget.host, - targetPort: selectedTarget.port, - component: 'route-handler' - }); - } - - // Get the appropriate host value - let targetHost: string; - - if (typeof selectedTarget.host === 'function') { - // For function-based host, use the same routeContext created earlier - const hostResult = selectedTarget.host(routeContext); - targetHost = Array.isArray(hostResult) - ? hostResult[Math.floor(Math.random() * hostResult.length)] - : hostResult; - } else { - // For static host value - targetHost = Array.isArray(selectedTarget.host) - ? selectedTarget.host[Math.floor(Math.random() * selectedTarget.host.length)] - : selectedTarget.host; - } - - // Determine port - either function-based, static, or preserve incoming port - let targetPort: number; - if (typeof selectedTarget.port === 'function') { - targetPort = selectedTarget.port(routeContext); - } else if (selectedTarget.port === 'preserve') { - targetPort = record.localPort; - } else { - targetPort = selectedTarget.port; - } - - // Update the connection record and context with resolved values - record.targetHost = targetHost; - record.targetPort = targetPort; - - return this.setupDirectConnection( - socket, - record, - record.lockedDomain, - initialChunk, - undefined, - targetHost, - targetPort - ); - } - } - } - - /** - * Handle a socket-handler action for a route - */ - private async handleSocketHandlerAction( - socket: plugins.net.Socket | WrappedSocket, - record: IConnectionRecord, - route: IRouteConfig, - initialChunk?: Buffer - ): Promise { - const connectionId = record.id; - - // Store the route config in the connection record for metrics and other uses - record.routeConfig = route; - record.routeId = route.id || route.name || 'unnamed'; - - // Track connection by route - this.smartProxy.connectionManager.trackConnectionByRoute(record.routeId, record.id); - - if (!route.action.socketHandler) { - logger.log('error', 'socket-handler action missing socketHandler function', { - connectionId, - routeName: route.name, - component: 'route-handler' - }); - socket.destroy(); - this.smartProxy.connectionManager.cleanupConnection(record, 'missing_handler'); - return; - } - - // Track event listeners added by the handler so we can clean them up - const originalOn = socket.on.bind(socket); - const originalOnce = socket.once.bind(socket); - const trackedListeners: Array<{event: string; listener: (...args: any[]) => void}> = []; - - // Override socket.on to track listeners - socket.on = function(event: string, listener: (...args: any[]) => void) { - trackedListeners.push({event, listener}); - return originalOn(event, listener); - } as any; - - // Override socket.once to track listeners - socket.once = function(event: string, listener: (...args: any[]) => void) { - trackedListeners.push({event, listener}); - return originalOnce(event, listener); - } as any; - - // Set up automatic cleanup when socket closes - const cleanupHandler = () => { - // Remove all tracked listeners - for (const {event, listener} of trackedListeners) { - socket.removeListener(event, listener); - } - // Restore original methods - socket.on = originalOn; - socket.once = originalOnce; - }; - - // Listen for socket close to trigger cleanup - originalOnce('close', cleanupHandler); - originalOnce('error', cleanupHandler); - - // Create route context for the handler - const routeContext = this.createRouteContext({ - connectionId: record.id, - port: record.localPort, - domain: record.lockedDomain, - clientIp: record.remoteIP, - serverIp: socket.localAddress || '', - isTls: record.isTLS || false, - tlsVersion: record.tlsVersion, - routeName: route.name, - routeId: route.id, - }); - - try { - // Call the handler with the appropriate socket (extract underlying if needed) - const handlerSocket = getUnderlyingSocket(socket); - const result = route.action.socketHandler(handlerSocket, routeContext); - - // Handle async handlers properly - if (result instanceof Promise) { - result - .then(() => { - // Emit initial chunk after async handler completes - if (initialChunk && initialChunk.length > 0) { - socket.emit('data', initialChunk); - } - }) - .catch(error => { - logger.log('error', 'Socket handler error', { - connectionId, - routeName: route.name, - error: error.message, - component: 'route-handler' - }); - // Remove all event listeners before destroying to prevent memory leaks - socket.removeAllListeners(); - if (!socket.destroyed) { - socket.destroy(); - } - this.smartProxy.connectionManager.cleanupConnection(record, 'handler_error'); - }); - } else { - // For sync handlers, emit on next tick - if (initialChunk && initialChunk.length > 0) { - process.nextTick(() => { - socket.emit('data', initialChunk); - }); - } - } - } catch (error) { - logger.log('error', 'Socket handler error', { - connectionId, - routeName: route.name, - error: error.message, - component: 'route-handler' - }); - // Remove all event listeners before destroying to prevent memory leaks - socket.removeAllListeners(); - if (!socket.destroyed) { - socket.destroy(); - } - this.smartProxy.connectionManager.cleanupConnection(record, 'handler_error'); - } - } - - - /** - * Sets up a direct connection to the target - */ - private setupDirectConnection( - socket: plugins.net.Socket | WrappedSocket, - record: IConnectionRecord, - serverName?: string, - initialChunk?: Buffer, - overridePort?: number, - targetHost?: string, - targetPort?: number - ): void { - const connectionId = record.id; - - // Determine target host and port if not provided - const finalTargetHost = - targetHost || record.targetHost || this.smartProxy.settings.defaults?.target?.host || 'localhost'; - - // Determine target port - const finalTargetPort = - targetPort || - record.targetPort || - (overridePort !== undefined ? overridePort : this.smartProxy.settings.defaults?.target?.port || 443); - - // Update record with final target information - record.targetHost = finalTargetHost; - record.targetPort = finalTargetPort; - - if (this.smartProxy.settings.enableDetailedLogging) { - logger.log('info', `Setting up direct connection ${connectionId} to ${finalTargetHost}:${finalTargetPort}`, { - connectionId, - targetHost: finalTargetHost, - targetPort: finalTargetPort, - component: 'route-handler' - }); - } - - // Setup connection options - const connectionOptions: plugins.net.NetConnectOpts = { - host: finalTargetHost, - port: finalTargetPort, - }; - - // Preserve source IP if configured - if (this.smartProxy.settings.defaults?.preserveSourceIP || this.smartProxy.settings.preserveSourceIP) { - connectionOptions.localAddress = record.remoteIP.replace('::ffff:', ''); - } - - // Store initial data if provided - if (initialChunk) { - // Don't count bytes here - they will be counted when actually forwarded through bidirectional forwarding - record.pendingData.push(Buffer.from(initialChunk)); - record.pendingDataSize = initialChunk.length; - } - - // Create the target socket with immediate error handling - const targetSocket = createSocketWithErrorHandler({ - port: finalTargetPort, - host: finalTargetHost, - timeout: this.smartProxy.settings.connectionTimeout || 30000, // Connection timeout (default: 30s) - onError: (error) => { - // Connection failed - clean up everything immediately - // Check if connection record is still valid (client might have disconnected) - if (record.connectionClosed) { - logger.log('debug', `Backend connection failed but client already disconnected for ${connectionId}`, { - connectionId, - errorCode: (error as any).code, - component: 'route-handler' - }); - return; - } - - logger.log('error', - `Connection setup error for ${connectionId} to ${finalTargetHost}:${finalTargetPort}: ${error.message} (${(error as any).code})`, - { - connectionId, - targetHost: finalTargetHost, - targetPort: finalTargetPort, - errorMessage: error.message, - errorCode: (error as any).code, - component: 'route-handler' - } - ); - - // Log specific error types for easier debugging - if ((error as any).code === 'ECONNREFUSED') { - logger.log('error', - `Connection ${connectionId}: Target ${finalTargetHost}:${finalTargetPort} refused connection. Check if the target service is running and listening on that port.`, - { - connectionId, - targetHost: finalTargetHost, - targetPort: finalTargetPort, - recommendation: 'Check if the target service is running and listening on that port.', - component: 'route-handler' - } - ); - } - - // Resume the incoming socket to prevent it from hanging - if (socket && !socket.destroyed) { - socket.resume(); - } - - // Clean up the incoming socket - if (socket && !socket.destroyed) { - socket.destroy(); - } - - // Clean up the connection record - this is critical! - this.smartProxy.connectionManager.cleanupConnection(record, `connection_failed_${(error as any).code || 'unknown'}`); - }, - onConnect: async () => { - if (this.smartProxy.settings.enableDetailedLogging) { - logger.log('info', `Connection ${connectionId} established to target ${finalTargetHost}:${finalTargetPort}`, { - connectionId, - targetHost: finalTargetHost, - targetPort: finalTargetPort, - component: 'route-handler' - }); - } - - // Clear any error listeners added by createSocketWithErrorHandler - targetSocket.removeAllListeners('error'); - - // Add the normal error handler for established connections - targetSocket.on('error', this.smartProxy.connectionManager.handleError('outgoing', record)); - - // Check if we should send PROXY protocol header - const shouldSendProxyProtocol = record.routeConfig?.action?.sendProxyProtocol || - this.smartProxy.settings.sendProxyProtocol; - - if (shouldSendProxyProtocol) { - try { - // Generate PROXY protocol header - const proxyInfo = { - protocol: (record.remoteIP.includes(':') ? 'TCP6' : 'TCP4') as 'TCP4' | 'TCP6', - sourceIP: record.remoteIP, - sourcePort: record.remotePort || socket.remotePort || 0, - destinationIP: socket.localAddress || '', - destinationPort: socket.localPort || 0 - }; - - const proxyHeader = ProxyProtocolParser.generate(proxyInfo); - - // Note: PROXY protocol headers are sent to the backend, not to the client - // They are internal protocol overhead and shouldn't be counted in client-facing metrics - - // Send PROXY protocol header first - await new Promise((resolve, reject) => { - targetSocket.write(proxyHeader, (err) => { - if (err) { - logger.log('error', `Failed to send PROXY protocol header`, { - connectionId, - error: err.message, - component: 'route-handler' - }); - reject(err); - } else { - logger.log('info', `PROXY protocol header sent to backend`, { - connectionId, - targetHost: finalTargetHost, - targetPort: finalTargetPort, - sourceIP: proxyInfo.sourceIP, - sourcePort: proxyInfo.sourcePort, - component: 'route-handler' - }); - resolve(); - } - }); - }); - } catch (error) { - logger.log('error', `Error sending PROXY protocol header`, { - connectionId, - error: error.message, - component: 'route-handler' - }); - // Continue anyway - don't break the connection - } - } - - // Flush any pending data to target - if (record.pendingData.length > 0) { - const combinedData = Buffer.concat(record.pendingData); - - if (this.smartProxy.settings.enableDetailedLogging) { - console.log( - `[${connectionId}] Forwarding ${combinedData.length} bytes of initial data to target` - ); - } - - // Record the initial chunk bytes for metrics - record.bytesReceived += combinedData.length; - if (this.smartProxy.metricsCollector) { - this.smartProxy.metricsCollector.recordBytes(record.id, combinedData.length, 0); - } - - // Write pending data immediately - targetSocket.write(combinedData, (err) => { - if (err) { - logger.log('error', `Error writing pending data to target for connection ${connectionId}: ${err.message}`, { - connectionId, - error: err.message, - component: 'route-handler' - }); - return this.smartProxy.connectionManager.cleanupConnection(record, 'write_error'); - } - }); - - // Clear the buffer now that we've processed it - record.pendingData = []; - record.pendingDataSize = 0; - } - - // Use centralized bidirectional forwarding setup - // Extract underlying sockets for socket-utils functions - const incomingSocket = getUnderlyingSocket(socket); - - setupBidirectionalForwarding(incomingSocket, targetSocket, { - onClientData: (chunk) => { - record.bytesReceived += chunk.length; - this.smartProxy.timeoutManager.updateActivity(record); - - // Record bytes for metrics - if (this.smartProxy.metricsCollector) { - this.smartProxy.metricsCollector.recordBytes(record.id, chunk.length, 0); - } - }, - onServerData: (chunk) => { - record.bytesSent += chunk.length; - this.smartProxy.timeoutManager.updateActivity(record); - - // Record bytes for metrics - if (this.smartProxy.metricsCollector) { - this.smartProxy.metricsCollector.recordBytes(record.id, 0, chunk.length); - } - }, - onCleanup: (reason) => { - this.smartProxy.connectionManager.cleanupConnection(record, reason); - }, - enableHalfOpen: false // Default: close both when one closes (required for proxy chains) - }); - - // Apply timeouts using TimeoutManager - const timeout = this.smartProxy.timeoutManager.getEffectiveInactivityTimeout(record); - // Skip timeout for immortal connections (MAX_SAFE_INTEGER would cause issues) - if (timeout !== Number.MAX_SAFE_INTEGER) { - const safeTimeout = this.smartProxy.timeoutManager.ensureSafeTimeout(timeout); - socket.setTimeout(safeTimeout); - targetSocket.setTimeout(safeTimeout); - } - - // Log successful connection - logger.log('info', - `Connection established: ${record.remoteIP} -> ${finalTargetHost}:${finalTargetPort}` + - `${serverName ? ` (SNI: ${serverName})` : record.lockedDomain ? ` (Domain: ${record.lockedDomain})` : ''}`, - { - remoteIP: record.remoteIP, - targetHost: finalTargetHost, - targetPort: finalTargetPort, - sni: serverName || undefined, - domain: !serverName && record.lockedDomain ? record.lockedDomain : undefined, - component: 'route-handler' - } - ); - - // Add TLS renegotiation handler if needed - if (serverName) { - // Create connection info object for the existing connection - const connInfo = { - sourceIp: record.remoteIP, - sourcePort: record.incoming.remotePort || 0, - destIp: record.incoming.localAddress || '', - destPort: record.incoming.localPort || 0, - }; - - // Create a renegotiation handler function - const renegotiationHandler = this.smartProxy.tlsManager.createRenegotiationHandler( - connectionId, - serverName, - connInfo, - (_connectionId, reason) => this.smartProxy.connectionManager.cleanupConnection(record, reason) - ); - - // Store the handler in the connection record so we can remove it during cleanup - record.renegotiationHandler = renegotiationHandler; - - // Add the handler to the socket - socket.on('data', renegotiationHandler); - - if (this.smartProxy.settings.enableDetailedLogging) { - logger.log('info', `TLS renegotiation handler installed for connection ${connectionId} with SNI ${serverName}`, { - connectionId, - serverName, - component: 'route-handler' - }); - } - } - - // Set connection timeout - record.cleanupTimer = this.smartProxy.timeoutManager.setupConnectionTimeout(record, (record, reason) => { - logger.log('warn', `Connection ${connectionId} from ${record.remoteIP} exceeded max lifetime, forcing cleanup`, { - connectionId, - remoteIP: record.remoteIP, - component: 'route-handler' - }); - this.smartProxy.connectionManager.cleanupConnection(record, reason); - }); - - // Mark TLS handshake as complete for TLS connections - if (record.isTLS) { - record.tlsHandshakeComplete = true; - } - } - }); - - // Set outgoing socket immediately so it can be cleaned up if client disconnects - record.outgoing = targetSocket; - record.outgoingStartTime = Date.now(); - - // Apply socket optimizations - targetSocket.setNoDelay(this.smartProxy.settings.noDelay); - - // Apply keep-alive settings if enabled - if (this.smartProxy.settings.keepAlive) { - targetSocket.setKeepAlive(true, this.smartProxy.settings.keepAliveInitialDelay); - - // Apply enhanced TCP keep-alive options if enabled - if (this.smartProxy.settings.enableKeepAliveProbes) { - try { - if ('setKeepAliveProbes' in targetSocket) { - (targetSocket as any).setKeepAliveProbes(10); - } - if ('setKeepAliveInterval' in targetSocket) { - (targetSocket as any).setKeepAliveInterval(1000); - } - } catch (err) { - // Ignore errors - these are optional enhancements - if (this.smartProxy.settings.enableDetailedLogging) { - logger.log('warn', `Enhanced TCP keep-alive not supported for outgoing socket on connection ${connectionId}: ${err}`, { - connectionId, - error: err, - component: 'route-handler' - }); - } - } - } - } - - // Setup error handlers for incoming socket - socket.on('error', this.smartProxy.connectionManager.handleError('incoming', record)); - - // Handle timeouts with keep-alive awareness - socket.on('timeout', () => { - // For keep-alive connections, just log a warning instead of closing - if (record.hasKeepAlive) { - logger.log('warn', `Timeout event on incoming keep-alive connection ${connectionId} from ${record.remoteIP} after ${plugins.prettyMs(this.smartProxy.settings.socketTimeout || 3600000)}. Connection preserved.`, { - connectionId, - remoteIP: record.remoteIP, - timeout: plugins.prettyMs(this.smartProxy.settings.socketTimeout || 3600000), - status: 'Connection preserved', - component: 'route-handler' - }); - return; - } - - // For non-keep-alive connections, proceed with normal cleanup - logger.log('warn', `Timeout on incoming side for connection ${connectionId} from ${record.remoteIP} after ${plugins.prettyMs(this.smartProxy.settings.socketTimeout || 3600000)}`, { - connectionId, - remoteIP: record.remoteIP, - timeout: plugins.prettyMs(this.smartProxy.settings.socketTimeout || 3600000), - component: 'route-handler' - }); - if (record.incomingTerminationReason === null) { - record.incomingTerminationReason = 'timeout'; - this.smartProxy.connectionManager.incrementTerminationStat('incoming', 'timeout'); - } - this.smartProxy.connectionManager.cleanupConnection(record, 'timeout_incoming'); - }); - - targetSocket.on('timeout', () => { - // For keep-alive connections, just log a warning instead of closing - if (record.hasKeepAlive) { - logger.log('warn', `Timeout event on outgoing keep-alive connection ${connectionId} from ${record.remoteIP} after ${plugins.prettyMs(this.smartProxy.settings.socketTimeout || 3600000)}. Connection preserved.`, { - connectionId, - remoteIP: record.remoteIP, - timeout: plugins.prettyMs(this.smartProxy.settings.socketTimeout || 3600000), - status: 'Connection preserved', - component: 'route-handler' - }); - return; - } - - // For non-keep-alive connections, proceed with normal cleanup - logger.log('warn', `Timeout on outgoing side for connection ${connectionId} from ${record.remoteIP} after ${plugins.prettyMs(this.smartProxy.settings.socketTimeout || 3600000)}`, { - connectionId, - remoteIP: record.remoteIP, - timeout: plugins.prettyMs(this.smartProxy.settings.socketTimeout || 3600000), - component: 'route-handler' - }); - if (record.outgoingTerminationReason === null) { - record.outgoingTerminationReason = 'timeout'; - this.smartProxy.connectionManager.incrementTerminationStat('outgoing', 'timeout'); - } - this.smartProxy.connectionManager.cleanupConnection(record, 'timeout_outgoing'); - }); - - // Apply socket timeouts - this.smartProxy.timeoutManager.applySocketTimeouts(record); - } -} \ No newline at end of file diff --git a/ts/proxies/smart-proxy/route-orchestrator.ts b/ts/proxies/smart-proxy/route-orchestrator.ts deleted file mode 100644 index 0d2dbc6..0000000 --- a/ts/proxies/smart-proxy/route-orchestrator.ts +++ /dev/null @@ -1,297 +0,0 @@ -import { logger } from '../../core/utils/logger.js'; -import type { IRouteConfig } from './models/route-types.js'; -import type { ILogger } from '../http-proxy/models/types.js'; -import { RouteValidator } from './utils/route-validator.js'; -import { Mutex } from './utils/mutex.js'; -import type { PortManager } from './port-manager.js'; -import type { SharedRouteManager as RouteManager } from '../../core/routing/route-manager.js'; -import type { HttpProxyBridge } from './http-proxy-bridge.js'; -import type { NFTablesManager } from './nftables-manager.js'; -import type { SmartCertManager } from './certificate-manager.js'; - -/** - * Orchestrates route updates and coordination between components - * Extracted from SmartProxy to reduce class complexity - */ -export class RouteOrchestrator { - private routeUpdateLock: Mutex; - private portManager: PortManager; - private routeManager: RouteManager; - private httpProxyBridge: HttpProxyBridge; - private nftablesManager: NFTablesManager; - private certManager: SmartCertManager | null = null; - private logger: ILogger; - - constructor( - portManager: PortManager, - routeManager: RouteManager, - httpProxyBridge: HttpProxyBridge, - nftablesManager: NFTablesManager, - certManager: SmartCertManager | null, - logger: ILogger - ) { - this.portManager = portManager; - this.routeManager = routeManager; - this.httpProxyBridge = httpProxyBridge; - this.nftablesManager = nftablesManager; - this.certManager = certManager; - this.logger = logger; - this.routeUpdateLock = new Mutex(); - } - - /** - * Set or update certificate manager reference - */ - public setCertManager(certManager: SmartCertManager | null): void { - this.certManager = certManager; - } - - /** - * Get certificate manager reference - */ - public getCertManager(): SmartCertManager | null { - return this.certManager; - } - - /** - * Update routes with validation and coordination - */ - public async updateRoutes( - oldRoutes: IRouteConfig[], - newRoutes: IRouteConfig[], - options: { - acmePort?: number; - acmeOptions?: any; - acmeState?: any; - globalChallengeRouteActive?: boolean; - createCertificateManager?: ( - routes: IRouteConfig[], - certStore: string, - acmeOptions?: any, - initialState?: any - ) => Promise; - verifyChallengeRouteRemoved?: () => Promise; - } = {} - ): Promise<{ - portUsageMap: Map>; - newChallengeRouteActive: boolean; - newCertManager?: SmartCertManager; - }> { - return this.routeUpdateLock.runExclusive(async () => { - // Validate route configurations - const validation = RouteValidator.validateRoutes(newRoutes); - if (!validation.valid) { - RouteValidator.logValidationErrors(validation.errors); - throw new Error(`Route validation failed: ${validation.errors.size} route(s) have errors`); - } - - // Track port usage before and after updates - const oldPortUsage = this.updatePortUsageMap(oldRoutes); - const newPortUsage = this.updatePortUsageMap(newRoutes); - - // Get the lists of currently listening ports and new ports needed - const currentPorts = new Set(this.portManager.getListeningPorts()); - const newPortsSet = new Set(newPortUsage.keys()); - - // Log the port usage for debugging - this.logger.debug(`Current listening ports: ${Array.from(currentPorts).join(', ')}`); - this.logger.debug(`Ports needed for new routes: ${Array.from(newPortsSet).join(', ')}`); - - // Find orphaned ports - ports that no longer have any routes - const orphanedPorts = this.findOrphanedPorts(oldPortUsage, newPortUsage); - - // Find new ports that need binding (only ports that we aren't already listening on) - const newBindingPorts = Array.from(newPortsSet).filter(p => !currentPorts.has(p)); - - // Check for ACME challenge port to give it special handling - const acmePort = options.acmePort || 80; - const acmePortNeeded = newPortsSet.has(acmePort); - const acmePortListed = newBindingPorts.includes(acmePort); - - if (acmePortNeeded && acmePortListed) { - this.logger.info(`Adding ACME challenge port ${acmePort} to routes`); - } - - // Update NFTables routes - await this.updateNfTablesRoutes(oldRoutes, newRoutes); - - // Update routes in RouteManager - this.routeManager.updateRoutes(newRoutes); - - // Release orphaned ports first to free resources - if (orphanedPorts.length > 0) { - this.logger.info(`Releasing ${orphanedPorts.length} orphaned ports: ${orphanedPorts.join(', ')}`); - await this.portManager.removePorts(orphanedPorts); - } - - // Add new ports if needed - if (newBindingPorts.length > 0) { - this.logger.info(`Binding to ${newBindingPorts.length} new ports: ${newBindingPorts.join(', ')}`); - - // Handle port binding with improved error recovery - try { - await this.portManager.addPorts(newBindingPorts); - } catch (error) { - // Special handling for port binding errors - if ((error as any).code === 'EADDRINUSE') { - const port = (error as any).port || newBindingPorts[0]; - const isAcmePort = port === acmePort; - - if (isAcmePort) { - this.logger.warn(`Could not bind to ACME challenge port ${port}. It may be in use by another application.`); - - // Re-throw with more helpful message - throw new Error( - `ACME challenge port ${port} is already in use by another application. ` + - `Configure a different port in settings.acme.port (e.g., 8080) or free up port ${port}.` - ); - } - } - - // Re-throw the original error for other cases - throw error; - } - } - - // If HttpProxy is initialized, resync the configurations - if (this.httpProxyBridge.getHttpProxy()) { - await this.httpProxyBridge.syncRoutesToHttpProxy(newRoutes); - } - - // Update certificate manager if needed - let newCertManager: SmartCertManager | undefined; - let newChallengeRouteActive = options.globalChallengeRouteActive || false; - - if (this.certManager && options.createCertificateManager) { - const existingAcmeOptions = this.certManager.getAcmeOptions(); - const existingState = this.certManager.getState(); - - // Store global state before stopping - newChallengeRouteActive = existingState.challengeRouteActive; - - // Keep certificate manager routes in sync before stopping - this.certManager.setRoutes(newRoutes); - - await this.certManager.stop(); - - // Verify the challenge route has been properly removed - if (options.verifyChallengeRouteRemoved) { - await options.verifyChallengeRouteRemoved(); - } - - // Create new certificate manager with preserved state - newCertManager = await options.createCertificateManager( - newRoutes, - './certs', - existingAcmeOptions, - { challengeRouteActive: newChallengeRouteActive } - ); - - this.certManager = newCertManager; - } - - return { - portUsageMap: newPortUsage, - newChallengeRouteActive, - newCertManager - }; - }); - } - - /** - * Update port usage map based on the provided routes - */ - public updatePortUsageMap(routes: IRouteConfig[]): Map> { - const portUsage = new Map>(); - - for (const route of routes) { - // Get the ports for this route - const portsConfig = Array.isArray(route.match.ports) - ? route.match.ports - : [route.match.ports]; - - // Expand port range objects to individual port numbers - const expandedPorts: number[] = []; - for (const portConfig of portsConfig) { - if (typeof portConfig === 'number') { - expandedPorts.push(portConfig); - } else if (typeof portConfig === 'object' && 'from' in portConfig && 'to' in portConfig) { - // Expand the port range - for (let p = portConfig.from; p <= portConfig.to; p++) { - expandedPorts.push(p); - } - } - } - - // Use route name if available, otherwise generate a unique ID - const routeName = route.name || `unnamed_${Math.random().toString(36).substring(2, 9)}`; - - // Add each port to the usage map - for (const port of expandedPorts) { - if (!portUsage.has(port)) { - portUsage.set(port, new Set()); - } - portUsage.get(port)!.add(routeName); - } - } - - // Log port usage for debugging - for (const [port, routes] of portUsage.entries()) { - this.logger.debug(`Port ${port} is used by ${routes.size} routes: ${Array.from(routes).join(', ')}`); - } - - return portUsage; - } - - /** - * Find ports that have no routes in the new configuration - */ - private findOrphanedPorts(oldUsage: Map>, newUsage: Map>): number[] { - const orphanedPorts: number[] = []; - - for (const [port, routes] of oldUsage.entries()) { - if (!newUsage.has(port) || newUsage.get(port)!.size === 0) { - orphanedPorts.push(port); - } - } - - return orphanedPorts; - } - - /** - * Update NFTables routes - */ - private async updateNfTablesRoutes(oldRoutes: IRouteConfig[], newRoutes: IRouteConfig[]): Promise { - // Get existing routes that use NFTables and update them - const oldNfTablesRoutes = oldRoutes.filter( - r => r.action.forwardingEngine === 'nftables' - ); - - const newNfTablesRoutes = newRoutes.filter( - r => r.action.forwardingEngine === 'nftables' - ); - - // Update existing NFTables routes - for (const oldRoute of oldNfTablesRoutes) { - const newRoute = newNfTablesRoutes.find(r => r.name === oldRoute.name); - - if (!newRoute) { - // Route was removed - await this.nftablesManager.deprovisionRoute(oldRoute); - } else { - // Route was updated - await this.nftablesManager.updateRoute(oldRoute, newRoute); - } - } - - // Add new NFTables routes - for (const newRoute of newNfTablesRoutes) { - const oldRoute = oldNfTablesRoutes.find(r => r.name === newRoute.name); - - if (!oldRoute) { - // New route - await this.nftablesManager.provisionRoute(newRoute); - } - } - } -} \ No newline at end of file diff --git a/ts/proxies/smart-proxy/route-preprocessor.ts b/ts/proxies/smart-proxy/route-preprocessor.ts new file mode 100644 index 0000000..7f1b030 --- /dev/null +++ b/ts/proxies/smart-proxy/route-preprocessor.ts @@ -0,0 +1,122 @@ +import type { IRouteConfig, IRouteAction, IRouteTarget } from './models/route-types.js'; +import { logger } from '../../core/utils/logger.js'; + +/** + * Preprocesses routes before sending them to Rust. + * + * Strips non-serializable fields (functions, callbacks) and classifies + * routes that must be handled by TypeScript (socket-handler, dynamic host/port). + */ +export class RoutePreprocessor { + /** + * Map of route name/id โ†’ original route config (with JS functions preserved). + * Used by the socket handler server to look up the original handler. + */ + private originalRoutes = new Map(); + + /** + * Preprocess routes for the Rust binary. + * + * - Routes with `socketHandler` callbacks are marked as socket-handler type + * (Rust will relay these back to TS) + * - Routes with dynamic `host`/`port` functions are converted to socket-handler + * type (Rust relays, TS resolves the function) + * - Non-serializable fields are stripped + * - Original routes are preserved in the local map for handler lookup + */ + public preprocessForRust(routes: IRouteConfig[]): IRouteConfig[] { + this.originalRoutes.clear(); + return routes.map((route, index) => this.preprocessRoute(route, index)); + } + + /** + * Get the original route config (with JS functions) by route name or id. + */ + public getOriginalRoute(routeKey: string): IRouteConfig | undefined { + return this.originalRoutes.get(routeKey); + } + + /** + * Get all original routes that have socket handlers or dynamic functions. + */ + public getHandlerRoutes(): Map { + return new Map(this.originalRoutes); + } + + private preprocessRoute(route: IRouteConfig, index: number): IRouteConfig { + const routeKey = route.name || route.id || `route_${index}`; + + // Check if this route needs TS-side handling + const needsTsHandling = this.routeNeedsTsHandling(route); + + if (needsTsHandling) { + // Store the original route for handler lookup + this.originalRoutes.set(routeKey, route); + } + + // Create a clean copy for Rust + const cleanRoute: IRouteConfig = { + ...route, + action: this.cleanAction(route.action, routeKey, needsTsHandling), + }; + + // Ensure we have a name for handler lookup + if (!cleanRoute.name && !cleanRoute.id) { + cleanRoute.name = routeKey; + } + + return cleanRoute; + } + + private routeNeedsTsHandling(route: IRouteConfig): boolean { + // Socket handler routes always need TS + if (route.action.type === 'socket-handler' && route.action.socketHandler) { + return true; + } + + // Routes with dynamic host/port functions need TS + if (route.action.targets) { + for (const target of route.action.targets) { + if (typeof target.host === 'function' || typeof target.port === 'function') { + return true; + } + } + } + + return false; + } + + private cleanAction(action: IRouteAction, routeKey: string, needsTsHandling: boolean): IRouteAction { + const cleanAction: IRouteAction = { ...action }; + + if (needsTsHandling) { + // Convert to socket-handler type for Rust (Rust will relay back to TS) + cleanAction.type = 'socket-handler'; + // Remove the JS handler (not serializable) + delete (cleanAction as any).socketHandler; + } + + // Clean targets - replace functions with static values + if (cleanAction.targets) { + cleanAction.targets = cleanAction.targets.map(t => this.cleanTarget(t)); + } + + return cleanAction; + } + + private cleanTarget(target: IRouteTarget): IRouteTarget { + const clean: IRouteTarget = { ...target }; + + // Replace function host with placeholder + if (typeof clean.host === 'function') { + clean.host = 'localhost'; + } + + // Replace function port with placeholder + if (typeof clean.port === 'function') { + clean.port = 0; + } + + return clean; + } +} diff --git a/ts/proxies/smart-proxy/rust-binary-locator.ts b/ts/proxies/smart-proxy/rust-binary-locator.ts new file mode 100644 index 0000000..d94f0a5 --- /dev/null +++ b/ts/proxies/smart-proxy/rust-binary-locator.ts @@ -0,0 +1,112 @@ +import * as plugins from '../../plugins.js'; +import { logger } from '../../core/utils/logger.js'; + +/** + * Locates the RustProxy binary using a priority-ordered search strategy: + * 1. SMARTPROXY_RUST_BINARY environment variable + * 2. Platform-specific optional npm package + * 3. Local development build at ./rust/target/release/rustproxy + * 4. System PATH + */ +export class RustBinaryLocator { + private cachedPath: string | null = null; + + /** + * Find the RustProxy binary path. + * Returns null if no binary is available. + */ + public async findBinary(): Promise { + if (this.cachedPath !== null) { + return this.cachedPath; + } + + const path = await this.searchBinary(); + this.cachedPath = path; + return path; + } + + /** + * Clear the cached binary path (e.g., after a failed launch). + */ + public clearCache(): void { + this.cachedPath = null; + } + + private async searchBinary(): Promise { + // 1. Environment variable override + const envPath = process.env.SMARTPROXY_RUST_BINARY; + if (envPath) { + if (await this.isExecutable(envPath)) { + logger.log('info', `RustProxy binary found via SMARTPROXY_RUST_BINARY: ${envPath}`, { component: 'rust-locator' }); + return envPath; + } + logger.log('warn', `SMARTPROXY_RUST_BINARY set but not executable: ${envPath}`, { component: 'rust-locator' }); + } + + // 2. Platform-specific optional npm package + const platformBinary = await this.findPlatformPackageBinary(); + if (platformBinary) { + logger.log('info', `RustProxy binary found in platform package: ${platformBinary}`, { component: 'rust-locator' }); + return platformBinary; + } + + // 3. Local development build + const localPaths = [ + plugins.path.resolve(process.cwd(), 'rust/target/release/rustproxy'), + plugins.path.resolve(process.cwd(), 'rust/target/debug/rustproxy'), + ]; + for (const localPath of localPaths) { + if (await this.isExecutable(localPath)) { + logger.log('info', `RustProxy binary found at local path: ${localPath}`, { component: 'rust-locator' }); + return localPath; + } + } + + // 4. System PATH + const systemPath = await this.findInPath('rustproxy'); + if (systemPath) { + logger.log('info', `RustProxy binary found in system PATH: ${systemPath}`, { component: 'rust-locator' }); + return systemPath; + } + + logger.log('error', 'No RustProxy binary found. Set SMARTPROXY_RUST_BINARY, install the platform package, or build with: cd rust && cargo build --release', { component: 'rust-locator' }); + return null; + } + + private async findPlatformPackageBinary(): Promise { + const platform = process.platform; + const arch = process.arch; + const packageName = `@push.rocks/smartproxy-${platform}-${arch}`; + + try { + // Try to resolve the platform-specific package + const packagePath = require.resolve(`${packageName}/rustproxy`); + if (await this.isExecutable(packagePath)) { + return packagePath; + } + } catch { + // Package not installed - expected for development + } + return null; + } + + private async isExecutable(filePath: string): Promise { + try { + await plugins.fs.promises.access(filePath, plugins.fs.constants.X_OK); + return true; + } catch { + return false; + } + } + + private async findInPath(binaryName: string): Promise { + const pathDirs = (process.env.PATH || '').split(plugins.path.delimiter); + for (const dir of pathDirs) { + const fullPath = plugins.path.join(dir, binaryName); + if (await this.isExecutable(fullPath)) { + return fullPath; + } + } + return null; + } +} diff --git a/ts/proxies/smart-proxy/rust-metrics-adapter.ts b/ts/proxies/smart-proxy/rust-metrics-adapter.ts new file mode 100644 index 0000000..e8701a0 --- /dev/null +++ b/ts/proxies/smart-proxy/rust-metrics-adapter.ts @@ -0,0 +1,136 @@ +import type { IMetrics, IThroughputData, IThroughputHistoryPoint } from './models/metrics-types.js'; +import type { RustProxyBridge } from './rust-proxy-bridge.js'; + +/** + * Adapts Rust JSON metrics to the IMetrics interface. + * + * Polls the Rust binary periodically via the bridge and caches the result. + * All IMetrics getters read from the cache synchronously. + * Fields not yet in Rust (percentiles, per-IP, history) return zero/empty. + */ +export class RustMetricsAdapter implements IMetrics { + private bridge: RustProxyBridge; + private cache: any = null; + private pollTimer: ReturnType | null = null; + private pollIntervalMs: number; + + // Cumulative totals tracked across polls + private cumulativeBytesIn = 0; + private cumulativeBytesOut = 0; + private cumulativeConnections = 0; + + constructor(bridge: RustProxyBridge, pollIntervalMs = 1000) { + this.bridge = bridge; + this.pollIntervalMs = pollIntervalMs; + } + + public startPolling(): void { + if (this.pollTimer) return; + this.pollTimer = setInterval(async () => { + try { + this.cache = await this.bridge.getMetrics(); + // Update cumulative totals + if (this.cache) { + this.cumulativeBytesIn = this.cache.totalBytesIn ?? this.cache.total_bytes_in ?? 0; + this.cumulativeBytesOut = this.cache.totalBytesOut ?? this.cache.total_bytes_out ?? 0; + this.cumulativeConnections = this.cache.totalConnections ?? this.cache.total_connections ?? 0; + } + } catch { + // Ignore poll errors (bridge may be shutting down) + } + }, this.pollIntervalMs); + if (this.pollTimer.unref) { + this.pollTimer.unref(); + } + } + + public stopPolling(): void { + if (this.pollTimer) { + clearInterval(this.pollTimer); + this.pollTimer = null; + } + } + + // --- IMetrics implementation --- + + public connections = { + active: (): number => { + return this.cache?.activeConnections ?? this.cache?.active_connections ?? 0; + }, + total: (): number => { + return this.cumulativeConnections; + }, + byRoute: (): Map => { + return new Map(); + }, + byIP: (): Map => { + return new Map(); + }, + topIPs: (_limit?: number): Array<{ ip: string; count: number }> => { + return []; + }, + }; + + public throughput = { + instant: (): IThroughputData => { + return { in: this.cache?.bytesInPerSecond ?? 0, out: this.cache?.bytesOutPerSecond ?? 0 }; + }, + recent: (): IThroughputData => { + return this.throughput.instant(); + }, + average: (): IThroughputData => { + return this.throughput.instant(); + }, + custom: (_seconds: number): IThroughputData => { + return this.throughput.instant(); + }, + history: (_seconds: number): Array => { + return []; + }, + byRoute: (_windowSeconds?: number): Map => { + return new Map(); + }, + byIP: (_windowSeconds?: number): Map => { + return new Map(); + }, + }; + + public requests = { + perSecond: (): number => { + return this.cache?.requestsPerSecond ?? 0; + }, + perMinute: (): number => { + return (this.cache?.requestsPerSecond ?? 0) * 60; + }, + total: (): number => { + return this.cache?.totalRequests ?? this.cache?.total_requests ?? 0; + }, + }; + + public totals = { + bytesIn: (): number => { + return this.cumulativeBytesIn; + }, + bytesOut: (): number => { + return this.cumulativeBytesOut; + }, + connections: (): number => { + return this.cumulativeConnections; + }, + }; + + public percentiles = { + connectionDuration: (): { p50: number; p95: number; p99: number } => { + return { p50: 0, p95: 0, p99: 0 }; + }, + bytesTransferred: (): { + in: { p50: number; p95: number; p99: number }; + out: { p50: number; p95: number; p99: number }; + } => { + return { + in: { p50: 0, p95: 0, p99: 0 }, + out: { p50: 0, p95: 0, p99: 0 }, + }; + }, + }; +} diff --git a/ts/proxies/smart-proxy/rust-proxy-bridge.ts b/ts/proxies/smart-proxy/rust-proxy-bridge.ts new file mode 100644 index 0000000..1b04773 --- /dev/null +++ b/ts/proxies/smart-proxy/rust-proxy-bridge.ts @@ -0,0 +1,278 @@ +import * as plugins from '../../plugins.js'; +import { logger } from '../../core/utils/logger.js'; +import { RustBinaryLocator } from './rust-binary-locator.js'; +import type { IRouteConfig } from './models/route-types.js'; +import { ChildProcess, spawn } from 'child_process'; +import { createInterface, Interface as ReadlineInterface } from 'readline'; + +/** + * Management request sent to the Rust binary via stdin. + */ +interface IManagementRequest { + id: string; + method: string; + params: Record; +} + +/** + * Management response received from the Rust binary via stdout. + */ +interface IManagementResponse { + id: string; + success: boolean; + result?: any; + error?: string; +} + +/** + * Management event received from the Rust binary (unsolicited). + */ +interface IManagementEvent { + event: string; + data: any; +} + +/** + * Bridge between TypeScript SmartProxy and the Rust binary. + * Communicates via JSON-over-stdin/stdout IPC protocol. + */ +export class RustProxyBridge extends plugins.EventEmitter { + private locator = new RustBinaryLocator(); + private process: ChildProcess | null = null; + private readline: ReadlineInterface | null = null; + private pendingRequests = new Map void; + reject: (error: Error) => void; + timer: NodeJS.Timeout; + }>(); + private requestCounter = 0; + private isRunning = false; + private binaryPath: string | null = null; + private readonly requestTimeoutMs = 30000; + + /** + * Spawn the Rust binary in management mode. + * Returns true if the binary was found and spawned successfully. + */ + public async spawn(): Promise { + this.binaryPath = await this.locator.findBinary(); + if (!this.binaryPath) { + return false; + } + + return new Promise((resolve) => { + try { + this.process = spawn(this.binaryPath!, ['--management'], { + stdio: ['pipe', 'pipe', 'pipe'], + env: { ...process.env }, + }); + + // Handle stderr (logging from Rust goes here) + this.process.stderr?.on('data', (data: Buffer) => { + const lines = data.toString().split('\n').filter(l => l.trim()); + for (const line of lines) { + logger.log('debug', `[rustproxy] ${line}`, { component: 'rust-bridge' }); + } + }); + + // Handle stdout (JSON IPC) + this.readline = createInterface({ input: this.process.stdout! }); + this.readline.on('line', (line: string) => { + this.handleLine(line.trim()); + }); + + // Handle process exit + this.process.on('exit', (code, signal) => { + logger.log('info', `RustProxy process exited (code=${code}, signal=${signal})`, { component: 'rust-bridge' }); + this.cleanup(); + this.emit('exit', code, signal); + }); + + this.process.on('error', (err) => { + logger.log('error', `RustProxy process error: ${err.message}`, { component: 'rust-bridge' }); + this.cleanup(); + resolve(false); + }); + + // Wait for the 'ready' event from Rust + const readyTimeout = setTimeout(() => { + logger.log('error', 'RustProxy did not send ready event within 10s', { component: 'rust-bridge' }); + this.kill(); + resolve(false); + }, 10000); + + this.once('management:ready', () => { + clearTimeout(readyTimeout); + this.isRunning = true; + logger.log('info', 'RustProxy bridge connected', { component: 'rust-bridge' }); + resolve(true); + }); + } catch (err: any) { + logger.log('error', `Failed to spawn RustProxy: ${err.message}`, { component: 'rust-bridge' }); + resolve(false); + } + }); + } + + /** + * Send a management command to the Rust process and wait for the response. + */ + public async sendCommand(method: string, params: Record = {}): Promise { + if (!this.process || !this.isRunning) { + throw new Error('RustProxy bridge is not running'); + } + + const id = `req_${++this.requestCounter}`; + const request: IManagementRequest = { id, method, params }; + + return new Promise((resolve, reject) => { + const timer = setTimeout(() => { + this.pendingRequests.delete(id); + reject(new Error(`RustProxy command '${method}' timed out after ${this.requestTimeoutMs}ms`)); + }, this.requestTimeoutMs); + + this.pendingRequests.set(id, { resolve, reject, timer }); + + const json = JSON.stringify(request) + '\n'; + this.process!.stdin!.write(json, (err) => { + if (err) { + clearTimeout(timer); + this.pendingRequests.delete(id); + reject(new Error(`Failed to write to RustProxy stdin: ${err.message}`)); + } + }); + }); + } + + // Convenience methods for each management command + + public async startProxy(config: any): Promise { + await this.sendCommand('start', { config }); + } + + public async stopProxy(): Promise { + await this.sendCommand('stop'); + } + + public async updateRoutes(routes: IRouteConfig[]): Promise { + await this.sendCommand('updateRoutes', { routes }); + } + + public async getMetrics(): Promise { + return this.sendCommand('getMetrics'); + } + + public async getStatistics(): Promise { + return this.sendCommand('getStatistics'); + } + + public async provisionCertificate(routeName: string): Promise { + await this.sendCommand('provisionCertificate', { routeName }); + } + + public async renewCertificate(routeName: string): Promise { + await this.sendCommand('renewCertificate', { routeName }); + } + + public async getCertificateStatus(routeName: string): Promise { + return this.sendCommand('getCertificateStatus', { routeName }); + } + + public async getListeningPorts(): Promise { + const result = await this.sendCommand('getListeningPorts'); + return result?.ports ?? []; + } + + public async getNftablesStatus(): Promise { + return this.sendCommand('getNftablesStatus'); + } + + public async setSocketHandlerRelay(socketPath: string): Promise { + await this.sendCommand('setSocketHandlerRelay', { socketPath }); + } + + public async addListeningPort(port: number): Promise { + await this.sendCommand('addListeningPort', { port }); + } + + public async removeListeningPort(port: number): Promise { + await this.sendCommand('removeListeningPort', { port }); + } + + public async loadCertificate(domain: string, cert: string, key: string, ca?: string): Promise { + await this.sendCommand('loadCertificate', { domain, cert, key, ca }); + } + + /** + * Kill the Rust process. + */ + public kill(): void { + if (this.process) { + this.process.kill('SIGTERM'); + // Force kill after 5 seconds + setTimeout(() => { + if (this.process) { + this.process.kill('SIGKILL'); + } + }, 5000).unref(); + } + } + + /** + * Whether the bridge is currently running. + */ + public get running(): boolean { + return this.isRunning; + } + + private handleLine(line: string): void { + if (!line) return; + + let parsed: any; + try { + parsed = JSON.parse(line); + } catch { + logger.log('warn', `Non-JSON output from RustProxy: ${line}`, { component: 'rust-bridge' }); + return; + } + + // Check if it's an event (has 'event' field) + if ('event' in parsed) { + const event = parsed as IManagementEvent; + this.emit(`management:${event.event}`, event.data); + return; + } + + // Otherwise it's a response (has 'id' field) + if ('id' in parsed) { + const response = parsed as IManagementResponse; + const pending = this.pendingRequests.get(response.id); + if (pending) { + clearTimeout(pending.timer); + this.pendingRequests.delete(response.id); + if (response.success) { + pending.resolve(response.result); + } else { + pending.reject(new Error(response.error || 'Unknown error from RustProxy')); + } + } + } + } + + private cleanup(): void { + this.isRunning = false; + this.process = null; + + if (this.readline) { + this.readline.close(); + this.readline = null; + } + + // Reject all pending requests + for (const [id, pending] of this.pendingRequests) { + clearTimeout(pending.timer); + pending.reject(new Error('RustProxy process exited')); + } + this.pendingRequests.clear(); + } +} diff --git a/ts/proxies/smart-proxy/security-manager.ts b/ts/proxies/smart-proxy/security-manager.ts deleted file mode 100644 index a3c54d7..0000000 --- a/ts/proxies/smart-proxy/security-manager.ts +++ /dev/null @@ -1,269 +0,0 @@ -import * as plugins from '../../plugins.js'; -import type { SmartProxy } from './smart-proxy.js'; -import { connectionLogDeduplicator } from '../../core/utils/log-deduplicator.js'; -import { isIPAuthorized, normalizeIP } from '../../core/utils/security-utils.js'; - -/** - * Handles security aspects like IP tracking, rate limiting, and authorization - * for SmartProxy. This is a lightweight wrapper that uses shared utilities. - */ -export class SecurityManager { - private connectionsByIP: Map> = new Map(); - private connectionRateByIP: Map = new Map(); - private cleanupInterval: NodeJS.Timeout | null = null; - - constructor(private smartProxy: SmartProxy) { - // Start periodic cleanup every 60 seconds - this.startPeriodicCleanup(); - } - - /** - * Get connections count by IP (checks normalized variants) - */ - public getConnectionCountByIP(ip: string): number { - // Check all normalized variants of the IP - const variants = normalizeIP(ip); - for (const variant of variants) { - const connections = this.connectionsByIP.get(variant); - if (connections) { - return connections.size; - } - } - return 0; - } - - /** - * Check and update connection rate for an IP - * @returns true if within rate limit, false if exceeding limit - */ - public checkConnectionRate(ip: string): boolean { - const now = Date.now(); - const minute = 60 * 1000; - - // Find existing rate tracking (check normalized variants) - const variants = normalizeIP(ip); - let existingKey: string | null = null; - for (const variant of variants) { - if (this.connectionRateByIP.has(variant)) { - existingKey = variant; - break; - } - } - - const key = existingKey || ip; - - if (!this.connectionRateByIP.has(key)) { - this.connectionRateByIP.set(key, [now]); - return true; - } - - // Get timestamps and filter out entries older than 1 minute - const timestamps = this.connectionRateByIP.get(key)!.filter((time) => now - time < minute); - timestamps.push(now); - this.connectionRateByIP.set(key, timestamps); - - // Check if rate exceeds limit - return timestamps.length <= this.smartProxy.settings.connectionRateLimitPerMinute!; - } - - /** - * Track connection by IP - */ - public trackConnectionByIP(ip: string, connectionId: string): void { - // Check if any variant already exists - const variants = normalizeIP(ip); - let existingKey: string | null = null; - - for (const variant of variants) { - if (this.connectionsByIP.has(variant)) { - existingKey = variant; - break; - } - } - - const key = existingKey || ip; - if (!this.connectionsByIP.has(key)) { - this.connectionsByIP.set(key, new Set()); - } - this.connectionsByIP.get(key)!.add(connectionId); - } - - /** - * Remove connection tracking for an IP - */ - public removeConnectionByIP(ip: string, connectionId: string): void { - // Check all variants to find where the connection is tracked - const variants = normalizeIP(ip); - - for (const variant of variants) { - if (this.connectionsByIP.has(variant)) { - const connections = this.connectionsByIP.get(variant)!; - connections.delete(connectionId); - if (connections.size === 0) { - this.connectionsByIP.delete(variant); - } - break; - } - } - } - - /** - * Check if an IP is authorized using security rules - * - * This method is used to determine if an IP is allowed to connect, based on security - * rules configured in the route configuration. The allowed and blocked IPs are - * typically derived from route.security.ipAllowList and ipBlockList. - * - * @param ip - The IP address to check - * @param allowedIPs - Array of allowed IP patterns from security.ipAllowList - * @param blockedIPs - Array of blocked IP patterns from security.ipBlockList - * @returns true if IP is authorized, false if blocked - */ - public isIPAuthorized(ip: string, allowedIPs: string[], blockedIPs: string[] = []): boolean { - return isIPAuthorized(ip, allowedIPs, blockedIPs); - } - - /** - * Check if IP should be allowed considering connection rate and max connections - * @returns Object with result and reason - */ - public validateIP(ip: string): { allowed: boolean; reason?: string } { - // Check connection count limit - if ( - this.smartProxy.settings.maxConnectionsPerIP && - this.getConnectionCountByIP(ip) >= this.smartProxy.settings.maxConnectionsPerIP - ) { - return { - allowed: false, - reason: `Maximum connections per IP (${this.smartProxy.settings.maxConnectionsPerIP}) exceeded` - }; - } - - // Check connection rate limit - if ( - this.smartProxy.settings.connectionRateLimitPerMinute && - !this.checkConnectionRate(ip) - ) { - return { - allowed: false, - reason: `Connection rate limit (${this.smartProxy.settings.connectionRateLimitPerMinute}/min) exceeded` - }; - } - - return { allowed: true }; - } - - /** - * Atomically validate an IP and track the connection if allowed. - * This prevents race conditions where concurrent connections could bypass per-IP limits. - * - * @param ip - The IP address to validate - * @param connectionId - The connection ID to track if validation passes - * @returns Object with validation result and reason - */ - public validateAndTrackIP(ip: string, connectionId: string): { allowed: boolean; reason?: string } { - // Check connection count limit BEFORE tracking - if ( - this.smartProxy.settings.maxConnectionsPerIP && - this.getConnectionCountByIP(ip) >= this.smartProxy.settings.maxConnectionsPerIP - ) { - return { - allowed: false, - reason: `Maximum connections per IP (${this.smartProxy.settings.maxConnectionsPerIP}) exceeded` - }; - } - - // Check connection rate limit - if ( - this.smartProxy.settings.connectionRateLimitPerMinute && - !this.checkConnectionRate(ip) - ) { - return { - allowed: false, - reason: `Connection rate limit (${this.smartProxy.settings.connectionRateLimitPerMinute}/min) exceeded` - }; - } - - // Validation passed - immediately track to prevent race conditions - this.trackConnectionByIP(ip, connectionId); - - return { allowed: true }; - } - - /** - * Clears all IP tracking data (for shutdown) - */ - public clearIPTracking(): void { - if (this.cleanupInterval) { - clearInterval(this.cleanupInterval); - this.cleanupInterval = null; - } - this.connectionsByIP.clear(); - this.connectionRateByIP.clear(); - } - - /** - * Start periodic cleanup of expired data - */ - private startPeriodicCleanup(): void { - this.cleanupInterval = setInterval(() => { - this.performCleanup(); - }, 60000); // Run every minute - - // Unref the timer so it doesn't keep the process alive - if (this.cleanupInterval.unref) { - this.cleanupInterval.unref(); - } - } - - /** - * Perform cleanup of expired rate limits and empty IP entries - */ - private performCleanup(): void { - const now = Date.now(); - const minute = 60 * 1000; - let cleanedRateLimits = 0; - let cleanedIPs = 0; - - // Clean up expired rate limit timestamps - for (const [ip, timestamps] of this.connectionRateByIP.entries()) { - const validTimestamps = timestamps.filter(time => now - time < minute); - - if (validTimestamps.length === 0) { - // No valid timestamps, remove the IP entry - this.connectionRateByIP.delete(ip); - cleanedRateLimits++; - } else if (validTimestamps.length < timestamps.length) { - // Some timestamps expired, update with valid ones - this.connectionRateByIP.set(ip, validTimestamps); - } - } - - // Clean up IPs with no active connections - for (const [ip, connections] of this.connectionsByIP.entries()) { - if (connections.size === 0) { - this.connectionsByIP.delete(ip); - cleanedIPs++; - } - } - - // Log cleanup stats if anything was cleaned - if (cleanedRateLimits > 0 || cleanedIPs > 0) { - if (this.smartProxy.settings.enableDetailedLogging) { - connectionLogDeduplicator.log( - 'ip-cleanup', - 'debug', - 'IP tracking cleanup completed', - { - cleanedRateLimits, - cleanedIPs, - remainingIPs: this.connectionsByIP.size, - remainingRateLimits: this.connectionRateByIP.size, - component: 'security-manager' - }, - 'periodic-cleanup' - ); - } - } - } -} \ No newline at end of file diff --git a/ts/proxies/smart-proxy/smart-proxy.ts b/ts/proxies/smart-proxy/smart-proxy.ts index d3a1467..7463554 100644 --- a/ts/proxies/smart-proxy/smart-proxy.ts +++ b/ts/proxies/smart-proxy/smart-proxy.ts @@ -1,940 +1,409 @@ import * as plugins from '../../plugins.js'; import { logger } from '../../core/utils/logger.js'; -import { connectionLogDeduplicator } from '../../core/utils/log-deduplicator.js'; -// Importing required components -import { ConnectionManager } from './connection-manager.js'; -import { SecurityManager } from './security-manager.js'; -import { TlsManager } from './tls-manager.js'; -import { HttpProxyBridge } from './http-proxy-bridge.js'; -import { TimeoutManager } from './timeout-manager.js'; -import { PortManager } from './port-manager.js'; +// Rust bridge and helpers +import { RustProxyBridge } from './rust-proxy-bridge.js'; +import { RustBinaryLocator } from './rust-binary-locator.js'; +import { RoutePreprocessor } from './route-preprocessor.js'; +import { SocketHandlerServer } from './socket-handler-server.js'; +import { RustMetricsAdapter } from './rust-metrics-adapter.js'; + +// Route management import { SharedRouteManager as RouteManager } from '../../core/routing/route-manager.js'; -import { RouteConnectionHandler } from './route-connection-handler.js'; -import { NFTablesManager } from './nftables-manager.js'; - -// Certificate manager -import { SmartCertManager, type ICertStatus } from './certificate-manager.js'; - -// Import types and utilities -import type { - ISmartProxyOptions -} from './models/interfaces.js'; -import type { IRouteConfig } from './models/route-types.js'; - -// Import mutex for route update synchronization +import { RouteValidator } from './utils/route-validator.js'; import { Mutex } from './utils/mutex.js'; -// Import route validator -import { RouteValidator } from './utils/route-validator.js'; - -// Import route orchestrator for route management -import { RouteOrchestrator } from './route-orchestrator.js'; - -// Import ACME state manager -import { AcmeStateManager } from './acme-state-manager.js'; - -// Import metrics collector -import { MetricsCollector } from './metrics-collector.js'; +// Types +import type { ISmartProxyOptions, TSmartProxyCertProvisionObject } from './models/interfaces.js'; +import type { IRouteConfig } from './models/route-types.js'; import type { IMetrics } from './models/metrics-types.js'; /** - * SmartProxy - Pure route-based API + * SmartProxy - Rust-backed proxy engine with TypeScript configuration API. * - * SmartProxy is a unified proxy system that works with routes to define connection handling behavior. - * Each route contains matching criteria (ports, domains, etc.) and an action to take (forward, redirect, block). - * - * Configuration is provided through a set of routes, with each route defining: - * - What to match (ports, domains, paths, client IPs) - * - What to do with matching traffic (forward, redirect, block) - * - How to handle TLS (passthrough, terminate, terminate-and-reencrypt) - * - Security settings (IP restrictions, connection limits) - * - Advanced options (timeout, headers, etc.) + * All networking (TCP, TLS, HTTP reverse proxy, connection management, security, + * NFTables) is handled by the Rust binary. TypeScript is only: + * - The npm module interface (types, route helpers) + * - The thin IPC wrapper (this class) + * - Socket-handler callback relay (for JS-defined handlers) + * - Certificate provisioning callbacks (certProvisionFunction) */ export class SmartProxy extends plugins.EventEmitter { - // Port manager handles dynamic listener management - private portManager: PortManager; - private connectionLogger: NodeJS.Timeout | null = null; - private isShuttingDown: boolean = false; - - // Component managers - public connectionManager: ConnectionManager; - public securityManager: SecurityManager; - public tlsManager: TlsManager; - public httpProxyBridge: HttpProxyBridge; - public timeoutManager: TimeoutManager; + public settings: ISmartProxyOptions; public routeManager: RouteManager; - public routeConnectionHandler: RouteConnectionHandler; - public nftablesManager: NFTablesManager; - - // Certificate manager for ACME and static certificates - public certManager: SmartCertManager | null = null; - - // Global challenge route tracking - private globalChallengeRouteActive: boolean = false; + + private bridge: RustProxyBridge; + private preprocessor: RoutePreprocessor; + private socketHandlerServer: SocketHandlerServer | null = null; + private metricsAdapter: RustMetricsAdapter; private routeUpdateLock: Mutex; - public acmeStateManager: AcmeStateManager; - - // Metrics collector - public metricsCollector: MetricsCollector; - - // Route orchestrator for managing route updates - private routeOrchestrator: RouteOrchestrator; - - // Track port usage across route updates - private portUsageMap: Map> = new Map(); - - /** - * Constructor for SmartProxy - * - * @param settingsArg Configuration options containing routes and other settings - * Routes define how traffic is matched and handled, with each route having: - * - match: criteria for matching traffic (ports, domains, paths, IPs) - * - action: what to do with matched traffic (forward, redirect, block) - * - * Example: - * ```ts - * const proxy = new SmartProxy({ - * routes: [ - * { - * match: { - * ports: 443, - * domains: ['example.com', '*.example.com'] - * }, - * action: { - * type: 'forward', - * target: { host: '10.0.0.1', port: 8443 }, - * tls: { mode: 'passthrough' } - * } - * } - * ], - * defaults: { - * target: { host: 'localhost', port: 8080 }, - * security: { ipAllowList: ['*'] } - * } - * }); - * ``` - */ + constructor(settingsArg: ISmartProxyOptions) { super(); - - // Set reasonable defaults for all settings + + // Apply defaults this.settings = { ...settingsArg, initialDataTimeout: settingsArg.initialDataTimeout || 120000, socketTimeout: settingsArg.socketTimeout || 3600000, - inactivityCheckInterval: settingsArg.inactivityCheckInterval || 60000, maxConnectionLifetime: settingsArg.maxConnectionLifetime || 86400000, inactivityTimeout: settingsArg.inactivityTimeout || 14400000, gracefulShutdownTimeout: settingsArg.gracefulShutdownTimeout || 30000, - noDelay: settingsArg.noDelay !== undefined ? settingsArg.noDelay : true, - keepAlive: settingsArg.keepAlive !== undefined ? settingsArg.keepAlive : true, - keepAliveInitialDelay: settingsArg.keepAliveInitialDelay || 10000, - maxPendingDataSize: settingsArg.maxPendingDataSize || 10 * 1024 * 1024, - disableInactivityCheck: settingsArg.disableInactivityCheck || false, - enableKeepAliveProbes: - settingsArg.enableKeepAliveProbes !== undefined ? settingsArg.enableKeepAliveProbes : true, - enableDetailedLogging: settingsArg.enableDetailedLogging || false, - enableTlsDebugLogging: settingsArg.enableTlsDebugLogging || false, - enableRandomizedTimeouts: settingsArg.enableRandomizedTimeouts || false, maxConnectionsPerIP: settingsArg.maxConnectionsPerIP || 100, connectionRateLimitPerMinute: settingsArg.connectionRateLimitPerMinute || 300, keepAliveTreatment: settingsArg.keepAliveTreatment || 'extended', keepAliveInactivityMultiplier: settingsArg.keepAliveInactivityMultiplier || 6, extendedKeepAliveLifetime: settingsArg.extendedKeepAliveLifetime || 7 * 24 * 60 * 60 * 1000, - httpProxyPort: settingsArg.httpProxyPort || 8443, }; - - // Normalize ACME options if provided (support both email and accountEmail) + + // Normalize ACME options if (this.settings.acme) { - // Support both 'email' and 'accountEmail' fields if (this.settings.acme.accountEmail && !this.settings.acme.email) { this.settings.acme.email = this.settings.acme.accountEmail; } - - // Set reasonable defaults for commonly used fields this.settings.acme = { - enabled: this.settings.acme.enabled !== false, // Enable by default if acme object exists + enabled: this.settings.acme.enabled !== false, port: this.settings.acme.port || 80, email: this.settings.acme.email, useProduction: this.settings.acme.useProduction || false, renewThresholdDays: this.settings.acme.renewThresholdDays || 30, - autoRenew: this.settings.acme.autoRenew !== false, // Enable by default + autoRenew: this.settings.acme.autoRenew !== false, certificateStore: this.settings.acme.certificateStore || './certs', skipConfiguredCerts: this.settings.acme.skipConfiguredCerts || false, renewCheckIntervalHours: this.settings.acme.renewCheckIntervalHours || 24, routeForwards: this.settings.acme.routeForwards || [], - ...this.settings.acme // Preserve any additional fields + ...this.settings.acme, }; } - - // Initialize component managers - this.timeoutManager = new TimeoutManager(this); - this.securityManager = new SecurityManager(this); - this.connectionManager = new ConnectionManager(this); - - // Create the route manager with SharedRouteManager API - // Create a logger adapter to match ILogger interface - const loggerAdapter = { - debug: (message: string, data?: any) => logger.log('debug', message, data), - info: (message: string, data?: any) => logger.log('info', message, data), - warn: (message: string, data?: any) => logger.log('warn', message, data), - error: (message: string, data?: any) => logger.log('error', message, data) - }; - - // Validate initial routes - if (this.settings.routes && this.settings.routes.length > 0) { + + // Validate routes + if (this.settings.routes?.length) { const validation = RouteValidator.validateRoutes(this.settings.routes); if (!validation.valid) { RouteValidator.logValidationErrors(validation.errors); throw new Error(`Initial route validation failed: ${validation.errors.size} route(s) have errors`); } } - + + // Create logger adapter + const loggerAdapter = { + debug: (message: string, data?: any) => logger.log('debug', message, data), + info: (message: string, data?: any) => logger.log('info', message, data), + warn: (message: string, data?: any) => logger.log('warn', message, data), + error: (message: string, data?: any) => logger.log('error', message, data), + }; + + // Initialize components this.routeManager = new RouteManager({ logger: loggerAdapter, enableDetailedLogging: this.settings.enableDetailedLogging, - routes: this.settings.routes + routes: this.settings.routes, }); - - // Create other required components - this.tlsManager = new TlsManager(this); - this.httpProxyBridge = new HttpProxyBridge(this); - - // Initialize connection handler with route support - this.routeConnectionHandler = new RouteConnectionHandler(this); - - // Initialize port manager - this.portManager = new PortManager(this); - - // Initialize NFTablesManager - this.nftablesManager = new NFTablesManager(this); - - // Initialize route update mutex for synchronization + this.bridge = new RustProxyBridge(); + this.preprocessor = new RoutePreprocessor(); + this.metricsAdapter = new RustMetricsAdapter(this.bridge); this.routeUpdateLock = new Mutex(); - - // Initialize ACME state manager - this.acmeStateManager = new AcmeStateManager(); - - // Initialize metrics collector with reference to this SmartProxy instance - this.metricsCollector = new MetricsCollector(this, { - sampleIntervalMs: this.settings.metrics?.sampleIntervalMs, - retentionSeconds: this.settings.metrics?.retentionSeconds - }); - - // Initialize route orchestrator for managing route updates - this.routeOrchestrator = new RouteOrchestrator( - this.portManager, - this.routeManager, - this.httpProxyBridge, - this.nftablesManager, - null, // certManager will be set later - loggerAdapter - ); - } - - /** - * The settings for the SmartProxy - */ - public settings: ISmartProxyOptions; - - /** - * Helper method to create and configure certificate manager - * This ensures consistent setup including the required ACME callback - */ - private async createCertificateManager( - routes: IRouteConfig[], - certStore: string = './certs', - acmeOptions?: any, - initialState?: { challengeRouteActive?: boolean } - ): Promise { - const certManager = new SmartCertManager(routes, certStore, acmeOptions, initialState); - - // Always set up the route update callback for ACME challenges - certManager.setUpdateRoutesCallback(async (routes) => { - await this.updateRoutes(routes); - }); - - // Connect with HttpProxy if available - if (this.httpProxyBridge.getHttpProxy()) { - certManager.setHttpProxy(this.httpProxyBridge.getHttpProxy()); - } - - // Set the ACME state manager - certManager.setAcmeStateManager(this.acmeStateManager); - - // Pass down the global ACME config if available - if (this.settings.acme) { - certManager.setGlobalAcmeDefaults(this.settings.acme); - } - - // Pass down the custom certificate provision function if available - if (this.settings.certProvisionFunction) { - certManager.setCertProvisionFunction(this.settings.certProvisionFunction); - } - - // Pass down the fallback to ACME setting - if (this.settings.certProvisionFallbackToAcme !== undefined) { - certManager.setCertProvisionFallbackToAcme(this.settings.certProvisionFallbackToAcme); - } - - await certManager.initialize(); - return certManager; } /** - * Initialize certificate manager + * Start the proxy. + * Spawns the Rust binary, configures socket relay if needed, sends routes, handles cert provisioning. */ - private async initializeCertificateManager(): Promise { - // Extract global ACME options if any routes use auto certificates - const autoRoutes = this.settings.routes.filter(r => - r.action.tls?.certificate === 'auto' - ); - - if (autoRoutes.length === 0 && !this.hasStaticCertRoutes()) { - logger.log('info', 'No routes require certificate management', { component: 'certificate-manager' }); - return; - } - - // Prepare ACME options with priority: - // 1. Use top-level ACME config if available - // 2. Fall back to first auto route's ACME config - // 3. Otherwise use undefined - let acmeOptions: { email?: string; useProduction?: boolean; port?: number } | undefined; - - if (this.settings.acme?.email) { - // Use top-level ACME config - acmeOptions = { - email: this.settings.acme.email, - useProduction: this.settings.acme.useProduction || false, - port: this.settings.acme.port || 80 - }; - logger.log('info', `Using top-level ACME configuration with email: ${acmeOptions.email}`, { component: 'certificate-manager' }); - } else if (autoRoutes.length > 0) { - // Check for route-level ACME config - const routeWithAcme = autoRoutes.find(r => r.action.tls?.acme?.email); - if (routeWithAcme?.action.tls?.acme) { - const routeAcme = routeWithAcme.action.tls.acme; - acmeOptions = { - email: routeAcme.email, - useProduction: routeAcme.useProduction || false, - port: routeAcme.challengePort || 80 - }; - logger.log('info', `Using route-level ACME configuration from route '${routeWithAcme.name}' with email: ${acmeOptions.email}`, { component: 'certificate-manager' }); - } - } - - // Validate we have required configuration - if (autoRoutes.length > 0 && !acmeOptions?.email) { + public async start(): Promise { + // Spawn Rust binary + const spawned = await this.bridge.spawn(); + if (!spawned) { throw new Error( - 'ACME email is required for automatic certificate provisioning. ' + - 'Please provide email in either:\n' + - '1. Top-level "acme" configuration\n' + - '2. Individual route\'s "tls.acme" configuration' + 'RustProxy binary not found. Set SMARTPROXY_RUST_BINARY env var, install the platform package, ' + + 'or build locally with: cd rust && cargo build --release' ); } - - // Use the helper method to create and configure the certificate manager - this.certManager = await this.createCertificateManager( - this.settings.routes, - this.settings.acme?.certificateStore || './certs', - acmeOptions - ); - } - - /** - * Check if we have routes with static certificates - */ - private hasStaticCertRoutes(): boolean { - return this.settings.routes.some(r => - r.action.tls?.certificate && - r.action.tls.certificate !== 'auto' - ); - } - - /** - * Start the proxy server with support for both configuration types - */ - public async start() { - // Don't start if already shutting down - if (this.isShuttingDown) { - logger.log('warn', "Cannot start SmartProxy while it's in the shutdown process"); - return; - } - // Validate the route configuration - const configWarnings = this.routeManager.validateConfiguration(); - - // Also validate ACME configuration - const acmeWarnings = this.validateAcmeConfiguration(); - const allWarnings = [...configWarnings, ...acmeWarnings]; - - if (allWarnings.length > 0) { - logger.log('warn', `${allWarnings.length} configuration warnings found`, { count: allWarnings.length }); - for (const warning of allWarnings) { - logger.log('warn', `${warning}`); - } - } - - // Get listening ports from RouteManager - const listeningPorts = this.routeManager.getListeningPorts(); - - // Initialize port usage tracking using RouteOrchestrator - this.portUsageMap = this.routeOrchestrator.updatePortUsageMap(this.settings.routes); - - // Log port usage for startup - logger.log('info', `SmartProxy starting with ${listeningPorts.length} ports: ${listeningPorts.join(', ')}`, { - portCount: listeningPorts.length, - ports: listeningPorts, - component: 'smart-proxy' + // Handle unexpected exit + this.bridge.on('exit', (code: number | null, signal: string | null) => { + logger.log('error', `RustProxy exited unexpectedly (code=${code}, signal=${signal})`, { component: 'smart-proxy' }); + this.emit('error', new Error(`RustProxy exited (code=${code}, signal=${signal})`)); }); - // Provision NFTables rules for routes that use NFTables - for (const route of this.settings.routes) { - if (route.action.forwardingEngine === 'nftables') { - await this.nftablesManager.provisionRoute(route); - } + // Start socket handler relay if any routes need TS-side handling + const hasHandlerRoutes = this.settings.routes.some( + (r) => + (r.action.type === 'socket-handler' && r.action.socketHandler) || + r.action.targets?.some((t) => typeof t.host === 'function' || typeof t.port === 'function') + ); + + if (hasHandlerRoutes) { + this.socketHandlerServer = new SocketHandlerServer(this.preprocessor); + await this.socketHandlerServer.start(); + await this.bridge.setSocketHandlerRelay(this.socketHandlerServer.getSocketPath()); } - // Initialize and start HttpProxy if needed - before port binding - if (this.settings.useHttpProxy && this.settings.useHttpProxy.length > 0) { - await this.httpProxyBridge.initialize(); - await this.httpProxyBridge.start(); - } + // Preprocess routes (strip JS functions, convert socket-handler routes) + const rustRoutes = this.preprocessor.preprocessForRust(this.settings.routes); - // Start port listeners using the PortManager BEFORE initializing certificate manager - // This ensures all required ports are bound and ready when adding ACME challenge routes - await this.portManager.addPorts(listeningPorts); - - // Initialize certificate manager AFTER port binding is complete - // This ensures the ACME challenge port is already bound and ready when needed - await this.initializeCertificateManager(); - - // Connect certificate manager with HttpProxy if both are available - if (this.certManager && this.httpProxyBridge.getHttpProxy()) { - this.certManager.setHttpProxy(this.httpProxyBridge.getHttpProxy()); - } + // Build Rust config + const config = this.buildRustConfig(rustRoutes); - // Now that ports are listening, provision any required certificates - if (this.certManager) { - logger.log('info', 'Starting certificate provisioning now that ports are ready', { component: 'certificate-manager' }); - await this.certManager.provisionAllCertificates(); - } - - // Start the metrics collector now that all components are initialized - this.metricsCollector.start(); + // Start the Rust proxy + await this.bridge.startProxy(config); - // Set up periodic connection logging and inactivity checks - this.connectionLogger = setInterval(() => { - // Immediately return if shutting down - if (this.isShuttingDown) return; + // Handle certProvisionFunction + await this.provisionCertificatesViaCallback(); - // Perform inactivity check - this.connectionManager.performInactivityCheck(); + // Start metrics polling + this.metricsAdapter.startPolling(); - // Log connection statistics - const now = Date.now(); - let maxIncoming = 0; - let maxOutgoing = 0; - let tlsConnections = 0; - let nonTlsConnections = 0; - let completedTlsHandshakes = 0; - let pendingTlsHandshakes = 0; - let keepAliveConnections = 0; - let httpProxyConnections = 0; - - // Get connection records for analysis - const connectionRecords = this.connectionManager.getConnections(); - - // Analyze active connections - for (const record of connectionRecords.values()) { - // Track connection stats - if (record.isTLS) { - tlsConnections++; - if (record.tlsHandshakeComplete) { - completedTlsHandshakes++; - } else { - pendingTlsHandshakes++; - } - } else { - nonTlsConnections++; - } - - if (record.hasKeepAlive) { - keepAliveConnections++; - } - - if (record.usingNetworkProxy) { - httpProxyConnections++; - } - - maxIncoming = Math.max(maxIncoming, now - record.incomingStartTime); - if (record.outgoingStartTime) { - maxOutgoing = Math.max(maxOutgoing, now - record.outgoingStartTime); - } - } - - // Get termination stats - const terminationStats = this.connectionManager.getTerminationStats(); - - // Log detailed stats - logger.log('info', 'Connection statistics', { - activeConnections: connectionRecords.size, - tls: { - total: tlsConnections, - completed: completedTlsHandshakes, - pending: pendingTlsHandshakes - }, - nonTls: nonTlsConnections, - keepAlive: keepAliveConnections, - httpProxy: httpProxyConnections, - longestRunning: { - incoming: plugins.prettyMs(maxIncoming), - outgoing: plugins.prettyMs(maxOutgoing) - }, - terminationStats: { - incoming: terminationStats.incoming, - outgoing: terminationStats.outgoing - }, - component: 'connection-manager' - }); - }, this.settings.inactivityCheckInterval || 60000); - - // Make sure the interval doesn't keep the process alive - if (this.connectionLogger.unref) { - this.connectionLogger.unref(); - } + logger.log('info', 'SmartProxy started (Rust engine)', { component: 'smart-proxy' }); } - + /** - * Extract domain configurations from routes for certificate provisioning - * - * Note: This method has been removed as we now work directly with routes + * Stop the proxy. */ - - /** - * Stop the proxy server - */ - public async stop() { - logger.log('info', 'SmartProxy shutting down...'); - this.isShuttingDown = true; - this.portManager.setShuttingDown(true); - - // Stop certificate manager - if (this.certManager) { - await this.certManager.stop(); - logger.log('info', 'Certificate manager stopped'); - } - - // Stop NFTablesManager - await this.nftablesManager.stop(); - logger.log('info', 'NFTablesManager stopped'); + public async stop(): Promise { + logger.log('info', 'SmartProxy shutting down...', { component: 'smart-proxy' }); - // Stop the connection logger - if (this.connectionLogger) { - clearInterval(this.connectionLogger); - this.connectionLogger = null; - } + // Stop metrics polling + this.metricsAdapter.stopPolling(); - // Stop all port listeners - await this.portManager.closeAll(); - logger.log('info', 'All servers closed. Cleaning up active connections...'); - - // Clean up all active connections - await this.connectionManager.clearConnections(); - - // Stop HttpProxy - await this.httpProxyBridge.stop(); - - // Clear ACME state manager - this.acmeStateManager.clear(); - - // Stop metrics collector - this.metricsCollector.stop(); - - // Clean up ProtocolDetector singleton - const detection = await import('../../detection/index.js'); - detection.ProtocolDetector.destroy(); - - // Flush any pending deduplicated logs - connectionLogDeduplicator.flushAll(); - - logger.log('info', 'SmartProxy shutdown complete.'); - } - - /** - * Updates the domain configurations for the proxy - * - * Note: This legacy method has been removed. Use updateRoutes instead. - */ - public async updateDomainConfigs(): Promise { - logger.log('warn', 'Method updateDomainConfigs() is deprecated. Use updateRoutes() instead.'); - throw new Error('updateDomainConfigs() is deprecated - use updateRoutes() instead'); - } - - /** - * Verify the challenge route has been properly removed from routes - */ - private async verifyChallengeRouteRemoved(): Promise { - const maxRetries = 10; - const retryDelay = 100; // milliseconds - - for (let i = 0; i < maxRetries; i++) { - // Check if the challenge route is still in the active routes - const challengeRouteExists = this.settings.routes.some(r => r.name === 'acme-challenge'); - - if (!challengeRouteExists) { - try { - logger.log('info', 'Challenge route successfully removed from routes'); - } catch (error) { - // Silently handle logging errors - console.log('[INFO] Challenge route successfully removed from routes'); - } - return; - } - - // Wait before retrying - await plugins.smartdelay.delayFor(retryDelay); - } - - const error = `Failed to verify challenge route removal after ${maxRetries} attempts`; + // Stop Rust proxy try { - logger.log('error', error); - } catch (logError) { - // Silently handle logging errors - console.log(`[ERROR] ${error}`); + await this.bridge.stopProxy(); + } catch { + // Ignore if already stopped } - throw new Error(error); + this.bridge.kill(); + + // Stop socket handler relay + if (this.socketHandlerServer) { + await this.socketHandlerServer.stop(); + this.socketHandlerServer = null; + } + + logger.log('info', 'SmartProxy shutdown complete.', { component: 'smart-proxy' }); } - + /** - * Update routes with new configuration - * - * This method replaces the current route configuration with the provided routes. - * It also provisions certificates for routes that require TLS termination and have - * `certificate: 'auto'` set in their TLS configuration. - * - * @param newRoutes Array of route configurations to use - * - * Example: - * ```ts - * proxy.updateRoutes([ - * { - * match: { ports: 443, domains: 'secure.example.com' }, - * action: { - * type: 'forward', - * target: { host: '10.0.0.1', port: 8443 }, - * tls: { mode: 'terminate', certificate: 'auto' } - * } - * } - * ]); - * ``` + * Update routes atomically. */ public async updateRoutes(newRoutes: IRouteConfig[]): Promise { return this.routeUpdateLock.runExclusive(async () => { - try { - logger.log('info', `Updating routes (${newRoutes.length} routes)`, { - routeCount: newRoutes.length, - component: 'smart-proxy' - }); - } catch (error) { - // Silently handle logging errors - console.log(`[INFO] Updating routes (${newRoutes.length} routes)`); + // Validate + const validation = RouteValidator.validateRoutes(newRoutes); + if (!validation.valid) { + RouteValidator.logValidationErrors(validation.errors); + throw new Error(`Route validation failed: ${validation.errors.size} route(s) have errors`); } - // Update route orchestrator dependencies if cert manager changed - if (this.certManager && !this.routeOrchestrator.getCertManager()) { - this.routeOrchestrator.setCertManager(this.certManager); - } - - // Delegate the complex route update logic to RouteOrchestrator - const updateResult = await this.routeOrchestrator.updateRoutes( - this.settings.routes, - newRoutes, - { - acmePort: this.settings.acme?.port || 80, - acmeOptions: this.certManager?.getAcmeOptions(), - acmeState: this.certManager?.getState(), - globalChallengeRouteActive: this.globalChallengeRouteActive, - createCertificateManager: this.createCertificateManager.bind(this), - verifyChallengeRouteRemoved: this.verifyChallengeRouteRemoved.bind(this) - } + // Preprocess for Rust + const rustRoutes = this.preprocessor.preprocessForRust(newRoutes); + + // Send to Rust + await this.bridge.updateRoutes(rustRoutes); + + // Update local route manager + this.routeManager.updateRoutes(newRoutes); + + // Update socket handler relay if handler routes changed + const hasHandlerRoutes = newRoutes.some( + (r) => + (r.action.type === 'socket-handler' && r.action.socketHandler) || + r.action.targets?.some((t) => typeof t.host === 'function' || typeof t.port === 'function') ); - - // Update settings with the new routes - this.settings.routes = newRoutes; - - // Update global state from orchestrator results - this.globalChallengeRouteActive = updateResult.newChallengeRouteActive; - - // Update port usage map from orchestrator - this.portUsageMap = updateResult.portUsageMap; - - // If certificate manager was recreated, update our reference - if (updateResult.newCertManager) { - this.certManager = updateResult.newCertManager; - // Update the orchestrator's reference too - this.routeOrchestrator.setCertManager(this.certManager); + + if (hasHandlerRoutes && !this.socketHandlerServer) { + this.socketHandlerServer = new SocketHandlerServer(this.preprocessor); + await this.socketHandlerServer.start(); + await this.bridge.setSocketHandlerRelay(this.socketHandlerServer.getSocketPath()); + } else if (!hasHandlerRoutes && this.socketHandlerServer) { + await this.socketHandlerServer.stop(); + this.socketHandlerServer = null; } + + // Update stored routes + this.settings.routes = newRoutes; + + // Handle cert provisioning for new routes + await this.provisionCertificatesViaCallback(); + + logger.log('info', `Routes updated (${newRoutes.length} routes)`, { component: 'smart-proxy' }); }); } - + /** - * Manually provision a certificate for a route + * Provision a certificate for a named route. */ public async provisionCertificate(routeName: string): Promise { - if (!this.certManager) { - throw new Error('Certificate manager not initialized'); - } - - const route = this.settings.routes.find(r => r.name === routeName); - if (!route) { - throw new Error(`Route ${routeName} not found`); - } - - await this.certManager.provisionCertificate(route); + await this.bridge.provisionCertificate(routeName); } - // Port usage tracking methods moved to RouteOrchestrator - /** - * Force renewal of a certificate + * Force renewal of a certificate. */ public async renewCertificate(routeName: string): Promise { - if (!this.certManager) { - throw new Error('Certificate manager not initialized'); - } - - await this.certManager.renewCertificate(routeName); + await this.bridge.renewCertificate(routeName); } - + /** - * Get certificate status for a route + * Get certificate status for a route (async - calls Rust). */ - public getCertificateStatus(routeName: string): ICertStatus | undefined { - if (!this.certManager) { - return undefined; - } - - return this.certManager.getCertificateStatus(routeName); + public async getCertificateStatus(routeName: string): Promise { + return this.bridge.getCertificateStatus(routeName); } - + /** - * Get proxy metrics with clean API - * - * @returns IMetrics interface with grouped metrics methods + * Get the metrics interface. */ public getMetrics(): IMetrics { - return this.metricsCollector; + return this.metricsAdapter; } - + /** - * Validates if a domain name is valid for certificate issuance + * Get statistics (async - calls Rust). */ - private isValidDomain(domain: string): boolean { - // Very basic domain validation - if (!domain || domain.length === 0) { - return false; - } - - // Check for wildcard domains (they can't get ACME certs) - if (domain.includes('*')) { - logger.log('warn', `Wildcard domains like "${domain}" are not supported for automatic ACME certificates`, { domain, component: 'certificate-manager' }); - return false; - } - - // Check if domain has at least one dot and no invalid characters - const validDomainRegex = /^[a-zA-Z0-9]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?(\.[a-zA-Z0-9]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)*$/; - if (!validDomainRegex.test(domain)) { - logger.log('warn', `Domain "${domain}" has invalid format for certificate issuance`, { domain, component: 'certificate-manager' }); - return false; - } - - return true; + public async getStatistics(): Promise { + return this.bridge.getStatistics(); } - + /** - * Add a new listening port without changing the route configuration - * - * This allows you to add a port listener without updating routes. - * Useful for preparing to listen on a port before adding routes for it. - * - * @param port The port to start listening on - * @returns Promise that resolves when the port is listening + * Add a listening port at runtime. */ public async addListeningPort(port: number): Promise { - return this.portManager.addPort(port); + await this.bridge.addListeningPort(port); } /** - * Stop listening on a specific port without changing the route configuration - * - * This allows you to stop a port listener without updating routes. - * Useful for temporary maintenance or port changes. - * - * @param port The port to stop listening on - * @returns Promise that resolves when the port is closed + * Remove a listening port at runtime. */ public async removeListeningPort(port: number): Promise { - return this.portManager.removePort(port); + await this.bridge.removeListeningPort(port); } /** - * Get a list of all ports currently being listened on - * - * @returns Array of port numbers + * Get all currently listening ports (async - calls Rust). */ - public getListeningPorts(): number[] { - return this.portManager.getListeningPorts(); + public async getListeningPorts(): Promise { + return this.bridge.getListeningPorts(); } /** - * Get statistics about current connections - */ - public getStatistics(): any { - const connectionRecords = this.connectionManager.getConnections(); - const terminationStats = this.connectionManager.getTerminationStats(); - - let tlsConnections = 0; - let nonTlsConnections = 0; - let keepAliveConnections = 0; - let httpProxyConnections = 0; - - // Analyze active connections - for (const record of connectionRecords.values()) { - if (record.isTLS) tlsConnections++; - else nonTlsConnections++; - if (record.hasKeepAlive) keepAliveConnections++; - if (record.usingNetworkProxy) httpProxyConnections++; - } - - return { - activeConnections: connectionRecords.size, - tlsConnections, - nonTlsConnections, - keepAliveConnections, - httpProxyConnections, - terminationStats, - acmeEnabled: !!this.certManager, - port80HandlerPort: this.certManager ? 80 : null, - routeCount: this.settings.routes.length, - activePorts: this.portManager.getListeningPorts().length, - listeningPorts: this.portManager.getListeningPorts() - }; - } - - /** - * Get a list of eligible domains for ACME certificates + * Get eligible domains for ACME certificates (sync - reads local routes). */ public getEligibleDomainsForCertificates(): string[] { const domains: string[] = []; - - // Get domains from routes - const routes = this.settings.routes || []; - - for (const route of routes) { + for (const route of this.settings.routes || []) { if (!route.match.domains) continue; - - // Skip routes without TLS termination or auto certificates - if (route.action.type !== 'forward' || - !route.action.tls || - route.action.tls.mode === 'passthrough' || - route.action.tls.certificate !== 'auto') continue; - - const routeDomains = Array.isArray(route.match.domains) - ? route.match.domains - : [route.match.domains]; - - // Skip domains that can't be used with ACME - const eligibleDomains = routeDomains.filter(domain => - !domain.includes('*') && this.isValidDomain(domain) - ); - - domains.push(...eligibleDomains); + if ( + route.action.type !== 'forward' || + !route.action.tls || + route.action.tls.mode === 'passthrough' || + route.action.tls.certificate !== 'auto' + ) + continue; + + const routeDomains = Array.isArray(route.match.domains) ? route.match.domains : [route.match.domains]; + const eligible = routeDomains.filter((d) => !d.includes('*') && this.isValidDomain(d)); + domains.push(...eligible); } - - // Legacy mode is no longer supported - return domains; } - + /** - * Get NFTables status + * Get NFTables status (async - calls Rust). */ public async getNfTablesStatus(): Promise> { - return this.nftablesManager.getStatus(); - } - - /** - * Validate ACME configuration - */ - private validateAcmeConfiguration(): string[] { - const warnings: string[] = []; - - // Check for routes with certificate: 'auto' - const autoRoutes = this.settings.routes.filter(r => - r.action.tls?.certificate === 'auto' - ); - - if (autoRoutes.length === 0) { - return warnings; - } - - // Check if we have ACME email configuration - const hasTopLevelEmail = this.settings.acme?.email; - const routesWithEmail = autoRoutes.filter(r => r.action.tls?.acme?.email); - - if (!hasTopLevelEmail && routesWithEmail.length === 0) { - warnings.push( - 'Routes with certificate: "auto" require ACME email configuration. ' + - 'Add email to either top-level "acme" config or individual route\'s "tls.acme" config.' - ); - } - - // Check for port 80 availability for challenges - if (autoRoutes.length > 0) { - const challengePort = this.settings.acme?.port || 80; - const portsInUse = this.routeManager.getListeningPorts(); - - if (!portsInUse.includes(challengePort)) { - warnings.push( - `Port ${challengePort} is not configured for any routes but is needed for ACME challenges. ` + - `Add a route listening on port ${challengePort} or ensure it's accessible for HTTP-01 challenges.` - ); - } - } - - // Check for mismatched environments - if (this.settings.acme?.useProduction) { - const stagingRoutes = autoRoutes.filter(r => - r.action.tls?.acme?.useProduction === false - ); - if (stagingRoutes.length > 0) { - warnings.push( - 'Top-level ACME uses production but some routes use staging. ' + - 'Consider aligning environments to avoid certificate issues.' - ); - } - } - - // Check for wildcard domains with auto certificates - for (const route of autoRoutes) { - const domains = Array.isArray(route.match.domains) - ? route.match.domains - : [route.match.domains]; - - const wildcardDomains = domains.filter(d => d?.includes('*')); - if (wildcardDomains.length > 0) { - warnings.push( - `Route "${route.name}" has wildcard domain(s) ${wildcardDomains.join(', ')} ` + - 'with certificate: "auto". Wildcard certificates require DNS-01 challenges, ' + - 'which are not currently supported. Use static certificates instead.' - ); - } - } - - return warnings; + return this.bridge.getNftablesStatus(); } -} \ No newline at end of file + // --- Private helpers --- + + /** + * Build the Rust configuration object from TS settings. + */ + private buildRustConfig(routes: IRouteConfig[]): any { + return { + routes, + defaults: this.settings.defaults, + acme: this.settings.acme + ? { + enabled: this.settings.acme.enabled, + email: this.settings.acme.email, + useProduction: this.settings.acme.useProduction, + port: this.settings.acme.port, + renewThresholdDays: this.settings.acme.renewThresholdDays, + autoRenew: this.settings.acme.autoRenew, + certificateStore: this.settings.acme.certificateStore, + renewCheckIntervalHours: this.settings.acme.renewCheckIntervalHours, + } + : undefined, + connectionTimeout: this.settings.connectionTimeout, + initialDataTimeout: this.settings.initialDataTimeout, + socketTimeout: this.settings.socketTimeout, + maxConnectionLifetime: this.settings.maxConnectionLifetime, + gracefulShutdownTimeout: this.settings.gracefulShutdownTimeout, + maxConnectionsPerIp: this.settings.maxConnectionsPerIP, + connectionRateLimitPerMinute: this.settings.connectionRateLimitPerMinute, + keepAliveTreatment: this.settings.keepAliveTreatment, + keepAliveInactivityMultiplier: this.settings.keepAliveInactivityMultiplier, + extendedKeepAliveLifetime: this.settings.extendedKeepAliveLifetime, + acceptProxyProtocol: this.settings.acceptProxyProtocol, + sendProxyProtocol: this.settings.sendProxyProtocol, + }; + } + + /** + * For routes with certificate: 'auto', call certProvisionFunction if set. + * If the callback returns a cert object, load it into Rust. + * If it returns 'http01', let Rust handle ACME. + */ + private async provisionCertificatesViaCallback(): Promise { + const provisionFn = this.settings.certProvisionFunction; + if (!provisionFn) return; + + for (const route of this.settings.routes) { + if (route.action.tls?.certificate !== 'auto') continue; + if (!route.match.domains) continue; + + const domains = Array.isArray(route.match.domains) ? route.match.domains : [route.match.domains]; + + for (const domain of domains) { + if (domain.includes('*')) continue; + + try { + const result: TSmartProxyCertProvisionObject = await provisionFn(domain); + + if (result === 'http01') { + // Rust handles ACME for this domain + continue; + } + + // Got a static cert object - load it into Rust + if (result && typeof result === 'object') { + const certObj = result as plugins.tsclass.network.ICert; + await this.bridge.loadCertificate( + domain, + certObj.publicKey, + certObj.privateKey, + ); + logger.log('info', `Certificate loaded via provision function for ${domain}`, { component: 'smart-proxy' }); + } + } catch (err: any) { + logger.log('warn', `certProvisionFunction failed for ${domain}: ${err.message}`, { component: 'smart-proxy' }); + + // Fallback to ACME if enabled + if (this.settings.certProvisionFallbackToAcme !== false) { + logger.log('info', `Falling back to ACME for ${domain}`, { component: 'smart-proxy' }); + } + } + } + } + } + + private isValidDomain(domain: string): boolean { + if (!domain || domain.length === 0) return false; + if (domain.includes('*')) return false; + const validDomainRegex = + /^[a-zA-Z0-9]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?(\.[a-zA-Z0-9]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)*$/; + return validDomainRegex.test(domain); + } +} diff --git a/ts/proxies/smart-proxy/socket-handler-server.ts b/ts/proxies/smart-proxy/socket-handler-server.ts new file mode 100644 index 0000000..b475515 --- /dev/null +++ b/ts/proxies/smart-proxy/socket-handler-server.ts @@ -0,0 +1,178 @@ +import * as plugins from '../../plugins.js'; +import { logger } from '../../core/utils/logger.js'; +import type { IRouteConfig, IRouteContext } from './models/route-types.js'; +import type { RoutePreprocessor } from './route-preprocessor.js'; + +/** + * Unix domain socket server that receives relayed connections from the Rust proxy. + * + * When Rust encounters a route of type `socket-handler`, it connects to this + * Unix socket, sends a JSON metadata line, then proxies the raw TCP bytes. + * This server reads the metadata, finds the original JS handler, builds an + * IRouteContext, and hands the socket to the handler. + */ +export class SocketHandlerServer { + private server: plugins.net.Server | null = null; + private socketPath: string; + private preprocessor: RoutePreprocessor; + + constructor(preprocessor: RoutePreprocessor) { + this.preprocessor = preprocessor; + this.socketPath = `/tmp/smartproxy-relay-${process.pid}.sock`; + } + + /** + * The Unix socket path this server listens on. + */ + public getSocketPath(): string { + return this.socketPath; + } + + /** + * Start listening for relayed connections from Rust. + */ + public async start(): Promise { + // Clean up stale socket file + try { + await plugins.fs.promises.unlink(this.socketPath); + } catch { + // Ignore if doesn't exist + } + + return new Promise((resolve, reject) => { + this.server = plugins.net.createServer((socket) => { + this.handleConnection(socket); + }); + + this.server.on('error', (err) => { + logger.log('error', `SocketHandlerServer error: ${err.message}`, { component: 'socket-handler-server' }); + }); + + this.server.listen(this.socketPath, () => { + logger.log('info', `SocketHandlerServer listening on ${this.socketPath}`, { component: 'socket-handler-server' }); + resolve(); + }); + + this.server.on('error', reject); + }); + } + + /** + * Stop the server and clean up. + */ + public async stop(): Promise { + if (this.server) { + return new Promise((resolve) => { + this.server!.close(() => { + this.server = null; + // Clean up socket file + plugins.fs.unlink(this.socketPath, () => resolve()); + }); + }); + } + } + + /** + * Handle an incoming relayed connection from Rust. + * + * Protocol: Rust sends a single JSON line with metadata, then raw bytes follow. + * JSON format: { "routeKey": "my-route", "remoteIP": "1.2.3.4", "remotePort": 12345, + * "localPort": 443, "isTLS": true, "domain": "example.com" } + */ + private handleConnection(socket: plugins.net.Socket): void { + let metadataBuffer = ''; + let metadataParsed = false; + + const onData = (chunk: Buffer) => { + if (metadataParsed) return; + + metadataBuffer += chunk.toString('utf8'); + const newlineIndex = metadataBuffer.indexOf('\n'); + + if (newlineIndex === -1) { + // Haven't received full metadata line yet + if (metadataBuffer.length > 8192) { + logger.log('error', 'Socket handler metadata too large, closing', { component: 'socket-handler-server' }); + socket.destroy(); + } + return; + } + + metadataParsed = true; + socket.removeListener('data', onData); + + const metadataJson = metadataBuffer.slice(0, newlineIndex); + const remainingData = metadataBuffer.slice(newlineIndex + 1); + + let metadata: any; + try { + metadata = JSON.parse(metadataJson); + } catch { + logger.log('error', `Invalid socket handler metadata JSON: ${metadataJson.slice(0, 200)}`, { component: 'socket-handler-server' }); + socket.destroy(); + return; + } + + this.dispatchToHandler(socket, metadata, remainingData); + }; + + socket.on('data', onData); + socket.on('error', (err) => { + logger.log('error', `Socket handler relay error: ${err.message}`, { component: 'socket-handler-server' }); + }); + } + + /** + * Dispatch a relayed connection to the appropriate JS handler. + */ + private dispatchToHandler(socket: plugins.net.Socket, metadata: any, remainingData: string): void { + const routeKey = metadata.routeKey as string; + if (!routeKey) { + logger.log('error', 'Socket handler relay missing routeKey', { component: 'socket-handler-server' }); + socket.destroy(); + return; + } + + const originalRoute = this.preprocessor.getOriginalRoute(routeKey); + if (!originalRoute) { + logger.log('error', `No handler found for route: ${routeKey}`, { component: 'socket-handler-server' }); + socket.destroy(); + return; + } + + const handler = originalRoute.action.socketHandler; + if (!handler) { + logger.log('error', `Route ${routeKey} has no socketHandler`, { component: 'socket-handler-server' }); + socket.destroy(); + return; + } + + // Build route context + const context: IRouteContext = { + port: metadata.localPort || 0, + domain: metadata.domain, + clientIp: metadata.remoteIP || 'unknown', + serverIp: '0.0.0.0', + path: metadata.path, + isTls: metadata.isTLS || false, + tlsVersion: metadata.tlsVersion, + routeName: originalRoute.name, + routeId: originalRoute.id, + timestamp: Date.now(), + connectionId: metadata.connectionId || `relay-${Date.now()}`, + }; + + // If there was remaining data after the metadata line, push it back + if (remainingData.length > 0) { + socket.unshift(Buffer.from(remainingData, 'utf8')); + } + + // Call the handler + try { + handler(socket, context); + } catch (err: any) { + logger.log('error', `Socket handler threw for route ${routeKey}: ${err.message}`, { component: 'socket-handler-server' }); + socket.destroy(); + } + } +} diff --git a/ts/proxies/smart-proxy/throughput-tracker.ts b/ts/proxies/smart-proxy/throughput-tracker.ts deleted file mode 100644 index dd9938c..0000000 --- a/ts/proxies/smart-proxy/throughput-tracker.ts +++ /dev/null @@ -1,138 +0,0 @@ -import type { IThroughputSample, IThroughputData, IThroughputHistoryPoint } from './models/metrics-types.js'; - -/** - * Tracks throughput data using time-series sampling - */ -export class ThroughputTracker { - private samples: IThroughputSample[] = []; - private readonly maxSamples: number; - private accumulatedBytesIn: number = 0; - private accumulatedBytesOut: number = 0; - private lastSampleTime: number = 0; - - constructor(retentionSeconds: number = 3600) { - // Keep samples for the retention period at 1 sample per second - this.maxSamples = retentionSeconds; - } - - /** - * Record bytes transferred (called on every data transfer) - */ - public recordBytes(bytesIn: number, bytesOut: number): void { - this.accumulatedBytesIn += bytesIn; - this.accumulatedBytesOut += bytesOut; - } - - /** - * Take a sample of accumulated bytes (called every second) - */ - public takeSample(): void { - const now = Date.now(); - - // Record accumulated bytes since last sample - this.samples.push({ - timestamp: now, - bytesIn: this.accumulatedBytesIn, - bytesOut: this.accumulatedBytesOut - }); - - // Reset accumulators - this.accumulatedBytesIn = 0; - this.accumulatedBytesOut = 0; - this.lastSampleTime = now; - - // Maintain circular buffer - remove oldest samples - if (this.samples.length > this.maxSamples) { - this.samples.shift(); - } - } - - /** - * Get throughput rate over specified window (bytes per second) - */ - public getRate(windowSeconds: number): IThroughputData { - if (this.samples.length === 0) { - return { in: 0, out: 0 }; - } - - const now = Date.now(); - const windowStart = now - (windowSeconds * 1000); - - // Find samples within the window - const relevantSamples = this.samples.filter(s => s.timestamp > windowStart); - - if (relevantSamples.length === 0) { - return { in: 0, out: 0 }; - } - - // Calculate total bytes in window - const totalBytesIn = relevantSamples.reduce((sum, s) => sum + s.bytesIn, 0); - const totalBytesOut = relevantSamples.reduce((sum, s) => sum + s.bytesOut, 0); - - // Use actual number of seconds covered by samples for accurate rate - const oldestSampleTime = relevantSamples[0].timestamp; - const newestSampleTime = relevantSamples[relevantSamples.length - 1].timestamp; - const actualSeconds = Math.max(1, (newestSampleTime - oldestSampleTime) / 1000 + 1); - - return { - in: Math.round(totalBytesIn / actualSeconds), - out: Math.round(totalBytesOut / actualSeconds) - }; - } - - /** - * Get throughput history for specified duration - */ - public getHistory(durationSeconds: number): IThroughputHistoryPoint[] { - const now = Date.now(); - const startTime = now - (durationSeconds * 1000); - - // Filter samples within duration - const relevantSamples = this.samples.filter(s => s.timestamp > startTime); - - // Convert to history points with per-second rates - const history: IThroughputHistoryPoint[] = []; - - for (let i = 0; i < relevantSamples.length; i++) { - const sample = relevantSamples[i]; - - // For the first sample or samples after gaps, we can't calculate rate - if (i === 0 || sample.timestamp - relevantSamples[i - 1].timestamp > 2000) { - history.push({ - timestamp: sample.timestamp, - in: sample.bytesIn, - out: sample.bytesOut - }); - } else { - // Calculate rate based on time since previous sample - const prevSample = relevantSamples[i - 1]; - const timeDelta = (sample.timestamp - prevSample.timestamp) / 1000; - - history.push({ - timestamp: sample.timestamp, - in: Math.round(sample.bytesIn / timeDelta), - out: Math.round(sample.bytesOut / timeDelta) - }); - } - } - - return history; - } - - /** - * Clear all samples - */ - public clear(): void { - this.samples = []; - this.accumulatedBytesIn = 0; - this.accumulatedBytesOut = 0; - this.lastSampleTime = 0; - } - - /** - * Get sample count for debugging - */ - public getSampleCount(): number { - return this.samples.length; - } -} \ No newline at end of file diff --git a/ts/proxies/smart-proxy/timeout-manager.ts b/ts/proxies/smart-proxy/timeout-manager.ts deleted file mode 100644 index 8653ba8..0000000 --- a/ts/proxies/smart-proxy/timeout-manager.ts +++ /dev/null @@ -1,196 +0,0 @@ -import type { IConnectionRecord } from './models/interfaces.js'; -import type { SmartProxy } from './smart-proxy.js'; - -/** - * Manages timeouts and inactivity tracking for connections - */ -export class TimeoutManager { - constructor(private smartProxy: SmartProxy) {} - - /** - * Ensure timeout values don't exceed Node.js max safe integer - */ - public ensureSafeTimeout(timeout: number): number { - const MAX_SAFE_TIMEOUT = 2147483647; // Maximum safe value (2^31 - 1) - return Math.min(Math.floor(timeout), MAX_SAFE_TIMEOUT); - } - - /** - * Generate a slightly randomized timeout to prevent thundering herd - */ - public randomizeTimeout(baseTimeout: number, variationPercent: number = 5): number { - const safeBaseTimeout = this.ensureSafeTimeout(baseTimeout); - const variation = safeBaseTimeout * (variationPercent / 100); - return this.ensureSafeTimeout( - safeBaseTimeout + Math.floor(Math.random() * variation * 2) - variation - ); - } - - /** - * Update connection activity timestamp - */ - public updateActivity(record: IConnectionRecord): void { - record.lastActivity = Date.now(); - - // Clear any inactivity warning - if (record.inactivityWarningIssued) { - record.inactivityWarningIssued = false; - } - } - - /** - * Calculate effective inactivity timeout based on connection type - */ - public getEffectiveInactivityTimeout(record: IConnectionRecord): number { - let effectiveTimeout = this.smartProxy.settings.inactivityTimeout || 14400000; // 4 hours default - - // For immortal keep-alive connections, use an extremely long timeout - if (record.hasKeepAlive && this.smartProxy.settings.keepAliveTreatment === 'immortal') { - return Number.MAX_SAFE_INTEGER; - } - - // For extended keep-alive connections, apply multiplier - if (record.hasKeepAlive && this.smartProxy.settings.keepAliveTreatment === 'extended') { - const multiplier = this.smartProxy.settings.keepAliveInactivityMultiplier || 6; - effectiveTimeout = effectiveTimeout * multiplier; - } - - return this.ensureSafeTimeout(effectiveTimeout); - } - - /** - * Calculate effective max lifetime based on connection type - */ - public getEffectiveMaxLifetime(record: IConnectionRecord): number { - // Use route-specific timeout if available from the routeConfig - const baseTimeout = record.routeConfig?.action.advanced?.timeout || - this.smartProxy.settings.maxConnectionLifetime || - 86400000; // 24 hours default - - // For immortal keep-alive connections, use an extremely long lifetime - if (record.hasKeepAlive && this.smartProxy.settings.keepAliveTreatment === 'immortal') { - return Number.MAX_SAFE_INTEGER; - } - - // For extended keep-alive connections, use the extended lifetime setting - if (record.hasKeepAlive && this.smartProxy.settings.keepAliveTreatment === 'extended') { - return this.ensureSafeTimeout( - this.smartProxy.settings.extendedKeepAliveLifetime || 7 * 24 * 60 * 60 * 1000 // 7 days default - ); - } - - // Apply randomization if enabled - if (this.smartProxy.settings.enableRandomizedTimeouts) { - return this.randomizeTimeout(baseTimeout); - } - - return this.ensureSafeTimeout(baseTimeout); - } - - /** - * Setup connection timeout - * @returns The cleanup timer - */ - public setupConnectionTimeout( - record: IConnectionRecord, - onTimeout: (record: IConnectionRecord, reason: string) => void - ): NodeJS.Timeout | null { - // Clear any existing timer - if (record.cleanupTimer) { - clearTimeout(record.cleanupTimer); - } - - // Skip timeout for immortal keep-alive connections - if (record.hasKeepAlive && this.smartProxy.settings.keepAliveTreatment === 'immortal') { - return null; - } - - // Calculate effective timeout - const effectiveLifetime = this.getEffectiveMaxLifetime(record); - - // Set up the timeout - const timer = setTimeout(() => { - // Call the provided callback - onTimeout(record, 'connection_timeout'); - }, effectiveLifetime); - - // Make sure timeout doesn't keep the process alive - if (timer.unref) { - timer.unref(); - } - - return timer; - } - - /** - * Check for inactivity on a connection - * @returns Object with check results - */ - public checkInactivity(record: IConnectionRecord): { - isInactive: boolean; - shouldWarn: boolean; - inactivityTime: number; - effectiveTimeout: number; - } { - // Skip for connections with inactivity check disabled - if (this.smartProxy.settings.disableInactivityCheck) { - return { - isInactive: false, - shouldWarn: false, - inactivityTime: 0, - effectiveTimeout: 0 - }; - } - - // Skip for immortal keep-alive connections - if (record.hasKeepAlive && this.smartProxy.settings.keepAliveTreatment === 'immortal') { - return { - isInactive: false, - shouldWarn: false, - inactivityTime: 0, - effectiveTimeout: 0 - }; - } - - const now = Date.now(); - const inactivityTime = now - record.lastActivity; - const effectiveTimeout = this.getEffectiveInactivityTimeout(record); - - // Check if inactive - const isInactive = inactivityTime > effectiveTimeout; - - // For keep-alive connections, we should warn first - const shouldWarn = record.hasKeepAlive && - isInactive && - !record.inactivityWarningIssued; - - return { - isInactive, - shouldWarn, - inactivityTime, - effectiveTimeout - }; - } - - /** - * Apply socket timeout settings - */ - public applySocketTimeouts(record: IConnectionRecord): void { - // Skip for immortal keep-alive connections - if (record.hasKeepAlive && this.smartProxy.settings.keepAliveTreatment === 'immortal') { - // Disable timeouts completely for immortal connections - record.incoming.setTimeout(0); - if (record.outgoing) { - record.outgoing.setTimeout(0); - } - return; - } - - // Apply normal timeouts - const timeout = this.ensureSafeTimeout(this.smartProxy.settings.socketTimeout || 3600000); // 1 hour default - record.incoming.setTimeout(timeout); - if (record.outgoing) { - record.outgoing.setTimeout(timeout); - } - } -} \ No newline at end of file diff --git a/ts/proxies/smart-proxy/tls-manager.ts b/ts/proxies/smart-proxy/tls-manager.ts deleted file mode 100644 index 9efc169..0000000 --- a/ts/proxies/smart-proxy/tls-manager.ts +++ /dev/null @@ -1,171 +0,0 @@ -import * as plugins from '../../plugins.js'; -import { SniHandler } from '../../tls/sni/sni-handler.js'; -import { ProtocolDetector, TlsDetector } from '../../detection/index.js'; -import type { SmartProxy } from './smart-proxy.js'; - -/** - * Interface for connection information used for SNI extraction - */ -interface IConnectionInfo { - sourceIp: string; - sourcePort: number; - destIp: string; - destPort: number; -} - -/** - * Manages TLS-related operations including SNI extraction and validation - */ -export class TlsManager { - constructor(private smartProxy: SmartProxy) {} - - /** - * Check if a data chunk appears to be a TLS handshake - */ - public isTlsHandshake(chunk: Buffer): boolean { - return SniHandler.isTlsHandshake(chunk); - } - - /** - * Check if a data chunk appears to be a TLS ClientHello - */ - public isClientHello(chunk: Buffer): boolean { - return SniHandler.isClientHello(chunk); - } - - /** - * Extract Server Name Indication (SNI) from TLS handshake - */ - public extractSNI( - chunk: Buffer, - connInfo: IConnectionInfo, - previousDomain?: string - ): string | undefined { - // Use the SniHandler to process the TLS packet - return SniHandler.processTlsPacket( - chunk, - connInfo, - this.smartProxy.settings.enableTlsDebugLogging || false, - previousDomain - ); - } - -/** - * Check for SNI mismatch during renegotiation - */ - public checkRenegotiationSNI( - chunk: Buffer, - connInfo: IConnectionInfo, - expectedDomain: string, - connectionId: string - ): { hasMismatch: boolean; extractedSNI?: string } { - // Only process if this looks like a TLS ClientHello - if (!this.isClientHello(chunk)) { - return { hasMismatch: false }; - } - - try { - // Extract SNI with renegotiation support - const newSNI = SniHandler.extractSNIWithResumptionSupport( - chunk, - connInfo, - this.smartProxy.settings.enableTlsDebugLogging || false - ); - - // Skip if no SNI was found - if (!newSNI) return { hasMismatch: false }; - - // Check for SNI mismatch - if (newSNI !== expectedDomain) { - if (this.smartProxy.settings.enableTlsDebugLogging) { - console.log( - `[${connectionId}] Renegotiation with different SNI: ${expectedDomain} -> ${newSNI}. ` + - `Terminating connection - SNI domain switching is not allowed.` - ); - } - return { hasMismatch: true, extractedSNI: newSNI }; - } else if (this.smartProxy.settings.enableTlsDebugLogging) { - console.log( - `[${connectionId}] Renegotiation detected with same SNI: ${newSNI}. Allowing.` - ); - } - } catch (err) { - console.log( - `[${connectionId}] Error processing ClientHello: ${err}. Allowing connection to continue.` - ); - } - - return { hasMismatch: false }; - } - - /** - * Create a renegotiation handler function for a connection - */ - public createRenegotiationHandler( - connectionId: string, - lockedDomain: string, - connInfo: IConnectionInfo, - onMismatch: (connectionId: string, reason: string) => void - ): (chunk: Buffer) => void { - return (chunk: Buffer) => { - const result = this.checkRenegotiationSNI(chunk, connInfo, lockedDomain, connectionId); - if (result.hasMismatch) { - onMismatch(connectionId, 'sni_mismatch'); - } - }; - } - - /** - * Analyze TLS connection for browser fingerprinting - * This helps identify browser vs non-browser connections - */ - public analyzeClientHello(chunk: Buffer): { - isBrowserConnection: boolean; - isRenewal: boolean; - hasSNI: boolean; - } { - // Default result - const result = { - isBrowserConnection: false, - isRenewal: false, - hasSNI: false - }; - - try { - // Check if it's a ClientHello - if (!this.isClientHello(chunk)) { - return result; - } - - // Check for session resumption - const resumptionInfo = SniHandler.hasSessionResumption( - chunk, - this.smartProxy.settings.enableTlsDebugLogging || false - ); - - // Extract SNI - const sni = SniHandler.extractSNI( - chunk, - this.smartProxy.settings.enableTlsDebugLogging || false - ); - - // Update result - result.isRenewal = resumptionInfo.isResumption; - result.hasSNI = !!sni; - - // Browsers typically: - // 1. Send SNI extension - // 2. Have a variety of extensions (ALPN, etc.) - // 3. Use standard cipher suites - // ...more complex heuristics could be implemented here - - // Simple heuristic: presence of SNI suggests browser - result.isBrowserConnection = !!sni; - - return result; - } catch (err) { - console.log(`Error analyzing ClientHello: ${err}`); - return result; - } - } -} \ No newline at end of file diff --git a/ts/routing/index.ts b/ts/routing/index.ts index 1076983..879ad65 100644 --- a/ts/routing/index.ts +++ b/ts/routing/index.ts @@ -2,8 +2,8 @@ * Routing functionality module */ -// Export types and models from HttpProxy -export * from '../proxies/http-proxy/models/http-types.js'; +// Export types and models +export * from './models/http-types.js'; // Export router functionality export * from './router/index.js'; diff --git a/ts/routing/models/http-types.ts b/ts/routing/models/http-types.ts index 738c85d..1677eb6 100644 --- a/ts/routing/models/http-types.ts +++ b/ts/routing/models/http-types.ts @@ -1,6 +1,149 @@ /** - * This file re-exports HTTP types from the HttpProxy module - * for backward compatibility. All HTTP types are now consolidated - * in the HttpProxy module. + * HTTP types for routing module. + * These were previously in http-proxy and are now self-contained here. */ -export * from '../../proxies/http-proxy/models/http-types.js'; \ No newline at end of file +import * as plugins from '../../plugins.js'; +import { HttpStatus as ProtocolHttpStatus, getStatusText as getProtocolStatusText } from '../../protocols/http/index.js'; + +/** + * HTTP-specific event types + */ +export enum HttpEvents { + REQUEST_RECEIVED = 'request-received', + REQUEST_FORWARDED = 'request-forwarded', + REQUEST_HANDLED = 'request-handled', + REQUEST_ERROR = 'request-error', +} + +// Re-export for backward compatibility with subset of commonly used codes +export const HttpStatus = { + OK: ProtocolHttpStatus.OK, + MOVED_PERMANENTLY: ProtocolHttpStatus.MOVED_PERMANENTLY, + FOUND: ProtocolHttpStatus.FOUND, + TEMPORARY_REDIRECT: ProtocolHttpStatus.TEMPORARY_REDIRECT, + PERMANENT_REDIRECT: ProtocolHttpStatus.PERMANENT_REDIRECT, + BAD_REQUEST: ProtocolHttpStatus.BAD_REQUEST, + UNAUTHORIZED: ProtocolHttpStatus.UNAUTHORIZED, + FORBIDDEN: ProtocolHttpStatus.FORBIDDEN, + NOT_FOUND: ProtocolHttpStatus.NOT_FOUND, + METHOD_NOT_ALLOWED: ProtocolHttpStatus.METHOD_NOT_ALLOWED, + REQUEST_TIMEOUT: ProtocolHttpStatus.REQUEST_TIMEOUT, + TOO_MANY_REQUESTS: ProtocolHttpStatus.TOO_MANY_REQUESTS, + INTERNAL_SERVER_ERROR: ProtocolHttpStatus.INTERNAL_SERVER_ERROR, + NOT_IMPLEMENTED: ProtocolHttpStatus.NOT_IMPLEMENTED, + BAD_GATEWAY: ProtocolHttpStatus.BAD_GATEWAY, + SERVICE_UNAVAILABLE: ProtocolHttpStatus.SERVICE_UNAVAILABLE, + GATEWAY_TIMEOUT: ProtocolHttpStatus.GATEWAY_TIMEOUT, +} as const; + +/** + * Base error class for HTTP-related errors + */ +export class HttpError extends Error { + constructor(message: string, public readonly statusCode: number = HttpStatus.INTERNAL_SERVER_ERROR) { + super(message); + this.name = 'HttpError'; + } +} + +/** + * Error related to certificate operations + */ +export class CertificateError extends HttpError { + constructor( + message: string, + public readonly domain: string, + public readonly isRenewal: boolean = false + ) { + super(`${message} for domain ${domain}${isRenewal ? ' (renewal)' : ''}`, HttpStatus.INTERNAL_SERVER_ERROR); + this.name = 'CertificateError'; + } +} + +/** + * Error related to server operations + */ +export class ServerError extends HttpError { + constructor(message: string, public readonly code?: string, statusCode: number = HttpStatus.INTERNAL_SERVER_ERROR) { + super(message, statusCode); + this.name = 'ServerError'; + } +} + +/** + * Error for bad requests + */ +export class BadRequestError extends HttpError { + constructor(message: string) { + super(message, HttpStatus.BAD_REQUEST); + this.name = 'BadRequestError'; + } +} + +/** + * Error for not found resources + */ +export class NotFoundError extends HttpError { + constructor(message: string = 'Resource not found') { + super(message, HttpStatus.NOT_FOUND); + this.name = 'NotFoundError'; + } +} + +/** + * Redirect configuration for HTTP requests + */ +export interface IRedirectConfig { + source: string; + destination: string; + type: number; + preserveQuery?: boolean; +} + +/** + * HTTP router configuration + */ +export interface IRouterConfig { + routes: Array<{ + path: string; + method?: string; + handler: (req: plugins.http.IncomingMessage, res: plugins.http.ServerResponse) => void | Promise; + }>; + notFoundHandler?: (req: plugins.http.IncomingMessage, res: plugins.http.ServerResponse) => void; + errorHandler?: (error: Error, req: plugins.http.IncomingMessage, res: plugins.http.ServerResponse) => void; +} + +/** + * HTTP request method types + */ +export type HttpMethod = 'GET' | 'POST' | 'PUT' | 'DELETE' | 'PATCH' | 'HEAD' | 'OPTIONS' | 'CONNECT' | 'TRACE'; + +/** + * Helper function to get HTTP status text + */ +export function getStatusText(status: number): string { + return getProtocolStatusText(status as ProtocolHttpStatus); +} + +// Legacy interfaces for backward compatibility +export interface IDomainOptions { + domainName: string; + sslRedirect: boolean; + acmeMaintenance: boolean; + forward?: { ip: string; port: number }; + acmeForward?: { ip: string; port: number }; +} + +export interface IDomainCertificate { + options: IDomainOptions; + certObtained: boolean; + obtainingInProgress: boolean; + certificate?: string; + privateKey?: string; + expiryDate?: Date; + lastRenewalAttempt?: Date; +} + +// Backward compatibility exports +export { HttpError as Port80HandlerError }; +export { CertificateError as CertError };