Compare commits
22 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 30e5ab308f | |||
| d2a54b3491 | |||
| dc922c97df | |||
| 8d1bae7604 | |||
| 200e86e311 | |||
| a53a2c4ca5 | |||
| 6ee7237357 | |||
| b5b4c608f0 | |||
| af132f40fc | |||
| 781634446a | |||
| e988d935b6 | |||
| 99a026627d | |||
| 572e31587a | |||
| 8587fb997c | |||
| 9ba101c59b | |||
| 1ad3e61c15 | |||
| 3bfa451341 | |||
| 7b3ab7378b | |||
| 527c616cd4 | |||
| b04eb0ab17 | |||
| a55ff20391 | |||
| 3c24bf659b |
@@ -1,5 +1,80 @@
|
||||
# Changelog
|
||||
|
||||
## 2026-04-14 - 27.7.3 - fix(repo)
|
||||
no changes detected
|
||||
|
||||
|
||||
## 2026-04-14 - 27.7.2 - fix(docs)
|
||||
clarify metrics documentation for domain normalization and saturating gauges
|
||||
|
||||
- Document that per-IP domain keys are normalized to lowercase and have trailing dots stripped before counting.
|
||||
- Clarify that the saturating close pattern also applies to connection and UDP active gauges.
|
||||
|
||||
## 2026-04-14 - 27.7.1 - fix(rustproxy-http,rustproxy-metrics)
|
||||
fix domain-scoped request host detection and harden connection metrics cleanup
|
||||
|
||||
- use a shared request host extractor that falls back to URI authority so domain-scoped IP allow lists work for HTTP/2 and HTTP/3 requests without a Host header
|
||||
- add request filter and host extraction tests covering domain-scoped ACL behavior
|
||||
- prevent connection counters from underflowing during close handling and clean up per-IP metrics entries more safely
|
||||
- normalize tracked domain keys in metrics to reduce duplicate entries caused by case or trailing-dot variations
|
||||
|
||||
## 2026-04-13 - 27.7.0 - feat(smart-proxy)
|
||||
add typed Rust config serialization and regex header contract coverage
|
||||
|
||||
- serialize SmartProxy routes and top-level options into explicit Rust-safe types, including header regex literals, UDP field normalization, ACME, defaults, and proxy settings
|
||||
- support JS-style regex header literals with flags in Rust header matching and add cross-contract tests for route preprocessing and config deserialization
|
||||
- improve TypeScript safety for Rust bridge and metrics integration by replacing loose any-based payloads with dedicated Rust type definitions
|
||||
|
||||
## 2026-04-13 - 27.6.0 - feat(metrics)
|
||||
track per-IP domain request metrics across HTTP and TCP passthrough traffic
|
||||
|
||||
- records domain request counts per frontend IP from HTTP Host headers and TCP SNI
|
||||
- exposes per-IP domain maps and top IP-domain request pairs through the TypeScript metrics adapter
|
||||
- bounds per-IP domain tracking and prunes stale entries to limit memory growth
|
||||
- adds metrics system documentation covering architecture, collected data, and known gaps
|
||||
|
||||
## 2026-04-06 - 27.5.0 - feat(security)
|
||||
add domain-scoped IP allow list support across HTTP and passthrough filtering
|
||||
|
||||
- extend route security types to accept IP allow entries scoped to specific domains
|
||||
- apply domain-aware IP checks using Host headers for HTTP and SNI context for QUIC and passthrough connections
|
||||
- preserve compatibility for existing plain allow list entries and add validation and tests for scoped matching
|
||||
|
||||
## 2026-04-04 - 27.4.0 - feat(rustproxy)
|
||||
add HTTP/3 proxy service wiring for QUIC listeners
|
||||
|
||||
- registers H3ProxyService with the UDP listener manager so QUIC connections can serve HTTP/3
|
||||
- keeps proxy IP configuration intact while enabling HTTP/3 handling during listener setup
|
||||
|
||||
## 2026-04-04 - 27.3.1 - fix(metrics)
|
||||
correct frontend and backend protocol connection tracking across h1, h2, h3, and websocket traffic
|
||||
|
||||
- move frontend protocol accounting from per-request to connection lifetime tracking for HTTP/1, HTTP/2, and HTTP/3
|
||||
- add backend protocol guards to connection drivers so active protocol metrics reflect live upstream connections
|
||||
- prevent protocol counter underflow by using atomic saturating decrements in the metrics collector
|
||||
- read backend protocol distribution directly from cached aggregate counters in the Rust metrics adapter
|
||||
|
||||
## 2026-04-04 - 27.3.0 - feat(test)
|
||||
add end-to-end WebSocket proxy test coverage
|
||||
|
||||
- add comprehensive WebSocket e2e tests for upgrade handling, bidirectional messaging, header forwarding, close propagation, and large payloads
|
||||
- add ws and @types/ws as development dependencies to support the new test suite
|
||||
|
||||
## 2026-04-04 - 27.2.0 - feat(metrics)
|
||||
add frontend and backend protocol distribution metrics
|
||||
|
||||
- track active and total frontend protocol counts for h1, h2, h3, websocket, and other traffic
|
||||
- add backend protocol counters with RAII guards to ensure metrics are decremented on all exit paths
|
||||
- expose protocol distribution through the TypeScript metrics interfaces and Rust metrics adapter
|
||||
|
||||
## 2026-03-27 - 27.1.0 - feat(rustproxy-passthrough)
|
||||
add selective connection recycling for route, security, and certificate updates
|
||||
|
||||
- introduce a shared connection registry to track active TCP and QUIC connections by route, source IP, and domain
|
||||
- recycle only affected connections when route actions or security rules change instead of broadly invalidating traffic
|
||||
- gracefully recycle existing connections when TLS certificates change for a domain
|
||||
- apply route-level IP security checks to QUIC connections and share route cancellation state with UDP listeners
|
||||
|
||||
## 2026-03-26 - 27.0.0 - BREAKING CHANGE(smart-proxy)
|
||||
remove route helper APIs and standardize route configuration on plain route objects
|
||||
|
||||
|
||||
@@ -12,9 +12,11 @@
|
||||
"npm:@push.rocks/smartserve@^2.0.3": "2.0.3",
|
||||
"npm:@tsclass/tsclass@^9.5.0": "9.5.0",
|
||||
"npm:@types/node@^25.5.0": "25.5.0",
|
||||
"npm:@types/ws@^8.18.1": "8.18.1",
|
||||
"npm:minimatch@^10.2.4": "10.2.4",
|
||||
"npm:typescript@^6.0.2": "6.0.2",
|
||||
"npm:why-is-node-running@^3.2.2": "3.2.2"
|
||||
"npm:why-is-node-running@^3.2.2": "3.2.2",
|
||||
"npm:ws@^8.20.0": "8.20.0"
|
||||
},
|
||||
"npm": {
|
||||
"@api.global/typedrequest-interfaces@2.0.2": {
|
||||
@@ -6743,9 +6745,11 @@
|
||||
"npm:@push.rocks/smartserve@^2.0.3",
|
||||
"npm:@tsclass/tsclass@^9.5.0",
|
||||
"npm:@types/node@^25.5.0",
|
||||
"npm:@types/ws@^8.18.1",
|
||||
"npm:minimatch@^10.2.4",
|
||||
"npm:typescript@^6.0.2",
|
||||
"npm:why-is-node-running@^3.2.2"
|
||||
"npm:why-is-node-running@^3.2.2",
|
||||
"npm:ws@^8.20.0"
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
+4
-2
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@push.rocks/smartproxy",
|
||||
"version": "27.0.0",
|
||||
"version": "27.7.3",
|
||||
"private": false,
|
||||
"description": "A powerful proxy package with unified route-based configuration for high traffic management. Features include SSL/TLS support, flexible routing patterns, WebSocket handling, advanced security options, and automatic ACME certificate management.",
|
||||
"main": "dist_ts/index.js",
|
||||
@@ -22,8 +22,10 @@
|
||||
"@git.zone/tstest": "^3.6.0",
|
||||
"@push.rocks/smartserve": "^2.0.3",
|
||||
"@types/node": "^25.5.0",
|
||||
"@types/ws": "^8.18.1",
|
||||
"typescript": "^6.0.2",
|
||||
"why-is-node-running": "^3.2.2"
|
||||
"why-is-node-running": "^3.2.2",
|
||||
"ws": "^8.20.0"
|
||||
},
|
||||
"dependencies": {
|
||||
"@push.rocks/smartcrypto": "^2.0.4",
|
||||
|
||||
Generated
+7
-15
@@ -45,12 +45,18 @@ importers:
|
||||
'@types/node':
|
||||
specifier: ^25.5.0
|
||||
version: 25.5.0
|
||||
'@types/ws':
|
||||
specifier: ^8.18.1
|
||||
version: 8.18.1
|
||||
typescript:
|
||||
specifier: ^6.0.2
|
||||
version: 6.0.2
|
||||
why-is-node-running:
|
||||
specifier: ^3.2.2
|
||||
version: 3.2.2
|
||||
ws:
|
||||
specifier: ^8.20.0
|
||||
version: 8.20.0
|
||||
|
||||
packages:
|
||||
|
||||
@@ -3304,18 +3310,6 @@ packages:
|
||||
wrappy@1.0.2:
|
||||
resolution: {integrity: sha1-tSQ9jz7BqjXxNkYFvA0QNuMKtp8=}
|
||||
|
||||
ws@8.19.0:
|
||||
resolution: {integrity: sha512-blAT2mjOEIi0ZzruJfIhb3nps74PRWTCz1IjglWEEpQl5XS/UNama6u2/rjFkDDouqr4L67ry+1aGIALViWjDg==}
|
||||
engines: {node: '>=10.0.0'}
|
||||
peerDependencies:
|
||||
bufferutil: ^4.0.1
|
||||
utf-8-validate: '>=5.0.2'
|
||||
peerDependenciesMeta:
|
||||
bufferutil:
|
||||
optional: true
|
||||
utf-8-validate:
|
||||
optional: true
|
||||
|
||||
ws@8.20.0:
|
||||
resolution: {integrity: sha512-sAt8BhgNbzCtgGbt2OxmpuryO63ZoDk/sqaB/znQm94T4fCEsy/yV+7CdC1kJhOU9lboAEU7R3kquuycDoibVA==}
|
||||
engines: {node: '>=10.0.0'}
|
||||
@@ -5296,7 +5290,7 @@ snapshots:
|
||||
'@push.rocks/smartenv': 6.0.0
|
||||
'@push.rocks/smartlog': 3.2.1
|
||||
'@push.rocks/smartpath': 6.0.0
|
||||
ws: 8.19.0
|
||||
ws: 8.20.0
|
||||
transitivePeerDependencies:
|
||||
- bufferutil
|
||||
- utf-8-validate
|
||||
@@ -8033,8 +8027,6 @@ snapshots:
|
||||
|
||||
wrappy@1.0.2: {}
|
||||
|
||||
ws@8.19.0: {}
|
||||
|
||||
ws@8.20.0: {}
|
||||
|
||||
xml-parse-from-string@1.0.1: {}
|
||||
|
||||
@@ -0,0 +1,484 @@
|
||||
# SmartProxy Metrics System
|
||||
|
||||
## Architecture
|
||||
|
||||
Two-tier design separating the data plane from the observation plane:
|
||||
|
||||
**Hot path (per-chunk, lock-free):** All recording in the proxy data plane touches only `AtomicU64` counters. No `Mutex` is ever acquired on the forwarding path. `CountingBody` batches flushes every 64KB to reduce DashMap shard contention.
|
||||
|
||||
**Cold path (1Hz sampling):** A background tokio task drains pending atomics into `ThroughputTracker` circular buffers (Mutex-guarded), producing per-second throughput history. Same task prunes orphaned entries and cleans up rate limiter state.
|
||||
|
||||
**Read path (on-demand):** `snapshot()` reads all atomics and locks ThroughputTrackers to build a serializable `Metrics` struct. TypeScript polls this at 1s intervals via IPC.
|
||||
|
||||
```
|
||||
Data Plane (lock-free) Background (1Hz) Read Path
|
||||
───────────────────── ────────────────── ─────────
|
||||
record_bytes() ──> AtomicU64 ──┐
|
||||
record_http_request() ──> AtomicU64 ──┤
|
||||
connection_opened/closed() ──> AtomicU64 ──┤ sample_all() snapshot()
|
||||
backend_*() ──> DashMap<AtomicU64> ──┤────> drain atomics ──────> Metrics struct
|
||||
protocol_*() ──> AtomicU64 ──┤ feed ThroughputTrackers ──> JSON
|
||||
datagram_*() ──> AtomicU64 ──┘ prune orphans ──> IPC stdout
|
||||
──> TS cache
|
||||
──> IMetrics API
|
||||
```
|
||||
|
||||
### Key Types
|
||||
|
||||
| Type | Crate | Purpose |
|
||||
|---|---|---|
|
||||
| `MetricsCollector` | `rustproxy-metrics` | Central store. All DashMaps, atomics, and throughput trackers |
|
||||
| `ThroughputTracker` | `rustproxy-metrics` | Circular buffer of 1Hz samples. Default 3600 capacity (1 hour) |
|
||||
| `ForwardMetricsCtx` | `rustproxy-passthrough` | Carries `Arc<MetricsCollector>` + route_id + source_ip through TCP forwarding |
|
||||
| `CountingBody` | `rustproxy-http` | Wraps HTTP bodies, batches byte recording per 64KB, flushes on drop |
|
||||
| `ProtocolGuard` | `rustproxy-http` | RAII guard for frontend/backend protocol active/total counters |
|
||||
| `ConnectionGuard` | `rustproxy-passthrough` | RAII guard calling `connection_closed()` on drop |
|
||||
| `RustMetricsAdapter` | TypeScript | Polls Rust via IPC, implements `IMetrics` interface over cached JSON |
|
||||
|
||||
---
|
||||
|
||||
## What's Collected
|
||||
|
||||
### Global Counters
|
||||
|
||||
| Metric | Type | Updated by |
|
||||
|---|---|---|
|
||||
| Active connections | AtomicU64 | `connection_opened/closed` |
|
||||
| Total connections (lifetime) | AtomicU64 | `connection_opened` |
|
||||
| Total bytes in | AtomicU64 | `record_bytes` |
|
||||
| Total bytes out | AtomicU64 | `record_bytes` |
|
||||
| Total HTTP requests | AtomicU64 | `record_http_request` |
|
||||
| Active UDP sessions | AtomicU64 | `udp_session_opened/closed` |
|
||||
| Total UDP sessions | AtomicU64 | `udp_session_opened` |
|
||||
| Total datagrams in | AtomicU64 | `record_datagram_in` |
|
||||
| Total datagrams out | AtomicU64 | `record_datagram_out` |
|
||||
|
||||
### Per-Route Metrics (keyed by route ID string)
|
||||
|
||||
| Metric | Storage |
|
||||
|---|---|
|
||||
| Active connections | `DashMap<String, AtomicU64>` |
|
||||
| Total connections | `DashMap<String, AtomicU64>` |
|
||||
| Bytes in / out | `DashMap<String, AtomicU64>` |
|
||||
| Pending throughput (in, out) | `DashMap<String, (AtomicU64, AtomicU64)>` |
|
||||
| Throughput history | `DashMap<String, Mutex<ThroughputTracker>>` |
|
||||
|
||||
Entries are pruned via `retain_routes()` when routes are removed.
|
||||
|
||||
### Per-IP Metrics (keyed by IP string)
|
||||
|
||||
| Metric | Storage |
|
||||
|---|---|
|
||||
| Active connections | `DashMap<String, AtomicU64>` |
|
||||
| Total connections | `DashMap<String, AtomicU64>` |
|
||||
| Bytes in / out | `DashMap<String, AtomicU64>` |
|
||||
| Pending throughput (in, out) | `DashMap<String, (AtomicU64, AtomicU64)>` |
|
||||
| Throughput history | `DashMap<String, Mutex<ThroughputTracker>>` |
|
||||
| Domain requests | `DashMap<String, DashMap<String, AtomicU64>>` (IP → domain → count) |
|
||||
|
||||
All seven maps for an IP are evicted when its active connection count drops to 0. Safety-net pruning in `sample_all()` catches entries orphaned by races. Snapshots cap at 100 IPs, sorted by active connections descending.
|
||||
|
||||
**Domain request tracking:** Records which domains each frontend IP has requested. Populated from HTTP Host headers (for HTTP/1.1, HTTP/2, HTTP/3 requests) and from SNI (for TLS passthrough connections). Domain keys are normalized to lowercase with any trailing dot stripped so the same hostname does not fragment across multiple counters. Capped at 256 domains per IP (`MAX_DOMAINS_PER_IP`) to prevent subdomain-spray abuse. Inner DashMap uses 2 shards to minimise base memory per IP (~200 bytes). Common case (IP + domain both known) is two DashMap reads + one atomic increment with zero allocation.
|
||||
|
||||
### Per-Backend Metrics (keyed by "host:port")
|
||||
|
||||
| Metric | Storage |
|
||||
|---|---|
|
||||
| Active connections | `DashMap<String, AtomicU64>` |
|
||||
| Total connections | `DashMap<String, AtomicU64>` |
|
||||
| Detected protocol (h1/h2/h3) | `DashMap<String, String>` |
|
||||
| Connect errors | `DashMap<String, AtomicU64>` |
|
||||
| Handshake errors | `DashMap<String, AtomicU64>` |
|
||||
| Request errors | `DashMap<String, AtomicU64>` |
|
||||
| Total connect time (microseconds) | `DashMap<String, AtomicU64>` |
|
||||
| Connect count | `DashMap<String, AtomicU64>` |
|
||||
| Pool hits | `DashMap<String, AtomicU64>` |
|
||||
| Pool misses | `DashMap<String, AtomicU64>` |
|
||||
| H2 failures (fallback to H1) | `DashMap<String, AtomicU64>` |
|
||||
|
||||
All per-backend maps are evicted when active count reaches 0. Pruned via `retain_backends()`.
|
||||
|
||||
### Frontend Protocol Distribution
|
||||
|
||||
Tracked via `ProtocolGuard` RAII guards and `FrontendProtocolTracker`. Five protocol categories, each with active + total counters (AtomicU64):
|
||||
|
||||
| Protocol | Where detected |
|
||||
|---|---|
|
||||
| h1 | `FrontendProtocolTracker` on first HTTP/1.x request |
|
||||
| h2 | `FrontendProtocolTracker` on first HTTP/2 request |
|
||||
| h3 | `ProtocolGuard::frontend("h3")` in H3ProxyService |
|
||||
| ws | `ProtocolGuard::frontend("ws")` on WebSocket upgrade |
|
||||
| other | Fallback (TCP passthrough without HTTP) |
|
||||
|
||||
Uses `fetch_update` for saturating decrements to prevent underflow races. The same saturating-close pattern is also used for connection and UDP active gauges.
|
||||
|
||||
### Backend Protocol Distribution
|
||||
|
||||
Same five categories (h1/h2/h3/ws/other), tracked via `ProtocolGuard::backend()` at connection establishment. Backend h2 failures (fallback to h1) are separately counted.
|
||||
|
||||
### Throughput History
|
||||
|
||||
`ThroughputTracker` is a circular buffer storing `ThroughputSample { timestamp_ms, bytes_in, bytes_out }` at 1Hz.
|
||||
|
||||
- Global tracker: 1 instance, default 3600 capacity
|
||||
- Per-route trackers: 1 per active route
|
||||
- Per-IP trackers: 1 per connected IP (evicted with the IP)
|
||||
- HTTP request tracker: reuses ThroughputTracker with bytes_in = request count, bytes_out = 0
|
||||
|
||||
Query methods:
|
||||
- `instant()` — last 1 second average
|
||||
- `recent()` — last 10 seconds average
|
||||
- `throughput(N)` — last N seconds average
|
||||
- `history(N)` — last N raw samples in chronological order
|
||||
|
||||
Snapshots return 60 samples of global throughput history.
|
||||
|
||||
### Protocol Detection Cache
|
||||
|
||||
Not part of MetricsCollector. Maintained by `HttpProxyService`'s protocol detection system. Injected into the metrics snapshot at read time by `get_metrics()`.
|
||||
|
||||
Each entry records: host, port, domain, detected protocol (h1/h2/h3), H3 port, age, last accessed, last probed, suppression flags, cooldown timers, consecutive failure counts.
|
||||
|
||||
---
|
||||
|
||||
## Instrumentation Points
|
||||
|
||||
### TCP Passthrough (`rustproxy-passthrough`)
|
||||
|
||||
**Connection lifecycle** — `tcp_listener.rs`:
|
||||
- Accept: `conn_tracker.connection_opened(&ip)` (rate limiter) + `ConnectionTrackerGuard` RAII
|
||||
- Route match: `metrics.connection_opened(route_id, source_ip)` + `ConnectionGuard` RAII
|
||||
- Close: Both guards call their respective `_closed()` methods on drop
|
||||
|
||||
**Byte recording** — `forwarder.rs` (`forward_bidirectional_with_timeouts`):
|
||||
- Initial peeked data recorded immediately
|
||||
- Per-chunk in both directions: `record_bytes(n, 0, ...)` / `record_bytes(0, n, ...)`
|
||||
- Same pattern in `forward_bidirectional_split_with_timeouts` (tcp_listener.rs) for TLS-terminated paths
|
||||
|
||||
### HTTP Proxy (`rustproxy-http`)
|
||||
|
||||
**Request counting** — `proxy_service.rs`:
|
||||
- `record_http_request()` called once per request after route matching succeeds
|
||||
|
||||
**Body byte counting** — `counting_body.rs` wrapping:
|
||||
- Request bodies: `CountingBody::new(body, ..., Direction::In)` — counts client-to-upstream bytes
|
||||
- Response bodies: `CountingBody::new(body, ..., Direction::Out)` — counts upstream-to-client bytes
|
||||
- Batched flush every 64KB (`BYTE_FLUSH_THRESHOLD = 65_536`), remainder flushed on drop
|
||||
- Also updates `connection_activity` atomic (idle watchdog) and `active_requests` counter (streaming detection)
|
||||
|
||||
**Backend metrics** — `proxy_service.rs`:
|
||||
- `backend_connection_opened(key, connect_time)` — after TCP/TLS connect succeeds
|
||||
- `backend_connection_closed(key)` — on teardown
|
||||
- `backend_connect_error(key)` — TCP/TLS connect failure or timeout
|
||||
- `backend_handshake_error(key)` — H1/H2 protocol handshake failure
|
||||
- `backend_request_error(key)` — send_request failure
|
||||
- `backend_h2_failure(key)` — H2 attempted, fell back to H1
|
||||
- `backend_pool_hit(key)` / `backend_pool_miss(key)` — connection pool reuse
|
||||
- `set_backend_protocol(key, proto)` — records detected protocol
|
||||
|
||||
**WebSocket** — `proxy_service.rs`:
|
||||
- Does NOT use CountingBody; records bytes directly per-chunk in both directions of the bidirectional copy loop
|
||||
|
||||
### QUIC (`rustproxy-passthrough`)
|
||||
|
||||
**Connection level** — `quic_handler.rs`:
|
||||
- `connection_opened` / `connection_closed` via `QuicConnGuard` RAII
|
||||
- `conn_tracker.connection_opened/closed` for rate limiting
|
||||
|
||||
**Stream level**:
|
||||
- For QUIC-to-TCP stream forwarding: `record_bytes(bytes_in, bytes_out, ...)` called once per stream at completion (post-hoc, not per-chunk)
|
||||
- For HTTP/3: delegates to `HttpProxyService.handle_request()`, so all HTTP proxy metrics apply
|
||||
|
||||
**H3 specifics** — `h3_service.rs`:
|
||||
- `ProtocolGuard::frontend("h3")` tracks the H3 connection
|
||||
- H3 request bodies: `record_bytes(data.len(), 0, ...)` called directly (not CountingBody) since H3 uses `stream.send_data()`
|
||||
- H3 response bodies: wrapped in CountingBody like HTTP/1 and HTTP/2
|
||||
|
||||
### UDP (`rustproxy-passthrough`)
|
||||
|
||||
**Session lifecycle** — `udp_listener.rs` / `udp_session.rs`:
|
||||
- `udp_session_opened()` + `connection_opened(route_id, source_ip)` on new session
|
||||
- `udp_session_closed()` + `connection_closed(route_id, source_ip)` on idle reap or port drain
|
||||
|
||||
**Datagram counting** — `udp_listener.rs`:
|
||||
- Inbound: `record_bytes(len, 0, ...)` + `record_datagram_in()`
|
||||
- Outbound (backend reply): `record_bytes(0, len, ...)` + `record_datagram_out()`
|
||||
|
||||
---
|
||||
|
||||
## Sampling Loop
|
||||
|
||||
`lib.rs` spawns a tokio task at configurable interval (default 1000ms):
|
||||
|
||||
```rust
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = cancel => break,
|
||||
_ = interval.tick() => {
|
||||
metrics.sample_all();
|
||||
conn_tracker.cleanup_stale_timestamps();
|
||||
http_proxy.cleanup_all_rate_limiters();
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
`sample_all()` performs in one pass:
|
||||
1. Drains `global_pending_tp_in/out` into global ThroughputTracker, samples
|
||||
2. Drains per-route pending counters into per-route trackers, samples each
|
||||
3. Samples idle route trackers (no new data) to advance their window
|
||||
4. Drains per-IP pending counters into per-IP trackers, samples each
|
||||
5. Drains `pending_http_requests` into HTTP request throughput tracker
|
||||
6. Prunes orphaned per-IP entries (bytes/throughput maps with no matching ip_connections key)
|
||||
7. Prunes orphaned per-backend entries (error/stats maps with no matching active/total key)
|
||||
|
||||
---
|
||||
|
||||
## Data Flow: Rust to TypeScript
|
||||
|
||||
```
|
||||
MetricsCollector.snapshot()
|
||||
├── reads all AtomicU64 counters
|
||||
├── iterates DashMaps (routes, IPs, backends)
|
||||
├── locks ThroughputTrackers for instant/recent rates + history
|
||||
└── produces Metrics struct
|
||||
|
||||
RustProxy::get_metrics()
|
||||
├── calls snapshot()
|
||||
├── enriches with detectedProtocols from HTTP proxy protocol cache
|
||||
└── returns Metrics
|
||||
|
||||
management.rs "getMetrics" IPC command
|
||||
├── calls get_metrics()
|
||||
├── serde_json::to_value (camelCase)
|
||||
└── writes JSON to stdout
|
||||
|
||||
RustProxyBridge (TypeScript)
|
||||
├── reads JSON from Rust process stdout
|
||||
└── returns parsed object
|
||||
|
||||
RustMetricsAdapter
|
||||
├── setInterval polls bridge.getMetrics() every 1s
|
||||
├── stores raw JSON in this.cache
|
||||
└── IMetrics methods read synchronously from cache
|
||||
|
||||
SmartProxy.getMetrics()
|
||||
└── returns the RustMetricsAdapter instance
|
||||
```
|
||||
|
||||
### IPC JSON Shape (Metrics)
|
||||
|
||||
```json
|
||||
{
|
||||
"activeConnections": 42,
|
||||
"totalConnections": 1000,
|
||||
"bytesIn": 123456789,
|
||||
"bytesOut": 987654321,
|
||||
"throughputInBytesPerSec": 50000,
|
||||
"throughputOutBytesPerSec": 80000,
|
||||
"throughputRecentInBytesPerSec": 45000,
|
||||
"throughputRecentOutBytesPerSec": 75000,
|
||||
"routes": {
|
||||
"<route-id>": {
|
||||
"activeConnections": 5,
|
||||
"totalConnections": 100,
|
||||
"bytesIn": 0, "bytesOut": 0,
|
||||
"throughputInBytesPerSec": 0, "throughputOutBytesPerSec": 0,
|
||||
"throughputRecentInBytesPerSec": 0, "throughputRecentOutBytesPerSec": 0
|
||||
}
|
||||
},
|
||||
"ips": {
|
||||
"<ip>": {
|
||||
"activeConnections": 2, "totalConnections": 10,
|
||||
"bytesIn": 0, "bytesOut": 0,
|
||||
"throughputInBytesPerSec": 0, "throughputOutBytesPerSec": 0,
|
||||
"domainRequests": {
|
||||
"example.com": 4821,
|
||||
"api.example.com": 312
|
||||
}
|
||||
}
|
||||
},
|
||||
"backends": {
|
||||
"<host:port>": {
|
||||
"activeConnections": 3, "totalConnections": 50,
|
||||
"protocol": "h2",
|
||||
"connectErrors": 0, "handshakeErrors": 0, "requestErrors": 0,
|
||||
"totalConnectTimeUs": 150000, "connectCount": 50,
|
||||
"poolHits": 40, "poolMisses": 10, "h2Failures": 1
|
||||
}
|
||||
},
|
||||
"throughputHistory": [
|
||||
{ "timestampMs": 1713000000000, "bytesIn": 50000, "bytesOut": 80000 }
|
||||
],
|
||||
"totalHttpRequests": 5000,
|
||||
"httpRequestsPerSec": 100,
|
||||
"httpRequestsPerSecRecent": 95,
|
||||
"activeUdpSessions": 0, "totalUdpSessions": 5,
|
||||
"totalDatagramsIn": 1000, "totalDatagramsOut": 1000,
|
||||
"frontendProtocols": {
|
||||
"h1Active": 10, "h1Total": 500,
|
||||
"h2Active": 5, "h2Total": 200,
|
||||
"h3Active": 1, "h3Total": 50,
|
||||
"wsActive": 2, "wsTotal": 30,
|
||||
"otherActive": 0, "otherTotal": 0
|
||||
},
|
||||
"backendProtocols": { "...same shape..." },
|
||||
"detectedProtocols": [
|
||||
{
|
||||
"host": "backend", "port": 443, "domain": "example.com",
|
||||
"protocol": "h2", "h3Port": 443,
|
||||
"ageSecs": 120, "lastAccessedSecs": 5, "lastProbedSecs": 120,
|
||||
"h2Suppressed": false, "h3Suppressed": false,
|
||||
"h2CooldownRemainingSecs": null, "h3CooldownRemainingSecs": null,
|
||||
"h2ConsecutiveFailures": null, "h3ConsecutiveFailures": null
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### IPC JSON Shape (Statistics)
|
||||
|
||||
Lightweight administrative summary, fetched on-demand (not polled):
|
||||
|
||||
```json
|
||||
{
|
||||
"activeConnections": 42,
|
||||
"totalConnections": 1000,
|
||||
"routesCount": 5,
|
||||
"listeningPorts": [80, 443, 8443],
|
||||
"uptimeSeconds": 86400
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## TypeScript Consumer API
|
||||
|
||||
`SmartProxy.getMetrics()` returns an `IMetrics` object. All members are synchronous methods reading from the polled cache.
|
||||
|
||||
### connections
|
||||
|
||||
| Method | Return | Source |
|
||||
|---|---|---|
|
||||
| `active()` | `number` | `cache.activeConnections` |
|
||||
| `total()` | `number` | `cache.totalConnections` |
|
||||
| `byRoute()` | `Map<string, number>` | `cache.routes[name].activeConnections` |
|
||||
| `byIP()` | `Map<string, number>` | `cache.ips[ip].activeConnections` |
|
||||
| `topIPs(limit?)` | `Array<{ip, count}>` | `cache.ips` sorted by active desc, default 10 |
|
||||
| `domainRequestsByIP()` | `Map<string, Map<string, number>>` | `cache.ips[ip].domainRequests` |
|
||||
| `topDomainRequests(limit?)` | `Array<{ip, domain, count}>` | Flattened from all IPs, sorted by count desc, default 20 |
|
||||
| `frontendProtocols()` | `IProtocolDistribution` | `cache.frontendProtocols.*` |
|
||||
| `backendProtocols()` | `IProtocolDistribution` | `cache.backendProtocols.*` |
|
||||
|
||||
### throughput
|
||||
|
||||
| Method | Return | Source |
|
||||
|---|---|---|
|
||||
| `instant()` | `{in, out}` | `cache.throughputInBytesPerSec/Out` |
|
||||
| `recent()` | `{in, out}` | `cache.throughputRecentInBytesPerSec/Out` |
|
||||
| `average()` | `{in, out}` | Falls back to `instant()` (not wired to windowed average) |
|
||||
| `custom(seconds)` | `{in, out}` | Falls back to `instant()` (not wired) |
|
||||
| `history(seconds)` | `IThroughputHistoryPoint[]` | `cache.throughputHistory` sliced to last N entries |
|
||||
| `byRoute(windowSeconds?)` | `Map<string, {in, out}>` | `cache.routes[name].throughputIn/OutBytesPerSec` |
|
||||
| `byIP(windowSeconds?)` | `Map<string, {in, out}>` | `cache.ips[ip].throughputIn/OutBytesPerSec` |
|
||||
|
||||
### requests
|
||||
|
||||
| Method | Return | Source |
|
||||
|---|---|---|
|
||||
| `perSecond()` | `number` | `cache.httpRequestsPerSec` |
|
||||
| `perMinute()` | `number` | `cache.httpRequestsPerSecRecent * 60` |
|
||||
| `total()` | `number` | `cache.totalHttpRequests` (fallback: totalConnections) |
|
||||
|
||||
### totals
|
||||
|
||||
| Method | Return | Source |
|
||||
|---|---|---|
|
||||
| `bytesIn()` | `number` | `cache.bytesIn` |
|
||||
| `bytesOut()` | `number` | `cache.bytesOut` |
|
||||
| `connections()` | `number` | `cache.totalConnections` |
|
||||
|
||||
### backends
|
||||
|
||||
| Method | Return | Source |
|
||||
|---|---|---|
|
||||
| `byBackend()` | `Map<string, IBackendMetrics>` | `cache.backends[key].*` with computed `avgConnectTimeMs` and `poolHitRate` |
|
||||
| `protocols()` | `Map<string, string>` | `cache.backends[key].protocol` |
|
||||
| `topByErrors(limit?)` | `IBackendMetrics[]` | Sorted by total errors desc |
|
||||
| `detectedProtocols()` | `IProtocolCacheEntry[]` | `cache.detectedProtocols` passthrough |
|
||||
|
||||
`IBackendMetrics`: `{ protocol, activeConnections, totalConnections, connectErrors, handshakeErrors, requestErrors, avgConnectTimeMs, poolHitRate, h2Failures }`
|
||||
|
||||
### udp
|
||||
|
||||
| Method | Return | Source |
|
||||
|---|---|---|
|
||||
| `activeSessions()` | `number` | `cache.activeUdpSessions` |
|
||||
| `totalSessions()` | `number` | `cache.totalUdpSessions` |
|
||||
| `datagramsIn()` | `number` | `cache.totalDatagramsIn` |
|
||||
| `datagramsOut()` | `number` | `cache.totalDatagramsOut` |
|
||||
|
||||
### percentiles (stub)
|
||||
|
||||
`connectionDuration()` and `bytesTransferred()` always return zeros. Not implemented.
|
||||
|
||||
---
|
||||
|
||||
## Configuration
|
||||
|
||||
```typescript
|
||||
interface IMetricsConfig {
|
||||
enabled: boolean; // default true
|
||||
sampleIntervalMs: number; // default 1000 (1Hz sampling + TS polling)
|
||||
retentionSeconds: number; // default 3600 (ThroughputTracker capacity)
|
||||
enableDetailedTracking: boolean;
|
||||
enablePercentiles: boolean;
|
||||
cacheResultsMs: number;
|
||||
prometheusEnabled: boolean; // not wired
|
||||
prometheusPath: string; // not wired
|
||||
prometheusPrefix: string; // not wired
|
||||
}
|
||||
```
|
||||
|
||||
Rust-side config (`MetricsConfig` in `rustproxy-config`):
|
||||
|
||||
```rust
|
||||
pub struct MetricsConfig {
|
||||
pub enabled: Option<bool>,
|
||||
pub sample_interval_ms: Option<u64>, // default 1000
|
||||
pub retention_seconds: Option<u64>, // default 3600
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Design Decisions
|
||||
|
||||
**Lock-free hot path.** `record_bytes()` is the most frequently called method (per-chunk in TCP, per-64KB in HTTP). It only touches `AtomicU64` with `Relaxed` ordering and short-circuits zero-byte directions to skip DashMap lookups entirely.
|
||||
|
||||
**CountingBody batching.** HTTP body frames are typically 16KB. Flushing to MetricsCollector every 64KB reduces DashMap shard contention by ~4x compared to per-frame recording.
|
||||
|
||||
**RAII guards everywhere.** `ConnectionGuard`, `ConnectionTrackerGuard`, `QuicConnGuard`, `ProtocolGuard`, `FrontendProtocolTracker` all use Drop to guarantee counter cleanup on all exit paths including panics.
|
||||
|
||||
**Saturating decrements.** Protocol counters use `fetch_update` loops instead of `fetch_sub` to prevent underflow to `u64::MAX` from concurrent close races.
|
||||
|
||||
**Bounded memory.** Per-IP entries evicted on last connection close. Per-backend entries evicted on last connection close. Snapshot caps IPs and backends at 100 each. `sample_all()` prunes orphaned entries every second.
|
||||
|
||||
**Two-phase throughput.** Pending bytes accumulate in lock-free atomics. The 1Hz cold path drains them into Mutex-guarded ThroughputTrackers. This means the hot path never contends on a Mutex, while the cold path does minimal work (one drain + one sample per tracker).
|
||||
|
||||
---
|
||||
|
||||
## Known Gaps
|
||||
|
||||
| Gap | Status |
|
||||
|---|---|
|
||||
| `throughput.average()` / `throughput.custom(seconds)` | Fall back to `instant()`. Not wired to Rust windowed queries. |
|
||||
| `percentiles.connectionDuration()` / `percentiles.bytesTransferred()` | Stub returning zeros. No histogram in Rust. |
|
||||
| Prometheus export | Config fields exist but not wired to any exporter. |
|
||||
| `LogDeduplicator` | Implemented in `rustproxy-metrics` but not connected to any call site. |
|
||||
| Rate limit hit counters | Rate-limited requests return 429 but no counter is recorded in MetricsCollector. |
|
||||
| QUIC stream byte counting | Post-hoc (per-stream totals after close), not per-chunk like TCP. |
|
||||
| Throughput history in snapshot | Capped at 60 samples. TS `history(seconds)` cannot return more than 60 points regardless of `retentionSeconds`. |
|
||||
| Per-route total connections / bytes | Available in Rust JSON but `IMetrics.connections.byRoute()` only exposes active connections. |
|
||||
| Per-IP total connections / bytes | Available in Rust JSON but `IMetrics.connections.byIP()` only exposes active connections. |
|
||||
| IPC response typing | `RustProxyBridge` declares `result: any` for both metrics commands. No type-safe response. |
|
||||
@@ -129,7 +129,6 @@ pub struct RustProxyOptions {
|
||||
pub defaults: Option<DefaultConfig>,
|
||||
|
||||
// ─── Timeout Settings ────────────────────────────────────────────
|
||||
|
||||
/// Timeout for establishing connection to backend (ms), default: 30000
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub connection_timeout: Option<u64>,
|
||||
@@ -159,7 +158,6 @@ pub struct RustProxyOptions {
|
||||
pub graceful_shutdown_timeout: Option<u64>,
|
||||
|
||||
// ─── Socket Optimization ─────────────────────────────────────────
|
||||
|
||||
/// Disable Nagle's algorithm (default: true)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub no_delay: Option<bool>,
|
||||
@@ -177,7 +175,6 @@ pub struct RustProxyOptions {
|
||||
pub max_pending_data_size: Option<u64>,
|
||||
|
||||
// ─── Enhanced Features ───────────────────────────────────────────
|
||||
|
||||
/// Disable inactivity checking entirely
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub disable_inactivity_check: Option<bool>,
|
||||
@@ -199,7 +196,6 @@ pub struct RustProxyOptions {
|
||||
pub enable_randomized_timeouts: Option<bool>,
|
||||
|
||||
// ─── Rate Limiting ───────────────────────────────────────────────
|
||||
|
||||
/// Maximum simultaneous connections from a single IP
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub max_connections_per_ip: Option<u64>,
|
||||
@@ -213,7 +209,6 @@ pub struct RustProxyOptions {
|
||||
pub max_connections: Option<u64>,
|
||||
|
||||
// ─── Keep-Alive Settings ─────────────────────────────────────────
|
||||
|
||||
/// How to treat keep-alive connections
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub keep_alive_treatment: Option<KeepAliveTreatment>,
|
||||
@@ -227,7 +222,6 @@ pub struct RustProxyOptions {
|
||||
pub extended_keep_alive_lifetime: Option<u64>,
|
||||
|
||||
// ─── HttpProxy Integration ───────────────────────────────────────
|
||||
|
||||
/// Array of ports to forward to HttpProxy
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub use_http_proxy: Option<Vec<u16>>,
|
||||
@@ -237,13 +231,11 @@ pub struct RustProxyOptions {
|
||||
pub http_proxy_port: Option<u16>,
|
||||
|
||||
// ─── Metrics ─────────────────────────────────────────────────────
|
||||
|
||||
/// Metrics configuration
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub metrics: Option<MetricsConfig>,
|
||||
|
||||
// ─── ACME ────────────────────────────────────────────────────────
|
||||
|
||||
/// Global ACME configuration
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub acme: Option<AcmeOptions>,
|
||||
@@ -318,7 +310,8 @@ impl RustProxyOptions {
|
||||
|
||||
/// Get all unique ports that routes listen on.
|
||||
pub fn all_listening_ports(&self) -> Vec<u16> {
|
||||
let mut ports: Vec<u16> = self.routes
|
||||
let mut ports: Vec<u16> = self
|
||||
.routes
|
||||
.iter()
|
||||
.flat_map(|r| r.listening_ports())
|
||||
.collect();
|
||||
@@ -340,7 +333,12 @@ mod tests {
|
||||
route_match: RouteMatch {
|
||||
ports: PortRange::Single(listen_port),
|
||||
domains: Some(DomainSpec::Single(domain.to_string())),
|
||||
path: None, client_ip: None, transport: None, tls_version: None, headers: None, protocol: None,
|
||||
path: None,
|
||||
client_ip: None,
|
||||
transport: None,
|
||||
tls_version: None,
|
||||
headers: None,
|
||||
protocol: None,
|
||||
},
|
||||
action: RouteAction {
|
||||
action_type: RouteActionType::Forward,
|
||||
@@ -348,14 +346,30 @@ mod tests {
|
||||
target_match: None,
|
||||
host: HostSpec::Single(host.to_string()),
|
||||
port: PortSpec::Fixed(port),
|
||||
tls: None, websocket: None, load_balancing: None, send_proxy_protocol: None,
|
||||
headers: None, advanced: None, backend_transport: None, priority: None,
|
||||
tls: None,
|
||||
websocket: None,
|
||||
load_balancing: None,
|
||||
send_proxy_protocol: None,
|
||||
headers: None,
|
||||
advanced: None,
|
||||
backend_transport: None,
|
||||
priority: None,
|
||||
}]),
|
||||
tls: None, websocket: None, load_balancing: None, advanced: None,
|
||||
options: None, send_proxy_protocol: None, udp: None,
|
||||
tls: None,
|
||||
websocket: None,
|
||||
load_balancing: None,
|
||||
advanced: None,
|
||||
options: None,
|
||||
send_proxy_protocol: None,
|
||||
udp: None,
|
||||
},
|
||||
headers: None, security: None, name: None, description: None,
|
||||
priority: None, tags: None, enabled: None,
|
||||
headers: None,
|
||||
security: None,
|
||||
name: None,
|
||||
description: None,
|
||||
priority: None,
|
||||
tags: None,
|
||||
enabled: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -363,8 +377,12 @@ mod tests {
|
||||
let mut route = make_route(domain, host, port, 443);
|
||||
route.action.tls = Some(RouteTls {
|
||||
mode: TlsMode::Passthrough,
|
||||
certificate: None, acme: None, versions: None, ciphers: None,
|
||||
honor_cipher_order: None, session_timeout: None,
|
||||
certificate: None,
|
||||
acme: None,
|
||||
versions: None,
|
||||
ciphers: None,
|
||||
honor_cipher_order: None,
|
||||
session_timeout: None,
|
||||
});
|
||||
route
|
||||
}
|
||||
@@ -410,6 +428,209 @@ mod tests {
|
||||
assert_eq!(parsed.connection_timeout, Some(5000));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_deserialize_ts_contract_route_shapes() {
|
||||
let value = serde_json::json!({
|
||||
"routes": [{
|
||||
"name": "contract-route",
|
||||
"match": {
|
||||
"ports": [443, { "from": 8443, "to": 8444 }],
|
||||
"domains": ["api.example.com", "*.example.com"],
|
||||
"transport": "udp",
|
||||
"protocol": "http3",
|
||||
"headers": {
|
||||
"content-type": "/^application\\/json$/i"
|
||||
}
|
||||
},
|
||||
"action": {
|
||||
"type": "forward",
|
||||
"targets": [{
|
||||
"match": {
|
||||
"ports": [443],
|
||||
"path": "/api/*",
|
||||
"method": ["GET"],
|
||||
"headers": {
|
||||
"x-env": "/^(prod|stage)$/"
|
||||
}
|
||||
},
|
||||
"host": ["backend-a", "backend-b"],
|
||||
"port": "preserve",
|
||||
"sendProxyProtocol": true,
|
||||
"backendTransport": "tcp"
|
||||
}],
|
||||
"tls": {
|
||||
"mode": "terminate",
|
||||
"certificate": "auto"
|
||||
},
|
||||
"sendProxyProtocol": true,
|
||||
"udp": {
|
||||
"maxSessionsPerIp": 321,
|
||||
"quic": {
|
||||
"enableHttp3": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"security": {
|
||||
"ipAllowList": [{
|
||||
"ip": "10.0.0.0/8",
|
||||
"domains": ["api.example.com"]
|
||||
}]
|
||||
}
|
||||
}],
|
||||
"preserveSourceIp": true,
|
||||
"proxyIps": ["10.0.0.1"],
|
||||
"acceptProxyProtocol": true,
|
||||
"sendProxyProtocol": true,
|
||||
"noDelay": true,
|
||||
"keepAlive": true,
|
||||
"keepAliveInitialDelay": 1500,
|
||||
"maxPendingDataSize": 4096,
|
||||
"disableInactivityCheck": true,
|
||||
"enableKeepAliveProbes": true,
|
||||
"enableDetailedLogging": true,
|
||||
"enableTlsDebugLogging": true,
|
||||
"enableRandomizedTimeouts": true,
|
||||
"connectionTimeout": 5000,
|
||||
"initialDataTimeout": 7000,
|
||||
"socketTimeout": 9000,
|
||||
"inactivityCheckInterval": 1100,
|
||||
"maxConnectionLifetime": 13000,
|
||||
"inactivityTimeout": 15000,
|
||||
"gracefulShutdownTimeout": 17000,
|
||||
"maxConnectionsPerIp": 20,
|
||||
"connectionRateLimitPerMinute": 30,
|
||||
"keepAliveTreatment": "extended",
|
||||
"keepAliveInactivityMultiplier": 2.0,
|
||||
"extendedKeepAliveLifetime": 19000,
|
||||
"metrics": {
|
||||
"enabled": true,
|
||||
"sampleIntervalMs": 250,
|
||||
"retentionSeconds": 60
|
||||
},
|
||||
"acme": {
|
||||
"enabled": true,
|
||||
"email": "ops@example.com",
|
||||
"environment": "staging",
|
||||
"useProduction": false,
|
||||
"skipConfiguredCerts": true,
|
||||
"renewThresholdDays": 14,
|
||||
"renewCheckIntervalHours": 12,
|
||||
"autoRenew": true,
|
||||
"port": 80
|
||||
}
|
||||
});
|
||||
|
||||
let options: RustProxyOptions = serde_json::from_value(value).unwrap();
|
||||
|
||||
assert_eq!(options.routes.len(), 1);
|
||||
assert_eq!(options.preserve_source_ip, Some(true));
|
||||
assert_eq!(options.proxy_ips, Some(vec!["10.0.0.1".to_string()]));
|
||||
assert_eq!(options.accept_proxy_protocol, Some(true));
|
||||
assert_eq!(options.send_proxy_protocol, Some(true));
|
||||
assert_eq!(options.no_delay, Some(true));
|
||||
assert_eq!(options.keep_alive, Some(true));
|
||||
assert_eq!(options.keep_alive_initial_delay, Some(1500));
|
||||
assert_eq!(options.max_pending_data_size, Some(4096));
|
||||
assert_eq!(options.disable_inactivity_check, Some(true));
|
||||
assert_eq!(options.enable_keep_alive_probes, Some(true));
|
||||
assert_eq!(options.enable_detailed_logging, Some(true));
|
||||
assert_eq!(options.enable_tls_debug_logging, Some(true));
|
||||
assert_eq!(options.enable_randomized_timeouts, Some(true));
|
||||
assert_eq!(options.connection_timeout, Some(5000));
|
||||
assert_eq!(options.initial_data_timeout, Some(7000));
|
||||
assert_eq!(options.socket_timeout, Some(9000));
|
||||
assert_eq!(options.inactivity_check_interval, Some(1100));
|
||||
assert_eq!(options.max_connection_lifetime, Some(13000));
|
||||
assert_eq!(options.inactivity_timeout, Some(15000));
|
||||
assert_eq!(options.graceful_shutdown_timeout, Some(17000));
|
||||
assert_eq!(options.max_connections_per_ip, Some(20));
|
||||
assert_eq!(options.connection_rate_limit_per_minute, Some(30));
|
||||
assert_eq!(
|
||||
options.keep_alive_treatment,
|
||||
Some(KeepAliveTreatment::Extended)
|
||||
);
|
||||
assert_eq!(options.keep_alive_inactivity_multiplier, Some(2.0));
|
||||
assert_eq!(options.extended_keep_alive_lifetime, Some(19000));
|
||||
|
||||
let route = &options.routes[0];
|
||||
assert_eq!(route.route_match.transport, Some(TransportProtocol::Udp));
|
||||
assert_eq!(route.route_match.protocol.as_deref(), Some("http3"));
|
||||
assert_eq!(
|
||||
route
|
||||
.route_match
|
||||
.headers
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.get("content-type")
|
||||
.unwrap(),
|
||||
"/^application\\/json$/i"
|
||||
);
|
||||
|
||||
let target = &route.action.targets.as_ref().unwrap()[0];
|
||||
assert!(matches!(target.host, HostSpec::List(_)));
|
||||
assert!(matches!(target.port, PortSpec::Special(ref p) if p == "preserve"));
|
||||
assert_eq!(target.backend_transport, Some(TransportProtocol::Tcp));
|
||||
assert_eq!(target.send_proxy_protocol, Some(true));
|
||||
assert_eq!(
|
||||
target
|
||||
.target_match
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.headers
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.get("x-env")
|
||||
.unwrap(),
|
||||
"/^(prod|stage)$/"
|
||||
);
|
||||
assert_eq!(route.action.send_proxy_protocol, Some(true));
|
||||
assert_eq!(
|
||||
route.action.udp.as_ref().unwrap().max_sessions_per_ip,
|
||||
Some(321)
|
||||
);
|
||||
assert_eq!(
|
||||
route
|
||||
.action
|
||||
.udp
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.quic
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.enable_http3,
|
||||
Some(true)
|
||||
);
|
||||
|
||||
let allow_list = route
|
||||
.security
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.ip_allow_list
|
||||
.as_ref()
|
||||
.unwrap();
|
||||
assert!(matches!(
|
||||
&allow_list[0],
|
||||
crate::security_types::IpAllowEntry::DomainScoped { ip, domains }
|
||||
if ip == "10.0.0.0/8" && domains == &vec!["api.example.com".to_string()]
|
||||
));
|
||||
|
||||
let metrics = options.metrics.as_ref().unwrap();
|
||||
assert_eq!(metrics.enabled, Some(true));
|
||||
assert_eq!(metrics.sample_interval_ms, Some(250));
|
||||
assert_eq!(metrics.retention_seconds, Some(60));
|
||||
|
||||
let acme = options.acme.as_ref().unwrap();
|
||||
assert_eq!(acme.enabled, Some(true));
|
||||
assert_eq!(acme.email.as_deref(), Some("ops@example.com"));
|
||||
assert_eq!(acme.environment, Some(AcmeEnvironment::Staging));
|
||||
assert_eq!(acme.use_production, Some(false));
|
||||
assert_eq!(acme.skip_configured_certs, Some(true));
|
||||
assert_eq!(acme.renew_threshold_days, Some(14));
|
||||
assert_eq!(acme.renew_check_interval_hours, Some(12));
|
||||
assert_eq!(acme.auto_renew, Some(true));
|
||||
assert_eq!(acme.port, Some(80));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_default_timeouts() {
|
||||
let options = RustProxyOptions::default();
|
||||
@@ -438,9 +659,9 @@ mod tests {
|
||||
fn test_all_listening_ports() {
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![
|
||||
make_route("a.com", "backend", 8080, 80), // port 80
|
||||
make_route("a.com", "backend", 8080, 80), // port 80
|
||||
make_passthrough_route("b.com", "backend", 443), // port 443
|
||||
make_route("c.com", "backend", 9090, 80), // port 80 (duplicate)
|
||||
make_route("c.com", "backend", 9090, 80), // port 80 (duplicate)
|
||||
],
|
||||
..Default::default()
|
||||
};
|
||||
@@ -464,9 +685,11 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_deserialize_example_json() {
|
||||
let content = std::fs::read_to_string(
|
||||
concat!(env!("CARGO_MANIFEST_DIR"), "/../../config/example.json")
|
||||
).unwrap();
|
||||
let content = std::fs::read_to_string(concat!(
|
||||
env!("CARGO_MANIFEST_DIR"),
|
||||
"/../../config/example.json"
|
||||
))
|
||||
.unwrap();
|
||||
let options: RustProxyOptions = serde_json::from_str(&content).unwrap();
|
||||
assert_eq!(options.routes.len(), 4);
|
||||
let ports = options.all_listening_ports();
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::tls_types::RouteTls;
|
||||
use crate::security_types::RouteSecurity;
|
||||
use crate::tls_types::RouteTls;
|
||||
|
||||
// ─── Port Range ──────────────────────────────────────────────────────
|
||||
|
||||
@@ -32,12 +32,13 @@ impl PortRange {
|
||||
pub fn to_ports(&self) -> Vec<u16> {
|
||||
match self {
|
||||
PortRange::Single(p) => vec![*p],
|
||||
PortRange::List(items) => {
|
||||
items.iter().flat_map(|item| match item {
|
||||
PortRange::List(items) => items
|
||||
.iter()
|
||||
.flat_map(|item| match item {
|
||||
PortRangeItem::Port(p) => vec![*p],
|
||||
PortRangeItem::Range(r) => (r.from..=r.to).collect(),
|
||||
}).collect()
|
||||
}
|
||||
})
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -105,7 +106,8 @@ impl From<Vec<&str>> for DomainSpec {
|
||||
}
|
||||
|
||||
/// Header match value: either exact string or regex pattern.
|
||||
/// In JSON, all values come as strings. Regex patterns are prefixed with `/` and suffixed with `/`.
|
||||
/// In JSON, all values come as strings. Regex patterns use JS-style literal syntax,
|
||||
/// e.g. `/^application\/json$/` or `/^application\/json$/i`.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum HeaderMatchValue {
|
||||
|
||||
@@ -103,14 +103,30 @@ pub struct JwtAuthConfig {
|
||||
pub exclude_paths: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
/// An entry in the IP allow list: either a plain IP/CIDR string
|
||||
/// or a domain-scoped entry that restricts the IP to specific domains.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum IpAllowEntry {
|
||||
/// Plain IP/CIDR — allowed for all domains on this route
|
||||
Plain(String),
|
||||
/// Domain-scoped — allowed only when the requested domain matches
|
||||
DomainScoped {
|
||||
ip: String,
|
||||
domains: Vec<String>,
|
||||
},
|
||||
}
|
||||
|
||||
/// Security options for routes.
|
||||
/// Matches TypeScript: `IRouteSecurity`
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct RouteSecurity {
|
||||
/// IP addresses that are allowed to connect
|
||||
/// IP addresses that are allowed to connect.
|
||||
/// Entries can be plain strings (full route access) or objects with
|
||||
/// `{ ip, domains }` to scope access to specific domains.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub ip_allow_list: Option<Vec<String>>,
|
||||
pub ip_allow_list: Option<Vec<IpAllowEntry>>,
|
||||
/// IP addresses that are blocked from connecting
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub ip_block_list: Option<Vec<String>>,
|
||||
|
||||
@@ -3,8 +3,8 @@
|
||||
//! Reuses idle keep-alive connections to avoid per-request TCP+TLS handshakes.
|
||||
//! HTTP/2 and HTTP/3 connections are multiplexed (clone the sender / share the connection).
|
||||
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use bytes::Bytes;
|
||||
@@ -105,13 +105,19 @@ impl ConnectionPool {
|
||||
|
||||
/// Try to check out an idle HTTP/1.1 sender for the given key.
|
||||
/// Returns `None` if no usable idle connection exists.
|
||||
pub fn checkout_h1(&self, key: &PoolKey) -> Option<http1::SendRequest<BoxBody<Bytes, hyper::Error>>> {
|
||||
pub fn checkout_h1(
|
||||
&self,
|
||||
key: &PoolKey,
|
||||
) -> Option<http1::SendRequest<BoxBody<Bytes, hyper::Error>>> {
|
||||
let mut entry = self.h1_pool.get_mut(key)?;
|
||||
let idles = entry.value_mut();
|
||||
|
||||
while let Some(idle) = idles.pop() {
|
||||
// Check if the connection is still alive and ready
|
||||
if idle.idle_since.elapsed() < IDLE_TIMEOUT && idle.sender.is_ready() && !idle.sender.is_closed() {
|
||||
if idle.idle_since.elapsed() < IDLE_TIMEOUT
|
||||
&& idle.sender.is_ready()
|
||||
&& !idle.sender.is_closed()
|
||||
{
|
||||
// H1 pool hit — no logging on hot path
|
||||
return Some(idle.sender);
|
||||
}
|
||||
@@ -128,7 +134,11 @@ impl ConnectionPool {
|
||||
|
||||
/// Return an HTTP/1.1 sender to the pool after the response body has been prepared.
|
||||
/// The caller should NOT call this if the sender is closed or not ready.
|
||||
pub fn checkin_h1(&self, key: PoolKey, sender: http1::SendRequest<BoxBody<Bytes, hyper::Error>>) {
|
||||
pub fn checkin_h1(
|
||||
&self,
|
||||
key: PoolKey,
|
||||
sender: http1::SendRequest<BoxBody<Bytes, hyper::Error>>,
|
||||
) {
|
||||
if sender.is_closed() || !sender.is_ready() {
|
||||
return; // Don't pool broken connections
|
||||
}
|
||||
@@ -145,7 +155,10 @@ impl ConnectionPool {
|
||||
|
||||
/// Try to get a cloned HTTP/2 sender for the given key.
|
||||
/// HTTP/2 senders are Clone-able (multiplexed), so we clone rather than remove.
|
||||
pub fn checkout_h2(&self, key: &PoolKey) -> Option<(http2::SendRequest<BoxBody<Bytes, hyper::Error>>, Duration)> {
|
||||
pub fn checkout_h2(
|
||||
&self,
|
||||
key: &PoolKey,
|
||||
) -> Option<(http2::SendRequest<BoxBody<Bytes, hyper::Error>>, Duration)> {
|
||||
let entry = self.h2_pool.get(key)?;
|
||||
let pooled = entry.value();
|
||||
let age = pooled.created_at.elapsed();
|
||||
@@ -184,16 +197,23 @@ impl ConnectionPool {
|
||||
/// Register an HTTP/2 sender in the pool. Returns the generation ID for this entry.
|
||||
/// The caller should pass this generation to the connection driver so it can use
|
||||
/// `remove_h2_if_generation` instead of `remove_h2` to avoid phantom eviction.
|
||||
pub fn register_h2(&self, key: PoolKey, sender: http2::SendRequest<BoxBody<Bytes, hyper::Error>>) -> u64 {
|
||||
pub fn register_h2(
|
||||
&self,
|
||||
key: PoolKey,
|
||||
sender: http2::SendRequest<BoxBody<Bytes, hyper::Error>>,
|
||||
) -> u64 {
|
||||
let gen = self.h2_generation.fetch_add(1, Ordering::Relaxed);
|
||||
if sender.is_closed() {
|
||||
return gen;
|
||||
}
|
||||
self.h2_pool.insert(key, PooledH2 {
|
||||
sender,
|
||||
created_at: Instant::now(),
|
||||
generation: gen,
|
||||
});
|
||||
self.h2_pool.insert(
|
||||
key,
|
||||
PooledH2 {
|
||||
sender,
|
||||
created_at: Instant::now(),
|
||||
generation: gen,
|
||||
},
|
||||
);
|
||||
gen
|
||||
}
|
||||
|
||||
@@ -204,7 +224,11 @@ impl ConnectionPool {
|
||||
pub fn checkout_h3(
|
||||
&self,
|
||||
key: &PoolKey,
|
||||
) -> Option<(h3::client::SendRequest<h3_quinn::OpenStreams, Bytes>, quinn::Connection, Duration)> {
|
||||
) -> Option<(
|
||||
h3::client::SendRequest<h3_quinn::OpenStreams, Bytes>,
|
||||
quinn::Connection,
|
||||
Duration,
|
||||
)> {
|
||||
let entry = self.h3_pool.get(key)?;
|
||||
let pooled = entry.value();
|
||||
let age = pooled.created_at.elapsed();
|
||||
@@ -234,12 +258,15 @@ impl ConnectionPool {
|
||||
send_request: h3::client::SendRequest<h3_quinn::OpenStreams, Bytes>,
|
||||
) -> u64 {
|
||||
let gen = self.h2_generation.fetch_add(1, Ordering::Relaxed);
|
||||
self.h3_pool.insert(key, PooledH3 {
|
||||
send_request,
|
||||
connection,
|
||||
created_at: Instant::now(),
|
||||
generation: gen,
|
||||
});
|
||||
self.h3_pool.insert(
|
||||
key,
|
||||
PooledH3 {
|
||||
send_request,
|
||||
connection,
|
||||
created_at: Instant::now(),
|
||||
generation: gen,
|
||||
},
|
||||
);
|
||||
gen
|
||||
}
|
||||
|
||||
@@ -280,7 +307,9 @@ impl ConnectionPool {
|
||||
// Evict dead or aged-out H2 connections
|
||||
let mut dead_h2 = Vec::new();
|
||||
for entry in h2_pool.iter() {
|
||||
if entry.value().sender.is_closed() || entry.value().created_at.elapsed() >= MAX_H2_AGE {
|
||||
if entry.value().sender.is_closed()
|
||||
|| entry.value().created_at.elapsed() >= MAX_H2_AGE
|
||||
{
|
||||
dead_h2.push(entry.key().clone());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
//! A body wrapper that counts bytes flowing through and reports them to MetricsCollector.
|
||||
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::task::{Context, Poll};
|
||||
|
||||
use bytes::Bytes;
|
||||
@@ -76,7 +76,11 @@ impl<B> CountingBody<B> {
|
||||
/// Set the connection-level activity tracker. When set, each data frame
|
||||
/// updates this timestamp to prevent the idle watchdog from killing the
|
||||
/// connection during active body streaming.
|
||||
pub fn with_connection_activity(mut self, activity: Arc<AtomicU64>, start: std::time::Instant) -> Self {
|
||||
pub fn with_connection_activity(
|
||||
mut self,
|
||||
activity: Arc<AtomicU64>,
|
||||
start: std::time::Instant,
|
||||
) -> Self {
|
||||
self.connection_activity = Some(activity);
|
||||
self.activity_start = Some(start);
|
||||
self
|
||||
@@ -134,7 +138,9 @@ where
|
||||
}
|
||||
// Keep the connection-level idle watchdog alive on every frame
|
||||
// (this is just one atomic store — cheap enough per-frame)
|
||||
if let (Some(activity), Some(start)) = (&this.connection_activity, &this.activity_start) {
|
||||
if let (Some(activity), Some(start)) =
|
||||
(&this.connection_activity, &this.activity_start)
|
||||
{
|
||||
activity.store(start.elapsed().as_millis() as u64, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,14 +11,14 @@ use std::task::{Context, Poll};
|
||||
|
||||
use bytes::{Buf, Bytes};
|
||||
use http_body::Frame;
|
||||
use http_body_util::BodyExt;
|
||||
use http_body_util::combinators::BoxBody;
|
||||
use http_body_util::BodyExt;
|
||||
use tracing::{debug, warn};
|
||||
|
||||
use rustproxy_config::RouteConfig;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
use crate::proxy_service::{ConnActivity, HttpProxyService};
|
||||
use crate::proxy_service::{ConnActivity, HttpProxyService, ProtocolGuard};
|
||||
|
||||
/// HTTP/3 proxy service.
|
||||
///
|
||||
@@ -48,6 +48,10 @@ impl H3ProxyService {
|
||||
let remote_addr = real_client_addr.unwrap_or_else(|| connection.remote_address());
|
||||
debug!("HTTP/3 connection from {} on port {}", remote_addr, port);
|
||||
|
||||
// Track frontend H3 connection for the QUIC connection's lifetime.
|
||||
let _frontend_h3_guard =
|
||||
ProtocolGuard::frontend(Arc::clone(self.http_proxy.metrics()), "h3");
|
||||
|
||||
let mut h3_conn: h3::server::Connection<h3_quinn::Connection, Bytes> =
|
||||
h3::server::builder()
|
||||
.send_grease(false)
|
||||
@@ -89,8 +93,15 @@ impl H3ProxyService {
|
||||
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = handle_h3_request(
|
||||
request, stream, port, remote_addr, &http_proxy, request_cancel,
|
||||
).await {
|
||||
request,
|
||||
stream,
|
||||
port,
|
||||
remote_addr,
|
||||
&http_proxy,
|
||||
request_cancel,
|
||||
)
|
||||
.await
|
||||
{
|
||||
debug!("HTTP/3 request error from {}: {}", remote_addr, e);
|
||||
}
|
||||
});
|
||||
@@ -150,11 +161,14 @@ async fn handle_h3_request(
|
||||
// Delegate to HttpProxyService — same backend path as TCP/HTTP:
|
||||
// route matching, ALPN protocol detection, connection pool, H1/H2/H3 auto.
|
||||
let conn_activity = ConnActivity::new_standalone();
|
||||
let response = http_proxy.handle_request(req, peer_addr, port, cancel, conn_activity).await
|
||||
let response = http_proxy
|
||||
.handle_request(req, peer_addr, port, cancel, conn_activity)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("Backend request failed: {}", e))?;
|
||||
|
||||
// Await the body reader to get the H3 stream back
|
||||
let mut stream = body_reader.await
|
||||
let mut stream = body_reader
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("Body reader task failed: {}", e))?;
|
||||
|
||||
// Send response headers over H3 (skip hop-by-hop headers)
|
||||
@@ -167,10 +181,13 @@ async fn handle_h3_request(
|
||||
}
|
||||
h3_response = h3_response.header(name, value);
|
||||
}
|
||||
let h3_response = h3_response.body(())
|
||||
let h3_response = h3_response
|
||||
.body(())
|
||||
.map_err(|e| anyhow::anyhow!("Failed to build H3 response: {}", e))?;
|
||||
|
||||
stream.send_response(h3_response).await
|
||||
stream
|
||||
.send_response(h3_response)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("Failed to send H3 response: {}", e))?;
|
||||
|
||||
// Stream response body back over H3
|
||||
@@ -179,7 +196,9 @@ async fn handle_h3_request(
|
||||
match frame {
|
||||
Ok(frame) => {
|
||||
if let Ok(data) = frame.into_data() {
|
||||
stream.send_data(data).await
|
||||
stream
|
||||
.send_data(data)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("Failed to send H3 data: {}", e))?;
|
||||
}
|
||||
}
|
||||
@@ -191,7 +210,9 @@ async fn handle_h3_request(
|
||||
}
|
||||
|
||||
// Finish the H3 stream (send QUIC FIN)
|
||||
stream.finish().await
|
||||
stream
|
||||
.finish()
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("Failed to finish H3 stream: {}", e))?;
|
||||
|
||||
Ok(())
|
||||
|
||||
@@ -5,14 +5,15 @@
|
||||
|
||||
pub mod connection_pool;
|
||||
pub mod counting_body;
|
||||
pub mod h3_service;
|
||||
pub mod protocol_cache;
|
||||
pub mod proxy_service;
|
||||
pub mod request_filter;
|
||||
mod request_host;
|
||||
pub mod response_filter;
|
||||
pub mod shutdown_on_drop;
|
||||
pub mod template;
|
||||
pub mod upstream_selector;
|
||||
pub mod h3_service;
|
||||
|
||||
pub use connection_pool::*;
|
||||
pub use counting_body::*;
|
||||
|
||||
@@ -144,10 +144,14 @@ impl FailureState {
|
||||
}
|
||||
|
||||
fn all_expired(&self) -> bool {
|
||||
let h2_expired = self.h2.as_ref()
|
||||
let h2_expired = self
|
||||
.h2
|
||||
.as_ref()
|
||||
.map(|r| r.failed_at.elapsed() >= r.cooldown)
|
||||
.unwrap_or(true);
|
||||
let h3_expired = self.h3.as_ref()
|
||||
let h3_expired = self
|
||||
.h3
|
||||
.as_ref()
|
||||
.map(|r| r.failed_at.elapsed() >= r.cooldown)
|
||||
.unwrap_or(true);
|
||||
h2_expired && h3_expired
|
||||
@@ -355,9 +359,13 @@ impl ProtocolCache {
|
||||
|
||||
let record = entry.get_mut(protocol);
|
||||
let (consecutive, new_cooldown) = match record {
|
||||
Some(existing) if existing.failed_at.elapsed() < existing.cooldown.saturating_mul(2) => {
|
||||
Some(existing)
|
||||
if existing.failed_at.elapsed() < existing.cooldown.saturating_mul(2) =>
|
||||
{
|
||||
// Still within the "recent" window — escalate
|
||||
let c = existing.consecutive_failures.saturating_add(1)
|
||||
let c = existing
|
||||
.consecutive_failures
|
||||
.saturating_add(1)
|
||||
.min(PROTOCOL_FAILURE_ESCALATION_CAP);
|
||||
(c, escalate_cooldown(c))
|
||||
}
|
||||
@@ -394,8 +402,13 @@ impl ProtocolCache {
|
||||
if protocol == DetectedProtocol::H1 {
|
||||
return false;
|
||||
}
|
||||
self.failures.get(key)
|
||||
.and_then(|entry| entry.get(protocol).map(|r| r.failed_at.elapsed() < r.cooldown))
|
||||
self.failures
|
||||
.get(key)
|
||||
.and_then(|entry| {
|
||||
entry
|
||||
.get(protocol)
|
||||
.map(|r| r.failed_at.elapsed() < r.cooldown)
|
||||
})
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
@@ -464,19 +477,18 @@ impl ProtocolCache {
|
||||
|
||||
/// Snapshot all non-expired cache entries for metrics/UI display.
|
||||
pub fn snapshot(&self) -> Vec<ProtocolCacheEntry> {
|
||||
self.cache.iter()
|
||||
self.cache
|
||||
.iter()
|
||||
.filter(|entry| entry.value().last_accessed_at.elapsed() < PROTOCOL_CACHE_TTL)
|
||||
.map(|entry| {
|
||||
let key = entry.key();
|
||||
let val = entry.value();
|
||||
let failure_info = self.failures.get(key);
|
||||
|
||||
let (h2_sup, h2_cd, h2_cons) = Self::suppression_info(
|
||||
failure_info.as_deref().and_then(|f| f.h2.as_ref()),
|
||||
);
|
||||
let (h3_sup, h3_cd, h3_cons) = Self::suppression_info(
|
||||
failure_info.as_deref().and_then(|f| f.h3.as_ref()),
|
||||
);
|
||||
let (h2_sup, h2_cd, h2_cons) =
|
||||
Self::suppression_info(failure_info.as_deref().and_then(|f| f.h2.as_ref()));
|
||||
let (h3_sup, h3_cd, h3_cons) =
|
||||
Self::suppression_info(failure_info.as_deref().and_then(|f| f.h3.as_ref()));
|
||||
|
||||
ProtocolCacheEntry {
|
||||
host: key.host.clone(),
|
||||
@@ -507,7 +519,13 @@ impl ProtocolCache {
|
||||
/// Insert a protocol detection result with an optional H3 port.
|
||||
/// Logs protocol transitions when overwriting an existing entry.
|
||||
/// No suppression check — callers must check before calling.
|
||||
fn insert_internal(&self, key: ProtocolCacheKey, protocol: DetectedProtocol, h3_port: Option<u16>, reason: &str) {
|
||||
fn insert_internal(
|
||||
&self,
|
||||
key: ProtocolCacheKey,
|
||||
protocol: DetectedProtocol,
|
||||
h3_port: Option<u16>,
|
||||
reason: &str,
|
||||
) {
|
||||
// Check for existing entry to log protocol transitions
|
||||
if let Some(existing) = self.cache.get(&key) {
|
||||
if existing.protocol != protocol {
|
||||
@@ -522,7 +540,9 @@ impl ProtocolCache {
|
||||
|
||||
// Evict oldest entry if at capacity
|
||||
if self.cache.len() >= PROTOCOL_CACHE_MAX_ENTRIES && !self.cache.contains_key(&key) {
|
||||
let oldest = self.cache.iter()
|
||||
let oldest = self
|
||||
.cache
|
||||
.iter()
|
||||
.min_by_key(|entry| entry.value().last_accessed_at)
|
||||
.map(|entry| entry.key().clone());
|
||||
if let Some(oldest_key) = oldest {
|
||||
@@ -531,13 +551,16 @@ impl ProtocolCache {
|
||||
}
|
||||
|
||||
let now = Instant::now();
|
||||
self.cache.insert(key, CachedEntry {
|
||||
protocol,
|
||||
detected_at: now,
|
||||
last_accessed_at: now,
|
||||
last_probed_at: now,
|
||||
h3_port,
|
||||
});
|
||||
self.cache.insert(
|
||||
key,
|
||||
CachedEntry {
|
||||
protocol,
|
||||
detected_at: now,
|
||||
last_accessed_at: now,
|
||||
last_probed_at: now,
|
||||
h3_port,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
/// Reduce a failure record's remaining cooldown to `target`, if it currently
|
||||
@@ -582,26 +605,34 @@ impl ProtocolCache {
|
||||
interval.tick().await;
|
||||
|
||||
// Clean expired cache entries (sliding TTL based on last_accessed_at)
|
||||
let expired: Vec<ProtocolCacheKey> = cache.iter()
|
||||
let expired: Vec<ProtocolCacheKey> = cache
|
||||
.iter()
|
||||
.filter(|entry| entry.value().last_accessed_at.elapsed() >= PROTOCOL_CACHE_TTL)
|
||||
.map(|entry| entry.key().clone())
|
||||
.collect();
|
||||
|
||||
if !expired.is_empty() {
|
||||
debug!("Protocol cache cleanup: removing {} expired entries", expired.len());
|
||||
debug!(
|
||||
"Protocol cache cleanup: removing {} expired entries",
|
||||
expired.len()
|
||||
);
|
||||
for key in expired {
|
||||
cache.remove(&key);
|
||||
}
|
||||
}
|
||||
|
||||
// Clean fully-expired failure entries
|
||||
let expired_failures: Vec<ProtocolCacheKey> = failures.iter()
|
||||
let expired_failures: Vec<ProtocolCacheKey> = failures
|
||||
.iter()
|
||||
.filter(|entry| entry.value().all_expired())
|
||||
.map(|entry| entry.key().clone())
|
||||
.collect();
|
||||
|
||||
if !expired_failures.is_empty() {
|
||||
debug!("Protocol cache cleanup: removing {} expired failure entries", expired_failures.len());
|
||||
debug!(
|
||||
"Protocol cache cleanup: removing {} expired failure entries",
|
||||
expired_failures.len()
|
||||
);
|
||||
for key in expired_failures {
|
||||
failures.remove(&key);
|
||||
}
|
||||
@@ -609,7 +640,8 @@ impl ProtocolCache {
|
||||
|
||||
// Safety net: cap failures map at 2× max entries
|
||||
if failures.len() > PROTOCOL_CACHE_MAX_ENTRIES * 2 {
|
||||
let oldest: Vec<ProtocolCacheKey> = failures.iter()
|
||||
let oldest: Vec<ProtocolCacheKey> = failures
|
||||
.iter()
|
||||
.filter(|e| e.value().all_expired())
|
||||
.map(|e| e.key().clone())
|
||||
.take(failures.len() - PROTOCOL_CACHE_MAX_ENTRIES)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -4,13 +4,15 @@ use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
|
||||
use bytes::Bytes;
|
||||
use http_body_util::Full;
|
||||
use http_body_util::BodyExt;
|
||||
use hyper::{Request, Response, StatusCode};
|
||||
use http_body_util::combinators::BoxBody;
|
||||
use http_body_util::BodyExt;
|
||||
use http_body_util::Full;
|
||||
use hyper::{Request, Response, StatusCode};
|
||||
|
||||
use rustproxy_config::RouteSecurity;
|
||||
use rustproxy_security::{IpFilter, BasicAuthValidator, JwtValidator, RateLimiter};
|
||||
use rustproxy_security::{BasicAuthValidator, IpFilter, JwtValidator, RateLimiter};
|
||||
|
||||
use crate::request_host::extract_request_host;
|
||||
|
||||
pub struct RequestFilter;
|
||||
|
||||
@@ -35,13 +37,14 @@ impl RequestFilter {
|
||||
let client_ip = peer_addr.ip();
|
||||
let request_path = req.uri().path();
|
||||
|
||||
// IP filter
|
||||
// IP filter (domain-aware: use the same host extraction as route matching)
|
||||
if security.ip_allow_list.is_some() || security.ip_block_list.is_some() {
|
||||
let allow = security.ip_allow_list.as_deref().unwrap_or(&[]);
|
||||
let block = security.ip_block_list.as_deref().unwrap_or(&[]);
|
||||
let filter = IpFilter::new(allow, block);
|
||||
let normalized = IpFilter::normalize_ip(&client_ip);
|
||||
if !filter.is_allowed(&normalized) {
|
||||
let host = extract_request_host(req);
|
||||
if !filter.is_allowed_for_domain(&normalized, host) {
|
||||
return Some(error_response(StatusCode::FORBIDDEN, "Access denied"));
|
||||
}
|
||||
}
|
||||
@@ -55,16 +58,15 @@ impl RequestFilter {
|
||||
!limiter.check(&key)
|
||||
} else {
|
||||
// Create a per-check limiter (less ideal but works for non-shared case)
|
||||
let limiter = RateLimiter::new(
|
||||
rate_limit_config.max_requests,
|
||||
rate_limit_config.window,
|
||||
);
|
||||
let limiter =
|
||||
RateLimiter::new(rate_limit_config.max_requests, rate_limit_config.window);
|
||||
let key = Self::rate_limit_key(rate_limit_config, req, peer_addr);
|
||||
!limiter.check(&key)
|
||||
};
|
||||
|
||||
if should_block {
|
||||
let message = rate_limit_config.error_message
|
||||
let message = rate_limit_config
|
||||
.error_message
|
||||
.as_deref()
|
||||
.unwrap_or("Rate limit exceeded");
|
||||
return Some(error_response(StatusCode::TOO_MANY_REQUESTS, message));
|
||||
@@ -80,36 +82,48 @@ impl RequestFilter {
|
||||
if let Some(ref basic_auth) = security.basic_auth {
|
||||
if basic_auth.enabled {
|
||||
// Check basic auth exclude paths
|
||||
let skip_basic = basic_auth.exclude_paths.as_ref()
|
||||
let skip_basic = basic_auth
|
||||
.exclude_paths
|
||||
.as_ref()
|
||||
.map(|paths| Self::path_matches_any(request_path, paths))
|
||||
.unwrap_or(false);
|
||||
|
||||
if !skip_basic {
|
||||
let users: Vec<(String, String)> = basic_auth.users.iter()
|
||||
let users: Vec<(String, String)> = basic_auth
|
||||
.users
|
||||
.iter()
|
||||
.map(|c| (c.username.clone(), c.password.clone()))
|
||||
.collect();
|
||||
let validator = BasicAuthValidator::new(users, basic_auth.realm.clone());
|
||||
|
||||
let auth_header = req.headers()
|
||||
let auth_header = req
|
||||
.headers()
|
||||
.get("authorization")
|
||||
.and_then(|v| v.to_str().ok());
|
||||
|
||||
match auth_header {
|
||||
Some(header) => {
|
||||
if validator.validate(header).is_none() {
|
||||
return Some(Response::builder()
|
||||
.status(StatusCode::UNAUTHORIZED)
|
||||
.header("WWW-Authenticate", validator.www_authenticate())
|
||||
.body(boxed_body("Invalid credentials"))
|
||||
.unwrap());
|
||||
return Some(
|
||||
Response::builder()
|
||||
.status(StatusCode::UNAUTHORIZED)
|
||||
.header(
|
||||
"WWW-Authenticate",
|
||||
validator.www_authenticate(),
|
||||
)
|
||||
.body(boxed_body("Invalid credentials"))
|
||||
.unwrap(),
|
||||
);
|
||||
}
|
||||
}
|
||||
None => {
|
||||
return Some(Response::builder()
|
||||
.status(StatusCode::UNAUTHORIZED)
|
||||
.header("WWW-Authenticate", validator.www_authenticate())
|
||||
.body(boxed_body("Authentication required"))
|
||||
.unwrap());
|
||||
return Some(
|
||||
Response::builder()
|
||||
.status(StatusCode::UNAUTHORIZED)
|
||||
.header("WWW-Authenticate", validator.www_authenticate())
|
||||
.body(boxed_body("Authentication required"))
|
||||
.unwrap(),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -120,7 +134,9 @@ impl RequestFilter {
|
||||
if let Some(ref jwt_auth) = security.jwt_auth {
|
||||
if jwt_auth.enabled {
|
||||
// Check JWT auth exclude paths
|
||||
let skip_jwt = jwt_auth.exclude_paths.as_ref()
|
||||
let skip_jwt = jwt_auth
|
||||
.exclude_paths
|
||||
.as_ref()
|
||||
.map(|paths| Self::path_matches_any(request_path, paths))
|
||||
.unwrap_or(false);
|
||||
|
||||
@@ -132,18 +148,25 @@ impl RequestFilter {
|
||||
jwt_auth.audience.as_deref(),
|
||||
);
|
||||
|
||||
let auth_header = req.headers()
|
||||
let auth_header = req
|
||||
.headers()
|
||||
.get("authorization")
|
||||
.and_then(|v| v.to_str().ok());
|
||||
|
||||
match auth_header.and_then(JwtValidator::extract_token) {
|
||||
Some(token) => {
|
||||
if validator.validate(token).is_err() {
|
||||
return Some(error_response(StatusCode::UNAUTHORIZED, "Invalid token"));
|
||||
return Some(error_response(
|
||||
StatusCode::UNAUTHORIZED,
|
||||
"Invalid token",
|
||||
));
|
||||
}
|
||||
}
|
||||
None => {
|
||||
return Some(error_response(StatusCode::UNAUTHORIZED, "Bearer token required"));
|
||||
return Some(error_response(
|
||||
StatusCode::UNAUTHORIZED,
|
||||
"Bearer token required",
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -203,14 +226,19 @@ impl RequestFilter {
|
||||
}
|
||||
|
||||
/// Check IP-based security (for use in passthrough / TCP-level connections).
|
||||
/// `domain` is the SNI from the TLS handshake (if available) for domain-scoped filtering.
|
||||
/// Returns true if allowed, false if blocked.
|
||||
pub fn check_ip_security(security: &RouteSecurity, client_ip: &std::net::IpAddr) -> bool {
|
||||
pub fn check_ip_security(
|
||||
security: &RouteSecurity,
|
||||
client_ip: &std::net::IpAddr,
|
||||
domain: Option<&str>,
|
||||
) -> bool {
|
||||
if security.ip_allow_list.is_some() || security.ip_block_list.is_some() {
|
||||
let allow = security.ip_allow_list.as_deref().unwrap_or(&[]);
|
||||
let block = security.ip_block_list.as_deref().unwrap_or(&[]);
|
||||
let filter = IpFilter::new(allow, block);
|
||||
let normalized = IpFilter::normalize_ip(client_ip);
|
||||
filter.is_allowed(&normalized)
|
||||
filter.is_allowed_for_domain(&normalized, domain)
|
||||
} else {
|
||||
true
|
||||
}
|
||||
@@ -233,19 +261,28 @@ impl RequestFilter {
|
||||
return None;
|
||||
}
|
||||
|
||||
let origin = req.headers()
|
||||
let origin = req
|
||||
.headers()
|
||||
.get("origin")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.unwrap_or("*");
|
||||
|
||||
Some(Response::builder()
|
||||
.status(StatusCode::NO_CONTENT)
|
||||
.header("Access-Control-Allow-Origin", origin)
|
||||
.header("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, PATCH, OPTIONS")
|
||||
.header("Access-Control-Allow-Headers", "Content-Type, Authorization, X-Requested-With")
|
||||
.header("Access-Control-Max-Age", "86400")
|
||||
.body(boxed_body(""))
|
||||
.unwrap())
|
||||
Some(
|
||||
Response::builder()
|
||||
.status(StatusCode::NO_CONTENT)
|
||||
.header("Access-Control-Allow-Origin", origin)
|
||||
.header(
|
||||
"Access-Control-Allow-Methods",
|
||||
"GET, POST, PUT, DELETE, PATCH, OPTIONS",
|
||||
)
|
||||
.header(
|
||||
"Access-Control-Allow-Headers",
|
||||
"Content-Type, Authorization, X-Requested-With",
|
||||
)
|
||||
.header("Access-Control-Max-Age", "86400")
|
||||
.body(boxed_body(""))
|
||||
.unwrap(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -260,3 +297,71 @@ fn error_response(status: StatusCode, message: &str) -> Response<BoxBody<Bytes,
|
||||
fn boxed_body(data: &str) -> BoxBody<Bytes, hyper::Error> {
|
||||
BoxBody::new(Full::new(Bytes::from(data.to_string())).map_err(|never| match never {}))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use bytes::Bytes;
|
||||
use http_body_util::Empty;
|
||||
use hyper::{Request, StatusCode, Version};
|
||||
use rustproxy_config::{IpAllowEntry, RouteSecurity};
|
||||
|
||||
use super::RequestFilter;
|
||||
|
||||
fn domain_scoped_security() -> RouteSecurity {
|
||||
RouteSecurity {
|
||||
ip_allow_list: Some(vec![IpAllowEntry::DomainScoped {
|
||||
ip: "10.8.0.2".to_string(),
|
||||
domains: vec!["*.abc.xyz".to_string()],
|
||||
}]),
|
||||
ip_block_list: None,
|
||||
max_connections: None,
|
||||
authentication: None,
|
||||
rate_limit: None,
|
||||
basic_auth: None,
|
||||
jwt_auth: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn peer_addr() -> std::net::SocketAddr {
|
||||
std::net::SocketAddr::from(([10, 8, 0, 2], 4242))
|
||||
}
|
||||
|
||||
fn request(uri: &str, version: Version, host: Option<&str>) -> Request<Empty<Bytes>> {
|
||||
let mut builder = Request::builder().uri(uri).version(version);
|
||||
if let Some(host) = host {
|
||||
builder = builder.header("host", host);
|
||||
}
|
||||
|
||||
builder.body(Empty::<Bytes>::new()).unwrap()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn domain_scoped_acl_allows_uri_authority_without_host_header() {
|
||||
let security = domain_scoped_security();
|
||||
let req = request("https://outline.abc.xyz/", Version::HTTP_2, None);
|
||||
|
||||
assert!(RequestFilter::apply(&security, &req, &peer_addr()).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn domain_scoped_acl_allows_host_header_with_port() {
|
||||
let security = domain_scoped_security();
|
||||
let req = request(
|
||||
"https://unrelated.invalid/",
|
||||
Version::HTTP_11,
|
||||
Some("outline.abc.xyz:443"),
|
||||
);
|
||||
|
||||
assert!(RequestFilter::apply(&security, &req, &peer_addr()).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn domain_scoped_acl_denies_non_matching_uri_authority() {
|
||||
let security = domain_scoped_security();
|
||||
let req = request("https://outline.other.xyz/", Version::HTTP_2, None);
|
||||
|
||||
let response = RequestFilter::apply(&security, &req, &peer_addr())
|
||||
.expect("non-matching domain should be denied");
|
||||
assert_eq!(response.status(), StatusCode::FORBIDDEN);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,43 @@
|
||||
use hyper::Request;
|
||||
|
||||
/// Extract the effective request host for routing and scoped ACL checks.
|
||||
///
|
||||
/// Prefer the explicit `Host` header when present, otherwise fall back to the
|
||||
/// URI authority used by HTTP/2 and HTTP/3 requests.
|
||||
pub(crate) fn extract_request_host<B>(req: &Request<B>) -> Option<&str> {
|
||||
req.headers()
|
||||
.get("host")
|
||||
.and_then(|value| value.to_str().ok())
|
||||
.map(|host| host.split(':').next().unwrap_or(host))
|
||||
.or_else(|| req.uri().host())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use bytes::Bytes;
|
||||
use http_body_util::Empty;
|
||||
use hyper::Request;
|
||||
|
||||
use super::extract_request_host;
|
||||
|
||||
#[test]
|
||||
fn extracts_host_header_before_uri_authority() {
|
||||
let req = Request::builder()
|
||||
.uri("https://uri.abc.xyz/test")
|
||||
.header("host", "header.abc.xyz:443")
|
||||
.body(Empty::<Bytes>::new())
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(extract_request_host(&req), Some("header.abc.xyz"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn falls_back_to_uri_authority_when_host_header_missing() {
|
||||
let req = Request::builder()
|
||||
.uri("https://outline.abc.xyz/test")
|
||||
.body(Empty::<Bytes>::new())
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(extract_request_host(&req), Some("outline.abc.xyz"));
|
||||
}
|
||||
}
|
||||
@@ -3,7 +3,7 @@
|
||||
use hyper::header::{HeaderMap, HeaderName, HeaderValue};
|
||||
use rustproxy_config::RouteConfig;
|
||||
|
||||
use crate::template::{RequestContext, expand_template};
|
||||
use crate::template::{expand_template, RequestContext};
|
||||
|
||||
pub struct ResponseFilter;
|
||||
|
||||
@@ -11,12 +11,17 @@ impl ResponseFilter {
|
||||
/// Apply response headers from route config and CORS settings.
|
||||
/// If a `RequestContext` is provided, template variables in header values will be expanded.
|
||||
/// Also injects Alt-Svc header for routes with HTTP/3 enabled.
|
||||
pub fn apply_headers(route: &RouteConfig, headers: &mut HeaderMap, req_ctx: Option<&RequestContext>) {
|
||||
pub fn apply_headers(
|
||||
route: &RouteConfig,
|
||||
headers: &mut HeaderMap,
|
||||
req_ctx: Option<&RequestContext>,
|
||||
) {
|
||||
// Inject Alt-Svc for HTTP/3 advertisement if QUIC/HTTP3 is enabled on this route
|
||||
if let Some(ref udp) = route.action.udp {
|
||||
if let Some(ref quic) = udp.quic {
|
||||
if quic.enable_http3.unwrap_or(false) {
|
||||
let port = quic.alt_svc_port
|
||||
let port = quic
|
||||
.alt_svc_port
|
||||
.or_else(|| req_ctx.map(|c| c.port))
|
||||
.unwrap_or(443);
|
||||
let max_age = quic.alt_svc_max_age.unwrap_or(86400);
|
||||
@@ -63,10 +68,7 @@ impl ResponseFilter {
|
||||
headers.insert("access-control-allow-origin", val);
|
||||
}
|
||||
} else {
|
||||
headers.insert(
|
||||
"access-control-allow-origin",
|
||||
HeaderValue::from_static("*"),
|
||||
);
|
||||
headers.insert("access-control-allow-origin", HeaderValue::from_static("*"));
|
||||
}
|
||||
|
||||
// Allow-Methods
|
||||
|
||||
@@ -62,17 +62,11 @@ impl<S: AsyncRead + AsyncWrite + Unpin + Send + 'static> AsyncWrite for Shutdown
|
||||
self.inner.as_ref().unwrap().is_write_vectored()
|
||||
}
|
||||
|
||||
fn poll_flush(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
Pin::new(self.get_mut().inner.as_mut().unwrap()).poll_flush(cx)
|
||||
}
|
||||
|
||||
fn poll_shutdown(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
let this = self.get_mut();
|
||||
let result = Pin::new(this.inner.as_mut().unwrap()).poll_shutdown(cx);
|
||||
if result.is_ready() {
|
||||
@@ -93,7 +87,8 @@ impl<S: AsyncRead + AsyncWrite + Unpin + Send + 'static> Drop for ShutdownOnDrop
|
||||
let _ = tokio::time::timeout(
|
||||
std::time::Duration::from_secs(2),
|
||||
tokio::io::AsyncWriteExt::shutdown(&mut stream),
|
||||
).await;
|
||||
)
|
||||
.await;
|
||||
// stream is dropped here — all resources freed
|
||||
});
|
||||
}
|
||||
|
||||
@@ -39,7 +39,8 @@ pub fn expand_headers(
|
||||
headers: &HashMap<String, String>,
|
||||
ctx: &RequestContext,
|
||||
) -> HashMap<String, String> {
|
||||
headers.iter()
|
||||
headers
|
||||
.iter()
|
||||
.map(|(k, v)| (k.clone(), expand_template(v, ctx)))
|
||||
.collect()
|
||||
}
|
||||
@@ -150,7 +151,10 @@ mod tests {
|
||||
let ctx = test_context();
|
||||
let template = "{clientIp}|{domain}|{port}|{path}|{routeName}|{connectionId}";
|
||||
let result = expand_template(template, &ctx);
|
||||
assert_eq!(result, "192.168.1.100|example.com|443|/api/v1/users|api-route|42");
|
||||
assert_eq!(
|
||||
result,
|
||||
"192.168.1.100|example.com|443|/api/v1/users|api-route|42"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -7,7 +7,7 @@ use std::sync::Arc;
|
||||
use std::sync::Mutex;
|
||||
|
||||
use dashmap::DashMap;
|
||||
use rustproxy_config::{RouteTarget, LoadBalancingAlgorithm};
|
||||
use rustproxy_config::{LoadBalancingAlgorithm, RouteTarget};
|
||||
|
||||
/// Upstream selection result.
|
||||
pub struct UpstreamSelection {
|
||||
@@ -51,21 +51,19 @@ impl UpstreamSelector {
|
||||
}
|
||||
|
||||
// Determine load balancing algorithm
|
||||
let algorithm = target.load_balancing.as_ref()
|
||||
let algorithm = target
|
||||
.load_balancing
|
||||
.as_ref()
|
||||
.map(|lb| &lb.algorithm)
|
||||
.unwrap_or(&LoadBalancingAlgorithm::RoundRobin);
|
||||
|
||||
let idx = match algorithm {
|
||||
LoadBalancingAlgorithm::RoundRobin => {
|
||||
self.round_robin_select(&hosts, port)
|
||||
}
|
||||
LoadBalancingAlgorithm::RoundRobin => self.round_robin_select(&hosts, port),
|
||||
LoadBalancingAlgorithm::IpHash => {
|
||||
let hash = Self::ip_hash(client_addr);
|
||||
hash % hosts.len()
|
||||
}
|
||||
LoadBalancingAlgorithm::LeastConnections => {
|
||||
self.least_connections_select(&hosts, port)
|
||||
}
|
||||
LoadBalancingAlgorithm::LeastConnections => self.least_connections_select(&hosts, port),
|
||||
};
|
||||
|
||||
UpstreamSelection {
|
||||
@@ -78,9 +76,7 @@ impl UpstreamSelector {
|
||||
fn round_robin_select(&self, hosts: &[&str], port: u16) -> usize {
|
||||
let key = format!("{}:{}", hosts[0], port);
|
||||
let mut counters = self.round_robin.lock().unwrap();
|
||||
let counter = counters
|
||||
.entry(key)
|
||||
.or_insert_with(|| AtomicUsize::new(0));
|
||||
let counter = counters.entry(key).or_insert_with(|| AtomicUsize::new(0));
|
||||
let idx = counter.fetch_add(1, Ordering::Relaxed);
|
||||
idx % hosts.len()
|
||||
}
|
||||
@@ -91,7 +87,8 @@ impl UpstreamSelector {
|
||||
|
||||
for (i, host) in hosts.iter().enumerate() {
|
||||
let key = format!("{}:{}", host, port);
|
||||
let conns = self.active_connections
|
||||
let conns = self
|
||||
.active_connections
|
||||
.get(&key)
|
||||
.map(|entry| entry.value().load(Ordering::Relaxed))
|
||||
.unwrap_or(0);
|
||||
@@ -228,13 +225,21 @@ mod tests {
|
||||
selector.connection_started("backend:8080");
|
||||
selector.connection_started("backend:8080");
|
||||
assert_eq!(
|
||||
selector.active_connections.get("backend:8080").unwrap().load(Ordering::Relaxed),
|
||||
selector
|
||||
.active_connections
|
||||
.get("backend:8080")
|
||||
.unwrap()
|
||||
.load(Ordering::Relaxed),
|
||||
2
|
||||
);
|
||||
|
||||
selector.connection_ended("backend:8080");
|
||||
assert_eq!(
|
||||
selector.active_connections.get("backend:8080").unwrap().load(Ordering::Relaxed),
|
||||
selector
|
||||
.active_connections
|
||||
.get("backend:8080")
|
||||
.unwrap()
|
||||
.load(Ordering::Relaxed),
|
||||
1
|
||||
);
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,329 @@
|
||||
//! Shared connection registry for selective connection recycling.
|
||||
//!
|
||||
//! Tracks active connections across both TCP and QUIC with metadata
|
||||
//! (source IP, SNI domain, route ID, cancel token) so that connections
|
||||
//! can be selectively recycled when certificates, security rules, or
|
||||
//! route targets change.
|
||||
|
||||
use std::collections::HashSet;
|
||||
use std::net::IpAddr;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
|
||||
use dashmap::DashMap;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::info;
|
||||
|
||||
use rustproxy_config::RouteSecurity;
|
||||
use rustproxy_http::request_filter::RequestFilter;
|
||||
use rustproxy_routing::matchers::domain_matches;
|
||||
|
||||
/// Metadata about an active connection.
|
||||
pub struct ConnectionEntry {
|
||||
/// Per-connection cancel token (child of per-route token).
|
||||
pub cancel: CancellationToken,
|
||||
/// Client source IP.
|
||||
pub source_ip: IpAddr,
|
||||
/// SNI domain from TLS handshake (None for non-TLS connections).
|
||||
pub domain: Option<String>,
|
||||
/// Route ID this connection was matched to (None if route has no ID).
|
||||
pub route_id: Option<String>,
|
||||
}
|
||||
|
||||
/// Transport-agnostic registry of active connections.
|
||||
///
|
||||
/// Used by both `TcpListenerManager` and `UdpListenerManager` to track
|
||||
/// connections and enable selective recycling on config changes.
|
||||
pub struct ConnectionRegistry {
|
||||
connections: DashMap<u64, ConnectionEntry>,
|
||||
next_id: AtomicU64,
|
||||
}
|
||||
|
||||
impl ConnectionRegistry {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
connections: DashMap::new(),
|
||||
next_id: AtomicU64::new(1),
|
||||
}
|
||||
}
|
||||
|
||||
/// Register a connection and return its ID + RAII guard.
|
||||
///
|
||||
/// The guard automatically removes the connection from the registry on drop.
|
||||
pub fn register(self: &Arc<Self>, entry: ConnectionEntry) -> (u64, ConnectionRegistryGuard) {
|
||||
let id = self.next_id.fetch_add(1, Ordering::Relaxed);
|
||||
self.connections.insert(id, entry);
|
||||
let guard = ConnectionRegistryGuard {
|
||||
registry: Arc::clone(self),
|
||||
conn_id: id,
|
||||
};
|
||||
(id, guard)
|
||||
}
|
||||
|
||||
/// Number of tracked connections (for metrics/debugging).
|
||||
pub fn len(&self) -> usize {
|
||||
self.connections.len()
|
||||
}
|
||||
|
||||
/// Recycle connections whose SNI domain matches a renewed certificate domain.
|
||||
///
|
||||
/// Uses bidirectional domain matching so that:
|
||||
/// - Cert `*.example.com` recycles connections for `sub.example.com`
|
||||
/// - Cert `sub.example.com` recycles connections on routes with `*.example.com`
|
||||
pub fn recycle_for_cert_change(&self, cert_domain: &str) {
|
||||
let mut recycled = 0u64;
|
||||
self.connections.retain(|_, entry| {
|
||||
let matches = entry.domain.as_deref()
|
||||
.map(|d| domain_matches(cert_domain, d) || domain_matches(d, cert_domain))
|
||||
.unwrap_or(false);
|
||||
if matches {
|
||||
entry.cancel.cancel();
|
||||
recycled += 1;
|
||||
false
|
||||
} else {
|
||||
true
|
||||
}
|
||||
});
|
||||
if recycled > 0 {
|
||||
info!(
|
||||
"Recycled {} connection(s) for cert change on domain '{}'",
|
||||
recycled, cert_domain
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Recycle connections on a route whose security config changed.
|
||||
///
|
||||
/// Re-evaluates each connection's source IP against the new security rules.
|
||||
/// Only connections from now-blocked IPs are terminated; allowed IPs are undisturbed.
|
||||
pub fn recycle_for_security_change(&self, route_id: &str, new_security: &RouteSecurity) {
|
||||
let mut recycled = 0u64;
|
||||
self.connections.retain(|_, entry| {
|
||||
if entry.route_id.as_deref() == Some(route_id) {
|
||||
if !RequestFilter::check_ip_security(new_security, &entry.source_ip, entry.domain.as_deref()) {
|
||||
info!(
|
||||
"Terminating connection from {} — IP now blocked on route '{}'",
|
||||
entry.source_ip, route_id
|
||||
);
|
||||
entry.cancel.cancel();
|
||||
recycled += 1;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
true
|
||||
});
|
||||
if recycled > 0 {
|
||||
info!(
|
||||
"Recycled {} connection(s) for security change on route '{}'",
|
||||
recycled, route_id
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Recycle all connections on a route (e.g., when targets changed).
|
||||
pub fn recycle_for_route_change(&self, route_id: &str) {
|
||||
let mut recycled = 0u64;
|
||||
self.connections.retain(|_, entry| {
|
||||
if entry.route_id.as_deref() == Some(route_id) {
|
||||
entry.cancel.cancel();
|
||||
recycled += 1;
|
||||
false
|
||||
} else {
|
||||
true
|
||||
}
|
||||
});
|
||||
if recycled > 0 {
|
||||
info!(
|
||||
"Recycled {} connection(s) for config change on route '{}'",
|
||||
recycled, route_id
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Remove connections on routes that no longer exist.
|
||||
///
|
||||
/// This complements per-route CancellationToken cancellation —
|
||||
/// the token cascade handles graceful shutdown, this cleans up the registry.
|
||||
pub fn cleanup_removed_routes(&self, active_route_ids: &HashSet<String>) {
|
||||
self.connections.retain(|_, entry| {
|
||||
match &entry.route_id {
|
||||
Some(id) => active_route_ids.contains(id),
|
||||
None => true, // keep connections without a route ID
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/// RAII guard that removes a connection from the registry on drop.
|
||||
pub struct ConnectionRegistryGuard {
|
||||
registry: Arc<ConnectionRegistry>,
|
||||
conn_id: u64,
|
||||
}
|
||||
|
||||
impl Drop for ConnectionRegistryGuard {
|
||||
fn drop(&mut self) {
|
||||
self.registry.connections.remove(&self.conn_id);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn make_registry() -> Arc<ConnectionRegistry> {
|
||||
Arc::new(ConnectionRegistry::new())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_register_and_guard_cleanup() {
|
||||
let reg = make_registry();
|
||||
let token = CancellationToken::new();
|
||||
let entry = ConnectionEntry {
|
||||
cancel: token.clone(),
|
||||
source_ip: "10.0.0.1".parse().unwrap(),
|
||||
domain: Some("example.com".to_string()),
|
||||
route_id: Some("route-1".to_string()),
|
||||
};
|
||||
let (id, guard) = reg.register(entry);
|
||||
assert_eq!(reg.len(), 1);
|
||||
assert!(reg.connections.contains_key(&id));
|
||||
|
||||
drop(guard);
|
||||
assert_eq!(reg.len(), 0);
|
||||
assert!(!token.is_cancelled());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_recycle_for_cert_change_exact() {
|
||||
let reg = make_registry();
|
||||
let t1 = CancellationToken::new();
|
||||
let t2 = CancellationToken::new();
|
||||
let (_, _g1) = reg.register(ConnectionEntry {
|
||||
cancel: t1.clone(),
|
||||
source_ip: "10.0.0.1".parse().unwrap(),
|
||||
domain: Some("api.example.com".to_string()),
|
||||
route_id: Some("r1".to_string()),
|
||||
});
|
||||
let (_, _g2) = reg.register(ConnectionEntry {
|
||||
cancel: t2.clone(),
|
||||
source_ip: "10.0.0.2".parse().unwrap(),
|
||||
domain: Some("other.com".to_string()),
|
||||
route_id: Some("r2".to_string()),
|
||||
});
|
||||
|
||||
reg.recycle_for_cert_change("api.example.com");
|
||||
assert!(t1.is_cancelled());
|
||||
assert!(!t2.is_cancelled());
|
||||
// Registry retains unmatched entry (guard still alive keeps it too,
|
||||
// but the retain removed the matched one before guard could)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_recycle_for_cert_change_wildcard() {
|
||||
let reg = make_registry();
|
||||
let t1 = CancellationToken::new();
|
||||
let t2 = CancellationToken::new();
|
||||
let (_, _g1) = reg.register(ConnectionEntry {
|
||||
cancel: t1.clone(),
|
||||
source_ip: "10.0.0.1".parse().unwrap(),
|
||||
domain: Some("sub.example.com".to_string()),
|
||||
route_id: Some("r1".to_string()),
|
||||
});
|
||||
let (_, _g2) = reg.register(ConnectionEntry {
|
||||
cancel: t2.clone(),
|
||||
source_ip: "10.0.0.2".parse().unwrap(),
|
||||
domain: Some("other.com".to_string()),
|
||||
route_id: Some("r2".to_string()),
|
||||
});
|
||||
|
||||
// Wildcard cert should match subdomain
|
||||
reg.recycle_for_cert_change("*.example.com");
|
||||
assert!(t1.is_cancelled());
|
||||
assert!(!t2.is_cancelled());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_recycle_for_security_change() {
|
||||
let reg = make_registry();
|
||||
let t1 = CancellationToken::new();
|
||||
let t2 = CancellationToken::new();
|
||||
let (_, _g1) = reg.register(ConnectionEntry {
|
||||
cancel: t1.clone(),
|
||||
source_ip: "10.0.0.1".parse().unwrap(),
|
||||
domain: None,
|
||||
route_id: Some("r1".to_string()),
|
||||
});
|
||||
let (_, _g2) = reg.register(ConnectionEntry {
|
||||
cancel: t2.clone(),
|
||||
source_ip: "10.0.0.2".parse().unwrap(),
|
||||
domain: None,
|
||||
route_id: Some("r1".to_string()),
|
||||
});
|
||||
|
||||
// Block 10.0.0.1, allow 10.0.0.2
|
||||
let security = RouteSecurity {
|
||||
ip_allow_list: None,
|
||||
ip_block_list: Some(vec!["10.0.0.1".to_string()]),
|
||||
max_connections: None,
|
||||
authentication: None,
|
||||
rate_limit: None,
|
||||
basic_auth: None,
|
||||
jwt_auth: None,
|
||||
};
|
||||
|
||||
reg.recycle_for_security_change("r1", &security);
|
||||
assert!(t1.is_cancelled());
|
||||
assert!(!t2.is_cancelled());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_recycle_for_route_change() {
|
||||
let reg = make_registry();
|
||||
let t1 = CancellationToken::new();
|
||||
let t2 = CancellationToken::new();
|
||||
let (_, _g1) = reg.register(ConnectionEntry {
|
||||
cancel: t1.clone(),
|
||||
source_ip: "10.0.0.1".parse().unwrap(),
|
||||
domain: None,
|
||||
route_id: Some("r1".to_string()),
|
||||
});
|
||||
let (_, _g2) = reg.register(ConnectionEntry {
|
||||
cancel: t2.clone(),
|
||||
source_ip: "10.0.0.2".parse().unwrap(),
|
||||
domain: None,
|
||||
route_id: Some("r2".to_string()),
|
||||
});
|
||||
|
||||
reg.recycle_for_route_change("r1");
|
||||
assert!(t1.is_cancelled());
|
||||
assert!(!t2.is_cancelled());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cleanup_removed_routes() {
|
||||
let reg = make_registry();
|
||||
let t1 = CancellationToken::new();
|
||||
let t2 = CancellationToken::new();
|
||||
let (_, _g1) = reg.register(ConnectionEntry {
|
||||
cancel: t1.clone(),
|
||||
source_ip: "10.0.0.1".parse().unwrap(),
|
||||
domain: None,
|
||||
route_id: Some("active".to_string()),
|
||||
});
|
||||
let (_, _g2) = reg.register(ConnectionEntry {
|
||||
cancel: t2.clone(),
|
||||
source_ip: "10.0.0.2".parse().unwrap(),
|
||||
domain: None,
|
||||
route_id: Some("removed".to_string()),
|
||||
});
|
||||
|
||||
let mut active = HashSet::new();
|
||||
active.insert("active".to_string());
|
||||
reg.cleanup_removed_routes(&active);
|
||||
|
||||
// "removed" route entry was cleaned from registry
|
||||
// (but guard is still alive so len may differ — the retain already removed it)
|
||||
assert!(!t1.is_cancelled()); // not cancelled by cleanup, only by token cascade
|
||||
assert!(!t2.is_cancelled()); // cleanup doesn't cancel, just removes from registry
|
||||
}
|
||||
}
|
||||
@@ -10,6 +10,7 @@ pub mod forwarder;
|
||||
pub mod proxy_protocol;
|
||||
pub mod tls_handler;
|
||||
pub mod connection_tracker;
|
||||
pub mod connection_registry;
|
||||
pub mod socket_opts;
|
||||
pub mod udp_session;
|
||||
pub mod udp_listener;
|
||||
@@ -21,6 +22,7 @@ pub use forwarder::*;
|
||||
pub use proxy_protocol::*;
|
||||
pub use tls_handler::*;
|
||||
pub use connection_tracker::*;
|
||||
pub use connection_registry::*;
|
||||
pub use socket_opts::*;
|
||||
pub use udp_session::*;
|
||||
pub use udp_listener::*;
|
||||
|
||||
@@ -30,6 +30,7 @@ use rustproxy_routing::{MatchContext, RouteManager};
|
||||
use rustproxy_http::h3_service::H3ProxyService;
|
||||
|
||||
use crate::connection_tracker::ConnectionTracker;
|
||||
use crate::connection_registry::{ConnectionEntry, ConnectionRegistry};
|
||||
|
||||
/// Create a QUIC server endpoint on the given port with the provided TLS config.
|
||||
///
|
||||
@@ -350,6 +351,8 @@ pub async fn quic_accept_loop(
|
||||
cancel: CancellationToken,
|
||||
h3_service: Option<Arc<H3ProxyService>>,
|
||||
real_client_map: Option<Arc<DashMap<SocketAddr, SocketAddr>>>,
|
||||
route_cancels: Arc<DashMap<String, CancellationToken>>,
|
||||
connection_registry: Arc<ConnectionRegistry>,
|
||||
) {
|
||||
loop {
|
||||
let incoming = tokio::select! {
|
||||
@@ -406,17 +409,48 @@ pub async fn quic_accept_loop(
|
||||
}
|
||||
};
|
||||
|
||||
// Check route-level IP security for QUIC (domain from SNI context)
|
||||
if let Some(ref security) = route.security {
|
||||
if !rustproxy_http::request_filter::RequestFilter::check_ip_security(
|
||||
security, &ip, ctx.domain,
|
||||
) {
|
||||
debug!("QUIC connection from {} blocked by route security", real_addr);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
conn_tracker.connection_opened(&ip);
|
||||
let route_id = route.name.clone().or(route.id.clone());
|
||||
metrics.connection_opened(route_id.as_deref(), Some(&ip_str));
|
||||
|
||||
// Resolve per-route cancel token (child of global cancel)
|
||||
let route_cancel = match route_id.as_deref() {
|
||||
Some(id) => route_cancels.entry(id.to_string())
|
||||
.or_insert_with(|| cancel.child_token())
|
||||
.clone(),
|
||||
None => cancel.child_token(),
|
||||
};
|
||||
// Per-connection child token for selective recycling
|
||||
let conn_cancel = route_cancel.child_token();
|
||||
|
||||
// Register in connection registry
|
||||
let registry = Arc::clone(&connection_registry);
|
||||
let reg_entry = ConnectionEntry {
|
||||
cancel: conn_cancel.clone(),
|
||||
source_ip: ip,
|
||||
domain: None, // QUIC Initial is encrypted, domain comes later via H3 :authority
|
||||
route_id: route_id.clone(),
|
||||
};
|
||||
|
||||
let metrics = Arc::clone(&metrics);
|
||||
let conn_tracker = Arc::clone(&conn_tracker);
|
||||
let cancel = cancel.child_token();
|
||||
let h3_svc = h3_service.clone();
|
||||
let real_client_addr = if real_addr != remote_addr { Some(real_addr) } else { None };
|
||||
|
||||
tokio::spawn(async move {
|
||||
// Register in connection registry (RAII guard removes on drop)
|
||||
let (_conn_id, _registry_guard) = registry.register(reg_entry);
|
||||
|
||||
// RAII guard: ensures metrics/tracker cleanup even on panic
|
||||
struct QuicConnGuard {
|
||||
tracker: Arc<ConnectionTracker>,
|
||||
@@ -439,7 +473,7 @@ pub async fn quic_accept_loop(
|
||||
route_id,
|
||||
};
|
||||
|
||||
match handle_quic_connection(incoming, route, port, Arc::clone(&metrics), &cancel, h3_svc, real_client_addr).await {
|
||||
match handle_quic_connection(incoming, route, port, Arc::clone(&metrics), &conn_cancel, h3_svc, real_client_addr).await {
|
||||
Ok(()) => debug!("QUIC connection from {} completed", real_addr),
|
||||
Err(e) => debug!("QUIC connection from {} error: {}", real_addr, e),
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ use crate::sni_parser;
|
||||
use crate::forwarder;
|
||||
use crate::tls_handler;
|
||||
use crate::connection_tracker::ConnectionTracker;
|
||||
use crate::connection_registry::{ConnectionEntry, ConnectionRegistry};
|
||||
use crate::socket_opts;
|
||||
|
||||
/// RAII guard that decrements the active connection metric on drop.
|
||||
@@ -42,6 +43,33 @@ impl Drop for ConnectionGuard {
|
||||
}
|
||||
}
|
||||
|
||||
/// RAII guard for frontend+backend protocol distribution tracking.
|
||||
/// Calls the appropriate _closed methods on drop for both frontend and backend.
|
||||
struct ProtocolGuard {
|
||||
metrics: Arc<MetricsCollector>,
|
||||
frontend_proto: Option<&'static str>,
|
||||
backend_proto: Option<&'static str>,
|
||||
}
|
||||
|
||||
impl ProtocolGuard {
|
||||
fn new(metrics: Arc<MetricsCollector>, frontend: &'static str, backend: &'static str) -> Self {
|
||||
metrics.frontend_protocol_opened(frontend);
|
||||
metrics.backend_protocol_opened(backend);
|
||||
Self { metrics, frontend_proto: Some(frontend), backend_proto: Some(backend) }
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for ProtocolGuard {
|
||||
fn drop(&mut self) {
|
||||
if let Some(proto) = self.frontend_proto {
|
||||
self.metrics.frontend_protocol_closed(proto);
|
||||
}
|
||||
if let Some(proto) = self.backend_proto {
|
||||
self.metrics.backend_protocol_closed(proto);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// RAII guard that calls ConnectionTracker::connection_closed on drop.
|
||||
/// Ensures per-IP tracking is cleaned up on ALL exit paths — normal, error, or panic.
|
||||
struct ConnectionTrackerGuard {
|
||||
@@ -166,6 +194,8 @@ pub struct TcpListenerManager {
|
||||
/// Per-route cancellation tokens (child of cancel_token).
|
||||
/// When a route is removed, its token is cancelled, terminating all connections on that route.
|
||||
route_cancels: Arc<DashMap<String, CancellationToken>>,
|
||||
/// Shared connection registry for selective recycling on config changes.
|
||||
connection_registry: Arc<ConnectionRegistry>,
|
||||
}
|
||||
|
||||
impl TcpListenerManager {
|
||||
@@ -205,6 +235,7 @@ impl TcpListenerManager {
|
||||
socket_handler_relay: Arc::new(std::sync::RwLock::new(None)),
|
||||
conn_semaphore: Arc::new(tokio::sync::Semaphore::new(max_conns)),
|
||||
route_cancels: Arc::new(DashMap::new()),
|
||||
connection_registry: Arc::new(ConnectionRegistry::new()),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -244,6 +275,7 @@ impl TcpListenerManager {
|
||||
socket_handler_relay: Arc::new(std::sync::RwLock::new(None)),
|
||||
conn_semaphore: Arc::new(tokio::sync::Semaphore::new(max_conns)),
|
||||
route_cancels: Arc::new(DashMap::new()),
|
||||
connection_registry: Arc::new(ConnectionRegistry::new()),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -328,12 +360,13 @@ impl TcpListenerManager {
|
||||
let relay = Arc::clone(&self.socket_handler_relay);
|
||||
let semaphore = Arc::clone(&self.conn_semaphore);
|
||||
let route_cancels = Arc::clone(&self.route_cancels);
|
||||
let connection_registry = Arc::clone(&self.connection_registry);
|
||||
|
||||
let handle = tokio::spawn(async move {
|
||||
Self::accept_loop(
|
||||
listener, port, route_manager_swap, metrics, tls_configs,
|
||||
shared_tls_acceptor, http_proxy, conn_config, conn_tracker, cancel, relay,
|
||||
semaphore, route_cancels,
|
||||
semaphore, route_cancels, connection_registry,
|
||||
).await;
|
||||
});
|
||||
|
||||
@@ -446,6 +479,16 @@ impl TcpListenerManager {
|
||||
&self.metrics
|
||||
}
|
||||
|
||||
/// Get a reference to the shared connection registry.
|
||||
pub fn connection_registry(&self) -> &Arc<ConnectionRegistry> {
|
||||
&self.connection_registry
|
||||
}
|
||||
|
||||
/// Get a reference to the per-route cancellation tokens.
|
||||
pub fn route_cancels(&self) -> &Arc<DashMap<String, CancellationToken>> {
|
||||
&self.route_cancels
|
||||
}
|
||||
|
||||
/// Accept loop for a single port.
|
||||
async fn accept_loop(
|
||||
listener: TcpListener,
|
||||
@@ -461,6 +504,7 @@ impl TcpListenerManager {
|
||||
socket_handler_relay: Arc<std::sync::RwLock<Option<String>>>,
|
||||
conn_semaphore: Arc<tokio::sync::Semaphore>,
|
||||
route_cancels: Arc<DashMap<String, CancellationToken>>,
|
||||
connection_registry: Arc<ConnectionRegistry>,
|
||||
) {
|
||||
loop {
|
||||
tokio::select! {
|
||||
@@ -514,6 +558,7 @@ impl TcpListenerManager {
|
||||
let cn = cancel.clone();
|
||||
let sr = Arc::clone(&socket_handler_relay);
|
||||
let rc = Arc::clone(&route_cancels);
|
||||
let cr = Arc::clone(&connection_registry);
|
||||
debug!("Accepted connection from {} on port {}", peer_addr, port);
|
||||
|
||||
tokio::spawn(async move {
|
||||
@@ -522,7 +567,7 @@ impl TcpListenerManager {
|
||||
// RAII guard ensures connection_closed is called on all paths
|
||||
let _ct_guard = ConnectionTrackerGuard::new(ct, ip);
|
||||
let result = Self::handle_connection(
|
||||
stream, port, peer_addr, rm, m, tc, sa, hp, cc, cn, sr, rc,
|
||||
stream, port, peer_addr, rm, m, tc, sa, hp, cc, cn, sr, rc, cr,
|
||||
).await;
|
||||
if let Err(e) = result {
|
||||
warn!("Connection error from {}: {}", peer_addr, e);
|
||||
@@ -553,6 +598,7 @@ impl TcpListenerManager {
|
||||
cancel: CancellationToken,
|
||||
socket_handler_relay: Arc<std::sync::RwLock<Option<String>>>,
|
||||
route_cancels: Arc<DashMap<String, CancellationToken>>,
|
||||
connection_registry: Arc<ConnectionRegistry>,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
use tokio::io::AsyncReadExt;
|
||||
|
||||
@@ -672,17 +718,29 @@ impl TcpListenerManager {
|
||||
let route_id = quick_match.route.id.as_deref();
|
||||
|
||||
// Resolve per-route cancel token (child of global cancel)
|
||||
let conn_cancel = match route_id {
|
||||
let route_cancel = match route_id {
|
||||
Some(id) => route_cancels.entry(id.to_string())
|
||||
.or_insert_with(|| cancel.child_token())
|
||||
.clone(),
|
||||
None => cancel.clone(),
|
||||
};
|
||||
// Per-connection child token for selective recycling
|
||||
let conn_cancel = route_cancel.child_token();
|
||||
|
||||
// Check route-level IP security
|
||||
// Register in connection registry for selective recycling
|
||||
let (_conn_id, _registry_guard) = connection_registry.register(
|
||||
ConnectionEntry {
|
||||
cancel: conn_cancel.clone(),
|
||||
source_ip: peer_addr.ip(),
|
||||
domain: None, // fast path has no domain
|
||||
route_id: route_id.map(|s| s.to_string()),
|
||||
},
|
||||
);
|
||||
|
||||
// Check route-level IP security (fast path: no SNI available)
|
||||
if let Some(ref security) = quick_match.route.security {
|
||||
if !rustproxy_http::request_filter::RequestFilter::check_ip_security(
|
||||
security, &peer_addr.ip(),
|
||||
security, &peer_addr.ip(), None,
|
||||
) {
|
||||
warn!("Connection from {} blocked by route security", peer_addr);
|
||||
return Ok(());
|
||||
@@ -852,18 +910,31 @@ impl TcpListenerManager {
|
||||
// Resolve per-route cancel token (child of global cancel).
|
||||
// When this route is removed via updateRoutes, the token is cancelled,
|
||||
// terminating all connections on this route.
|
||||
let cancel = match route_id {
|
||||
let route_cancel = match route_id {
|
||||
Some(id) => route_cancels.entry(id.to_string())
|
||||
.or_insert_with(|| cancel.child_token())
|
||||
.clone(),
|
||||
None => cancel,
|
||||
};
|
||||
// Per-connection child token for selective recycling
|
||||
let cancel = route_cancel.child_token();
|
||||
|
||||
// Check route-level IP security for passthrough connections
|
||||
// Register in connection registry for selective recycling
|
||||
let (_conn_id, _registry_guard) = connection_registry.register(
|
||||
ConnectionEntry {
|
||||
cancel: cancel.clone(),
|
||||
source_ip: peer_addr.ip(),
|
||||
domain: domain.clone(),
|
||||
route_id: route_id.map(|s| s.to_string()),
|
||||
},
|
||||
);
|
||||
|
||||
// Check route-level IP security for passthrough connections (SNI available)
|
||||
if let Some(ref security) = route_match.route.security {
|
||||
if !rustproxy_http::request_filter::RequestFilter::check_ip_security(
|
||||
security,
|
||||
&peer_addr.ip(),
|
||||
domain.as_deref(),
|
||||
) {
|
||||
warn!("Connection from {} blocked by route security", peer_addr);
|
||||
return Ok(());
|
||||
@@ -872,6 +943,9 @@ impl TcpListenerManager {
|
||||
|
||||
// Track connection in metrics — guard ensures connection_closed on all exit paths
|
||||
metrics.connection_opened(route_id, Some(&ip_str));
|
||||
if let Some(d) = effective_domain {
|
||||
metrics.record_ip_domain_request(&ip_str, d);
|
||||
}
|
||||
let _conn_guard = ConnectionGuard::new(Arc::clone(&metrics), route_id, Some(&ip_str));
|
||||
|
||||
// Check if this is a socket-handler route that should be relayed to TypeScript
|
||||
@@ -981,6 +1055,9 @@ impl TcpListenerManager {
|
||||
peer_addr, target_host, target_port, domain
|
||||
);
|
||||
|
||||
// Track as "other" protocol (non-HTTP passthrough)
|
||||
let _proto_guard = ProtocolGuard::new(Arc::clone(&metrics), "other", "other");
|
||||
|
||||
let mut actual_buf = vec![0u8; n];
|
||||
stream.read_exact(&mut actual_buf).await?;
|
||||
|
||||
@@ -1047,6 +1124,8 @@ impl TcpListenerManager {
|
||||
"TLS Terminate + TCP: {} -> {}:{} (domain: {:?})",
|
||||
peer_addr, target_host, target_port, domain
|
||||
);
|
||||
// Track as "other" protocol (TLS-terminated non-HTTP)
|
||||
let _proto_guard = ProtocolGuard::new(Arc::clone(&metrics), "other", "other");
|
||||
// Raw TCP forwarding of decrypted stream
|
||||
let backend = match tokio::time::timeout(
|
||||
connect_timeout,
|
||||
@@ -1133,6 +1212,8 @@ impl TcpListenerManager {
|
||||
"TLS Terminate+Reencrypt + TCP: {} -> {}:{}",
|
||||
peer_addr, target_host, target_port
|
||||
);
|
||||
// Track as "other" protocol (TLS-terminated non-HTTP, re-encrypted)
|
||||
let _proto_guard = ProtocolGuard::new(Arc::clone(&metrics), "other", "other");
|
||||
Self::handle_tls_reencrypt_tunnel(
|
||||
buf_stream, &target_host, target_port,
|
||||
peer_addr, Arc::clone(&metrics), route_id,
|
||||
@@ -1149,6 +1230,8 @@ impl TcpListenerManager {
|
||||
Ok(())
|
||||
} else {
|
||||
// Plain TCP forwarding (non-HTTP)
|
||||
// Track as "other" protocol (plain TCP, non-HTTP)
|
||||
let _proto_guard = ProtocolGuard::new(Arc::clone(&metrics), "other", "other");
|
||||
let mut backend = match tokio::time::timeout(
|
||||
connect_timeout,
|
||||
tokio::net::TcpStream::connect(format!("{}:{}", target_host, target_port)),
|
||||
|
||||
@@ -28,6 +28,8 @@ use rustproxy_routing::{MatchContext, RouteManager};
|
||||
|
||||
use rustproxy_http::h3_service::H3ProxyService;
|
||||
|
||||
use crate::connection_registry::ConnectionRegistry;
|
||||
|
||||
use crate::connection_tracker::ConnectionTracker;
|
||||
use crate::udp_session::{SessionKey, UdpSession, UdpSessionConfig, UdpSessionTable};
|
||||
|
||||
@@ -56,6 +58,10 @@ pub struct UdpListenerManager {
|
||||
/// Trusted proxy IPs that may send PROXY protocol v2 headers.
|
||||
/// When non-empty, PROXY v2 detection is enabled on both raw UDP and QUIC paths.
|
||||
proxy_ips: Arc<Vec<IpAddr>>,
|
||||
/// Per-route cancellation tokens (shared with TcpListenerManager).
|
||||
route_cancels: Arc<DashMap<String, CancellationToken>>,
|
||||
/// Shared connection registry for selective recycling.
|
||||
connection_registry: Arc<ConnectionRegistry>,
|
||||
}
|
||||
|
||||
impl Drop for UdpListenerManager {
|
||||
@@ -76,6 +82,8 @@ impl UdpListenerManager {
|
||||
metrics: Arc<MetricsCollector>,
|
||||
conn_tracker: Arc<ConnectionTracker>,
|
||||
cancel_token: CancellationToken,
|
||||
route_cancels: Arc<DashMap<String, CancellationToken>>,
|
||||
connection_registry: Arc<ConnectionRegistry>,
|
||||
) -> Self {
|
||||
Self {
|
||||
listeners: HashMap::new(),
|
||||
@@ -89,6 +97,8 @@ impl UdpListenerManager {
|
||||
relay_reader_cancel: None,
|
||||
h3_service: None,
|
||||
proxy_ips: Arc::new(Vec::new()),
|
||||
route_cancels,
|
||||
connection_registry,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -152,6 +162,8 @@ impl UdpListenerManager {
|
||||
self.cancel_token.child_token(),
|
||||
self.h3_service.clone(),
|
||||
None,
|
||||
Arc::clone(&self.route_cancels),
|
||||
Arc::clone(&self.connection_registry),
|
||||
));
|
||||
self.listeners.insert(port, (handle, Some(endpoint_for_updates)));
|
||||
info!("QUIC endpoint started on port {}", port);
|
||||
@@ -173,6 +185,8 @@ impl UdpListenerManager {
|
||||
self.cancel_token.child_token(),
|
||||
self.h3_service.clone(),
|
||||
Some(relay.real_client_map),
|
||||
Arc::clone(&self.route_cancels),
|
||||
Arc::clone(&self.connection_registry),
|
||||
));
|
||||
self.listeners.insert(port, (handle, Some(endpoint_for_updates)));
|
||||
info!("QUIC endpoint with PROXY relay started on port {}", port);
|
||||
@@ -356,6 +370,8 @@ impl UdpListenerManager {
|
||||
self.cancel_token.child_token(),
|
||||
self.h3_service.clone(),
|
||||
None,
|
||||
Arc::clone(&self.route_cancels),
|
||||
Arc::clone(&self.connection_registry),
|
||||
));
|
||||
self.listeners.insert(port, (handle, Some(endpoint_for_updates)));
|
||||
Ok(())
|
||||
@@ -379,6 +395,8 @@ impl UdpListenerManager {
|
||||
self.cancel_token.child_token(),
|
||||
self.h3_service.clone(),
|
||||
Some(relay.real_client_map),
|
||||
Arc::clone(&self.route_cancels),
|
||||
Arc::clone(&self.connection_registry),
|
||||
));
|
||||
self.listeners.insert(port, (handle, Some(endpoint_for_updates)));
|
||||
Ok(())
|
||||
|
||||
@@ -1,5 +1,42 @@
|
||||
use std::collections::HashMap;
|
||||
use regex::Regex;
|
||||
use std::collections::HashMap;
|
||||
|
||||
fn compile_regex_pattern(pattern: &str) -> Option<Regex> {
|
||||
if !pattern.starts_with('/') {
|
||||
return None;
|
||||
}
|
||||
|
||||
let last_slash = pattern.rfind('/')?;
|
||||
if last_slash == 0 {
|
||||
return None;
|
||||
}
|
||||
|
||||
let regex_body = &pattern[1..last_slash];
|
||||
let flags = &pattern[last_slash + 1..];
|
||||
|
||||
let mut inline_flags = String::new();
|
||||
for flag in flags.chars() {
|
||||
match flag {
|
||||
'i' | 'm' | 's' | 'u' => {
|
||||
if !inline_flags.contains(flag) {
|
||||
inline_flags.push(flag);
|
||||
}
|
||||
}
|
||||
'g' => {
|
||||
// Global has no effect for single header matching.
|
||||
}
|
||||
_ => return None,
|
||||
}
|
||||
}
|
||||
|
||||
let compiled = if inline_flags.is_empty() {
|
||||
regex_body.to_string()
|
||||
} else {
|
||||
format!("(?{}){}", inline_flags, regex_body)
|
||||
};
|
||||
|
||||
Regex::new(&compiled).ok()
|
||||
}
|
||||
|
||||
/// Match HTTP headers against a set of patterns.
|
||||
///
|
||||
@@ -24,16 +61,15 @@ pub fn headers_match(
|
||||
None => return false, // Required header not present
|
||||
};
|
||||
|
||||
// Check if pattern is a regex (surrounded by /)
|
||||
if pattern.starts_with('/') && pattern.ends_with('/') && pattern.len() > 2 {
|
||||
let regex_str = &pattern[1..pattern.len() - 1];
|
||||
match Regex::new(regex_str) {
|
||||
Ok(re) => {
|
||||
// Check if pattern is a regex literal (/pattern/ or /pattern/flags)
|
||||
if pattern.starts_with('/') && pattern.len() > 2 {
|
||||
match compile_regex_pattern(pattern) {
|
||||
Some(re) => {
|
||||
if !re.is_match(header_value) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
None => {
|
||||
// Invalid regex, fall back to exact match
|
||||
if header_value != pattern {
|
||||
return false;
|
||||
@@ -85,6 +121,24 @@ mod tests {
|
||||
assert!(headers_match(&patterns, &headers));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_regex_header_match_with_flags() {
|
||||
let patterns: HashMap<String, String> = {
|
||||
let mut m = HashMap::new();
|
||||
m.insert(
|
||||
"Content-Type".to_string(),
|
||||
"/^application\\/json$/i".to_string(),
|
||||
);
|
||||
m
|
||||
};
|
||||
let headers: HashMap<String, String> = {
|
||||
let mut m = HashMap::new();
|
||||
m.insert("content-type".to_string(), "Application/JSON".to_string());
|
||||
m
|
||||
};
|
||||
assert!(headers_match(&patterns, &headers));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_missing_header() {
|
||||
let patterns: HashMap<String, String> = {
|
||||
|
||||
@@ -281,6 +281,11 @@ impl RouteManager {
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
/// Get all enabled routes.
|
||||
pub fn routes(&self) -> &[RouteConfig] {
|
||||
&self.routes
|
||||
}
|
||||
|
||||
/// Get the total number of enabled routes.
|
||||
pub fn route_count(&self) -> usize {
|
||||
self.routes.len()
|
||||
|
||||
@@ -2,12 +2,24 @@ use ipnet::IpNet;
|
||||
use std::net::IpAddr;
|
||||
use std::str::FromStr;
|
||||
|
||||
use rustproxy_config::IpAllowEntry;
|
||||
|
||||
/// IP filter supporting CIDR ranges, wildcards, and exact matches.
|
||||
/// Supports domain-scoped allow entries that restrict an IP to specific domains.
|
||||
pub struct IpFilter {
|
||||
/// Plain allow entries — IP allowed for any domain on the route
|
||||
allow_list: Vec<IpPattern>,
|
||||
/// Domain-scoped allow entries — IP allowed only for matching domains
|
||||
domain_scoped: Vec<DomainScopedEntry>,
|
||||
block_list: Vec<IpPattern>,
|
||||
}
|
||||
|
||||
/// A domain-scoped allow entry: IP + list of allowed domain patterns.
|
||||
struct DomainScopedEntry {
|
||||
pattern: IpPattern,
|
||||
domains: Vec<String>,
|
||||
}
|
||||
|
||||
/// Represents an IP pattern for matching.
|
||||
#[derive(Debug)]
|
||||
enum IpPattern {
|
||||
@@ -31,10 +43,6 @@ impl IpPattern {
|
||||
if let Ok(addr) = IpAddr::from_str(s) {
|
||||
return IpPattern::Exact(addr);
|
||||
}
|
||||
// Try as CIDR by appending default prefix
|
||||
if let Ok(addr) = IpAddr::from_str(s) {
|
||||
return IpPattern::Exact(addr);
|
||||
}
|
||||
// Fallback: treat as exact, will never match an invalid string
|
||||
IpPattern::Exact(IpAddr::from_str("0.0.0.0").unwrap())
|
||||
}
|
||||
@@ -48,19 +56,56 @@ impl IpPattern {
|
||||
}
|
||||
}
|
||||
|
||||
/// Simple domain pattern matching (exact, `*`, or `*.suffix`).
|
||||
fn domain_matches_pattern(pattern: &str, domain: &str) -> bool {
|
||||
let p = pattern.trim();
|
||||
let d = domain.trim();
|
||||
if p == "*" {
|
||||
return true;
|
||||
}
|
||||
if p.eq_ignore_ascii_case(d) {
|
||||
return true;
|
||||
}
|
||||
if p.starts_with("*.") {
|
||||
let suffix = &p[1..]; // e.g., ".abc.xyz"
|
||||
d.len() > suffix.len()
|
||||
&& d[d.len() - suffix.len()..].eq_ignore_ascii_case(suffix)
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
impl IpFilter {
|
||||
/// Create a new IP filter from allow and block lists.
|
||||
pub fn new(allow_list: &[String], block_list: &[String]) -> Self {
|
||||
/// Create a new IP filter from allow entries and a block list.
|
||||
pub fn new(allow_entries: &[IpAllowEntry], block_list: &[String]) -> Self {
|
||||
let mut allow_list = Vec::new();
|
||||
let mut domain_scoped = Vec::new();
|
||||
|
||||
for entry in allow_entries {
|
||||
match entry {
|
||||
IpAllowEntry::Plain(ip) => {
|
||||
allow_list.push(IpPattern::parse(ip));
|
||||
}
|
||||
IpAllowEntry::DomainScoped { ip, domains } => {
|
||||
domain_scoped.push(DomainScopedEntry {
|
||||
pattern: IpPattern::parse(ip),
|
||||
domains: domains.clone(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Self {
|
||||
allow_list: allow_list.iter().map(|s| IpPattern::parse(s)).collect(),
|
||||
allow_list,
|
||||
domain_scoped,
|
||||
block_list: block_list.iter().map(|s| IpPattern::parse(s)).collect(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if an IP is allowed.
|
||||
/// If allow_list is non-empty, IP must match at least one entry.
|
||||
/// If block_list is non-empty, IP must NOT match any entry.
|
||||
pub fn is_allowed(&self, ip: &IpAddr) -> bool {
|
||||
/// Check if an IP is allowed, considering domain-scoped entries.
|
||||
/// If `domain` is Some, domain-scoped entries are evaluated against it.
|
||||
/// If `domain` is None, only plain allow entries are considered.
|
||||
pub fn is_allowed_for_domain(&self, ip: &IpAddr, domain: Option<&str>) -> bool {
|
||||
// Check block list first
|
||||
if !self.block_list.is_empty() {
|
||||
for pattern in &self.block_list {
|
||||
@@ -70,14 +115,36 @@ impl IpFilter {
|
||||
}
|
||||
}
|
||||
|
||||
// If allow list is non-empty, must match at least one
|
||||
if !self.allow_list.is_empty() {
|
||||
return self.allow_list.iter().any(|p| p.matches(ip));
|
||||
// If there are any allow entries (plain or domain-scoped), IP must match
|
||||
let has_any_allow = !self.allow_list.is_empty() || !self.domain_scoped.is_empty();
|
||||
if has_any_allow {
|
||||
// Check plain allow list — grants access to entire route
|
||||
if self.allow_list.iter().any(|p| p.matches(ip)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Check domain-scoped entries — grants access only if domain matches
|
||||
if let Some(req_domain) = domain {
|
||||
for entry in &self.domain_scoped {
|
||||
if entry.pattern.matches(ip) {
|
||||
if entry.domains.iter().any(|d| domain_matches_pattern(d, req_domain)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
/// Check if an IP is allowed (backwards-compat wrapper, no domain context).
|
||||
pub fn is_allowed(&self, ip: &IpAddr) -> bool {
|
||||
self.is_allowed_for_domain(ip, None)
|
||||
}
|
||||
|
||||
/// Normalize IPv4-mapped IPv6 addresses (::ffff:x.x.x.x -> x.x.x.x)
|
||||
pub fn normalize_ip(ip: &IpAddr) -> IpAddr {
|
||||
match ip {
|
||||
@@ -97,19 +164,28 @@ impl IpFilter {
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn plain(s: &str) -> IpAllowEntry {
|
||||
IpAllowEntry::Plain(s.to_string())
|
||||
}
|
||||
|
||||
fn scoped(ip: &str, domains: &[&str]) -> IpAllowEntry {
|
||||
IpAllowEntry::DomainScoped {
|
||||
ip: ip.to_string(),
|
||||
domains: domains.iter().map(|s| s.to_string()).collect(),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_lists_allow_all() {
|
||||
let filter = IpFilter::new(&[], &[]);
|
||||
let ip: IpAddr = "192.168.1.1".parse().unwrap();
|
||||
assert!(filter.is_allowed(&ip));
|
||||
assert!(filter.is_allowed_for_domain(&ip, Some("example.com")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_allow_list_exact() {
|
||||
let filter = IpFilter::new(
|
||||
&["10.0.0.1".to_string()],
|
||||
&[],
|
||||
);
|
||||
fn test_plain_allow_list_exact() {
|
||||
let filter = IpFilter::new(&[plain("10.0.0.1")], &[]);
|
||||
let allowed: IpAddr = "10.0.0.1".parse().unwrap();
|
||||
let denied: IpAddr = "10.0.0.2".parse().unwrap();
|
||||
assert!(filter.is_allowed(&allowed));
|
||||
@@ -117,11 +193,8 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_allow_list_cidr() {
|
||||
let filter = IpFilter::new(
|
||||
&["10.0.0.0/8".to_string()],
|
||||
&[],
|
||||
);
|
||||
fn test_plain_allow_list_cidr() {
|
||||
let filter = IpFilter::new(&[plain("10.0.0.0/8")], &[]);
|
||||
let allowed: IpAddr = "10.255.255.255".parse().unwrap();
|
||||
let denied: IpAddr = "192.168.1.1".parse().unwrap();
|
||||
assert!(filter.is_allowed(&allowed));
|
||||
@@ -130,10 +203,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_block_list() {
|
||||
let filter = IpFilter::new(
|
||||
&[],
|
||||
&["192.168.1.100".to_string()],
|
||||
);
|
||||
let filter = IpFilter::new(&[], &["192.168.1.100".to_string()]);
|
||||
let blocked: IpAddr = "192.168.1.100".parse().unwrap();
|
||||
let allowed: IpAddr = "192.168.1.101".parse().unwrap();
|
||||
assert!(!filter.is_allowed(&blocked));
|
||||
@@ -143,7 +213,7 @@ mod tests {
|
||||
#[test]
|
||||
fn test_block_trumps_allow() {
|
||||
let filter = IpFilter::new(
|
||||
&["10.0.0.0/8".to_string()],
|
||||
&[plain("10.0.0.0/8")],
|
||||
&["10.0.0.5".to_string()],
|
||||
);
|
||||
let blocked: IpAddr = "10.0.0.5".parse().unwrap();
|
||||
@@ -154,20 +224,14 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_wildcard_allow() {
|
||||
let filter = IpFilter::new(
|
||||
&["*".to_string()],
|
||||
&[],
|
||||
);
|
||||
let filter = IpFilter::new(&[plain("*")], &[]);
|
||||
let ip: IpAddr = "1.2.3.4".parse().unwrap();
|
||||
assert!(filter.is_allowed(&ip));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_wildcard_block() {
|
||||
let filter = IpFilter::new(
|
||||
&[],
|
||||
&["*".to_string()],
|
||||
);
|
||||
let filter = IpFilter::new(&[], &["*".to_string()]);
|
||||
let ip: IpAddr = "1.2.3.4".parse().unwrap();
|
||||
assert!(!filter.is_allowed(&ip));
|
||||
}
|
||||
@@ -186,4 +250,97 @@ mod tests {
|
||||
let normalized = IpFilter::normalize_ip(&ip);
|
||||
assert_eq!(normalized, ip);
|
||||
}
|
||||
|
||||
// Domain-scoped tests
|
||||
|
||||
#[test]
|
||||
fn test_domain_scoped_allows_matching_domain() {
|
||||
let filter = IpFilter::new(
|
||||
&[scoped("10.8.0.2", &["outline.abc.xyz"])],
|
||||
&[],
|
||||
);
|
||||
let ip: IpAddr = "10.8.0.2".parse().unwrap();
|
||||
assert!(filter.is_allowed_for_domain(&ip, Some("outline.abc.xyz")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_domain_scoped_denies_non_matching_domain() {
|
||||
let filter = IpFilter::new(
|
||||
&[scoped("10.8.0.2", &["outline.abc.xyz"])],
|
||||
&[],
|
||||
);
|
||||
let ip: IpAddr = "10.8.0.2".parse().unwrap();
|
||||
assert!(!filter.is_allowed_for_domain(&ip, Some("app.abc.xyz")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_domain_scoped_denies_without_domain() {
|
||||
let filter = IpFilter::new(
|
||||
&[scoped("10.8.0.2", &["outline.abc.xyz"])],
|
||||
&[],
|
||||
);
|
||||
let ip: IpAddr = "10.8.0.2".parse().unwrap();
|
||||
// Without domain context, domain-scoped entries cannot match
|
||||
assert!(!filter.is_allowed_for_domain(&ip, None));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_domain_scoped_wildcard_domain() {
|
||||
let filter = IpFilter::new(
|
||||
&[scoped("10.8.0.2", &["*.abc.xyz"])],
|
||||
&[],
|
||||
);
|
||||
let ip: IpAddr = "10.8.0.2".parse().unwrap();
|
||||
assert!(filter.is_allowed_for_domain(&ip, Some("outline.abc.xyz")));
|
||||
assert!(filter.is_allowed_for_domain(&ip, Some("app.abc.xyz")));
|
||||
assert!(!filter.is_allowed_for_domain(&ip, Some("other.com")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_plain_and_domain_scoped_coexist() {
|
||||
let filter = IpFilter::new(
|
||||
&[
|
||||
plain("1.2.3.4"), // full route access
|
||||
scoped("10.8.0.2", &["outline.abc.xyz"]), // scoped access
|
||||
],
|
||||
&[],
|
||||
);
|
||||
|
||||
let admin: IpAddr = "1.2.3.4".parse().unwrap();
|
||||
let vpn: IpAddr = "10.8.0.2".parse().unwrap();
|
||||
let other: IpAddr = "9.9.9.9".parse().unwrap();
|
||||
|
||||
// Admin IP has full access
|
||||
assert!(filter.is_allowed_for_domain(&admin, Some("anything.abc.xyz")));
|
||||
assert!(filter.is_allowed_for_domain(&admin, Some("outline.abc.xyz")));
|
||||
|
||||
// VPN IP only has scoped access
|
||||
assert!(filter.is_allowed_for_domain(&vpn, Some("outline.abc.xyz")));
|
||||
assert!(!filter.is_allowed_for_domain(&vpn, Some("app.abc.xyz")));
|
||||
|
||||
// Unknown IP denied
|
||||
assert!(!filter.is_allowed_for_domain(&other, Some("outline.abc.xyz")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_block_trumps_domain_scoped() {
|
||||
let filter = IpFilter::new(
|
||||
&[scoped("10.8.0.2", &["outline.abc.xyz"])],
|
||||
&["10.8.0.2".to_string()],
|
||||
);
|
||||
let ip: IpAddr = "10.8.0.2".parse().unwrap();
|
||||
assert!(!filter.is_allowed_for_domain(&ip, Some("outline.abc.xyz")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_domain_matches_pattern_fn() {
|
||||
assert!(domain_matches_pattern("example.com", "example.com"));
|
||||
assert!(domain_matches_pattern("*.abc.xyz", "outline.abc.xyz"));
|
||||
assert!(domain_matches_pattern("*.abc.xyz", "app.abc.xyz"));
|
||||
assert!(!domain_matches_pattern("*.abc.xyz", "abc.xyz")); // suffix only, not exact parent
|
||||
assert!(domain_matches_pattern("*", "anything.com"));
|
||||
assert!(!domain_matches_pattern("outline.abc.xyz", "app.abc.xyz"));
|
||||
// Case insensitive
|
||||
assert!(domain_matches_pattern("*.ABC.XYZ", "outline.abc.xyz"));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -198,7 +198,9 @@ impl RustProxy {
|
||||
};
|
||||
|
||||
if let Some(ref allow_list) = default_security.ip_allow_list {
|
||||
security.ip_allow_list = Some(allow_list.clone());
|
||||
security.ip_allow_list = Some(
|
||||
allow_list.iter().map(|s| rustproxy_config::IpAllowEntry::Plain(s.clone())).collect()
|
||||
);
|
||||
}
|
||||
if let Some(ref block_list) = default_security.ip_block_list {
|
||||
security.ip_block_list = Some(block_list.clone());
|
||||
@@ -356,12 +358,17 @@ impl RustProxy {
|
||||
|
||||
// Bind UDP ports (if any)
|
||||
if !udp_ports.is_empty() {
|
||||
let conn_tracker = self.listener_manager.as_ref().unwrap().conn_tracker().clone();
|
||||
let tcp_mgr = self.listener_manager.as_ref().unwrap();
|
||||
let conn_tracker = tcp_mgr.conn_tracker().clone();
|
||||
let route_cancels = tcp_mgr.route_cancels().clone();
|
||||
let connection_registry = tcp_mgr.connection_registry().clone();
|
||||
let mut udp_mgr = UdpListenerManager::new(
|
||||
Arc::clone(&*self.route_table.load()),
|
||||
Arc::clone(&self.metrics),
|
||||
conn_tracker,
|
||||
self.cancel_token.clone(),
|
||||
route_cancels,
|
||||
connection_registry,
|
||||
);
|
||||
udp_mgr.set_proxy_ips(udp_proxy_ips.clone());
|
||||
|
||||
@@ -707,6 +714,9 @@ impl RustProxy {
|
||||
.collect();
|
||||
self.metrics.retain_backends(&active_backends);
|
||||
|
||||
// Capture old route manager for diff-based connection recycling
|
||||
let old_manager = self.route_table.load_full();
|
||||
|
||||
// Atomically swap the route table
|
||||
let new_manager = Arc::new(new_manager);
|
||||
self.route_table.store(Arc::clone(&new_manager));
|
||||
@@ -742,9 +752,47 @@ impl RustProxy {
|
||||
listener.update_route_manager(Arc::clone(&new_manager));
|
||||
// Cancel connections on routes that were removed or disabled
|
||||
listener.invalidate_removed_routes(&active_route_ids);
|
||||
// Clean up registry entries for removed routes
|
||||
listener.connection_registry().cleanup_removed_routes(&active_route_ids);
|
||||
// Prune HTTP proxy caches (rate limiters, regex cache, round-robin counters)
|
||||
listener.prune_http_proxy_caches(&active_route_ids);
|
||||
|
||||
// Diff-based connection recycling for changed routes
|
||||
{
|
||||
let registry = listener.connection_registry();
|
||||
for new_route in &routes {
|
||||
let new_id = match &new_route.id {
|
||||
Some(id) => id.as_str(),
|
||||
None => continue,
|
||||
};
|
||||
// Find corresponding old route
|
||||
let old_route = old_manager.routes().iter().find(|r| {
|
||||
r.id.as_deref() == Some(new_id)
|
||||
});
|
||||
let old_route = match old_route {
|
||||
Some(r) => r,
|
||||
None => continue, // new route, no existing connections to recycle
|
||||
};
|
||||
|
||||
// Security diff: re-evaluate existing connections' IPs
|
||||
let old_sec = serde_json::to_string(&old_route.security).ok();
|
||||
let new_sec = serde_json::to_string(&new_route.security).ok();
|
||||
if old_sec != new_sec {
|
||||
if let Some(ref security) = new_route.security {
|
||||
registry.recycle_for_security_change(new_id, security);
|
||||
}
|
||||
// If security removed entirely (became more permissive), no recycling needed
|
||||
}
|
||||
|
||||
// Action diff (targets, TLS mode, etc.): recycle all connections on route
|
||||
let old_action = serde_json::to_string(&old_route.action).ok();
|
||||
let new_action = serde_json::to_string(&new_route.action).ok();
|
||||
if old_action != new_action {
|
||||
registry.recycle_for_route_change(new_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add new ports
|
||||
for port in &new_ports {
|
||||
if !old_ports.contains(port) {
|
||||
@@ -787,14 +835,22 @@ impl RustProxy {
|
||||
if self.udp_listener_manager.is_none() {
|
||||
if let Some(ref listener) = self.listener_manager {
|
||||
let conn_tracker = listener.conn_tracker().clone();
|
||||
let route_cancels = listener.route_cancels().clone();
|
||||
let connection_registry = listener.connection_registry().clone();
|
||||
let conn_config = Self::build_connection_config(&self.options);
|
||||
let mut udp_mgr = UdpListenerManager::new(
|
||||
Arc::clone(&new_manager),
|
||||
Arc::clone(&self.metrics),
|
||||
conn_tracker,
|
||||
self.cancel_token.clone(),
|
||||
route_cancels,
|
||||
connection_registry,
|
||||
);
|
||||
udp_mgr.set_proxy_ips(conn_config.proxy_ips);
|
||||
// Wire up H3ProxyService so QUIC connections can serve HTTP/3
|
||||
let http_proxy = listener.http_proxy().clone();
|
||||
let h3_svc = rustproxy_http::h3_service::H3ProxyService::new(http_proxy);
|
||||
udp_mgr.set_h3_service(std::sync::Arc::new(h3_svc));
|
||||
self.udp_listener_manager = Some(udp_mgr);
|
||||
}
|
||||
}
|
||||
@@ -1096,6 +1152,10 @@ impl RustProxy {
|
||||
}
|
||||
|
||||
/// Load a certificate for a domain and hot-swap the TLS configuration.
|
||||
///
|
||||
/// If the cert PEM differs from the currently loaded cert for this domain,
|
||||
/// existing connections for the domain are gracefully recycled (GOAWAY for
|
||||
/// HTTP/2, Connection: close for HTTP/1.1, graceful FIN for TCP).
|
||||
pub async fn load_certificate(
|
||||
&mut self,
|
||||
domain: &str,
|
||||
@@ -1105,6 +1165,12 @@ impl RustProxy {
|
||||
) -> Result<()> {
|
||||
info!("Loading certificate for domain: {}", domain);
|
||||
|
||||
// Check if the cert actually changed (for selective connection recycling)
|
||||
let cert_changed = self.loaded_certs
|
||||
.get(domain)
|
||||
.map(|existing| existing.cert_pem != cert_pem)
|
||||
.unwrap_or(false); // new domain = no existing connections to recycle
|
||||
|
||||
// Store in cert manager if available
|
||||
if let Some(ref cm_arc) = self.cert_manager {
|
||||
let now = std::time::SystemTime::now()
|
||||
@@ -1153,6 +1219,13 @@ impl RustProxy {
|
||||
}
|
||||
}
|
||||
|
||||
// Recycle existing connections if cert actually changed
|
||||
if cert_changed {
|
||||
if let Some(ref listener) = self.listener_manager {
|
||||
listener.connection_registry().recycle_for_cert_change(domain);
|
||||
}
|
||||
}
|
||||
|
||||
info!("Certificate loaded and TLS config updated for {}", domain);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -537,6 +537,31 @@ tap.test('Route Matching - routeMatchesHeaders', async () => {
|
||||
'X-Custom-Header': 'value'
|
||||
})).toBeFalse();
|
||||
|
||||
const regexHeaderRoute: IRouteConfig = {
|
||||
match: {
|
||||
domains: 'example.com',
|
||||
ports: 80,
|
||||
headers: {
|
||||
'Content-Type': /^application\/(json|problem\+json)$/i,
|
||||
}
|
||||
},
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{
|
||||
host: 'localhost',
|
||||
port: 3000
|
||||
}]
|
||||
}
|
||||
};
|
||||
|
||||
expect(routeMatchesHeaders(regexHeaderRoute, {
|
||||
'Content-Type': 'Application/Problem+Json',
|
||||
})).toBeTrue();
|
||||
|
||||
expect(routeMatchesHeaders(regexHeaderRoute, {
|
||||
'Content-Type': 'text/html',
|
||||
})).toBeFalse();
|
||||
|
||||
// Route without header matching should match any headers
|
||||
const noHeaderRoute: IRouteConfig = {
|
||||
match: { ports: 80, domains: 'example.com' },
|
||||
|
||||
@@ -0,0 +1,192 @@
|
||||
import { expect, tap } from '@git.zone/tstest/tapbundle';
|
||||
|
||||
import type { ISmartProxyOptions } from '../ts/proxies/smart-proxy/models/interfaces.js';
|
||||
import type { IRouteConfig } from '../ts/proxies/smart-proxy/models/route-types.js';
|
||||
import { RoutePreprocessor } from '../ts/proxies/smart-proxy/route-preprocessor.js';
|
||||
import { buildRustProxyOptions } from '../ts/proxies/smart-proxy/utils/rust-config.js';
|
||||
|
||||
tap.test('Rust contract - preprocessor serializes regex headers for Rust', async () => {
|
||||
const route: IRouteConfig = {
|
||||
name: 'contract-route',
|
||||
match: {
|
||||
ports: [443, { from: 8443, to: 8444 }],
|
||||
domains: ['api.example.com', '*.example.com'],
|
||||
transport: 'udp',
|
||||
protocol: 'http3',
|
||||
headers: {
|
||||
'Content-Type': /^application\/json$/i,
|
||||
},
|
||||
},
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{
|
||||
match: {
|
||||
ports: [443],
|
||||
path: '/api/*',
|
||||
method: ['GET'],
|
||||
headers: {
|
||||
'X-Env': /^(prod|stage)$/,
|
||||
},
|
||||
},
|
||||
host: ['backend-a', 'backend-b'],
|
||||
port: 'preserve',
|
||||
sendProxyProtocol: true,
|
||||
backendTransport: 'tcp',
|
||||
}],
|
||||
tls: {
|
||||
mode: 'terminate',
|
||||
certificate: 'auto',
|
||||
},
|
||||
sendProxyProtocol: true,
|
||||
udp: {
|
||||
maxSessionsPerIP: 321,
|
||||
quic: {
|
||||
enableHttp3: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
security: {
|
||||
ipAllowList: [{
|
||||
ip: '10.0.0.0/8',
|
||||
domains: ['api.example.com'],
|
||||
}],
|
||||
},
|
||||
};
|
||||
|
||||
const preprocessor = new RoutePreprocessor();
|
||||
const [rustRoute] = preprocessor.preprocessForRust([route]);
|
||||
|
||||
expect(rustRoute.match.headers?.['Content-Type']).toEqual('/^application\\/json$/i');
|
||||
expect(rustRoute.match.transport).toEqual('udp');
|
||||
expect(rustRoute.match.protocol).toEqual('http3');
|
||||
expect(rustRoute.action.targets?.[0].match?.headers?.['X-Env']).toEqual('/^(prod|stage)$/');
|
||||
expect(rustRoute.action.targets?.[0].port).toEqual('preserve');
|
||||
expect(rustRoute.action.targets?.[0].backendTransport).toEqual('tcp');
|
||||
expect(rustRoute.action.sendProxyProtocol).toBeTrue();
|
||||
expect(rustRoute.action.udp?.maxSessionsPerIp).toEqual(321);
|
||||
});
|
||||
|
||||
tap.test('Rust contract - preprocessor converts dynamic targets to relay-safe payloads', async () => {
|
||||
const route: IRouteConfig = {
|
||||
name: 'dynamic-contract-route',
|
||||
match: {
|
||||
ports: 8080,
|
||||
},
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{
|
||||
host: () => 'dynamic-backend.internal',
|
||||
port: () => 9443,
|
||||
}],
|
||||
},
|
||||
};
|
||||
|
||||
const preprocessor = new RoutePreprocessor();
|
||||
const [rustRoute] = preprocessor.preprocessForRust([route]);
|
||||
|
||||
expect(rustRoute.action.type).toEqual('socket-handler');
|
||||
expect(rustRoute.action.targets?.[0].host).toEqual('localhost');
|
||||
expect(rustRoute.action.targets?.[0].port).toEqual(0);
|
||||
expect(preprocessor.getOriginalRoute('dynamic-contract-route')).toEqual(route);
|
||||
});
|
||||
|
||||
tap.test('Rust contract - top-level config keeps shared SmartProxy settings', async () => {
|
||||
const settings: ISmartProxyOptions = {
|
||||
routes: [{
|
||||
name: 'top-level-contract-route',
|
||||
match: {
|
||||
ports: 443,
|
||||
domains: 'api.example.com',
|
||||
},
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{
|
||||
host: 'backend.internal',
|
||||
port: 8443,
|
||||
}],
|
||||
tls: {
|
||||
mode: 'terminate',
|
||||
certificate: 'auto',
|
||||
},
|
||||
},
|
||||
}],
|
||||
preserveSourceIP: true,
|
||||
proxyIPs: ['10.0.0.1'],
|
||||
acceptProxyProtocol: true,
|
||||
sendProxyProtocol: true,
|
||||
noDelay: true,
|
||||
keepAlive: true,
|
||||
keepAliveInitialDelay: 1500,
|
||||
maxPendingDataSize: 4096,
|
||||
disableInactivityCheck: true,
|
||||
enableKeepAliveProbes: true,
|
||||
enableDetailedLogging: true,
|
||||
enableTlsDebugLogging: true,
|
||||
enableRandomizedTimeouts: true,
|
||||
connectionTimeout: 5000,
|
||||
initialDataTimeout: 7000,
|
||||
socketTimeout: 9000,
|
||||
inactivityCheckInterval: 1100,
|
||||
maxConnectionLifetime: 13000,
|
||||
inactivityTimeout: 15000,
|
||||
gracefulShutdownTimeout: 17000,
|
||||
maxConnectionsPerIP: 20,
|
||||
connectionRateLimitPerMinute: 30,
|
||||
keepAliveTreatment: 'extended',
|
||||
keepAliveInactivityMultiplier: 2,
|
||||
extendedKeepAliveLifetime: 19000,
|
||||
metrics: {
|
||||
enabled: true,
|
||||
sampleIntervalMs: 250,
|
||||
retentionSeconds: 60,
|
||||
},
|
||||
acme: {
|
||||
enabled: true,
|
||||
email: 'ops@example.com',
|
||||
environment: 'staging',
|
||||
useProduction: false,
|
||||
skipConfiguredCerts: true,
|
||||
renewThresholdDays: 14,
|
||||
renewCheckIntervalHours: 12,
|
||||
autoRenew: true,
|
||||
port: 80,
|
||||
},
|
||||
};
|
||||
|
||||
const preprocessor = new RoutePreprocessor();
|
||||
const routes = preprocessor.preprocessForRust(settings.routes);
|
||||
const config = buildRustProxyOptions(settings, routes);
|
||||
|
||||
expect(config.preserveSourceIp).toBeTrue();
|
||||
expect(config.proxyIps).toEqual(['10.0.0.1']);
|
||||
expect(config.acceptProxyProtocol).toBeTrue();
|
||||
expect(config.sendProxyProtocol).toBeTrue();
|
||||
expect(config.noDelay).toBeTrue();
|
||||
expect(config.keepAlive).toBeTrue();
|
||||
expect(config.keepAliveInitialDelay).toEqual(1500);
|
||||
expect(config.maxPendingDataSize).toEqual(4096);
|
||||
expect(config.disableInactivityCheck).toBeTrue();
|
||||
expect(config.enableKeepAliveProbes).toBeTrue();
|
||||
expect(config.enableDetailedLogging).toBeTrue();
|
||||
expect(config.enableTlsDebugLogging).toBeTrue();
|
||||
expect(config.enableRandomizedTimeouts).toBeTrue();
|
||||
expect(config.connectionTimeout).toEqual(5000);
|
||||
expect(config.initialDataTimeout).toEqual(7000);
|
||||
expect(config.socketTimeout).toEqual(9000);
|
||||
expect(config.inactivityCheckInterval).toEqual(1100);
|
||||
expect(config.maxConnectionLifetime).toEqual(13000);
|
||||
expect(config.inactivityTimeout).toEqual(15000);
|
||||
expect(config.gracefulShutdownTimeout).toEqual(17000);
|
||||
expect(config.maxConnectionsPerIp).toEqual(20);
|
||||
expect(config.connectionRateLimitPerMinute).toEqual(30);
|
||||
expect(config.keepAliveTreatment).toEqual('extended');
|
||||
expect(config.keepAliveInactivityMultiplier).toEqual(2);
|
||||
expect(config.extendedKeepAliveLifetime).toEqual(19000);
|
||||
expect(config.metrics?.sampleIntervalMs).toEqual(250);
|
||||
expect(config.acme?.email).toEqual('ops@example.com');
|
||||
expect(config.acme?.environment).toEqual('staging');
|
||||
expect(config.acme?.skipConfiguredCerts).toBeTrue();
|
||||
expect(config.acme?.renewThresholdDays).toEqual(14);
|
||||
});
|
||||
|
||||
export default tap.start();
|
||||
@@ -0,0 +1,418 @@
|
||||
import { tap, expect } from '@git.zone/tstest/tapbundle';
|
||||
import { SmartProxy } from '../ts/index.js';
|
||||
import * as http from 'http';
|
||||
import WebSocket, { WebSocketServer } from 'ws';
|
||||
import { findFreePorts, assertPortsFree } from './helpers/port-allocator.js';
|
||||
|
||||
/**
|
||||
* Helper: create a WebSocket client that connects through the proxy.
|
||||
* Registers the message handler BEFORE awaiting open to avoid race conditions.
|
||||
*/
|
||||
function connectWs(
|
||||
url: string,
|
||||
headers: Record<string, string> = {},
|
||||
opts: WebSocket.ClientOptions = {},
|
||||
): { ws: WebSocket; messages: string[]; opened: Promise<void> } {
|
||||
const messages: string[] = [];
|
||||
const ws = new WebSocket(url, { headers, ...opts });
|
||||
|
||||
// Register message handler immediately — before open fires
|
||||
ws.on('message', (data) => {
|
||||
messages.push(data.toString());
|
||||
});
|
||||
|
||||
const opened = new Promise<void>((resolve, reject) => {
|
||||
const timeout = setTimeout(() => reject(new Error('WebSocket open timeout')), 5000);
|
||||
ws.on('open', () => { clearTimeout(timeout); resolve(); });
|
||||
ws.on('error', (err) => { clearTimeout(timeout); reject(err); });
|
||||
});
|
||||
|
||||
return { ws, messages, opened };
|
||||
}
|
||||
|
||||
/** Wait until `predicate` returns true, with a hard timeout. */
|
||||
function waitFor(predicate: () => boolean, timeoutMs = 5000): Promise<void> {
|
||||
return new Promise((resolve, reject) => {
|
||||
const deadline = setTimeout(() => reject(new Error('waitFor timeout')), timeoutMs);
|
||||
const check = () => {
|
||||
if (predicate()) { clearTimeout(deadline); resolve(); }
|
||||
else setTimeout(check, 30);
|
||||
};
|
||||
check();
|
||||
});
|
||||
}
|
||||
|
||||
/** Graceful close helper */
|
||||
function closeWs(ws: WebSocket): Promise<void> {
|
||||
return new Promise((resolve) => {
|
||||
if (ws.readyState === WebSocket.CLOSED) return resolve();
|
||||
ws.on('close', () => resolve());
|
||||
ws.close();
|
||||
setTimeout(resolve, 2000); // fallback
|
||||
});
|
||||
}
|
||||
|
||||
// ─── Test 1: Basic WebSocket upgrade and bidirectional messaging ───
|
||||
tap.test('should proxy WebSocket connections with bidirectional messaging', async () => {
|
||||
const [PROXY_PORT, BACKEND_PORT] = await findFreePorts(2);
|
||||
|
||||
// Backend: echoes messages with prefix, sends greeting on connect
|
||||
const backendServer = http.createServer();
|
||||
const wss = new WebSocketServer({ server: backendServer });
|
||||
const backendMessages: string[] = [];
|
||||
|
||||
wss.on('connection', (ws) => {
|
||||
ws.on('message', (data) => {
|
||||
const msg = data.toString();
|
||||
backendMessages.push(msg);
|
||||
ws.send(`echo: ${msg}`);
|
||||
});
|
||||
ws.send('hello from backend');
|
||||
});
|
||||
|
||||
await new Promise<void>((resolve) => {
|
||||
backendServer.listen(BACKEND_PORT, '127.0.0.1', () => resolve());
|
||||
});
|
||||
|
||||
const proxy = new SmartProxy({
|
||||
routes: [{
|
||||
name: 'ws-test-route',
|
||||
match: { ports: PROXY_PORT },
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{ host: '127.0.0.1', port: BACKEND_PORT }],
|
||||
websocket: { enabled: true },
|
||||
},
|
||||
}],
|
||||
});
|
||||
await proxy.start();
|
||||
|
||||
// Connect client — message handler registered before open
|
||||
const { ws, messages, opened } = connectWs(
|
||||
`ws://127.0.0.1:${PROXY_PORT}/`,
|
||||
{ Host: 'test.local' },
|
||||
);
|
||||
await opened;
|
||||
|
||||
// Wait for the backend greeting
|
||||
await waitFor(() => messages.length >= 1);
|
||||
expect(messages[0]).toEqual('hello from backend');
|
||||
|
||||
// Send 3 messages, expect 3 echoes
|
||||
ws.send('ping 1');
|
||||
ws.send('ping 2');
|
||||
ws.send('ping 3');
|
||||
|
||||
await waitFor(() => messages.length >= 4);
|
||||
|
||||
expect(messages).toContain('echo: ping 1');
|
||||
expect(messages).toContain('echo: ping 2');
|
||||
expect(messages).toContain('echo: ping 3');
|
||||
expect(backendMessages).toInclude('ping 1');
|
||||
expect(backendMessages).toInclude('ping 2');
|
||||
expect(backendMessages).toInclude('ping 3');
|
||||
|
||||
await closeWs(ws);
|
||||
await proxy.stop();
|
||||
await new Promise<void>((resolve) => backendServer.close(() => resolve()));
|
||||
await new Promise((r) => setTimeout(r, 500));
|
||||
await assertPortsFree([PROXY_PORT, BACKEND_PORT]);
|
||||
});
|
||||
|
||||
// ─── Test 2: Multiple concurrent WebSocket connections ───
|
||||
tap.test('should handle multiple concurrent WebSocket connections', async () => {
|
||||
const [PROXY_PORT, BACKEND_PORT] = await findFreePorts(2);
|
||||
|
||||
const backendServer = http.createServer();
|
||||
const wss = new WebSocketServer({ server: backendServer });
|
||||
|
||||
let connectionCount = 0;
|
||||
wss.on('connection', (ws) => {
|
||||
const id = ++connectionCount;
|
||||
ws.on('message', (data) => {
|
||||
ws.send(`conn${id}: ${data.toString()}`);
|
||||
});
|
||||
});
|
||||
|
||||
await new Promise<void>((resolve) => {
|
||||
backendServer.listen(BACKEND_PORT, '127.0.0.1', () => resolve());
|
||||
});
|
||||
|
||||
const proxy = new SmartProxy({
|
||||
routes: [{
|
||||
name: 'ws-multi-route',
|
||||
match: { ports: PROXY_PORT },
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{ host: '127.0.0.1', port: BACKEND_PORT }],
|
||||
websocket: { enabled: true },
|
||||
},
|
||||
}],
|
||||
});
|
||||
await proxy.start();
|
||||
|
||||
const NUM_CLIENTS = 5;
|
||||
const clients: { ws: WebSocket; messages: string[] }[] = [];
|
||||
|
||||
for (let i = 0; i < NUM_CLIENTS; i++) {
|
||||
const c = connectWs(
|
||||
`ws://127.0.0.1:${PROXY_PORT}/`,
|
||||
{ Host: 'test.local' },
|
||||
);
|
||||
await c.opened;
|
||||
clients.push(c);
|
||||
}
|
||||
|
||||
// Each client sends a unique message
|
||||
for (let i = 0; i < NUM_CLIENTS; i++) {
|
||||
clients[i].ws.send(`hello from client ${i}`);
|
||||
}
|
||||
|
||||
// Wait for all replies
|
||||
await waitFor(() => clients.every((c) => c.messages.length >= 1));
|
||||
|
||||
for (let i = 0; i < NUM_CLIENTS; i++) {
|
||||
expect(clients[i].messages.length).toBeGreaterThanOrEqual(1);
|
||||
expect(clients[i].messages[0]).toInclude(`hello from client ${i}`);
|
||||
}
|
||||
expect(connectionCount).toEqual(NUM_CLIENTS);
|
||||
|
||||
for (const c of clients) await closeWs(c.ws);
|
||||
await proxy.stop();
|
||||
await new Promise<void>((resolve) => backendServer.close(() => resolve()));
|
||||
await new Promise((r) => setTimeout(r, 500));
|
||||
await assertPortsFree([PROXY_PORT, BACKEND_PORT]);
|
||||
});
|
||||
|
||||
// ─── Test 3: WebSocket with binary data ───
|
||||
tap.test('should proxy binary WebSocket frames', async () => {
|
||||
const [PROXY_PORT, BACKEND_PORT] = await findFreePorts(2);
|
||||
|
||||
const backendServer = http.createServer();
|
||||
const wss = new WebSocketServer({ server: backendServer });
|
||||
|
||||
wss.on('connection', (ws) => {
|
||||
ws.on('message', (data) => {
|
||||
ws.send(data, { binary: true });
|
||||
});
|
||||
});
|
||||
|
||||
await new Promise<void>((resolve) => {
|
||||
backendServer.listen(BACKEND_PORT, '127.0.0.1', () => resolve());
|
||||
});
|
||||
|
||||
const proxy = new SmartProxy({
|
||||
routes: [{
|
||||
name: 'ws-binary-route',
|
||||
match: { ports: PROXY_PORT },
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{ host: '127.0.0.1', port: BACKEND_PORT }],
|
||||
websocket: { enabled: true },
|
||||
},
|
||||
}],
|
||||
});
|
||||
await proxy.start();
|
||||
|
||||
const receivedBuffers: Buffer[] = [];
|
||||
const ws = new WebSocket(`ws://127.0.0.1:${PROXY_PORT}/`, {
|
||||
headers: { Host: 'test.local' },
|
||||
});
|
||||
ws.on('message', (data) => {
|
||||
receivedBuffers.push(Buffer.from(data as ArrayBuffer));
|
||||
});
|
||||
|
||||
await new Promise<void>((resolve, reject) => {
|
||||
const timeout = setTimeout(() => reject(new Error('timeout')), 5000);
|
||||
ws.on('open', () => { clearTimeout(timeout); resolve(); });
|
||||
ws.on('error', (err) => { clearTimeout(timeout); reject(err); });
|
||||
});
|
||||
|
||||
// Send a 256-byte buffer with known content
|
||||
const sentBuffer = Buffer.alloc(256);
|
||||
for (let i = 0; i < 256; i++) sentBuffer[i] = i;
|
||||
ws.send(sentBuffer);
|
||||
|
||||
await waitFor(() => receivedBuffers.length >= 1);
|
||||
|
||||
expect(receivedBuffers[0].length).toEqual(256);
|
||||
expect(Buffer.compare(receivedBuffers[0], sentBuffer)).toEqual(0);
|
||||
|
||||
await closeWs(ws);
|
||||
await proxy.stop();
|
||||
await new Promise<void>((resolve) => backendServer.close(() => resolve()));
|
||||
await new Promise((r) => setTimeout(r, 500));
|
||||
await assertPortsFree([PROXY_PORT, BACKEND_PORT]);
|
||||
});
|
||||
|
||||
// ─── Test 4: WebSocket path and query string preserved ───
|
||||
tap.test('should preserve path and query string through proxy', async () => {
|
||||
const [PROXY_PORT, BACKEND_PORT] = await findFreePorts(2);
|
||||
|
||||
const backendServer = http.createServer();
|
||||
const wss = new WebSocketServer({ server: backendServer });
|
||||
|
||||
let receivedUrl = '';
|
||||
wss.on('connection', (ws, req) => {
|
||||
receivedUrl = req.url || '';
|
||||
ws.send(`url: ${receivedUrl}`);
|
||||
});
|
||||
|
||||
await new Promise<void>((resolve) => {
|
||||
backendServer.listen(BACKEND_PORT, '127.0.0.1', () => resolve());
|
||||
});
|
||||
|
||||
const proxy = new SmartProxy({
|
||||
routes: [{
|
||||
name: 'ws-path-route',
|
||||
match: { ports: PROXY_PORT },
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{ host: '127.0.0.1', port: BACKEND_PORT }],
|
||||
websocket: { enabled: true },
|
||||
},
|
||||
}],
|
||||
});
|
||||
await proxy.start();
|
||||
|
||||
const { ws, messages, opened } = connectWs(
|
||||
`ws://127.0.0.1:${PROXY_PORT}/chat/room1?token=abc123`,
|
||||
{ Host: 'test.local' },
|
||||
);
|
||||
await opened;
|
||||
|
||||
await waitFor(() => messages.length >= 1);
|
||||
|
||||
expect(receivedUrl).toEqual('/chat/room1?token=abc123');
|
||||
expect(messages[0]).toEqual('url: /chat/room1?token=abc123');
|
||||
|
||||
await closeWs(ws);
|
||||
await proxy.stop();
|
||||
await new Promise<void>((resolve) => backendServer.close(() => resolve()));
|
||||
await new Promise((r) => setTimeout(r, 500));
|
||||
await assertPortsFree([PROXY_PORT, BACKEND_PORT]);
|
||||
});
|
||||
|
||||
// ─── Test 5: Clean close propagation ───
|
||||
tap.test('should handle clean WebSocket close from client', async () => {
|
||||
const [PROXY_PORT, BACKEND_PORT] = await findFreePorts(2);
|
||||
|
||||
const backendServer = http.createServer();
|
||||
const wss = new WebSocketServer({ server: backendServer });
|
||||
|
||||
let backendGotClose = false;
|
||||
let backendCloseCode = 0;
|
||||
wss.on('connection', (ws) => {
|
||||
ws.on('close', (code) => {
|
||||
backendGotClose = true;
|
||||
backendCloseCode = code;
|
||||
});
|
||||
ws.on('message', (data) => {
|
||||
ws.send(data);
|
||||
});
|
||||
});
|
||||
|
||||
await new Promise<void>((resolve) => {
|
||||
backendServer.listen(BACKEND_PORT, '127.0.0.1', () => resolve());
|
||||
});
|
||||
|
||||
const proxy = new SmartProxy({
|
||||
routes: [{
|
||||
name: 'ws-close-route',
|
||||
match: { ports: PROXY_PORT },
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{ host: '127.0.0.1', port: BACKEND_PORT }],
|
||||
websocket: { enabled: true },
|
||||
},
|
||||
}],
|
||||
});
|
||||
await proxy.start();
|
||||
|
||||
const { ws, messages, opened } = connectWs(
|
||||
`ws://127.0.0.1:${PROXY_PORT}/`,
|
||||
{ Host: 'test.local' },
|
||||
);
|
||||
await opened;
|
||||
|
||||
// Confirm connection works with a round-trip
|
||||
ws.send('test');
|
||||
await waitFor(() => messages.length >= 1);
|
||||
|
||||
// Close with code 1000
|
||||
let clientCloseCode = 0;
|
||||
const closed = new Promise<void>((resolve) => {
|
||||
ws.on('close', (code) => {
|
||||
clientCloseCode = code;
|
||||
resolve();
|
||||
});
|
||||
setTimeout(resolve, 3000);
|
||||
});
|
||||
ws.close(1000, 'done');
|
||||
await closed;
|
||||
|
||||
// Wait for backend to register
|
||||
await waitFor(() => backendGotClose, 3000);
|
||||
|
||||
expect(backendGotClose).toBeTrue();
|
||||
expect(clientCloseCode).toEqual(1000);
|
||||
|
||||
await proxy.stop();
|
||||
await new Promise<void>((resolve) => backendServer.close(() => resolve()));
|
||||
await new Promise((r) => setTimeout(r, 500));
|
||||
await assertPortsFree([PROXY_PORT, BACKEND_PORT]);
|
||||
});
|
||||
|
||||
// ─── Test 6: Large messages ───
|
||||
tap.test('should handle large WebSocket messages', async () => {
|
||||
const [PROXY_PORT, BACKEND_PORT] = await findFreePorts(2);
|
||||
|
||||
const backendServer = http.createServer();
|
||||
const wss = new WebSocketServer({ server: backendServer, maxPayload: 5 * 1024 * 1024 });
|
||||
|
||||
wss.on('connection', (ws) => {
|
||||
ws.on('message', (data) => {
|
||||
const buf = Buffer.from(data as ArrayBuffer);
|
||||
ws.send(`received ${buf.length} bytes`);
|
||||
});
|
||||
});
|
||||
|
||||
await new Promise<void>((resolve) => {
|
||||
backendServer.listen(BACKEND_PORT, '127.0.0.1', () => resolve());
|
||||
});
|
||||
|
||||
const proxy = new SmartProxy({
|
||||
routes: [{
|
||||
name: 'ws-large-route',
|
||||
match: { ports: PROXY_PORT },
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{ host: '127.0.0.1', port: BACKEND_PORT }],
|
||||
websocket: { enabled: true },
|
||||
},
|
||||
}],
|
||||
});
|
||||
await proxy.start();
|
||||
|
||||
const { ws, messages, opened } = connectWs(
|
||||
`ws://127.0.0.1:${PROXY_PORT}/`,
|
||||
{ Host: 'test.local' },
|
||||
{ maxPayload: 5 * 1024 * 1024 },
|
||||
);
|
||||
await opened;
|
||||
|
||||
// Send a 1MB message
|
||||
const largePayload = Buffer.alloc(1024 * 1024, 0x42);
|
||||
ws.send(largePayload);
|
||||
|
||||
await waitFor(() => messages.length >= 1);
|
||||
expect(messages[0]).toEqual(`received ${1024 * 1024} bytes`);
|
||||
|
||||
await closeWs(ws);
|
||||
await proxy.stop();
|
||||
await new Promise<void>((resolve) => backendServer.close(() => resolve()));
|
||||
await new Promise((r) => setTimeout(r, 500));
|
||||
await assertPortsFree([PROXY_PORT, BACKEND_PORT]);
|
||||
});
|
||||
|
||||
export default tap.start();
|
||||
@@ -3,6 +3,6 @@
|
||||
*/
|
||||
export const commitinfo = {
|
||||
name: '@push.rocks/smartproxy',
|
||||
version: '27.0.0',
|
||||
version: '27.7.3',
|
||||
description: 'A powerful proxy package with unified route-based configuration for high traffic management. Features include SSL/TLS support, flexible routing patterns, WebSocket handling, advanced security options, and automatic ACME certificate management.'
|
||||
}
|
||||
|
||||
@@ -32,6 +32,23 @@ export interface IThroughputHistoryPoint {
|
||||
/**
|
||||
* Main metrics interface with clean, grouped API
|
||||
*/
|
||||
/**
|
||||
* Protocol distribution for frontend (client→proxy) or backend (proxy→upstream).
|
||||
* Tracks active and total counts for h1/h2/h3/ws/other.
|
||||
*/
|
||||
export interface IProtocolDistribution {
|
||||
h1Active: number;
|
||||
h1Total: number;
|
||||
h2Active: number;
|
||||
h2Total: number;
|
||||
h3Active: number;
|
||||
h3Total: number;
|
||||
wsActive: number;
|
||||
wsTotal: number;
|
||||
otherActive: number;
|
||||
otherTotal: number;
|
||||
}
|
||||
|
||||
export interface IMetrics {
|
||||
// Connection metrics
|
||||
connections: {
|
||||
@@ -40,6 +57,12 @@ export interface IMetrics {
|
||||
byRoute(): Map<string, number>;
|
||||
byIP(): Map<string, number>;
|
||||
topIPs(limit?: number): Array<{ ip: string; count: number }>;
|
||||
/** Per-IP domain request counts: IP -> { domain -> count }. */
|
||||
domainRequestsByIP(): Map<string, Map<string, number>>;
|
||||
/** Top IP-domain pairs sorted by request count descending. */
|
||||
topDomainRequests(limit?: number): Array<{ ip: string; domain: string; count: number }>;
|
||||
frontendProtocols(): IProtocolDistribution;
|
||||
backendProtocols(): IProtocolDistribution;
|
||||
};
|
||||
|
||||
// Throughput metrics (bytes per second)
|
||||
|
||||
@@ -141,8 +141,10 @@ export interface IRouteAuthentication {
|
||||
* Security options for routes
|
||||
*/
|
||||
export interface IRouteSecurity {
|
||||
// Access control lists
|
||||
ipAllowList?: string[]; // IP addresses that are allowed to connect
|
||||
// Access control lists.
|
||||
// Entries can be plain IP/CIDR strings (full route access) or
|
||||
// objects { ip, domains } to scope access to specific domains on this route.
|
||||
ipAllowList?: Array<string | { ip: string; domains: string[] }>;
|
||||
ipBlockList?: string[]; // IP addresses that are blocked from connecting
|
||||
|
||||
// Connection limits
|
||||
|
||||
@@ -0,0 +1,160 @@
|
||||
import type { IProtocolCacheEntry, IProtocolDistribution } from './metrics-types.js';
|
||||
import type { IAcmeOptions, ISmartProxyOptions } from './interfaces.js';
|
||||
import type {
|
||||
IRouteAction,
|
||||
IRouteConfig,
|
||||
IRouteMatch,
|
||||
IRouteTarget,
|
||||
ITargetMatch,
|
||||
IRouteUdp,
|
||||
} from './route-types.js';
|
||||
|
||||
export type TRustHeaderMatchers = Record<string, string>;
|
||||
|
||||
export interface IRustRouteMatch extends Omit<IRouteMatch, 'headers'> {
|
||||
headers?: TRustHeaderMatchers;
|
||||
}
|
||||
|
||||
export interface IRustTargetMatch extends Omit<ITargetMatch, 'headers'> {
|
||||
headers?: TRustHeaderMatchers;
|
||||
}
|
||||
|
||||
export interface IRustRouteTarget extends Omit<IRouteTarget, 'host' | 'port' | 'match'> {
|
||||
host: string | string[];
|
||||
port: number | 'preserve';
|
||||
match?: IRustTargetMatch;
|
||||
}
|
||||
|
||||
export interface IRustRouteUdp extends Omit<IRouteUdp, 'maxSessionsPerIP'> {
|
||||
maxSessionsPerIp?: number;
|
||||
}
|
||||
|
||||
export interface IRustDefaultConfig extends Omit<NonNullable<ISmartProxyOptions['defaults']>, 'preserveSourceIP'> {
|
||||
preserveSourceIp?: boolean;
|
||||
}
|
||||
|
||||
export interface IRustRouteAction
|
||||
extends Omit<IRouteAction, 'targets' | 'socketHandler' | 'datagramHandler' | 'forwardingEngine' | 'nftables' | 'udp'> {
|
||||
targets?: IRustRouteTarget[];
|
||||
udp?: IRustRouteUdp;
|
||||
}
|
||||
|
||||
export interface IRustRouteConfig extends Omit<IRouteConfig, 'match' | 'action'> {
|
||||
match: IRustRouteMatch;
|
||||
action: IRustRouteAction;
|
||||
}
|
||||
|
||||
export interface IRustAcmeOptions extends Omit<IAcmeOptions, 'routeForwards'> {}
|
||||
|
||||
export interface IRustProxyOptions {
|
||||
routes: IRustRouteConfig[];
|
||||
preserveSourceIp?: boolean;
|
||||
proxyIps?: string[];
|
||||
acceptProxyProtocol?: boolean;
|
||||
sendProxyProtocol?: boolean;
|
||||
defaults?: IRustDefaultConfig;
|
||||
connectionTimeout?: number;
|
||||
initialDataTimeout?: number;
|
||||
socketTimeout?: number;
|
||||
inactivityCheckInterval?: number;
|
||||
maxConnectionLifetime?: number;
|
||||
inactivityTimeout?: number;
|
||||
gracefulShutdownTimeout?: number;
|
||||
noDelay?: boolean;
|
||||
keepAlive?: boolean;
|
||||
keepAliveInitialDelay?: number;
|
||||
maxPendingDataSize?: number;
|
||||
disableInactivityCheck?: boolean;
|
||||
enableKeepAliveProbes?: boolean;
|
||||
enableDetailedLogging?: boolean;
|
||||
enableTlsDebugLogging?: boolean;
|
||||
enableRandomizedTimeouts?: boolean;
|
||||
maxConnectionsPerIp?: number;
|
||||
connectionRateLimitPerMinute?: number;
|
||||
keepAliveTreatment?: ISmartProxyOptions['keepAliveTreatment'];
|
||||
keepAliveInactivityMultiplier?: number;
|
||||
extendedKeepAliveLifetime?: number;
|
||||
metrics?: ISmartProxyOptions['metrics'];
|
||||
acme?: IRustAcmeOptions;
|
||||
}
|
||||
|
||||
export interface IRustStatistics {
|
||||
activeConnections: number;
|
||||
totalConnections: number;
|
||||
routesCount: number;
|
||||
listeningPorts: number[];
|
||||
uptimeSeconds: number;
|
||||
}
|
||||
|
||||
export interface IRustCertificateStatus {
|
||||
domain: string;
|
||||
source: string;
|
||||
expiresAt: number;
|
||||
isValid: boolean;
|
||||
}
|
||||
|
||||
export interface IRustThroughputSample {
|
||||
timestampMs: number;
|
||||
bytesIn: number;
|
||||
bytesOut: number;
|
||||
}
|
||||
|
||||
export interface IRustRouteMetrics {
|
||||
activeConnections: number;
|
||||
totalConnections: number;
|
||||
bytesIn: number;
|
||||
bytesOut: number;
|
||||
throughputInBytesPerSec: number;
|
||||
throughputOutBytesPerSec: number;
|
||||
throughputRecentInBytesPerSec: number;
|
||||
throughputRecentOutBytesPerSec: number;
|
||||
}
|
||||
|
||||
export interface IRustIpMetrics {
|
||||
activeConnections: number;
|
||||
totalConnections: number;
|
||||
bytesIn: number;
|
||||
bytesOut: number;
|
||||
throughputInBytesPerSec: number;
|
||||
throughputOutBytesPerSec: number;
|
||||
domainRequests: Record<string, number>;
|
||||
}
|
||||
|
||||
export interface IRustBackendMetrics {
|
||||
activeConnections: number;
|
||||
totalConnections: number;
|
||||
protocol: string;
|
||||
connectErrors: number;
|
||||
handshakeErrors: number;
|
||||
requestErrors: number;
|
||||
totalConnectTimeUs: number;
|
||||
connectCount: number;
|
||||
poolHits: number;
|
||||
poolMisses: number;
|
||||
h2Failures: number;
|
||||
}
|
||||
|
||||
export interface IRustMetricsSnapshot {
|
||||
activeConnections: number;
|
||||
totalConnections: number;
|
||||
bytesIn: number;
|
||||
bytesOut: number;
|
||||
throughputInBytesPerSec: number;
|
||||
throughputOutBytesPerSec: number;
|
||||
throughputRecentInBytesPerSec: number;
|
||||
throughputRecentOutBytesPerSec: number;
|
||||
routes: Record<string, IRustRouteMetrics>;
|
||||
ips: Record<string, IRustIpMetrics>;
|
||||
backends: Record<string, IRustBackendMetrics>;
|
||||
throughputHistory: IRustThroughputSample[];
|
||||
totalHttpRequests: number;
|
||||
httpRequestsPerSec: number;
|
||||
httpRequestsPerSecRecent: number;
|
||||
activeUdpSessions: number;
|
||||
totalUdpSessions: number;
|
||||
totalDatagramsIn: number;
|
||||
totalDatagramsOut: number;
|
||||
detectedProtocols: IProtocolCacheEntry[];
|
||||
frontendProtocols: IProtocolDistribution;
|
||||
backendProtocols: IProtocolDistribution;
|
||||
}
|
||||
@@ -1,5 +1,6 @@
|
||||
import type { IRouteConfig, IRouteAction, IRouteTarget } from './models/route-types.js';
|
||||
import { logger } from '../../core/utils/logger.js';
|
||||
import type { IRustRouteConfig } from './models/rust-types.js';
|
||||
import { serializeRouteForRust } from './utils/rust-config.js';
|
||||
|
||||
/**
|
||||
* Preprocesses routes before sending them to Rust.
|
||||
@@ -24,7 +25,7 @@ export class RoutePreprocessor {
|
||||
* - Non-serializable fields are stripped
|
||||
* - Original routes are preserved in the local map for handler lookup
|
||||
*/
|
||||
public preprocessForRust(routes: IRouteConfig[]): IRouteConfig[] {
|
||||
public preprocessForRust(routes: IRouteConfig[]): IRustRouteConfig[] {
|
||||
this.originalRoutes.clear();
|
||||
return routes.map((route, index) => this.preprocessRoute(route, index));
|
||||
}
|
||||
@@ -43,7 +44,7 @@ export class RoutePreprocessor {
|
||||
return new Map(this.originalRoutes);
|
||||
}
|
||||
|
||||
private preprocessRoute(route: IRouteConfig, index: number): IRouteConfig {
|
||||
private preprocessRoute(route: IRouteConfig, index: number): IRustRouteConfig {
|
||||
const routeKey = route.name || route.id || `route_${index}`;
|
||||
|
||||
// Check if this route needs TS-side handling
|
||||
@@ -57,7 +58,7 @@ export class RoutePreprocessor {
|
||||
// Create a clean copy for Rust
|
||||
const cleanRoute: IRouteConfig = {
|
||||
...route,
|
||||
action: this.cleanAction(route.action, routeKey, needsTsHandling),
|
||||
action: this.cleanAction(route.action, needsTsHandling),
|
||||
};
|
||||
|
||||
// Ensure we have a name for handler lookup
|
||||
@@ -65,7 +66,7 @@ export class RoutePreprocessor {
|
||||
cleanRoute.name = routeKey;
|
||||
}
|
||||
|
||||
return cleanRoute;
|
||||
return serializeRouteForRust(cleanRoute);
|
||||
}
|
||||
|
||||
private routeNeedsTsHandling(route: IRouteConfig): boolean {
|
||||
@@ -91,15 +92,16 @@ export class RoutePreprocessor {
|
||||
return false;
|
||||
}
|
||||
|
||||
private cleanAction(action: IRouteAction, routeKey: string, needsTsHandling: boolean): IRouteAction {
|
||||
const cleanAction: IRouteAction = { ...action };
|
||||
private cleanAction(action: IRouteAction, needsTsHandling: boolean): IRouteAction {
|
||||
let cleanAction: IRouteAction = { ...action };
|
||||
|
||||
if (needsTsHandling) {
|
||||
// Convert to socket-handler type for Rust (Rust will relay back to TS)
|
||||
cleanAction.type = 'socket-handler';
|
||||
// Remove the JS handlers (not serializable)
|
||||
delete (cleanAction as any).socketHandler;
|
||||
delete (cleanAction as any).datagramHandler;
|
||||
const { socketHandler: _socketHandler, datagramHandler: _datagramHandler, ...serializableAction } = cleanAction;
|
||||
cleanAction = {
|
||||
...serializableAction,
|
||||
type: 'socket-handler',
|
||||
};
|
||||
}
|
||||
|
||||
// Clean targets - replace functions with static values
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import type { IMetrics, IBackendMetrics, IProtocolCacheEntry, IThroughputData, IThroughputHistoryPoint } from './models/metrics-types.js';
|
||||
import type { IMetrics, IBackendMetrics, IProtocolCacheEntry, IProtocolDistribution, IThroughputData, IThroughputHistoryPoint } from './models/metrics-types.js';
|
||||
import type { RustProxyBridge } from './rust-proxy-bridge.js';
|
||||
import type { IRustBackendMetrics, IRustIpMetrics, IRustMetricsSnapshot, IRustRouteMetrics } from './models/rust-types.js';
|
||||
|
||||
/**
|
||||
* Adapts Rust JSON metrics to the IMetrics interface.
|
||||
@@ -14,7 +15,7 @@ import type { RustProxyBridge } from './rust-proxy-bridge.js';
|
||||
*/
|
||||
export class RustMetricsAdapter implements IMetrics {
|
||||
private bridge: RustProxyBridge;
|
||||
private cache: any = null;
|
||||
private cache: IRustMetricsSnapshot | null = null;
|
||||
private pollTimer: ReturnType<typeof setInterval> | null = null;
|
||||
private pollIntervalMs: number;
|
||||
|
||||
@@ -65,8 +66,8 @@ export class RustMetricsAdapter implements IMetrics {
|
||||
byRoute: (): Map<string, number> => {
|
||||
const result = new Map<string, number>();
|
||||
if (this.cache?.routes) {
|
||||
for (const [name, rm] of Object.entries(this.cache.routes)) {
|
||||
result.set(name, (rm as any).activeConnections ?? 0);
|
||||
for (const [name, rm] of Object.entries(this.cache.routes) as Array<[string, IRustRouteMetrics]>) {
|
||||
result.set(name, rm.activeConnections ?? 0);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
@@ -74,8 +75,8 @@ export class RustMetricsAdapter implements IMetrics {
|
||||
byIP: (): Map<string, number> => {
|
||||
const result = new Map<string, number>();
|
||||
if (this.cache?.ips) {
|
||||
for (const [ip, im] of Object.entries(this.cache.ips)) {
|
||||
result.set(ip, (im as any).activeConnections ?? 0);
|
||||
for (const [ip, im] of Object.entries(this.cache.ips) as Array<[string, IRustIpMetrics]>) {
|
||||
result.set(ip, im.activeConnections ?? 0);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
@@ -83,13 +84,76 @@ export class RustMetricsAdapter implements IMetrics {
|
||||
topIPs: (limit: number = 10): Array<{ ip: string; count: number }> => {
|
||||
const result: Array<{ ip: string; count: number }> = [];
|
||||
if (this.cache?.ips) {
|
||||
for (const [ip, im] of Object.entries(this.cache.ips)) {
|
||||
result.push({ ip, count: (im as any).activeConnections ?? 0 });
|
||||
for (const [ip, im] of Object.entries(this.cache.ips) as Array<[string, IRustIpMetrics]>) {
|
||||
result.push({ ip, count: im.activeConnections ?? 0 });
|
||||
}
|
||||
}
|
||||
result.sort((a, b) => b.count - a.count);
|
||||
return result.slice(0, limit);
|
||||
},
|
||||
domainRequestsByIP: (): Map<string, Map<string, number>> => {
|
||||
const result = new Map<string, Map<string, number>>();
|
||||
if (this.cache?.ips) {
|
||||
for (const [ip, im] of Object.entries(this.cache.ips) as Array<[string, IRustIpMetrics]>) {
|
||||
const dr = im.domainRequests;
|
||||
if (dr && typeof dr === 'object') {
|
||||
const domainMap = new Map<string, number>();
|
||||
for (const [domain, count] of Object.entries(dr)) {
|
||||
domainMap.set(domain, count as number);
|
||||
}
|
||||
if (domainMap.size > 0) {
|
||||
result.set(ip, domainMap);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return result;
|
||||
},
|
||||
topDomainRequests: (limit: number = 20): Array<{ ip: string; domain: string; count: number }> => {
|
||||
const result: Array<{ ip: string; domain: string; count: number }> = [];
|
||||
if (this.cache?.ips) {
|
||||
for (const [ip, im] of Object.entries(this.cache.ips) as Array<[string, IRustIpMetrics]>) {
|
||||
const dr = im.domainRequests;
|
||||
if (dr && typeof dr === 'object') {
|
||||
for (const [domain, count] of Object.entries(dr)) {
|
||||
result.push({ ip, domain, count: count as number });
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
result.sort((a, b) => b.count - a.count);
|
||||
return result.slice(0, limit);
|
||||
},
|
||||
frontendProtocols: (): IProtocolDistribution => {
|
||||
const fp = this.cache?.frontendProtocols;
|
||||
return {
|
||||
h1Active: fp?.h1Active ?? 0,
|
||||
h1Total: fp?.h1Total ?? 0,
|
||||
h2Active: fp?.h2Active ?? 0,
|
||||
h2Total: fp?.h2Total ?? 0,
|
||||
h3Active: fp?.h3Active ?? 0,
|
||||
h3Total: fp?.h3Total ?? 0,
|
||||
wsActive: fp?.wsActive ?? 0,
|
||||
wsTotal: fp?.wsTotal ?? 0,
|
||||
otherActive: fp?.otherActive ?? 0,
|
||||
otherTotal: fp?.otherTotal ?? 0,
|
||||
};
|
||||
},
|
||||
backendProtocols: (): IProtocolDistribution => {
|
||||
const bp = this.cache?.backendProtocols;
|
||||
return {
|
||||
h1Active: bp?.h1Active ?? 0,
|
||||
h1Total: bp?.h1Total ?? 0,
|
||||
h2Active: bp?.h2Active ?? 0,
|
||||
h2Total: bp?.h2Total ?? 0,
|
||||
h3Active: bp?.h3Active ?? 0,
|
||||
h3Total: bp?.h3Total ?? 0,
|
||||
wsActive: bp?.wsActive ?? 0,
|
||||
wsTotal: bp?.wsTotal ?? 0,
|
||||
otherActive: bp?.otherActive ?? 0,
|
||||
otherTotal: bp?.otherTotal ?? 0,
|
||||
};
|
||||
},
|
||||
};
|
||||
|
||||
public throughput = {
|
||||
@@ -113,7 +177,7 @@ export class RustMetricsAdapter implements IMetrics {
|
||||
},
|
||||
history: (seconds: number): Array<IThroughputHistoryPoint> => {
|
||||
if (!this.cache?.throughputHistory) return [];
|
||||
return this.cache.throughputHistory.slice(-seconds).map((p: any) => ({
|
||||
return this.cache.throughputHistory.slice(-seconds).map((p) => ({
|
||||
timestamp: p.timestampMs,
|
||||
in: p.bytesIn,
|
||||
out: p.bytesOut,
|
||||
@@ -122,10 +186,10 @@ export class RustMetricsAdapter implements IMetrics {
|
||||
byRoute: (_windowSeconds?: number): Map<string, IThroughputData> => {
|
||||
const result = new Map<string, IThroughputData>();
|
||||
if (this.cache?.routes) {
|
||||
for (const [name, rm] of Object.entries(this.cache.routes)) {
|
||||
for (const [name, rm] of Object.entries(this.cache.routes) as Array<[string, IRustRouteMetrics]>) {
|
||||
result.set(name, {
|
||||
in: (rm as any).throughputInBytesPerSec ?? 0,
|
||||
out: (rm as any).throughputOutBytesPerSec ?? 0,
|
||||
in: rm.throughputInBytesPerSec ?? 0,
|
||||
out: rm.throughputOutBytesPerSec ?? 0,
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -134,10 +198,10 @@ export class RustMetricsAdapter implements IMetrics {
|
||||
byIP: (_windowSeconds?: number): Map<string, IThroughputData> => {
|
||||
const result = new Map<string, IThroughputData>();
|
||||
if (this.cache?.ips) {
|
||||
for (const [ip, im] of Object.entries(this.cache.ips)) {
|
||||
for (const [ip, im] of Object.entries(this.cache.ips) as Array<[string, IRustIpMetrics]>) {
|
||||
result.set(ip, {
|
||||
in: (im as any).throughputInBytesPerSec ?? 0,
|
||||
out: (im as any).throughputOutBytesPerSec ?? 0,
|
||||
in: im.throughputInBytesPerSec ?? 0,
|
||||
out: im.throughputOutBytesPerSec ?? 0,
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -173,23 +237,22 @@ export class RustMetricsAdapter implements IMetrics {
|
||||
byBackend: (): Map<string, IBackendMetrics> => {
|
||||
const result = new Map<string, IBackendMetrics>();
|
||||
if (this.cache?.backends) {
|
||||
for (const [key, bm] of Object.entries(this.cache.backends)) {
|
||||
const m = bm as any;
|
||||
const totalTimeUs = m.totalConnectTimeUs ?? 0;
|
||||
const count = m.connectCount ?? 0;
|
||||
const poolHits = m.poolHits ?? 0;
|
||||
const poolMisses = m.poolMisses ?? 0;
|
||||
for (const [key, bm] of Object.entries(this.cache.backends) as Array<[string, IRustBackendMetrics]>) {
|
||||
const totalTimeUs = bm.totalConnectTimeUs ?? 0;
|
||||
const count = bm.connectCount ?? 0;
|
||||
const poolHits = bm.poolHits ?? 0;
|
||||
const poolMisses = bm.poolMisses ?? 0;
|
||||
const poolTotal = poolHits + poolMisses;
|
||||
result.set(key, {
|
||||
protocol: m.protocol ?? 'unknown',
|
||||
activeConnections: m.activeConnections ?? 0,
|
||||
totalConnections: m.totalConnections ?? 0,
|
||||
connectErrors: m.connectErrors ?? 0,
|
||||
handshakeErrors: m.handshakeErrors ?? 0,
|
||||
requestErrors: m.requestErrors ?? 0,
|
||||
protocol: bm.protocol ?? 'unknown',
|
||||
activeConnections: bm.activeConnections ?? 0,
|
||||
totalConnections: bm.totalConnections ?? 0,
|
||||
connectErrors: bm.connectErrors ?? 0,
|
||||
handshakeErrors: bm.handshakeErrors ?? 0,
|
||||
requestErrors: bm.requestErrors ?? 0,
|
||||
avgConnectTimeMs: count > 0 ? (totalTimeUs / count) / 1000 : 0,
|
||||
poolHitRate: poolTotal > 0 ? poolHits / poolTotal : 0,
|
||||
h2Failures: m.h2Failures ?? 0,
|
||||
h2Failures: bm.h2Failures ?? 0,
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -198,8 +261,8 @@ export class RustMetricsAdapter implements IMetrics {
|
||||
protocols: (): Map<string, string> => {
|
||||
const result = new Map<string, string>();
|
||||
if (this.cache?.backends) {
|
||||
for (const [key, bm] of Object.entries(this.cache.backends)) {
|
||||
result.set(key, (bm as any).protocol ?? 'unknown');
|
||||
for (const [key, bm] of Object.entries(this.cache.backends) as Array<[string, IRustBackendMetrics]>) {
|
||||
result.set(key, bm.protocol ?? 'unknown');
|
||||
}
|
||||
}
|
||||
return result;
|
||||
@@ -207,9 +270,8 @@ export class RustMetricsAdapter implements IMetrics {
|
||||
topByErrors: (limit: number = 10): Array<{ backend: string; errors: number }> => {
|
||||
const result: Array<{ backend: string; errors: number }> = [];
|
||||
if (this.cache?.backends) {
|
||||
for (const [key, bm] of Object.entries(this.cache.backends)) {
|
||||
const m = bm as any;
|
||||
const errors = (m.connectErrors ?? 0) + (m.handshakeErrors ?? 0) + (m.requestErrors ?? 0);
|
||||
for (const [key, bm] of Object.entries(this.cache.backends) as Array<[string, IRustBackendMetrics]>) {
|
||||
const errors = (bm.connectErrors ?? 0) + (bm.handshakeErrors ?? 0) + (bm.requestErrors ?? 0);
|
||||
if (errors > 0) result.push({ backend: key, errors });
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,23 +1,29 @@
|
||||
import * as plugins from '../../plugins.js';
|
||||
import { logger } from '../../core/utils/logger.js';
|
||||
import type { IRouteConfig } from './models/route-types.js';
|
||||
import type {
|
||||
IRustCertificateStatus,
|
||||
IRustMetricsSnapshot,
|
||||
IRustProxyOptions,
|
||||
IRustRouteConfig,
|
||||
IRustStatistics,
|
||||
} from './models/rust-types.js';
|
||||
|
||||
/**
|
||||
* Type-safe command definitions for the Rust proxy IPC protocol.
|
||||
*/
|
||||
type TSmartProxyCommands = {
|
||||
start: { params: { config: any }; result: void };
|
||||
stop: { params: Record<string, never>; result: void };
|
||||
updateRoutes: { params: { routes: IRouteConfig[] }; result: void };
|
||||
getMetrics: { params: Record<string, never>; result: any };
|
||||
getStatistics: { params: Record<string, never>; result: any };
|
||||
provisionCertificate: { params: { routeName: string }; result: void };
|
||||
renewCertificate: { params: { routeName: string }; result: void };
|
||||
getCertificateStatus: { params: { routeName: string }; result: any };
|
||||
getListeningPorts: { params: Record<string, never>; result: { ports: number[] } };
|
||||
setSocketHandlerRelay: { params: { socketPath: string }; result: void };
|
||||
addListeningPort: { params: { port: number }; result: void };
|
||||
removeListeningPort: { params: { port: number }; result: void };
|
||||
start: { params: { config: IRustProxyOptions }; result: void };
|
||||
stop: { params: Record<string, never>; result: void };
|
||||
updateRoutes: { params: { routes: IRustRouteConfig[] }; result: void };
|
||||
getMetrics: { params: Record<string, never>; result: IRustMetricsSnapshot };
|
||||
getStatistics: { params: Record<string, never>; result: IRustStatistics };
|
||||
provisionCertificate: { params: { routeName: string }; result: void };
|
||||
renewCertificate: { params: { routeName: string }; result: void };
|
||||
getCertificateStatus: { params: { routeName: string }; result: IRustCertificateStatus | null };
|
||||
getListeningPorts: { params: Record<string, never>; result: { ports: number[] } };
|
||||
setSocketHandlerRelay: { params: { socketPath: string }; result: void };
|
||||
addListeningPort: { params: { port: number }; result: void };
|
||||
removeListeningPort: { params: { port: number }; result: void };
|
||||
loadCertificate: { params: { domain: string; cert: string; key: string; ca?: string }; result: void };
|
||||
setDatagramHandlerRelay: { params: { socketPath: string }; result: void };
|
||||
};
|
||||
@@ -121,7 +127,7 @@ export class RustProxyBridge extends plugins.EventEmitter {
|
||||
|
||||
// --- Convenience methods for each management command ---
|
||||
|
||||
public async startProxy(config: any): Promise<void> {
|
||||
public async startProxy(config: IRustProxyOptions): Promise<void> {
|
||||
await this.bridge.sendCommand('start', { config });
|
||||
}
|
||||
|
||||
@@ -129,15 +135,15 @@ export class RustProxyBridge extends plugins.EventEmitter {
|
||||
await this.bridge.sendCommand('stop', {} as Record<string, never>);
|
||||
}
|
||||
|
||||
public async updateRoutes(routes: IRouteConfig[]): Promise<void> {
|
||||
public async updateRoutes(routes: IRustRouteConfig[]): Promise<void> {
|
||||
await this.bridge.sendCommand('updateRoutes', { routes });
|
||||
}
|
||||
|
||||
public async getMetrics(): Promise<any> {
|
||||
public async getMetrics(): Promise<IRustMetricsSnapshot> {
|
||||
return this.bridge.sendCommand('getMetrics', {} as Record<string, never>);
|
||||
}
|
||||
|
||||
public async getStatistics(): Promise<any> {
|
||||
public async getStatistics(): Promise<IRustStatistics> {
|
||||
return this.bridge.sendCommand('getStatistics', {} as Record<string, never>);
|
||||
}
|
||||
|
||||
@@ -149,7 +155,7 @@ export class RustProxyBridge extends plugins.EventEmitter {
|
||||
await this.bridge.sendCommand('renewCertificate', { routeName });
|
||||
}
|
||||
|
||||
public async getCertificateStatus(routeName: string): Promise<any> {
|
||||
public async getCertificateStatus(routeName: string): Promise<IRustCertificateStatus | null> {
|
||||
return this.bridge.sendCommand('getCertificateStatus', { routeName });
|
||||
}
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ import { RustMetricsAdapter } from './rust-metrics-adapter.js';
|
||||
// Route management
|
||||
import { SharedRouteManager as RouteManager } from '../../core/routing/route-manager.js';
|
||||
import { RouteValidator } from './utils/route-validator.js';
|
||||
import { buildRustProxyOptions } from './utils/rust-config.js';
|
||||
import { generateDefaultCertificate } from './utils/default-cert-generator.js';
|
||||
import { Mutex } from './utils/mutex.js';
|
||||
import { ConcurrencySemaphore } from './utils/concurrency-semaphore.js';
|
||||
@@ -19,6 +20,7 @@ import { ConcurrencySemaphore } from './utils/concurrency-semaphore.js';
|
||||
import type { ISmartProxyOptions, TSmartProxyCertProvisionObject, IAcmeOptions, ICertProvisionEventComms, ICertificateIssuedEvent, ICertificateFailedEvent } from './models/interfaces.js';
|
||||
import type { IRouteConfig } from './models/route-types.js';
|
||||
import type { IMetrics } from './models/metrics-types.js';
|
||||
import type { IRustCertificateStatus, IRustProxyOptions, IRustStatistics } from './models/rust-types.js';
|
||||
|
||||
/**
|
||||
* SmartProxy - Rust-backed proxy engine with TypeScript configuration API.
|
||||
@@ -365,7 +367,7 @@ export class SmartProxy extends plugins.EventEmitter {
|
||||
/**
|
||||
* Get certificate status for a route (async - calls Rust).
|
||||
*/
|
||||
public async getCertificateStatus(routeName: string): Promise<any> {
|
||||
public async getCertificateStatus(routeName: string): Promise<IRustCertificateStatus | null> {
|
||||
return this.bridge.getCertificateStatus(routeName);
|
||||
}
|
||||
|
||||
@@ -379,7 +381,7 @@ export class SmartProxy extends plugins.EventEmitter {
|
||||
/**
|
||||
* Get statistics (async - calls Rust).
|
||||
*/
|
||||
public async getStatistics(): Promise<any> {
|
||||
public async getStatistics(): Promise<IRustStatistics> {
|
||||
return this.bridge.getStatistics();
|
||||
}
|
||||
|
||||
@@ -484,37 +486,8 @@ export class SmartProxy extends plugins.EventEmitter {
|
||||
/**
|
||||
* Build the Rust configuration object from TS settings.
|
||||
*/
|
||||
private buildRustConfig(routes: IRouteConfig[], acmeOverride?: IAcmeOptions): any {
|
||||
const acme = acmeOverride !== undefined ? acmeOverride : this.settings.acme;
|
||||
return {
|
||||
routes,
|
||||
defaults: this.settings.defaults,
|
||||
acme: acme
|
||||
? {
|
||||
enabled: acme.enabled,
|
||||
email: acme.email,
|
||||
useProduction: acme.useProduction,
|
||||
port: acme.port,
|
||||
renewThresholdDays: acme.renewThresholdDays,
|
||||
autoRenew: acme.autoRenew,
|
||||
renewCheckIntervalHours: acme.renewCheckIntervalHours,
|
||||
}
|
||||
: undefined,
|
||||
connectionTimeout: this.settings.connectionTimeout,
|
||||
initialDataTimeout: this.settings.initialDataTimeout,
|
||||
socketTimeout: this.settings.socketTimeout,
|
||||
maxConnectionLifetime: this.settings.maxConnectionLifetime,
|
||||
gracefulShutdownTimeout: this.settings.gracefulShutdownTimeout,
|
||||
maxConnectionsPerIp: this.settings.maxConnectionsPerIP,
|
||||
connectionRateLimitPerMinute: this.settings.connectionRateLimitPerMinute,
|
||||
keepAliveTreatment: this.settings.keepAliveTreatment,
|
||||
keepAliveInactivityMultiplier: this.settings.keepAliveInactivityMultiplier,
|
||||
extendedKeepAliveLifetime: this.settings.extendedKeepAliveLifetime,
|
||||
proxyIps: this.settings.proxyIPs,
|
||||
acceptProxyProtocol: this.settings.acceptProxyProtocol,
|
||||
sendProxyProtocol: this.settings.sendProxyProtocol,
|
||||
metrics: this.settings.metrics,
|
||||
};
|
||||
private buildRustConfig(routes: IRustProxyOptions['routes'], acmeOverride?: IAcmeOptions): IRustProxyOptions {
|
||||
return buildRustProxyOptions(this.settings, routes, acmeOverride);
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -168,14 +168,28 @@ export function routeMatchesHeaders(
|
||||
if (!route.match?.headers || Object.keys(route.match.headers).length === 0) {
|
||||
return true; // No headers specified means it matches any headers
|
||||
}
|
||||
|
||||
// Convert RegExp patterns to strings for HeaderMatcher
|
||||
const stringHeaders: Record<string, string> = {};
|
||||
for (const [key, value] of Object.entries(route.match.headers)) {
|
||||
stringHeaders[key] = value instanceof RegExp ? value.source : value;
|
||||
|
||||
for (const [headerName, expectedValue] of Object.entries(route.match.headers)) {
|
||||
const actualKey = Object.keys(headers).find((key) => key.toLowerCase() === headerName.toLowerCase());
|
||||
const actualValue = actualKey ? headers[actualKey] : undefined;
|
||||
|
||||
if (actualValue === undefined) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (expectedValue instanceof RegExp) {
|
||||
if (!expectedValue.test(actualValue)) {
|
||||
return false;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!HeaderMatcher.match(expectedValue, actualValue)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return HeaderMatcher.matchAll(stringHeaders, headers);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -283,4 +297,4 @@ export function generateRouteId(route: IRouteConfig): string {
|
||||
*/
|
||||
export function cloneRoute(route: IRouteConfig): IRouteConfig {
|
||||
return JSON.parse(JSON.stringify(route));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -196,10 +196,19 @@ export class RouteValidator {
|
||||
// Validate IP allow/block lists
|
||||
if (route.security.ipAllowList) {
|
||||
const allowList = Array.isArray(route.security.ipAllowList) ? route.security.ipAllowList : [route.security.ipAllowList];
|
||||
|
||||
for (const ip of allowList) {
|
||||
if (!this.isValidIPPattern(ip)) {
|
||||
errors.push(`Invalid IP pattern in allow list: ${ip}`);
|
||||
|
||||
for (const entry of allowList) {
|
||||
if (typeof entry === 'string') {
|
||||
if (!this.isValidIPPattern(entry)) {
|
||||
errors.push(`Invalid IP pattern in allow list: ${entry}`);
|
||||
}
|
||||
} else if (entry && typeof entry === 'object') {
|
||||
if (!this.isValidIPPattern(entry.ip)) {
|
||||
errors.push(`Invalid IP pattern in domain-scoped allow entry: ${entry.ip}`);
|
||||
}
|
||||
if (!Array.isArray(entry.domains) || entry.domains.length === 0) {
|
||||
errors.push(`Domain-scoped allow entry for ${entry.ip} must have non-empty domains array`);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,187 @@
|
||||
import type { IAcmeOptions, ISmartProxyOptions } from '../models/interfaces.js';
|
||||
import type { IRouteAction, IRouteConfig, IRouteMatch, IRouteTarget, ITargetMatch } from '../models/route-types.js';
|
||||
import type {
|
||||
IRustAcmeOptions,
|
||||
IRustDefaultConfig,
|
||||
IRustProxyOptions,
|
||||
IRustRouteAction,
|
||||
IRustRouteConfig,
|
||||
IRustRouteMatch,
|
||||
IRustRouteTarget,
|
||||
IRustTargetMatch,
|
||||
IRustRouteUdp,
|
||||
TRustHeaderMatchers,
|
||||
} from '../models/rust-types.js';
|
||||
|
||||
const SUPPORTED_REGEX_FLAGS = new Set(['i', 'm', 's', 'u', 'g']);
|
||||
|
||||
export function serializeHeaderMatchValue(value: string | RegExp): string {
|
||||
if (typeof value === 'string') {
|
||||
return value;
|
||||
}
|
||||
|
||||
const unsupportedFlags = Array.from(new Set(value.flags)).filter((flag) => !SUPPORTED_REGEX_FLAGS.has(flag));
|
||||
if (unsupportedFlags.length > 0) {
|
||||
throw new Error(
|
||||
`Header RegExp uses unsupported flags for Rust serialization: ${unsupportedFlags.join(', ')}`
|
||||
);
|
||||
}
|
||||
|
||||
return `/${value.source}/${value.flags}`;
|
||||
}
|
||||
|
||||
export function serializeHeaderMatchers(headers?: Record<string, string | RegExp>): TRustHeaderMatchers | undefined {
|
||||
if (!headers) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
return Object.fromEntries(
|
||||
Object.entries(headers).map(([key, value]) => [key, serializeHeaderMatchValue(value)])
|
||||
);
|
||||
}
|
||||
|
||||
export function serializeTargetMatchForRust(match?: ITargetMatch): IRustTargetMatch | undefined {
|
||||
if (!match) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
return {
|
||||
...match,
|
||||
headers: serializeHeaderMatchers(match.headers),
|
||||
};
|
||||
}
|
||||
|
||||
export function serializeRouteMatchForRust(match: IRouteMatch): IRustRouteMatch {
|
||||
return {
|
||||
...match,
|
||||
headers: serializeHeaderMatchers(match.headers),
|
||||
};
|
||||
}
|
||||
|
||||
export function serializeRouteTargetForRust(target: IRouteTarget): IRustRouteTarget {
|
||||
if (typeof target.host !== 'string' && !Array.isArray(target.host)) {
|
||||
throw new Error('Route target host must be serialized before sending to Rust');
|
||||
}
|
||||
|
||||
if (typeof target.port !== 'number' && target.port !== 'preserve') {
|
||||
throw new Error('Route target port must be serialized before sending to Rust');
|
||||
}
|
||||
|
||||
return {
|
||||
...target,
|
||||
host: target.host,
|
||||
port: target.port,
|
||||
match: serializeTargetMatchForRust(target.match),
|
||||
};
|
||||
}
|
||||
|
||||
function serializeUdpForRust(udp?: IRouteAction['udp']): IRustRouteUdp | undefined {
|
||||
if (!udp) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
const { maxSessionsPerIP, ...rest } = udp;
|
||||
|
||||
return {
|
||||
...rest,
|
||||
maxSessionsPerIp: maxSessionsPerIP,
|
||||
};
|
||||
}
|
||||
|
||||
export function serializeRouteActionForRust(action: IRouteAction): IRustRouteAction {
|
||||
const {
|
||||
socketHandler: _socketHandler,
|
||||
datagramHandler: _datagramHandler,
|
||||
forwardingEngine: _forwardingEngine,
|
||||
nftables: _nftables,
|
||||
targets,
|
||||
udp,
|
||||
...rest
|
||||
} = action;
|
||||
|
||||
return {
|
||||
...rest,
|
||||
targets: targets?.map((target) => serializeRouteTargetForRust(target)),
|
||||
udp: serializeUdpForRust(udp),
|
||||
};
|
||||
}
|
||||
|
||||
export function serializeRouteForRust(route: IRouteConfig): IRustRouteConfig {
|
||||
return {
|
||||
...route,
|
||||
match: serializeRouteMatchForRust(route.match),
|
||||
action: serializeRouteActionForRust(route.action),
|
||||
};
|
||||
}
|
||||
|
||||
function serializeAcmeForRust(acme?: IAcmeOptions): IRustAcmeOptions | undefined {
|
||||
if (!acme) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
return {
|
||||
enabled: acme.enabled,
|
||||
email: acme.email,
|
||||
environment: acme.environment,
|
||||
accountEmail: acme.accountEmail,
|
||||
port: acme.port,
|
||||
useProduction: acme.useProduction,
|
||||
renewThresholdDays: acme.renewThresholdDays,
|
||||
autoRenew: acme.autoRenew,
|
||||
skipConfiguredCerts: acme.skipConfiguredCerts,
|
||||
renewCheckIntervalHours: acme.renewCheckIntervalHours,
|
||||
};
|
||||
}
|
||||
|
||||
function serializeDefaultsForRust(defaults?: ISmartProxyOptions['defaults']): IRustDefaultConfig | undefined {
|
||||
if (!defaults) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
const { preserveSourceIP, ...rest } = defaults;
|
||||
|
||||
return {
|
||||
...rest,
|
||||
preserveSourceIp: preserveSourceIP,
|
||||
};
|
||||
}
|
||||
|
||||
export function buildRustProxyOptions(
|
||||
settings: ISmartProxyOptions,
|
||||
routes: IRustRouteConfig[],
|
||||
acmeOverride?: IAcmeOptions,
|
||||
): IRustProxyOptions {
|
||||
const acme = acmeOverride !== undefined ? acmeOverride : settings.acme;
|
||||
|
||||
return {
|
||||
routes,
|
||||
preserveSourceIp: settings.preserveSourceIP,
|
||||
proxyIps: settings.proxyIPs,
|
||||
acceptProxyProtocol: settings.acceptProxyProtocol,
|
||||
sendProxyProtocol: settings.sendProxyProtocol,
|
||||
defaults: serializeDefaultsForRust(settings.defaults),
|
||||
connectionTimeout: settings.connectionTimeout,
|
||||
initialDataTimeout: settings.initialDataTimeout,
|
||||
socketTimeout: settings.socketTimeout,
|
||||
inactivityCheckInterval: settings.inactivityCheckInterval,
|
||||
maxConnectionLifetime: settings.maxConnectionLifetime,
|
||||
inactivityTimeout: settings.inactivityTimeout,
|
||||
gracefulShutdownTimeout: settings.gracefulShutdownTimeout,
|
||||
noDelay: settings.noDelay,
|
||||
keepAlive: settings.keepAlive,
|
||||
keepAliveInitialDelay: settings.keepAliveInitialDelay,
|
||||
maxPendingDataSize: settings.maxPendingDataSize,
|
||||
disableInactivityCheck: settings.disableInactivityCheck,
|
||||
enableKeepAliveProbes: settings.enableKeepAliveProbes,
|
||||
enableDetailedLogging: settings.enableDetailedLogging,
|
||||
enableTlsDebugLogging: settings.enableTlsDebugLogging,
|
||||
enableRandomizedTimeouts: settings.enableRandomizedTimeouts,
|
||||
maxConnectionsPerIp: settings.maxConnectionsPerIP,
|
||||
connectionRateLimitPerMinute: settings.connectionRateLimitPerMinute,
|
||||
keepAliveTreatment: settings.keepAliveTreatment,
|
||||
keepAliveInactivityMultiplier: settings.keepAliveInactivityMultiplier,
|
||||
extendedKeepAliveLifetime: settings.extendedKeepAliveLifetime,
|
||||
metrics: settings.metrics,
|
||||
acme: serializeAcmeForRust(acme),
|
||||
};
|
||||
}
|
||||
Reference in New Issue
Block a user