Compare commits

...

26 Commits

Author SHA1 Message Date
jkunz e806f7257f v27.9.0
Default (tags) / security (push) Failing after 0s
Default (tags) / test (push) Failing after 0s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-04-26 15:11:10 +00:00
jkunz af4908b63f feat(smart-proxy): add hot-reloadable global ingress security policy across Rust and TypeScript proxy layers 2026-04-26 15:11:10 +00:00
jkunz 8fa3a51b03 v27.8.2
Default (tags) / security (push) Failing after 0s
Default (tags) / test (push) Failing after 0s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-04-26 11:25:24 +00:00
jkunz 088ef6ab09 fix(rustproxy-metrics): retain inactive per-IP metric buckets briefly to capture final throughput before pruning 2026-04-26 11:25:24 +00:00
jkunz fdb5ec59bc v27.8.1
Default (tags) / security (push) Failing after 0s
Default (tags) / test (push) Failing after 0s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-04-26 09:17:11 +00:00
jkunz 1ea290a085 fix(rustproxy-metrics): preserve high-throughput IPs in metrics snapshots when active-connection rankings are saturated 2026-04-26 09:17:11 +00:00
jkunz cb71f32b90 v27.8.0
Default (tags) / security (push) Failing after 0s
Default (tags) / test (push) Failing after 0s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-04-14 12:43:59 +00:00
jkunz 46155ab12c feat(metrics): add per-domain HTTP request rate metrics 2026-04-14 12:43:59 +00:00
jkunz 490a310b54 v27.7.4
Default (tags) / security (push) Failing after 0s
Default (tags) / test (push) Failing after 0s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-04-14 09:17:55 +00:00
jkunz 6c5180573a fix(rustproxy metrics): use stable route metrics keys across HTTP and passthrough listeners 2026-04-14 09:17:55 +00:00
jkunz 30e5ab308f v27.7.3
Default (tags) / security (push) Failing after 0s
Default (tags) / test (push) Failing after 0s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-04-14 01:14:33 +00:00
jkunz d2a54b3491 fix(repo): no changes detected 2026-04-14 01:14:33 +00:00
jkunz dc922c97df v27.7.2
Default (tags) / security (push) Failing after 0s
Default (tags) / test (push) Failing after 0s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-04-14 00:55:25 +00:00
jkunz 8d1bae7604 fix(docs): clarify metrics documentation for domain normalization and saturating gauges 2026-04-14 00:55:25 +00:00
jkunz 200e86e311 v27.7.1
Default (tags) / security (push) Failing after 0s
Default (tags) / test (push) Failing after 0s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-04-14 00:54:12 +00:00
jkunz a53a2c4ca5 fix(rustproxy-http,rustproxy-metrics): fix domain-scoped request host detection and harden connection metrics cleanup 2026-04-14 00:54:12 +00:00
jkunz 6ee7237357 v27.7.0
Default (tags) / security (push) Failing after 0s
Default (tags) / test (push) Failing after 0s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-04-13 23:21:54 +00:00
jkunz b5b4c608f0 feat(smart-proxy): add typed Rust config serialization and regex header contract coverage 2026-04-13 23:21:54 +00:00
jkunz af132f40fc v27.6.0
Default (tags) / security (push) Failing after 0s
Default (tags) / test (push) Failing after 0s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-04-13 18:33:28 +00:00
jkunz 781634446a feat(metrics): track per-IP domain request metrics across HTTP and TCP passthrough traffic 2026-04-13 18:33:28 +00:00
jkunz e988d935b6 v27.5.0
Default (tags) / security (push) Failing after 0s
Default (tags) / test (push) Failing after 0s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-04-06 12:46:09 +00:00
jkunz 99a026627d feat(security): add domain-scoped IP allow list support across HTTP and passthrough filtering 2026-04-06 12:46:09 +00:00
jkunz 572e31587a v27.4.0
Default (tags) / security (push) Failing after 0s
Default (tags) / test (push) Failing after 1s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-04-04 19:25:06 +00:00
jkunz 8587fb997c feat(rustproxy): add HTTP/3 proxy service wiring for QUIC listeners 2026-04-04 19:25:06 +00:00
jkunz 9ba101c59b v27.3.1
Default (tags) / security (push) Failing after 0s
Default (tags) / test (push) Failing after 0s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-04-04 18:54:05 +00:00
jkunz 1ad3e61c15 fix(metrics): correct frontend and backend protocol connection tracking across h1, h2, h3, and websocket traffic 2026-04-04 18:54:05 +00:00
83 changed files with 6928 additions and 2047 deletions
+89
View File
@@ -1,5 +1,94 @@
# Changelog # Changelog
## 2026-04-26 - 27.9.0 - feat(smart-proxy)
add hot-reloadable global ingress security policy across Rust and TypeScript proxy layers
- adds global securityPolicy config with blocked IP and CIDR support to SmartProxy and RustProxy options
- introduces management IPC support to update the security policy at runtime via setSecurityPolicy
- enforces the global block list early for TCP, UDP, and QUIC traffic before route selection and backend handling
## 2026-04-26 - 27.8.2 - fix(rustproxy-metrics)
retain inactive per-IP metric buckets briefly to capture final throughput before pruning
- adds a bounded retention window for closed IP buckets so short-lived transfers are still included in per-IP throughput sampling
- prunes expired inactive IP tracking by TTL and hard cap to prevent unbounded metric map growth
- updates Rust and throughput tests to expect zero active connections during the temporary retention period
## 2026-04-26 - 27.8.1 - fix(rustproxy-metrics)
preserve high-throughput IPs in metrics snapshots when active-connection rankings are saturated
- Select snapshot IPs using a blend of active-connection and throughput rankings instead of only active connections
- Adds a regression test to ensure a high-bandwidth IP remains included when many other IPs have more active connections
## 2026-04-14 - 27.8.0 - feat(metrics)
add per-domain HTTP request rate metrics
- Record canonicalized HTTP request rates per domain in the Rust metrics collector and expose per-second and last-minute values in snapshots.
- Add TypeScript metrics interfaces and adapter support for requests.byDomain().
- Cover HTTP domain rate tracking and ensure TLS passthrough SNI traffic does not affect HTTP request rate metrics.
## 2026-04-14 - 27.7.4 - fix(rustproxy metrics)
use stable route metrics keys across HTTP and passthrough listeners
- adds a shared RouteConfig::metrics_key helper that prefers route name and falls back to route id
- updates HTTP, TCP, UDP, and QUIC metrics labeling to use the shared route metrics key consistently
- keeps route cancellation and rate limiter indexing bound to route config ids where required
- adds tests covering metrics key selection behavior
## 2026-04-14 - 27.7.3 - fix(repo)
no changes detected
## 2026-04-14 - 27.7.2 - fix(docs)
clarify metrics documentation for domain normalization and saturating gauges
- Document that per-IP domain keys are normalized to lowercase and have trailing dots stripped before counting.
- Clarify that the saturating close pattern also applies to connection and UDP active gauges.
## 2026-04-14 - 27.7.1 - fix(rustproxy-http,rustproxy-metrics)
fix domain-scoped request host detection and harden connection metrics cleanup
- use a shared request host extractor that falls back to URI authority so domain-scoped IP allow lists work for HTTP/2 and HTTP/3 requests without a Host header
- add request filter and host extraction tests covering domain-scoped ACL behavior
- prevent connection counters from underflowing during close handling and clean up per-IP metrics entries more safely
- normalize tracked domain keys in metrics to reduce duplicate entries caused by case or trailing-dot variations
## 2026-04-13 - 27.7.0 - feat(smart-proxy)
add typed Rust config serialization and regex header contract coverage
- serialize SmartProxy routes and top-level options into explicit Rust-safe types, including header regex literals, UDP field normalization, ACME, defaults, and proxy settings
- support JS-style regex header literals with flags in Rust header matching and add cross-contract tests for route preprocessing and config deserialization
- improve TypeScript safety for Rust bridge and metrics integration by replacing loose any-based payloads with dedicated Rust type definitions
## 2026-04-13 - 27.6.0 - feat(metrics)
track per-IP domain request metrics across HTTP and TCP passthrough traffic
- records domain request counts per frontend IP from HTTP Host headers and TCP SNI
- exposes per-IP domain maps and top IP-domain request pairs through the TypeScript metrics adapter
- bounds per-IP domain tracking and prunes stale entries to limit memory growth
- adds metrics system documentation covering architecture, collected data, and known gaps
## 2026-04-06 - 27.5.0 - feat(security)
add domain-scoped IP allow list support across HTTP and passthrough filtering
- extend route security types to accept IP allow entries scoped to specific domains
- apply domain-aware IP checks using Host headers for HTTP and SNI context for QUIC and passthrough connections
- preserve compatibility for existing plain allow list entries and add validation and tests for scoped matching
## 2026-04-04 - 27.4.0 - feat(rustproxy)
add HTTP/3 proxy service wiring for QUIC listeners
- registers H3ProxyService with the UDP listener manager so QUIC connections can serve HTTP/3
- keeps proxy IP configuration intact while enabling HTTP/3 handling during listener setup
## 2026-04-04 - 27.3.1 - fix(metrics)
correct frontend and backend protocol connection tracking across h1, h2, h3, and websocket traffic
- move frontend protocol accounting from per-request to connection lifetime tracking for HTTP/1, HTTP/2, and HTTP/3
- add backend protocol guards to connection drivers so active protocol metrics reflect live upstream connections
- prevent protocol counter underflow by using atomic saturating decrements in the metrics collector
- read backend protocol distribution directly from cached aggregate counters in the Rust metrics adapter
## 2026-04-04 - 27.3.0 - feat(test) ## 2026-04-04 - 27.3.0 - feat(test)
add end-to-end WebSocket proxy test coverage add end-to-end WebSocket proxy test coverage
Generated
+6 -2
View File
@@ -12,9 +12,11 @@
"npm:@push.rocks/smartserve@^2.0.3": "2.0.3", "npm:@push.rocks/smartserve@^2.0.3": "2.0.3",
"npm:@tsclass/tsclass@^9.5.0": "9.5.0", "npm:@tsclass/tsclass@^9.5.0": "9.5.0",
"npm:@types/node@^25.5.0": "25.5.0", "npm:@types/node@^25.5.0": "25.5.0",
"npm:@types/ws@^8.18.1": "8.18.1",
"npm:minimatch@^10.2.4": "10.2.4", "npm:minimatch@^10.2.4": "10.2.4",
"npm:typescript@^6.0.2": "6.0.2", "npm:typescript@^6.0.2": "6.0.2",
"npm:why-is-node-running@^3.2.2": "3.2.2" "npm:why-is-node-running@^3.2.2": "3.2.2",
"npm:ws@^8.20.0": "8.20.0"
}, },
"npm": { "npm": {
"@api.global/typedrequest-interfaces@2.0.2": { "@api.global/typedrequest-interfaces@2.0.2": {
@@ -6743,9 +6745,11 @@
"npm:@push.rocks/smartserve@^2.0.3", "npm:@push.rocks/smartserve@^2.0.3",
"npm:@tsclass/tsclass@^9.5.0", "npm:@tsclass/tsclass@^9.5.0",
"npm:@types/node@^25.5.0", "npm:@types/node@^25.5.0",
"npm:@types/ws@^8.18.1",
"npm:minimatch@^10.2.4", "npm:minimatch@^10.2.4",
"npm:typescript@^6.0.2", "npm:typescript@^6.0.2",
"npm:why-is-node-running@^3.2.2" "npm:why-is-node-running@^3.2.2",
"npm:ws@^8.20.0"
] ]
} }
} }
+1 -1
View File
@@ -1,6 +1,6 @@
{ {
"name": "@push.rocks/smartproxy", "name": "@push.rocks/smartproxy",
"version": "27.3.0", "version": "27.9.0",
"private": false, "private": false,
"description": "A powerful proxy package with unified route-based configuration for high traffic management. Features include SSL/TLS support, flexible routing patterns, WebSocket handling, advanced security options, and automatic ACME certificate management.", "description": "A powerful proxy package with unified route-based configuration for high traffic management. Features include SSL/TLS support, flexible routing patterns, WebSocket handling, advanced security options, and automatic ACME certificate management.",
"main": "dist_ts/index.js", "main": "dist_ts/index.js",
+484
View File
@@ -0,0 +1,484 @@
# SmartProxy Metrics System
## Architecture
Two-tier design separating the data plane from the observation plane:
**Hot path (per-chunk, lock-free):** All recording in the proxy data plane touches only `AtomicU64` counters. No `Mutex` is ever acquired on the forwarding path. `CountingBody` batches flushes every 64KB to reduce DashMap shard contention.
**Cold path (1Hz sampling):** A background tokio task drains pending atomics into `ThroughputTracker` circular buffers (Mutex-guarded), producing per-second throughput history. Same task prunes orphaned entries and cleans up rate limiter state.
**Read path (on-demand):** `snapshot()` reads all atomics and locks ThroughputTrackers to build a serializable `Metrics` struct. TypeScript polls this at 1s intervals via IPC.
```
Data Plane (lock-free) Background (1Hz) Read Path
───────────────────── ────────────────── ─────────
record_bytes() ──> AtomicU64 ──┐
record_http_request() ──> AtomicU64 ──┤
connection_opened/closed() ──> AtomicU64 ──┤ sample_all() snapshot()
backend_*() ──> DashMap<AtomicU64> ──┤────> drain atomics ──────> Metrics struct
protocol_*() ──> AtomicU64 ──┤ feed ThroughputTrackers ──> JSON
datagram_*() ──> AtomicU64 ──┘ prune orphans ──> IPC stdout
──> TS cache
──> IMetrics API
```
### Key Types
| Type | Crate | Purpose |
|---|---|---|
| `MetricsCollector` | `rustproxy-metrics` | Central store. All DashMaps, atomics, and throughput trackers |
| `ThroughputTracker` | `rustproxy-metrics` | Circular buffer of 1Hz samples. Default 3600 capacity (1 hour) |
| `ForwardMetricsCtx` | `rustproxy-passthrough` | Carries `Arc<MetricsCollector>` + route_id + source_ip through TCP forwarding |
| `CountingBody` | `rustproxy-http` | Wraps HTTP bodies, batches byte recording per 64KB, flushes on drop |
| `ProtocolGuard` | `rustproxy-http` | RAII guard for frontend/backend protocol active/total counters |
| `ConnectionGuard` | `rustproxy-passthrough` | RAII guard calling `connection_closed()` on drop |
| `RustMetricsAdapter` | TypeScript | Polls Rust via IPC, implements `IMetrics` interface over cached JSON |
---
## What's Collected
### Global Counters
| Metric | Type | Updated by |
|---|---|---|
| Active connections | AtomicU64 | `connection_opened/closed` |
| Total connections (lifetime) | AtomicU64 | `connection_opened` |
| Total bytes in | AtomicU64 | `record_bytes` |
| Total bytes out | AtomicU64 | `record_bytes` |
| Total HTTP requests | AtomicU64 | `record_http_request` |
| Active UDP sessions | AtomicU64 | `udp_session_opened/closed` |
| Total UDP sessions | AtomicU64 | `udp_session_opened` |
| Total datagrams in | AtomicU64 | `record_datagram_in` |
| Total datagrams out | AtomicU64 | `record_datagram_out` |
### Per-Route Metrics (keyed by route ID string)
| Metric | Storage |
|---|---|
| Active connections | `DashMap<String, AtomicU64>` |
| Total connections | `DashMap<String, AtomicU64>` |
| Bytes in / out | `DashMap<String, AtomicU64>` |
| Pending throughput (in, out) | `DashMap<String, (AtomicU64, AtomicU64)>` |
| Throughput history | `DashMap<String, Mutex<ThroughputTracker>>` |
Entries are pruned via `retain_routes()` when routes are removed.
### Per-IP Metrics (keyed by IP string)
| Metric | Storage |
|---|---|
| Active connections | `DashMap<String, AtomicU64>` |
| Total connections | `DashMap<String, AtomicU64>` |
| Bytes in / out | `DashMap<String, AtomicU64>` |
| Pending throughput (in, out) | `DashMap<String, (AtomicU64, AtomicU64)>` |
| Throughput history | `DashMap<String, Mutex<ThroughputTracker>>` |
| Domain requests | `DashMap<String, DashMap<String, AtomicU64>>` (IP → domain → count) |
All seven maps for an IP are evicted when its active connection count drops to 0. Safety-net pruning in `sample_all()` catches entries orphaned by races. Snapshots cap at 100 IPs, sorted by active connections descending.
**Domain request tracking:** Records which domains each frontend IP has requested. Populated from HTTP Host headers (for HTTP/1.1, HTTP/2, HTTP/3 requests) and from SNI (for TLS passthrough connections). Domain keys are normalized to lowercase with any trailing dot stripped so the same hostname does not fragment across multiple counters. Capped at 256 domains per IP (`MAX_DOMAINS_PER_IP`) to prevent subdomain-spray abuse. Inner DashMap uses 2 shards to minimise base memory per IP (~200 bytes). Common case (IP + domain both known) is two DashMap reads + one atomic increment with zero allocation.
### Per-Backend Metrics (keyed by "host:port")
| Metric | Storage |
|---|---|
| Active connections | `DashMap<String, AtomicU64>` |
| Total connections | `DashMap<String, AtomicU64>` |
| Detected protocol (h1/h2/h3) | `DashMap<String, String>` |
| Connect errors | `DashMap<String, AtomicU64>` |
| Handshake errors | `DashMap<String, AtomicU64>` |
| Request errors | `DashMap<String, AtomicU64>` |
| Total connect time (microseconds) | `DashMap<String, AtomicU64>` |
| Connect count | `DashMap<String, AtomicU64>` |
| Pool hits | `DashMap<String, AtomicU64>` |
| Pool misses | `DashMap<String, AtomicU64>` |
| H2 failures (fallback to H1) | `DashMap<String, AtomicU64>` |
All per-backend maps are evicted when active count reaches 0. Pruned via `retain_backends()`.
### Frontend Protocol Distribution
Tracked via `ProtocolGuard` RAII guards and `FrontendProtocolTracker`. Five protocol categories, each with active + total counters (AtomicU64):
| Protocol | Where detected |
|---|---|
| h1 | `FrontendProtocolTracker` on first HTTP/1.x request |
| h2 | `FrontendProtocolTracker` on first HTTP/2 request |
| h3 | `ProtocolGuard::frontend("h3")` in H3ProxyService |
| ws | `ProtocolGuard::frontend("ws")` on WebSocket upgrade |
| other | Fallback (TCP passthrough without HTTP) |
Uses `fetch_update` for saturating decrements to prevent underflow races. The same saturating-close pattern is also used for connection and UDP active gauges.
### Backend Protocol Distribution
Same five categories (h1/h2/h3/ws/other), tracked via `ProtocolGuard::backend()` at connection establishment. Backend h2 failures (fallback to h1) are separately counted.
### Throughput History
`ThroughputTracker` is a circular buffer storing `ThroughputSample { timestamp_ms, bytes_in, bytes_out }` at 1Hz.
- Global tracker: 1 instance, default 3600 capacity
- Per-route trackers: 1 per active route
- Per-IP trackers: 1 per connected IP (evicted with the IP)
- HTTP request tracker: reuses ThroughputTracker with bytes_in = request count, bytes_out = 0
Query methods:
- `instant()` — last 1 second average
- `recent()` — last 10 seconds average
- `throughput(N)` — last N seconds average
- `history(N)` — last N raw samples in chronological order
Snapshots return 60 samples of global throughput history.
### Protocol Detection Cache
Not part of MetricsCollector. Maintained by `HttpProxyService`'s protocol detection system. Injected into the metrics snapshot at read time by `get_metrics()`.
Each entry records: host, port, domain, detected protocol (h1/h2/h3), H3 port, age, last accessed, last probed, suppression flags, cooldown timers, consecutive failure counts.
---
## Instrumentation Points
### TCP Passthrough (`rustproxy-passthrough`)
**Connection lifecycle** — `tcp_listener.rs`:
- Accept: `conn_tracker.connection_opened(&ip)` (rate limiter) + `ConnectionTrackerGuard` RAII
- Route match: `metrics.connection_opened(route_id, source_ip)` + `ConnectionGuard` RAII
- Close: Both guards call their respective `_closed()` methods on drop
**Byte recording** — `forwarder.rs` (`forward_bidirectional_with_timeouts`):
- Initial peeked data recorded immediately
- Per-chunk in both directions: `record_bytes(n, 0, ...)` / `record_bytes(0, n, ...)`
- Same pattern in `forward_bidirectional_split_with_timeouts` (tcp_listener.rs) for TLS-terminated paths
### HTTP Proxy (`rustproxy-http`)
**Request counting** — `proxy_service.rs`:
- `record_http_request()` called once per request after route matching succeeds
**Body byte counting** — `counting_body.rs` wrapping:
- Request bodies: `CountingBody::new(body, ..., Direction::In)` — counts client-to-upstream bytes
- Response bodies: `CountingBody::new(body, ..., Direction::Out)` — counts upstream-to-client bytes
- Batched flush every 64KB (`BYTE_FLUSH_THRESHOLD = 65_536`), remainder flushed on drop
- Also updates `connection_activity` atomic (idle watchdog) and `active_requests` counter (streaming detection)
**Backend metrics** — `proxy_service.rs`:
- `backend_connection_opened(key, connect_time)` — after TCP/TLS connect succeeds
- `backend_connection_closed(key)` — on teardown
- `backend_connect_error(key)` — TCP/TLS connect failure or timeout
- `backend_handshake_error(key)` — H1/H2 protocol handshake failure
- `backend_request_error(key)` — send_request failure
- `backend_h2_failure(key)` — H2 attempted, fell back to H1
- `backend_pool_hit(key)` / `backend_pool_miss(key)` — connection pool reuse
- `set_backend_protocol(key, proto)` — records detected protocol
**WebSocket** — `proxy_service.rs`:
- Does NOT use CountingBody; records bytes directly per-chunk in both directions of the bidirectional copy loop
### QUIC (`rustproxy-passthrough`)
**Connection level** — `quic_handler.rs`:
- `connection_opened` / `connection_closed` via `QuicConnGuard` RAII
- `conn_tracker.connection_opened/closed` for rate limiting
**Stream level**:
- For QUIC-to-TCP stream forwarding: `record_bytes(bytes_in, bytes_out, ...)` called once per stream at completion (post-hoc, not per-chunk)
- For HTTP/3: delegates to `HttpProxyService.handle_request()`, so all HTTP proxy metrics apply
**H3 specifics** — `h3_service.rs`:
- `ProtocolGuard::frontend("h3")` tracks the H3 connection
- H3 request bodies: `record_bytes(data.len(), 0, ...)` called directly (not CountingBody) since H3 uses `stream.send_data()`
- H3 response bodies: wrapped in CountingBody like HTTP/1 and HTTP/2
### UDP (`rustproxy-passthrough`)
**Session lifecycle** — `udp_listener.rs` / `udp_session.rs`:
- `udp_session_opened()` + `connection_opened(route_id, source_ip)` on new session
- `udp_session_closed()` + `connection_closed(route_id, source_ip)` on idle reap or port drain
**Datagram counting** — `udp_listener.rs`:
- Inbound: `record_bytes(len, 0, ...)` + `record_datagram_in()`
- Outbound (backend reply): `record_bytes(0, len, ...)` + `record_datagram_out()`
---
## Sampling Loop
`lib.rs` spawns a tokio task at configurable interval (default 1000ms):
```rust
loop {
tokio::select! {
_ = cancel => break,
_ = interval.tick() => {
metrics.sample_all();
conn_tracker.cleanup_stale_timestamps();
http_proxy.cleanup_all_rate_limiters();
}
}
}
```
`sample_all()` performs in one pass:
1. Drains `global_pending_tp_in/out` into global ThroughputTracker, samples
2. Drains per-route pending counters into per-route trackers, samples each
3. Samples idle route trackers (no new data) to advance their window
4. Drains per-IP pending counters into per-IP trackers, samples each
5. Drains `pending_http_requests` into HTTP request throughput tracker
6. Prunes orphaned per-IP entries (bytes/throughput maps with no matching ip_connections key)
7. Prunes orphaned per-backend entries (error/stats maps with no matching active/total key)
---
## Data Flow: Rust to TypeScript
```
MetricsCollector.snapshot()
├── reads all AtomicU64 counters
├── iterates DashMaps (routes, IPs, backends)
├── locks ThroughputTrackers for instant/recent rates + history
└── produces Metrics struct
RustProxy::get_metrics()
├── calls snapshot()
├── enriches with detectedProtocols from HTTP proxy protocol cache
└── returns Metrics
management.rs "getMetrics" IPC command
├── calls get_metrics()
├── serde_json::to_value (camelCase)
└── writes JSON to stdout
RustProxyBridge (TypeScript)
├── reads JSON from Rust process stdout
└── returns parsed object
RustMetricsAdapter
├── setInterval polls bridge.getMetrics() every 1s
├── stores raw JSON in this.cache
└── IMetrics methods read synchronously from cache
SmartProxy.getMetrics()
└── returns the RustMetricsAdapter instance
```
### IPC JSON Shape (Metrics)
```json
{
"activeConnections": 42,
"totalConnections": 1000,
"bytesIn": 123456789,
"bytesOut": 987654321,
"throughputInBytesPerSec": 50000,
"throughputOutBytesPerSec": 80000,
"throughputRecentInBytesPerSec": 45000,
"throughputRecentOutBytesPerSec": 75000,
"routes": {
"<route-id>": {
"activeConnections": 5,
"totalConnections": 100,
"bytesIn": 0, "bytesOut": 0,
"throughputInBytesPerSec": 0, "throughputOutBytesPerSec": 0,
"throughputRecentInBytesPerSec": 0, "throughputRecentOutBytesPerSec": 0
}
},
"ips": {
"<ip>": {
"activeConnections": 2, "totalConnections": 10,
"bytesIn": 0, "bytesOut": 0,
"throughputInBytesPerSec": 0, "throughputOutBytesPerSec": 0,
"domainRequests": {
"example.com": 4821,
"api.example.com": 312
}
}
},
"backends": {
"<host:port>": {
"activeConnections": 3, "totalConnections": 50,
"protocol": "h2",
"connectErrors": 0, "handshakeErrors": 0, "requestErrors": 0,
"totalConnectTimeUs": 150000, "connectCount": 50,
"poolHits": 40, "poolMisses": 10, "h2Failures": 1
}
},
"throughputHistory": [
{ "timestampMs": 1713000000000, "bytesIn": 50000, "bytesOut": 80000 }
],
"totalHttpRequests": 5000,
"httpRequestsPerSec": 100,
"httpRequestsPerSecRecent": 95,
"activeUdpSessions": 0, "totalUdpSessions": 5,
"totalDatagramsIn": 1000, "totalDatagramsOut": 1000,
"frontendProtocols": {
"h1Active": 10, "h1Total": 500,
"h2Active": 5, "h2Total": 200,
"h3Active": 1, "h3Total": 50,
"wsActive": 2, "wsTotal": 30,
"otherActive": 0, "otherTotal": 0
},
"backendProtocols": { "...same shape..." },
"detectedProtocols": [
{
"host": "backend", "port": 443, "domain": "example.com",
"protocol": "h2", "h3Port": 443,
"ageSecs": 120, "lastAccessedSecs": 5, "lastProbedSecs": 120,
"h2Suppressed": false, "h3Suppressed": false,
"h2CooldownRemainingSecs": null, "h3CooldownRemainingSecs": null,
"h2ConsecutiveFailures": null, "h3ConsecutiveFailures": null
}
]
}
```
### IPC JSON Shape (Statistics)
Lightweight administrative summary, fetched on-demand (not polled):
```json
{
"activeConnections": 42,
"totalConnections": 1000,
"routesCount": 5,
"listeningPorts": [80, 443, 8443],
"uptimeSeconds": 86400
}
```
---
## TypeScript Consumer API
`SmartProxy.getMetrics()` returns an `IMetrics` object. All members are synchronous methods reading from the polled cache.
### connections
| Method | Return | Source |
|---|---|---|
| `active()` | `number` | `cache.activeConnections` |
| `total()` | `number` | `cache.totalConnections` |
| `byRoute()` | `Map<string, number>` | `cache.routes[name].activeConnections` |
| `byIP()` | `Map<string, number>` | `cache.ips[ip].activeConnections` |
| `topIPs(limit?)` | `Array<{ip, count}>` | `cache.ips` sorted by active desc, default 10 |
| `domainRequestsByIP()` | `Map<string, Map<string, number>>` | `cache.ips[ip].domainRequests` |
| `topDomainRequests(limit?)` | `Array<{ip, domain, count}>` | Flattened from all IPs, sorted by count desc, default 20 |
| `frontendProtocols()` | `IProtocolDistribution` | `cache.frontendProtocols.*` |
| `backendProtocols()` | `IProtocolDistribution` | `cache.backendProtocols.*` |
### throughput
| Method | Return | Source |
|---|---|---|
| `instant()` | `{in, out}` | `cache.throughputInBytesPerSec/Out` |
| `recent()` | `{in, out}` | `cache.throughputRecentInBytesPerSec/Out` |
| `average()` | `{in, out}` | Falls back to `instant()` (not wired to windowed average) |
| `custom(seconds)` | `{in, out}` | Falls back to `instant()` (not wired) |
| `history(seconds)` | `IThroughputHistoryPoint[]` | `cache.throughputHistory` sliced to last N entries |
| `byRoute(windowSeconds?)` | `Map<string, {in, out}>` | `cache.routes[name].throughputIn/OutBytesPerSec` |
| `byIP(windowSeconds?)` | `Map<string, {in, out}>` | `cache.ips[ip].throughputIn/OutBytesPerSec` |
### requests
| Method | Return | Source |
|---|---|---|
| `perSecond()` | `number` | `cache.httpRequestsPerSec` |
| `perMinute()` | `number` | `cache.httpRequestsPerSecRecent * 60` |
| `total()` | `number` | `cache.totalHttpRequests` (fallback: totalConnections) |
### totals
| Method | Return | Source |
|---|---|---|
| `bytesIn()` | `number` | `cache.bytesIn` |
| `bytesOut()` | `number` | `cache.bytesOut` |
| `connections()` | `number` | `cache.totalConnections` |
### backends
| Method | Return | Source |
|---|---|---|
| `byBackend()` | `Map<string, IBackendMetrics>` | `cache.backends[key].*` with computed `avgConnectTimeMs` and `poolHitRate` |
| `protocols()` | `Map<string, string>` | `cache.backends[key].protocol` |
| `topByErrors(limit?)` | `IBackendMetrics[]` | Sorted by total errors desc |
| `detectedProtocols()` | `IProtocolCacheEntry[]` | `cache.detectedProtocols` passthrough |
`IBackendMetrics`: `{ protocol, activeConnections, totalConnections, connectErrors, handshakeErrors, requestErrors, avgConnectTimeMs, poolHitRate, h2Failures }`
### udp
| Method | Return | Source |
|---|---|---|
| `activeSessions()` | `number` | `cache.activeUdpSessions` |
| `totalSessions()` | `number` | `cache.totalUdpSessions` |
| `datagramsIn()` | `number` | `cache.totalDatagramsIn` |
| `datagramsOut()` | `number` | `cache.totalDatagramsOut` |
### percentiles (stub)
`connectionDuration()` and `bytesTransferred()` always return zeros. Not implemented.
---
## Configuration
```typescript
interface IMetricsConfig {
enabled: boolean; // default true
sampleIntervalMs: number; // default 1000 (1Hz sampling + TS polling)
retentionSeconds: number; // default 3600 (ThroughputTracker capacity)
enableDetailedTracking: boolean;
enablePercentiles: boolean;
cacheResultsMs: number;
prometheusEnabled: boolean; // not wired
prometheusPath: string; // not wired
prometheusPrefix: string; // not wired
}
```
Rust-side config (`MetricsConfig` in `rustproxy-config`):
```rust
pub struct MetricsConfig {
pub enabled: Option<bool>,
pub sample_interval_ms: Option<u64>, // default 1000
pub retention_seconds: Option<u64>, // default 3600
}
```
---
## Design Decisions
**Lock-free hot path.** `record_bytes()` is the most frequently called method (per-chunk in TCP, per-64KB in HTTP). It only touches `AtomicU64` with `Relaxed` ordering and short-circuits zero-byte directions to skip DashMap lookups entirely.
**CountingBody batching.** HTTP body frames are typically 16KB. Flushing to MetricsCollector every 64KB reduces DashMap shard contention by ~4x compared to per-frame recording.
**RAII guards everywhere.** `ConnectionGuard`, `ConnectionTrackerGuard`, `QuicConnGuard`, `ProtocolGuard`, `FrontendProtocolTracker` all use Drop to guarantee counter cleanup on all exit paths including panics.
**Saturating decrements.** Protocol counters use `fetch_update` loops instead of `fetch_sub` to prevent underflow to `u64::MAX` from concurrent close races.
**Bounded memory.** Per-IP entries evicted on last connection close. Per-backend entries evicted on last connection close. Snapshot caps IPs and backends at 100 each. `sample_all()` prunes orphaned entries every second.
**Two-phase throughput.** Pending bytes accumulate in lock-free atomics. The 1Hz cold path drains them into Mutex-guarded ThroughputTrackers. This means the hot path never contends on a Mutex, while the cold path does minimal work (one drain + one sample per tracker).
---
## Known Gaps
| Gap | Status |
|---|---|
| `throughput.average()` / `throughput.custom(seconds)` | Fall back to `instant()`. Not wired to Rust windowed queries. |
| `percentiles.connectionDuration()` / `percentiles.bytesTransferred()` | Stub returning zeros. No histogram in Rust. |
| Prometheus export | Config fields exist but not wired to any exporter. |
| `LogDeduplicator` | Implemented in `rustproxy-metrics` but not connected to any call site. |
| Rate limit hit counters | Rate-limited requests return 429 but no counter is recorded in MetricsCollector. |
| QUIC stream byte counting | Post-hoc (per-stream totals after close), not per-chunk like TCP. |
| Throughput history in snapshot | Capped at 60 samples. TS `history(seconds)` cannot return more than 60 points regardless of `retentionSeconds`. |
| Per-route total connections / bytes | Available in Rust JSON but `IMetrics.connections.byRoute()` only exposes active connections. |
| Per-IP total connections / bytes | Available in Rust JSON but `IMetrics.connections.byIP()` only exposes active connections. |
| IPC response typing | `RustProxyBridge` declares `result: any` for both metrics commands. No type-safe response. |
+1
View File
@@ -1319,6 +1319,7 @@ dependencies = [
"rustproxy-http", "rustproxy-http",
"rustproxy-metrics", "rustproxy-metrics",
"rustproxy-routing", "rustproxy-routing",
"rustproxy-security",
"serde", "serde",
"serde_json", "serde_json",
"socket2 0.5.10", "socket2 0.5.10",
+4 -4
View File
@@ -3,15 +3,15 @@
//! Configuration types for RustProxy, fully compatible with SmartProxy's JSON schema. //! Configuration types for RustProxy, fully compatible with SmartProxy's JSON schema.
//! All types use `#[serde(rename_all = "camelCase")]` to match TypeScript field naming. //! All types use `#[serde(rename_all = "camelCase")]` to match TypeScript field naming.
pub mod route_types;
pub mod proxy_options; pub mod proxy_options;
pub mod tls_types; pub mod route_types;
pub mod security_types; pub mod security_types;
pub mod tls_types;
pub mod validation; pub mod validation;
// Re-export all primary types // Re-export all primary types
pub use route_types::*;
pub use proxy_options::*; pub use proxy_options::*;
pub use tls_types::*; pub use route_types::*;
pub use security_types::*; pub use security_types::*;
pub use tls_types::*;
pub use validation::*; pub use validation::*;
+261 -23
View File
@@ -97,6 +97,16 @@ pub struct MetricsConfig {
pub retention_seconds: Option<u64>, pub retention_seconds: Option<u64>,
} }
/// Global ingress security policy.
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SecurityPolicy {
#[serde(skip_serializing_if = "Option::is_none")]
pub blocked_ips: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub blocked_cidrs: Option<Vec<String>>,
}
/// RustProxy configuration options. /// RustProxy configuration options.
/// Matches TypeScript: `ISmartProxyOptions` /// Matches TypeScript: `ISmartProxyOptions`
/// ///
@@ -129,7 +139,6 @@ pub struct RustProxyOptions {
pub defaults: Option<DefaultConfig>, pub defaults: Option<DefaultConfig>,
// ─── Timeout Settings ──────────────────────────────────────────── // ─── Timeout Settings ────────────────────────────────────────────
/// Timeout for establishing connection to backend (ms), default: 30000 /// Timeout for establishing connection to backend (ms), default: 30000
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub connection_timeout: Option<u64>, pub connection_timeout: Option<u64>,
@@ -159,7 +168,6 @@ pub struct RustProxyOptions {
pub graceful_shutdown_timeout: Option<u64>, pub graceful_shutdown_timeout: Option<u64>,
// ─── Socket Optimization ───────────────────────────────────────── // ─── Socket Optimization ─────────────────────────────────────────
/// Disable Nagle's algorithm (default: true) /// Disable Nagle's algorithm (default: true)
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub no_delay: Option<bool>, pub no_delay: Option<bool>,
@@ -177,7 +185,6 @@ pub struct RustProxyOptions {
pub max_pending_data_size: Option<u64>, pub max_pending_data_size: Option<u64>,
// ─── Enhanced Features ─────────────────────────────────────────── // ─── Enhanced Features ───────────────────────────────────────────
/// Disable inactivity checking entirely /// Disable inactivity checking entirely
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub disable_inactivity_check: Option<bool>, pub disable_inactivity_check: Option<bool>,
@@ -199,7 +206,6 @@ pub struct RustProxyOptions {
pub enable_randomized_timeouts: Option<bool>, pub enable_randomized_timeouts: Option<bool>,
// ─── Rate Limiting ─────────────────────────────────────────────── // ─── Rate Limiting ───────────────────────────────────────────────
/// Maximum simultaneous connections from a single IP /// Maximum simultaneous connections from a single IP
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub max_connections_per_ip: Option<u64>, pub max_connections_per_ip: Option<u64>,
@@ -213,7 +219,6 @@ pub struct RustProxyOptions {
pub max_connections: Option<u64>, pub max_connections: Option<u64>,
// ─── Keep-Alive Settings ───────────────────────────────────────── // ─── Keep-Alive Settings ─────────────────────────────────────────
/// How to treat keep-alive connections /// How to treat keep-alive connections
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub keep_alive_treatment: Option<KeepAliveTreatment>, pub keep_alive_treatment: Option<KeepAliveTreatment>,
@@ -227,7 +232,6 @@ pub struct RustProxyOptions {
pub extended_keep_alive_lifetime: Option<u64>, pub extended_keep_alive_lifetime: Option<u64>,
// ─── HttpProxy Integration ─────────────────────────────────────── // ─── HttpProxy Integration ───────────────────────────────────────
/// Array of ports to forward to HttpProxy /// Array of ports to forward to HttpProxy
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub use_http_proxy: Option<Vec<u16>>, pub use_http_proxy: Option<Vec<u16>>,
@@ -237,13 +241,15 @@ pub struct RustProxyOptions {
pub http_proxy_port: Option<u16>, pub http_proxy_port: Option<u16>,
// ─── Metrics ───────────────────────────────────────────────────── // ─── Metrics ─────────────────────────────────────────────────────
/// Metrics configuration /// Metrics configuration
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub metrics: Option<MetricsConfig>, pub metrics: Option<MetricsConfig>,
// ─── ACME ──────────────────────────────────────────────────────── /// Global ingress security policy, enforced before route selection.
#[serde(skip_serializing_if = "Option::is_none")]
pub security_policy: Option<SecurityPolicy>,
// ─── ACME ────────────────────────────────────────────────────────
/// Global ACME configuration /// Global ACME configuration
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub acme: Option<AcmeOptions>, pub acme: Option<AcmeOptions>,
@@ -283,6 +289,7 @@ impl Default for RustProxyOptions {
use_http_proxy: None, use_http_proxy: None,
http_proxy_port: None, http_proxy_port: None,
metrics: None, metrics: None,
security_policy: None,
acme: None, acme: None,
} }
} }
@@ -318,7 +325,8 @@ impl RustProxyOptions {
/// Get all unique ports that routes listen on. /// Get all unique ports that routes listen on.
pub fn all_listening_ports(&self) -> Vec<u16> { pub fn all_listening_ports(&self) -> Vec<u16> {
let mut ports: Vec<u16> = self.routes let mut ports: Vec<u16> = self
.routes
.iter() .iter()
.flat_map(|r| r.listening_ports()) .flat_map(|r| r.listening_ports())
.collect(); .collect();
@@ -340,7 +348,12 @@ mod tests {
route_match: RouteMatch { route_match: RouteMatch {
ports: PortRange::Single(listen_port), ports: PortRange::Single(listen_port),
domains: Some(DomainSpec::Single(domain.to_string())), domains: Some(DomainSpec::Single(domain.to_string())),
path: None, client_ip: None, transport: None, tls_version: None, headers: None, protocol: None, path: None,
client_ip: None,
transport: None,
tls_version: None,
headers: None,
protocol: None,
}, },
action: RouteAction { action: RouteAction {
action_type: RouteActionType::Forward, action_type: RouteActionType::Forward,
@@ -348,14 +361,30 @@ mod tests {
target_match: None, target_match: None,
host: HostSpec::Single(host.to_string()), host: HostSpec::Single(host.to_string()),
port: PortSpec::Fixed(port), port: PortSpec::Fixed(port),
tls: None, websocket: None, load_balancing: None, send_proxy_protocol: None, tls: None,
headers: None, advanced: None, backend_transport: None, priority: None, websocket: None,
load_balancing: None,
send_proxy_protocol: None,
headers: None,
advanced: None,
backend_transport: None,
priority: None,
}]), }]),
tls: None, websocket: None, load_balancing: None, advanced: None, tls: None,
options: None, send_proxy_protocol: None, udp: None, websocket: None,
load_balancing: None,
advanced: None,
options: None,
send_proxy_protocol: None,
udp: None,
}, },
headers: None, security: None, name: None, description: None, headers: None,
priority: None, tags: None, enabled: None, security: None,
name: None,
description: None,
priority: None,
tags: None,
enabled: None,
} }
} }
@@ -363,8 +392,12 @@ mod tests {
let mut route = make_route(domain, host, port, 443); let mut route = make_route(domain, host, port, 443);
route.action.tls = Some(RouteTls { route.action.tls = Some(RouteTls {
mode: TlsMode::Passthrough, mode: TlsMode::Passthrough,
certificate: None, acme: None, versions: None, ciphers: None, certificate: None,
honor_cipher_order: None, session_timeout: None, acme: None,
versions: None,
ciphers: None,
honor_cipher_order: None,
session_timeout: None,
}); });
route route
} }
@@ -410,6 +443,209 @@ mod tests {
assert_eq!(parsed.connection_timeout, Some(5000)); assert_eq!(parsed.connection_timeout, Some(5000));
} }
#[test]
fn test_deserialize_ts_contract_route_shapes() {
let value = serde_json::json!({
"routes": [{
"name": "contract-route",
"match": {
"ports": [443, { "from": 8443, "to": 8444 }],
"domains": ["api.example.com", "*.example.com"],
"transport": "udp",
"protocol": "http3",
"headers": {
"content-type": "/^application\\/json$/i"
}
},
"action": {
"type": "forward",
"targets": [{
"match": {
"ports": [443],
"path": "/api/*",
"method": ["GET"],
"headers": {
"x-env": "/^(prod|stage)$/"
}
},
"host": ["backend-a", "backend-b"],
"port": "preserve",
"sendProxyProtocol": true,
"backendTransport": "tcp"
}],
"tls": {
"mode": "terminate",
"certificate": "auto"
},
"sendProxyProtocol": true,
"udp": {
"maxSessionsPerIp": 321,
"quic": {
"enableHttp3": true
}
}
},
"security": {
"ipAllowList": [{
"ip": "10.0.0.0/8",
"domains": ["api.example.com"]
}]
}
}],
"preserveSourceIp": true,
"proxyIps": ["10.0.0.1"],
"acceptProxyProtocol": true,
"sendProxyProtocol": true,
"noDelay": true,
"keepAlive": true,
"keepAliveInitialDelay": 1500,
"maxPendingDataSize": 4096,
"disableInactivityCheck": true,
"enableKeepAliveProbes": true,
"enableDetailedLogging": true,
"enableTlsDebugLogging": true,
"enableRandomizedTimeouts": true,
"connectionTimeout": 5000,
"initialDataTimeout": 7000,
"socketTimeout": 9000,
"inactivityCheckInterval": 1100,
"maxConnectionLifetime": 13000,
"inactivityTimeout": 15000,
"gracefulShutdownTimeout": 17000,
"maxConnectionsPerIp": 20,
"connectionRateLimitPerMinute": 30,
"keepAliveTreatment": "extended",
"keepAliveInactivityMultiplier": 2.0,
"extendedKeepAliveLifetime": 19000,
"metrics": {
"enabled": true,
"sampleIntervalMs": 250,
"retentionSeconds": 60
},
"acme": {
"enabled": true,
"email": "ops@example.com",
"environment": "staging",
"useProduction": false,
"skipConfiguredCerts": true,
"renewThresholdDays": 14,
"renewCheckIntervalHours": 12,
"autoRenew": true,
"port": 80
}
});
let options: RustProxyOptions = serde_json::from_value(value).unwrap();
assert_eq!(options.routes.len(), 1);
assert_eq!(options.preserve_source_ip, Some(true));
assert_eq!(options.proxy_ips, Some(vec!["10.0.0.1".to_string()]));
assert_eq!(options.accept_proxy_protocol, Some(true));
assert_eq!(options.send_proxy_protocol, Some(true));
assert_eq!(options.no_delay, Some(true));
assert_eq!(options.keep_alive, Some(true));
assert_eq!(options.keep_alive_initial_delay, Some(1500));
assert_eq!(options.max_pending_data_size, Some(4096));
assert_eq!(options.disable_inactivity_check, Some(true));
assert_eq!(options.enable_keep_alive_probes, Some(true));
assert_eq!(options.enable_detailed_logging, Some(true));
assert_eq!(options.enable_tls_debug_logging, Some(true));
assert_eq!(options.enable_randomized_timeouts, Some(true));
assert_eq!(options.connection_timeout, Some(5000));
assert_eq!(options.initial_data_timeout, Some(7000));
assert_eq!(options.socket_timeout, Some(9000));
assert_eq!(options.inactivity_check_interval, Some(1100));
assert_eq!(options.max_connection_lifetime, Some(13000));
assert_eq!(options.inactivity_timeout, Some(15000));
assert_eq!(options.graceful_shutdown_timeout, Some(17000));
assert_eq!(options.max_connections_per_ip, Some(20));
assert_eq!(options.connection_rate_limit_per_minute, Some(30));
assert_eq!(
options.keep_alive_treatment,
Some(KeepAliveTreatment::Extended)
);
assert_eq!(options.keep_alive_inactivity_multiplier, Some(2.0));
assert_eq!(options.extended_keep_alive_lifetime, Some(19000));
let route = &options.routes[0];
assert_eq!(route.route_match.transport, Some(TransportProtocol::Udp));
assert_eq!(route.route_match.protocol.as_deref(), Some("http3"));
assert_eq!(
route
.route_match
.headers
.as_ref()
.unwrap()
.get("content-type")
.unwrap(),
"/^application\\/json$/i"
);
let target = &route.action.targets.as_ref().unwrap()[0];
assert!(matches!(target.host, HostSpec::List(_)));
assert!(matches!(target.port, PortSpec::Special(ref p) if p == "preserve"));
assert_eq!(target.backend_transport, Some(TransportProtocol::Tcp));
assert_eq!(target.send_proxy_protocol, Some(true));
assert_eq!(
target
.target_match
.as_ref()
.unwrap()
.headers
.as_ref()
.unwrap()
.get("x-env")
.unwrap(),
"/^(prod|stage)$/"
);
assert_eq!(route.action.send_proxy_protocol, Some(true));
assert_eq!(
route.action.udp.as_ref().unwrap().max_sessions_per_ip,
Some(321)
);
assert_eq!(
route
.action
.udp
.as_ref()
.unwrap()
.quic
.as_ref()
.unwrap()
.enable_http3,
Some(true)
);
let allow_list = route
.security
.as_ref()
.unwrap()
.ip_allow_list
.as_ref()
.unwrap();
assert!(matches!(
&allow_list[0],
crate::security_types::IpAllowEntry::DomainScoped { ip, domains }
if ip == "10.0.0.0/8" && domains == &vec!["api.example.com".to_string()]
));
let metrics = options.metrics.as_ref().unwrap();
assert_eq!(metrics.enabled, Some(true));
assert_eq!(metrics.sample_interval_ms, Some(250));
assert_eq!(metrics.retention_seconds, Some(60));
let acme = options.acme.as_ref().unwrap();
assert_eq!(acme.enabled, Some(true));
assert_eq!(acme.email.as_deref(), Some("ops@example.com"));
assert_eq!(acme.environment, Some(AcmeEnvironment::Staging));
assert_eq!(acme.use_production, Some(false));
assert_eq!(acme.skip_configured_certs, Some(true));
assert_eq!(acme.renew_threshold_days, Some(14));
assert_eq!(acme.renew_check_interval_hours, Some(12));
assert_eq!(acme.auto_renew, Some(true));
assert_eq!(acme.port, Some(80));
}
#[test] #[test]
fn test_default_timeouts() { fn test_default_timeouts() {
let options = RustProxyOptions::default(); let options = RustProxyOptions::default();
@@ -438,9 +674,9 @@ mod tests {
fn test_all_listening_ports() { fn test_all_listening_ports() {
let options = RustProxyOptions { let options = RustProxyOptions {
routes: vec![ routes: vec![
make_route("a.com", "backend", 8080, 80), // port 80 make_route("a.com", "backend", 8080, 80), // port 80
make_passthrough_route("b.com", "backend", 443), // port 443 make_passthrough_route("b.com", "backend", 443), // port 443
make_route("c.com", "backend", 9090, 80), // port 80 (duplicate) make_route("c.com", "backend", 9090, 80), // port 80 (duplicate)
], ],
..Default::default() ..Default::default()
}; };
@@ -464,9 +700,11 @@ mod tests {
#[test] #[test]
fn test_deserialize_example_json() { fn test_deserialize_example_json() {
let content = std::fs::read_to_string( let content = std::fs::read_to_string(concat!(
concat!(env!("CARGO_MANIFEST_DIR"), "/../../config/example.json") env!("CARGO_MANIFEST_DIR"),
).unwrap(); "/../../config/example.json"
))
.unwrap();
let options: RustProxyOptions = serde_json::from_str(&content).unwrap(); let options: RustProxyOptions = serde_json::from_str(&content).unwrap();
assert_eq!(options.routes.len(), 4); assert_eq!(options.routes.len(), 4);
let ports = options.all_listening_ports(); let ports = options.all_listening_ports();
@@ -1,8 +1,8 @@
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::collections::HashMap; use std::collections::HashMap;
use crate::tls_types::RouteTls;
use crate::security_types::RouteSecurity; use crate::security_types::RouteSecurity;
use crate::tls_types::RouteTls;
// ─── Port Range ────────────────────────────────────────────────────── // ─── Port Range ──────────────────────────────────────────────────────
@@ -32,12 +32,13 @@ impl PortRange {
pub fn to_ports(&self) -> Vec<u16> { pub fn to_ports(&self) -> Vec<u16> {
match self { match self {
PortRange::Single(p) => vec![*p], PortRange::Single(p) => vec![*p],
PortRange::List(items) => { PortRange::List(items) => items
items.iter().flat_map(|item| match item { .iter()
.flat_map(|item| match item {
PortRangeItem::Port(p) => vec![*p], PortRangeItem::Port(p) => vec![*p],
PortRangeItem::Range(r) => (r.from..=r.to).collect(), PortRangeItem::Range(r) => (r.from..=r.to).collect(),
}).collect() })
} .collect(),
} }
} }
} }
@@ -105,7 +106,8 @@ impl From<Vec<&str>> for DomainSpec {
} }
/// Header match value: either exact string or regex pattern. /// Header match value: either exact string or regex pattern.
/// In JSON, all values come as strings. Regex patterns are prefixed with `/` and suffixed with `/`. /// In JSON, all values come as strings. Regex patterns use JS-style literal syntax,
/// e.g. `/^application\/json$/` or `/^application\/json$/i`.
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)] #[serde(untagged)]
pub enum HeaderMatchValue { pub enum HeaderMatchValue {
@@ -654,6 +656,11 @@ impl RouteConfig {
self.route_match.ports.to_ports() self.route_match.ports.to_ports()
} }
/// Stable key used for frontend route-scoped metrics.
pub fn metrics_key(&self) -> Option<&str> {
self.name.as_deref().or(self.id.as_deref())
}
/// Get the TLS mode for this route (from action-level or first target). /// Get the TLS mode for this route (from action-level or first target).
pub fn tls_mode(&self) -> Option<&crate::tls_types::TlsMode> { pub fn tls_mode(&self) -> Option<&crate::tls_types::TlsMode> {
// Check action-level TLS first // Check action-level TLS first
@@ -671,3 +678,63 @@ impl RouteConfig {
None None
} }
} }
#[cfg(test)]
mod tests {
use super::*;
fn test_route(name: Option<&str>, id: Option<&str>) -> RouteConfig {
RouteConfig {
id: id.map(str::to_string),
route_match: RouteMatch {
ports: PortRange::Single(443),
transport: None,
domains: None,
path: None,
client_ip: None,
tls_version: None,
headers: None,
protocol: None,
},
action: RouteAction {
action_type: RouteActionType::Forward,
targets: None,
tls: None,
websocket: None,
load_balancing: None,
advanced: None,
options: None,
send_proxy_protocol: None,
udp: None,
},
headers: None,
security: None,
name: name.map(str::to_string),
description: None,
priority: None,
tags: None,
enabled: None,
}
}
#[test]
fn metrics_key_prefers_name() {
let route = test_route(Some("named-route"), Some("route-id"));
assert_eq!(route.metrics_key(), Some("named-route"));
}
#[test]
fn metrics_key_falls_back_to_id() {
let route = test_route(None, Some("route-id"));
assert_eq!(route.metrics_key(), Some("route-id"));
}
#[test]
fn metrics_key_is_absent_without_name_or_id() {
let route = test_route(None, None);
assert_eq!(route.metrics_key(), None);
}
}
@@ -103,14 +103,27 @@ pub struct JwtAuthConfig {
pub exclude_paths: Option<Vec<String>>, pub exclude_paths: Option<Vec<String>>,
} }
/// An entry in the IP allow list: either a plain IP/CIDR string
/// or a domain-scoped entry that restricts the IP to specific domains.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum IpAllowEntry {
/// Plain IP/CIDR — allowed for all domains on this route
Plain(String),
/// Domain-scoped — allowed only when the requested domain matches
DomainScoped { ip: String, domains: Vec<String> },
}
/// Security options for routes. /// Security options for routes.
/// Matches TypeScript: `IRouteSecurity` /// Matches TypeScript: `IRouteSecurity`
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct RouteSecurity { pub struct RouteSecurity {
/// IP addresses that are allowed to connect /// IP addresses that are allowed to connect.
/// Entries can be plain strings (full route access) or objects with
/// `{ ip, domains }` to scope access to specific domains.
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub ip_allow_list: Option<Vec<String>>, pub ip_allow_list: Option<Vec<IpAllowEntry>>,
/// IP addresses that are blocked from connecting /// IP addresses that are blocked from connecting
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub ip_block_list: Option<Vec<String>>, pub ip_block_list: Option<Vec<String>>,
+17 -8
View File
@@ -1,6 +1,6 @@
use thiserror::Error; use thiserror::Error;
use crate::route_types::{RouteConfig, RouteActionType}; use crate::route_types::{RouteActionType, RouteConfig};
/// Validation errors for route configurations. /// Validation errors for route configurations.
#[derive(Debug, Error)] #[derive(Debug, Error)]
@@ -30,9 +30,10 @@ pub enum ValidationError {
/// Validate a single route configuration. /// Validate a single route configuration.
pub fn validate_route(route: &RouteConfig) -> Result<(), Vec<ValidationError>> { pub fn validate_route(route: &RouteConfig) -> Result<(), Vec<ValidationError>> {
let mut errors = Vec::new(); let mut errors = Vec::new();
let name = route.name.clone().unwrap_or_else(|| { let name = route
route.id.clone().unwrap_or_else(|| "unnamed".to_string()) .name
}); .clone()
.unwrap_or_else(|| route.id.clone().unwrap_or_else(|| "unnamed".to_string()));
// Check ports // Check ports
let ports = route.listening_ports(); let ports = route.listening_ports();
@@ -160,7 +161,9 @@ mod tests {
let mut route = make_valid_route(); let mut route = make_valid_route();
route.action.targets = None; route.action.targets = None;
let errors = validate_route(&route).unwrap_err(); let errors = validate_route(&route).unwrap_err();
assert!(errors.iter().any(|e| matches!(e, ValidationError::MissingTargets { .. }))); assert!(errors
.iter()
.any(|e| matches!(e, ValidationError::MissingTargets { .. })));
} }
#[test] #[test]
@@ -168,7 +171,9 @@ mod tests {
let mut route = make_valid_route(); let mut route = make_valid_route();
route.action.targets = Some(vec![]); route.action.targets = Some(vec![]);
let errors = validate_route(&route).unwrap_err(); let errors = validate_route(&route).unwrap_err();
assert!(errors.iter().any(|e| matches!(e, ValidationError::EmptyTargets { .. }))); assert!(errors
.iter()
.any(|e| matches!(e, ValidationError::EmptyTargets { .. })));
} }
#[test] #[test]
@@ -176,7 +181,9 @@ mod tests {
let mut route = make_valid_route(); let mut route = make_valid_route();
route.route_match.ports = PortRange::Single(0); route.route_match.ports = PortRange::Single(0);
let errors = validate_route(&route).unwrap_err(); let errors = validate_route(&route).unwrap_err();
assert!(errors.iter().any(|e| matches!(e, ValidationError::InvalidPort { port: 0, .. }))); assert!(errors
.iter()
.any(|e| matches!(e, ValidationError::InvalidPort { port: 0, .. })));
} }
#[test] #[test]
@@ -186,7 +193,9 @@ mod tests {
let mut r2 = make_valid_route(); let mut r2 = make_valid_route();
r2.id = Some("route-1".to_string()); r2.id = Some("route-1".to_string());
let errors = validate_routes(&[r1, r2]).unwrap_err(); let errors = validate_routes(&[r1, r2]).unwrap_err();
assert!(errors.iter().any(|e| matches!(e, ValidationError::DuplicateId { .. }))); assert!(errors
.iter()
.any(|e| matches!(e, ValidationError::DuplicateId { .. })));
} }
#[test] #[test]
@@ -3,8 +3,8 @@
//! Reuses idle keep-alive connections to avoid per-request TCP+TLS handshakes. //! Reuses idle keep-alive connections to avoid per-request TCP+TLS handshakes.
//! HTTP/2 and HTTP/3 connections are multiplexed (clone the sender / share the connection). //! HTTP/2 and HTTP/3 connections are multiplexed (clone the sender / share the connection).
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use bytes::Bytes; use bytes::Bytes;
@@ -105,13 +105,19 @@ impl ConnectionPool {
/// Try to check out an idle HTTP/1.1 sender for the given key. /// Try to check out an idle HTTP/1.1 sender for the given key.
/// Returns `None` if no usable idle connection exists. /// Returns `None` if no usable idle connection exists.
pub fn checkout_h1(&self, key: &PoolKey) -> Option<http1::SendRequest<BoxBody<Bytes, hyper::Error>>> { pub fn checkout_h1(
&self,
key: &PoolKey,
) -> Option<http1::SendRequest<BoxBody<Bytes, hyper::Error>>> {
let mut entry = self.h1_pool.get_mut(key)?; let mut entry = self.h1_pool.get_mut(key)?;
let idles = entry.value_mut(); let idles = entry.value_mut();
while let Some(idle) = idles.pop() { while let Some(idle) = idles.pop() {
// Check if the connection is still alive and ready // Check if the connection is still alive and ready
if idle.idle_since.elapsed() < IDLE_TIMEOUT && idle.sender.is_ready() && !idle.sender.is_closed() { if idle.idle_since.elapsed() < IDLE_TIMEOUT
&& idle.sender.is_ready()
&& !idle.sender.is_closed()
{
// H1 pool hit — no logging on hot path // H1 pool hit — no logging on hot path
return Some(idle.sender); return Some(idle.sender);
} }
@@ -128,7 +134,11 @@ impl ConnectionPool {
/// Return an HTTP/1.1 sender to the pool after the response body has been prepared. /// Return an HTTP/1.1 sender to the pool after the response body has been prepared.
/// The caller should NOT call this if the sender is closed or not ready. /// The caller should NOT call this if the sender is closed or not ready.
pub fn checkin_h1(&self, key: PoolKey, sender: http1::SendRequest<BoxBody<Bytes, hyper::Error>>) { pub fn checkin_h1(
&self,
key: PoolKey,
sender: http1::SendRequest<BoxBody<Bytes, hyper::Error>>,
) {
if sender.is_closed() || !sender.is_ready() { if sender.is_closed() || !sender.is_ready() {
return; // Don't pool broken connections return; // Don't pool broken connections
} }
@@ -145,7 +155,10 @@ impl ConnectionPool {
/// Try to get a cloned HTTP/2 sender for the given key. /// Try to get a cloned HTTP/2 sender for the given key.
/// HTTP/2 senders are Clone-able (multiplexed), so we clone rather than remove. /// HTTP/2 senders are Clone-able (multiplexed), so we clone rather than remove.
pub fn checkout_h2(&self, key: &PoolKey) -> Option<(http2::SendRequest<BoxBody<Bytes, hyper::Error>>, Duration)> { pub fn checkout_h2(
&self,
key: &PoolKey,
) -> Option<(http2::SendRequest<BoxBody<Bytes, hyper::Error>>, Duration)> {
let entry = self.h2_pool.get(key)?; let entry = self.h2_pool.get(key)?;
let pooled = entry.value(); let pooled = entry.value();
let age = pooled.created_at.elapsed(); let age = pooled.created_at.elapsed();
@@ -184,16 +197,23 @@ impl ConnectionPool {
/// Register an HTTP/2 sender in the pool. Returns the generation ID for this entry. /// Register an HTTP/2 sender in the pool. Returns the generation ID for this entry.
/// The caller should pass this generation to the connection driver so it can use /// The caller should pass this generation to the connection driver so it can use
/// `remove_h2_if_generation` instead of `remove_h2` to avoid phantom eviction. /// `remove_h2_if_generation` instead of `remove_h2` to avoid phantom eviction.
pub fn register_h2(&self, key: PoolKey, sender: http2::SendRequest<BoxBody<Bytes, hyper::Error>>) -> u64 { pub fn register_h2(
&self,
key: PoolKey,
sender: http2::SendRequest<BoxBody<Bytes, hyper::Error>>,
) -> u64 {
let gen = self.h2_generation.fetch_add(1, Ordering::Relaxed); let gen = self.h2_generation.fetch_add(1, Ordering::Relaxed);
if sender.is_closed() { if sender.is_closed() {
return gen; return gen;
} }
self.h2_pool.insert(key, PooledH2 { self.h2_pool.insert(
sender, key,
created_at: Instant::now(), PooledH2 {
generation: gen, sender,
}); created_at: Instant::now(),
generation: gen,
},
);
gen gen
} }
@@ -204,7 +224,11 @@ impl ConnectionPool {
pub fn checkout_h3( pub fn checkout_h3(
&self, &self,
key: &PoolKey, key: &PoolKey,
) -> Option<(h3::client::SendRequest<h3_quinn::OpenStreams, Bytes>, quinn::Connection, Duration)> { ) -> Option<(
h3::client::SendRequest<h3_quinn::OpenStreams, Bytes>,
quinn::Connection,
Duration,
)> {
let entry = self.h3_pool.get(key)?; let entry = self.h3_pool.get(key)?;
let pooled = entry.value(); let pooled = entry.value();
let age = pooled.created_at.elapsed(); let age = pooled.created_at.elapsed();
@@ -234,12 +258,15 @@ impl ConnectionPool {
send_request: h3::client::SendRequest<h3_quinn::OpenStreams, Bytes>, send_request: h3::client::SendRequest<h3_quinn::OpenStreams, Bytes>,
) -> u64 { ) -> u64 {
let gen = self.h2_generation.fetch_add(1, Ordering::Relaxed); let gen = self.h2_generation.fetch_add(1, Ordering::Relaxed);
self.h3_pool.insert(key, PooledH3 { self.h3_pool.insert(
send_request, key,
connection, PooledH3 {
created_at: Instant::now(), send_request,
generation: gen, connection,
}); created_at: Instant::now(),
generation: gen,
},
);
gen gen
} }
@@ -280,7 +307,9 @@ impl ConnectionPool {
// Evict dead or aged-out H2 connections // Evict dead or aged-out H2 connections
let mut dead_h2 = Vec::new(); let mut dead_h2 = Vec::new();
for entry in h2_pool.iter() { for entry in h2_pool.iter() {
if entry.value().sender.is_closed() || entry.value().created_at.elapsed() >= MAX_H2_AGE { if entry.value().sender.is_closed()
|| entry.value().created_at.elapsed() >= MAX_H2_AGE
{
dead_h2.push(entry.key().clone()); dead_h2.push(entry.key().clone());
} }
} }
@@ -1,8 +1,8 @@
//! A body wrapper that counts bytes flowing through and reports them to MetricsCollector. //! A body wrapper that counts bytes flowing through and reports them to MetricsCollector.
use std::pin::Pin; use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::task::{Context, Poll}; use std::task::{Context, Poll};
use bytes::Bytes; use bytes::Bytes;
@@ -76,7 +76,11 @@ impl<B> CountingBody<B> {
/// Set the connection-level activity tracker. When set, each data frame /// Set the connection-level activity tracker. When set, each data frame
/// updates this timestamp to prevent the idle watchdog from killing the /// updates this timestamp to prevent the idle watchdog from killing the
/// connection during active body streaming. /// connection during active body streaming.
pub fn with_connection_activity(mut self, activity: Arc<AtomicU64>, start: std::time::Instant) -> Self { pub fn with_connection_activity(
mut self,
activity: Arc<AtomicU64>,
start: std::time::Instant,
) -> Self {
self.connection_activity = Some(activity); self.connection_activity = Some(activity);
self.activity_start = Some(start); self.activity_start = Some(start);
self self
@@ -134,7 +138,9 @@ where
} }
// Keep the connection-level idle watchdog alive on every frame // Keep the connection-level idle watchdog alive on every frame
// (this is just one atomic store — cheap enough per-frame) // (this is just one atomic store — cheap enough per-frame)
if let (Some(activity), Some(start)) = (&this.connection_activity, &this.activity_start) { if let (Some(activity), Some(start)) =
(&this.connection_activity, &this.activity_start)
{
activity.store(start.elapsed().as_millis() as u64, Ordering::Relaxed); activity.store(start.elapsed().as_millis() as u64, Ordering::Relaxed);
} }
} }
+31 -10
View File
@@ -11,14 +11,14 @@ use std::task::{Context, Poll};
use bytes::{Buf, Bytes}; use bytes::{Buf, Bytes};
use http_body::Frame; use http_body::Frame;
use http_body_util::BodyExt;
use http_body_util::combinators::BoxBody; use http_body_util::combinators::BoxBody;
use http_body_util::BodyExt;
use tracing::{debug, warn}; use tracing::{debug, warn};
use rustproxy_config::RouteConfig; use rustproxy_config::RouteConfig;
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use crate::proxy_service::{ConnActivity, HttpProxyService}; use crate::proxy_service::{ConnActivity, HttpProxyService, ProtocolGuard};
/// HTTP/3 proxy service. /// HTTP/3 proxy service.
/// ///
@@ -48,6 +48,10 @@ impl H3ProxyService {
let remote_addr = real_client_addr.unwrap_or_else(|| connection.remote_address()); let remote_addr = real_client_addr.unwrap_or_else(|| connection.remote_address());
debug!("HTTP/3 connection from {} on port {}", remote_addr, port); debug!("HTTP/3 connection from {} on port {}", remote_addr, port);
// Track frontend H3 connection for the QUIC connection's lifetime.
let _frontend_h3_guard =
ProtocolGuard::frontend(Arc::clone(self.http_proxy.metrics()), "h3");
let mut h3_conn: h3::server::Connection<h3_quinn::Connection, Bytes> = let mut h3_conn: h3::server::Connection<h3_quinn::Connection, Bytes> =
h3::server::builder() h3::server::builder()
.send_grease(false) .send_grease(false)
@@ -89,8 +93,15 @@ impl H3ProxyService {
tokio::spawn(async move { tokio::spawn(async move {
if let Err(e) = handle_h3_request( if let Err(e) = handle_h3_request(
request, stream, port, remote_addr, &http_proxy, request_cancel, request,
).await { stream,
port,
remote_addr,
&http_proxy,
request_cancel,
)
.await
{
debug!("HTTP/3 request error from {}: {}", remote_addr, e); debug!("HTTP/3 request error from {}: {}", remote_addr, e);
} }
}); });
@@ -150,11 +161,14 @@ async fn handle_h3_request(
// Delegate to HttpProxyService — same backend path as TCP/HTTP: // Delegate to HttpProxyService — same backend path as TCP/HTTP:
// route matching, ALPN protocol detection, connection pool, H1/H2/H3 auto. // route matching, ALPN protocol detection, connection pool, H1/H2/H3 auto.
let conn_activity = ConnActivity::new_standalone(); let conn_activity = ConnActivity::new_standalone();
let response = http_proxy.handle_request(req, peer_addr, port, cancel, conn_activity).await let response = http_proxy
.handle_request(req, peer_addr, port, cancel, conn_activity)
.await
.map_err(|e| anyhow::anyhow!("Backend request failed: {}", e))?; .map_err(|e| anyhow::anyhow!("Backend request failed: {}", e))?;
// Await the body reader to get the H3 stream back // Await the body reader to get the H3 stream back
let mut stream = body_reader.await let mut stream = body_reader
.await
.map_err(|e| anyhow::anyhow!("Body reader task failed: {}", e))?; .map_err(|e| anyhow::anyhow!("Body reader task failed: {}", e))?;
// Send response headers over H3 (skip hop-by-hop headers) // Send response headers over H3 (skip hop-by-hop headers)
@@ -167,10 +181,13 @@ async fn handle_h3_request(
} }
h3_response = h3_response.header(name, value); h3_response = h3_response.header(name, value);
} }
let h3_response = h3_response.body(()) let h3_response = h3_response
.body(())
.map_err(|e| anyhow::anyhow!("Failed to build H3 response: {}", e))?; .map_err(|e| anyhow::anyhow!("Failed to build H3 response: {}", e))?;
stream.send_response(h3_response).await stream
.send_response(h3_response)
.await
.map_err(|e| anyhow::anyhow!("Failed to send H3 response: {}", e))?; .map_err(|e| anyhow::anyhow!("Failed to send H3 response: {}", e))?;
// Stream response body back over H3 // Stream response body back over H3
@@ -179,7 +196,9 @@ async fn handle_h3_request(
match frame { match frame {
Ok(frame) => { Ok(frame) => {
if let Ok(data) = frame.into_data() { if let Ok(data) = frame.into_data() {
stream.send_data(data).await stream
.send_data(data)
.await
.map_err(|e| anyhow::anyhow!("Failed to send H3 data: {}", e))?; .map_err(|e| anyhow::anyhow!("Failed to send H3 data: {}", e))?;
} }
} }
@@ -191,7 +210,9 @@ async fn handle_h3_request(
} }
// Finish the H3 stream (send QUIC FIN) // Finish the H3 stream (send QUIC FIN)
stream.finish().await stream
.finish()
.await
.map_err(|e| anyhow::anyhow!("Failed to finish H3 stream: {}", e))?; .map_err(|e| anyhow::anyhow!("Failed to finish H3 stream: {}", e))?;
Ok(()) Ok(())
+2 -1
View File
@@ -5,14 +5,15 @@
pub mod connection_pool; pub mod connection_pool;
pub mod counting_body; pub mod counting_body;
pub mod h3_service;
pub mod protocol_cache; pub mod protocol_cache;
pub mod proxy_service; pub mod proxy_service;
pub mod request_filter; pub mod request_filter;
mod request_host;
pub mod response_filter; pub mod response_filter;
pub mod shutdown_on_drop; pub mod shutdown_on_drop;
pub mod template; pub mod template;
pub mod upstream_selector; pub mod upstream_selector;
pub mod h3_service;
pub use connection_pool::*; pub use connection_pool::*;
pub use counting_body::*; pub use counting_body::*;
@@ -144,10 +144,14 @@ impl FailureState {
} }
fn all_expired(&self) -> bool { fn all_expired(&self) -> bool {
let h2_expired = self.h2.as_ref() let h2_expired = self
.h2
.as_ref()
.map(|r| r.failed_at.elapsed() >= r.cooldown) .map(|r| r.failed_at.elapsed() >= r.cooldown)
.unwrap_or(true); .unwrap_or(true);
let h3_expired = self.h3.as_ref() let h3_expired = self
.h3
.as_ref()
.map(|r| r.failed_at.elapsed() >= r.cooldown) .map(|r| r.failed_at.elapsed() >= r.cooldown)
.unwrap_or(true); .unwrap_or(true);
h2_expired && h3_expired h2_expired && h3_expired
@@ -355,9 +359,13 @@ impl ProtocolCache {
let record = entry.get_mut(protocol); let record = entry.get_mut(protocol);
let (consecutive, new_cooldown) = match record { let (consecutive, new_cooldown) = match record {
Some(existing) if existing.failed_at.elapsed() < existing.cooldown.saturating_mul(2) => { Some(existing)
if existing.failed_at.elapsed() < existing.cooldown.saturating_mul(2) =>
{
// Still within the "recent" window — escalate // Still within the "recent" window — escalate
let c = existing.consecutive_failures.saturating_add(1) let c = existing
.consecutive_failures
.saturating_add(1)
.min(PROTOCOL_FAILURE_ESCALATION_CAP); .min(PROTOCOL_FAILURE_ESCALATION_CAP);
(c, escalate_cooldown(c)) (c, escalate_cooldown(c))
} }
@@ -394,8 +402,13 @@ impl ProtocolCache {
if protocol == DetectedProtocol::H1 { if protocol == DetectedProtocol::H1 {
return false; return false;
} }
self.failures.get(key) self.failures
.and_then(|entry| entry.get(protocol).map(|r| r.failed_at.elapsed() < r.cooldown)) .get(key)
.and_then(|entry| {
entry
.get(protocol)
.map(|r| r.failed_at.elapsed() < r.cooldown)
})
.unwrap_or(false) .unwrap_or(false)
} }
@@ -464,19 +477,18 @@ impl ProtocolCache {
/// Snapshot all non-expired cache entries for metrics/UI display. /// Snapshot all non-expired cache entries for metrics/UI display.
pub fn snapshot(&self) -> Vec<ProtocolCacheEntry> { pub fn snapshot(&self) -> Vec<ProtocolCacheEntry> {
self.cache.iter() self.cache
.iter()
.filter(|entry| entry.value().last_accessed_at.elapsed() < PROTOCOL_CACHE_TTL) .filter(|entry| entry.value().last_accessed_at.elapsed() < PROTOCOL_CACHE_TTL)
.map(|entry| { .map(|entry| {
let key = entry.key(); let key = entry.key();
let val = entry.value(); let val = entry.value();
let failure_info = self.failures.get(key); let failure_info = self.failures.get(key);
let (h2_sup, h2_cd, h2_cons) = Self::suppression_info( let (h2_sup, h2_cd, h2_cons) =
failure_info.as_deref().and_then(|f| f.h2.as_ref()), Self::suppression_info(failure_info.as_deref().and_then(|f| f.h2.as_ref()));
); let (h3_sup, h3_cd, h3_cons) =
let (h3_sup, h3_cd, h3_cons) = Self::suppression_info( Self::suppression_info(failure_info.as_deref().and_then(|f| f.h3.as_ref()));
failure_info.as_deref().and_then(|f| f.h3.as_ref()),
);
ProtocolCacheEntry { ProtocolCacheEntry {
host: key.host.clone(), host: key.host.clone(),
@@ -507,7 +519,13 @@ impl ProtocolCache {
/// Insert a protocol detection result with an optional H3 port. /// Insert a protocol detection result with an optional H3 port.
/// Logs protocol transitions when overwriting an existing entry. /// Logs protocol transitions when overwriting an existing entry.
/// No suppression check — callers must check before calling. /// No suppression check — callers must check before calling.
fn insert_internal(&self, key: ProtocolCacheKey, protocol: DetectedProtocol, h3_port: Option<u16>, reason: &str) { fn insert_internal(
&self,
key: ProtocolCacheKey,
protocol: DetectedProtocol,
h3_port: Option<u16>,
reason: &str,
) {
// Check for existing entry to log protocol transitions // Check for existing entry to log protocol transitions
if let Some(existing) = self.cache.get(&key) { if let Some(existing) = self.cache.get(&key) {
if existing.protocol != protocol { if existing.protocol != protocol {
@@ -522,7 +540,9 @@ impl ProtocolCache {
// Evict oldest entry if at capacity // Evict oldest entry if at capacity
if self.cache.len() >= PROTOCOL_CACHE_MAX_ENTRIES && !self.cache.contains_key(&key) { if self.cache.len() >= PROTOCOL_CACHE_MAX_ENTRIES && !self.cache.contains_key(&key) {
let oldest = self.cache.iter() let oldest = self
.cache
.iter()
.min_by_key(|entry| entry.value().last_accessed_at) .min_by_key(|entry| entry.value().last_accessed_at)
.map(|entry| entry.key().clone()); .map(|entry| entry.key().clone());
if let Some(oldest_key) = oldest { if let Some(oldest_key) = oldest {
@@ -531,13 +551,16 @@ impl ProtocolCache {
} }
let now = Instant::now(); let now = Instant::now();
self.cache.insert(key, CachedEntry { self.cache.insert(
protocol, key,
detected_at: now, CachedEntry {
last_accessed_at: now, protocol,
last_probed_at: now, detected_at: now,
h3_port, last_accessed_at: now,
}); last_probed_at: now,
h3_port,
},
);
} }
/// Reduce a failure record's remaining cooldown to `target`, if it currently /// Reduce a failure record's remaining cooldown to `target`, if it currently
@@ -582,26 +605,34 @@ impl ProtocolCache {
interval.tick().await; interval.tick().await;
// Clean expired cache entries (sliding TTL based on last_accessed_at) // Clean expired cache entries (sliding TTL based on last_accessed_at)
let expired: Vec<ProtocolCacheKey> = cache.iter() let expired: Vec<ProtocolCacheKey> = cache
.iter()
.filter(|entry| entry.value().last_accessed_at.elapsed() >= PROTOCOL_CACHE_TTL) .filter(|entry| entry.value().last_accessed_at.elapsed() >= PROTOCOL_CACHE_TTL)
.map(|entry| entry.key().clone()) .map(|entry| entry.key().clone())
.collect(); .collect();
if !expired.is_empty() { if !expired.is_empty() {
debug!("Protocol cache cleanup: removing {} expired entries", expired.len()); debug!(
"Protocol cache cleanup: removing {} expired entries",
expired.len()
);
for key in expired { for key in expired {
cache.remove(&key); cache.remove(&key);
} }
} }
// Clean fully-expired failure entries // Clean fully-expired failure entries
let expired_failures: Vec<ProtocolCacheKey> = failures.iter() let expired_failures: Vec<ProtocolCacheKey> = failures
.iter()
.filter(|entry| entry.value().all_expired()) .filter(|entry| entry.value().all_expired())
.map(|entry| entry.key().clone()) .map(|entry| entry.key().clone())
.collect(); .collect();
if !expired_failures.is_empty() { if !expired_failures.is_empty() {
debug!("Protocol cache cleanup: removing {} expired failure entries", expired_failures.len()); debug!(
"Protocol cache cleanup: removing {} expired failure entries",
expired_failures.len()
);
for key in expired_failures { for key in expired_failures {
failures.remove(&key); failures.remove(&key);
} }
@@ -609,7 +640,8 @@ impl ProtocolCache {
// Safety net: cap failures map at 2× max entries // Safety net: cap failures map at 2× max entries
if failures.len() > PROTOCOL_CACHE_MAX_ENTRIES * 2 { if failures.len() > PROTOCOL_CACHE_MAX_ENTRIES * 2 {
let oldest: Vec<ProtocolCacheKey> = failures.iter() let oldest: Vec<ProtocolCacheKey> = failures
.iter()
.filter(|e| e.value().all_expired()) .filter(|e| e.value().all_expired())
.map(|e| e.key().clone()) .map(|e| e.key().clone())
.take(failures.len() - PROTOCOL_CACHE_MAX_ENTRIES) .take(failures.len() - PROTOCOL_CACHE_MAX_ENTRIES)
File diff suppressed because it is too large Load Diff
+144 -39
View File
@@ -4,13 +4,15 @@ use std::net::SocketAddr;
use std::sync::Arc; use std::sync::Arc;
use bytes::Bytes; use bytes::Bytes;
use http_body_util::Full;
use http_body_util::BodyExt;
use hyper::{Request, Response, StatusCode};
use http_body_util::combinators::BoxBody; use http_body_util::combinators::BoxBody;
use http_body_util::BodyExt;
use http_body_util::Full;
use hyper::{Request, Response, StatusCode};
use rustproxy_config::RouteSecurity; use rustproxy_config::RouteSecurity;
use rustproxy_security::{IpFilter, BasicAuthValidator, JwtValidator, RateLimiter}; use rustproxy_security::{BasicAuthValidator, IpFilter, JwtValidator, RateLimiter};
use crate::request_host::extract_request_host;
pub struct RequestFilter; pub struct RequestFilter;
@@ -35,13 +37,14 @@ impl RequestFilter {
let client_ip = peer_addr.ip(); let client_ip = peer_addr.ip();
let request_path = req.uri().path(); let request_path = req.uri().path();
// IP filter // IP filter (domain-aware: use the same host extraction as route matching)
if security.ip_allow_list.is_some() || security.ip_block_list.is_some() { if security.ip_allow_list.is_some() || security.ip_block_list.is_some() {
let allow = security.ip_allow_list.as_deref().unwrap_or(&[]); let allow = security.ip_allow_list.as_deref().unwrap_or(&[]);
let block = security.ip_block_list.as_deref().unwrap_or(&[]); let block = security.ip_block_list.as_deref().unwrap_or(&[]);
let filter = IpFilter::new(allow, block); let filter = IpFilter::new(allow, block);
let normalized = IpFilter::normalize_ip(&client_ip); let normalized = IpFilter::normalize_ip(&client_ip);
if !filter.is_allowed(&normalized) { let host = extract_request_host(req);
if !filter.is_allowed_for_domain(&normalized, host) {
return Some(error_response(StatusCode::FORBIDDEN, "Access denied")); return Some(error_response(StatusCode::FORBIDDEN, "Access denied"));
} }
} }
@@ -55,16 +58,15 @@ impl RequestFilter {
!limiter.check(&key) !limiter.check(&key)
} else { } else {
// Create a per-check limiter (less ideal but works for non-shared case) // Create a per-check limiter (less ideal but works for non-shared case)
let limiter = RateLimiter::new( let limiter =
rate_limit_config.max_requests, RateLimiter::new(rate_limit_config.max_requests, rate_limit_config.window);
rate_limit_config.window,
);
let key = Self::rate_limit_key(rate_limit_config, req, peer_addr); let key = Self::rate_limit_key(rate_limit_config, req, peer_addr);
!limiter.check(&key) !limiter.check(&key)
}; };
if should_block { if should_block {
let message = rate_limit_config.error_message let message = rate_limit_config
.error_message
.as_deref() .as_deref()
.unwrap_or("Rate limit exceeded"); .unwrap_or("Rate limit exceeded");
return Some(error_response(StatusCode::TOO_MANY_REQUESTS, message)); return Some(error_response(StatusCode::TOO_MANY_REQUESTS, message));
@@ -80,36 +82,48 @@ impl RequestFilter {
if let Some(ref basic_auth) = security.basic_auth { if let Some(ref basic_auth) = security.basic_auth {
if basic_auth.enabled { if basic_auth.enabled {
// Check basic auth exclude paths // Check basic auth exclude paths
let skip_basic = basic_auth.exclude_paths.as_ref() let skip_basic = basic_auth
.exclude_paths
.as_ref()
.map(|paths| Self::path_matches_any(request_path, paths)) .map(|paths| Self::path_matches_any(request_path, paths))
.unwrap_or(false); .unwrap_or(false);
if !skip_basic { if !skip_basic {
let users: Vec<(String, String)> = basic_auth.users.iter() let users: Vec<(String, String)> = basic_auth
.users
.iter()
.map(|c| (c.username.clone(), c.password.clone())) .map(|c| (c.username.clone(), c.password.clone()))
.collect(); .collect();
let validator = BasicAuthValidator::new(users, basic_auth.realm.clone()); let validator = BasicAuthValidator::new(users, basic_auth.realm.clone());
let auth_header = req.headers() let auth_header = req
.headers()
.get("authorization") .get("authorization")
.and_then(|v| v.to_str().ok()); .and_then(|v| v.to_str().ok());
match auth_header { match auth_header {
Some(header) => { Some(header) => {
if validator.validate(header).is_none() { if validator.validate(header).is_none() {
return Some(Response::builder() return Some(
.status(StatusCode::UNAUTHORIZED) Response::builder()
.header("WWW-Authenticate", validator.www_authenticate()) .status(StatusCode::UNAUTHORIZED)
.body(boxed_body("Invalid credentials")) .header(
.unwrap()); "WWW-Authenticate",
validator.www_authenticate(),
)
.body(boxed_body("Invalid credentials"))
.unwrap(),
);
} }
} }
None => { None => {
return Some(Response::builder() return Some(
.status(StatusCode::UNAUTHORIZED) Response::builder()
.header("WWW-Authenticate", validator.www_authenticate()) .status(StatusCode::UNAUTHORIZED)
.body(boxed_body("Authentication required")) .header("WWW-Authenticate", validator.www_authenticate())
.unwrap()); .body(boxed_body("Authentication required"))
.unwrap(),
);
} }
} }
} }
@@ -120,7 +134,9 @@ impl RequestFilter {
if let Some(ref jwt_auth) = security.jwt_auth { if let Some(ref jwt_auth) = security.jwt_auth {
if jwt_auth.enabled { if jwt_auth.enabled {
// Check JWT auth exclude paths // Check JWT auth exclude paths
let skip_jwt = jwt_auth.exclude_paths.as_ref() let skip_jwt = jwt_auth
.exclude_paths
.as_ref()
.map(|paths| Self::path_matches_any(request_path, paths)) .map(|paths| Self::path_matches_any(request_path, paths))
.unwrap_or(false); .unwrap_or(false);
@@ -132,18 +148,25 @@ impl RequestFilter {
jwt_auth.audience.as_deref(), jwt_auth.audience.as_deref(),
); );
let auth_header = req.headers() let auth_header = req
.headers()
.get("authorization") .get("authorization")
.and_then(|v| v.to_str().ok()); .and_then(|v| v.to_str().ok());
match auth_header.and_then(JwtValidator::extract_token) { match auth_header.and_then(JwtValidator::extract_token) {
Some(token) => { Some(token) => {
if validator.validate(token).is_err() { if validator.validate(token).is_err() {
return Some(error_response(StatusCode::UNAUTHORIZED, "Invalid token")); return Some(error_response(
StatusCode::UNAUTHORIZED,
"Invalid token",
));
} }
} }
None => { None => {
return Some(error_response(StatusCode::UNAUTHORIZED, "Bearer token required")); return Some(error_response(
StatusCode::UNAUTHORIZED,
"Bearer token required",
));
} }
} }
} }
@@ -203,14 +226,19 @@ impl RequestFilter {
} }
/// Check IP-based security (for use in passthrough / TCP-level connections). /// Check IP-based security (for use in passthrough / TCP-level connections).
/// `domain` is the SNI from the TLS handshake (if available) for domain-scoped filtering.
/// Returns true if allowed, false if blocked. /// Returns true if allowed, false if blocked.
pub fn check_ip_security(security: &RouteSecurity, client_ip: &std::net::IpAddr) -> bool { pub fn check_ip_security(
security: &RouteSecurity,
client_ip: &std::net::IpAddr,
domain: Option<&str>,
) -> bool {
if security.ip_allow_list.is_some() || security.ip_block_list.is_some() { if security.ip_allow_list.is_some() || security.ip_block_list.is_some() {
let allow = security.ip_allow_list.as_deref().unwrap_or(&[]); let allow = security.ip_allow_list.as_deref().unwrap_or(&[]);
let block = security.ip_block_list.as_deref().unwrap_or(&[]); let block = security.ip_block_list.as_deref().unwrap_or(&[]);
let filter = IpFilter::new(allow, block); let filter = IpFilter::new(allow, block);
let normalized = IpFilter::normalize_ip(client_ip); let normalized = IpFilter::normalize_ip(client_ip);
filter.is_allowed(&normalized) filter.is_allowed_for_domain(&normalized, domain)
} else { } else {
true true
} }
@@ -233,19 +261,28 @@ impl RequestFilter {
return None; return None;
} }
let origin = req.headers() let origin = req
.headers()
.get("origin") .get("origin")
.and_then(|v| v.to_str().ok()) .and_then(|v| v.to_str().ok())
.unwrap_or("*"); .unwrap_or("*");
Some(Response::builder() Some(
.status(StatusCode::NO_CONTENT) Response::builder()
.header("Access-Control-Allow-Origin", origin) .status(StatusCode::NO_CONTENT)
.header("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, PATCH, OPTIONS") .header("Access-Control-Allow-Origin", origin)
.header("Access-Control-Allow-Headers", "Content-Type, Authorization, X-Requested-With") .header(
.header("Access-Control-Max-Age", "86400") "Access-Control-Allow-Methods",
.body(boxed_body("")) "GET, POST, PUT, DELETE, PATCH, OPTIONS",
.unwrap()) )
.header(
"Access-Control-Allow-Headers",
"Content-Type, Authorization, X-Requested-With",
)
.header("Access-Control-Max-Age", "86400")
.body(boxed_body(""))
.unwrap(),
)
} }
} }
@@ -260,3 +297,71 @@ fn error_response(status: StatusCode, message: &str) -> Response<BoxBody<Bytes,
fn boxed_body(data: &str) -> BoxBody<Bytes, hyper::Error> { fn boxed_body(data: &str) -> BoxBody<Bytes, hyper::Error> {
BoxBody::new(Full::new(Bytes::from(data.to_string())).map_err(|never| match never {})) BoxBody::new(Full::new(Bytes::from(data.to_string())).map_err(|never| match never {}))
} }
#[cfg(test)]
mod tests {
use bytes::Bytes;
use http_body_util::Empty;
use hyper::{Request, StatusCode, Version};
use rustproxy_config::{IpAllowEntry, RouteSecurity};
use super::RequestFilter;
fn domain_scoped_security() -> RouteSecurity {
RouteSecurity {
ip_allow_list: Some(vec![IpAllowEntry::DomainScoped {
ip: "10.8.0.2".to_string(),
domains: vec!["*.abc.xyz".to_string()],
}]),
ip_block_list: None,
max_connections: None,
authentication: None,
rate_limit: None,
basic_auth: None,
jwt_auth: None,
}
}
fn peer_addr() -> std::net::SocketAddr {
std::net::SocketAddr::from(([10, 8, 0, 2], 4242))
}
fn request(uri: &str, version: Version, host: Option<&str>) -> Request<Empty<Bytes>> {
let mut builder = Request::builder().uri(uri).version(version);
if let Some(host) = host {
builder = builder.header("host", host);
}
builder.body(Empty::<Bytes>::new()).unwrap()
}
#[test]
fn domain_scoped_acl_allows_uri_authority_without_host_header() {
let security = domain_scoped_security();
let req = request("https://outline.abc.xyz/", Version::HTTP_2, None);
assert!(RequestFilter::apply(&security, &req, &peer_addr()).is_none());
}
#[test]
fn domain_scoped_acl_allows_host_header_with_port() {
let security = domain_scoped_security();
let req = request(
"https://unrelated.invalid/",
Version::HTTP_11,
Some("outline.abc.xyz:443"),
);
assert!(RequestFilter::apply(&security, &req, &peer_addr()).is_none());
}
#[test]
fn domain_scoped_acl_denies_non_matching_uri_authority() {
let security = domain_scoped_security();
let req = request("https://outline.other.xyz/", Version::HTTP_2, None);
let response = RequestFilter::apply(&security, &req, &peer_addr())
.expect("non-matching domain should be denied");
assert_eq!(response.status(), StatusCode::FORBIDDEN);
}
}
@@ -0,0 +1,43 @@
use hyper::Request;
/// Extract the effective request host for routing and scoped ACL checks.
///
/// Prefer the explicit `Host` header when present, otherwise fall back to the
/// URI authority used by HTTP/2 and HTTP/3 requests.
pub(crate) fn extract_request_host<B>(req: &Request<B>) -> Option<&str> {
req.headers()
.get("host")
.and_then(|value| value.to_str().ok())
.map(|host| host.split(':').next().unwrap_or(host))
.or_else(|| req.uri().host())
}
#[cfg(test)]
mod tests {
use bytes::Bytes;
use http_body_util::Empty;
use hyper::Request;
use super::extract_request_host;
#[test]
fn extracts_host_header_before_uri_authority() {
let req = Request::builder()
.uri("https://uri.abc.xyz/test")
.header("host", "header.abc.xyz:443")
.body(Empty::<Bytes>::new())
.unwrap();
assert_eq!(extract_request_host(&req), Some("header.abc.xyz"));
}
#[test]
fn falls_back_to_uri_authority_when_host_header_missing() {
let req = Request::builder()
.uri("https://outline.abc.xyz/test")
.body(Empty::<Bytes>::new())
.unwrap();
assert_eq!(extract_request_host(&req), Some("outline.abc.xyz"));
}
}
@@ -3,7 +3,7 @@
use hyper::header::{HeaderMap, HeaderName, HeaderValue}; use hyper::header::{HeaderMap, HeaderName, HeaderValue};
use rustproxy_config::RouteConfig; use rustproxy_config::RouteConfig;
use crate::template::{RequestContext, expand_template}; use crate::template::{expand_template, RequestContext};
pub struct ResponseFilter; pub struct ResponseFilter;
@@ -11,12 +11,17 @@ impl ResponseFilter {
/// Apply response headers from route config and CORS settings. /// Apply response headers from route config and CORS settings.
/// If a `RequestContext` is provided, template variables in header values will be expanded. /// If a `RequestContext` is provided, template variables in header values will be expanded.
/// Also injects Alt-Svc header for routes with HTTP/3 enabled. /// Also injects Alt-Svc header for routes with HTTP/3 enabled.
pub fn apply_headers(route: &RouteConfig, headers: &mut HeaderMap, req_ctx: Option<&RequestContext>) { pub fn apply_headers(
route: &RouteConfig,
headers: &mut HeaderMap,
req_ctx: Option<&RequestContext>,
) {
// Inject Alt-Svc for HTTP/3 advertisement if QUIC/HTTP3 is enabled on this route // Inject Alt-Svc for HTTP/3 advertisement if QUIC/HTTP3 is enabled on this route
if let Some(ref udp) = route.action.udp { if let Some(ref udp) = route.action.udp {
if let Some(ref quic) = udp.quic { if let Some(ref quic) = udp.quic {
if quic.enable_http3.unwrap_or(false) { if quic.enable_http3.unwrap_or(false) {
let port = quic.alt_svc_port let port = quic
.alt_svc_port
.or_else(|| req_ctx.map(|c| c.port)) .or_else(|| req_ctx.map(|c| c.port))
.unwrap_or(443); .unwrap_or(443);
let max_age = quic.alt_svc_max_age.unwrap_or(86400); let max_age = quic.alt_svc_max_age.unwrap_or(86400);
@@ -63,10 +68,7 @@ impl ResponseFilter {
headers.insert("access-control-allow-origin", val); headers.insert("access-control-allow-origin", val);
} }
} else { } else {
headers.insert( headers.insert("access-control-allow-origin", HeaderValue::from_static("*"));
"access-control-allow-origin",
HeaderValue::from_static("*"),
);
} }
// Allow-Methods // Allow-Methods
@@ -62,17 +62,11 @@ impl<S: AsyncRead + AsyncWrite + Unpin + Send + 'static> AsyncWrite for Shutdown
self.inner.as_ref().unwrap().is_write_vectored() self.inner.as_ref().unwrap().is_write_vectored()
} }
fn poll_flush( fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<io::Result<()>> {
Pin::new(self.get_mut().inner.as_mut().unwrap()).poll_flush(cx) Pin::new(self.get_mut().inner.as_mut().unwrap()).poll_flush(cx)
} }
fn poll_shutdown( fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<io::Result<()>> {
let this = self.get_mut(); let this = self.get_mut();
let result = Pin::new(this.inner.as_mut().unwrap()).poll_shutdown(cx); let result = Pin::new(this.inner.as_mut().unwrap()).poll_shutdown(cx);
if result.is_ready() { if result.is_ready() {
@@ -93,7 +87,8 @@ impl<S: AsyncRead + AsyncWrite + Unpin + Send + 'static> Drop for ShutdownOnDrop
let _ = tokio::time::timeout( let _ = tokio::time::timeout(
std::time::Duration::from_secs(2), std::time::Duration::from_secs(2),
tokio::io::AsyncWriteExt::shutdown(&mut stream), tokio::io::AsyncWriteExt::shutdown(&mut stream),
).await; )
.await;
// stream is dropped here — all resources freed // stream is dropped here — all resources freed
}); });
} }
+6 -2
View File
@@ -39,7 +39,8 @@ pub fn expand_headers(
headers: &HashMap<String, String>, headers: &HashMap<String, String>,
ctx: &RequestContext, ctx: &RequestContext,
) -> HashMap<String, String> { ) -> HashMap<String, String> {
headers.iter() headers
.iter()
.map(|(k, v)| (k.clone(), expand_template(v, ctx))) .map(|(k, v)| (k.clone(), expand_template(v, ctx)))
.collect() .collect()
} }
@@ -150,7 +151,10 @@ mod tests {
let ctx = test_context(); let ctx = test_context();
let template = "{clientIp}|{domain}|{port}|{path}|{routeName}|{connectionId}"; let template = "{clientIp}|{domain}|{port}|{path}|{routeName}|{connectionId}";
let result = expand_template(template, &ctx); let result = expand_template(template, &ctx);
assert_eq!(result, "192.168.1.100|example.com|443|/api/v1/users|api-route|42"); assert_eq!(
result,
"192.168.1.100|example.com|443|/api/v1/users|api-route|42"
);
} }
#[test] #[test]
@@ -7,7 +7,7 @@ use std::sync::Arc;
use std::sync::Mutex; use std::sync::Mutex;
use dashmap::DashMap; use dashmap::DashMap;
use rustproxy_config::{RouteTarget, LoadBalancingAlgorithm}; use rustproxy_config::{LoadBalancingAlgorithm, RouteTarget};
/// Upstream selection result. /// Upstream selection result.
pub struct UpstreamSelection { pub struct UpstreamSelection {
@@ -51,21 +51,19 @@ impl UpstreamSelector {
} }
// Determine load balancing algorithm // Determine load balancing algorithm
let algorithm = target.load_balancing.as_ref() let algorithm = target
.load_balancing
.as_ref()
.map(|lb| &lb.algorithm) .map(|lb| &lb.algorithm)
.unwrap_or(&LoadBalancingAlgorithm::RoundRobin); .unwrap_or(&LoadBalancingAlgorithm::RoundRobin);
let idx = match algorithm { let idx = match algorithm {
LoadBalancingAlgorithm::RoundRobin => { LoadBalancingAlgorithm::RoundRobin => self.round_robin_select(&hosts, port),
self.round_robin_select(&hosts, port)
}
LoadBalancingAlgorithm::IpHash => { LoadBalancingAlgorithm::IpHash => {
let hash = Self::ip_hash(client_addr); let hash = Self::ip_hash(client_addr);
hash % hosts.len() hash % hosts.len()
} }
LoadBalancingAlgorithm::LeastConnections => { LoadBalancingAlgorithm::LeastConnections => self.least_connections_select(&hosts, port),
self.least_connections_select(&hosts, port)
}
}; };
UpstreamSelection { UpstreamSelection {
@@ -78,9 +76,7 @@ impl UpstreamSelector {
fn round_robin_select(&self, hosts: &[&str], port: u16) -> usize { fn round_robin_select(&self, hosts: &[&str], port: u16) -> usize {
let key = format!("{}:{}", hosts[0], port); let key = format!("{}:{}", hosts[0], port);
let mut counters = self.round_robin.lock().unwrap(); let mut counters = self.round_robin.lock().unwrap();
let counter = counters let counter = counters.entry(key).or_insert_with(|| AtomicUsize::new(0));
.entry(key)
.or_insert_with(|| AtomicUsize::new(0));
let idx = counter.fetch_add(1, Ordering::Relaxed); let idx = counter.fetch_add(1, Ordering::Relaxed);
idx % hosts.len() idx % hosts.len()
} }
@@ -91,7 +87,8 @@ impl UpstreamSelector {
for (i, host) in hosts.iter().enumerate() { for (i, host) in hosts.iter().enumerate() {
let key = format!("{}:{}", host, port); let key = format!("{}:{}", host, port);
let conns = self.active_connections let conns = self
.active_connections
.get(&key) .get(&key)
.map(|entry| entry.value().load(Ordering::Relaxed)) .map(|entry| entry.value().load(Ordering::Relaxed))
.unwrap_or(0); .unwrap_or(0);
@@ -228,13 +225,21 @@ mod tests {
selector.connection_started("backend:8080"); selector.connection_started("backend:8080");
selector.connection_started("backend:8080"); selector.connection_started("backend:8080");
assert_eq!( assert_eq!(
selector.active_connections.get("backend:8080").unwrap().load(Ordering::Relaxed), selector
.active_connections
.get("backend:8080")
.unwrap()
.load(Ordering::Relaxed),
2 2
); );
selector.connection_ended("backend:8080"); selector.connection_ended("backend:8080");
assert_eq!( assert_eq!(
selector.active_connections.get("backend:8080").unwrap().load(Ordering::Relaxed), selector
.active_connections
.get("backend:8080")
.unwrap()
.load(Ordering::Relaxed),
1 1
); );
File diff suppressed because it is too large Load Diff
+2 -2
View File
@@ -2,10 +2,10 @@
//! //!
//! Metrics and throughput tracking for RustProxy. //! Metrics and throughput tracking for RustProxy.
pub mod throughput;
pub mod collector; pub mod collector;
pub mod log_dedup; pub mod log_dedup;
pub mod throughput;
pub use throughput::*;
pub use collector::*; pub use collector::*;
pub use log_dedup::*; pub use log_dedup::*;
pub use throughput::*;
+11 -8
View File
@@ -1,6 +1,6 @@
use dashmap::DashMap; use dashmap::DashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use tracing::info; use tracing::info;
@@ -47,13 +47,16 @@ impl LogDeduplicator {
let map_key = format!("{}:{}", category, key); let map_key = format!("{}:{}", category, key);
let now = Instant::now(); let now = Instant::now();
let entry = self.events.entry(map_key).or_insert_with(|| AggregatedEvent { let entry = self
category: category.to_string(), .events
first_message: message.to_string(), .entry(map_key)
count: AtomicU64::new(0), .or_insert_with(|| AggregatedEvent {
first_seen: now, category: category.to_string(),
last_seen: now, first_message: message.to_string(),
}); count: AtomicU64::new(0),
first_seen: now,
last_seen: now,
});
let count = entry.count.fetch_add(1, Ordering::Relaxed) + 1; let count = entry.count.fetch_add(1, Ordering::Relaxed) + 1;
+146 -1
View File
@@ -29,6 +29,113 @@ pub struct ThroughputTracker {
created_at: Instant, created_at: Instant,
} }
fn unix_timestamp_seconds() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
}
/// Circular buffer for per-second event counts.
///
/// Unlike `ThroughputTracker`, events are recorded directly into the current
/// second so request counts remain stable even when the collector is sampled
/// more frequently than once per second.
pub(crate) struct RequestRateTracker {
samples: Vec<u64>,
write_index: usize,
count: usize,
capacity: usize,
current_second: Option<u64>,
current_count: u64,
}
impl RequestRateTracker {
pub(crate) fn new(retention_seconds: usize) -> Self {
Self {
samples: Vec::with_capacity(retention_seconds.max(1)),
write_index: 0,
count: 0,
capacity: retention_seconds.max(1),
current_second: None,
current_count: 0,
}
}
fn push_sample(&mut self, count: u64) {
if self.samples.len() < self.capacity {
self.samples.push(count);
} else {
self.samples[self.write_index] = count;
}
self.write_index = (self.write_index + 1) % self.capacity;
self.count = (self.count + 1).min(self.capacity);
}
pub(crate) fn record_event(&mut self) {
self.record_events_at(unix_timestamp_seconds(), 1);
}
pub(crate) fn record_events_at(&mut self, now_sec: u64, count: u64) {
self.advance_to(now_sec);
self.current_count = self.current_count.saturating_add(count);
}
pub(crate) fn advance_to_now(&mut self) {
self.advance_to(unix_timestamp_seconds());
}
pub(crate) fn advance_to(&mut self, now_sec: u64) {
match self.current_second {
Some(current_second) if now_sec > current_second => {
self.push_sample(self.current_count);
for _ in 1..(now_sec - current_second) {
self.push_sample(0);
}
self.current_second = Some(now_sec);
self.current_count = 0;
}
Some(_) => {}
None => {
self.current_second = Some(now_sec);
self.current_count = 0;
}
}
}
fn sum_recent(&self, window_seconds: usize) -> u64 {
let window = window_seconds.min(self.count);
if window == 0 {
return 0;
}
let mut total = 0u64;
for i in 0..window {
let idx = if self.write_index >= i + 1 {
self.write_index - i - 1
} else {
self.capacity - (i + 1 - self.write_index)
};
if idx < self.samples.len() {
total += self.samples[idx];
}
}
total
}
pub(crate) fn last_second(&self) -> u64 {
self.sum_recent(1)
}
pub(crate) fn last_minute(&self) -> u64 {
self.sum_recent(60)
}
pub(crate) fn is_idle(&self) -> bool {
self.current_count == 0 && self.sum_recent(self.capacity) == 0
}
}
impl ThroughputTracker { impl ThroughputTracker {
/// Create a new tracker with the given capacity (seconds of retention). /// Create a new tracker with the given capacity (seconds of retention).
pub fn new(retention_seconds: usize) -> Self { pub fn new(retention_seconds: usize) -> Self {
@@ -46,7 +153,8 @@ impl ThroughputTracker {
/// Record bytes (called from data flow callbacks). /// Record bytes (called from data flow callbacks).
pub fn record_bytes(&self, bytes_in: u64, bytes_out: u64) { pub fn record_bytes(&self, bytes_in: u64, bytes_out: u64) {
self.pending_bytes_in.fetch_add(bytes_in, Ordering::Relaxed); self.pending_bytes_in.fetch_add(bytes_in, Ordering::Relaxed);
self.pending_bytes_out.fetch_add(bytes_out, Ordering::Relaxed); self.pending_bytes_out
.fetch_add(bytes_out, Ordering::Relaxed);
} }
/// Take a sample (called at 1Hz). /// Take a sample (called at 1Hz).
@@ -229,4 +337,41 @@ mod tests {
let history = tracker.history(10); let history = tracker.history(10);
assert!(history.is_empty()); assert!(history.is_empty());
} }
#[test]
fn test_request_rate_tracker_counts_last_second_and_last_minute() {
let mut tracker = RequestRateTracker::new(60);
tracker.record_events_at(100, 2);
tracker.record_events_at(100, 3);
tracker.advance_to(101);
assert_eq!(tracker.last_second(), 5);
assert_eq!(tracker.last_minute(), 5);
}
#[test]
fn test_request_rate_tracker_adds_zero_samples_for_gaps() {
let mut tracker = RequestRateTracker::new(60);
tracker.record_events_at(100, 4);
tracker.record_events_at(102, 1);
tracker.advance_to(103);
assert_eq!(tracker.last_second(), 1);
assert_eq!(tracker.last_minute(), 5);
}
#[test]
fn test_request_rate_tracker_decays_to_zero_over_window() {
let mut tracker = RequestRateTracker::new(60);
tracker.record_events_at(100, 7);
tracker.advance_to(101);
tracker.advance_to(161);
assert_eq!(tracker.last_second(), 0);
assert_eq!(tracker.last_minute(), 0);
assert!(tracker.is_idle());
}
} }
@@ -10,6 +10,7 @@ description = "Raw TCP/SNI passthrough engine for RustProxy"
rustproxy-config = { workspace = true } rustproxy-config = { workspace = true }
rustproxy-routing = { workspace = true } rustproxy-routing = { workspace = true }
rustproxy-metrics = { workspace = true } rustproxy-metrics = { workspace = true }
rustproxy-security = { workspace = true }
tokio = { workspace = true } tokio = { workspace = true }
tracing = { workspace = true } tracing = { workspace = true }
thiserror = { workspace = true } thiserror = { workspace = true }
@@ -7,8 +7,8 @@
use std::collections::HashSet; use std::collections::HashSet;
use std::net::IpAddr; use std::net::IpAddr;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use dashmap::DashMap; use dashmap::DashMap;
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
@@ -73,7 +73,9 @@ impl ConnectionRegistry {
pub fn recycle_for_cert_change(&self, cert_domain: &str) { pub fn recycle_for_cert_change(&self, cert_domain: &str) {
let mut recycled = 0u64; let mut recycled = 0u64;
self.connections.retain(|_, entry| { self.connections.retain(|_, entry| {
let matches = entry.domain.as_deref() let matches = entry
.domain
.as_deref()
.map(|d| domain_matches(cert_domain, d) || domain_matches(d, cert_domain)) .map(|d| domain_matches(cert_domain, d) || domain_matches(d, cert_domain))
.unwrap_or(false); .unwrap_or(false);
if matches { if matches {
@@ -100,7 +102,11 @@ impl ConnectionRegistry {
let mut recycled = 0u64; let mut recycled = 0u64;
self.connections.retain(|_, entry| { self.connections.retain(|_, entry| {
if entry.route_id.as_deref() == Some(route_id) { if entry.route_id.as_deref() == Some(route_id) {
if !RequestFilter::check_ip_security(new_security, &entry.source_ip) { if !RequestFilter::check_ip_security(
new_security,
&entry.source_ip,
entry.domain.as_deref(),
) {
info!( info!(
"Terminating connection from {} — IP now blocked on route '{}'", "Terminating connection from {} — IP now blocked on route '{}'",
entry.source_ip, route_id entry.source_ip, route_id
@@ -31,7 +31,8 @@ impl ConnectionTracker {
pub fn try_accept(&self, ip: &IpAddr) -> bool { pub fn try_accept(&self, ip: &IpAddr) -> bool {
// Check per-IP connection limit // Check per-IP connection limit
if let Some(max) = self.max_per_ip { if let Some(max) = self.max_per_ip {
let count = self.active let count = self
.active
.get(ip) .get(ip)
.map(|c| c.value().load(Ordering::Relaxed)) .map(|c| c.value().load(Ordering::Relaxed))
.unwrap_or(0); .unwrap_or(0);
@@ -48,7 +49,10 @@ impl ConnectionTracker {
let timestamps = entry.value_mut(); let timestamps = entry.value_mut();
// Remove timestamps older than 1 minute // Remove timestamps older than 1 minute
while timestamps.front().is_some_and(|t| now.duration_since(*t) >= one_minute) { while timestamps
.front()
.is_some_and(|t| now.duration_since(*t) >= one_minute)
{
timestamps.pop_front(); timestamps.pop_front();
} }
@@ -111,7 +115,6 @@ impl ConnectionTracker {
pub fn tracked_ips(&self) -> usize { pub fn tracked_ips(&self) -> usize {
self.active.len() self.active.len()
} }
} }
#[cfg(test)] #[cfg(test)]
@@ -1,8 +1,8 @@
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use tracing::debug; use tracing::debug;
use rustproxy_metrics::MetricsCollector; use rustproxy_metrics::MetricsCollector;
@@ -87,7 +87,12 @@ pub async fn forward_bidirectional_with_timeouts(
if let Some(data) = initial_data { if let Some(data) = initial_data {
backend.write_all(data).await?; backend.write_all(data).await?;
if let Some(ref ctx) = metrics { if let Some(ref ctx) = metrics {
ctx.collector.record_bytes(data.len() as u64, 0, ctx.route_id.as_deref(), ctx.source_ip.as_deref()); ctx.collector.record_bytes(
data.len() as u64,
0,
ctx.route_id.as_deref(),
ctx.source_ip.as_deref(),
);
} }
} }
@@ -123,14 +128,17 @@ pub async fn forward_bidirectional_with_timeouts(
total += n as u64; total += n as u64;
la1.store(start.elapsed().as_millis() as u64, Ordering::Relaxed); la1.store(start.elapsed().as_millis() as u64, Ordering::Relaxed);
if let Some(ref ctx) = metrics_c2b { if let Some(ref ctx) = metrics_c2b {
ctx.collector.record_bytes(n as u64, 0, ctx.route_id.as_deref(), ctx.source_ip.as_deref()); ctx.collector.record_bytes(
n as u64,
0,
ctx.route_id.as_deref(),
ctx.source_ip.as_deref(),
);
} }
} }
// Graceful shutdown with timeout (sends TCP FIN / TLS close_notify) // Graceful shutdown with timeout (sends TCP FIN / TLS close_notify)
let _ = tokio::time::timeout( let _ =
std::time::Duration::from_secs(2), tokio::time::timeout(std::time::Duration::from_secs(2), backend_write.shutdown()).await;
backend_write.shutdown(),
).await;
total total
}); });
@@ -154,14 +162,17 @@ pub async fn forward_bidirectional_with_timeouts(
total += n as u64; total += n as u64;
la2.store(start.elapsed().as_millis() as u64, Ordering::Relaxed); la2.store(start.elapsed().as_millis() as u64, Ordering::Relaxed);
if let Some(ref ctx) = metrics_b2c { if let Some(ref ctx) = metrics_b2c {
ctx.collector.record_bytes(0, n as u64, ctx.route_id.as_deref(), ctx.source_ip.as_deref()); ctx.collector.record_bytes(
0,
n as u64,
ctx.route_id.as_deref(),
ctx.source_ip.as_deref(),
);
} }
} }
// Graceful shutdown with timeout (sends TCP FIN / TLS close_notify) // Graceful shutdown with timeout (sends TCP FIN / TLS close_notify)
let _ = tokio::time::timeout( let _ =
std::time::Duration::from_secs(2), tokio::time::timeout(std::time::Duration::from_secs(2), client_write.shutdown()).await;
client_write.shutdown(),
).await;
total total
}); });
+16 -16
View File
@@ -4,26 +4,26 @@
//! Handles TCP listening, TLS ClientHello SNI extraction, bidirectional forwarding, //! Handles TCP listening, TLS ClientHello SNI extraction, bidirectional forwarding,
//! and UDP datagram session tracking with forwarding. //! and UDP datagram session tracking with forwarding.
pub mod tcp_listener; pub mod connection_registry;
pub mod sni_parser; pub mod connection_tracker;
pub mod forwarder; pub mod forwarder;
pub mod proxy_protocol; pub mod proxy_protocol;
pub mod tls_handler;
pub mod connection_tracker;
pub mod connection_registry;
pub mod socket_opts;
pub mod udp_session;
pub mod udp_listener;
pub mod quic_handler; pub mod quic_handler;
pub mod sni_parser;
pub mod socket_opts;
pub mod tcp_listener;
pub mod tls_handler;
pub mod udp_listener;
pub mod udp_session;
pub use tcp_listener::*; pub use connection_registry::*;
pub use sni_parser::*; pub use connection_tracker::*;
pub use forwarder::*; pub use forwarder::*;
pub use proxy_protocol::*; pub use proxy_protocol::*;
pub use tls_handler::*;
pub use connection_tracker::*;
pub use connection_registry::*;
pub use socket_opts::*;
pub use udp_session::*;
pub use udp_listener::*;
pub use quic_handler::*; pub use quic_handler::*;
pub use sni_parser::*;
pub use socket_opts::*;
pub use tcp_listener::*;
pub use tls_handler::*;
pub use udp_listener::*;
pub use udp_session::*;
@@ -54,8 +54,8 @@ pub fn parse_v1(data: &[u8]) -> Result<(ProxyProtocolHeader, usize), ProxyProtoc
.position(|w| w == b"\r\n") .position(|w| w == b"\r\n")
.ok_or(ProxyProtocolError::InvalidHeader)?; .ok_or(ProxyProtocolError::InvalidHeader)?;
let line = std::str::from_utf8(&data[..line_end]) let line =
.map_err(|_| ProxyProtocolError::InvalidHeader)?; std::str::from_utf8(&data[..line_end]).map_err(|_| ProxyProtocolError::InvalidHeader)?;
if !line.starts_with("PROXY ") { if !line.starts_with("PROXY ") {
return Err(ProxyProtocolError::InvalidHeader); return Err(ProxyProtocolError::InvalidHeader);
@@ -148,7 +148,10 @@ pub fn parse_v2(data: &[u8]) -> Result<(ProxyProtocolHeader, usize), ProxyProtoc
let command = data[12] & 0x0F; let command = data[12] & 0x0F;
// 0x0 = LOCAL, 0x1 = PROXY // 0x0 = LOCAL, 0x1 = PROXY
if command > 1 { if command > 1 {
return Err(ProxyProtocolError::Parse(format!("Unknown command: {}", command))); return Err(ProxyProtocolError::Parse(format!(
"Unknown command: {}",
command
)));
} }
// Address family (high nibble) + transport (low nibble) of byte 13 // Address family (high nibble) + transport (low nibble) of byte 13
@@ -182,7 +185,9 @@ pub fn parse_v2(data: &[u8]) -> Result<(ProxyProtocolHeader, usize), ProxyProtoc
// AF_INET (0x1) + STREAM (0x1) = TCP4 // AF_INET (0x1) + STREAM (0x1) = TCP4
(0x1, 0x1) => { (0x1, 0x1) => {
if addr_len < 12 { if addr_len < 12 {
return Err(ProxyProtocolError::Parse("IPv4 address block too short".to_string())); return Err(ProxyProtocolError::Parse(
"IPv4 address block too short".to_string(),
));
} }
let src_ip = Ipv4Addr::new(addr_block[0], addr_block[1], addr_block[2], addr_block[3]); let src_ip = Ipv4Addr::new(addr_block[0], addr_block[1], addr_block[2], addr_block[3]);
let dst_ip = Ipv4Addr::new(addr_block[4], addr_block[5], addr_block[6], addr_block[7]); let dst_ip = Ipv4Addr::new(addr_block[4], addr_block[5], addr_block[6], addr_block[7]);
@@ -200,7 +205,9 @@ pub fn parse_v2(data: &[u8]) -> Result<(ProxyProtocolHeader, usize), ProxyProtoc
// AF_INET (0x1) + DGRAM (0x2) = UDP4 // AF_INET (0x1) + DGRAM (0x2) = UDP4
(0x1, 0x2) => { (0x1, 0x2) => {
if addr_len < 12 { if addr_len < 12 {
return Err(ProxyProtocolError::Parse("IPv4 address block too short".to_string())); return Err(ProxyProtocolError::Parse(
"IPv4 address block too short".to_string(),
));
} }
let src_ip = Ipv4Addr::new(addr_block[0], addr_block[1], addr_block[2], addr_block[3]); let src_ip = Ipv4Addr::new(addr_block[0], addr_block[1], addr_block[2], addr_block[3]);
let dst_ip = Ipv4Addr::new(addr_block[4], addr_block[5], addr_block[6], addr_block[7]); let dst_ip = Ipv4Addr::new(addr_block[4], addr_block[5], addr_block[6], addr_block[7]);
@@ -218,7 +225,9 @@ pub fn parse_v2(data: &[u8]) -> Result<(ProxyProtocolHeader, usize), ProxyProtoc
// AF_INET6 (0x2) + STREAM (0x1) = TCP6 // AF_INET6 (0x2) + STREAM (0x1) = TCP6
(0x2, 0x1) => { (0x2, 0x1) => {
if addr_len < 36 { if addr_len < 36 {
return Err(ProxyProtocolError::Parse("IPv6 address block too short".to_string())); return Err(ProxyProtocolError::Parse(
"IPv6 address block too short".to_string(),
));
} }
let src_ip = Ipv6Addr::from(<[u8; 16]>::try_from(&addr_block[0..16]).unwrap()); let src_ip = Ipv6Addr::from(<[u8; 16]>::try_from(&addr_block[0..16]).unwrap());
let dst_ip = Ipv6Addr::from(<[u8; 16]>::try_from(&addr_block[16..32]).unwrap()); let dst_ip = Ipv6Addr::from(<[u8; 16]>::try_from(&addr_block[16..32]).unwrap());
@@ -236,7 +245,9 @@ pub fn parse_v2(data: &[u8]) -> Result<(ProxyProtocolHeader, usize), ProxyProtoc
// AF_INET6 (0x2) + DGRAM (0x2) = UDP6 // AF_INET6 (0x2) + DGRAM (0x2) = UDP6
(0x2, 0x2) => { (0x2, 0x2) => {
if addr_len < 36 { if addr_len < 36 {
return Err(ProxyProtocolError::Parse("IPv6 address block too short".to_string())); return Err(ProxyProtocolError::Parse(
"IPv6 address block too short".to_string(),
));
} }
let src_ip = Ipv6Addr::from(<[u8; 16]>::try_from(&addr_block[0..16]).unwrap()); let src_ip = Ipv6Addr::from(<[u8; 16]>::try_from(&addr_block[0..16]).unwrap());
let dst_ip = Ipv6Addr::from(<[u8; 16]>::try_from(&addr_block[16..32]).unwrap()); let dst_ip = Ipv6Addr::from(<[u8; 16]>::try_from(&addr_block[16..32]).unwrap());
@@ -268,11 +279,7 @@ pub fn parse_v2(data: &[u8]) -> Result<(ProxyProtocolHeader, usize), ProxyProtoc
} }
/// Generate a PROXY protocol v2 binary header. /// Generate a PROXY protocol v2 binary header.
pub fn generate_v2( pub fn generate_v2(source: &SocketAddr, dest: &SocketAddr, transport: ProxyV2Transport) -> Vec<u8> {
source: &SocketAddr,
dest: &SocketAddr,
transport: ProxyV2Transport,
) -> Vec<u8> {
let transport_nibble: u8 = match transport { let transport_nibble: u8 = match transport {
ProxyV2Transport::Stream => 0x1, ProxyV2Transport::Stream => 0x1,
ProxyV2Transport::Datagram => 0x2, ProxyV2Transport::Datagram => 0x2,
@@ -462,7 +469,10 @@ mod tests {
header.push(0x11); header.push(0x11);
header.extend_from_slice(&12u16.to_be_bytes()); header.extend_from_slice(&12u16.to_be_bytes());
header.extend_from_slice(&[0u8; 12]); header.extend_from_slice(&[0u8; 12]);
assert!(matches!(parse_v2(&header), Err(ProxyProtocolError::UnsupportedVersion))); assert!(matches!(
parse_v2(&header),
Err(ProxyProtocolError::UnsupportedVersion)
));
} }
#[test] #[test]
@@ -26,11 +26,12 @@ use tracing::{debug, info, warn};
use rustproxy_config::{RouteConfig, TransportProtocol}; use rustproxy_config::{RouteConfig, TransportProtocol};
use rustproxy_metrics::MetricsCollector; use rustproxy_metrics::MetricsCollector;
use rustproxy_routing::{MatchContext, RouteManager}; use rustproxy_routing::{MatchContext, RouteManager};
use rustproxy_security::IpBlockList;
use rustproxy_http::h3_service::H3ProxyService; use rustproxy_http::h3_service::H3ProxyService;
use crate::connection_tracker::ConnectionTracker;
use crate::connection_registry::{ConnectionEntry, ConnectionRegistry}; use crate::connection_registry::{ConnectionEntry, ConnectionRegistry};
use crate::connection_tracker::ConnectionTracker;
/// Create a QUIC server endpoint on the given port with the provided TLS config. /// Create a QUIC server endpoint on the given port with the provided TLS config.
/// ///
@@ -48,8 +49,7 @@ pub fn create_quic_endpoint(
quinn::EndpointConfig::default(), quinn::EndpointConfig::default(),
Some(server_config), Some(server_config),
socket, socket,
quinn::default_runtime() quinn::default_runtime().ok_or_else(|| anyhow::anyhow!("No async runtime for quinn"))?,
.ok_or_else(|| anyhow::anyhow!("No async runtime for quinn"))?,
)?; )?;
info!("QUIC endpoint listening on port {}", port); info!("QUIC endpoint listening on port {}", port);
@@ -97,6 +97,7 @@ pub fn create_quic_endpoint_with_proxy_relay(
port: u16, port: u16,
tls_config: Arc<RustlsServerConfig>, tls_config: Arc<RustlsServerConfig>,
proxy_ips: Arc<Vec<IpAddr>>, proxy_ips: Arc<Vec<IpAddr>>,
security_policy: Arc<ArcSwap<IpBlockList>>,
cancel: CancellationToken, cancel: CancellationToken,
) -> anyhow::Result<QuicProxyRelay> { ) -> anyhow::Result<QuicProxyRelay> {
// Bind external socket on the real port // Bind external socket on the real port
@@ -119,8 +120,7 @@ pub fn create_quic_endpoint_with_proxy_relay(
quinn::EndpointConfig::default(), quinn::EndpointConfig::default(),
Some(server_config), Some(server_config),
internal_socket, internal_socket,
quinn::default_runtime() quinn::default_runtime().ok_or_else(|| anyhow::anyhow!("No async runtime for quinn"))?,
.ok_or_else(|| anyhow::anyhow!("No async runtime for quinn"))?,
)?; )?;
let real_client_map = Arc::new(DashMap::new()); let real_client_map = Arc::new(DashMap::new());
@@ -129,12 +129,20 @@ pub fn create_quic_endpoint_with_proxy_relay(
external_socket, external_socket,
quinn_internal_addr, quinn_internal_addr,
proxy_ips, proxy_ips,
security_policy,
Arc::clone(&real_client_map), Arc::clone(&real_client_map),
cancel, cancel,
)); ));
info!("QUIC endpoint with PROXY relay on port {} (quinn internal: {})", port, quinn_internal_addr); info!(
Ok(QuicProxyRelay { endpoint, relay_task, real_client_map }) "QUIC endpoint with PROXY relay on port {} (quinn internal: {})",
port, quinn_internal_addr
);
Ok(QuicProxyRelay {
endpoint,
relay_task,
real_client_map,
})
} }
/// Main relay loop: reads datagrams from the external socket, filters PROXY v2 /// Main relay loop: reads datagrams from the external socket, filters PROXY v2
@@ -144,6 +152,7 @@ async fn quic_proxy_relay_loop(
external_socket: Arc<UdpSocket>, external_socket: Arc<UdpSocket>,
quinn_internal_addr: SocketAddr, quinn_internal_addr: SocketAddr,
proxy_ips: Arc<Vec<IpAddr>>, proxy_ips: Arc<Vec<IpAddr>>,
security_policy: Arc<ArcSwap<IpBlockList>>,
real_client_map: Arc<DashMap<SocketAddr, SocketAddr>>, real_client_map: Arc<DashMap<SocketAddr, SocketAddr>>,
cancel: CancellationToken, cancel: CancellationToken,
) { ) {
@@ -184,26 +193,43 @@ async fn quic_proxy_relay_loop(
if crate::proxy_protocol::is_proxy_protocol_v2(datagram) { if crate::proxy_protocol::is_proxy_protocol_v2(datagram) {
match crate::proxy_protocol::parse_v2(datagram) { match crate::proxy_protocol::parse_v2(datagram) {
Ok((header, _consumed)) => { Ok((header, _consumed)) => {
debug!("QUIC PROXY v2 from {}: real client {}", src_addr, header.source_addr); debug!(
"QUIC PROXY v2 from {}: real client {}",
src_addr, header.source_addr
);
proxy_addr_map.insert(src_addr, header.source_addr); proxy_addr_map.insert(src_addr, header.source_addr);
continue; // consume the PROXY v2 datagram continue; // consume the PROXY v2 datagram
} }
Err(e) => { Err(e) => {
debug!("QUIC proxy relay: failed to parse PROXY v2 from {}: {}", src_addr, e); debug!(
"QUIC proxy relay: failed to parse PROXY v2 from {}: {}",
src_addr, e
);
} }
} }
} }
} }
// Determine real client address // Determine real client address
let real_client = proxy_addr_map.get(&src_addr) let real_client = proxy_addr_map
.get(&src_addr)
.map(|r| *r) .map(|r| *r)
.unwrap_or(src_addr); .unwrap_or(src_addr);
let block_list = security_policy.load();
if !block_list.is_empty() && block_list.is_blocked(&real_client.ip()) {
debug!(
"QUIC datagram from {} blocked by global security policy",
real_client
);
continue;
}
// Get or create relay session for this external source // Get or create relay session for this external source
let session = match relay_sessions.get(&src_addr) { let session = match relay_sessions.get(&src_addr) {
Some(s) => { Some(s) => {
s.last_activity.store(epoch.elapsed().as_millis() as u64, Ordering::Relaxed); s.last_activity
.store(epoch.elapsed().as_millis() as u64, Ordering::Relaxed);
Arc::clone(s.value()) Arc::clone(s.value())
} }
None => { None => {
@@ -216,7 +242,10 @@ async fn quic_proxy_relay_loop(
} }
}; };
if let Err(e) = relay_socket.connect(quinn_internal_addr).await { if let Err(e) = relay_socket.connect(quinn_internal_addr).await {
warn!("QUIC relay: failed to connect relay socket to {}: {}", quinn_internal_addr, e); warn!(
"QUIC relay: failed to connect relay socket to {}: {}",
quinn_internal_addr, e
);
continue; continue;
} }
let relay_local_addr = match relay_socket.local_addr() { let relay_local_addr = match relay_socket.local_addr() {
@@ -248,8 +277,10 @@ async fn quic_proxy_relay_loop(
}); });
relay_sessions.insert(src_addr, Arc::clone(&session)); relay_sessions.insert(src_addr, Arc::clone(&session));
debug!("QUIC relay: new session for {} (relay {}), real client {}", debug!(
src_addr, relay_local_addr, real_client); "QUIC relay: new session for {} (relay {}), real client {}",
src_addr, relay_local_addr, real_client
);
session session
} }
@@ -264,9 +295,11 @@ async fn quic_proxy_relay_loop(
if last_cleanup.elapsed() >= cleanup_interval { if last_cleanup.elapsed() >= cleanup_interval {
last_cleanup = Instant::now(); last_cleanup = Instant::now();
let now_ms = epoch.elapsed().as_millis() as u64; let now_ms = epoch.elapsed().as_millis() as u64;
let stale_keys: Vec<SocketAddr> = relay_sessions.iter() let stale_keys: Vec<SocketAddr> = relay_sessions
.iter()
.filter(|entry| { .filter(|entry| {
let age = now_ms.saturating_sub(entry.value().last_activity.load(Ordering::Relaxed)); let age =
now_ms.saturating_sub(entry.value().last_activity.load(Ordering::Relaxed));
age > session_timeout_ms age > session_timeout_ms
}) })
.map(|entry| *entry.key()) .map(|entry| *entry.key())
@@ -287,13 +320,17 @@ async fn quic_proxy_relay_loop(
// Also clean orphaned proxy_addr_map entries (PROXY header received // Also clean orphaned proxy_addr_map entries (PROXY header received
// but no relay session was ever created, e.g. client never sent data) // but no relay session was ever created, e.g. client never sent data)
let orphaned: Vec<SocketAddr> = proxy_addr_map.iter() let orphaned: Vec<SocketAddr> = proxy_addr_map
.iter()
.filter(|entry| relay_sessions.get(entry.key()).is_none()) .filter(|entry| relay_sessions.get(entry.key()).is_none())
.map(|entry| *entry.key()) .map(|entry| *entry.key())
.collect(); .collect();
for key in orphaned { for key in orphaned {
proxy_addr_map.remove(&key); proxy_addr_map.remove(&key);
debug!("QUIC relay: cleaned up orphaned proxy_addr_map entry for {}", key); debug!(
"QUIC relay: cleaned up orphaned proxy_addr_map entry for {}",
key
);
} }
} }
} }
@@ -328,8 +365,14 @@ async fn relay_return_path(
} }
}; };
if let Err(e) = external_socket.send_to(&buf[..len], external_src_addr).await { if let Err(e) = external_socket
debug!("QUIC relay return send error to {}: {}", external_src_addr, e); .send_to(&buf[..len], external_src_addr)
.await
{
debug!(
"QUIC relay return send error to {}: {}",
external_src_addr, e
);
break; break;
} }
} }
@@ -353,6 +396,7 @@ pub async fn quic_accept_loop(
real_client_map: Option<Arc<DashMap<SocketAddr, SocketAddr>>>, real_client_map: Option<Arc<DashMap<SocketAddr, SocketAddr>>>,
route_cancels: Arc<DashMap<String, CancellationToken>>, route_cancels: Arc<DashMap<String, CancellationToken>>,
connection_registry: Arc<ConnectionRegistry>, connection_registry: Arc<ConnectionRegistry>,
security_policy: Arc<ArcSwap<IpBlockList>>,
) { ) {
loop { loop {
let incoming = tokio::select! { let incoming = tokio::select! {
@@ -374,11 +418,21 @@ pub async fn quic_accept_loop(
let remote_addr = incoming.remote_address(); let remote_addr = incoming.remote_address();
// Resolve real client IP from PROXY protocol map if available // Resolve real client IP from PROXY protocol map if available
let real_addr = real_client_map.as_ref() let real_addr = real_client_map
.as_ref()
.and_then(|map| map.get(&remote_addr).map(|r| *r)) .and_then(|map| map.get(&remote_addr).map(|r| *r))
.unwrap_or(remote_addr); .unwrap_or(remote_addr);
let ip = real_addr.ip(); let ip = real_addr.ip();
let block_list = security_policy.load();
if !block_list.is_empty() && block_list.is_blocked(&ip) {
debug!(
"QUIC connection from {} blocked by global security policy",
real_addr
);
continue;
}
// Per-IP rate limiting // Per-IP rate limiting
if !conn_tracker.try_accept(&ip) { if !conn_tracker.try_accept(&ip) {
debug!("QUIC connection rejected from {} (rate limit)", real_addr); debug!("QUIC connection rejected from {} (rate limit)", real_addr);
@@ -409,23 +463,27 @@ pub async fn quic_accept_loop(
} }
}; };
// Check route-level IP security (previously missing for QUIC) // Check route-level IP security for QUIC (domain from SNI context)
if let Some(ref security) = route.security { if let Some(ref security) = route.security {
if !rustproxy_http::request_filter::RequestFilter::check_ip_security( if !rustproxy_http::request_filter::RequestFilter::check_ip_security(
security, &ip, security, &ip, ctx.domain,
) { ) {
debug!("QUIC connection from {} blocked by route security", real_addr); debug!(
"QUIC connection from {} blocked by route security",
real_addr
);
continue; continue;
} }
} }
conn_tracker.connection_opened(&ip); conn_tracker.connection_opened(&ip);
let route_id = route.name.clone().or(route.id.clone()); let route_id = route.metrics_key().map(str::to_string);
metrics.connection_opened(route_id.as_deref(), Some(&ip_str)); metrics.connection_opened(route_id.as_deref(), Some(&ip_str));
// Resolve per-route cancel token (child of global cancel) // Resolve per-route cancel token (child of global cancel)
let route_cancel = match route_id.as_deref() { let route_cancel = match route_id.as_deref() {
Some(id) => route_cancels.entry(id.to_string()) Some(id) => route_cancels
.entry(id.to_string())
.or_insert_with(|| cancel.child_token()) .or_insert_with(|| cancel.child_token())
.clone(), .clone(),
None => cancel.child_token(), None => cancel.child_token(),
@@ -445,7 +503,11 @@ pub async fn quic_accept_loop(
let metrics = Arc::clone(&metrics); let metrics = Arc::clone(&metrics);
let conn_tracker = Arc::clone(&conn_tracker); let conn_tracker = Arc::clone(&conn_tracker);
let h3_svc = h3_service.clone(); let h3_svc = h3_service.clone();
let real_client_addr = if real_addr != remote_addr { Some(real_addr) } else { None }; let real_client_addr = if real_addr != remote_addr {
Some(real_addr)
} else {
None
};
tokio::spawn(async move { tokio::spawn(async move {
// Register in connection registry (RAII guard removes on drop) // Register in connection registry (RAII guard removes on drop)
@@ -462,7 +524,8 @@ pub async fn quic_accept_loop(
impl Drop for QuicConnGuard { impl Drop for QuicConnGuard {
fn drop(&mut self) { fn drop(&mut self) {
self.tracker.connection_closed(&self.ip); self.tracker.connection_closed(&self.ip);
self.metrics.connection_closed(self.route_id.as_deref(), Some(&self.ip_str)); self.metrics
.connection_closed(self.route_id.as_deref(), Some(&self.ip_str));
} }
} }
let _guard = QuicConnGuard { let _guard = QuicConnGuard {
@@ -473,7 +536,17 @@ pub async fn quic_accept_loop(
route_id, route_id,
}; };
match handle_quic_connection(incoming, route, port, Arc::clone(&metrics), &conn_cancel, h3_svc, real_client_addr).await { match handle_quic_connection(
incoming,
route,
port,
Arc::clone(&metrics),
&conn_cancel,
h3_svc,
real_client_addr,
)
.await
{
Ok(()) => debug!("QUIC connection from {} completed", real_addr), Ok(()) => debug!("QUIC connection from {} completed", real_addr),
Err(e) => debug!("QUIC connection from {} error: {}", real_addr, e), Err(e) => debug!("QUIC connection from {} error: {}", real_addr, e),
} }
@@ -501,17 +574,28 @@ async fn handle_quic_connection(
debug!("QUIC connection established from {}", effective_addr); debug!("QUIC connection established from {}", effective_addr);
// Check if this route has HTTP/3 enabled // Check if this route has HTTP/3 enabled
let enable_http3 = route.action.udp.as_ref() let enable_http3 = route
.action
.udp
.as_ref()
.and_then(|u| u.quic.as_ref()) .and_then(|u| u.quic.as_ref())
.and_then(|q| q.enable_http3) .and_then(|q| q.enable_http3)
.unwrap_or(false); .unwrap_or(false);
if enable_http3 { if enable_http3 {
if let Some(ref h3_svc) = h3_service { if let Some(ref h3_svc) = h3_service {
debug!("HTTP/3 enabled for route {:?}, dispatching to H3ProxyService", route.name); debug!(
h3_svc.handle_connection(connection, &route, port, real_client_addr, cancel).await "HTTP/3 enabled for route {:?}, dispatching to H3ProxyService",
route.name
);
h3_svc
.handle_connection(connection, &route, port, real_client_addr, cancel)
.await
} else { } else {
warn!("HTTP/3 enabled for route {:?} but H3ProxyService not initialized", route.name); warn!(
"HTTP/3 enabled for route {:?} but H3ProxyService not initialized",
route.name
);
// Keep connection alive until cancelled // Keep connection alive until cancelled
tokio::select! { tokio::select! {
_ = cancel.cancelled() => {} _ = cancel.cancelled() => {}
@@ -523,7 +607,8 @@ async fn handle_quic_connection(
} }
} else { } else {
// Non-HTTP3 QUIC: bidirectional stream forwarding to TCP backend // Non-HTTP3 QUIC: bidirectional stream forwarding to TCP backend
handle_quic_stream_forwarding(connection, route, port, metrics, cancel, real_client_addr).await handle_quic_stream_forwarding(connection, route, port, metrics, cancel, real_client_addr)
.await
} }
} }
@@ -541,11 +626,14 @@ async fn handle_quic_stream_forwarding(
real_client_addr: Option<SocketAddr>, real_client_addr: Option<SocketAddr>,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let effective_addr = real_client_addr.unwrap_or_else(|| connection.remote_address()); let effective_addr = real_client_addr.unwrap_or_else(|| connection.remote_address());
let route_id = route.name.as_deref().or(route.id.as_deref()); let route_id = route.metrics_key();
let metrics_arc = metrics; let metrics_arc = metrics;
// Resolve backend target // Resolve backend target
let target = route.action.targets.as_ref() let target = route
.action
.targets
.as_ref()
.and_then(|t| t.first()) .and_then(|t| t.first())
.ok_or_else(|| anyhow::anyhow!("No target for QUIC route"))?; .ok_or_else(|| anyhow::anyhow!("No target for QUIC route"))?;
let backend_host = target.host.first(); let backend_host = target.host.first();
@@ -576,19 +664,20 @@ async fn handle_quic_stream_forwarding(
// Spawn a task for each QUIC stream → TCP bidirectional forwarding // Spawn a task for each QUIC stream → TCP bidirectional forwarding
tokio::spawn(async move { tokio::spawn(async move {
match forward_quic_stream_to_tcp( match forward_quic_stream_to_tcp(send_stream, recv_stream, &backend_addr, stream_cancel)
send_stream, .await
recv_stream, {
&backend_addr,
stream_cancel,
).await {
Ok((bytes_in, bytes_out)) => { Ok((bytes_in, bytes_out)) => {
stream_metrics.record_bytes( stream_metrics.record_bytes(
bytes_in, bytes_out, bytes_in,
bytes_out,
stream_route_id.as_deref(), stream_route_id.as_deref(),
Some(&ip_str), Some(&ip_str),
); );
debug!("QUIC stream forwarded: {}B in, {}B out", bytes_in, bytes_out); debug!(
"QUIC stream forwarded: {}B in, {}B out",
bytes_in, bytes_out
);
} }
Err(e) => { Err(e) => {
debug!("QUIC stream forwarding error: {}", e); debug!("QUIC stream forwarding error: {}", e);
@@ -640,10 +729,7 @@ async fn forward_quic_stream_to_tcp(
total += n as u64; total += n as u64;
la1.store(start.elapsed().as_millis() as u64, Ordering::Relaxed); la1.store(start.elapsed().as_millis() as u64, Ordering::Relaxed);
} }
let _ = tokio::time::timeout( let _ = tokio::time::timeout(std::time::Duration::from_secs(2), tcp_write.shutdown()).await;
std::time::Duration::from_secs(2),
tcp_write.shutdown(),
).await;
total total
}); });
@@ -721,8 +807,8 @@ mod tests {
let _ = rustls::crypto::ring::default_provider().install_default(); let _ = rustls::crypto::ring::default_provider().install_default();
// Generate a single self-signed cert and use its key pair // Generate a single self-signed cert and use its key pair
let self_signed = rcgen::generate_simple_self_signed(vec!["localhost".to_string()]) let self_signed =
.unwrap(); rcgen::generate_simple_self_signed(vec!["localhost".to_string()]).unwrap();
let cert_der = self_signed.cert.der().clone(); let cert_der = self_signed.cert.der().clone();
let key_der = self_signed.key_pair.serialize_der(); let key_der = self_signed.key_pair.serialize_der();
@@ -737,6 +823,10 @@ mod tests {
// Port 0 = OS assigns a free port // Port 0 = OS assigns a free port
let result = create_quic_endpoint(0, Arc::new(tls_config)); let result = create_quic_endpoint(0, Arc::new(tls_config));
assert!(result.is_ok(), "QUIC endpoint creation failed: {:?}", result.err()); assert!(
result.is_ok(),
"QUIC endpoint creation failed: {:?}",
result.err()
);
} }
} }
@@ -47,9 +47,8 @@ pub fn extract_sni(data: &[u8]) -> SniResult {
} }
// Handshake length (3 bytes) - informational, we parse incrementally // Handshake length (3 bytes) - informational, we parse incrementally
let _handshake_len = ((data[6] as usize) << 16) let _handshake_len =
| ((data[7] as usize) << 8) ((data[6] as usize) << 16) | ((data[7] as usize) << 8) | (data[8] as usize);
| (data[8] as usize);
let hello = &data[9..]; let hello = &data[9..];
@@ -170,7 +169,10 @@ pub fn extract_http_path(data: &[u8]) -> Option<String> {
pub fn extract_http_host(data: &[u8]) -> Option<String> { pub fn extract_http_host(data: &[u8]) -> Option<String> {
let text = std::str::from_utf8(data).ok()?; let text = std::str::from_utf8(data).ok()?;
for line in text.split("\r\n") { for line in text.split("\r\n") {
if let Some(value) = line.strip_prefix("Host: ").or_else(|| line.strip_prefix("host: ")) { if let Some(value) = line
.strip_prefix("Host: ")
.or_else(|| line.strip_prefix("host: "))
{
// Strip port if present // Strip port if present
let host = value.split(':').next().unwrap_or(value).trim(); let host = value.split(':').next().unwrap_or(value).trim();
if !host.is_empty() { if !host.is_empty() {
@@ -196,7 +198,7 @@ pub fn is_http(data: &[u8]) -> bool {
b"PATC", b"PATC",
b"OPTI", b"OPTI",
b"CONN", b"CONN",
b"PRI ", // HTTP/2 connection preface b"PRI ", // HTTP/2 connection preface
]; ];
starts.iter().any(|s| data.starts_with(s)) starts.iter().any(|s| data.starts_with(s))
} }
@@ -213,7 +215,10 @@ mod tests {
#[test] #[test]
fn test_too_short() { fn test_too_short() {
assert!(matches!(extract_sni(&[0x16, 0x03]), SniResult::NeedMoreData)); assert!(matches!(
extract_sni(&[0x16, 0x03]),
SniResult::NeedMoreData
));
} }
#[test] #[test]
@@ -263,7 +268,8 @@ mod tests {
// Extension: type=0x0000 (SNI), length, data // Extension: type=0x0000 (SNI), length, data
let sni_extension = { let sni_extension = {
let mut e = Vec::new(); let mut e = Vec::new();
e.push(0x00); e.push(0x00); // SNI type e.push(0x00);
e.push(0x00); // SNI type
e.push(((sni_ext_data.len() >> 8) & 0xFF) as u8); e.push(((sni_ext_data.len() >> 8) & 0xFF) as u8);
e.push((sni_ext_data.len() & 0xFF) as u8); e.push((sni_ext_data.len() & 0xFF) as u8);
e.extend_from_slice(&sni_ext_data); e.extend_from_slice(&sni_ext_data);
@@ -283,16 +289,20 @@ mod tests {
let hello_body = { let hello_body = {
let mut h = Vec::new(); let mut h = Vec::new();
// Client version TLS 1.2 // Client version TLS 1.2
h.push(0x03); h.push(0x03); h.push(0x03);
h.push(0x03);
// Random (32 bytes) // Random (32 bytes)
h.extend_from_slice(&[0u8; 32]); h.extend_from_slice(&[0u8; 32]);
// Session ID length = 0 // Session ID length = 0
h.push(0x00); h.push(0x00);
// Cipher suites: length=2, one suite // Cipher suites: length=2, one suite
h.push(0x00); h.push(0x02); h.push(0x00);
h.push(0x00); h.push(0x2F); // TLS_RSA_WITH_AES_128_CBC_SHA h.push(0x02);
// Compression methods: length=1, null h.push(0x00);
h.push(0x01); h.push(0x00); h.push(0x2F); // TLS_RSA_WITH_AES_128_CBC_SHA
// Compression methods: length=1, null
h.push(0x01);
h.push(0x00);
// Extensions // Extensions
h.extend_from_slice(&extensions); h.extend_from_slice(&extensions);
h h
@@ -302,7 +312,7 @@ mod tests {
let handshake = { let handshake = {
let mut hs = Vec::new(); let mut hs = Vec::new();
hs.push(0x01); // ClientHello hs.push(0x01); // ClientHello
// 3-byte length // 3-byte length
hs.push(((hello_body.len() >> 16) & 0xFF) as u8); hs.push(((hello_body.len() >> 16) & 0xFF) as u8);
hs.push(((hello_body.len() >> 8) & 0xFF) as u8); hs.push(((hello_body.len() >> 8) & 0xFF) as u8);
hs.push((hello_body.len() & 0xFF) as u8); hs.push((hello_body.len() & 0xFF) as u8);
@@ -313,7 +323,8 @@ mod tests {
// TLS record: type=0x16, version TLS 1.0, length // TLS record: type=0x16, version TLS 1.0, length
let mut record = Vec::new(); let mut record = Vec::new();
record.push(0x16); // Handshake record.push(0x16); // Handshake
record.push(0x03); record.push(0x01); // TLS 1.0 record.push(0x03);
record.push(0x01); // TLS 1.0
record.push(((handshake.len() >> 8) & 0xFF) as u8); record.push(((handshake.len() >> 8) & 0xFF) as u8);
record.push((handshake.len() & 0xFF) as u8); record.push((handshake.len() & 0xFF) as u8);
record.extend_from_slice(&handshake); record.extend_from_slice(&handshake);
File diff suppressed because it is too large Load Diff
@@ -7,7 +7,7 @@ use rustls::server::ResolvesServerCert;
use rustls::sign::CertifiedKey; use rustls::sign::CertifiedKey;
use rustls::ServerConfig; use rustls::ServerConfig;
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio_rustls::{TlsAcceptor, TlsConnector, server::TlsStream as ServerTlsStream}; use tokio_rustls::{server::TlsStream as ServerTlsStream, TlsAcceptor, TlsConnector};
use tracing::{debug, info}; use tracing::{debug, info};
use crate::tcp_listener::TlsCertConfig; use crate::tcp_listener::TlsCertConfig;
@@ -29,7 +29,9 @@ pub struct CertResolver {
impl CertResolver { impl CertResolver {
/// Build a resolver from PEM-encoded cert/key configs. /// Build a resolver from PEM-encoded cert/key configs.
/// Parses all PEM data upfront so connections only do a cheap HashMap lookup. /// Parses all PEM data upfront so connections only do a cheap HashMap lookup.
pub fn new(configs: &HashMap<String, TlsCertConfig>) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> { pub fn new(
configs: &HashMap<String, TlsCertConfig>,
) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
ensure_crypto_provider(); ensure_crypto_provider();
let provider = rustls::crypto::ring::default_provider(); let provider = rustls::crypto::ring::default_provider();
let mut certs = HashMap::new(); let mut certs = HashMap::new();
@@ -38,8 +40,10 @@ impl CertResolver {
for (domain, cfg) in configs { for (domain, cfg) in configs {
let cert_chain = load_certs(&cfg.cert_pem)?; let cert_chain = load_certs(&cfg.cert_pem)?;
let key = load_private_key(&cfg.key_pem)?; let key = load_private_key(&cfg.key_pem)?;
let ck = Arc::new(CertifiedKey::from_der(cert_chain, key, &provider) let ck = Arc::new(
.map_err(|e| format!("CertifiedKey for {}: {}", domain, e))?); CertifiedKey::from_der(cert_chain, key, &provider)
.map_err(|e| format!("CertifiedKey for {}: {}", domain, e))?,
);
if domain == "*" { if domain == "*" {
fallback = Some(Arc::clone(&ck)); fallback = Some(Arc::clone(&ck));
} }
@@ -78,7 +82,9 @@ impl ResolvesServerCert for CertResolver {
/// Build a shared TLS acceptor with SNI resolution, session cache, and session tickets. /// Build a shared TLS acceptor with SNI resolution, session cache, and session tickets.
/// The returned acceptor can be reused across all connections (cheap Arc clone). /// The returned acceptor can be reused across all connections (cheap Arc clone).
pub fn build_shared_tls_acceptor(resolver: CertResolver) -> Result<TlsAcceptor, Box<dyn std::error::Error + Send + Sync>> { pub fn build_shared_tls_acceptor(
resolver: CertResolver,
) -> Result<TlsAcceptor, Box<dyn std::error::Error + Send + Sync>> {
ensure_crypto_provider(); ensure_crypto_provider();
let mut config = ServerConfig::builder() let mut config = ServerConfig::builder()
.with_no_client_auth() .with_no_client_auth()
@@ -90,22 +96,30 @@ pub fn build_shared_tls_acceptor(resolver: CertResolver) -> Result<TlsAcceptor,
// Shared session cache — enables session ID resumption across connections // Shared session cache — enables session ID resumption across connections
config.session_storage = rustls::server::ServerSessionMemoryCache::new(4096); config.session_storage = rustls::server::ServerSessionMemoryCache::new(4096);
// Session ticket resumption (12-hour lifetime, Chacha20Poly1305 encrypted) // Session ticket resumption (12-hour lifetime, Chacha20Poly1305 encrypted)
config.ticketer = rustls::crypto::ring::Ticketer::new() config.ticketer =
.map_err(|e| format!("Ticketer: {}", e))?; rustls::crypto::ring::Ticketer::new().map_err(|e| format!("Ticketer: {}", e))?;
info!("Built shared TLS config with session cache (4096), ticket support, and ALPN h2+http/1.1"); info!(
"Built shared TLS config with session cache (4096), ticket support, and ALPN h2+http/1.1"
);
Ok(TlsAcceptor::from(Arc::new(config))) Ok(TlsAcceptor::from(Arc::new(config)))
} }
/// Build a TLS acceptor from PEM-encoded cert and key data. /// Build a TLS acceptor from PEM-encoded cert and key data.
/// Advertises both h2 and http/1.1 via ALPN (for client-facing connections). /// Advertises both h2 and http/1.1 via ALPN (for client-facing connections).
pub fn build_tls_acceptor(cert_pem: &str, key_pem: &str) -> Result<TlsAcceptor, Box<dyn std::error::Error + Send + Sync>> { pub fn build_tls_acceptor(
cert_pem: &str,
key_pem: &str,
) -> Result<TlsAcceptor, Box<dyn std::error::Error + Send + Sync>> {
build_tls_acceptor_with_config(cert_pem, key_pem, None) build_tls_acceptor_with_config(cert_pem, key_pem, None)
} }
/// Build a TLS acceptor for backend servers that only speak HTTP/1.1. /// Build a TLS acceptor for backend servers that only speak HTTP/1.1.
/// Does NOT advertise h2 in ALPN, preventing false h2 auto-detection. /// Does NOT advertise h2 in ALPN, preventing false h2 auto-detection.
pub fn build_tls_acceptor_h1_only(cert_pem: &str, key_pem: &str) -> Result<TlsAcceptor, Box<dyn std::error::Error + Send + Sync>> { pub fn build_tls_acceptor_h1_only(
cert_pem: &str,
key_pem: &str,
) -> Result<TlsAcceptor, Box<dyn std::error::Error + Send + Sync>> {
ensure_crypto_provider(); ensure_crypto_provider();
let certs = load_certs(cert_pem)?; let certs = load_certs(cert_pem)?;
let key = load_private_key(key_pem)?; let key = load_private_key(key_pem)?;
@@ -130,9 +144,7 @@ pub fn build_tls_acceptor_with_config(
// Apply TLS version restrictions // Apply TLS version restrictions
let versions = resolve_tls_versions(route_tls.versions.as_deref()); let versions = resolve_tls_versions(route_tls.versions.as_deref());
let builder = ServerConfig::builder_with_protocol_versions(&versions); let builder = ServerConfig::builder_with_protocol_versions(&versions);
builder builder.with_no_client_auth().with_single_cert(certs, key)?
.with_no_client_auth()
.with_single_cert(certs, key)?
} else { } else {
ServerConfig::builder() ServerConfig::builder()
.with_no_client_auth() .with_no_client_auth()
@@ -156,7 +168,9 @@ pub fn build_tls_acceptor_with_config(
} }
/// Resolve TLS version strings to rustls SupportedProtocolVersion. /// Resolve TLS version strings to rustls SupportedProtocolVersion.
fn resolve_tls_versions(versions: Option<&[String]>) -> Vec<&'static rustls::SupportedProtocolVersion> { fn resolve_tls_versions(
versions: Option<&[String]>,
) -> Vec<&'static rustls::SupportedProtocolVersion> {
let versions = match versions { let versions = match versions {
Some(v) if !v.is_empty() => v, Some(v) if !v.is_empty() => v,
_ => return vec![&rustls::version::TLS12, &rustls::version::TLS13], _ => return vec![&rustls::version::TLS12, &rustls::version::TLS13],
@@ -207,15 +221,17 @@ pub async fn accept_tls(
static SHARED_CLIENT_CONFIG: OnceLock<Arc<rustls::ClientConfig>> = OnceLock::new(); static SHARED_CLIENT_CONFIG: OnceLock<Arc<rustls::ClientConfig>> = OnceLock::new();
pub fn shared_backend_tls_config() -> Arc<rustls::ClientConfig> { pub fn shared_backend_tls_config() -> Arc<rustls::ClientConfig> {
SHARED_CLIENT_CONFIG.get_or_init(|| { SHARED_CLIENT_CONFIG
ensure_crypto_provider(); .get_or_init(|| {
let config = rustls::ClientConfig::builder() ensure_crypto_provider();
.dangerous() let config = rustls::ClientConfig::builder()
.with_custom_certificate_verifier(Arc::new(InsecureVerifier)) .dangerous()
.with_no_client_auth(); .with_custom_certificate_verifier(Arc::new(InsecureVerifier))
info!("Built shared backend TLS client config with session resumption"); .with_no_client_auth();
Arc::new(config) info!("Built shared backend TLS client config with session resumption");
}).clone() Arc::new(config)
})
.clone()
} }
/// Get or create a shared backend TLS `ClientConfig` with ALPN `h2` + `http/1.1`. /// Get or create a shared backend TLS `ClientConfig` with ALPN `h2` + `http/1.1`.
@@ -225,16 +241,20 @@ pub fn shared_backend_tls_config() -> Arc<rustls::ClientConfig> {
static SHARED_CLIENT_CONFIG_ALPN: OnceLock<Arc<rustls::ClientConfig>> = OnceLock::new(); static SHARED_CLIENT_CONFIG_ALPN: OnceLock<Arc<rustls::ClientConfig>> = OnceLock::new();
pub fn shared_backend_tls_config_alpn() -> Arc<rustls::ClientConfig> { pub fn shared_backend_tls_config_alpn() -> Arc<rustls::ClientConfig> {
SHARED_CLIENT_CONFIG_ALPN.get_or_init(|| { SHARED_CLIENT_CONFIG_ALPN
ensure_crypto_provider(); .get_or_init(|| {
let mut config = rustls::ClientConfig::builder() ensure_crypto_provider();
.dangerous() let mut config = rustls::ClientConfig::builder()
.with_custom_certificate_verifier(Arc::new(InsecureVerifier)) .dangerous()
.with_no_client_auth(); .with_custom_certificate_verifier(Arc::new(InsecureVerifier))
config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; .with_no_client_auth();
info!("Built shared backend TLS client config with ALPN h2+http/1.1 for auto-detection"); config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
Arc::new(config) info!(
}).clone() "Built shared backend TLS client config with ALPN h2+http/1.1 for auto-detection"
);
Arc::new(config)
})
.clone()
} }
/// Connect to a backend with TLS (for terminate-and-reencrypt mode). /// Connect to a backend with TLS (for terminate-and-reencrypt mode).
@@ -249,7 +269,8 @@ pub async fn connect_tls(
let stream = TcpStream::connect(format!("{}:{}", host, port)).await?; let stream = TcpStream::connect(format!("{}:{}", host, port)).await?;
stream.set_nodelay(true)?; stream.set_nodelay(true)?;
// Apply keepalive with 60s default (tls_handler doesn't have ConnectionConfig access) // Apply keepalive with 60s default (tls_handler doesn't have ConnectionConfig access)
if let Err(e) = crate::socket_opts::apply_keepalive(&stream, std::time::Duration::from_secs(60)) { if let Err(e) = crate::socket_opts::apply_keepalive(&stream, std::time::Duration::from_secs(60))
{
debug!("Failed to set keepalive on backend TLS socket: {}", e); debug!("Failed to set keepalive on backend TLS socket: {}", e);
} }
@@ -260,10 +281,12 @@ pub async fn connect_tls(
} }
/// Load certificates from PEM string. /// Load certificates from PEM string.
fn load_certs(pem: &str) -> Result<Vec<CertificateDer<'static>>, Box<dyn std::error::Error + Send + Sync>> { fn load_certs(
pem: &str,
) -> Result<Vec<CertificateDer<'static>>, Box<dyn std::error::Error + Send + Sync>> {
let mut reader = BufReader::new(pem.as_bytes()); let mut reader = BufReader::new(pem.as_bytes());
let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut reader) let certs: Vec<CertificateDer<'static>> =
.collect::<Result<Vec<_>, _>>()?; rustls_pemfile::certs(&mut reader).collect::<Result<Vec<_>, _>>()?;
if certs.is_empty() { if certs.is_empty() {
return Err("No certificates found in PEM data".into()); return Err("No certificates found in PEM data".into());
} }
@@ -271,11 +294,13 @@ fn load_certs(pem: &str) -> Result<Vec<CertificateDer<'static>>, Box<dyn std::er
} }
/// Load private key from PEM string. /// Load private key from PEM string.
fn load_private_key(pem: &str) -> Result<PrivateKeyDer<'static>, Box<dyn std::error::Error + Send + Sync>> { fn load_private_key(
pem: &str,
) -> Result<PrivateKeyDer<'static>, Box<dyn std::error::Error + Send + Sync>> {
let mut reader = BufReader::new(pem.as_bytes()); let mut reader = BufReader::new(pem.as_bytes());
// Try PKCS8 first, then RSA, then EC // Try PKCS8 first, then RSA, then EC
let key = rustls_pemfile::private_key(&mut reader)? let key =
.ok_or("No private key found in PEM data")?; rustls_pemfile::private_key(&mut reader)?.ok_or("No private key found in PEM data")?;
Ok(key) Ok(key)
} }
@@ -17,14 +17,15 @@ use tokio::io::{AsyncReadExt, AsyncWriteExt};
use arc_swap::ArcSwap; use arc_swap::ArcSwap;
use tokio::net::UdpSocket; use tokio::net::UdpSocket;
use tokio::task::JoinHandle;
use tokio::sync::{Mutex, RwLock}; use tokio::sync::{Mutex, RwLock};
use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use tracing::{debug, error, info, warn}; use tracing::{debug, error, info, warn};
use rustproxy_config::{RouteActionType, TransportProtocol}; use rustproxy_config::{RouteActionType, TransportProtocol};
use rustproxy_metrics::MetricsCollector; use rustproxy_metrics::MetricsCollector;
use rustproxy_routing::{MatchContext, RouteManager}; use rustproxy_routing::{MatchContext, RouteManager};
use rustproxy_security::IpBlockList;
use rustproxy_http::h3_service::H3ProxyService; use rustproxy_http::h3_service::H3ProxyService;
@@ -62,6 +63,8 @@ pub struct UdpListenerManager {
route_cancels: Arc<DashMap<String, CancellationToken>>, route_cancels: Arc<DashMap<String, CancellationToken>>,
/// Shared connection registry for selective recycling. /// Shared connection registry for selective recycling.
connection_registry: Arc<ConnectionRegistry>, connection_registry: Arc<ConnectionRegistry>,
/// Global ingress block policy, hot-reloadable without restarting listeners.
security_policy: Arc<ArcSwap<IpBlockList>>,
} }
impl Drop for UdpListenerManager { impl Drop for UdpListenerManager {
@@ -99,17 +102,26 @@ impl UdpListenerManager {
proxy_ips: Arc::new(Vec::new()), proxy_ips: Arc::new(Vec::new()),
route_cancels, route_cancels,
connection_registry, connection_registry,
security_policy: Arc::new(ArcSwap::from(Arc::new(IpBlockList::empty()))),
} }
} }
/// Set the trusted proxy IPs for PROXY protocol v2 detection. /// Set the trusted proxy IPs for PROXY protocol v2 detection.
pub fn set_proxy_ips(&mut self, ips: Vec<IpAddr>) { pub fn set_proxy_ips(&mut self, ips: Vec<IpAddr>) {
if !ips.is_empty() { if !ips.is_empty() {
info!("UDP/QUIC PROXY protocol v2 enabled for {} trusted IPs", ips.len()); info!(
"UDP/QUIC PROXY protocol v2 enabled for {} trusted IPs",
ips.len()
);
} }
self.proxy_ips = Arc::new(ips); self.proxy_ips = Arc::new(ips);
} }
/// Set the shared global ingress security policy.
pub fn set_security_policy(&mut self, policy: Arc<ArcSwap<IpBlockList>>) {
self.security_policy = policy;
}
/// Set the H3 proxy service for HTTP/3 request handling. /// Set the H3 proxy service for HTTP/3 request handling.
pub fn set_h3_service(&mut self, svc: Arc<H3ProxyService>) { pub fn set_h3_service(&mut self, svc: Arc<H3ProxyService>) {
self.h3_service = Some(svc); self.h3_service = Some(svc);
@@ -142,7 +154,9 @@ impl UdpListenerManager {
// Check if any route on this port uses QUIC // Check if any route on this port uses QUIC
let rm = self.route_manager.load(); let rm = self.route_manager.load();
let has_quic = rm.routes_for_port(port).iter().any(|r| { let has_quic = rm.routes_for_port(port).iter().any(|r| {
r.action.udp.as_ref() r.action
.udp
.as_ref()
.and_then(|u| u.quic.as_ref()) .and_then(|u| u.quic.as_ref())
.is_some() .is_some()
}); });
@@ -164,8 +178,10 @@ impl UdpListenerManager {
None, None,
Arc::clone(&self.route_cancels), Arc::clone(&self.route_cancels),
Arc::clone(&self.connection_registry), Arc::clone(&self.connection_registry),
Arc::clone(&self.security_policy),
)); ));
self.listeners.insert(port, (handle, Some(endpoint_for_updates))); self.listeners
.insert(port, (handle, Some(endpoint_for_updates)));
info!("QUIC endpoint started on port {}", port); info!("QUIC endpoint started on port {}", port);
} else { } else {
// Proxy relay path: we own external socket, quinn on localhost // Proxy relay path: we own external socket, quinn on localhost
@@ -173,6 +189,7 @@ impl UdpListenerManager {
port, port,
tls, tls,
Arc::clone(&self.proxy_ips), Arc::clone(&self.proxy_ips),
Arc::clone(&self.security_policy),
self.cancel_token.child_token(), self.cancel_token.child_token(),
)?; )?;
let endpoint_for_updates = relay.endpoint.clone(); let endpoint_for_updates = relay.endpoint.clone();
@@ -187,13 +204,18 @@ impl UdpListenerManager {
Some(relay.real_client_map), Some(relay.real_client_map),
Arc::clone(&self.route_cancels), Arc::clone(&self.route_cancels),
Arc::clone(&self.connection_registry), Arc::clone(&self.connection_registry),
Arc::clone(&self.security_policy),
)); ));
self.listeners.insert(port, (handle, Some(endpoint_for_updates))); self.listeners
.insert(port, (handle, Some(endpoint_for_updates)));
info!("QUIC endpoint with PROXY relay started on port {}", port); info!("QUIC endpoint with PROXY relay started on port {}", port);
} }
return Ok(()); return Ok(());
} else { } else {
warn!("QUIC routes on port {} but no TLS config provided, falling back to raw UDP", port); warn!(
"QUIC routes on port {} but no TLS config provided, falling back to raw UDP",
port
);
} }
} }
@@ -214,6 +236,7 @@ impl UdpListenerManager {
Arc::clone(&self.relay_writer), Arc::clone(&self.relay_writer),
self.cancel_token.child_token(), self.cancel_token.child_token(),
Arc::clone(&self.proxy_ips), Arc::clone(&self.proxy_ips),
Arc::clone(&self.security_policy),
)); ));
self.listeners.insert(port, (handle, None)); self.listeners.insert(port, (handle, None));
@@ -254,8 +277,10 @@ impl UdpListenerManager {
} }
debug!("UDP listener stopped on port {}", port); debug!("UDP listener stopped on port {}", port);
} }
info!("All UDP listeners stopped, {} sessions remaining", info!(
self.session_table.session_count()); "All UDP listeners stopped, {} sessions remaining",
self.session_table.session_count()
);
} }
/// Update TLS config on all active QUIC endpoints (cert refresh). /// Update TLS config on all active QUIC endpoints (cert refresh).
@@ -288,11 +313,15 @@ impl UdpListenerManager {
pub async fn upgrade_raw_to_quic(&mut self, tls_config: Arc<rustls::ServerConfig>) { pub async fn upgrade_raw_to_quic(&mut self, tls_config: Arc<rustls::ServerConfig>) {
// Find ports that are raw UDP fallback (endpoint=None) but have QUIC routes // Find ports that are raw UDP fallback (endpoint=None) but have QUIC routes
let rm = self.route_manager.load(); let rm = self.route_manager.load();
let upgrade_ports: Vec<u16> = self.listeners.iter() let upgrade_ports: Vec<u16> = self
.listeners
.iter()
.filter(|(_, (_, endpoint))| endpoint.is_none()) .filter(|(_, (_, endpoint))| endpoint.is_none())
.filter(|(port, _)| { .filter(|(port, _)| {
rm.routes_for_port(**port).iter().any(|r| { rm.routes_for_port(**port).iter().any(|r| {
r.action.udp.as_ref() r.action
.udp
.as_ref()
.and_then(|u| u.quic.as_ref()) .and_then(|u| u.quic.as_ref())
.is_some() .is_some()
}) })
@@ -301,17 +330,23 @@ impl UdpListenerManager {
.collect(); .collect();
for port in upgrade_ports { for port in upgrade_ports {
info!("Upgrading raw UDP listener on port {} to QUIC endpoint", port); info!(
"Upgrading raw UDP listener on port {} to QUIC endpoint",
port
);
// Stop the raw UDP listener task and drain sessions to release the socket // Stop the raw UDP listener task and drain sessions to release the socket
if let Some((handle, _)) = self.listeners.remove(&port) { if let Some((handle, _)) = self.listeners.remove(&port) {
handle.abort(); handle.abort();
} }
let drained = self.session_table.drain_port( let drained = self
port, &self.metrics, &self.conn_tracker, .session_table
); .drain_port(port, &self.metrics, &self.conn_tracker);
if drained > 0 { if drained > 0 {
debug!("Drained {} UDP sessions on port {} for QUIC upgrade", drained, port); debug!(
"Drained {} UDP sessions on port {} for QUIC upgrade",
drained, port
);
} }
// Brief yield to let aborted tasks drop their socket references // Brief yield to let aborted tasks drop their socket references
@@ -326,11 +361,17 @@ impl UdpListenerManager {
match create_result { match create_result {
Ok(()) => { Ok(()) => {
info!("QUIC endpoint started on port {} (upgraded from raw UDP)", port); info!(
"QUIC endpoint started on port {} (upgraded from raw UDP)",
port
);
} }
Err(e) => { Err(e) => {
// Port may still be held — retry once after a brief delay // Port may still be held — retry once after a brief delay
warn!("QUIC endpoint creation failed on port {}, retrying: {}", port, e); warn!(
"QUIC endpoint creation failed on port {}, retrying: {}",
port, e
);
tokio::time::sleep(std::time::Duration::from_millis(50)).await; tokio::time::sleep(std::time::Duration::from_millis(50)).await;
let retry_result = if self.proxy_ips.is_empty() { let retry_result = if self.proxy_ips.is_empty() {
@@ -341,11 +382,17 @@ impl UdpListenerManager {
match retry_result { match retry_result {
Ok(()) => { Ok(()) => {
info!("QUIC endpoint started on port {} (upgraded from raw UDP, retry)", port); info!(
"QUIC endpoint started on port {} (upgraded from raw UDP, retry)",
port
);
} }
Err(e2) => { Err(e2) => {
error!("Failed to upgrade port {} to QUIC after retry: {}. \ error!(
Rebinding as raw UDP.", port, e2); "Failed to upgrade port {} to QUIC after retry: {}. \
Rebinding as raw UDP.",
port, e2
);
// Fallback: rebind as raw UDP so the port isn't dead // Fallback: rebind as raw UDP so the port isn't dead
if let Ok(()) = self.rebind_raw_udp(port).await { if let Ok(()) = self.rebind_raw_udp(port).await {
warn!("Port {} rebound as raw UDP (QUIC upgrade failed)", port); warn!("Port {} rebound as raw UDP (QUIC upgrade failed)", port);
@@ -358,7 +405,11 @@ impl UdpListenerManager {
} }
/// Create a direct QUIC endpoint (quinn owns the socket). /// Create a direct QUIC endpoint (quinn owns the socket).
fn create_quic_direct(&mut self, port: u16, tls_config: Arc<rustls::ServerConfig>) -> anyhow::Result<()> { fn create_quic_direct(
&mut self,
port: u16,
tls_config: Arc<rustls::ServerConfig>,
) -> anyhow::Result<()> {
let endpoint = crate::quic_handler::create_quic_endpoint(port, tls_config)?; let endpoint = crate::quic_handler::create_quic_endpoint(port, tls_config)?;
let endpoint_for_updates = endpoint.clone(); let endpoint_for_updates = endpoint.clone();
let handle = tokio::spawn(crate::quic_handler::quic_accept_loop( let handle = tokio::spawn(crate::quic_handler::quic_accept_loop(
@@ -372,17 +423,24 @@ impl UdpListenerManager {
None, None,
Arc::clone(&self.route_cancels), Arc::clone(&self.route_cancels),
Arc::clone(&self.connection_registry), Arc::clone(&self.connection_registry),
Arc::clone(&self.security_policy),
)); ));
self.listeners.insert(port, (handle, Some(endpoint_for_updates))); self.listeners
.insert(port, (handle, Some(endpoint_for_updates)));
Ok(()) Ok(())
} }
/// Create a QUIC endpoint with PROXY protocol relay. /// Create a QUIC endpoint with PROXY protocol relay.
fn create_quic_with_relay(&mut self, port: u16, tls_config: Arc<rustls::ServerConfig>) -> anyhow::Result<()> { fn create_quic_with_relay(
&mut self,
port: u16,
tls_config: Arc<rustls::ServerConfig>,
) -> anyhow::Result<()> {
let relay = crate::quic_handler::create_quic_endpoint_with_proxy_relay( let relay = crate::quic_handler::create_quic_endpoint_with_proxy_relay(
port, port,
tls_config, tls_config,
Arc::clone(&self.proxy_ips), Arc::clone(&self.proxy_ips),
Arc::clone(&self.security_policy),
self.cancel_token.child_token(), self.cancel_token.child_token(),
)?; )?;
let endpoint_for_updates = relay.endpoint.clone(); let endpoint_for_updates = relay.endpoint.clone();
@@ -397,8 +455,10 @@ impl UdpListenerManager {
Some(relay.real_client_map), Some(relay.real_client_map),
Arc::clone(&self.route_cancels), Arc::clone(&self.route_cancels),
Arc::clone(&self.connection_registry), Arc::clone(&self.connection_registry),
Arc::clone(&self.security_policy),
)); ));
self.listeners.insert(port, (handle, Some(endpoint_for_updates))); self.listeners
.insert(port, (handle, Some(endpoint_for_updates)));
Ok(()) Ok(())
} }
@@ -419,6 +479,7 @@ impl UdpListenerManager {
Arc::clone(&self.relay_writer), Arc::clone(&self.relay_writer),
self.cancel_token.child_token(), self.cancel_token.child_token(),
Arc::clone(&self.proxy_ips), Arc::clone(&self.proxy_ips),
Arc::clone(&self.security_policy),
)); ));
self.listeners.insert(port, (handle, None)); self.listeners.insert(port, (handle, None));
@@ -458,7 +519,10 @@ impl UdpListenerManager {
info!("Datagram handler relay connected to {}", path); info!("Datagram handler relay connected to {}", path);
} }
Err(e) => { Err(e) => {
error!("Failed to connect datagram handler relay to {}: {}", path, e); error!(
"Failed to connect datagram handler relay to {}: {}",
path, e
);
} }
} }
} }
@@ -514,6 +578,7 @@ impl UdpListenerManager {
relay_writer: Arc<Mutex<Option<tokio::net::unix::OwnedWriteHalf>>>, relay_writer: Arc<Mutex<Option<tokio::net::unix::OwnedWriteHalf>>>,
cancel: CancellationToken, cancel: CancellationToken,
proxy_ips: Arc<Vec<IpAddr>>, proxy_ips: Arc<Vec<IpAddr>>,
security_policy: Arc<ArcSwap<IpBlockList>>,
) { ) {
// Use a reasonably large buffer; actual max is per-route but we need a single buffer // Use a reasonably large buffer; actual max is per-route but we need a single buffer
let mut buf = vec![0u8; 65535]; let mut buf = vec![0u8; 65535];
@@ -528,9 +593,11 @@ impl UdpListenerManager {
loop { loop {
// Periodic cleanup: remove proxy_addr_map entries with no active session // Periodic cleanup: remove proxy_addr_map entries with no active session
if !proxy_addr_map.is_empty() && last_proxy_cleanup.elapsed() >= proxy_cleanup_interval { if !proxy_addr_map.is_empty() && last_proxy_cleanup.elapsed() >= proxy_cleanup_interval
{
last_proxy_cleanup = tokio::time::Instant::now(); last_proxy_cleanup = tokio::time::Instant::now();
let stale: Vec<SocketAddr> = proxy_addr_map.iter() let stale: Vec<SocketAddr> = proxy_addr_map
.iter()
.filter(|entry| { .filter(|entry| {
let key: SessionKey = (*entry.key(), port); let key: SessionKey = (*entry.key(), port);
session_table.get(&key).is_none() session_table.get(&key).is_none()
@@ -538,7 +605,11 @@ impl UdpListenerManager {
.map(|entry| *entry.key()) .map(|entry| *entry.key())
.collect(); .collect();
if !stale.is_empty() { if !stale.is_empty() {
debug!("UDP proxy_addr_map cleanup: removing {} stale entries on port {}", stale.len(), port); debug!(
"UDP proxy_addr_map cleanup: removing {} stale entries on port {}",
stale.len(),
port
);
for addr in stale { for addr in stale {
proxy_addr_map.remove(&addr); proxy_addr_map.remove(&addr);
} }
@@ -564,34 +635,50 @@ impl UdpListenerManager {
let datagram = &buf[..len]; let datagram = &buf[..len];
// PROXY protocol v2 detection for datagrams from trusted proxy IPs // PROXY protocol v2 detection for datagrams from trusted proxy IPs
let effective_client_ip = if !proxy_ips.is_empty() && proxy_ips.contains(&client_addr.ip()) { let effective_client_ip =
let session_key: SessionKey = (client_addr, port); if !proxy_ips.is_empty() && proxy_ips.contains(&client_addr.ip()) {
if session_table.get(&session_key).is_none() && !proxy_addr_map.contains_key(&client_addr) { let session_key: SessionKey = (client_addr, port);
// No session and no prior PROXY header — check for PROXY v2 if session_table.get(&session_key).is_none()
if crate::proxy_protocol::is_proxy_protocol_v2(datagram) { && !proxy_addr_map.contains_key(&client_addr)
match crate::proxy_protocol::parse_v2(datagram) { {
Ok((header, _consumed)) => { // No session and no prior PROXY header — check for PROXY v2
debug!("UDP PROXY v2 from {}: real client {}", client_addr, header.source_addr); if crate::proxy_protocol::is_proxy_protocol_v2(datagram) {
proxy_addr_map.insert(client_addr, header.source_addr); match crate::proxy_protocol::parse_v2(datagram) {
continue; // discard the PROXY v2 datagram Ok((header, _consumed)) => {
} debug!(
Err(e) => { "UDP PROXY v2 from {}: real client {}",
debug!("UDP PROXY v2 parse error from {}: {}", client_addr, e); client_addr, header.source_addr
client_addr.ip() );
proxy_addr_map.insert(client_addr, header.source_addr);
continue; // discard the PROXY v2 datagram
}
Err(e) => {
debug!("UDP PROXY v2 parse error from {}: {}", client_addr, e);
client_addr.ip()
}
} }
} else {
client_addr.ip()
} }
} else { } else {
client_addr.ip() // Use real client IP if we've previously seen a PROXY v2 header
proxy_addr_map
.get(&client_addr)
.map(|r| r.ip())
.unwrap_or_else(|| client_addr.ip())
} }
} else { } else {
// Use real client IP if we've previously seen a PROXY v2 header client_addr.ip()
proxy_addr_map.get(&client_addr) };
.map(|r| r.ip())
.unwrap_or_else(|| client_addr.ip()) let block_list = security_policy.load();
} if !block_list.is_empty() && block_list.is_blocked(&effective_client_ip) {
} else { debug!(
client_addr.ip() "UDP datagram from {} blocked by global security policy",
}; effective_client_ip
);
continue;
}
// Route matching — use effective (real) client IP // Route matching — use effective (real) client IP
let rm = route_manager.load(); let rm = route_manager.load();
@@ -611,13 +698,16 @@ impl UdpListenerManager {
let route_match = match rm.find_route(&ctx) { let route_match = match rm.find_route(&ctx) {
Some(m) => m, Some(m) => m,
None => { None => {
debug!("No UDP route matched for port {} from {}", port, client_addr); debug!(
"No UDP route matched for port {} from {}",
port, client_addr
);
continue; continue;
} }
}; };
let route = route_match.route; let route = route_match.route;
let route_id = route.name.as_deref().or(route.id.as_deref()); let route_id = route.metrics_key();
// Socket handler routes → relay datagram to TS via persistent Unix socket // Socket handler routes → relay datagram to TS via persistent Unix socket
if route.action.action_type == RouteActionType::SocketHandler { if route.action.action_type == RouteActionType::SocketHandler {
@@ -627,7 +717,9 @@ impl UdpListenerManager {
&client_addr, &client_addr,
port, port,
datagram, datagram,
).await { )
.await
{
debug!("Failed to relay UDP datagram to TS: {}", e); debug!("Failed to relay UDP datagram to TS: {}", e);
} }
continue; continue;
@@ -638,8 +730,10 @@ impl UdpListenerManager {
// Check datagram size // Check datagram size
if len as u32 > udp_config.max_datagram_size { if len as u32 > udp_config.max_datagram_size {
debug!("UDP datagram too large ({} > {}) from {}, dropping", debug!(
len, udp_config.max_datagram_size, client_addr); "UDP datagram too large ({} > {}) from {}, dropping",
len, udp_config.max_datagram_size, client_addr
);
continue; continue;
} }
@@ -651,21 +745,27 @@ impl UdpListenerManager {
None => { None => {
// New session — check per-IP limits using the real client IP // New session — check per-IP limits using the real client IP
if !conn_tracker.try_accept(&effective_client_ip) { if !conn_tracker.try_accept(&effective_client_ip) {
debug!("UDP session rejected for {} (rate limit)", effective_client_ip); debug!(
"UDP session rejected for {} (rate limit)",
effective_client_ip
);
continue; continue;
} }
if !session_table.can_create_session( if !session_table
&effective_client_ip, .can_create_session(&effective_client_ip, udp_config.max_sessions_per_ip)
udp_config.max_sessions_per_ip, {
) { debug!(
debug!("UDP session rejected for {} (per-IP session limit)", effective_client_ip); "UDP session rejected for {} (per-IP session limit)",
effective_client_ip
);
continue; continue;
} }
// Resolve target // Resolve target
let target = match route_match.target.or_else(|| { let target = match route_match
route.action.targets.as_ref().and_then(|t| t.first()) .target
}) { .or_else(|| route.action.targets.as_ref().and_then(|t| t.first()))
{
Some(t) => t, Some(t) => t,
None => { None => {
warn!("No target for UDP route {:?}", route_id); warn!("No target for UDP route {:?}", route_id);
@@ -686,13 +786,18 @@ impl UdpListenerManager {
} }
}; };
if let Err(e) = backend_socket.connect(&backend_addr).await { if let Err(e) = backend_socket.connect(&backend_addr).await {
error!("Failed to connect backend UDP socket to {}: {}", backend_addr, e); error!(
"Failed to connect backend UDP socket to {}: {}",
backend_addr, e
);
continue; continue;
} }
let backend_socket = Arc::new(backend_socket); let backend_socket = Arc::new(backend_socket);
debug!("New UDP session: {} -> {} (via port {}, real client {})", debug!(
client_addr, backend_addr, port, effective_client_ip); "New UDP session: {} -> {} (via port {}, real client {})",
client_addr, backend_addr, port, effective_client_ip
);
// Spawn return-path relay task // Spawn return-path relay task
let session_cancel = CancellationToken::new(); let session_cancel = CancellationToken::new();
@@ -709,7 +814,9 @@ impl UdpListenerManager {
let session = Arc::new(UdpSession { let session = Arc::new(UdpSession {
backend_socket, backend_socket,
last_activity: std::sync::atomic::AtomicU64::new(session_table.elapsed_ms()), last_activity: std::sync::atomic::AtomicU64::new(
session_table.elapsed_ms(),
),
created_at: std::time::Instant::now(), created_at: std::time::Instant::now(),
route_id: route_id.map(|s| s.to_string()), route_id: route_id.map(|s| s.to_string()),
source_ip: effective_client_ip, source_ip: effective_client_ip,
@@ -718,7 +825,11 @@ impl UdpListenerManager {
cancel: session_cancel, cancel: session_cancel,
}); });
if !session_table.insert(session_key, Arc::clone(&session), udp_config.max_sessions_per_ip) { if !session_table.insert(
session_key,
Arc::clone(&session),
udp_config.max_sessions_per_ip,
) {
warn!("Failed to insert UDP session (race condition)"); warn!("Failed to insert UDP session (race condition)");
continue; continue;
} }
@@ -735,7 +846,9 @@ impl UdpListenerManager {
// Forward datagram to backend // Forward datagram to backend
match session.backend_socket.send(datagram).await { match session.backend_socket.send(datagram).await {
Ok(_) => { Ok(_) => {
session.last_activity.store(session_table.elapsed_ms(), Ordering::Relaxed); session
.last_activity
.store(session_table.elapsed_ms(), Ordering::Relaxed);
metrics.record_bytes(len as u64, 0, route_id, Some(&ip_str)); metrics.record_bytes(len as u64, 0, route_id, Some(&ip_str));
metrics.record_datagram_in(); metrics.record_datagram_in();
} }
@@ -779,7 +892,9 @@ impl UdpListenerManager {
Ok(_) => { Ok(_) => {
// Update session activity // Update session activity
if let Some(session) = session_table.get(&session_key) { if let Some(session) = session_table.get(&session_key) {
session.last_activity.store(session_table.elapsed_ms(), Ordering::Relaxed); session
.last_activity
.store(session_table.elapsed_ms(), Ordering::Relaxed);
} }
metrics.record_bytes(0, len as u64, route_id.as_deref(), Some(&ip_str)); metrics.record_bytes(0, len as u64, route_id.as_deref(), Some(&ip_str));
metrics.record_datagram_out(); metrics.record_datagram_out();
@@ -814,7 +929,8 @@ impl UdpListenerManager {
let json = serde_json::to_vec(&msg)?; let json = serde_json::to_vec(&msg)?;
let mut guard = writer.lock().await; let mut guard = writer.lock().await;
let stream = guard.as_mut() let stream = guard
.as_mut()
.ok_or_else(|| anyhow::anyhow!("Datagram relay not connected"))?; .ok_or_else(|| anyhow::anyhow!("Datagram relay not connected"))?;
// Length-prefixed frame // Length-prefixed frame
@@ -879,9 +995,15 @@ impl UdpListenerManager {
} }
let source_ip = reply.get("sourceIp").and_then(|v| v.as_str()).unwrap_or(""); let source_ip = reply.get("sourceIp").and_then(|v| v.as_str()).unwrap_or("");
let source_port = reply.get("sourcePort").and_then(|v| v.as_u64()).unwrap_or(0) as u16; let source_port = reply
.get("sourcePort")
.and_then(|v| v.as_u64())
.unwrap_or(0) as u16;
let dest_port = reply.get("destPort").and_then(|v| v.as_u64()).unwrap_or(0) as u16; let dest_port = reply.get("destPort").and_then(|v| v.as_u64()).unwrap_or(0) as u16;
let payload_b64 = reply.get("payloadBase64").and_then(|v| v.as_str()).unwrap_or(""); let payload_b64 = reply
.get("payloadBase64")
.and_then(|v| v.as_str())
.unwrap_or("");
let payload = match base64::engine::general_purpose::STANDARD.decode(payload_b64) { let payload = match base64::engine::general_purpose::STANDARD.decode(payload_b64) {
Ok(p) => p, Ok(p) => p,
@@ -111,12 +111,15 @@ impl UdpSessionTable {
/// Look up an existing session. /// Look up an existing session.
pub fn get(&self, key: &SessionKey) -> Option<Arc<UdpSession>> { pub fn get(&self, key: &SessionKey) -> Option<Arc<UdpSession>> {
self.sessions.get(key).map(|entry| Arc::clone(entry.value())) self.sessions
.get(key)
.map(|entry| Arc::clone(entry.value()))
} }
/// Check if we can create a new session for this IP (under the per-IP limit). /// Check if we can create a new session for this IP (under the per-IP limit).
pub fn can_create_session(&self, ip: &IpAddr, max_per_ip: u32) -> bool { pub fn can_create_session(&self, ip: &IpAddr, max_per_ip: u32) -> bool {
let count = self.ip_session_counts let count = self
.ip_session_counts
.get(ip) .get(ip)
.map(|c| *c.value()) .map(|c| *c.value())
.unwrap_or(0); .unwrap_or(0);
@@ -124,12 +127,7 @@ impl UdpSessionTable {
} }
/// Insert a new session. Returns false if per-IP limit exceeded. /// Insert a new session. Returns false if per-IP limit exceeded.
pub fn insert( pub fn insert(&self, key: SessionKey, session: Arc<UdpSession>, max_per_ip: u32) -> bool {
&self,
key: SessionKey,
session: Arc<UdpSession>,
max_per_ip: u32,
) -> bool {
let ip = session.source_ip; let ip = session.source_ip;
// Atomically check and increment per-IP count // Atomically check and increment per-IP count
@@ -173,7 +171,9 @@ impl UdpSessionTable {
let mut removed = 0; let mut removed = 0;
// Collect keys to remove (avoid holding DashMap refs during removal) // Collect keys to remove (avoid holding DashMap refs during removal)
let stale_keys: Vec<SessionKey> = self.sessions.iter() let stale_keys: Vec<SessionKey> = self
.sessions
.iter()
.filter(|entry| { .filter(|entry| {
let last = entry.value().last_activity.load(Ordering::Relaxed); let last = entry.value().last_activity.load(Ordering::Relaxed);
now_ms.saturating_sub(last) >= timeout_ms now_ms.saturating_sub(last) >= timeout_ms
@@ -185,7 +185,8 @@ impl UdpSessionTable {
if let Some(session) = self.remove(&key) { if let Some(session) = self.remove(&key) {
debug!( debug!(
"UDP session expired: {} -> port {} (idle {}ms)", "UDP session expired: {} -> port {} (idle {}ms)",
session.client_addr, key.1, session.client_addr,
key.1,
now_ms.saturating_sub(session.last_activity.load(Ordering::Relaxed)) now_ms.saturating_sub(session.last_activity.load(Ordering::Relaxed))
); );
conn_tracker.connection_closed(&session.source_ip); conn_tracker.connection_closed(&session.source_ip);
@@ -210,7 +211,9 @@ impl UdpSessionTable {
metrics: &MetricsCollector, metrics: &MetricsCollector,
conn_tracker: &ConnectionTracker, conn_tracker: &ConnectionTracker,
) -> usize { ) -> usize {
let keys: Vec<SessionKey> = self.sessions.iter() let keys: Vec<SessionKey> = self
.sessions
.iter()
.filter(|entry| entry.key().1 == port) .filter(|entry| entry.key().1 == port)
.map(|entry| *entry.key()) .map(|entry| *entry.key())
.collect(); .collect();
@@ -257,9 +260,8 @@ mod tests {
.enable_all() .enable_all()
.build() .build()
.unwrap(); .unwrap();
let backend_socket = rt.block_on(async { let backend_socket =
Arc::new(UdpSocket::bind("127.0.0.1:0").await.unwrap()) rt.block_on(async { Arc::new(UdpSocket::bind("127.0.0.1:0").await.unwrap()) });
});
let child_cancel = cancel.child_token(); let child_cancel = cancel.child_token();
let return_task = rt.spawn(async move { let return_task = rt.spawn(async move {
+1 -1
View File
@@ -3,7 +3,7 @@
//! Route matching engine for RustProxy. //! Route matching engine for RustProxy.
//! Provides domain/path/IP/header matchers and a port-indexed RouteManager. //! Provides domain/path/IP/header matchers and a port-indexed RouteManager.
pub mod route_manager;
pub mod matchers; pub mod matchers;
pub mod route_manager;
pub use route_manager::*; pub use route_manager::*;
@@ -20,7 +20,7 @@ pub fn domain_matches(pattern: &str, domain: &str) -> bool {
// Wildcard patterns // Wildcard patterns
if pattern.starts_with("*.") || pattern.starts_with("*.") { if pattern.starts_with("*.") || pattern.starts_with("*.") {
let suffix = &pattern[2..]; // e.g., "example.com" let suffix = &pattern[2..]; // e.g., "example.com"
// Match exact parent or any single-level subdomain // Match exact parent or any single-level subdomain
if domain.eq_ignore_ascii_case(suffix) { if domain.eq_ignore_ascii_case(suffix) {
return true; return true;
} }
@@ -1,5 +1,42 @@
use std::collections::HashMap;
use regex::Regex; use regex::Regex;
use std::collections::HashMap;
fn compile_regex_pattern(pattern: &str) -> Option<Regex> {
if !pattern.starts_with('/') {
return None;
}
let last_slash = pattern.rfind('/')?;
if last_slash == 0 {
return None;
}
let regex_body = &pattern[1..last_slash];
let flags = &pattern[last_slash + 1..];
let mut inline_flags = String::new();
for flag in flags.chars() {
match flag {
'i' | 'm' | 's' | 'u' => {
if !inline_flags.contains(flag) {
inline_flags.push(flag);
}
}
'g' => {
// Global has no effect for single header matching.
}
_ => return None,
}
}
let compiled = if inline_flags.is_empty() {
regex_body.to_string()
} else {
format!("(?{}){}", inline_flags, regex_body)
};
Regex::new(&compiled).ok()
}
/// Match HTTP headers against a set of patterns. /// Match HTTP headers against a set of patterns.
/// ///
@@ -24,16 +61,15 @@ pub fn headers_match(
None => return false, // Required header not present None => return false, // Required header not present
}; };
// Check if pattern is a regex (surrounded by /) // Check if pattern is a regex literal (/pattern/ or /pattern/flags)
if pattern.starts_with('/') && pattern.ends_with('/') && pattern.len() > 2 { if pattern.starts_with('/') && pattern.len() > 2 {
let regex_str = &pattern[1..pattern.len() - 1]; match compile_regex_pattern(pattern) {
match Regex::new(regex_str) { Some(re) => {
Ok(re) => {
if !re.is_match(header_value) { if !re.is_match(header_value) {
return false; return false;
} }
} }
Err(_) => { None => {
// Invalid regex, fall back to exact match // Invalid regex, fall back to exact match
if header_value != pattern { if header_value != pattern {
return false; return false;
@@ -85,6 +121,24 @@ mod tests {
assert!(headers_match(&patterns, &headers)); assert!(headers_match(&patterns, &headers));
} }
#[test]
fn test_regex_header_match_with_flags() {
let patterns: HashMap<String, String> = {
let mut m = HashMap::new();
m.insert(
"Content-Type".to_string(),
"/^application\\/json$/i".to_string(),
);
m
};
let headers: HashMap<String, String> = {
let mut m = HashMap::new();
m.insert("content-type".to_string(), "Application/JSON".to_string());
m
};
assert!(headers_match(&patterns, &headers));
}
#[test] #[test]
fn test_missing_header() { fn test_missing_header() {
let patterns: HashMap<String, String> = { let patterns: HashMap<String, String> = {
@@ -1,6 +1,6 @@
use ipnet::IpNet;
use std::net::IpAddr; use std::net::IpAddr;
use std::str::FromStr; use std::str::FromStr;
use ipnet::IpNet;
/// Match an IP address against a pattern. /// Match an IP address against a pattern.
/// ///
@@ -85,7 +85,10 @@ fn wildcard_to_cidr(pattern: &str) -> Option<String> {
} }
} }
Some(format!("{}.{}.{}.{}/{}", octets[0], octets[1], octets[2], octets[3], prefix_len)) Some(format!(
"{}.{}.{}.{}/{}",
octets[0], octets[1], octets[2], octets[3], prefix_len
))
} }
#[cfg(test)] #[cfg(test)]
@@ -1,9 +1,9 @@
pub mod domain; pub mod domain;
pub mod path;
pub mod ip;
pub mod header; pub mod header;
pub mod ip;
pub mod path;
pub use domain::*; pub use domain::*;
pub use path::*;
pub use ip::*;
pub use header::*; pub use header::*;
pub use ip::*;
pub use path::*;
@@ -1,7 +1,7 @@
use std::collections::HashMap; use std::collections::HashMap;
use rustproxy_config::{RouteConfig, RouteTarget, TransportProtocol, TlsMode};
use crate::matchers; use crate::matchers;
use rustproxy_config::{RouteConfig, RouteTarget, TlsMode, TransportProtocol};
/// Context for route matching (subset of connection info). /// Context for route matching (subset of connection info).
pub struct MatchContext<'a> { pub struct MatchContext<'a> {
@@ -42,19 +42,14 @@ impl RouteManager {
}; };
// Filter enabled routes and sort by priority // Filter enabled routes and sort by priority
let mut enabled_routes: Vec<RouteConfig> = routes let mut enabled_routes: Vec<RouteConfig> =
.into_iter() routes.into_iter().filter(|r| r.is_enabled()).collect();
.filter(|r| r.is_enabled())
.collect();
enabled_routes.sort_by(|a, b| b.effective_priority().cmp(&a.effective_priority())); enabled_routes.sort_by(|a, b| b.effective_priority().cmp(&a.effective_priority()));
// Build port index // Build port index
for (idx, route) in enabled_routes.iter().enumerate() { for (idx, route) in enabled_routes.iter().enumerate() {
for port in route.listening_ports() { for port in route.listening_ports() {
manager.port_index manager.port_index.entry(port).or_default().push(idx);
.entry(port)
.or_default()
.push(idx);
} }
} }
@@ -66,7 +61,9 @@ impl RouteManager {
/// Used to skip expensive header HashMap construction when no route needs it. /// Used to skip expensive header HashMap construction when no route needs it.
pub fn any_route_has_headers(&self, port: u16) -> bool { pub fn any_route_has_headers(&self, port: u16) -> bool {
if let Some(indices) = self.port_index.get(&port) { if let Some(indices) = self.port_index.get(&port) {
indices.iter().any(|&idx| self.routes[idx].route_match.headers.is_some()) indices
.iter()
.any(|&idx| self.routes[idx].route_match.headers.is_some())
} else { } else {
false false
} }
@@ -99,8 +96,8 @@ impl RouteManager {
let ctx_transport = ctx.transport.as_ref(); let ctx_transport = ctx.transport.as_ref();
match (route_transport, ctx_transport) { match (route_transport, ctx_transport) {
// Route requires UDP only — reject non-UDP contexts // Route requires UDP only — reject non-UDP contexts
(Some(TransportProtocol::Udp), None) | (Some(TransportProtocol::Udp), None)
(Some(TransportProtocol::Udp), Some(TransportProtocol::Tcp)) => return false, | (Some(TransportProtocol::Udp), Some(TransportProtocol::Tcp)) => return false,
// Route requires TCP only — reject UDP contexts // Route requires TCP only — reject UDP contexts
(Some(TransportProtocol::Tcp), Some(TransportProtocol::Udp)) => return false, (Some(TransportProtocol::Tcp), Some(TransportProtocol::Udp)) => return false,
// Route has no transport (default = TCP) — reject UDP contexts // Route has no transport (default = TCP) — reject UDP contexts
@@ -196,7 +193,11 @@ impl RouteManager {
} }
/// Find the best matching target within a route. /// Find the best matching target within a route.
fn find_target<'a>(&self, route: &'a RouteConfig, ctx: &MatchContext<'_>) -> Option<&'a RouteTarget> { fn find_target<'a>(
&self,
route: &'a RouteConfig,
ctx: &MatchContext<'_>,
) -> Option<&'a RouteTarget> {
let targets = route.action.targets.as_ref()?; let targets = route.action.targets.as_ref()?;
if targets.len() == 1 && targets[0].target_match.is_none() { if targets.len() == 1 && targets[0].target_match.is_none() {
@@ -223,17 +224,11 @@ impl RouteManager {
} }
// Fall back to first target without match criteria // Fall back to first target without match criteria
best.or_else(|| { best.or_else(|| targets.iter().find(|t| t.target_match.is_none()))
targets.iter().find(|t| t.target_match.is_none())
})
} }
/// Check if a target match criteria matches the context. /// Check if a target match criteria matches the context.
fn matches_target( fn matches_target(&self, tm: &rustproxy_config::TargetMatch, ctx: &MatchContext<'_>) -> bool {
&self,
tm: &rustproxy_config::TargetMatch,
ctx: &MatchContext<'_>,
) -> bool {
// Port matching // Port matching
if let Some(ref ports) = tm.ports { if let Some(ref ports) = tm.ports {
if !ports.contains(&ctx.port) { if !ports.contains(&ctx.port) {
@@ -298,9 +293,7 @@ impl RouteManager {
// If multiple passthrough routes on same port, SNI is needed // If multiple passthrough routes on same port, SNI is needed
let passthrough_routes: Vec<_> = routes let passthrough_routes: Vec<_> = routes
.iter() .iter()
.filter(|r| { .filter(|r| r.tls_mode() == Some(&TlsMode::Passthrough))
r.tls_mode() == Some(&TlsMode::Passthrough)
})
.collect(); .collect();
if passthrough_routes.len() > 1 { if passthrough_routes.len() > 1 {
@@ -419,7 +412,11 @@ mod tests {
let result = manager.find_route(&ctx).unwrap(); let result = manager.find_route(&ctx).unwrap();
// Should match the higher-priority specific route // Should match the higher-priority specific route
assert!(result.route.route_match.domains.as_ref() assert!(result
.route
.route_match
.domains
.as_ref()
.map(|d| d.to_vec()) .map(|d| d.to_vec())
.unwrap() .unwrap()
.contains(&"api.example.com")); .contains(&"api.example.com"));
@@ -619,8 +616,14 @@ mod tests {
let result = manager.find_route(&ctx); let result = manager.find_route(&ctx);
assert!(result.is_some()); assert!(result.is_some());
let matched_domains = result.unwrap().route.route_match.domains.as_ref() let matched_domains = result
.map(|d| d.to_vec()).unwrap(); .unwrap()
.route
.route_match
.domains
.as_ref()
.map(|d| d.to_vec())
.unwrap();
assert!(matched_domains.contains(&"*")); assert!(matched_domains.contains(&"*"));
} }
@@ -735,7 +738,11 @@ mod tests {
assert_eq!(result.target.unwrap().host.first(), "default-backend"); assert_eq!(result.target.unwrap().host.first(), "default-backend");
} }
fn make_route_with_protocol(port: u16, domain: Option<&str>, protocol: Option<&str>) -> RouteConfig { fn make_route_with_protocol(
port: u16,
domain: Option<&str>,
protocol: Option<&str>,
) -> RouteConfig {
let mut route = make_route(port, domain, 0); let mut route = make_route(port, domain, 0);
route.route_match.protocol = protocol.map(|s| s.to_string()); route.route_match.protocol = protocol.map(|s| s.to_string());
route route
@@ -1029,8 +1036,10 @@ mod tests {
transport: Some(TransportProtocol::Udp), transport: Some(TransportProtocol::Udp),
}; };
assert!(manager.find_route(&ctx).is_some(), assert!(
"QUIC (UDP) with is_tls=true and domain=None should match domain-restricted routes"); manager.find_route(&ctx).is_some(),
"QUIC (UDP) with is_tls=true and domain=None should match domain-restricted routes"
);
} }
#[test] #[test]
@@ -1051,7 +1060,9 @@ mod tests {
transport: None, // TCP (default) transport: None, // TCP (default)
}; };
assert!(manager.find_route(&ctx).is_none(), assert!(
"TCP TLS without SNI should NOT match domain-restricted routes"); manager.find_route(&ctx).is_none(),
"TCP TLS without SNI should NOT match domain-restricted routes"
);
} }
} }
@@ -1,5 +1,5 @@
use base64::Engine;
use base64::engine::general_purpose::STANDARD as BASE64; use base64::engine::general_purpose::STANDARD as BASE64;
use base64::Engine;
/// Basic auth validator. /// Basic auth validator.
pub struct BasicAuthValidator { pub struct BasicAuthValidator {
+217 -41
View File
@@ -2,14 +2,26 @@ use ipnet::IpNet;
use std::net::IpAddr; use std::net::IpAddr;
use std::str::FromStr; use std::str::FromStr;
use rustproxy_config::IpAllowEntry;
/// IP filter supporting CIDR ranges, wildcards, and exact matches. /// IP filter supporting CIDR ranges, wildcards, and exact matches.
/// Supports domain-scoped allow entries that restrict an IP to specific domains.
pub struct IpFilter { pub struct IpFilter {
/// Plain allow entries — IP allowed for any domain on the route
allow_list: Vec<IpPattern>, allow_list: Vec<IpPattern>,
/// Domain-scoped allow entries — IP allowed only for matching domains
domain_scoped: Vec<DomainScopedEntry>,
block_list: Vec<IpPattern>, block_list: Vec<IpPattern>,
} }
/// A domain-scoped allow entry: IP + list of allowed domain patterns.
struct DomainScopedEntry {
pattern: IpPattern,
domains: Vec<String>,
}
/// Represents an IP pattern for matching. /// Represents an IP pattern for matching.
#[derive(Debug)] #[derive(Debug, Clone)]
enum IpPattern { enum IpPattern {
/// Exact IP match /// Exact IP match
Exact(IpAddr), Exact(IpAddr),
@@ -19,6 +31,37 @@ enum IpPattern {
Wildcard, Wildcard,
} }
/// Compiled block list for early ingress filtering.
#[derive(Debug, Clone)]
pub struct IpBlockList {
block_list: Vec<IpPattern>,
}
impl IpBlockList {
pub fn new(block_list: &[String]) -> Self {
Self {
block_list: block_list.iter().map(|s| IpPattern::parse(s)).collect(),
}
}
pub fn empty() -> Self {
Self {
block_list: Vec::new(),
}
}
pub fn is_empty(&self) -> bool {
self.block_list.is_empty()
}
pub fn is_blocked(&self, ip: &IpAddr) -> bool {
let normalized = IpFilter::normalize_ip(ip);
self.block_list
.iter()
.any(|pattern| pattern.matches(&normalized))
}
}
impl IpPattern { impl IpPattern {
fn parse(s: &str) -> Self { fn parse(s: &str) -> Self {
let s = s.trim(); let s = s.trim();
@@ -31,10 +74,6 @@ impl IpPattern {
if let Ok(addr) = IpAddr::from_str(s) { if let Ok(addr) = IpAddr::from_str(s) {
return IpPattern::Exact(addr); return IpPattern::Exact(addr);
} }
// Try as CIDR by appending default prefix
if let Ok(addr) = IpAddr::from_str(s) {
return IpPattern::Exact(addr);
}
// Fallback: treat as exact, will never match an invalid string // Fallback: treat as exact, will never match an invalid string
IpPattern::Exact(IpAddr::from_str("0.0.0.0").unwrap()) IpPattern::Exact(IpAddr::from_str("0.0.0.0").unwrap())
} }
@@ -48,19 +87,55 @@ impl IpPattern {
} }
} }
/// Simple domain pattern matching (exact, `*`, or `*.suffix`).
fn domain_matches_pattern(pattern: &str, domain: &str) -> bool {
let p = pattern.trim();
let d = domain.trim();
if p == "*" {
return true;
}
if p.eq_ignore_ascii_case(d) {
return true;
}
if p.starts_with("*.") {
let suffix = &p[1..]; // e.g., ".abc.xyz"
d.len() > suffix.len() && d[d.len() - suffix.len()..].eq_ignore_ascii_case(suffix)
} else {
false
}
}
impl IpFilter { impl IpFilter {
/// Create a new IP filter from allow and block lists. /// Create a new IP filter from allow entries and a block list.
pub fn new(allow_list: &[String], block_list: &[String]) -> Self { pub fn new(allow_entries: &[IpAllowEntry], block_list: &[String]) -> Self {
let mut allow_list = Vec::new();
let mut domain_scoped = Vec::new();
for entry in allow_entries {
match entry {
IpAllowEntry::Plain(ip) => {
allow_list.push(IpPattern::parse(ip));
}
IpAllowEntry::DomainScoped { ip, domains } => {
domain_scoped.push(DomainScopedEntry {
pattern: IpPattern::parse(ip),
domains: domains.clone(),
});
}
}
}
Self { Self {
allow_list: allow_list.iter().map(|s| IpPattern::parse(s)).collect(), allow_list,
domain_scoped,
block_list: block_list.iter().map(|s| IpPattern::parse(s)).collect(), block_list: block_list.iter().map(|s| IpPattern::parse(s)).collect(),
} }
} }
/// Check if an IP is allowed. /// Check if an IP is allowed, considering domain-scoped entries.
/// If allow_list is non-empty, IP must match at least one entry. /// If `domain` is Some, domain-scoped entries are evaluated against it.
/// If block_list is non-empty, IP must NOT match any entry. /// If `domain` is None, only plain allow entries are considered.
pub fn is_allowed(&self, ip: &IpAddr) -> bool { pub fn is_allowed_for_domain(&self, ip: &IpAddr, domain: Option<&str>) -> bool {
// Check block list first // Check block list first
if !self.block_list.is_empty() { if !self.block_list.is_empty() {
for pattern in &self.block_list { for pattern in &self.block_list {
@@ -70,14 +145,40 @@ impl IpFilter {
} }
} }
// If allow list is non-empty, must match at least one // If there are any allow entries (plain or domain-scoped), IP must match
if !self.allow_list.is_empty() { let has_any_allow = !self.allow_list.is_empty() || !self.domain_scoped.is_empty();
return self.allow_list.iter().any(|p| p.matches(ip)); if has_any_allow {
// Check plain allow list — grants access to entire route
if self.allow_list.iter().any(|p| p.matches(ip)) {
return true;
}
// Check domain-scoped entries — grants access only if domain matches
if let Some(req_domain) = domain {
for entry in &self.domain_scoped {
if entry.pattern.matches(ip) {
if entry
.domains
.iter()
.any(|d| domain_matches_pattern(d, req_domain))
{
return true;
}
}
}
}
return false;
} }
true true
} }
/// Check if an IP is allowed (backwards-compat wrapper, no domain context).
pub fn is_allowed(&self, ip: &IpAddr) -> bool {
self.is_allowed_for_domain(ip, None)
}
/// Normalize IPv4-mapped IPv6 addresses (::ffff:x.x.x.x -> x.x.x.x) /// Normalize IPv4-mapped IPv6 addresses (::ffff:x.x.x.x -> x.x.x.x)
pub fn normalize_ip(ip: &IpAddr) -> IpAddr { pub fn normalize_ip(ip: &IpAddr) -> IpAddr {
match ip { match ip {
@@ -97,19 +198,28 @@ impl IpFilter {
mod tests { mod tests {
use super::*; use super::*;
fn plain(s: &str) -> IpAllowEntry {
IpAllowEntry::Plain(s.to_string())
}
fn scoped(ip: &str, domains: &[&str]) -> IpAllowEntry {
IpAllowEntry::DomainScoped {
ip: ip.to_string(),
domains: domains.iter().map(|s| s.to_string()).collect(),
}
}
#[test] #[test]
fn test_empty_lists_allow_all() { fn test_empty_lists_allow_all() {
let filter = IpFilter::new(&[], &[]); let filter = IpFilter::new(&[], &[]);
let ip: IpAddr = "192.168.1.1".parse().unwrap(); let ip: IpAddr = "192.168.1.1".parse().unwrap();
assert!(filter.is_allowed(&ip)); assert!(filter.is_allowed(&ip));
assert!(filter.is_allowed_for_domain(&ip, Some("example.com")));
} }
#[test] #[test]
fn test_allow_list_exact() { fn test_plain_allow_list_exact() {
let filter = IpFilter::new( let filter = IpFilter::new(&[plain("10.0.0.1")], &[]);
&["10.0.0.1".to_string()],
&[],
);
let allowed: IpAddr = "10.0.0.1".parse().unwrap(); let allowed: IpAddr = "10.0.0.1".parse().unwrap();
let denied: IpAddr = "10.0.0.2".parse().unwrap(); let denied: IpAddr = "10.0.0.2".parse().unwrap();
assert!(filter.is_allowed(&allowed)); assert!(filter.is_allowed(&allowed));
@@ -117,11 +227,8 @@ mod tests {
} }
#[test] #[test]
fn test_allow_list_cidr() { fn test_plain_allow_list_cidr() {
let filter = IpFilter::new( let filter = IpFilter::new(&[plain("10.0.0.0/8")], &[]);
&["10.0.0.0/8".to_string()],
&[],
);
let allowed: IpAddr = "10.255.255.255".parse().unwrap(); let allowed: IpAddr = "10.255.255.255".parse().unwrap();
let denied: IpAddr = "192.168.1.1".parse().unwrap(); let denied: IpAddr = "192.168.1.1".parse().unwrap();
assert!(filter.is_allowed(&allowed)); assert!(filter.is_allowed(&allowed));
@@ -130,10 +237,7 @@ mod tests {
#[test] #[test]
fn test_block_list() { fn test_block_list() {
let filter = IpFilter::new( let filter = IpFilter::new(&[], &["192.168.1.100".to_string()]);
&[],
&["192.168.1.100".to_string()],
);
let blocked: IpAddr = "192.168.1.100".parse().unwrap(); let blocked: IpAddr = "192.168.1.100".parse().unwrap();
let allowed: IpAddr = "192.168.1.101".parse().unwrap(); let allowed: IpAddr = "192.168.1.101".parse().unwrap();
assert!(!filter.is_allowed(&blocked)); assert!(!filter.is_allowed(&blocked));
@@ -142,10 +246,7 @@ mod tests {
#[test] #[test]
fn test_block_trumps_allow() { fn test_block_trumps_allow() {
let filter = IpFilter::new( let filter = IpFilter::new(&[plain("10.0.0.0/8")], &["10.0.0.5".to_string()]);
&["10.0.0.0/8".to_string()],
&["10.0.0.5".to_string()],
);
let blocked: IpAddr = "10.0.0.5".parse().unwrap(); let blocked: IpAddr = "10.0.0.5".parse().unwrap();
let allowed: IpAddr = "10.0.0.6".parse().unwrap(); let allowed: IpAddr = "10.0.0.6".parse().unwrap();
assert!(!filter.is_allowed(&blocked)); assert!(!filter.is_allowed(&blocked));
@@ -154,20 +255,14 @@ mod tests {
#[test] #[test]
fn test_wildcard_allow() { fn test_wildcard_allow() {
let filter = IpFilter::new( let filter = IpFilter::new(&[plain("*")], &[]);
&["*".to_string()],
&[],
);
let ip: IpAddr = "1.2.3.4".parse().unwrap(); let ip: IpAddr = "1.2.3.4".parse().unwrap();
assert!(filter.is_allowed(&ip)); assert!(filter.is_allowed(&ip));
} }
#[test] #[test]
fn test_wildcard_block() { fn test_wildcard_block() {
let filter = IpFilter::new( let filter = IpFilter::new(&[], &["*".to_string()]);
&[],
&["*".to_string()],
);
let ip: IpAddr = "1.2.3.4".parse().unwrap(); let ip: IpAddr = "1.2.3.4".parse().unwrap();
assert!(!filter.is_allowed(&ip)); assert!(!filter.is_allowed(&ip));
} }
@@ -186,4 +281,85 @@ mod tests {
let normalized = IpFilter::normalize_ip(&ip); let normalized = IpFilter::normalize_ip(&ip);
assert_eq!(normalized, ip); assert_eq!(normalized, ip);
} }
// Domain-scoped tests
#[test]
fn test_domain_scoped_allows_matching_domain() {
let filter = IpFilter::new(&[scoped("10.8.0.2", &["outline.abc.xyz"])], &[]);
let ip: IpAddr = "10.8.0.2".parse().unwrap();
assert!(filter.is_allowed_for_domain(&ip, Some("outline.abc.xyz")));
}
#[test]
fn test_domain_scoped_denies_non_matching_domain() {
let filter = IpFilter::new(&[scoped("10.8.0.2", &["outline.abc.xyz"])], &[]);
let ip: IpAddr = "10.8.0.2".parse().unwrap();
assert!(!filter.is_allowed_for_domain(&ip, Some("app.abc.xyz")));
}
#[test]
fn test_domain_scoped_denies_without_domain() {
let filter = IpFilter::new(&[scoped("10.8.0.2", &["outline.abc.xyz"])], &[]);
let ip: IpAddr = "10.8.0.2".parse().unwrap();
// Without domain context, domain-scoped entries cannot match
assert!(!filter.is_allowed_for_domain(&ip, None));
}
#[test]
fn test_domain_scoped_wildcard_domain() {
let filter = IpFilter::new(&[scoped("10.8.0.2", &["*.abc.xyz"])], &[]);
let ip: IpAddr = "10.8.0.2".parse().unwrap();
assert!(filter.is_allowed_for_domain(&ip, Some("outline.abc.xyz")));
assert!(filter.is_allowed_for_domain(&ip, Some("app.abc.xyz")));
assert!(!filter.is_allowed_for_domain(&ip, Some("other.com")));
}
#[test]
fn test_plain_and_domain_scoped_coexist() {
let filter = IpFilter::new(
&[
plain("1.2.3.4"), // full route access
scoped("10.8.0.2", &["outline.abc.xyz"]), // scoped access
],
&[],
);
let admin: IpAddr = "1.2.3.4".parse().unwrap();
let vpn: IpAddr = "10.8.0.2".parse().unwrap();
let other: IpAddr = "9.9.9.9".parse().unwrap();
// Admin IP has full access
assert!(filter.is_allowed_for_domain(&admin, Some("anything.abc.xyz")));
assert!(filter.is_allowed_for_domain(&admin, Some("outline.abc.xyz")));
// VPN IP only has scoped access
assert!(filter.is_allowed_for_domain(&vpn, Some("outline.abc.xyz")));
assert!(!filter.is_allowed_for_domain(&vpn, Some("app.abc.xyz")));
// Unknown IP denied
assert!(!filter.is_allowed_for_domain(&other, Some("outline.abc.xyz")));
}
#[test]
fn test_block_trumps_domain_scoped() {
let filter = IpFilter::new(
&[scoped("10.8.0.2", &["outline.abc.xyz"])],
&["10.8.0.2".to_string()],
);
let ip: IpAddr = "10.8.0.2".parse().unwrap();
assert!(!filter.is_allowed_for_domain(&ip, Some("outline.abc.xyz")));
}
#[test]
fn test_domain_matches_pattern_fn() {
assert!(domain_matches_pattern("example.com", "example.com"));
assert!(domain_matches_pattern("*.abc.xyz", "outline.abc.xyz"));
assert!(domain_matches_pattern("*.abc.xyz", "app.abc.xyz"));
assert!(!domain_matches_pattern("*.abc.xyz", "abc.xyz")); // suffix only, not exact parent
assert!(domain_matches_pattern("*", "anything.com"));
assert!(!domain_matches_pattern("outline.abc.xyz", "app.abc.xyz"));
// Case insensitive
assert!(domain_matches_pattern("*.ABC.XYZ", "outline.abc.xyz"));
}
} }
@@ -1,4 +1,4 @@
use jsonwebtoken::{decode, DecodingKey, Validation, Algorithm}; use jsonwebtoken::{decode, Algorithm, DecodingKey, Validation};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
/// JWT claims (minimal structure). /// JWT claims (minimal structure).
@@ -160,10 +160,7 @@ mod tests {
#[test] #[test]
fn test_extract_token_bearer() { fn test_extract_token_bearer() {
assert_eq!( assert_eq!(JwtValidator::extract_token("Bearer abc123"), Some("abc123"));
JwtValidator::extract_token("Bearer abc123"),
Some("abc123")
);
} }
#[test] #[test]
+4 -4
View File
@@ -2,12 +2,12 @@
//! //!
//! IP filtering, rate limiting, and authentication for RustProxy. //! IP filtering, rate limiting, and authentication for RustProxy.
pub mod ip_filter;
pub mod rate_limiter;
pub mod basic_auth; pub mod basic_auth;
pub mod ip_filter;
pub mod jwt_auth; pub mod jwt_auth;
pub mod rate_limiter;
pub use ip_filter::*;
pub use rate_limiter::*;
pub use basic_auth::*; pub use basic_auth::*;
pub use ip_filter::*;
pub use jwt_auth::*; pub use jwt_auth::*;
pub use rate_limiter::*;
@@ -79,7 +79,7 @@ mod tests {
assert!(limiter.check("client-a")); assert!(limiter.check("client-a"));
assert!(limiter.check("client-a")); assert!(limiter.check("client-a"));
assert!(!limiter.check("client-a")); // blocked assert!(!limiter.check("client-a")); // blocked
// Different key should still be allowed // Different key should still be allowed
assert!(limiter.check("client-b")); assert!(limiter.check("client-b"));
assert!(limiter.check("client-b")); assert!(limiter.check("client-b"));
} }
+14 -13
View File
@@ -4,8 +4,7 @@
//! Account credentials are ephemeral — the consumer owns all persistence. //! Account credentials are ephemeral — the consumer owns all persistence.
use instant_acme::{ use instant_acme::{
Account, NewAccount, NewOrder, Identifier, ChallengeType, OrderStatus, Account, AccountCredentials, ChallengeType, Identifier, NewAccount, NewOrder, OrderStatus,
AccountCredentials,
}; };
use rcgen::{CertificateParams, KeyPair}; use rcgen::{CertificateParams, KeyPair};
use thiserror::Error; use thiserror::Error;
@@ -89,7 +88,11 @@ impl AcmeClient {
F: FnOnce(PendingChallenge) -> Fut, F: FnOnce(PendingChallenge) -> Fut,
Fut: std::future::Future<Output = Result<(), AcmeError>>, Fut: std::future::Future<Output = Result<(), AcmeError>>,
{ {
info!("Starting ACME provisioning for {} via {}", domain, self.directory_url()); info!(
"Starting ACME provisioning for {} via {}",
domain,
self.directory_url()
);
// 1. Get or create ACME account // 1. Get or create ACME account
let account = self.get_or_create_account().await?; let account = self.get_or_create_account().await?;
@@ -170,14 +173,14 @@ impl AcmeClient {
debug!("Order ready, finalizing..."); debug!("Order ready, finalizing...");
// 6. Generate CSR and finalize // 6. Generate CSR and finalize
let key_pair = KeyPair::generate().map_err(|e| { let key_pair = KeyPair::generate()
AcmeError::FinalizationFailed(format!("Key generation failed: {}", e)) .map_err(|e| AcmeError::FinalizationFailed(format!("Key generation failed: {}", e)))?;
})?;
let mut params = CertificateParams::new(vec![domain.to_string()]).map_err(|e| { let mut params = CertificateParams::new(vec![domain.to_string()])
AcmeError::FinalizationFailed(format!("CSR params failed: {}", e)) .map_err(|e| AcmeError::FinalizationFailed(format!("CSR params failed: {}", e)))?;
})?; params
params.distinguished_name.push(rcgen::DnType::CommonName, domain); .distinguished_name
.push(rcgen::DnType::CommonName, domain);
let csr = params.serialize_request(&key_pair).map_err(|e| { let csr = params.serialize_request(&key_pair).map_err(|e| {
AcmeError::FinalizationFailed(format!("CSR serialization failed: {}", e)) AcmeError::FinalizationFailed(format!("CSR serialization failed: {}", e))
@@ -219,9 +222,7 @@ impl AcmeClient {
.certificate() .certificate()
.await .await
.map_err(|e| AcmeError::FinalizationFailed(e.to_string()))? .map_err(|e| AcmeError::FinalizationFailed(e.to_string()))?
.ok_or_else(|| { .ok_or_else(|| AcmeError::FinalizationFailed("No certificate returned".to_string()))?;
AcmeError::FinalizationFailed("No certificate returned".to_string())
})?;
let private_key_pem = key_pair.serialize_pem(); let private_key_pem = key_pair.serialize_pem();
+20 -22
View File
@@ -2,8 +2,8 @@ use std::time::{SystemTime, UNIX_EPOCH};
use thiserror::Error; use thiserror::Error;
use tracing::info; use tracing::info;
use crate::cert_store::{CertStore, CertBundle, CertMetadata, CertSource};
use crate::acme::AcmeClient; use crate::acme::AcmeClient;
use crate::cert_store::{CertBundle, CertMetadata, CertSource, CertStore};
#[derive(Debug, Error)] #[derive(Debug, Error)]
pub enum CertManagerError { pub enum CertManagerError {
@@ -45,17 +45,13 @@ impl CertManager {
/// Create an ACME client using this manager's configuration. /// Create an ACME client using this manager's configuration.
/// Returns None if no ACME email is configured. /// Returns None if no ACME email is configured.
pub fn acme_client(&self) -> Option<AcmeClient> { pub fn acme_client(&self) -> Option<AcmeClient> {
self.acme_email.as_ref().map(|email| { self.acme_email
AcmeClient::new(email.clone(), self.use_production) .as_ref()
}) .map(|email| AcmeClient::new(email.clone(), self.use_production))
} }
/// Load a static certificate into the store (infallible — pure cache insert). /// Load a static certificate into the store (infallible — pure cache insert).
pub fn load_static( pub fn load_static(&mut self, domain: String, bundle: CertBundle) {
&mut self,
domain: String,
bundle: CertBundle,
) {
self.store.store(domain, bundle); self.store.store(domain, bundle);
} }
@@ -108,23 +104,25 @@ impl CertManager {
F: FnOnce(String, String) -> Fut, F: FnOnce(String, String) -> Fut,
Fut: std::future::Future<Output = ()>, Fut: std::future::Future<Output = ()>,
{ {
let acme_client = self.acme_client() let acme_client = self.acme_client().ok_or(CertManagerError::NoEmail)?;
.ok_or(CertManagerError::NoEmail)?;
info!("Renewing certificate for {}", domain); info!("Renewing certificate for {}", domain);
let domain_owned = domain.to_string(); let domain_owned = domain.to_string();
let result = acme_client.provision(&domain_owned, |pending| { let result = acme_client
let token = pending.token.clone(); .provision(&domain_owned, |pending| {
let key_auth = pending.key_authorization.clone(); let token = pending.token.clone();
async move { let key_auth = pending.key_authorization.clone();
challenge_setup(token, key_auth).await; async move {
Ok(()) challenge_setup(token, key_auth).await;
} Ok(())
}).await.map_err(|e| CertManagerError::AcmeFailure { }
domain: domain.to_string(), })
message: e.to_string(), .await
})?; .map_err(|e| CertManagerError::AcmeFailure {
domain: domain.to_string(),
message: e.to_string(),
})?;
let (cert_pem, key_pem) = result; let (cert_pem, key_pem) = result;
let now = SystemTime::now() let now = SystemTime::now()
+15 -6
View File
@@ -1,5 +1,5 @@
use std::collections::HashMap;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::collections::HashMap;
/// Certificate metadata stored alongside certs. /// Certificate metadata stored alongside certs.
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
@@ -90,8 +90,10 @@ mod tests {
fn make_test_bundle(domain: &str) -> CertBundle { fn make_test_bundle(domain: &str) -> CertBundle {
CertBundle { CertBundle {
key_pem: "-----BEGIN PRIVATE KEY-----\ntest-key\n-----END PRIVATE KEY-----\n".to_string(), key_pem: "-----BEGIN PRIVATE KEY-----\ntest-key\n-----END PRIVATE KEY-----\n"
cert_pem: "-----BEGIN CERTIFICATE-----\ntest-cert\n-----END CERTIFICATE-----\n".to_string(), .to_string(),
cert_pem: "-----BEGIN CERTIFICATE-----\ntest-cert\n-----END CERTIFICATE-----\n"
.to_string(),
ca_pem: None, ca_pem: None,
metadata: CertMetadata { metadata: CertMetadata {
domain: domain.to_string(), domain: domain.to_string(),
@@ -122,7 +124,8 @@ mod tests {
let mut store = CertStore::new(); let mut store = CertStore::new();
let mut bundle = make_test_bundle("secure.com"); let mut bundle = make_test_bundle("secure.com");
bundle.ca_pem = Some("-----BEGIN CERTIFICATE-----\nca-cert\n-----END CERTIFICATE-----\n".to_string()); bundle.ca_pem =
Some("-----BEGIN CERTIFICATE-----\nca-cert\n-----END CERTIFICATE-----\n".to_string());
store.store("secure.com".to_string(), bundle); store.store("secure.com".to_string(), bundle);
let loaded = store.get("secure.com").unwrap(); let loaded = store.get("secure.com").unwrap();
@@ -147,7 +150,10 @@ mod tests {
fn test_remove_cert() { fn test_remove_cert() {
let mut store = CertStore::new(); let mut store = CertStore::new();
store.store("remove-me.com".to_string(), make_test_bundle("remove-me.com")); store.store(
"remove-me.com".to_string(),
make_test_bundle("remove-me.com"),
);
assert!(store.has("remove-me.com")); assert!(store.has("remove-me.com"));
let removed = store.remove("remove-me.com"); let removed = store.remove("remove-me.com");
@@ -165,7 +171,10 @@ mod tests {
fn test_wildcard_domain() { fn test_wildcard_domain() {
let mut store = CertStore::new(); let mut store = CertStore::new();
store.store("*.example.com".to_string(), make_test_bundle("*.example.com")); store.store(
"*.example.com".to_string(),
make_test_bundle("*.example.com"),
);
assert!(store.has("*.example.com")); assert!(store.has("*.example.com"));
let loaded = store.get("*.example.com").unwrap(); let loaded = store.get("*.example.com").unwrap();
+3 -3
View File
@@ -3,11 +3,11 @@
//! TLS certificate management for RustProxy. //! TLS certificate management for RustProxy.
//! Handles ACME (Let's Encrypt), static certificates, and dynamic SNI resolution. //! Handles ACME (Let's Encrypt), static certificates, and dynamic SNI resolution.
pub mod cert_store;
pub mod cert_manager;
pub mod acme; pub mod acme;
pub mod cert_manager;
pub mod cert_store;
pub mod sni_resolver; pub mod sni_resolver;
pub use cert_store::*;
pub use cert_manager::*; pub use cert_manager::*;
pub use cert_store::*;
pub use sni_resolver::*; pub use sni_resolver::*;
+12 -8
View File
@@ -13,7 +13,7 @@ use hyper::{Request, Response, StatusCode};
use hyper_util::rt::TokioIo; use hyper_util::rt::TokioIo;
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use tracing::{debug, info, error}; use tracing::{debug, error, info};
/// ACME HTTP-01 challenge server. /// ACME HTTP-01 challenge server.
pub struct ChallengeServer { pub struct ChallengeServer {
@@ -47,7 +47,10 @@ impl ChallengeServer {
} }
/// Start the challenge server on the given port. /// Start the challenge server on the given port.
pub async fn start(&mut self, port: u16) -> Result<(), Box<dyn std::error::Error + Send + Sync>> { pub async fn start(
&mut self,
port: u16,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let addr = format!("0.0.0.0:{}", port); let addr = format!("0.0.0.0:{}", port);
let listener = TcpListener::bind(&addr).await?; let listener = TcpListener::bind(&addr).await?;
info!("ACME challenge server listening on port {}", port); info!("ACME challenge server listening on port {}", port);
@@ -101,10 +104,7 @@ impl ChallengeServer {
pub async fn stop(&mut self) { pub async fn stop(&mut self) {
self.cancel.cancel(); self.cancel.cancel();
if let Some(handle) = self.handle.take() { if let Some(handle) = self.handle.take() {
let _ = tokio::time::timeout( let _ = tokio::time::timeout(std::time::Duration::from_secs(5), handle).await;
std::time::Duration::from_secs(5),
handle,
).await;
} }
self.challenges.clear(); self.challenges.clear();
self.cancel = CancellationToken::new(); self.cancel = CancellationToken::new();
@@ -154,10 +154,14 @@ mod tests {
tokio::time::sleep(std::time::Duration::from_millis(50)).await; tokio::time::sleep(std::time::Duration::from_millis(50)).await;
// Fetch the challenge // Fetch the challenge
let client = tokio::net::TcpStream::connect("127.0.0.1:19900").await.unwrap(); let client = tokio::net::TcpStream::connect("127.0.0.1:19900")
.await
.unwrap();
let io = TokioIo::new(client); let io = TokioIo::new(client);
let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await.unwrap(); let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await.unwrap();
tokio::spawn(async move { let _ = conn.await; }); tokio::spawn(async move {
let _ = conn.await;
});
let req = Request::get("/.well-known/acme-challenge/test-token") let req = Request::get("/.well-known/acme-challenge/test-token")
.body(Full::new(Bytes::new())) .body(Full::new(Bytes::new()))
+303 -140
View File
@@ -57,24 +57,27 @@ use std::collections::{HashMap, HashSet};
use std::sync::Arc; use std::sync::Arc;
use std::time::Instant; use std::time::Instant;
use arc_swap::ArcSwap;
use anyhow::Result; use anyhow::Result;
use tracing::{info, warn, debug, error}; use arc_swap::ArcSwap;
use tracing::{debug, error, info, warn};
// Re-export key types // Re-export key types
pub use rustproxy_config; pub use rustproxy_config;
pub use rustproxy_routing;
pub use rustproxy_passthrough;
pub use rustproxy_tls;
pub use rustproxy_http; pub use rustproxy_http;
pub use rustproxy_metrics; pub use rustproxy_metrics;
pub use rustproxy_passthrough;
pub use rustproxy_routing;
pub use rustproxy_security; pub use rustproxy_security;
pub use rustproxy_tls;
use rustproxy_config::{RouteConfig, RustProxyOptions, TlsMode, CertificateSpec}; use rustproxy_config::{CertificateSpec, RouteConfig, RustProxyOptions, TlsMode};
use rustproxy_metrics::{Metrics, MetricsCollector, Statistics};
use rustproxy_passthrough::{
ConnectionConfig, TcpListenerManager, TlsCertConfig, UdpListenerManager,
};
use rustproxy_routing::RouteManager; use rustproxy_routing::RouteManager;
use rustproxy_passthrough::{TcpListenerManager, UdpListenerManager, TlsCertConfig, ConnectionConfig}; use rustproxy_security::IpBlockList;
use rustproxy_metrics::{MetricsCollector, Metrics, Statistics}; use rustproxy_tls::{CertBundle, CertManager, CertMetadata, CertSource, CertStore};
use rustproxy_tls::{CertManager, CertStore, CertBundle, CertMetadata, CertSource};
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
/// Certificate status. /// Certificate status.
@@ -106,6 +109,8 @@ pub struct RustProxy {
loaded_certs: HashMap<String, TlsCertConfig>, loaded_certs: HashMap<String, TlsCertConfig>,
/// Cancellation token for cooperative shutdown of background tasks. /// Cancellation token for cooperative shutdown of background tasks.
cancel_token: CancellationToken, cancel_token: CancellationToken,
/// Shared global ingress blocklist, hot-reloadable across TCP/UDP listeners.
security_policy: Arc<ArcSwap<IpBlockList>>,
} }
impl RustProxy { impl RustProxy {
@@ -127,13 +132,19 @@ impl RustProxy {
let route_manager = RouteManager::new(options.routes.clone()); let route_manager = RouteManager::new(options.routes.clone());
// Set up certificate manager if ACME is configured // Set up certificate manager if ACME is configured
let cert_manager = Self::build_cert_manager(&options) let cert_manager =
.map(|cm| Arc::new(tokio::sync::Mutex::new(cm))); Self::build_cert_manager(&options).map(|cm| Arc::new(tokio::sync::Mutex::new(cm)));
let retention = options.metrics.as_ref() let retention = options
.metrics
.as_ref()
.and_then(|m| m.retention_seconds) .and_then(|m| m.retention_seconds)
.unwrap_or(3600) as usize; .unwrap_or(3600) as usize;
let security_policy = Arc::new(ArcSwap::from(Arc::new(Self::build_ip_block_list(
options.security_policy.as_ref(),
))));
Ok(Self { Ok(Self {
options, options,
route_table: ArcSwap::from(Arc::new(route_manager)), route_table: ArcSwap::from(Arc::new(route_manager)),
@@ -149,6 +160,7 @@ impl RustProxy {
socket_handler_relay: Arc::new(std::sync::RwLock::new(None)), socket_handler_relay: Arc::new(std::sync::RwLock::new(None)),
loaded_certs: HashMap::new(), loaded_certs: HashMap::new(),
cancel_token: CancellationToken::new(), cancel_token: CancellationToken::new(),
security_policy,
}) })
} }
@@ -163,24 +175,25 @@ impl RustProxy {
// Apply default target if route has no targets // Apply default target if route has no targets
if route.action.targets.is_none() { if route.action.targets.is_none() {
if let Some(ref default_target) = defaults.target { if let Some(ref default_target) = defaults.target {
debug!("Applying default target {}:{} to route {:?}", debug!(
default_target.host, default_target.port, "Applying default target {}:{} to route {:?}",
route.name.as_deref().unwrap_or("unnamed")); default_target.host,
route.action.targets = Some(vec![ default_target.port,
rustproxy_config::RouteTarget { route.name.as_deref().unwrap_or("unnamed")
target_match: None, );
host: rustproxy_config::HostSpec::Single(default_target.host.clone()), route.action.targets = Some(vec![rustproxy_config::RouteTarget {
port: rustproxy_config::PortSpec::Fixed(default_target.port), target_match: None,
tls: None, host: rustproxy_config::HostSpec::Single(default_target.host.clone()),
websocket: None, port: rustproxy_config::PortSpec::Fixed(default_target.port),
load_balancing: None, tls: None,
send_proxy_protocol: None, websocket: None,
headers: None, load_balancing: None,
advanced: None, send_proxy_protocol: None,
backend_transport: None, headers: None,
priority: None, advanced: None,
} backend_transport: None,
]); priority: None,
}]);
} }
} }
@@ -198,7 +211,12 @@ impl RustProxy {
}; };
if let Some(ref allow_list) = default_security.ip_allow_list { if let Some(ref allow_list) = default_security.ip_allow_list {
security.ip_allow_list = Some(allow_list.clone()); security.ip_allow_list = Some(
allow_list
.iter()
.map(|s| rustproxy_config::IpAllowEntry::Plain(s.clone()))
.collect(),
);
} }
if let Some(ref block_list) = default_security.ip_block_list { if let Some(ref block_list) = default_security.ip_block_list {
security.ip_block_list = Some(block_list.clone()); security.ip_block_list = Some(block_list.clone());
@@ -206,8 +224,10 @@ impl RustProxy {
// Only apply if there's something meaningful // Only apply if there's something meaningful
if security.ip_allow_list.is_some() || security.ip_block_list.is_some() { if security.ip_allow_list.is_some() || security.ip_block_list.is_some() {
debug!("Applying default security to route {:?}", debug!(
route.name.as_deref().unwrap_or("unnamed")); "Applying default security to route {:?}",
route.name.as_deref().unwrap_or("unnamed")
);
route.security = Some(security); route.security = Some(security);
} }
} }
@@ -222,13 +242,17 @@ impl RustProxy {
return None; return None;
} }
let email = acme.email.clone() let email = acme.email.clone().or_else(|| acme.account_email.clone());
.or_else(|| acme.account_email.clone());
let use_production = acme.use_production.unwrap_or(false); let use_production = acme.use_production.unwrap_or(false);
let renew_before_days = acme.renew_threshold_days.unwrap_or(30); let renew_before_days = acme.renew_threshold_days.unwrap_or(30);
let store = CertStore::new(); let store = CertStore::new();
Some(CertManager::new(store, email, use_production, renew_before_days)) Some(CertManager::new(
store,
email,
use_production,
renew_before_days,
))
} }
/// Build ConnectionConfig from RustProxyOptions. /// Build ConnectionConfig from RustProxyOptions.
@@ -246,7 +270,10 @@ impl RustProxy {
extended_keep_alive_lifetime_ms: options.extended_keep_alive_lifetime, extended_keep_alive_lifetime_ms: options.extended_keep_alive_lifetime,
accept_proxy_protocol: options.accept_proxy_protocol.unwrap_or(false), accept_proxy_protocol: options.accept_proxy_protocol.unwrap_or(false),
send_proxy_protocol: options.send_proxy_protocol.unwrap_or(false), send_proxy_protocol: options.send_proxy_protocol.unwrap_or(false),
proxy_ips: options.proxy_ips.as_deref().unwrap_or(&[]) proxy_ips: options
.proxy_ips
.as_deref()
.unwrap_or(&[])
.iter() .iter()
.filter_map(|s| s.parse::<std::net::IpAddr>().ok()) .filter_map(|s| s.parse::<std::net::IpAddr>().ok())
.collect(), .collect(),
@@ -256,6 +283,22 @@ impl RustProxy {
} }
} }
fn build_ip_block_list(policy: Option<&rustproxy_config::SecurityPolicy>) -> IpBlockList {
let Some(policy) = policy else {
return IpBlockList::empty();
};
let mut entries = Vec::new();
if let Some(blocked_ips) = &policy.blocked_ips {
entries.extend(blocked_ips.iter().cloned());
}
if let Some(blocked_cidrs) = &policy.blocked_cidrs {
entries.extend(blocked_cidrs.iter().cloned());
}
IpBlockList::new(&entries)
}
/// Start the proxy, binding to all configured ports. /// Start the proxy, binding to all configured ports.
pub async fn start(&mut self) -> Result<()> { pub async fn start(&mut self) -> Result<()> {
if self.started { if self.started {
@@ -270,7 +313,11 @@ impl RustProxy {
let route_manager = self.route_table.load(); let route_manager = self.route_table.load();
let ports = route_manager.listening_ports(); let ports = route_manager.listening_ports();
info!("Configured {} routes on {} ports", route_manager.route_count(), ports.len()); info!(
"Configured {} routes on {} ports",
route_manager.route_count(),
ports.len()
);
// Create TCP listener manager with metrics // Create TCP listener manager with metrics
let mut listener = TcpListenerManager::with_metrics( let mut listener = TcpListenerManager::with_metrics(
@@ -280,7 +327,8 @@ impl RustProxy {
// Apply connection config from options // Apply connection config from options
let conn_config = Self::build_connection_config(&self.options); let conn_config = Self::build_connection_config(&self.options);
debug!("Connection config: timeout={}ms, initial_data={}ms, socket={}ms, max_life={}ms", debug!(
"Connection config: timeout={}ms, initial_data={}ms, socket={}ms, max_life={}ms",
conn_config.connection_timeout_ms, conn_config.connection_timeout_ms,
conn_config.initial_data_timeout_ms, conn_config.initial_data_timeout_ms,
conn_config.socket_timeout_ms, conn_config.socket_timeout_ms,
@@ -289,6 +337,7 @@ impl RustProxy {
// Clone proxy_ips before conn_config is moved into the TCP listener // Clone proxy_ips before conn_config is moved into the TCP listener
let udp_proxy_ips = conn_config.proxy_ips.clone(); let udp_proxy_ips = conn_config.proxy_ips.clone();
listener.set_connection_config(conn_config); listener.set_connection_config(conn_config);
listener.set_security_policy(Arc::clone(&self.security_policy));
// Share the socket-handler relay path with the listener // Share the socket-handler relay path with the listener
listener.set_socket_handler_relay(Arc::clone(&self.socket_handler_relay)); listener.set_socket_handler_relay(Arc::clone(&self.socket_handler_relay));
@@ -301,10 +350,13 @@ impl RustProxy {
let cm = cm.lock().await; let cm = cm.lock().await;
for (domain, bundle) in cm.store().iter() { for (domain, bundle) in cm.store().iter() {
if !tls_configs.contains_key(domain) { if !tls_configs.contains_key(domain) {
tls_configs.insert(domain.clone(), TlsCertConfig { tls_configs.insert(
cert_pem: bundle.cert_pem.clone(), domain.clone(),
key_pem: bundle.key_pem.clone(), TlsCertConfig {
}); cert_pem: bundle.cert_pem.clone(),
key_pem: bundle.key_pem.clone(),
},
);
} }
} }
} }
@@ -328,7 +380,9 @@ impl RustProxy {
let mut tcp_ports = std::collections::HashSet::new(); let mut tcp_ports = std::collections::HashSet::new();
let mut udp_ports = std::collections::HashSet::new(); let mut udp_ports = std::collections::HashSet::new();
for route in &self.options.routes { for route in &self.options.routes {
if !route.is_enabled() { continue; } if !route.is_enabled() {
continue;
}
let transport = route.route_match.transport.as_ref(); let transport = route.route_match.transport.as_ref();
let route_ports = route.route_match.ports.to_ports(); let route_ports = route.route_match.ports.to_ports();
for port in route_ports { for port in route_ports {
@@ -369,6 +423,7 @@ impl RustProxy {
connection_registry, connection_registry,
); );
udp_mgr.set_proxy_ips(udp_proxy_ips.clone()); udp_mgr.set_proxy_ips(udp_proxy_ips.clone());
udp_mgr.set_security_policy(Arc::clone(&self.security_policy));
// Share HttpProxyService with H3 — same route matching, connection // Share HttpProxyService with H3 — same route matching, connection
// pool, and ALPN protocol detection as the TCP/HTTP path. // pool, and ALPN protocol detection as the TCP/HTTP path.
@@ -377,10 +432,15 @@ impl RustProxy {
udp_mgr.set_h3_service(Arc::new(h3_svc)); udp_mgr.set_h3_service(Arc::new(h3_svc));
for port in &udp_ports { for port in &udp_ports {
udp_mgr.add_port_with_tls(*port, quic_tls_config.clone()).await?; udp_mgr
.add_port_with_tls(*port, quic_tls_config.clone())
.await?;
} }
info!("UDP listeners started on {} ports: {:?}", info!(
udp_ports.len(), udp_mgr.listening_ports()); "UDP listeners started on {} ports: {:?}",
udp_ports.len(),
udp_mgr.listening_ports()
);
self.udp_listener_manager = Some(udp_mgr); self.udp_listener_manager = Some(udp_mgr);
} }
@@ -389,16 +449,22 @@ impl RustProxy {
// Start the throughput sampling task with cooperative cancellation // Start the throughput sampling task with cooperative cancellation
let metrics = Arc::clone(&self.metrics); let metrics = Arc::clone(&self.metrics);
let conn_tracker = self.listener_manager.as_ref().unwrap().conn_tracker().clone(); let conn_tracker = self
.listener_manager
.as_ref()
.unwrap()
.conn_tracker()
.clone();
let http_proxy = self.listener_manager.as_ref().unwrap().http_proxy().clone(); let http_proxy = self.listener_manager.as_ref().unwrap().http_proxy().clone();
let interval_ms = self.options.metrics.as_ref() let interval_ms = self
.options
.metrics
.as_ref()
.and_then(|m| m.sample_interval_ms) .and_then(|m| m.sample_interval_ms)
.unwrap_or(1000); .unwrap_or(1000);
let sampling_cancel = self.cancel_token.clone(); let sampling_cancel = self.cancel_token.clone();
self.sampling_handle = Some(tokio::spawn(async move { self.sampling_handle = Some(tokio::spawn(async move {
let mut interval = tokio::time::interval( let mut interval = tokio::time::interval(std::time::Duration::from_millis(interval_ms));
std::time::Duration::from_millis(interval_ms)
);
loop { loop {
tokio::select! { tokio::select! {
_ = sampling_cancel.cancelled() => break, _ = sampling_cancel.cancelled() => break,
@@ -440,7 +506,10 @@ impl RustProxy {
continue; continue;
} }
let cert_spec = route.action.tls.as_ref() let cert_spec = route
.action
.tls
.as_ref()
.and_then(|tls| tls.certificate.as_ref()); .and_then(|tls| tls.certificate.as_ref());
if let Some(CertificateSpec::Auto(_)) = cert_spec { if let Some(CertificateSpec::Auto(_)) = cert_spec {
@@ -464,16 +533,25 @@ impl RustProxy {
return; return;
} }
info!("Auto-provisioning certificates for {} domains", domains_to_provision.len()); info!(
"Auto-provisioning certificates for {} domains",
domains_to_provision.len()
);
// Start challenge server // Start challenge server
let acme_port = self.options.acme.as_ref() let acme_port = self
.options
.acme
.as_ref()
.and_then(|a| a.port) .and_then(|a| a.port)
.unwrap_or(80); .unwrap_or(80);
let mut challenge_server = challenge_server::ChallengeServer::new(); let mut challenge_server = challenge_server::ChallengeServer::new();
if let Err(e) = challenge_server.start(acme_port).await { if let Err(e) = challenge_server.start(acme_port).await {
error!("Failed to start ACME challenge server on port {}: {}", acme_port, e); error!(
"Failed to start ACME challenge server on port {}: {}",
acme_port, e
);
return; return;
} }
@@ -486,13 +564,15 @@ impl RustProxy {
if let Some(acme_client) = acme_client { if let Some(acme_client) = acme_client {
let challenge_server_ref = &challenge_server; let challenge_server_ref = &challenge_server;
let result = acme_client.provision(domain, |pending| { let result = acme_client
challenge_server_ref.set_challenge( .provision(domain, |pending| {
pending.token.clone(), challenge_server_ref.set_challenge(
pending.key_authorization.clone(), pending.token.clone(),
); pending.key_authorization.clone(),
async move { Ok(()) } );
}).await; async move { Ok(()) }
})
.await;
match result { match result {
Ok((cert_pem, key_pem)) => { Ok((cert_pem, key_pem)) => {
@@ -537,7 +617,10 @@ impl RustProxy {
None => return, None => return,
}; };
let auto_renew = self.options.acme.as_ref() let auto_renew = self
.options
.acme
.as_ref()
.and_then(|a| a.auto_renew) .and_then(|a| a.auto_renew)
.unwrap_or(true); .unwrap_or(true);
@@ -545,11 +628,17 @@ impl RustProxy {
return; return;
} }
let check_interval_hours = self.options.acme.as_ref() let check_interval_hours = self
.options
.acme
.as_ref()
.and_then(|a| a.renew_check_interval_hours) .and_then(|a| a.renew_check_interval_hours)
.unwrap_or(24); .unwrap_or(24);
let acme_port = self.options.acme.as_ref() let acme_port = self
.options
.acme
.as_ref()
.and_then(|a| a.port) .and_then(|a| a.port)
.unwrap_or(80); .unwrap_or(80);
@@ -662,17 +751,19 @@ impl RustProxy {
/// Update routes atomically (hot-reload). /// Update routes atomically (hot-reload).
pub async fn update_routes(&mut self, routes: Vec<RouteConfig>) -> Result<()> { pub async fn update_routes(&mut self, routes: Vec<RouteConfig>) -> Result<()> {
// Validate new routes // Validate new routes
rustproxy_config::validate_routes(&routes) rustproxy_config::validate_routes(&routes).map_err(|errors| {
.map_err(|errors| { let msgs: Vec<String> = errors.iter().map(|e| e.to_string()).collect();
let msgs: Vec<String> = errors.iter().map(|e| e.to_string()).collect(); anyhow::anyhow!("Route validation failed: {}", msgs.join(", "))
anyhow::anyhow!("Route validation failed: {}", msgs.join(", ")) })?;
})?;
let new_manager = RouteManager::new(routes.clone()); let new_manager = RouteManager::new(routes.clone());
let new_ports = new_manager.listening_ports(); let new_ports = new_manager.listening_ports();
info!("Updating routes: {} routes on {} ports", info!(
new_manager.route_count(), new_ports.len()); "Updating routes: {} routes on {} ports",
new_manager.route_count(),
new_ports.len()
);
// Get old ports // Get old ports
let old_ports: Vec<u16> = if let Some(ref listener) = self.listener_manager { let old_ports: Vec<u16> = if let Some(ref listener) = self.listener_manager {
@@ -682,28 +773,35 @@ impl RustProxy {
}; };
// Prune per-route metrics for route IDs that no longer exist // Prune per-route metrics for route IDs that no longer exist
let active_route_ids: HashSet<String> = routes.iter() let active_route_ids: HashSet<String> =
.filter_map(|r| r.id.clone()) routes.iter().filter_map(|r| r.id.clone()).collect();
.collect();
self.metrics.retain_routes(&active_route_ids); self.metrics.retain_routes(&active_route_ids);
// Prune per-backend metrics for backends no longer in any route target. // Prune per-backend metrics for backends no longer in any route target.
// For PortSpec::Preserve routes, expand across all listening ports since // For PortSpec::Preserve routes, expand across all listening ports since
// the actual runtime port depends on the incoming connection. // the actual runtime port depends on the incoming connection.
let listening_ports = self.get_listening_ports(); let listening_ports = self.get_listening_ports();
let active_backends: HashSet<String> = routes.iter() let active_backends: HashSet<String> = routes
.iter()
.filter_map(|r| r.action.targets.as_ref()) .filter_map(|r| r.action.targets.as_ref())
.flat_map(|targets| targets.iter()) .flat_map(|targets| targets.iter())
.flat_map(|target| { .flat_map(|target| {
let hosts: Vec<String> = target.host.to_vec().into_iter().map(|s| s.to_string()).collect(); let hosts: Vec<String> = target
.host
.to_vec()
.into_iter()
.map(|s| s.to_string())
.collect();
match &target.port { match &target.port {
rustproxy_config::PortSpec::Fixed(p) => { rustproxy_config::PortSpec::Fixed(p) => hosts
hosts.into_iter().map(|h| format!("{}:{}", h, p)).collect::<Vec<_>>() .into_iter()
} .map(|h| format!("{}:{}", h, p))
.collect::<Vec<_>>(),
_ => { _ => {
// Preserve/special: expand across all listening ports // Preserve/special: expand across all listening ports
let lp = &listening_ports; let lp = &listening_ports;
hosts.into_iter() hosts
.into_iter()
.flat_map(|h| lp.iter().map(move |p| format!("{}:{}", h, *p))) .flat_map(|h| lp.iter().map(move |p| format!("{}:{}", h, *p)))
.collect::<Vec<_>>() .collect::<Vec<_>>()
} }
@@ -731,10 +829,13 @@ impl RustProxy {
let cm = cm_arc.lock().await; let cm = cm_arc.lock().await;
for (domain, bundle) in cm.store().iter() { for (domain, bundle) in cm.store().iter() {
if !tls_configs.contains_key(domain) { if !tls_configs.contains_key(domain) {
tls_configs.insert(domain.clone(), TlsCertConfig { tls_configs.insert(
cert_pem: bundle.cert_pem.clone(), domain.clone(),
key_pem: bundle.key_pem.clone(), TlsCertConfig {
}); cert_pem: bundle.cert_pem.clone(),
key_pem: bundle.key_pem.clone(),
},
);
} }
} }
} }
@@ -751,7 +852,9 @@ impl RustProxy {
// Cancel connections on routes that were removed or disabled // Cancel connections on routes that were removed or disabled
listener.invalidate_removed_routes(&active_route_ids); listener.invalidate_removed_routes(&active_route_ids);
// Clean up registry entries for removed routes // Clean up registry entries for removed routes
listener.connection_registry().cleanup_removed_routes(&active_route_ids); listener
.connection_registry()
.cleanup_removed_routes(&active_route_ids);
// Prune HTTP proxy caches (rate limiters, regex cache, round-robin counters) // Prune HTTP proxy caches (rate limiters, regex cache, round-robin counters)
listener.prune_http_proxy_caches(&active_route_ids); listener.prune_http_proxy_caches(&active_route_ids);
@@ -764,9 +867,10 @@ impl RustProxy {
None => continue, None => continue,
}; };
// Find corresponding old route // Find corresponding old route
let old_route = old_manager.routes().iter().find(|r| { let old_route = old_manager
r.id.as_deref() == Some(new_id) .routes()
}); .iter()
.find(|r| r.id.as_deref() == Some(new_id));
let old_route = match old_route { let old_route = match old_route {
Some(r) => r, Some(r) => r,
None => continue, // new route, no existing connections to recycle None => continue, // new route, no existing connections to recycle
@@ -810,11 +914,13 @@ impl RustProxy {
{ {
let mut new_udp_ports = HashSet::new(); let mut new_udp_ports = HashSet::new();
for route in &routes { for route in &routes {
if !route.is_enabled() { continue; } if !route.is_enabled() {
continue;
}
let transport = route.route_match.transport.as_ref(); let transport = route.route_match.transport.as_ref();
match transport { match transport {
Some(rustproxy_config::TransportProtocol::Udp) | Some(rustproxy_config::TransportProtocol::Udp)
Some(rustproxy_config::TransportProtocol::All) => { | Some(rustproxy_config::TransportProtocol::All) => {
for port in route.route_match.ports.to_ports() { for port in route.route_match.ports.to_ports() {
new_udp_ports.insert(port); new_udp_ports.insert(port);
} }
@@ -823,7 +929,8 @@ impl RustProxy {
} }
} }
let old_udp_ports: HashSet<u16> = self.udp_listener_manager let old_udp_ports: HashSet<u16> = self
.udp_listener_manager
.as_ref() .as_ref()
.map(|u| u.listening_ports().into_iter().collect()) .map(|u| u.listening_ports().into_iter().collect())
.unwrap_or_default(); .unwrap_or_default();
@@ -845,6 +952,11 @@ impl RustProxy {
connection_registry, connection_registry,
); );
udp_mgr.set_proxy_ips(conn_config.proxy_ips); udp_mgr.set_proxy_ips(conn_config.proxy_ips);
udp_mgr.set_security_policy(Arc::clone(&self.security_policy));
// Wire up H3ProxyService so QUIC connections can serve HTTP/3
let http_proxy = listener.http_proxy().clone();
let h3_svc = rustproxy_http::h3_service::H3ProxyService::new(http_proxy);
udp_mgr.set_h3_service(std::sync::Arc::new(h3_svc));
self.udp_listener_manager = Some(udp_mgr); self.udp_listener_manager = Some(udp_mgr);
} }
} }
@@ -892,56 +1004,77 @@ impl RustProxy {
/// Provision a certificate for a named route. /// Provision a certificate for a named route.
pub async fn provision_certificate(&mut self, route_name: &str) -> Result<()> { pub async fn provision_certificate(&mut self, route_name: &str) -> Result<()> {
let cm_arc = self.cert_manager.as_ref() let cm_arc = self.cert_manager.as_ref().ok_or_else(|| {
.ok_or_else(|| anyhow::anyhow!("No certificate manager configured (ACME not enabled)"))?; anyhow::anyhow!("No certificate manager configured (ACME not enabled)")
})?;
// Find the route by name // Find the route by name
let route = self.options.routes.iter() let route = self
.options
.routes
.iter()
.find(|r| r.name.as_deref() == Some(route_name)) .find(|r| r.name.as_deref() == Some(route_name))
.ok_or_else(|| anyhow::anyhow!("Route '{}' not found", route_name))?; .ok_or_else(|| anyhow::anyhow!("Route '{}' not found", route_name))?;
let domain = route.route_match.domains.as_ref() let domain = route
.route_match
.domains
.as_ref()
.and_then(|d| d.to_vec().first().map(|s| s.to_string())) .and_then(|d| d.to_vec().first().map(|s| s.to_string()))
.ok_or_else(|| anyhow::anyhow!("Route '{}' has no domain", route_name))?; .ok_or_else(|| anyhow::anyhow!("Route '{}' has no domain", route_name))?;
info!("Provisioning certificate for route '{}' (domain: {})", route_name, domain); info!(
"Provisioning certificate for route '{}' (domain: {})",
route_name, domain
);
// Start challenge server // Start challenge server
let acme_port = self.options.acme.as_ref() let acme_port = self
.options
.acme
.as_ref()
.and_then(|a| a.port) .and_then(|a| a.port)
.unwrap_or(80); .unwrap_or(80);
let mut cs = challenge_server::ChallengeServer::new(); let mut cs = challenge_server::ChallengeServer::new();
cs.start(acme_port).await cs.start(acme_port)
.await
.map_err(|e| anyhow::anyhow!("Failed to start challenge server: {}", e))?; .map_err(|e| anyhow::anyhow!("Failed to start challenge server: {}", e))?;
let cs_ref = &cs; let cs_ref = &cs;
let mut cm = cm_arc.lock().await; let mut cm = cm_arc.lock().await;
let result = cm.renew_domain(&domain, |token, key_auth| { let result = cm
cs_ref.set_challenge(token, key_auth); .renew_domain(&domain, |token, key_auth| {
async {} cs_ref.set_challenge(token, key_auth);
}).await; async {}
})
.await;
drop(cm); drop(cm);
cs.stop().await; cs.stop().await;
let bundle = result let bundle = result.map_err(|e| anyhow::anyhow!("ACME provisioning failed: {}", e))?;
.map_err(|e| anyhow::anyhow!("ACME provisioning failed: {}", e))?;
// Hot-swap into TLS configs // Hot-swap into TLS configs
let mut tls_configs = Self::extract_tls_configs(&self.options.routes); let mut tls_configs = Self::extract_tls_configs(&self.options.routes);
tls_configs.insert(domain.clone(), TlsCertConfig { tls_configs.insert(
cert_pem: bundle.cert_pem.clone(), domain.clone(),
key_pem: bundle.key_pem.clone(), TlsCertConfig {
}); cert_pem: bundle.cert_pem.clone(),
key_pem: bundle.key_pem.clone(),
},
);
{ {
let cm = cm_arc.lock().await; let cm = cm_arc.lock().await;
for (d, b) in cm.store().iter() { for (d, b) in cm.store().iter() {
if !tls_configs.contains_key(d) { if !tls_configs.contains_key(d) {
tls_configs.insert(d.clone(), TlsCertConfig { tls_configs.insert(
cert_pem: b.cert_pem.clone(), d.clone(),
key_pem: b.key_pem.clone(), TlsCertConfig {
}); cert_pem: b.cert_pem.clone(),
key_pem: b.key_pem.clone(),
},
);
} }
} }
} }
@@ -960,7 +1093,10 @@ impl RustProxy {
} }
} }
info!("Certificate provisioned and loaded for route '{}'", route_name); info!(
"Certificate provisioned and loaded for route '{}'",
route_name
);
Ok(()) Ok(())
} }
@@ -972,10 +1108,16 @@ impl RustProxy {
/// Get the status of a certificate for a named route. /// Get the status of a certificate for a named route.
pub async fn get_certificate_status(&self, route_name: &str) -> Option<CertStatus> { pub async fn get_certificate_status(&self, route_name: &str) -> Option<CertStatus> {
let route = self.options.routes.iter() let route = self
.options
.routes
.iter()
.find(|r| r.name.as_deref() == Some(route_name))?; .find(|r| r.name.as_deref() == Some(route_name))?;
let domain = route.route_match.domains.as_ref() let domain = route
.route_match
.domains
.as_ref()
.and_then(|d| d.to_vec().first().map(|s| s.to_string()))?; .and_then(|d| d.to_vec().first().map(|s| s.to_string()))?;
if let Some(ref cm_arc) = self.cert_manager { if let Some(ref cm_arc) = self.cert_manager {
@@ -1004,8 +1146,9 @@ impl RustProxy {
let mut metrics = self.metrics.snapshot(); let mut metrics = self.metrics.snapshot();
if let Some(ref lm) = self.listener_manager { if let Some(ref lm) = self.listener_manager {
let entries = lm.http_proxy().protocol_cache_snapshot(); let entries = lm.http_proxy().protocol_cache_snapshot();
metrics.detected_protocols = entries.into_iter().map(|e| { metrics.detected_protocols = entries
rustproxy_metrics::ProtocolCacheEntryMetric { .into_iter()
.map(|e| rustproxy_metrics::ProtocolCacheEntryMetric {
host: e.host, host: e.host,
port: e.port, port: e.port,
domain: e.domain, domain: e.domain,
@@ -1020,8 +1163,8 @@ impl RustProxy {
h3_cooldown_remaining_secs: e.h3_cooldown_remaining_secs, h3_cooldown_remaining_secs: e.h3_cooldown_remaining_secs,
h2_consecutive_failures: e.h2_consecutive_failures, h2_consecutive_failures: e.h2_consecutive_failures,
h3_consecutive_failures: e.h3_consecutive_failures, h3_consecutive_failures: e.h3_consecutive_failures,
} })
}).collect(); .collect();
} }
metrics metrics
} }
@@ -1052,9 +1195,7 @@ impl RustProxy {
/// Get statistics snapshot. /// Get statistics snapshot.
pub fn get_statistics(&self) -> Statistics { pub fn get_statistics(&self) -> Statistics {
let uptime = self.started_at let uptime = self.started_at.map(|t| t.elapsed().as_secs()).unwrap_or(0);
.map(|t| t.elapsed().as_secs())
.unwrap_or(0);
Statistics { Statistics {
active_connections: self.metrics.active_connections(), active_connections: self.metrics.active_connections(),
@@ -1065,6 +1206,13 @@ impl RustProxy {
} }
} }
/// Update the global ingress security policy.
pub fn set_security_policy(&mut self, policy: rustproxy_config::SecurityPolicy) {
self.security_policy
.store(Arc::new(Self::build_ip_block_list(Some(&policy))));
self.options.security_policy = Some(policy);
}
/// Set the Unix domain socket path for relaying socket-handler connections to TypeScript. /// Set the Unix domain socket path for relaying socket-handler connections to TypeScript.
/// The path is shared with the TcpListenerManager via Arc<RwLock>, so updates /// The path is shared with the TcpListenerManager via Arc<RwLock>, so updates
/// take effect immediately for all new connections. /// take effect immediately for all new connections.
@@ -1124,10 +1272,13 @@ impl RustProxy {
let cm = cm_arc.lock().await; let cm = cm_arc.lock().await;
for (d, b) in cm.store().iter() { for (d, b) in cm.store().iter() {
if !configs.contains_key(d) { if !configs.contains_key(d) {
configs.insert(d.clone(), TlsCertConfig { configs.insert(
cert_pem: b.cert_pem.clone(), d.clone(),
key_pem: b.key_pem.clone(), TlsCertConfig {
}); cert_pem: b.cert_pem.clone(),
key_pem: b.key_pem.clone(),
},
);
} }
} }
} }
@@ -1160,7 +1311,8 @@ impl RustProxy {
info!("Loading certificate for domain: {}", domain); info!("Loading certificate for domain: {}", domain);
// Check if the cert actually changed (for selective connection recycling) // Check if the cert actually changed (for selective connection recycling)
let cert_changed = self.loaded_certs let cert_changed = self
.loaded_certs
.get(domain) .get(domain)
.map(|existing| existing.cert_pem != cert_pem) .map(|existing| existing.cert_pem != cert_pem)
.unwrap_or(false); // new domain = no existing connections to recycle .unwrap_or(false); // new domain = no existing connections to recycle
@@ -1190,10 +1342,13 @@ impl RustProxy {
} }
// Persist in loaded_certs so future rebuild calls include this cert // Persist in loaded_certs so future rebuild calls include this cert
self.loaded_certs.insert(domain.to_string(), TlsCertConfig { self.loaded_certs.insert(
cert_pem: cert_pem.clone(), domain.to_string(),
key_pem: key_pem.clone(), TlsCertConfig {
}); cert_pem: cert_pem.clone(),
key_pem: key_pem.clone(),
},
);
// Hot-swap TLS config on TCP and QUIC listeners // Hot-swap TLS config on TCP and QUIC listeners
let tls_configs = self.current_tls_configs().await; let tls_configs = self.current_tls_configs().await;
@@ -1216,7 +1371,9 @@ impl RustProxy {
// Recycle existing connections if cert actually changed // Recycle existing connections if cert actually changed
if cert_changed { if cert_changed {
if let Some(ref listener) = self.listener_manager { if let Some(ref listener) = self.listener_manager {
listener.connection_registry().recycle_for_cert_change(domain); listener
.connection_registry()
.recycle_for_cert_change(domain);
} }
} }
@@ -1238,16 +1395,22 @@ impl RustProxy {
continue; continue;
} }
let cert_spec = route.action.tls.as_ref() let cert_spec = route
.action
.tls
.as_ref()
.and_then(|tls| tls.certificate.as_ref()); .and_then(|tls| tls.certificate.as_ref());
if let Some(CertificateSpec::Static(cert_config)) = cert_spec { if let Some(CertificateSpec::Static(cert_config)) = cert_spec {
if let Some(ref domains) = route.route_match.domains { if let Some(ref domains) = route.route_match.domains {
for domain in domains.to_vec() { for domain in domains.to_vec() {
configs.insert(domain.to_string(), TlsCertConfig { configs.insert(
cert_pem: cert_config.cert.clone(), domain.to_string(),
key_pem: cert_config.key.clone(), TlsCertConfig {
}); cert_pem: cert_config.cert.clone(),
key_pem: cert_config.key.clone(),
},
);
} }
} }
} }
+4 -9
View File
@@ -1,12 +1,12 @@
#[global_allocator] #[global_allocator]
static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc; static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc;
use anyhow::Result;
use clap::Parser; use clap::Parser;
use tracing_subscriber::EnvFilter; use tracing_subscriber::EnvFilter;
use anyhow::Result;
use rustproxy::RustProxy;
use rustproxy::management; use rustproxy::management;
use rustproxy::RustProxy;
use rustproxy_config::RustProxyOptions; use rustproxy_config::RustProxyOptions;
/// RustProxy - High-performance multi-protocol proxy /// RustProxy - High-performance multi-protocol proxy
@@ -43,8 +43,7 @@ async fn main() -> Result<()> {
tracing_subscriber::fmt() tracing_subscriber::fmt()
.with_writer(std::io::stderr) .with_writer(std::io::stderr)
.with_env_filter( .with_env_filter(
EnvFilter::try_from_default_env() EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new(&cli.log_level)),
.unwrap_or_else(|_| EnvFilter::new(&cli.log_level))
) )
.init(); .init();
@@ -60,11 +59,7 @@ async fn main() -> Result<()> {
let options = RustProxyOptions::from_file(&cli.config) let options = RustProxyOptions::from_file(&cli.config)
.map_err(|e| anyhow::anyhow!("Failed to load config '{}': {}", cli.config, e))?; .map_err(|e| anyhow::anyhow!("Failed to load config '{}': {}", cli.config, e))?;
tracing::info!( tracing::info!("Loaded {} routes from {}", options.routes.len(), cli.config);
"Loaded {} routes from {}",
options.routes.len(),
cli.config
);
// Validate-only mode // Validate-only mode
if cli.validate { if cli.validate {
+157 -65
View File
@@ -1,7 +1,7 @@
use anyhow::Result; use anyhow::Result;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tokio::io::{AsyncBufReadExt, BufReader}; use tokio::io::{AsyncBufReadExt, BufReader};
use tracing::{info, error}; use tracing::{error, info};
use crate::RustProxy; use crate::RustProxy;
use rustproxy_config::RustProxyOptions; use rustproxy_config::RustProxyOptions;
@@ -141,14 +141,19 @@ async fn handle_request(
"start" => handle_start(&id, &request.params, proxy).await, "start" => handle_start(&id, &request.params, proxy).await,
"stop" => handle_stop(&id, proxy).await, "stop" => handle_stop(&id, proxy).await,
"updateRoutes" => handle_update_routes(&id, &request.params, proxy).await, "updateRoutes" => handle_update_routes(&id, &request.params, proxy).await,
"setSecurityPolicy" => handle_set_security_policy(&id, &request.params, proxy),
"getMetrics" => handle_get_metrics(&id, proxy), "getMetrics" => handle_get_metrics(&id, proxy),
"getStatistics" => handle_get_statistics(&id, proxy), "getStatistics" => handle_get_statistics(&id, proxy),
"provisionCertificate" => handle_provision_certificate(&id, &request.params, proxy).await, "provisionCertificate" => handle_provision_certificate(&id, &request.params, proxy).await,
"renewCertificate" => handle_renew_certificate(&id, &request.params, proxy).await, "renewCertificate" => handle_renew_certificate(&id, &request.params, proxy).await,
"getCertificateStatus" => handle_get_certificate_status(&id, &request.params, proxy).await, "getCertificateStatus" => handle_get_certificate_status(&id, &request.params, proxy).await,
"getListeningPorts" => handle_get_listening_ports(&id, proxy), "getListeningPorts" => handle_get_listening_ports(&id, proxy),
"setSocketHandlerRelay" => handle_set_socket_handler_relay(&id, &request.params, proxy).await, "setSocketHandlerRelay" => {
"setDatagramHandlerRelay" => handle_set_datagram_handler_relay(&id, &request.params, proxy).await, handle_set_socket_handler_relay(&id, &request.params, proxy).await
}
"setDatagramHandlerRelay" => {
handle_set_datagram_handler_relay(&id, &request.params, proxy).await
}
"addListeningPort" => handle_add_listening_port(&id, &request.params, proxy).await, "addListeningPort" => handle_add_listening_port(&id, &request.params, proxy).await,
"removeListeningPort" => handle_remove_listening_port(&id, &request.params, proxy).await, "removeListeningPort" => handle_remove_listening_port(&id, &request.params, proxy).await,
"loadCertificate" => handle_load_certificate(&id, &request.params, proxy).await, "loadCertificate" => handle_load_certificate(&id, &request.params, proxy).await,
@@ -167,7 +172,12 @@ async fn handle_start(
let config = match params.get("config") { let config = match params.get("config") {
Some(config) => config, Some(config) => config,
None => return ManagementResponse::err(id.to_string(), "Missing 'config' parameter".to_string()), None => {
return ManagementResponse::err(
id.to_string(),
"Missing 'config' parameter".to_string(),
)
}
}; };
let options: RustProxyOptions = match serde_json::from_value(config.clone()) { let options: RustProxyOptions = match serde_json::from_value(config.clone()) {
@@ -176,38 +186,31 @@ async fn handle_start(
}; };
match RustProxy::new(options) { match RustProxy::new(options) {
Ok(mut p) => { Ok(mut p) => match p.start().await {
match p.start().await { Ok(()) => {
Ok(()) => { send_event("started", serde_json::json!({}));
send_event("started", serde_json::json!({})); *proxy = Some(p);
*proxy = Some(p); ManagementResponse::ok(id.to_string(), serde_json::json!({}))
ManagementResponse::ok(id.to_string(), serde_json::json!({}))
}
Err(e) => {
send_event("error", serde_json::json!({"message": format!("{}", e)}));
ManagementResponse::err(id.to_string(), format!("Failed to start: {}", e))
}
} }
} Err(e) => {
send_event("error", serde_json::json!({"message": format!("{}", e)}));
ManagementResponse::err(id.to_string(), format!("Failed to start: {}", e))
}
},
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to create proxy: {}", e)), Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to create proxy: {}", e)),
} }
} }
async fn handle_stop( async fn handle_stop(id: &str, proxy: &mut Option<RustProxy>) -> ManagementResponse {
id: &str,
proxy: &mut Option<RustProxy>,
) -> ManagementResponse {
match proxy.as_mut() { match proxy.as_mut() {
Some(p) => { Some(p) => match p.stop().await {
match p.stop().await { Ok(()) => {
Ok(()) => { *proxy = None;
*proxy = None; send_event("stopped", serde_json::json!({}));
send_event("stopped", serde_json::json!({})); ManagementResponse::ok(id.to_string(), serde_json::json!({}))
ManagementResponse::ok(id.to_string(), serde_json::json!({}))
}
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to stop: {}", e)),
} }
} Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to stop: {}", e)),
},
None => ManagementResponse::ok(id.to_string(), serde_json::json!({})), None => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
} }
} }
@@ -224,7 +227,12 @@ async fn handle_update_routes(
let routes = match params.get("routes") { let routes = match params.get("routes") {
Some(routes) => routes, Some(routes) => routes,
None => return ManagementResponse::err(id.to_string(), "Missing 'routes' parameter".to_string()), None => {
return ManagementResponse::err(
id.to_string(),
"Missing 'routes' parameter".to_string(),
)
}
}; };
let routes: Vec<rustproxy_config::RouteConfig> = match serde_json::from_value(routes.clone()) { let routes: Vec<rustproxy_config::RouteConfig> = match serde_json::from_value(routes.clone()) {
@@ -234,36 +242,72 @@ async fn handle_update_routes(
match p.update_routes(routes).await { match p.update_routes(routes).await {
Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})), Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to update routes: {}", e)), Err(e) => {
ManagementResponse::err(id.to_string(), format!("Failed to update routes: {}", e))
}
} }
} }
fn handle_get_metrics( fn handle_set_security_policy(
id: &str, id: &str,
proxy: &Option<RustProxy>, params: &serde_json::Value,
proxy: &mut Option<RustProxy>,
) -> ManagementResponse { ) -> ManagementResponse {
let p = match proxy.as_mut() {
Some(p) => p,
None => return ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
};
let policy = match params.get("policy") {
Some(policy) => policy,
None => {
return ManagementResponse::err(
id.to_string(),
"Missing 'policy' parameter".to_string(),
)
}
};
let policy: rustproxy_config::SecurityPolicy = match serde_json::from_value(policy.clone()) {
Ok(policy) => policy,
Err(e) => {
return ManagementResponse::err(
id.to_string(),
format!("Invalid security policy: {}", e),
)
}
};
p.set_security_policy(policy);
ManagementResponse::ok(id.to_string(), serde_json::json!({}))
}
fn handle_get_metrics(id: &str, proxy: &Option<RustProxy>) -> ManagementResponse {
match proxy.as_ref() { match proxy.as_ref() {
Some(p) => { Some(p) => {
let metrics = p.get_metrics(); let metrics = p.get_metrics();
match serde_json::to_value(&metrics) { match serde_json::to_value(&metrics) {
Ok(v) => ManagementResponse::ok(id.to_string(), v), Ok(v) => ManagementResponse::ok(id.to_string(), v),
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to serialize metrics: {}", e)), Err(e) => ManagementResponse::err(
id.to_string(),
format!("Failed to serialize metrics: {}", e),
),
} }
} }
None => ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()), None => ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
} }
} }
fn handle_get_statistics( fn handle_get_statistics(id: &str, proxy: &Option<RustProxy>) -> ManagementResponse {
id: &str,
proxy: &Option<RustProxy>,
) -> ManagementResponse {
match proxy.as_ref() { match proxy.as_ref() {
Some(p) => { Some(p) => {
let stats = p.get_statistics(); let stats = p.get_statistics();
match serde_json::to_value(&stats) { match serde_json::to_value(&stats) {
Ok(v) => ManagementResponse::ok(id.to_string(), v), Ok(v) => ManagementResponse::ok(id.to_string(), v),
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to serialize statistics: {}", e)), Err(e) => ManagementResponse::err(
id.to_string(),
format!("Failed to serialize statistics: {}", e),
),
} }
} }
None => ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()), None => ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
@@ -282,12 +326,20 @@ async fn handle_provision_certificate(
let route_name = match params.get("routeName").and_then(|v| v.as_str()) { let route_name = match params.get("routeName").and_then(|v| v.as_str()) {
Some(name) => name.to_string(), Some(name) => name.to_string(),
None => return ManagementResponse::err(id.to_string(), "Missing 'routeName' parameter".to_string()), None => {
return ManagementResponse::err(
id.to_string(),
"Missing 'routeName' parameter".to_string(),
)
}
}; };
match p.provision_certificate(&route_name).await { match p.provision_certificate(&route_name).await {
Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})), Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to provision certificate: {}", e)), Err(e) => ManagementResponse::err(
id.to_string(),
format!("Failed to provision certificate: {}", e),
),
} }
} }
@@ -303,12 +355,20 @@ async fn handle_renew_certificate(
let route_name = match params.get("routeName").and_then(|v| v.as_str()) { let route_name = match params.get("routeName").and_then(|v| v.as_str()) {
Some(name) => name.to_string(), Some(name) => name.to_string(),
None => return ManagementResponse::err(id.to_string(), "Missing 'routeName' parameter".to_string()), None => {
return ManagementResponse::err(
id.to_string(),
"Missing 'routeName' parameter".to_string(),
)
}
}; };
match p.renew_certificate(&route_name).await { match p.renew_certificate(&route_name).await {
Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})), Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to renew certificate: {}", e)), Err(e) => ManagementResponse::err(
id.to_string(),
format!("Failed to renew certificate: {}", e),
),
} }
} }
@@ -324,24 +384,29 @@ async fn handle_get_certificate_status(
let route_name = match params.get("routeName").and_then(|v| v.as_str()) { let route_name = match params.get("routeName").and_then(|v| v.as_str()) {
Some(name) => name, Some(name) => name,
None => return ManagementResponse::err(id.to_string(), "Missing 'routeName' parameter".to_string()), None => {
return ManagementResponse::err(
id.to_string(),
"Missing 'routeName' parameter".to_string(),
)
}
}; };
match p.get_certificate_status(route_name).await { match p.get_certificate_status(route_name).await {
Some(status) => ManagementResponse::ok(id.to_string(), serde_json::json!({ Some(status) => ManagementResponse::ok(
"domain": status.domain, id.to_string(),
"source": status.source, serde_json::json!({
"expiresAt": status.expires_at, "domain": status.domain,
"isValid": status.is_valid, "source": status.source,
})), "expiresAt": status.expires_at,
"isValid": status.is_valid,
}),
),
None => ManagementResponse::ok(id.to_string(), serde_json::Value::Null), None => ManagementResponse::ok(id.to_string(), serde_json::Value::Null),
} }
} }
fn handle_get_listening_ports( fn handle_get_listening_ports(id: &str, proxy: &Option<RustProxy>) -> ManagementResponse {
id: &str,
proxy: &Option<RustProxy>,
) -> ManagementResponse {
match proxy.as_ref() { match proxy.as_ref() {
Some(p) => { Some(p) => {
let ports = p.get_listening_ports(); let ports = p.get_listening_ports();
@@ -361,7 +426,8 @@ async fn handle_set_socket_handler_relay(
None => return ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()), None => return ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
}; };
let socket_path = params.get("socketPath") let socket_path = params
.get("socketPath")
.and_then(|v| v.as_str()) .and_then(|v| v.as_str())
.map(|s| s.to_string()); .map(|s| s.to_string());
@@ -381,7 +447,8 @@ async fn handle_set_datagram_handler_relay(
None => return ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()), None => return ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
}; };
let socket_path = params.get("socketPath") let socket_path = params
.get("socketPath")
.and_then(|v| v.as_str()) .and_then(|v| v.as_str())
.map(|s| s.to_string()); .map(|s| s.to_string());
@@ -403,12 +470,17 @@ async fn handle_add_listening_port(
let port = match params.get("port").and_then(|v| v.as_u64()) { let port = match params.get("port").and_then(|v| v.as_u64()) {
Some(port) => port as u16, Some(port) => port as u16,
None => return ManagementResponse::err(id.to_string(), "Missing 'port' parameter".to_string()), None => {
return ManagementResponse::err(id.to_string(), "Missing 'port' parameter".to_string())
}
}; };
match p.add_listening_port(port).await { match p.add_listening_port(port).await {
Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})), Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to add port {}: {}", port, e)), Err(e) => ManagementResponse::err(
id.to_string(),
format!("Failed to add port {}: {}", port, e),
),
} }
} }
@@ -424,12 +496,17 @@ async fn handle_remove_listening_port(
let port = match params.get("port").and_then(|v| v.as_u64()) { let port = match params.get("port").and_then(|v| v.as_u64()) {
Some(port) => port as u16, Some(port) => port as u16,
None => return ManagementResponse::err(id.to_string(), "Missing 'port' parameter".to_string()), None => {
return ManagementResponse::err(id.to_string(), "Missing 'port' parameter".to_string())
}
}; };
match p.remove_listening_port(port).await { match p.remove_listening_port(port).await {
Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})), Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to remove port {}: {}", port, e)), Err(e) => ManagementResponse::err(
id.to_string(),
format!("Failed to remove port {}: {}", port, e),
),
} }
} }
@@ -445,26 +522,41 @@ async fn handle_load_certificate(
let domain = match params.get("domain").and_then(|v| v.as_str()) { let domain = match params.get("domain").and_then(|v| v.as_str()) {
Some(d) => d.to_string(), Some(d) => d.to_string(),
None => return ManagementResponse::err(id.to_string(), "Missing 'domain' parameter".to_string()), None => {
return ManagementResponse::err(
id.to_string(),
"Missing 'domain' parameter".to_string(),
)
}
}; };
let cert = match params.get("cert").and_then(|v| v.as_str()) { let cert = match params.get("cert").and_then(|v| v.as_str()) {
Some(c) => c.to_string(), Some(c) => c.to_string(),
None => return ManagementResponse::err(id.to_string(), "Missing 'cert' parameter".to_string()), None => {
return ManagementResponse::err(id.to_string(), "Missing 'cert' parameter".to_string())
}
}; };
let key = match params.get("key").and_then(|v| v.as_str()) { let key = match params.get("key").and_then(|v| v.as_str()) {
Some(k) => k.to_string(), Some(k) => k.to_string(),
None => return ManagementResponse::err(id.to_string(), "Missing 'key' parameter".to_string()), None => {
return ManagementResponse::err(id.to_string(), "Missing 'key' parameter".to_string())
}
}; };
let ca = params.get("ca").and_then(|v| v.as_str()).map(|s| s.to_string()); let ca = params
.get("ca")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
info!("loadCertificate: domain={}", domain); info!("loadCertificate: domain={}", domain);
// Load cert into cert manager and hot-swap TLS config // Load cert into cert manager and hot-swap TLS config
match p.load_certificate(&domain, cert, key, ca).await { match p.load_certificate(&domain, cert, key, ca).await {
Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})), Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to load certificate for {}: {}", domain, e)), Err(e) => ManagementResponse::err(
id.to_string(),
format!("Failed to load certificate for {}: {}", domain, e),
),
} }
} }
+8 -8
View File
@@ -136,7 +136,8 @@ pub async fn start_http_echo_backend(port: u16, backend_name: &str) -> JoinHandl
let path = parts.get(1).copied().unwrap_or("/"); let path = parts.get(1).copied().unwrap_or("/");
// Extract Host header // Extract Host header
let host = req_str.lines() let host = req_str
.lines()
.find(|l| l.to_lowercase().starts_with("host:")) .find(|l| l.to_lowercase().starts_with("host:"))
.map(|l| l[5..].trim()) .map(|l| l[5..].trim())
.unwrap_or("unknown"); .unwrap_or("unknown");
@@ -336,7 +337,8 @@ pub async fn start_ws_echo_backend(port: u16) -> JoinHandle<()> {
let req_str = String::from_utf8_lossy(&buf[..n]); let req_str = String::from_utf8_lossy(&buf[..n]);
// Extract Sec-WebSocket-Key for proper handshake // Extract Sec-WebSocket-Key for proper handshake
let ws_key = req_str.lines() let ws_key = req_str
.lines()
.find(|l| l.to_lowercase().starts_with("sec-websocket-key:")) .find(|l| l.to_lowercase().starts_with("sec-websocket-key:"))
.map(|l| l.split(':').nth(1).unwrap_or("").trim().to_string()) .map(|l| l.split(':').nth(1).unwrap_or("").trim().to_string())
.unwrap_or_default(); .unwrap_or_default();
@@ -378,7 +380,9 @@ pub fn generate_self_signed_cert(domain: &str) -> (String, String) {
use rcgen::{CertificateParams, KeyPair}; use rcgen::{CertificateParams, KeyPair};
let mut params = CertificateParams::new(vec![domain.to_string()]).unwrap(); let mut params = CertificateParams::new(vec![domain.to_string()]).unwrap();
params.distinguished_name.push(rcgen::DnType::CommonName, domain); params
.distinguished_name
.push(rcgen::DnType::CommonName, domain);
let key_pair = KeyPair::generate().unwrap(); let key_pair = KeyPair::generate().unwrap();
let cert = params.self_signed(&key_pair).unwrap(); let cert = params.self_signed(&key_pair).unwrap();
@@ -458,11 +462,7 @@ pub fn make_tls_terminate_route(
/// Start a TLS WebSocket echo backend: accepts TLS, performs WS handshake, then echoes data. /// Start a TLS WebSocket echo backend: accepts TLS, performs WS handshake, then echoes data.
/// Combines TLS acceptance (like `start_tls_http_backend`) with WebSocket echo (like `start_ws_echo_backend`). /// Combines TLS acceptance (like `start_tls_http_backend`) with WebSocket echo (like `start_ws_echo_backend`).
pub async fn start_tls_ws_echo_backend( pub async fn start_tls_ws_echo_backend(port: u16, cert_pem: &str, key_pem: &str) -> JoinHandle<()> {
port: u16,
cert_pem: &str,
key_pem: &str,
) -> JoinHandle<()> {
use std::sync::Arc; use std::sync::Arc;
let acceptor = rustproxy_passthrough::build_tls_acceptor(cert_pem, key_pem) let acceptor = rustproxy_passthrough::build_tls_acceptor(cert_pem, key_pem)
@@ -1,9 +1,9 @@
mod common; mod common;
use bytes::Buf;
use common::*; use common::*;
use rustproxy::RustProxy; use rustproxy::RustProxy;
use rustproxy_config::{RustProxyOptions, TransportProtocol, RouteUdp, RouteQuic}; use rustproxy_config::{RouteQuic, RouteUdp, RustProxyOptions, TransportProtocol};
use bytes::Buf;
use std::sync::Arc; use std::sync::Arc;
/// Build a route that listens on UDP with HTTP/3 enabled and TLS terminate. /// Build a route that listens on UDP with HTTP/3 enabled and TLS terminate.
@@ -14,7 +14,14 @@ fn make_h3_route(
cert_pem: &str, cert_pem: &str,
key_pem: &str, key_pem: &str,
) -> rustproxy_config::RouteConfig { ) -> rustproxy_config::RouteConfig {
let mut route = make_tls_terminate_route(port, "localhost", target_host, target_port, cert_pem, key_pem); let mut route = make_tls_terminate_route(
port,
"localhost",
target_host,
target_port,
cert_pem,
key_pem,
);
route.route_match.transport = Some(TransportProtocol::All); route.route_match.transport = Some(TransportProtocol::All);
// Keep domain="localhost" from make_tls_terminate_route — needed for TLS cert extraction // Keep domain="localhost" from make_tls_terminate_route — needed for TLS cert extraction
route.action.udp = Some(RouteUdp { route.action.udp = Some(RouteUdp {
@@ -89,11 +96,9 @@ async fn test_h3_response_stream_finishes() {
.await .await
.expect("QUIC handshake failed"); .expect("QUIC handshake failed");
let (mut driver, mut send_request) = h3::client::new( let (mut driver, mut send_request) = h3::client::new(h3_quinn::Connection::new(connection))
h3_quinn::Connection::new(connection), .await
) .expect("H3 connection setup failed");
.await
.expect("H3 connection setup failed");
// Drive the H3 connection in background // Drive the H3 connection in background
tokio::spawn(async move { tokio::spawn(async move {
@@ -108,33 +113,46 @@ async fn test_h3_response_stream_finishes() {
.body(()) .body(())
.unwrap(); .unwrap();
let mut stream = send_request.send_request(req).await let mut stream = send_request
.send_request(req)
.await
.expect("Failed to send H3 request"); .expect("Failed to send H3 request");
stream.finish().await stream
.finish()
.await
.expect("Failed to finish sending H3 request body"); .expect("Failed to finish sending H3 request body");
// 6. Read response headers // 6. Read response headers
let resp = stream.recv_response().await let resp = stream
.recv_response()
.await
.expect("Failed to receive H3 response"); .expect("Failed to receive H3 response");
assert_eq!(resp.status(), http::StatusCode::OK, assert_eq!(
"Expected 200 OK, got {}", resp.status()); resp.status(),
http::StatusCode::OK,
"Expected 200 OK, got {}",
resp.status()
);
// 7. Read body and verify stream ends (FIN received) // 7. Read body and verify stream ends (FIN received)
// This is the critical assertion: recv_data() must return None (stream ended) // This is the critical assertion: recv_data() must return None (stream ended)
// within the timeout, NOT hang forever waiting for a FIN that never arrives. // within the timeout, NOT hang forever waiting for a FIN that never arrives.
let result = with_timeout(async { let result = with_timeout(
let mut total = 0usize; async {
while let Some(chunk) = stream.recv_data().await.expect("H3 data receive error") { let mut total = 0usize;
total += chunk.remaining(); while let Some(chunk) = stream.recv_data().await.expect("H3 data receive error") {
} total += chunk.remaining();
// recv_data() returned None => stream ended (FIN received) }
total // recv_data() returned None => stream ended (FIN received)
}, 10) total
},
10,
)
.await; .await;
let bytes_received = result.expect( let bytes_received = result.expect(
"TIMEOUT: H3 stream never ended (FIN not received by client). \ "TIMEOUT: H3 stream never ended (FIN not received by client). \
The proxy sent all response data but failed to send the QUIC stream FIN." The proxy sent all response data but failed to send the QUIC stream FIN.",
); );
assert_eq!( assert_eq!(
bytes_received, bytes_received,
@@ -43,17 +43,32 @@ async fn test_http_forward_basic() {
proxy.start().await.unwrap(); proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await); assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async { let result = with_timeout(
let response = send_http_request(proxy_port, "anyhost.com", "GET", "/hello").await; async {
let body = extract_body(&response); let response = send_http_request(proxy_port, "anyhost.com", "GET", "/hello").await;
body.to_string() let body = extract_body(&response);
}, 10) body.to_string()
},
10,
)
.await .await
.unwrap(); .unwrap();
assert!(result.contains(r#""method":"GET"#), "Expected GET method, got: {}", result); assert!(
assert!(result.contains(r#""path":"/hello"#), "Expected /hello path, got: {}", result); result.contains(r#""method":"GET"#),
assert!(result.contains(r#""backend":"main"#), "Expected main backend, got: {}", result); "Expected GET method, got: {}",
result
);
assert!(
result.contains(r#""path":"/hello"#),
"Expected /hello path, got: {}",
result
);
assert!(
result.contains(r#""backend":"main"#),
"Expected main backend, got: {}",
result
);
proxy.stop().await.unwrap(); proxy.stop().await.unwrap();
} }
@@ -69,8 +84,18 @@ async fn test_http_forward_host_routing() {
let options = RustProxyOptions { let options = RustProxyOptions {
routes: vec![ routes: vec![
make_test_route(proxy_port, Some("alpha.example.com"), "127.0.0.1", backend1_port), make_test_route(
make_test_route(proxy_port, Some("beta.example.com"), "127.0.0.1", backend2_port), proxy_port,
Some("alpha.example.com"),
"127.0.0.1",
backend1_port,
),
make_test_route(
proxy_port,
Some("beta.example.com"),
"127.0.0.1",
backend2_port,
),
], ],
..Default::default() ..Default::default()
}; };
@@ -80,24 +105,38 @@ async fn test_http_forward_host_routing() {
assert!(wait_for_port(proxy_port, 2000).await); assert!(wait_for_port(proxy_port, 2000).await);
// Test alpha domain // Test alpha domain
let alpha_result = with_timeout(async { let alpha_result = with_timeout(
let response = send_http_request(proxy_port, "alpha.example.com", "GET", "/").await; async {
extract_body(&response).to_string() let response = send_http_request(proxy_port, "alpha.example.com", "GET", "/").await;
}, 10) extract_body(&response).to_string()
},
10,
)
.await .await
.unwrap(); .unwrap();
assert!(alpha_result.contains(r#""backend":"alpha"#), "Expected alpha backend, got: {}", alpha_result); assert!(
alpha_result.contains(r#""backend":"alpha"#),
"Expected alpha backend, got: {}",
alpha_result
);
// Test beta domain // Test beta domain
let beta_result = with_timeout(async { let beta_result = with_timeout(
let response = send_http_request(proxy_port, "beta.example.com", "GET", "/").await; async {
extract_body(&response).to_string() let response = send_http_request(proxy_port, "beta.example.com", "GET", "/").await;
}, 10) extract_body(&response).to_string()
},
10,
)
.await .await
.unwrap(); .unwrap();
assert!(beta_result.contains(r#""backend":"beta"#), "Expected beta backend, got: {}", beta_result); assert!(
beta_result.contains(r#""backend":"beta"#),
"Expected beta backend, got: {}",
beta_result
);
proxy.stop().await.unwrap(); proxy.stop().await.unwrap();
} }
@@ -127,24 +166,38 @@ async fn test_http_forward_path_routing() {
assert!(wait_for_port(proxy_port, 2000).await); assert!(wait_for_port(proxy_port, 2000).await);
// Test API path // Test API path
let api_result = with_timeout(async { let api_result = with_timeout(
let response = send_http_request(proxy_port, "any.com", "GET", "/api/users").await; async {
extract_body(&response).to_string() let response = send_http_request(proxy_port, "any.com", "GET", "/api/users").await;
}, 10) extract_body(&response).to_string()
},
10,
)
.await .await
.unwrap(); .unwrap();
assert!(api_result.contains(r#""backend":"api"#), "Expected api backend, got: {}", api_result); assert!(
api_result.contains(r#""backend":"api"#),
"Expected api backend, got: {}",
api_result
);
// Test web path (no /api prefix) // Test web path (no /api prefix)
let web_result = with_timeout(async { let web_result = with_timeout(
let response = send_http_request(proxy_port, "any.com", "GET", "/index.html").await; async {
extract_body(&response).to_string() let response = send_http_request(proxy_port, "any.com", "GET", "/index.html").await;
}, 10) extract_body(&response).to_string()
},
10,
)
.await .await
.unwrap(); .unwrap();
assert!(web_result.contains(r#""backend":"web"#), "Expected web backend, got: {}", web_result); assert!(
web_result.contains(r#""backend":"web"#),
"Expected web backend, got: {}",
web_result
);
proxy.stop().await.unwrap(); proxy.stop().await.unwrap();
} }
@@ -184,9 +237,18 @@ async fn test_http_forward_cors_preflight() {
.unwrap(); .unwrap();
// Should get 204 No Content with CORS headers // Should get 204 No Content with CORS headers
assert!(result.contains("204"), "Expected 204 status, got: {}", result); assert!(
assert!(result.to_lowercase().contains("access-control-allow-origin"), result.contains("204"),
"Expected CORS header, got: {}", result); "Expected 204 status, got: {}",
result
);
assert!(
result
.to_lowercase()
.contains("access-control-allow-origin"),
"Expected CORS header, got: {}",
result
);
proxy.stop().await.unwrap(); proxy.stop().await.unwrap();
} }
@@ -208,15 +270,22 @@ async fn test_http_forward_backend_error() {
proxy.start().await.unwrap(); proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await); assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async { let result = with_timeout(
let response = send_http_request(proxy_port, "example.com", "GET", "/fail").await; async {
response let response = send_http_request(proxy_port, "example.com", "GET", "/fail").await;
}, 10) response
},
10,
)
.await .await
.unwrap(); .unwrap();
// Proxy should relay the 500 from backend // Proxy should relay the 500 from backend
assert!(result.contains("500"), "Expected 500 status, got: {}", result); assert!(
result.contains("500"),
"Expected 500 status, got: {}",
result
);
proxy.stop().await.unwrap(); proxy.stop().await.unwrap();
} }
@@ -227,7 +296,12 @@ async fn test_http_forward_no_route_matched() {
// Create a route only for a specific domain // Create a route only for a specific domain
let options = RustProxyOptions { let options = RustProxyOptions {
routes: vec![make_test_route(proxy_port, Some("known.example.com"), "127.0.0.1", 9999)], routes: vec![make_test_route(
proxy_port,
Some("known.example.com"),
"127.0.0.1",
9999,
)],
..Default::default() ..Default::default()
}; };
@@ -235,15 +309,22 @@ async fn test_http_forward_no_route_matched() {
proxy.start().await.unwrap(); proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await); assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async { let result = with_timeout(
let response = send_http_request(proxy_port, "unknown.example.com", "GET", "/").await; async {
response let response = send_http_request(proxy_port, "unknown.example.com", "GET", "/").await;
}, 10) response
},
10,
)
.await .await
.unwrap(); .unwrap();
// Should get 502 Bad Gateway (no route matched) // Should get 502 Bad Gateway (no route matched)
assert!(result.contains("502"), "Expected 502 status, got: {}", result); assert!(
result.contains("502"),
"Expected 502 status, got: {}",
result
);
proxy.stop().await.unwrap(); proxy.stop().await.unwrap();
} }
@@ -262,15 +343,22 @@ async fn test_http_forward_backend_unavailable() {
proxy.start().await.unwrap(); proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await); assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async { let result = with_timeout(
let response = send_http_request(proxy_port, "example.com", "GET", "/").await; async {
response let response = send_http_request(proxy_port, "example.com", "GET", "/").await;
}, 10) response
},
10,
)
.await .await
.unwrap(); .unwrap();
// Should get 502 Bad Gateway (backend unavailable) // Should get 502 Bad Gateway (backend unavailable)
assert!(result.contains("502"), "Expected 502 status, got: {}", result); assert!(
result.contains("502"),
"Expected 502 status, got: {}",
result
);
proxy.stop().await.unwrap(); proxy.stop().await.unwrap();
} }
@@ -286,7 +374,12 @@ async fn test_https_terminate_http_forward() {
let options = RustProxyOptions { let options = RustProxyOptions {
routes: vec![make_tls_terminate_route( routes: vec![make_tls_terminate_route(
proxy_port, domain, "127.0.0.1", backend_port, &cert_pem, &key_pem, proxy_port,
domain,
"127.0.0.1",
backend_port,
&cert_pem,
&key_pem,
)], )],
..Default::default() ..Default::default()
}; };
@@ -295,38 +388,53 @@ async fn test_https_terminate_http_forward() {
proxy.start().await.unwrap(); proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await); assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async { let result = with_timeout(
let _ = rustls::crypto::ring::default_provider().install_default(); async {
let tls_config = rustls::ClientConfig::builder() let _ = rustls::crypto::ring::default_provider().install_default();
.dangerous() let tls_config = rustls::ClientConfig::builder()
.with_custom_certificate_verifier(std::sync::Arc::new(InsecureVerifier)) .dangerous()
.with_no_client_auth(); .with_custom_certificate_verifier(std::sync::Arc::new(InsecureVerifier))
let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(tls_config)); .with_no_client_auth();
let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(tls_config));
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port)) let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await .await
.unwrap(); .unwrap();
let server_name = rustls::pki_types::ServerName::try_from(domain.to_string()).unwrap(); let server_name = rustls::pki_types::ServerName::try_from(domain.to_string()).unwrap();
let mut tls_stream = connector.connect(server_name, stream).await.unwrap(); let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
// Send HTTP request through TLS // Send HTTP request through TLS
let request = format!( let request = format!(
"GET /api/data HTTP/1.1\r\nHost: {}\r\nConnection: close\r\n\r\n", "GET /api/data HTTP/1.1\r\nHost: {}\r\nConnection: close\r\n\r\n",
domain domain
); );
tls_stream.write_all(request.as_bytes()).await.unwrap(); tls_stream.write_all(request.as_bytes()).await.unwrap();
let mut response = Vec::new(); let mut response = Vec::new();
tls_stream.read_to_end(&mut response).await.unwrap(); tls_stream.read_to_end(&mut response).await.unwrap();
String::from_utf8_lossy(&response).to_string() String::from_utf8_lossy(&response).to_string()
}, 10) },
10,
)
.await .await
.unwrap(); .unwrap();
let body = extract_body(&result); let body = extract_body(&result);
assert!(body.contains(r#""method":"GET"#), "Expected GET, got: {}", body); assert!(
assert!(body.contains(r#""path":"/api/data"#), "Expected /api/data, got: {}", body); body.contains(r#""method":"GET"#),
assert!(body.contains(r#""backend":"tls-backend"#), "Expected tls-backend, got: {}", body); "Expected GET, got: {}",
body
);
assert!(
body.contains(r#""path":"/api/data"#),
"Expected /api/data, got: {}",
body
);
assert!(
body.contains(r#""backend":"tls-backend"#),
"Expected tls-backend, got: {}",
body
);
proxy.stop().await.unwrap(); proxy.stop().await.unwrap();
} }
@@ -347,59 +455,68 @@ async fn test_websocket_through_proxy() {
proxy.start().await.unwrap(); proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await); assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async { let result = with_timeout(
let mut stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port)) async {
.await let mut stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.unwrap(); .await
.unwrap();
// Send WebSocket upgrade request // Send WebSocket upgrade request
let request = format!( let request = format!(
"GET /ws HTTP/1.1\r\n\ "GET /ws HTTP/1.1\r\n\
Host: example.com\r\n\ Host: example.com\r\n\
Upgrade: websocket\r\n\ Upgrade: websocket\r\n\
Connection: Upgrade\r\n\ Connection: Upgrade\r\n\
Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n\ Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n\
Sec-WebSocket-Version: 13\r\n\ Sec-WebSocket-Version: 13\r\n\
\r\n" \r\n"
); );
stream.write_all(request.as_bytes()).await.unwrap(); stream.write_all(request.as_bytes()).await.unwrap();
// Read the 101 response // Read the 101 response
let mut response_buf = Vec::with_capacity(4096); let mut response_buf = Vec::with_capacity(4096);
let mut temp = [0u8; 1]; let mut temp = [0u8; 1];
loop { loop {
let n = stream.read(&mut temp).await.unwrap(); let n = stream.read(&mut temp).await.unwrap();
if n == 0 { break; } if n == 0 {
response_buf.push(temp[0]);
if response_buf.len() >= 4 {
let len = response_buf.len();
if response_buf[len-4..] == *b"\r\n\r\n" {
break; break;
} }
response_buf.push(temp[0]);
if response_buf.len() >= 4 {
let len = response_buf.len();
if response_buf[len - 4..] == *b"\r\n\r\n" {
break;
}
}
} }
}
let response_str = String::from_utf8_lossy(&response_buf).to_string(); let response_str = String::from_utf8_lossy(&response_buf).to_string();
assert!(response_str.contains("101"), "Expected 101 Switching Protocols, got: {}", response_str); assert!(
assert!( response_str.contains("101"),
response_str.to_lowercase().contains("upgrade: websocket"), "Expected 101 Switching Protocols, got: {}",
"Expected Upgrade header, got: {}", response_str
response_str );
); assert!(
response_str.to_lowercase().contains("upgrade: websocket"),
"Expected Upgrade header, got: {}",
response_str
);
// After upgrade, send data and verify echo // After upgrade, send data and verify echo
let test_data = b"Hello WebSocket!"; let test_data = b"Hello WebSocket!";
stream.write_all(test_data).await.unwrap(); stream.write_all(test_data).await.unwrap();
// Read echoed data // Read echoed data
let mut echo_buf = vec![0u8; 256]; let mut echo_buf = vec![0u8; 256];
let n = stream.read(&mut echo_buf).await.unwrap(); let n = stream.read(&mut echo_buf).await.unwrap();
let echoed = &echo_buf[..n]; let echoed = &echo_buf[..n];
assert_eq!(echoed, test_data, "Expected echo of sent data"); assert_eq!(echoed, test_data, "Expected echo of sent data");
"ok".to_string() "ok".to_string()
}, 10) },
10,
)
.await .await
.unwrap(); .unwrap();
@@ -431,12 +548,22 @@ async fn test_terminate_and_reencrypt_http_routing() {
// Create terminate-and-reencrypt routes // Create terminate-and-reencrypt routes
let mut route1 = make_tls_terminate_route( let mut route1 = make_tls_terminate_route(
proxy_port, "alpha.example.com", "127.0.0.1", backend1_port, &cert1, &key1, proxy_port,
"alpha.example.com",
"127.0.0.1",
backend1_port,
&cert1,
&key1,
); );
route1.action.tls.as_mut().unwrap().mode = rustproxy_config::TlsMode::TerminateAndReencrypt; route1.action.tls.as_mut().unwrap().mode = rustproxy_config::TlsMode::TerminateAndReencrypt;
let mut route2 = make_tls_terminate_route( let mut route2 = make_tls_terminate_route(
proxy_port, "beta.example.com", "127.0.0.1", backend2_port, &cert2, &key2, proxy_port,
"beta.example.com",
"127.0.0.1",
backend2_port,
&cert2,
&key2,
); );
route2.action.tls.as_mut().unwrap().mode = rustproxy_config::TlsMode::TerminateAndReencrypt; route2.action.tls.as_mut().unwrap().mode = rustproxy_config::TlsMode::TerminateAndReencrypt;
@@ -450,27 +577,32 @@ async fn test_terminate_and_reencrypt_http_routing() {
assert!(wait_for_port(proxy_port, 2000).await); assert!(wait_for_port(proxy_port, 2000).await);
// Test alpha domain - HTTP request through TLS terminate-and-reencrypt // Test alpha domain - HTTP request through TLS terminate-and-reencrypt
let alpha_result = with_timeout(async { let alpha_result = with_timeout(
let _ = rustls::crypto::ring::default_provider().install_default(); async {
let tls_config = rustls::ClientConfig::builder() let _ = rustls::crypto::ring::default_provider().install_default();
.dangerous() let tls_config = rustls::ClientConfig::builder()
.with_custom_certificate_verifier(std::sync::Arc::new(InsecureVerifier)) .dangerous()
.with_no_client_auth(); .with_custom_certificate_verifier(std::sync::Arc::new(InsecureVerifier))
let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(tls_config)); .with_no_client_auth();
let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(tls_config));
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port)) let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await .await
.unwrap(); .unwrap();
let server_name = rustls::pki_types::ServerName::try_from("alpha.example.com".to_string()).unwrap(); let server_name =
let mut tls_stream = connector.connect(server_name, stream).await.unwrap(); rustls::pki_types::ServerName::try_from("alpha.example.com".to_string()).unwrap();
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
let request = "GET /api/data HTTP/1.1\r\nHost: alpha.example.com\r\nConnection: close\r\n\r\n"; let request =
tls_stream.write_all(request.as_bytes()).await.unwrap(); "GET /api/data HTTP/1.1\r\nHost: alpha.example.com\r\nConnection: close\r\n\r\n";
tls_stream.write_all(request.as_bytes()).await.unwrap();
let mut response = Vec::new(); let mut response = Vec::new();
tls_stream.read_to_end(&mut response).await.unwrap(); tls_stream.read_to_end(&mut response).await.unwrap();
String::from_utf8_lossy(&response).to_string() String::from_utf8_lossy(&response).to_string()
}, 10) },
10,
)
.await .await
.unwrap(); .unwrap();
@@ -498,27 +630,32 @@ async fn test_terminate_and_reencrypt_http_routing() {
); );
// Test beta domain - different host goes to different backend // Test beta domain - different host goes to different backend
let beta_result = with_timeout(async { let beta_result = with_timeout(
let _ = rustls::crypto::ring::default_provider().install_default(); async {
let tls_config = rustls::ClientConfig::builder() let _ = rustls::crypto::ring::default_provider().install_default();
.dangerous() let tls_config = rustls::ClientConfig::builder()
.with_custom_certificate_verifier(std::sync::Arc::new(InsecureVerifier)) .dangerous()
.with_no_client_auth(); .with_custom_certificate_verifier(std::sync::Arc::new(InsecureVerifier))
let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(tls_config)); .with_no_client_auth();
let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(tls_config));
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port)) let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await .await
.unwrap(); .unwrap();
let server_name = rustls::pki_types::ServerName::try_from("beta.example.com".to_string()).unwrap(); let server_name =
let mut tls_stream = connector.connect(server_name, stream).await.unwrap(); rustls::pki_types::ServerName::try_from("beta.example.com".to_string()).unwrap();
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
let request = "GET /other HTTP/1.1\r\nHost: beta.example.com\r\nConnection: close\r\n\r\n"; let request =
tls_stream.write_all(request.as_bytes()).await.unwrap(); "GET /other HTTP/1.1\r\nHost: beta.example.com\r\nConnection: close\r\n\r\n";
tls_stream.write_all(request.as_bytes()).await.unwrap();
let mut response = Vec::new(); let mut response = Vec::new();
tls_stream.read_to_end(&mut response).await.unwrap(); tls_stream.read_to_end(&mut response).await.unwrap();
String::from_utf8_lossy(&response).to_string() String::from_utf8_lossy(&response).to_string()
}, 10) },
10,
)
.await .await
.unwrap(); .unwrap();
@@ -589,14 +726,12 @@ async fn test_terminate_and_reencrypt_websocket() {
.dangerous() .dangerous()
.with_custom_certificate_verifier(std::sync::Arc::new(InsecureVerifier)) .with_custom_certificate_verifier(std::sync::Arc::new(InsecureVerifier))
.with_no_client_auth(); .with_no_client_auth();
let connector = let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(tls_config));
tokio_rustls::TlsConnector::from(std::sync::Arc::new(tls_config));
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port)) let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await .await
.unwrap(); .unwrap();
let server_name = let server_name = rustls::pki_types::ServerName::try_from(domain.to_string()).unwrap();
rustls::pki_types::ServerName::try_from(domain.to_string()).unwrap();
let mut tls_stream = connector.connect(server_name, stream).await.unwrap(); let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
// Send WebSocket upgrade request through TLS // Send WebSocket upgrade request through TLS
@@ -685,10 +820,13 @@ async fn test_protocol_field_in_route_config() {
assert!(wait_for_port(proxy_port, 2000).await); assert!(wait_for_port(proxy_port, 2000).await);
// HTTP request should match the route and get proxied // HTTP request should match the route and get proxied
let result = with_timeout(async { let result = with_timeout(
let response = send_http_request(proxy_port, "example.com", "GET", "/test").await; async {
extract_body(&response).to_string() let response = send_http_request(proxy_port, "example.com", "GET", "/test").await;
}, 10) extract_body(&response).to_string()
},
10,
)
.await .await
.unwrap(); .unwrap();
@@ -20,13 +20,19 @@ async fn test_start_and_stop() {
assert!(!wait_for_port(port, 200).await); assert!(!wait_for_port(port, 200).await);
proxy.start().await.unwrap(); proxy.start().await.unwrap();
assert!(wait_for_port(port, 2000).await, "Port should be listening after start"); assert!(
wait_for_port(port, 2000).await,
"Port should be listening after start"
);
proxy.stop().await.unwrap(); proxy.stop().await.unwrap();
// Give the OS a moment to release the port // Give the OS a moment to release the port
tokio::time::sleep(std::time::Duration::from_millis(100)).await; tokio::time::sleep(std::time::Duration::from_millis(100)).await;
assert!(!wait_for_port(port, 200).await, "Port should not be listening after stop"); assert!(
!wait_for_port(port, 200).await,
"Port should not be listening after stop"
);
} }
#[tokio::test] #[tokio::test]
@@ -54,7 +60,12 @@ async fn test_update_routes_hot_reload() {
let port = next_port(); let port = next_port();
let options = RustProxyOptions { let options = RustProxyOptions {
routes: vec![make_test_route(port, Some("old.example.com"), "127.0.0.1", 8080)], routes: vec![make_test_route(
port,
Some("old.example.com"),
"127.0.0.1",
8080,
)],
..Default::default() ..Default::default()
}; };
@@ -62,9 +73,12 @@ async fn test_update_routes_hot_reload() {
proxy.start().await.unwrap(); proxy.start().await.unwrap();
// Update routes atomically // Update routes atomically
let new_routes = vec![ let new_routes = vec![make_test_route(
make_test_route(port, Some("new.example.com"), "127.0.0.1", 9090), port,
]; Some("new.example.com"),
"127.0.0.1",
9090,
)];
let result = proxy.update_routes(new_routes).await; let result = proxy.update_routes(new_routes).await;
assert!(result.is_ok()); assert!(result.is_ok());
@@ -87,15 +101,24 @@ async fn test_add_remove_listening_port() {
// Add a new port // Add a new port
proxy.add_listening_port(port2).await.unwrap(); proxy.add_listening_port(port2).await.unwrap();
assert!(wait_for_port(port2, 2000).await, "New port should be listening"); assert!(
wait_for_port(port2, 2000).await,
"New port should be listening"
);
// Remove the port // Remove the port
proxy.remove_listening_port(port2).await.unwrap(); proxy.remove_listening_port(port2).await.unwrap();
tokio::time::sleep(std::time::Duration::from_millis(100)).await; tokio::time::sleep(std::time::Duration::from_millis(100)).await;
assert!(!wait_for_port(port2, 200).await, "Removed port should not be listening"); assert!(
!wait_for_port(port2, 200).await,
"Removed port should not be listening"
);
// Original port should still be listening // Original port should still be listening
assert!(wait_for_port(port1, 200).await, "Original port should still be listening"); assert!(
wait_for_port(port1, 200).await,
"Original port should still be listening"
);
proxy.stop().await.unwrap(); proxy.stop().await.unwrap();
} }
@@ -168,7 +191,11 @@ async fn test_metrics_track_connections() {
tokio::time::sleep(std::time::Duration::from_millis(100)).await; tokio::time::sleep(std::time::Duration::from_millis(100)).await;
let stats = proxy.get_statistics(); let stats = proxy.get_statistics();
assert!(stats.total_connections > 0, "Expected total_connections > 0, got {}", stats.total_connections); assert!(
stats.total_connections > 0,
"Expected total_connections > 0, got {}",
stats.total_connections
);
proxy.stop().await.unwrap(); proxy.stop().await.unwrap();
} }
@@ -205,8 +232,11 @@ async fn test_metrics_track_bytes() {
tokio::time::sleep(std::time::Duration::from_millis(100)).await; tokio::time::sleep(std::time::Duration::from_millis(100)).await;
let stats = proxy.get_statistics(); let stats = proxy.get_statistics();
assert!(stats.total_connections > 0, assert!(
"Expected some connections tracked, got {}", stats.total_connections); stats.total_connections > 0,
"Expected some connections tracked, got {}",
stats.total_connections
);
proxy.stop().await.unwrap(); proxy.stop().await.unwrap();
} }
@@ -228,23 +258,38 @@ async fn test_hot_reload_port_changes() {
let mut proxy = RustProxy::new(options).unwrap(); let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap(); proxy.start().await.unwrap();
assert!(wait_for_port(port1, 2000).await); assert!(wait_for_port(port1, 2000).await);
assert!(!wait_for_port(port2, 200).await, "port2 should not be listening yet"); assert!(
!wait_for_port(port2, 200).await,
"port2 should not be listening yet"
);
// Update routes to use port2 instead // Update routes to use port2 instead
let new_routes = vec![ let new_routes = vec![make_test_route(port2, None, "127.0.0.1", backend_port)];
make_test_route(port2, None, "127.0.0.1", backend_port),
];
proxy.update_routes(new_routes).await.unwrap(); proxy.update_routes(new_routes).await.unwrap();
// Port2 should now be listening, port1 should be closed // Port2 should now be listening, port1 should be closed
assert!(wait_for_port(port2, 2000).await, "port2 should be listening after reload"); assert!(
wait_for_port(port2, 2000).await,
"port2 should be listening after reload"
);
tokio::time::sleep(std::time::Duration::from_millis(100)).await; tokio::time::sleep(std::time::Duration::from_millis(100)).await;
assert!(!wait_for_port(port1, 200).await, "port1 should be closed after reload"); assert!(
!wait_for_port(port1, 200).await,
"port1 should be closed after reload"
);
// Verify port2 works // Verify port2 works
let ports = proxy.get_listening_ports(); let ports = proxy.get_listening_ports();
assert!(ports.contains(&port2), "Expected port2 in listening ports: {:?}", ports); assert!(
assert!(!ports.contains(&port1), "port1 should not be in listening ports: {:?}", ports); ports.contains(&port2),
"Expected port2 in listening ports: {:?}",
ports
);
assert!(
!ports.contains(&port1),
"port1 should not be in listening ports: {:?}",
ports
);
proxy.stop().await.unwrap(); proxy.stop().await.unwrap();
} }
@@ -24,19 +24,25 @@ async fn test_tcp_forward_echo() {
proxy.start().await.unwrap(); proxy.start().await.unwrap();
// Wait for proxy to be ready // Wait for proxy to be ready
assert!(wait_for_port(proxy_port, 2000).await, "Proxy port not ready"); assert!(
wait_for_port(proxy_port, 2000).await,
"Proxy port not ready"
);
// Connect and send data // Connect and send data
let result = with_timeout(async { let result = with_timeout(
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port)) async {
.await let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.unwrap(); .await
stream.write_all(b"hello world").await.unwrap(); .unwrap();
stream.write_all(b"hello world").await.unwrap();
let mut buf = vec![0u8; 1024]; let mut buf = vec![0u8; 1024];
let n = stream.read(&mut buf).await.unwrap(); let n = stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string() String::from_utf8_lossy(&buf[..n]).to_string()
}, 5) },
5,
)
.await .await
.unwrap(); .unwrap();
@@ -61,21 +67,24 @@ async fn test_tcp_forward_large_payload() {
proxy.start().await.unwrap(); proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await); assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async { let result = with_timeout(
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port)) async {
.await let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.unwrap(); .await
.unwrap();
// Send 1MB of data // Send 1MB of data
let data = vec![b'A'; 1_000_000]; let data = vec![b'A'; 1_000_000];
stream.write_all(&data).await.unwrap(); stream.write_all(&data).await.unwrap();
stream.shutdown().await.unwrap(); stream.shutdown().await.unwrap();
// Read all back // Read all back
let mut received = Vec::new(); let mut received = Vec::new();
stream.read_to_end(&mut received).await.unwrap(); stream.read_to_end(&mut received).await.unwrap();
received.len() received.len()
}, 10) },
10,
)
.await .await
.unwrap(); .unwrap();
@@ -100,29 +109,32 @@ async fn test_tcp_forward_multiple_connections() {
proxy.start().await.unwrap(); proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await); assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async { let result = with_timeout(
let mut handles = Vec::new(); async {
for i in 0..10 { let mut handles = Vec::new();
let port = proxy_port; for i in 0..10 {
handles.push(tokio::spawn(async move { let port = proxy_port;
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", port)) handles.push(tokio::spawn(async move {
.await let mut stream = TcpStream::connect(format!("127.0.0.1:{}", port))
.unwrap(); .await
let msg = format!("connection-{}", i); .unwrap();
stream.write_all(msg.as_bytes()).await.unwrap(); let msg = format!("connection-{}", i);
stream.write_all(msg.as_bytes()).await.unwrap();
let mut buf = vec![0u8; 1024]; let mut buf = vec![0u8; 1024];
let n = stream.read(&mut buf).await.unwrap(); let n = stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string() String::from_utf8_lossy(&buf[..n]).to_string()
})); }));
} }
let mut results = Vec::new(); let mut results = Vec::new();
for handle in handles { for handle in handles {
results.push(handle.await.unwrap()); results.push(handle.await.unwrap());
} }
results results
}, 10) },
10,
)
.await .await
.unwrap(); .unwrap();
@@ -149,14 +161,20 @@ async fn test_tcp_forward_backend_unreachable() {
assert!(wait_for_port(proxy_port, 2000).await); assert!(wait_for_port(proxy_port, 2000).await);
// Connection should complete (proxy accepts it) but data should not flow // Connection should complete (proxy accepts it) but data should not flow
let result = with_timeout(async { let result = with_timeout(
let stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port)).await; async {
stream.is_ok() let stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port)).await;
}, 5) stream.is_ok()
},
5,
)
.await .await
.unwrap(); .unwrap();
assert!(result, "Should be able to connect to proxy even if backend is down"); assert!(
result,
"Should be able to connect to proxy even if backend is down"
);
proxy.stop().await.unwrap(); proxy.stop().await.unwrap();
} }
@@ -178,16 +196,19 @@ async fn test_tcp_forward_bidirectional() {
proxy.start().await.unwrap(); proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await); assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async { let result = with_timeout(
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port)) async {
.await let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.unwrap(); .await
stream.write_all(b"test data").await.unwrap(); .unwrap();
stream.write_all(b"test data").await.unwrap();
let mut buf = vec![0u8; 1024]; let mut buf = vec![0u8; 1024];
let n = stream.read(&mut buf).await.unwrap(); let n = stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string() String::from_utf8_lossy(&buf[..n]).to_string()
}, 5) },
5,
)
.await .await
.unwrap(); .unwrap();
@@ -65,8 +65,18 @@ async fn test_tls_passthrough_sni_routing() {
let options = RustProxyOptions { let options = RustProxyOptions {
routes: vec![ routes: vec![
make_tls_passthrough_route(proxy_port, Some("one.example.com"), "127.0.0.1", backend1_port), make_tls_passthrough_route(
make_tls_passthrough_route(proxy_port, Some("two.example.com"), "127.0.0.1", backend2_port), proxy_port,
Some("one.example.com"),
"127.0.0.1",
backend1_port,
),
make_tls_passthrough_route(
proxy_port,
Some("two.example.com"),
"127.0.0.1",
backend2_port,
),
], ],
..Default::default() ..Default::default()
}; };
@@ -76,39 +86,53 @@ async fn test_tls_passthrough_sni_routing() {
assert!(wait_for_port(proxy_port, 2000).await); assert!(wait_for_port(proxy_port, 2000).await);
// Send a fake ClientHello with SNI "one.example.com" // Send a fake ClientHello with SNI "one.example.com"
let result = with_timeout(async { let result = with_timeout(
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port)) async {
.await let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.unwrap(); .await
let hello = build_client_hello("one.example.com"); .unwrap();
stream.write_all(&hello).await.unwrap(); let hello = build_client_hello("one.example.com");
stream.write_all(&hello).await.unwrap();
let mut buf = vec![0u8; 4096]; let mut buf = vec![0u8; 4096];
let n = stream.read(&mut buf).await.unwrap(); let n = stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string() String::from_utf8_lossy(&buf[..n]).to_string()
}, 5) },
5,
)
.await .await
.unwrap(); .unwrap();
// Backend1 should have received the ClientHello and prefixed its response // Backend1 should have received the ClientHello and prefixed its response
assert!(result.starts_with("BACKEND1:"), "Expected BACKEND1 prefix, got: {}", result); assert!(
result.starts_with("BACKEND1:"),
"Expected BACKEND1 prefix, got: {}",
result
);
// Now test routing to backend2 // Now test routing to backend2
let result2 = with_timeout(async { let result2 = with_timeout(
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port)) async {
.await let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.unwrap(); .await
let hello = build_client_hello("two.example.com"); .unwrap();
stream.write_all(&hello).await.unwrap(); let hello = build_client_hello("two.example.com");
stream.write_all(&hello).await.unwrap();
let mut buf = vec![0u8; 4096]; let mut buf = vec![0u8; 4096];
let n = stream.read(&mut buf).await.unwrap(); let n = stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string() String::from_utf8_lossy(&buf[..n]).to_string()
}, 5) },
5,
)
.await .await
.unwrap(); .unwrap();
assert!(result2.starts_with("BACKEND2:"), "Expected BACKEND2 prefix, got: {}", result2); assert!(
result2.starts_with("BACKEND2:"),
"Expected BACKEND2 prefix, got: {}",
result2
);
proxy.stop().await.unwrap(); proxy.stop().await.unwrap();
} }
@@ -121,9 +145,12 @@ async fn test_tls_passthrough_unknown_sni() {
let _backend = start_echo_server(backend_port).await; let _backend = start_echo_server(backend_port).await;
let options = RustProxyOptions { let options = RustProxyOptions {
routes: vec![ routes: vec![make_tls_passthrough_route(
make_tls_passthrough_route(proxy_port, Some("known.example.com"), "127.0.0.1", backend_port), proxy_port,
], Some("known.example.com"),
"127.0.0.1",
backend_port,
)],
..Default::default() ..Default::default()
}; };
@@ -132,21 +159,24 @@ async fn test_tls_passthrough_unknown_sni() {
assert!(wait_for_port(proxy_port, 2000).await); assert!(wait_for_port(proxy_port, 2000).await);
// Send ClientHello with unknown SNI - should get no response (connection dropped) // Send ClientHello with unknown SNI - should get no response (connection dropped)
let result = with_timeout(async { let result = with_timeout(
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port)) async {
.await let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.unwrap(); .await
let hello = build_client_hello("unknown.example.com"); .unwrap();
stream.write_all(&hello).await.unwrap(); let hello = build_client_hello("unknown.example.com");
stream.write_all(&hello).await.unwrap();
let mut buf = vec![0u8; 4096]; let mut buf = vec![0u8; 4096];
// Should either get 0 bytes (closed) or an error // Should either get 0 bytes (closed) or an error
match stream.read(&mut buf).await { match stream.read(&mut buf).await {
Ok(0) => true, // Connection closed = no route matched Ok(0) => true, // Connection closed = no route matched
Ok(_) => false, // Got data = route shouldn't have matched Ok(_) => false, // Got data = route shouldn't have matched
Err(_) => true, // Error = connection dropped Err(_) => true, // Error = connection dropped
} }
}, 5) },
5,
)
.await .await
.unwrap(); .unwrap();
@@ -163,9 +193,12 @@ async fn test_tls_passthrough_wildcard_domain() {
let _backend = start_prefix_echo_server(backend_port, "WILDCARD:").await; let _backend = start_prefix_echo_server(backend_port, "WILDCARD:").await;
let options = RustProxyOptions { let options = RustProxyOptions {
routes: vec![ routes: vec![make_tls_passthrough_route(
make_tls_passthrough_route(proxy_port, Some("*.example.com"), "127.0.0.1", backend_port), proxy_port,
], Some("*.example.com"),
"127.0.0.1",
backend_port,
)],
..Default::default() ..Default::default()
}; };
@@ -174,21 +207,28 @@ async fn test_tls_passthrough_wildcard_domain() {
assert!(wait_for_port(proxy_port, 2000).await); assert!(wait_for_port(proxy_port, 2000).await);
// Should match any subdomain of example.com // Should match any subdomain of example.com
let result = with_timeout(async { let result = with_timeout(
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port)) async {
.await let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.unwrap(); .await
let hello = build_client_hello("anything.example.com"); .unwrap();
stream.write_all(&hello).await.unwrap(); let hello = build_client_hello("anything.example.com");
stream.write_all(&hello).await.unwrap();
let mut buf = vec![0u8; 4096]; let mut buf = vec![0u8; 4096];
let n = stream.read(&mut buf).await.unwrap(); let n = stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string() String::from_utf8_lossy(&buf[..n]).to_string()
}, 5) },
5,
)
.await .await
.unwrap(); .unwrap();
assert!(result.starts_with("WILDCARD:"), "Expected WILDCARD prefix, got: {}", result); assert!(
result.starts_with("WILDCARD:"),
"Expected WILDCARD prefix, got: {}",
result
);
proxy.stop().await.unwrap(); proxy.stop().await.unwrap();
} }
@@ -222,24 +262,29 @@ async fn test_tls_passthrough_multiple_domains() {
("beta.example.com", "B2:"), ("beta.example.com", "B2:"),
("gamma.example.com", "B3:"), ("gamma.example.com", "B3:"),
] { ] {
let result = with_timeout(async { let result = with_timeout(
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port)) async {
.await let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.unwrap(); .await
let hello = build_client_hello(domain); .unwrap();
stream.write_all(&hello).await.unwrap(); let hello = build_client_hello(domain);
stream.write_all(&hello).await.unwrap();
let mut buf = vec![0u8; 4096]; let mut buf = vec![0u8; 4096];
let n = stream.read(&mut buf).await.unwrap(); let n = stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string() String::from_utf8_lossy(&buf[..n]).to_string()
}, 5) },
5,
)
.await .await
.unwrap(); .unwrap();
assert!( assert!(
result.starts_with(expected_prefix), result.starts_with(expected_prefix),
"Domain {} should route to {}, got: {}", "Domain {} should route to {}, got: {}",
domain, expected_prefix, result domain,
expected_prefix,
result
); );
} }
@@ -74,7 +74,12 @@ async fn test_tls_terminate_basic() {
let options = RustProxyOptions { let options = RustProxyOptions {
routes: vec![make_tls_terminate_route( routes: vec![make_tls_terminate_route(
proxy_port, domain, "127.0.0.1", backend_port, &cert_pem, &key_pem, proxy_port,
domain,
"127.0.0.1",
backend_port,
&cert_pem,
&key_pem,
)], )],
..Default::default() ..Default::default()
}; };
@@ -84,23 +89,26 @@ async fn test_tls_terminate_basic() {
assert!(wait_for_port(proxy_port, 2000).await); assert!(wait_for_port(proxy_port, 2000).await);
// Connect with TLS client // Connect with TLS client
let result = with_timeout(async { let result = with_timeout(
let tls_config = make_insecure_tls_client_config(); async {
let connector = tokio_rustls::TlsConnector::from(tls_config); let tls_config = make_insecure_tls_client_config();
let connector = tokio_rustls::TlsConnector::from(tls_config);
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port)) let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await .await
.unwrap(); .unwrap();
let server_name = rustls::pki_types::ServerName::try_from(domain.to_string()).unwrap(); let server_name = rustls::pki_types::ServerName::try_from(domain.to_string()).unwrap();
let mut tls_stream = connector.connect(server_name, stream).await.unwrap(); let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
tls_stream.write_all(b"hello TLS").await.unwrap(); tls_stream.write_all(b"hello TLS").await.unwrap();
let mut buf = vec![0u8; 1024]; let mut buf = vec![0u8; 1024];
let n = tls_stream.read(&mut buf).await.unwrap(); let n = tls_stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string() String::from_utf8_lossy(&buf[..n]).to_string()
}, 10) },
10,
)
.await .await
.unwrap(); .unwrap();
@@ -125,7 +133,12 @@ async fn test_tls_terminate_and_reencrypt() {
// Create terminate-and-reencrypt route // Create terminate-and-reencrypt route
let mut route = make_tls_terminate_route( let mut route = make_tls_terminate_route(
proxy_port, domain, "127.0.0.1", backend_port, &proxy_cert, &proxy_key, proxy_port,
domain,
"127.0.0.1",
backend_port,
&proxy_cert,
&proxy_key,
); );
route.action.tls.as_mut().unwrap().mode = rustproxy_config::TlsMode::TerminateAndReencrypt; route.action.tls.as_mut().unwrap().mode = rustproxy_config::TlsMode::TerminateAndReencrypt;
@@ -138,23 +151,26 @@ async fn test_tls_terminate_and_reencrypt() {
proxy.start().await.unwrap(); proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await); assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async { let result = with_timeout(
let tls_config = make_insecure_tls_client_config(); async {
let connector = tokio_rustls::TlsConnector::from(tls_config); let tls_config = make_insecure_tls_client_config();
let connector = tokio_rustls::TlsConnector::from(tls_config);
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port)) let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await .await
.unwrap(); .unwrap();
let server_name = rustls::pki_types::ServerName::try_from(domain.to_string()).unwrap(); let server_name = rustls::pki_types::ServerName::try_from(domain.to_string()).unwrap();
let mut tls_stream = connector.connect(server_name, stream).await.unwrap(); let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
tls_stream.write_all(b"hello reencrypt").await.unwrap(); tls_stream.write_all(b"hello reencrypt").await.unwrap();
let mut buf = vec![0u8; 1024]; let mut buf = vec![0u8; 1024];
let n = tls_stream.read(&mut buf).await.unwrap(); let n = tls_stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string() String::from_utf8_lossy(&buf[..n]).to_string()
}, 10) },
10,
)
.await .await
.unwrap(); .unwrap();
@@ -177,8 +193,22 @@ async fn test_tls_terminate_sni_cert_selection() {
let options = RustProxyOptions { let options = RustProxyOptions {
routes: vec![ routes: vec![
make_tls_terminate_route(proxy_port, "alpha.example.com", "127.0.0.1", backend1_port, &cert1, &key1), make_tls_terminate_route(
make_tls_terminate_route(proxy_port, "beta.example.com", "127.0.0.1", backend2_port, &cert2, &key2), proxy_port,
"alpha.example.com",
"127.0.0.1",
backend1_port,
&cert1,
&key1,
),
make_tls_terminate_route(
proxy_port,
"beta.example.com",
"127.0.0.1",
backend2_port,
&cert2,
&key2,
),
], ],
..Default::default() ..Default::default()
}; };
@@ -188,27 +218,35 @@ async fn test_tls_terminate_sni_cert_selection() {
assert!(wait_for_port(proxy_port, 2000).await); assert!(wait_for_port(proxy_port, 2000).await);
// Test alpha domain // Test alpha domain
let result = with_timeout(async { let result = with_timeout(
let tls_config = make_insecure_tls_client_config(); async {
let connector = tokio_rustls::TlsConnector::from(tls_config); let tls_config = make_insecure_tls_client_config();
let connector = tokio_rustls::TlsConnector::from(tls_config);
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port)) let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await .await
.unwrap(); .unwrap();
let server_name = rustls::pki_types::ServerName::try_from("alpha.example.com".to_string()).unwrap(); let server_name =
let mut tls_stream = connector.connect(server_name, stream).await.unwrap(); rustls::pki_types::ServerName::try_from("alpha.example.com".to_string()).unwrap();
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
tls_stream.write_all(b"test").await.unwrap(); tls_stream.write_all(b"test").await.unwrap();
let mut buf = vec![0u8; 1024]; let mut buf = vec![0u8; 1024];
let n = tls_stream.read(&mut buf).await.unwrap(); let n = tls_stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string() String::from_utf8_lossy(&buf[..n]).to_string()
}, 10) },
10,
)
.await .await
.unwrap(); .unwrap();
assert!(result.starts_with("ALPHA:"), "Expected ALPHA prefix, got: {}", result); assert!(
result.starts_with("ALPHA:"),
"Expected ALPHA prefix, got: {}",
result
);
proxy.stop().await.unwrap(); proxy.stop().await.unwrap();
} }
@@ -224,7 +262,12 @@ async fn test_tls_terminate_large_payload() {
let options = RustProxyOptions { let options = RustProxyOptions {
routes: vec![make_tls_terminate_route( routes: vec![make_tls_terminate_route(
proxy_port, domain, "127.0.0.1", backend_port, &cert_pem, &key_pem, proxy_port,
domain,
"127.0.0.1",
backend_port,
&cert_pem,
&key_pem,
)], )],
..Default::default() ..Default::default()
}; };
@@ -233,26 +276,29 @@ async fn test_tls_terminate_large_payload() {
proxy.start().await.unwrap(); proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await); assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async { let result = with_timeout(
let tls_config = make_insecure_tls_client_config(); async {
let connector = tokio_rustls::TlsConnector::from(tls_config); let tls_config = make_insecure_tls_client_config();
let connector = tokio_rustls::TlsConnector::from(tls_config);
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port)) let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
.await .await
.unwrap(); .unwrap();
let server_name = rustls::pki_types::ServerName::try_from(domain.to_string()).unwrap(); let server_name = rustls::pki_types::ServerName::try_from(domain.to_string()).unwrap();
let mut tls_stream = connector.connect(server_name, stream).await.unwrap(); let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
// Send 1MB of data // Send 1MB of data
let data = vec![b'X'; 1_000_000]; let data = vec![b'X'; 1_000_000];
tls_stream.write_all(&data).await.unwrap(); tls_stream.write_all(&data).await.unwrap();
tls_stream.shutdown().await.unwrap(); tls_stream.shutdown().await.unwrap();
let mut received = Vec::new(); let mut received = Vec::new();
tls_stream.read_to_end(&mut received).await.unwrap(); tls_stream.read_to_end(&mut received).await.unwrap();
received.len() received.len()
}, 15) },
15,
)
.await .await
.unwrap(); .unwrap();
@@ -272,7 +318,12 @@ async fn test_tls_terminate_concurrent() {
let options = RustProxyOptions { let options = RustProxyOptions {
routes: vec![make_tls_terminate_route( routes: vec![make_tls_terminate_route(
proxy_port, domain, "127.0.0.1", backend_port, &cert_pem, &key_pem, proxy_port,
domain,
"127.0.0.1",
backend_port,
&cert_pem,
&key_pem,
)], )],
..Default::default() ..Default::default()
}; };
@@ -281,37 +332,40 @@ async fn test_tls_terminate_concurrent() {
proxy.start().await.unwrap(); proxy.start().await.unwrap();
assert!(wait_for_port(proxy_port, 2000).await); assert!(wait_for_port(proxy_port, 2000).await);
let result = with_timeout(async { let result = with_timeout(
let mut handles = Vec::new(); async {
for i in 0..10 { let mut handles = Vec::new();
let port = proxy_port; for i in 0..10 {
let dom = domain.to_string(); let port = proxy_port;
handles.push(tokio::spawn(async move { let dom = domain.to_string();
let tls_config = make_insecure_tls_client_config(); handles.push(tokio::spawn(async move {
let connector = tokio_rustls::TlsConnector::from(tls_config); let tls_config = make_insecure_tls_client_config();
let connector = tokio_rustls::TlsConnector::from(tls_config);
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", port)) let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", port))
.await .await
.unwrap(); .unwrap();
let server_name = rustls::pki_types::ServerName::try_from(dom).unwrap(); let server_name = rustls::pki_types::ServerName::try_from(dom).unwrap();
let mut tls_stream = connector.connect(server_name, stream).await.unwrap(); let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
let msg = format!("conn-{}", i); let msg = format!("conn-{}", i);
tls_stream.write_all(msg.as_bytes()).await.unwrap(); tls_stream.write_all(msg.as_bytes()).await.unwrap();
let mut buf = vec![0u8; 1024]; let mut buf = vec![0u8; 1024];
let n = tls_stream.read(&mut buf).await.unwrap(); let n = tls_stream.read(&mut buf).await.unwrap();
String::from_utf8_lossy(&buf[..n]).to_string() String::from_utf8_lossy(&buf[..n]).to_string()
})); }));
} }
let mut results = Vec::new(); let mut results = Vec::new();
for handle in handles { for handle in handles {
results.push(handle.await.unwrap()); results.push(handle.await.unwrap());
} }
results results
}, 15) },
15,
)
.await .await
.unwrap(); .unwrap();
+191
View File
@@ -0,0 +1,191 @@
import { expect, tap } from '@git.zone/tstest/tapbundle';
import { SmartProxy } from '../ts/index.js';
import * as http from 'http';
import * as net from 'net';
import * as tls from 'tls';
import * as fs from 'fs';
import * as path from 'path';
import { fileURLToPath } from 'url';
import { assertPortsFree, findFreePorts } from './helpers/port-allocator.js';
const __filename = fileURLToPath(import.meta.url);
const __dirname = path.dirname(__filename);
const CERT_PEM = fs.readFileSync(path.join(__dirname, '..', 'assets', 'certs', 'cert.pem'), 'utf8');
const KEY_PEM = fs.readFileSync(path.join(__dirname, '..', 'assets', 'certs', 'key.pem'), 'utf8');
let httpBackendPort: number;
let tlsBackendPort: number;
let httpProxyPort: number;
let tlsProxyPort: number;
let httpBackend: http.Server;
let tlsBackend: tls.Server;
let proxy: SmartProxy;
async function pollMetrics(proxyToPoll: SmartProxy): Promise<void> {
await (proxyToPoll as any).metricsAdapter.poll();
}
async function waitForCondition(
callback: () => Promise<boolean>,
timeoutMs: number = 5000,
stepMs: number = 100,
): Promise<void> {
const deadline = Date.now() + timeoutMs;
while (Date.now() < deadline) {
if (await callback()) {
return;
}
await new Promise((resolve) => setTimeout(resolve, stepMs));
}
throw new Error(`Condition not met within ${timeoutMs}ms`);
}
function hasIpDomainRequest(domain: string): boolean {
const byIp = proxy.getMetrics().connections.domainRequestsByIP();
for (const domainMap of byIp.values()) {
if (domainMap.has(domain)) {
return true;
}
}
return false;
}
tap.test('setup - backend servers for HTTP domain rate metrics', async () => {
[httpBackendPort, tlsBackendPort, httpProxyPort, tlsProxyPort] = await findFreePorts(4);
httpBackend = http.createServer((req, res) => {
let body = '';
req.on('data', (chunk) => {
body += chunk;
});
req.on('end', () => {
res.writeHead(200, { 'Content-Type': 'text/plain' });
res.end(`ok:${body}`);
});
});
await new Promise<void>((resolve) => {
httpBackend.listen(httpBackendPort, () => resolve());
});
tlsBackend = tls.createServer({ cert: CERT_PEM, key: KEY_PEM }, (socket) => {
socket.on('data', (data) => {
socket.write(data);
});
socket.on('error', () => {});
});
await new Promise<void>((resolve) => {
tlsBackend.listen(tlsBackendPort, () => resolve());
});
});
tap.test('setup - start proxy with HTTP and TLS passthrough routes', async () => {
proxy = new SmartProxy({
routes: [
{
id: 'http-domain-rates',
name: 'http-domain-rates',
match: { ports: httpProxyPort, domains: 'example.com' },
action: {
type: 'forward',
targets: [{ host: 'localhost', port: httpBackendPort }],
},
},
{
id: 'tls-passthrough-domain-rates',
name: 'tls-passthrough-domain-rates',
match: { ports: tlsProxyPort, domains: 'passthrough.example.com' },
action: {
type: 'forward',
tls: { mode: 'passthrough' },
targets: [{ host: 'localhost', port: tlsBackendPort }],
},
},
],
metrics: { enabled: true, sampleIntervalMs: 100, retentionSeconds: 60 },
});
await proxy.start();
await new Promise((resolve) => setTimeout(resolve, 300));
});
tap.test('HTTP requests populate per-domain HTTP request rates', async () => {
for (let i = 0; i < 3; i++) {
await new Promise<void>((resolve, reject) => {
const body = `payload-${i}`;
const req = http.request(
{
hostname: 'localhost',
port: httpProxyPort,
path: '/echo',
method: 'POST',
headers: {
Host: 'Example.COM',
'Content-Type': 'text/plain',
'Content-Length': String(body.length),
},
},
(res) => {
res.resume();
res.on('end', () => resolve());
},
);
req.on('error', reject);
req.end(body);
});
}
await waitForCondition(async () => {
await pollMetrics(proxy);
const domainMetrics = proxy.getMetrics().requests.byDomain().get('example.com');
return (domainMetrics?.lastMinute ?? 0) >= 3 && (domainMetrics?.perSecond ?? 0) > 0;
});
const exampleMetrics = proxy.getMetrics().requests.byDomain().get('example.com');
expect(exampleMetrics).toBeTruthy();
expect(exampleMetrics?.lastMinute).toEqual(3);
expect(exampleMetrics?.perSecond).toBeGreaterThan(0);
});
tap.test('TLS passthrough SNI does not inflate HTTP domain request rates', async () => {
const tlsClient = tls.connect({
host: 'localhost',
port: tlsProxyPort,
servername: 'passthrough.example.com',
rejectUnauthorized: false,
});
await new Promise<void>((resolve, reject) => {
tlsClient.once('secureConnect', () => resolve());
tlsClient.once('error', reject);
});
const echoPromise = new Promise<void>((resolve, reject) => {
tlsClient.once('data', () => resolve());
tlsClient.once('error', reject);
});
tlsClient.write(Buffer.from('hello over tls passthrough'));
await echoPromise;
await waitForCondition(async () => {
await pollMetrics(proxy);
return hasIpDomainRequest('passthrough.example.com');
});
const requestRates = proxy.getMetrics().requests.byDomain();
expect(requestRates.has('passthrough.example.com')).toBeFalse();
expect(requestRates.get('example.com')?.lastMinute).toEqual(3);
expect(hasIpDomainRequest('passthrough.example.com')).toBeTrue();
tlsClient.destroy();
});
tap.test('cleanup - stop proxy and close backend servers', async () => {
await proxy.stop();
await new Promise<void>((resolve) => httpBackend.close(() => resolve()));
await new Promise<void>((resolve) => tlsBackend.close(() => resolve()));
await assertPortsFree([httpBackendPort, tlsBackendPort, httpProxyPort, tlsProxyPort]);
});
export default tap.start()
+4 -1
View File
@@ -83,6 +83,9 @@ tap.test('should verify new metrics API structure', async () => {
expect(metrics.throughput).toHaveProperty('history'); expect(metrics.throughput).toHaveProperty('history');
expect(metrics.throughput).toHaveProperty('byRoute'); expect(metrics.throughput).toHaveProperty('byRoute');
expect(metrics.throughput).toHaveProperty('byIP'); expect(metrics.throughput).toHaveProperty('byIP');
// Check request methods
expect(metrics.requests).toHaveProperty('byDomain');
}); });
tap.test('should track active connections', async (tools) => { tap.test('should track active connections', async (tools) => {
@@ -273,4 +276,4 @@ tap.test('should clean up resources', async () => {
await assertPortsFree([echoServerPort, proxyPort]); await assertPortsFree([echoServerPort, proxyPort]);
}); });
export default tap.start(); export default tap.start();
+25
View File
@@ -537,6 +537,31 @@ tap.test('Route Matching - routeMatchesHeaders', async () => {
'X-Custom-Header': 'value' 'X-Custom-Header': 'value'
})).toBeFalse(); })).toBeFalse();
const regexHeaderRoute: IRouteConfig = {
match: {
domains: 'example.com',
ports: 80,
headers: {
'Content-Type': /^application\/(json|problem\+json)$/i,
}
},
action: {
type: 'forward',
targets: [{
host: 'localhost',
port: 3000
}]
}
};
expect(routeMatchesHeaders(regexHeaderRoute, {
'Content-Type': 'Application/Problem+Json',
})).toBeTrue();
expect(routeMatchesHeaders(regexHeaderRoute, {
'Content-Type': 'text/html',
})).toBeFalse();
// Route without header matching should match any headers // Route without header matching should match any headers
const noHeaderRoute: IRouteConfig = { const noHeaderRoute: IRouteConfig = {
match: { ports: 80, domains: 'example.com' }, match: { ports: 80, domains: 'example.com' },
+192
View File
@@ -0,0 +1,192 @@
import { expect, tap } from '@git.zone/tstest/tapbundle';
import type { ISmartProxyOptions } from '../ts/proxies/smart-proxy/models/interfaces.js';
import type { IRouteConfig } from '../ts/proxies/smart-proxy/models/route-types.js';
import { RoutePreprocessor } from '../ts/proxies/smart-proxy/route-preprocessor.js';
import { buildRustProxyOptions } from '../ts/proxies/smart-proxy/utils/rust-config.js';
tap.test('Rust contract - preprocessor serializes regex headers for Rust', async () => {
const route: IRouteConfig = {
name: 'contract-route',
match: {
ports: [443, { from: 8443, to: 8444 }],
domains: ['api.example.com', '*.example.com'],
transport: 'udp',
protocol: 'http3',
headers: {
'Content-Type': /^application\/json$/i,
},
},
action: {
type: 'forward',
targets: [{
match: {
ports: [443],
path: '/api/*',
method: ['GET'],
headers: {
'X-Env': /^(prod|stage)$/,
},
},
host: ['backend-a', 'backend-b'],
port: 'preserve',
sendProxyProtocol: true,
backendTransport: 'tcp',
}],
tls: {
mode: 'terminate',
certificate: 'auto',
},
sendProxyProtocol: true,
udp: {
maxSessionsPerIP: 321,
quic: {
enableHttp3: true,
},
},
},
security: {
ipAllowList: [{
ip: '10.0.0.0/8',
domains: ['api.example.com'],
}],
},
};
const preprocessor = new RoutePreprocessor();
const [rustRoute] = preprocessor.preprocessForRust([route]);
expect(rustRoute.match.headers?.['Content-Type']).toEqual('/^application\\/json$/i');
expect(rustRoute.match.transport).toEqual('udp');
expect(rustRoute.match.protocol).toEqual('http3');
expect(rustRoute.action.targets?.[0].match?.headers?.['X-Env']).toEqual('/^(prod|stage)$/');
expect(rustRoute.action.targets?.[0].port).toEqual('preserve');
expect(rustRoute.action.targets?.[0].backendTransport).toEqual('tcp');
expect(rustRoute.action.sendProxyProtocol).toBeTrue();
expect(rustRoute.action.udp?.maxSessionsPerIp).toEqual(321);
});
tap.test('Rust contract - preprocessor converts dynamic targets to relay-safe payloads', async () => {
const route: IRouteConfig = {
name: 'dynamic-contract-route',
match: {
ports: 8080,
},
action: {
type: 'forward',
targets: [{
host: () => 'dynamic-backend.internal',
port: () => 9443,
}],
},
};
const preprocessor = new RoutePreprocessor();
const [rustRoute] = preprocessor.preprocessForRust([route]);
expect(rustRoute.action.type).toEqual('socket-handler');
expect(rustRoute.action.targets?.[0].host).toEqual('localhost');
expect(rustRoute.action.targets?.[0].port).toEqual(0);
expect(preprocessor.getOriginalRoute('dynamic-contract-route')).toEqual(route);
});
tap.test('Rust contract - top-level config keeps shared SmartProxy settings', async () => {
const settings: ISmartProxyOptions = {
routes: [{
name: 'top-level-contract-route',
match: {
ports: 443,
domains: 'api.example.com',
},
action: {
type: 'forward',
targets: [{
host: 'backend.internal',
port: 8443,
}],
tls: {
mode: 'terminate',
certificate: 'auto',
},
},
}],
preserveSourceIP: true,
proxyIPs: ['10.0.0.1'],
acceptProxyProtocol: true,
sendProxyProtocol: true,
noDelay: true,
keepAlive: true,
keepAliveInitialDelay: 1500,
maxPendingDataSize: 4096,
disableInactivityCheck: true,
enableKeepAliveProbes: true,
enableDetailedLogging: true,
enableTlsDebugLogging: true,
enableRandomizedTimeouts: true,
connectionTimeout: 5000,
initialDataTimeout: 7000,
socketTimeout: 9000,
inactivityCheckInterval: 1100,
maxConnectionLifetime: 13000,
inactivityTimeout: 15000,
gracefulShutdownTimeout: 17000,
maxConnectionsPerIP: 20,
connectionRateLimitPerMinute: 30,
keepAliveTreatment: 'extended',
keepAliveInactivityMultiplier: 2,
extendedKeepAliveLifetime: 19000,
metrics: {
enabled: true,
sampleIntervalMs: 250,
retentionSeconds: 60,
},
acme: {
enabled: true,
email: 'ops@example.com',
environment: 'staging',
useProduction: false,
skipConfiguredCerts: true,
renewThresholdDays: 14,
renewCheckIntervalHours: 12,
autoRenew: true,
port: 80,
},
};
const preprocessor = new RoutePreprocessor();
const routes = preprocessor.preprocessForRust(settings.routes);
const config = buildRustProxyOptions(settings, routes);
expect(config.preserveSourceIp).toBeTrue();
expect(config.proxyIps).toEqual(['10.0.0.1']);
expect(config.acceptProxyProtocol).toBeTrue();
expect(config.sendProxyProtocol).toBeTrue();
expect(config.noDelay).toBeTrue();
expect(config.keepAlive).toBeTrue();
expect(config.keepAliveInitialDelay).toEqual(1500);
expect(config.maxPendingDataSize).toEqual(4096);
expect(config.disableInactivityCheck).toBeTrue();
expect(config.enableKeepAliveProbes).toBeTrue();
expect(config.enableDetailedLogging).toBeTrue();
expect(config.enableTlsDebugLogging).toBeTrue();
expect(config.enableRandomizedTimeouts).toBeTrue();
expect(config.connectionTimeout).toEqual(5000);
expect(config.initialDataTimeout).toEqual(7000);
expect(config.socketTimeout).toEqual(9000);
expect(config.inactivityCheckInterval).toEqual(1100);
expect(config.maxConnectionLifetime).toEqual(13000);
expect(config.inactivityTimeout).toEqual(15000);
expect(config.gracefulShutdownTimeout).toEqual(17000);
expect(config.maxConnectionsPerIp).toEqual(20);
expect(config.connectionRateLimitPerMinute).toEqual(30);
expect(config.keepAliveTreatment).toEqual('extended');
expect(config.keepAliveInactivityMultiplier).toEqual(2);
expect(config.extendedKeepAliveLifetime).toEqual(19000);
expect(config.metrics?.sampleIntervalMs).toEqual(250);
expect(config.acme?.email).toEqual('ops@example.com');
expect(config.acme?.environment).toEqual('staging');
expect(config.acme?.skipConfiguredCerts).toBeTrue();
expect(config.acme?.renewThresholdDays).toEqual(14);
});
export default tap.start();
+4 -2
View File
@@ -188,10 +188,12 @@ tap.test('TCP forward - real-time byte tracking', async (tools) => {
const byRoute = m.throughput.byRoute(); const byRoute = m.throughput.byRoute();
console.log('TCP forward — throughput byRoute:', Array.from(byRoute.entries())); console.log('TCP forward — throughput byRoute:', Array.from(byRoute.entries()));
// After close, per-IP data should be evicted (memory leak fix) // After close, per-IP buckets are retained briefly for final throughput sampling,
// but active connection counts must already be zero.
const byIPAfter = m.connections.byIP(); const byIPAfter = m.connections.byIP();
console.log('TCP forward — connections byIP after close:', Array.from(byIPAfter.entries())); console.log('TCP forward — connections byIP after close:', Array.from(byIPAfter.entries()));
expect(byIPAfter.size).toEqual(0); expect(byIPAfter.size).toBeGreaterThan(0);
expect(Array.from(byIPAfter.values()).every((count) => count === 0)).toEqual(true);
await proxy.stop(); await proxy.stop();
await tools.delayFor(200); await tools.delayFor(200);
+1 -1
View File
@@ -3,6 +3,6 @@
*/ */
export const commitinfo = { export const commitinfo = {
name: '@push.rocks/smartproxy', name: '@push.rocks/smartproxy',
version: '27.3.0', version: '27.9.0',
description: 'A powerful proxy package with unified route-based configuration for high traffic management. Features include SSL/TLS support, flexible routing patterns, WebSocket handling, advanced security options, and automatic ACME certificate management.' description: 'A powerful proxy package with unified route-based configuration for high traffic management. Features include SSL/TLS support, flexible routing patterns, WebSocket handling, advanced security options, and automatic ACME certificate management.'
} }
+1 -1
View File
@@ -7,7 +7,7 @@ export { SmartProxy } from './proxies/smart-proxy/index.js';
export { SharedRouteManager as RouteManager } from './core/routing/route-manager.js'; export { SharedRouteManager as RouteManager } from './core/routing/route-manager.js';
// Export smart-proxy models // Export smart-proxy models
export type { ISmartProxyOptions, IConnectionRecord, IRouteConfig, IRouteMatch, IRouteAction, IRouteTls, IRouteContext } from './proxies/smart-proxy/models/index.js'; export type { ISmartProxyOptions, ISmartProxySecurityPolicy, IConnectionRecord, IRouteConfig, IRouteMatch, IRouteAction, IRouteTls, IRouteContext } from './proxies/smart-proxy/models/index.js';
export type { TSmartProxyCertProvisionObject, ICertProvisionEventComms, ICertificateIssuedEvent, ICertificateFailedEvent } from './proxies/smart-proxy/models/interfaces.js'; export type { TSmartProxyCertProvisionObject, ICertProvisionEventComms, ICertificateIssuedEvent, ICertificateFailedEvent } from './proxies/smart-proxy/models/interfaces.js';
export * from './proxies/smart-proxy/utils/index.js'; export * from './proxies/smart-proxy/utils/index.js';
+1 -1
View File
@@ -2,6 +2,6 @@
* SmartProxy models * SmartProxy models
*/ */
// Export everything except IAcmeOptions from interfaces // Export everything except IAcmeOptions from interfaces
export type { ISmartProxyOptions, ISmartProxyCertStore, IConnectionRecord, TSmartProxyCertProvisionObject, ICertProvisionEventComms, ICertificateIssuedEvent, ICertificateFailedEvent } from './interfaces.js'; export type { ISmartProxyOptions, ISmartProxySecurityPolicy, ISmartProxyCertStore, IConnectionRecord, TSmartProxyCertProvisionObject, ICertProvisionEventComms, ICertificateIssuedEvent, ICertificateFailedEvent } from './interfaces.js';
export * from './route-types.js'; export * from './route-types.js';
export * from './metrics-types.js'; export * from './metrics-types.js';
+7 -1
View File
@@ -29,6 +29,11 @@ export interface ISmartProxyCertStore {
} }
import type { IRouteConfig } from './route-types.js'; import type { IRouteConfig } from './route-types.js';
export interface ISmartProxySecurityPolicy {
blockedIps?: string[];
blockedCidrs?: string[];
}
/** /**
* Provision object for static or HTTP-01 certificate * Provision object for static or HTTP-01 certificate
*/ */
@@ -137,6 +142,7 @@ export interface ISmartProxyOptions {
// Rate limiting and security // Rate limiting and security
maxConnectionsPerIP?: number; // Maximum simultaneous connections from a single IP maxConnectionsPerIP?: number; // Maximum simultaneous connections from a single IP
connectionRateLimitPerMinute?: number; // Max new connections per minute from a single IP connectionRateLimitPerMinute?: number; // Max new connections per minute from a single IP
securityPolicy?: ISmartProxySecurityPolicy; // Global ingress block policy, enforced before routing
// Enhanced keep-alive settings // Enhanced keep-alive settings
keepAliveTreatment?: 'standard' | 'extended' | 'immortal'; // How to treat keep-alive connections keepAliveTreatment?: 'standard' | 'extended' | 'immortal'; // How to treat keep-alive connections
@@ -276,4 +282,4 @@ export interface IConnectionRecord {
path?: string; path?: string;
headers?: Record<string, string>; headers?: Record<string, string>;
}; };
} }
+11 -1
View File
@@ -29,6 +29,11 @@ export interface IThroughputHistoryPoint {
out: number; out: number;
} }
export interface IRequestRateMetrics {
perSecond: number;
lastMinute: number;
}
/** /**
* Main metrics interface with clean, grouped API * Main metrics interface with clean, grouped API
*/ */
@@ -57,6 +62,10 @@ export interface IMetrics {
byRoute(): Map<string, number>; byRoute(): Map<string, number>;
byIP(): Map<string, number>; byIP(): Map<string, number>;
topIPs(limit?: number): Array<{ ip: string; count: number }>; topIPs(limit?: number): Array<{ ip: string; count: number }>;
/** Per-IP domain request counts: IP -> { domain -> count }. */
domainRequestsByIP(): Map<string, Map<string, number>>;
/** Top IP-domain pairs sorted by request count descending. */
topDomainRequests(limit?: number): Array<{ ip: string; domain: string; count: number }>;
frontendProtocols(): IProtocolDistribution; frontendProtocols(): IProtocolDistribution;
backendProtocols(): IProtocolDistribution; backendProtocols(): IProtocolDistribution;
}; };
@@ -77,6 +86,7 @@ export interface IMetrics {
perSecond(): number; perSecond(): number;
perMinute(): number; perMinute(): number;
total(): number; total(): number;
byDomain(): Map<string, IRequestRateMetrics>;
}; };
// Cumulative totals // Cumulative totals
@@ -181,4 +191,4 @@ export interface IByteTracker {
bytesOut: number; bytesOut: number;
startTime: number; startTime: number;
lastUpdate: number; lastUpdate: number;
} }
+4 -2
View File
@@ -141,8 +141,10 @@ export interface IRouteAuthentication {
* Security options for routes * Security options for routes
*/ */
export interface IRouteSecurity { export interface IRouteSecurity {
// Access control lists // Access control lists.
ipAllowList?: string[]; // IP addresses that are allowed to connect // Entries can be plain IP/CIDR strings (full route access) or
// objects { ip, domains } to scope access to specific domains on this route.
ipAllowList?: Array<string | { ip: string; domains: string[] }>;
ipBlockList?: string[]; // IP addresses that are blocked from connecting ipBlockList?: string[]; // IP addresses that are blocked from connecting
// Connection limits // Connection limits
+167
View File
@@ -0,0 +1,167 @@
import type { IProtocolCacheEntry, IProtocolDistribution } from './metrics-types.js';
import type { IAcmeOptions, ISmartProxyOptions, ISmartProxySecurityPolicy } from './interfaces.js';
import type {
IRouteAction,
IRouteConfig,
IRouteMatch,
IRouteTarget,
ITargetMatch,
IRouteUdp,
} from './route-types.js';
export type TRustHeaderMatchers = Record<string, string>;
export interface IRustRouteMatch extends Omit<IRouteMatch, 'headers'> {
headers?: TRustHeaderMatchers;
}
export interface IRustTargetMatch extends Omit<ITargetMatch, 'headers'> {
headers?: TRustHeaderMatchers;
}
export interface IRustRouteTarget extends Omit<IRouteTarget, 'host' | 'port' | 'match'> {
host: string | string[];
port: number | 'preserve';
match?: IRustTargetMatch;
}
export interface IRustRouteUdp extends Omit<IRouteUdp, 'maxSessionsPerIP'> {
maxSessionsPerIp?: number;
}
export interface IRustDefaultConfig extends Omit<NonNullable<ISmartProxyOptions['defaults']>, 'preserveSourceIP'> {
preserveSourceIp?: boolean;
}
export interface IRustRouteAction
extends Omit<IRouteAction, 'targets' | 'socketHandler' | 'datagramHandler' | 'forwardingEngine' | 'nftables' | 'udp'> {
targets?: IRustRouteTarget[];
udp?: IRustRouteUdp;
}
export interface IRustRouteConfig extends Omit<IRouteConfig, 'match' | 'action'> {
match: IRustRouteMatch;
action: IRustRouteAction;
}
export interface IRustAcmeOptions extends Omit<IAcmeOptions, 'routeForwards'> {}
export interface IRustProxyOptions {
routes: IRustRouteConfig[];
preserveSourceIp?: boolean;
proxyIps?: string[];
acceptProxyProtocol?: boolean;
sendProxyProtocol?: boolean;
defaults?: IRustDefaultConfig;
connectionTimeout?: number;
initialDataTimeout?: number;
socketTimeout?: number;
inactivityCheckInterval?: number;
maxConnectionLifetime?: number;
inactivityTimeout?: number;
gracefulShutdownTimeout?: number;
noDelay?: boolean;
keepAlive?: boolean;
keepAliveInitialDelay?: number;
maxPendingDataSize?: number;
disableInactivityCheck?: boolean;
enableKeepAliveProbes?: boolean;
enableDetailedLogging?: boolean;
enableTlsDebugLogging?: boolean;
enableRandomizedTimeouts?: boolean;
maxConnectionsPerIp?: number;
connectionRateLimitPerMinute?: number;
keepAliveTreatment?: ISmartProxyOptions['keepAliveTreatment'];
keepAliveInactivityMultiplier?: number;
extendedKeepAliveLifetime?: number;
metrics?: ISmartProxyOptions['metrics'];
securityPolicy?: ISmartProxySecurityPolicy;
acme?: IRustAcmeOptions;
}
export interface IRustStatistics {
activeConnections: number;
totalConnections: number;
routesCount: number;
listeningPorts: number[];
uptimeSeconds: number;
}
export interface IRustCertificateStatus {
domain: string;
source: string;
expiresAt: number;
isValid: boolean;
}
export interface IRustThroughputSample {
timestampMs: number;
bytesIn: number;
bytesOut: number;
}
export interface IRustRouteMetrics {
activeConnections: number;
totalConnections: number;
bytesIn: number;
bytesOut: number;
throughputInBytesPerSec: number;
throughputOutBytesPerSec: number;
throughputRecentInBytesPerSec: number;
throughputRecentOutBytesPerSec: number;
}
export interface IRustIpMetrics {
activeConnections: number;
totalConnections: number;
bytesIn: number;
bytesOut: number;
throughputInBytesPerSec: number;
throughputOutBytesPerSec: number;
domainRequests: Record<string, number>;
}
export interface IRustBackendMetrics {
activeConnections: number;
totalConnections: number;
protocol: string;
connectErrors: number;
handshakeErrors: number;
requestErrors: number;
totalConnectTimeUs: number;
connectCount: number;
poolHits: number;
poolMisses: number;
h2Failures: number;
}
export interface IRustHttpDomainRequestMetrics {
requestsPerSecond: number;
requestsLastMinute: number;
}
export interface IRustMetricsSnapshot {
activeConnections: number;
totalConnections: number;
bytesIn: number;
bytesOut: number;
throughputInBytesPerSec: number;
throughputOutBytesPerSec: number;
throughputRecentInBytesPerSec: number;
throughputRecentOutBytesPerSec: number;
routes: Record<string, IRustRouteMetrics>;
ips: Record<string, IRustIpMetrics>;
backends: Record<string, IRustBackendMetrics>;
throughputHistory: IRustThroughputSample[];
totalHttpRequests: number;
httpRequestsPerSec: number;
httpRequestsPerSecRecent: number;
httpDomainRequests: Record<string, IRustHttpDomainRequestMetrics>;
activeUdpSessions: number;
totalUdpSessions: number;
totalDatagramsIn: number;
totalDatagramsOut: number;
detectedProtocols: IProtocolCacheEntry[];
frontendProtocols: IProtocolDistribution;
backendProtocols: IProtocolDistribution;
}
+13 -11
View File
@@ -1,5 +1,6 @@
import type { IRouteConfig, IRouteAction, IRouteTarget } from './models/route-types.js'; import type { IRouteConfig, IRouteAction, IRouteTarget } from './models/route-types.js';
import { logger } from '../../core/utils/logger.js'; import type { IRustRouteConfig } from './models/rust-types.js';
import { serializeRouteForRust } from './utils/rust-config.js';
/** /**
* Preprocesses routes before sending them to Rust. * Preprocesses routes before sending them to Rust.
@@ -24,7 +25,7 @@ export class RoutePreprocessor {
* - Non-serializable fields are stripped * - Non-serializable fields are stripped
* - Original routes are preserved in the local map for handler lookup * - Original routes are preserved in the local map for handler lookup
*/ */
public preprocessForRust(routes: IRouteConfig[]): IRouteConfig[] { public preprocessForRust(routes: IRouteConfig[]): IRustRouteConfig[] {
this.originalRoutes.clear(); this.originalRoutes.clear();
return routes.map((route, index) => this.preprocessRoute(route, index)); return routes.map((route, index) => this.preprocessRoute(route, index));
} }
@@ -43,7 +44,7 @@ export class RoutePreprocessor {
return new Map(this.originalRoutes); return new Map(this.originalRoutes);
} }
private preprocessRoute(route: IRouteConfig, index: number): IRouteConfig { private preprocessRoute(route: IRouteConfig, index: number): IRustRouteConfig {
const routeKey = route.name || route.id || `route_${index}`; const routeKey = route.name || route.id || `route_${index}`;
// Check if this route needs TS-side handling // Check if this route needs TS-side handling
@@ -57,7 +58,7 @@ export class RoutePreprocessor {
// Create a clean copy for Rust // Create a clean copy for Rust
const cleanRoute: IRouteConfig = { const cleanRoute: IRouteConfig = {
...route, ...route,
action: this.cleanAction(route.action, routeKey, needsTsHandling), action: this.cleanAction(route.action, needsTsHandling),
}; };
// Ensure we have a name for handler lookup // Ensure we have a name for handler lookup
@@ -65,7 +66,7 @@ export class RoutePreprocessor {
cleanRoute.name = routeKey; cleanRoute.name = routeKey;
} }
return cleanRoute; return serializeRouteForRust(cleanRoute);
} }
private routeNeedsTsHandling(route: IRouteConfig): boolean { private routeNeedsTsHandling(route: IRouteConfig): boolean {
@@ -91,15 +92,16 @@ export class RoutePreprocessor {
return false; return false;
} }
private cleanAction(action: IRouteAction, routeKey: string, needsTsHandling: boolean): IRouteAction { private cleanAction(action: IRouteAction, needsTsHandling: boolean): IRouteAction {
const cleanAction: IRouteAction = { ...action }; let cleanAction: IRouteAction = { ...action };
if (needsTsHandling) { if (needsTsHandling) {
// Convert to socket-handler type for Rust (Rust will relay back to TS) // Convert to socket-handler type for Rust (Rust will relay back to TS)
cleanAction.type = 'socket-handler'; const { socketHandler: _socketHandler, datagramHandler: _datagramHandler, ...serializableAction } = cleanAction;
// Remove the JS handlers (not serializable) cleanAction = {
delete (cleanAction as any).socketHandler; ...serializableAction,
delete (cleanAction as any).datagramHandler; type: 'socket-handler',
};
} }
// Clean targets - replace functions with static values // Clean targets - replace functions with static values
+83 -52
View File
@@ -1,5 +1,6 @@
import type { IMetrics, IBackendMetrics, IProtocolCacheEntry, IProtocolDistribution, IThroughputData, IThroughputHistoryPoint } from './models/metrics-types.js'; import type { IMetrics, IBackendMetrics, IProtocolCacheEntry, IProtocolDistribution, IRequestRateMetrics, IThroughputData, IThroughputHistoryPoint } from './models/metrics-types.js';
import type { RustProxyBridge } from './rust-proxy-bridge.js'; import type { RustProxyBridge } from './rust-proxy-bridge.js';
import type { IRustBackendMetrics, IRustHttpDomainRequestMetrics, IRustIpMetrics, IRustMetricsSnapshot, IRustRouteMetrics } from './models/rust-types.js';
/** /**
* Adapts Rust JSON metrics to the IMetrics interface. * Adapts Rust JSON metrics to the IMetrics interface.
@@ -14,7 +15,7 @@ import type { RustProxyBridge } from './rust-proxy-bridge.js';
*/ */
export class RustMetricsAdapter implements IMetrics { export class RustMetricsAdapter implements IMetrics {
private bridge: RustProxyBridge; private bridge: RustProxyBridge;
private cache: any = null; private cache: IRustMetricsSnapshot | null = null;
private pollTimer: ReturnType<typeof setInterval> | null = null; private pollTimer: ReturnType<typeof setInterval> | null = null;
private pollIntervalMs: number; private pollIntervalMs: number;
@@ -65,8 +66,8 @@ export class RustMetricsAdapter implements IMetrics {
byRoute: (): Map<string, number> => { byRoute: (): Map<string, number> => {
const result = new Map<string, number>(); const result = new Map<string, number>();
if (this.cache?.routes) { if (this.cache?.routes) {
for (const [name, rm] of Object.entries(this.cache.routes)) { for (const [name, rm] of Object.entries(this.cache.routes) as Array<[string, IRustRouteMetrics]>) {
result.set(name, (rm as any).activeConnections ?? 0); result.set(name, rm.activeConnections ?? 0);
} }
} }
return result; return result;
@@ -74,8 +75,8 @@ export class RustMetricsAdapter implements IMetrics {
byIP: (): Map<string, number> => { byIP: (): Map<string, number> => {
const result = new Map<string, number>(); const result = new Map<string, number>();
if (this.cache?.ips) { if (this.cache?.ips) {
for (const [ip, im] of Object.entries(this.cache.ips)) { for (const [ip, im] of Object.entries(this.cache.ips) as Array<[string, IRustIpMetrics]>) {
result.set(ip, (im as any).activeConnections ?? 0); result.set(ip, im.activeConnections ?? 0);
} }
} }
return result; return result;
@@ -83,8 +84,41 @@ export class RustMetricsAdapter implements IMetrics {
topIPs: (limit: number = 10): Array<{ ip: string; count: number }> => { topIPs: (limit: number = 10): Array<{ ip: string; count: number }> => {
const result: Array<{ ip: string; count: number }> = []; const result: Array<{ ip: string; count: number }> = [];
if (this.cache?.ips) { if (this.cache?.ips) {
for (const [ip, im] of Object.entries(this.cache.ips)) { for (const [ip, im] of Object.entries(this.cache.ips) as Array<[string, IRustIpMetrics]>) {
result.push({ ip, count: (im as any).activeConnections ?? 0 }); result.push({ ip, count: im.activeConnections ?? 0 });
}
}
result.sort((a, b) => b.count - a.count);
return result.slice(0, limit);
},
domainRequestsByIP: (): Map<string, Map<string, number>> => {
const result = new Map<string, Map<string, number>>();
if (this.cache?.ips) {
for (const [ip, im] of Object.entries(this.cache.ips) as Array<[string, IRustIpMetrics]>) {
const dr = im.domainRequests;
if (dr && typeof dr === 'object') {
const domainMap = new Map<string, number>();
for (const [domain, count] of Object.entries(dr)) {
domainMap.set(domain, count as number);
}
if (domainMap.size > 0) {
result.set(ip, domainMap);
}
}
}
}
return result;
},
topDomainRequests: (limit: number = 20): Array<{ ip: string; domain: string; count: number }> => {
const result: Array<{ ip: string; domain: string; count: number }> = [];
if (this.cache?.ips) {
for (const [ip, im] of Object.entries(this.cache.ips) as Array<[string, IRustIpMetrics]>) {
const dr = im.domainRequests;
if (dr && typeof dr === 'object') {
for (const [domain, count] of Object.entries(dr)) {
result.push({ ip, domain, count: count as number });
}
}
} }
} }
result.sort((a, b) => b.count - a.count); result.sort((a, b) => b.count - a.count);
@@ -106,27 +140,14 @@ export class RustMetricsAdapter implements IMetrics {
}; };
}, },
backendProtocols: (): IProtocolDistribution => { backendProtocols: (): IProtocolDistribution => {
// Merge per-backend h1/h2/h3 data with aggregate ws/other counters
const bp = this.cache?.backendProtocols; const bp = this.cache?.backendProtocols;
let h1Active = 0, h1Total = 0;
let h2Active = 0, h2Total = 0;
let h3Active = 0, h3Total = 0;
if (this.cache?.backends) {
for (const bm of Object.values(this.cache.backends)) {
const m = bm as any;
const active = m.activeConnections ?? 0;
const total = m.totalConnections ?? 0;
switch (m.protocol) {
case 'h2': h2Active += active; h2Total += total; break;
case 'h3': h3Active += active; h3Total += total; break;
default: h1Active += active; h1Total += total; break;
}
}
}
return { return {
h1Active, h1Total, h1Active: bp?.h1Active ?? 0,
h2Active, h2Total, h1Total: bp?.h1Total ?? 0,
h3Active, h3Total, h2Active: bp?.h2Active ?? 0,
h2Total: bp?.h2Total ?? 0,
h3Active: bp?.h3Active ?? 0,
h3Total: bp?.h3Total ?? 0,
wsActive: bp?.wsActive ?? 0, wsActive: bp?.wsActive ?? 0,
wsTotal: bp?.wsTotal ?? 0, wsTotal: bp?.wsTotal ?? 0,
otherActive: bp?.otherActive ?? 0, otherActive: bp?.otherActive ?? 0,
@@ -156,7 +177,7 @@ export class RustMetricsAdapter implements IMetrics {
}, },
history: (seconds: number): Array<IThroughputHistoryPoint> => { history: (seconds: number): Array<IThroughputHistoryPoint> => {
if (!this.cache?.throughputHistory) return []; if (!this.cache?.throughputHistory) return [];
return this.cache.throughputHistory.slice(-seconds).map((p: any) => ({ return this.cache.throughputHistory.slice(-seconds).map((p) => ({
timestamp: p.timestampMs, timestamp: p.timestampMs,
in: p.bytesIn, in: p.bytesIn,
out: p.bytesOut, out: p.bytesOut,
@@ -165,10 +186,10 @@ export class RustMetricsAdapter implements IMetrics {
byRoute: (_windowSeconds?: number): Map<string, IThroughputData> => { byRoute: (_windowSeconds?: number): Map<string, IThroughputData> => {
const result = new Map<string, IThroughputData>(); const result = new Map<string, IThroughputData>();
if (this.cache?.routes) { if (this.cache?.routes) {
for (const [name, rm] of Object.entries(this.cache.routes)) { for (const [name, rm] of Object.entries(this.cache.routes) as Array<[string, IRustRouteMetrics]>) {
result.set(name, { result.set(name, {
in: (rm as any).throughputInBytesPerSec ?? 0, in: rm.throughputInBytesPerSec ?? 0,
out: (rm as any).throughputOutBytesPerSec ?? 0, out: rm.throughputOutBytesPerSec ?? 0,
}); });
} }
} }
@@ -177,10 +198,10 @@ export class RustMetricsAdapter implements IMetrics {
byIP: (_windowSeconds?: number): Map<string, IThroughputData> => { byIP: (_windowSeconds?: number): Map<string, IThroughputData> => {
const result = new Map<string, IThroughputData>(); const result = new Map<string, IThroughputData>();
if (this.cache?.ips) { if (this.cache?.ips) {
for (const [ip, im] of Object.entries(this.cache.ips)) { for (const [ip, im] of Object.entries(this.cache.ips) as Array<[string, IRustIpMetrics]>) {
result.set(ip, { result.set(ip, {
in: (im as any).throughputInBytesPerSec ?? 0, in: im.throughputInBytesPerSec ?? 0,
out: (im as any).throughputOutBytesPerSec ?? 0, out: im.throughputOutBytesPerSec ?? 0,
}); });
} }
} }
@@ -198,6 +219,18 @@ export class RustMetricsAdapter implements IMetrics {
total: (): number => { total: (): number => {
return this.cache?.totalHttpRequests ?? this.cache?.totalConnections ?? 0; return this.cache?.totalHttpRequests ?? this.cache?.totalConnections ?? 0;
}, },
byDomain: (): Map<string, IRequestRateMetrics> => {
const result = new Map<string, IRequestRateMetrics>();
if (this.cache?.httpDomainRequests) {
for (const [domain, metrics] of Object.entries(this.cache.httpDomainRequests) as Array<[string, IRustHttpDomainRequestMetrics]>) {
result.set(domain, {
perSecond: metrics.requestsPerSecond ?? 0,
lastMinute: metrics.requestsLastMinute ?? 0,
});
}
}
return result;
},
}; };
public totals = { public totals = {
@@ -216,23 +249,22 @@ export class RustMetricsAdapter implements IMetrics {
byBackend: (): Map<string, IBackendMetrics> => { byBackend: (): Map<string, IBackendMetrics> => {
const result = new Map<string, IBackendMetrics>(); const result = new Map<string, IBackendMetrics>();
if (this.cache?.backends) { if (this.cache?.backends) {
for (const [key, bm] of Object.entries(this.cache.backends)) { for (const [key, bm] of Object.entries(this.cache.backends) as Array<[string, IRustBackendMetrics]>) {
const m = bm as any; const totalTimeUs = bm.totalConnectTimeUs ?? 0;
const totalTimeUs = m.totalConnectTimeUs ?? 0; const count = bm.connectCount ?? 0;
const count = m.connectCount ?? 0; const poolHits = bm.poolHits ?? 0;
const poolHits = m.poolHits ?? 0; const poolMisses = bm.poolMisses ?? 0;
const poolMisses = m.poolMisses ?? 0;
const poolTotal = poolHits + poolMisses; const poolTotal = poolHits + poolMisses;
result.set(key, { result.set(key, {
protocol: m.protocol ?? 'unknown', protocol: bm.protocol ?? 'unknown',
activeConnections: m.activeConnections ?? 0, activeConnections: bm.activeConnections ?? 0,
totalConnections: m.totalConnections ?? 0, totalConnections: bm.totalConnections ?? 0,
connectErrors: m.connectErrors ?? 0, connectErrors: bm.connectErrors ?? 0,
handshakeErrors: m.handshakeErrors ?? 0, handshakeErrors: bm.handshakeErrors ?? 0,
requestErrors: m.requestErrors ?? 0, requestErrors: bm.requestErrors ?? 0,
avgConnectTimeMs: count > 0 ? (totalTimeUs / count) / 1000 : 0, avgConnectTimeMs: count > 0 ? (totalTimeUs / count) / 1000 : 0,
poolHitRate: poolTotal > 0 ? poolHits / poolTotal : 0, poolHitRate: poolTotal > 0 ? poolHits / poolTotal : 0,
h2Failures: m.h2Failures ?? 0, h2Failures: bm.h2Failures ?? 0,
}); });
} }
} }
@@ -241,8 +273,8 @@ export class RustMetricsAdapter implements IMetrics {
protocols: (): Map<string, string> => { protocols: (): Map<string, string> => {
const result = new Map<string, string>(); const result = new Map<string, string>();
if (this.cache?.backends) { if (this.cache?.backends) {
for (const [key, bm] of Object.entries(this.cache.backends)) { for (const [key, bm] of Object.entries(this.cache.backends) as Array<[string, IRustBackendMetrics]>) {
result.set(key, (bm as any).protocol ?? 'unknown'); result.set(key, bm.protocol ?? 'unknown');
} }
} }
return result; return result;
@@ -250,9 +282,8 @@ export class RustMetricsAdapter implements IMetrics {
topByErrors: (limit: number = 10): Array<{ backend: string; errors: number }> => { topByErrors: (limit: number = 10): Array<{ backend: string; errors: number }> => {
const result: Array<{ backend: string; errors: number }> = []; const result: Array<{ backend: string; errors: number }> = [];
if (this.cache?.backends) { if (this.cache?.backends) {
for (const [key, bm] of Object.entries(this.cache.backends)) { for (const [key, bm] of Object.entries(this.cache.backends) as Array<[string, IRustBackendMetrics]>) {
const m = bm as any; const errors = (bm.connectErrors ?? 0) + (bm.handshakeErrors ?? 0) + (bm.requestErrors ?? 0);
const errors = (m.connectErrors ?? 0) + (m.handshakeErrors ?? 0) + (m.requestErrors ?? 0);
if (errors > 0) result.push({ backend: key, errors }); if (errors > 0) result.push({ backend: key, errors });
} }
} }
+30 -18
View File
@@ -1,23 +1,31 @@
import * as plugins from '../../plugins.js'; import * as plugins from '../../plugins.js';
import { logger } from '../../core/utils/logger.js'; import { logger } from '../../core/utils/logger.js';
import type { IRouteConfig } from './models/route-types.js'; import type {
IRustCertificateStatus,
IRustMetricsSnapshot,
IRustProxyOptions,
IRustRouteConfig,
IRustStatistics,
} from './models/rust-types.js';
import type { ISmartProxySecurityPolicy } from './models/interfaces.js';
/** /**
* Type-safe command definitions for the Rust proxy IPC protocol. * Type-safe command definitions for the Rust proxy IPC protocol.
*/ */
type TSmartProxyCommands = { type TSmartProxyCommands = {
start: { params: { config: any }; result: void }; start: { params: { config: IRustProxyOptions }; result: void };
stop: { params: Record<string, never>; result: void }; stop: { params: Record<string, never>; result: void };
updateRoutes: { params: { routes: IRouteConfig[] }; result: void }; updateRoutes: { params: { routes: IRustRouteConfig[] }; result: void };
getMetrics: { params: Record<string, never>; result: any }; setSecurityPolicy: { params: { policy: ISmartProxySecurityPolicy }; result: void };
getStatistics: { params: Record<string, never>; result: any }; getMetrics: { params: Record<string, never>; result: IRustMetricsSnapshot };
provisionCertificate: { params: { routeName: string }; result: void }; getStatistics: { params: Record<string, never>; result: IRustStatistics };
renewCertificate: { params: { routeName: string }; result: void }; provisionCertificate: { params: { routeName: string }; result: void };
getCertificateStatus: { params: { routeName: string }; result: any }; renewCertificate: { params: { routeName: string }; result: void };
getListeningPorts: { params: Record<string, never>; result: { ports: number[] } }; getCertificateStatus: { params: { routeName: string }; result: IRustCertificateStatus | null };
setSocketHandlerRelay: { params: { socketPath: string }; result: void }; getListeningPorts: { params: Record<string, never>; result: { ports: number[] } };
addListeningPort: { params: { port: number }; result: void }; setSocketHandlerRelay: { params: { socketPath: string }; result: void };
removeListeningPort: { params: { port: number }; result: void }; addListeningPort: { params: { port: number }; result: void };
removeListeningPort: { params: { port: number }; result: void };
loadCertificate: { params: { domain: string; cert: string; key: string; ca?: string }; result: void }; loadCertificate: { params: { domain: string; cert: string; key: string; ca?: string }; result: void };
setDatagramHandlerRelay: { params: { socketPath: string }; result: void }; setDatagramHandlerRelay: { params: { socketPath: string }; result: void };
}; };
@@ -121,7 +129,7 @@ export class RustProxyBridge extends plugins.EventEmitter {
// --- Convenience methods for each management command --- // --- Convenience methods for each management command ---
public async startProxy(config: any): Promise<void> { public async startProxy(config: IRustProxyOptions): Promise<void> {
await this.bridge.sendCommand('start', { config }); await this.bridge.sendCommand('start', { config });
} }
@@ -129,15 +137,19 @@ export class RustProxyBridge extends plugins.EventEmitter {
await this.bridge.sendCommand('stop', {} as Record<string, never>); await this.bridge.sendCommand('stop', {} as Record<string, never>);
} }
public async updateRoutes(routes: IRouteConfig[]): Promise<void> { public async updateRoutes(routes: IRustRouteConfig[]): Promise<void> {
await this.bridge.sendCommand('updateRoutes', { routes }); await this.bridge.sendCommand('updateRoutes', { routes });
} }
public async getMetrics(): Promise<any> { public async setSecurityPolicy(policy: ISmartProxySecurityPolicy): Promise<void> {
await this.bridge.sendCommand('setSecurityPolicy', { policy });
}
public async getMetrics(): Promise<IRustMetricsSnapshot> {
return this.bridge.sendCommand('getMetrics', {} as Record<string, never>); return this.bridge.sendCommand('getMetrics', {} as Record<string, never>);
} }
public async getStatistics(): Promise<any> { public async getStatistics(): Promise<IRustStatistics> {
return this.bridge.sendCommand('getStatistics', {} as Record<string, never>); return this.bridge.sendCommand('getStatistics', {} as Record<string, never>);
} }
@@ -149,7 +161,7 @@ export class RustProxyBridge extends plugins.EventEmitter {
await this.bridge.sendCommand('renewCertificate', { routeName }); await this.bridge.sendCommand('renewCertificate', { routeName });
} }
public async getCertificateStatus(routeName: string): Promise<any> { public async getCertificateStatus(routeName: string): Promise<IRustCertificateStatus | null> {
return this.bridge.sendCommand('getCertificateStatus', { routeName }); return this.bridge.sendCommand('getCertificateStatus', { routeName });
} }
+16 -34
View File
@@ -11,14 +11,16 @@ import { RustMetricsAdapter } from './rust-metrics-adapter.js';
// Route management // Route management
import { SharedRouteManager as RouteManager } from '../../core/routing/route-manager.js'; import { SharedRouteManager as RouteManager } from '../../core/routing/route-manager.js';
import { RouteValidator } from './utils/route-validator.js'; import { RouteValidator } from './utils/route-validator.js';
import { buildRustProxyOptions } from './utils/rust-config.js';
import { generateDefaultCertificate } from './utils/default-cert-generator.js'; import { generateDefaultCertificate } from './utils/default-cert-generator.js';
import { Mutex } from './utils/mutex.js'; import { Mutex } from './utils/mutex.js';
import { ConcurrencySemaphore } from './utils/concurrency-semaphore.js'; import { ConcurrencySemaphore } from './utils/concurrency-semaphore.js';
// Types // Types
import type { ISmartProxyOptions, TSmartProxyCertProvisionObject, IAcmeOptions, ICertProvisionEventComms, ICertificateIssuedEvent, ICertificateFailedEvent } from './models/interfaces.js'; import type { ISmartProxyOptions, ISmartProxySecurityPolicy, TSmartProxyCertProvisionObject, IAcmeOptions, ICertProvisionEventComms, ICertificateIssuedEvent, ICertificateFailedEvent } from './models/interfaces.js';
import type { IRouteConfig } from './models/route-types.js'; import type { IRouteConfig } from './models/route-types.js';
import type { IMetrics } from './models/metrics-types.js'; import type { IMetrics } from './models/metrics-types.js';
import type { IRustCertificateStatus, IRustProxyOptions, IRustStatistics } from './models/rust-types.js';
/** /**
* SmartProxy - Rust-backed proxy engine with TypeScript configuration API. * SmartProxy - Rust-backed proxy engine with TypeScript configuration API.
@@ -348,6 +350,15 @@ export class SmartProxy extends plugins.EventEmitter {
.catch((err) => logger.log('error', `Unexpected error in cert provisioning after route update: ${err.message}`, { component: 'smart-proxy' })); .catch((err) => logger.log('error', `Unexpected error in cert provisioning after route update: ${err.message}`, { component: 'smart-proxy' }));
} }
/**
* Update the global ingress security policy without changing routes.
* The Rust engine applies this before route selection and backend connection.
*/
public async updateSecurityPolicy(policy: ISmartProxySecurityPolicy): Promise<void> {
this.settings.securityPolicy = policy;
await this.bridge.setSecurityPolicy(policy);
}
/** /**
* Provision a certificate for a named route. * Provision a certificate for a named route.
*/ */
@@ -365,7 +376,7 @@ export class SmartProxy extends plugins.EventEmitter {
/** /**
* Get certificate status for a route (async - calls Rust). * Get certificate status for a route (async - calls Rust).
*/ */
public async getCertificateStatus(routeName: string): Promise<any> { public async getCertificateStatus(routeName: string): Promise<IRustCertificateStatus | null> {
return this.bridge.getCertificateStatus(routeName); return this.bridge.getCertificateStatus(routeName);
} }
@@ -379,7 +390,7 @@ export class SmartProxy extends plugins.EventEmitter {
/** /**
* Get statistics (async - calls Rust). * Get statistics (async - calls Rust).
*/ */
public async getStatistics(): Promise<any> { public async getStatistics(): Promise<IRustStatistics> {
return this.bridge.getStatistics(); return this.bridge.getStatistics();
} }
@@ -484,37 +495,8 @@ export class SmartProxy extends plugins.EventEmitter {
/** /**
* Build the Rust configuration object from TS settings. * Build the Rust configuration object from TS settings.
*/ */
private buildRustConfig(routes: IRouteConfig[], acmeOverride?: IAcmeOptions): any { private buildRustConfig(routes: IRustProxyOptions['routes'], acmeOverride?: IAcmeOptions): IRustProxyOptions {
const acme = acmeOverride !== undefined ? acmeOverride : this.settings.acme; return buildRustProxyOptions(this.settings, routes, acmeOverride);
return {
routes,
defaults: this.settings.defaults,
acme: acme
? {
enabled: acme.enabled,
email: acme.email,
useProduction: acme.useProduction,
port: acme.port,
renewThresholdDays: acme.renewThresholdDays,
autoRenew: acme.autoRenew,
renewCheckIntervalHours: acme.renewCheckIntervalHours,
}
: undefined,
connectionTimeout: this.settings.connectionTimeout,
initialDataTimeout: this.settings.initialDataTimeout,
socketTimeout: this.settings.socketTimeout,
maxConnectionLifetime: this.settings.maxConnectionLifetime,
gracefulShutdownTimeout: this.settings.gracefulShutdownTimeout,
maxConnectionsPerIp: this.settings.maxConnectionsPerIP,
connectionRateLimitPerMinute: this.settings.connectionRateLimitPerMinute,
keepAliveTreatment: this.settings.keepAliveTreatment,
keepAliveInactivityMultiplier: this.settings.keepAliveInactivityMultiplier,
extendedKeepAliveLifetime: this.settings.extendedKeepAliveLifetime,
proxyIps: this.settings.proxyIPs,
acceptProxyProtocol: this.settings.acceptProxyProtocol,
sendProxyProtocol: this.settings.sendProxyProtocol,
metrics: this.settings.metrics,
};
} }
/** /**
+22 -8
View File
@@ -168,14 +168,28 @@ export function routeMatchesHeaders(
if (!route.match?.headers || Object.keys(route.match.headers).length === 0) { if (!route.match?.headers || Object.keys(route.match.headers).length === 0) {
return true; // No headers specified means it matches any headers return true; // No headers specified means it matches any headers
} }
// Convert RegExp patterns to strings for HeaderMatcher for (const [headerName, expectedValue] of Object.entries(route.match.headers)) {
const stringHeaders: Record<string, string> = {}; const actualKey = Object.keys(headers).find((key) => key.toLowerCase() === headerName.toLowerCase());
for (const [key, value] of Object.entries(route.match.headers)) { const actualValue = actualKey ? headers[actualKey] : undefined;
stringHeaders[key] = value instanceof RegExp ? value.source : value;
if (actualValue === undefined) {
return false;
}
if (expectedValue instanceof RegExp) {
if (!expectedValue.test(actualValue)) {
return false;
}
continue;
}
if (!HeaderMatcher.match(expectedValue, actualValue)) {
return false;
}
} }
return HeaderMatcher.matchAll(stringHeaders, headers); return true;
} }
/** /**
@@ -283,4 +297,4 @@ export function generateRouteId(route: IRouteConfig): string {
*/ */
export function cloneRoute(route: IRouteConfig): IRouteConfig { export function cloneRoute(route: IRouteConfig): IRouteConfig {
return JSON.parse(JSON.stringify(route)); return JSON.parse(JSON.stringify(route));
} }
@@ -196,10 +196,19 @@ export class RouteValidator {
// Validate IP allow/block lists // Validate IP allow/block lists
if (route.security.ipAllowList) { if (route.security.ipAllowList) {
const allowList = Array.isArray(route.security.ipAllowList) ? route.security.ipAllowList : [route.security.ipAllowList]; const allowList = Array.isArray(route.security.ipAllowList) ? route.security.ipAllowList : [route.security.ipAllowList];
for (const ip of allowList) { for (const entry of allowList) {
if (!this.isValidIPPattern(ip)) { if (typeof entry === 'string') {
errors.push(`Invalid IP pattern in allow list: ${ip}`); if (!this.isValidIPPattern(entry)) {
errors.push(`Invalid IP pattern in allow list: ${entry}`);
}
} else if (entry && typeof entry === 'object') {
if (!this.isValidIPPattern(entry.ip)) {
errors.push(`Invalid IP pattern in domain-scoped allow entry: ${entry.ip}`);
}
if (!Array.isArray(entry.domains) || entry.domains.length === 0) {
errors.push(`Domain-scoped allow entry for ${entry.ip} must have non-empty domains array`);
}
} }
} }
} }
+188
View File
@@ -0,0 +1,188 @@
import type { IAcmeOptions, ISmartProxyOptions } from '../models/interfaces.js';
import type { IRouteAction, IRouteConfig, IRouteMatch, IRouteTarget, ITargetMatch } from '../models/route-types.js';
import type {
IRustAcmeOptions,
IRustDefaultConfig,
IRustProxyOptions,
IRustRouteAction,
IRustRouteConfig,
IRustRouteMatch,
IRustRouteTarget,
IRustTargetMatch,
IRustRouteUdp,
TRustHeaderMatchers,
} from '../models/rust-types.js';
const SUPPORTED_REGEX_FLAGS = new Set(['i', 'm', 's', 'u', 'g']);
export function serializeHeaderMatchValue(value: string | RegExp): string {
if (typeof value === 'string') {
return value;
}
const unsupportedFlags = Array.from(new Set(value.flags)).filter((flag) => !SUPPORTED_REGEX_FLAGS.has(flag));
if (unsupportedFlags.length > 0) {
throw new Error(
`Header RegExp uses unsupported flags for Rust serialization: ${unsupportedFlags.join(', ')}`
);
}
return `/${value.source}/${value.flags}`;
}
export function serializeHeaderMatchers(headers?: Record<string, string | RegExp>): TRustHeaderMatchers | undefined {
if (!headers) {
return undefined;
}
return Object.fromEntries(
Object.entries(headers).map(([key, value]) => [key, serializeHeaderMatchValue(value)])
);
}
export function serializeTargetMatchForRust(match?: ITargetMatch): IRustTargetMatch | undefined {
if (!match) {
return undefined;
}
return {
...match,
headers: serializeHeaderMatchers(match.headers),
};
}
export function serializeRouteMatchForRust(match: IRouteMatch): IRustRouteMatch {
return {
...match,
headers: serializeHeaderMatchers(match.headers),
};
}
export function serializeRouteTargetForRust(target: IRouteTarget): IRustRouteTarget {
if (typeof target.host !== 'string' && !Array.isArray(target.host)) {
throw new Error('Route target host must be serialized before sending to Rust');
}
if (typeof target.port !== 'number' && target.port !== 'preserve') {
throw new Error('Route target port must be serialized before sending to Rust');
}
return {
...target,
host: target.host,
port: target.port,
match: serializeTargetMatchForRust(target.match),
};
}
function serializeUdpForRust(udp?: IRouteAction['udp']): IRustRouteUdp | undefined {
if (!udp) {
return undefined;
}
const { maxSessionsPerIP, ...rest } = udp;
return {
...rest,
maxSessionsPerIp: maxSessionsPerIP,
};
}
export function serializeRouteActionForRust(action: IRouteAction): IRustRouteAction {
const {
socketHandler: _socketHandler,
datagramHandler: _datagramHandler,
forwardingEngine: _forwardingEngine,
nftables: _nftables,
targets,
udp,
...rest
} = action;
return {
...rest,
targets: targets?.map((target) => serializeRouteTargetForRust(target)),
udp: serializeUdpForRust(udp),
};
}
export function serializeRouteForRust(route: IRouteConfig): IRustRouteConfig {
return {
...route,
match: serializeRouteMatchForRust(route.match),
action: serializeRouteActionForRust(route.action),
};
}
function serializeAcmeForRust(acme?: IAcmeOptions): IRustAcmeOptions | undefined {
if (!acme) {
return undefined;
}
return {
enabled: acme.enabled,
email: acme.email,
environment: acme.environment,
accountEmail: acme.accountEmail,
port: acme.port,
useProduction: acme.useProduction,
renewThresholdDays: acme.renewThresholdDays,
autoRenew: acme.autoRenew,
skipConfiguredCerts: acme.skipConfiguredCerts,
renewCheckIntervalHours: acme.renewCheckIntervalHours,
};
}
function serializeDefaultsForRust(defaults?: ISmartProxyOptions['defaults']): IRustDefaultConfig | undefined {
if (!defaults) {
return undefined;
}
const { preserveSourceIP, ...rest } = defaults;
return {
...rest,
preserveSourceIp: preserveSourceIP,
};
}
export function buildRustProxyOptions(
settings: ISmartProxyOptions,
routes: IRustRouteConfig[],
acmeOverride?: IAcmeOptions,
): IRustProxyOptions {
const acme = acmeOverride !== undefined ? acmeOverride : settings.acme;
return {
routes,
preserveSourceIp: settings.preserveSourceIP,
proxyIps: settings.proxyIPs,
acceptProxyProtocol: settings.acceptProxyProtocol,
sendProxyProtocol: settings.sendProxyProtocol,
defaults: serializeDefaultsForRust(settings.defaults),
connectionTimeout: settings.connectionTimeout,
initialDataTimeout: settings.initialDataTimeout,
socketTimeout: settings.socketTimeout,
inactivityCheckInterval: settings.inactivityCheckInterval,
maxConnectionLifetime: settings.maxConnectionLifetime,
inactivityTimeout: settings.inactivityTimeout,
gracefulShutdownTimeout: settings.gracefulShutdownTimeout,
noDelay: settings.noDelay,
keepAlive: settings.keepAlive,
keepAliveInitialDelay: settings.keepAliveInitialDelay,
maxPendingDataSize: settings.maxPendingDataSize,
disableInactivityCheck: settings.disableInactivityCheck,
enableKeepAliveProbes: settings.enableKeepAliveProbes,
enableDetailedLogging: settings.enableDetailedLogging,
enableTlsDebugLogging: settings.enableTlsDebugLogging,
enableRandomizedTimeouts: settings.enableRandomizedTimeouts,
maxConnectionsPerIp: settings.maxConnectionsPerIP,
connectionRateLimitPerMinute: settings.connectionRateLimitPerMinute,
keepAliveTreatment: settings.keepAliveTreatment,
keepAliveInactivityMultiplier: settings.keepAliveInactivityMultiplier,
extendedKeepAliveLifetime: settings.extendedKeepAliveLifetime,
metrics: settings.metrics,
securityPolicy: settings.securityPolicy,
acme: serializeAcmeForRust(acme),
};
}