Compare commits

..

44 Commits

Author SHA1 Message Date
5be93c8d38 v27.0.0
Some checks failed
Default (tags) / security (push) Failing after 0s
Default (tags) / test (push) Failing after 0s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-03-26 20:45:41 +00:00
788ccea81e BREAKING CHANGE(smart-proxy): remove route helper APIs and standardize route configuration on plain route objects 2026-03-26 20:45:41 +00:00
47140e5403 v26.3.0
Some checks failed
Default (tags) / security (push) Failing after 0s
Default (tags) / test (push) Failing after 0s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-03-26 13:11:57 +00:00
a6ffa24e36 feat(nftables): move NFTables forwarding management from the Rust engine to @push.rocks/smartnftables 2026-03-26 13:11:57 +00:00
c0e432fd9b v26.2.4
Some checks failed
Default (tags) / security (push) Failing after 0s
Default (tags) / test (push) Failing after 0s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-03-26 07:05:57 +00:00
a3d8a3a388 fix(rustproxy-http): improve HTTP/3 connection reuse and clean up stale proxy state 2026-03-26 07:05:57 +00:00
437d1a3329 v26.2.3
Some checks failed
Default (tags) / security (push) Failing after 3s
Default (tags) / test (push) Failing after 3s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-03-25 07:26:47 +00:00
746d93663d fix(repo): no changes to commit 2026-03-25 07:26:47 +00:00
a3f3fee253 v26.2.2
Some checks failed
Default (tags) / security (push) Failing after 4s
Default (tags) / test (push) Failing after 5s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-03-25 07:22:17 +00:00
53dee1fffc fix(proxy): improve connection cleanup and route validation handling 2026-03-25 07:22:17 +00:00
34dc0cb9b6 v26.2.1
Some checks failed
Default (tags) / security (push) Failing after 0s
Default (tags) / test (push) Failing after 0s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-03-23 11:11:55 +00:00
c83c43194b fix(rustproxy-http): include the upstream request URL when caching H3 Alt-Svc discoveries 2026-03-23 11:11:55 +00:00
d026d7c266 v26.2.0
Some checks failed
Default (tags) / security (push) Failing after 0s
Default (tags) / test (push) Failing after 0s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-03-23 09:42:07 +00:00
3b01144c51 feat(protocol-cache): add sliding TTL re-probing and eviction for backend protocol detection 2026-03-23 09:42:07 +00:00
56f5697e1b v26.1.0
Some checks failed
Default (tags) / security (push) Failing after 0s
Default (tags) / test (push) Failing after 0s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-03-22 10:20:00 +00:00
f04875885f feat(rustproxy-http): add protocol failure suppression, h3 fallback escalation, and protocol cache metrics exposure 2026-03-22 10:20:00 +00:00
d12812bb8d v26.0.0
Some checks failed
Default (tags) / security (push) Failing after 0s
Default (tags) / test (push) Failing after 0s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-03-21 22:23:38 +00:00
fc04a0210b BREAKING CHANGE(ts-api,rustproxy): remove deprecated TypeScript protocol and utility exports while hardening QUIC, HTTP/3, WebSocket, and rate limiter cleanup paths 2026-03-21 22:23:38 +00:00
33fdf42a70 v25.17.10
Some checks failed
Default (tags) / security (push) Failing after 0s
Default (tags) / test (push) Failing after 0s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-03-20 08:57:18 +00:00
fb1c59ac9a fix(rustproxy-http): reuse the shared HTTP proxy service for HTTP/3 request handling 2026-03-20 08:57:18 +00:00
ea8224c400 v25.17.9
Some checks failed
Default (tags) / security (push) Failing after 0s
Default (tags) / test (push) Failing after 0s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-03-20 08:30:09 +00:00
da1cc58a3d fix(rustproxy-http): correct HTTP/3 host extraction and avoid protocol filtering during UDP route lookup 2026-03-20 08:30:09 +00:00
606c620849 v25.17.8
Some checks failed
Default (tags) / security (push) Failing after 0s
Default (tags) / test (push) Failing after 0s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-03-20 08:06:32 +00:00
4ae09ac6ae fix(rustproxy): use SNI-based certificate resolution for QUIC TLS connections 2026-03-20 08:06:32 +00:00
2fce910795 v25.17.7
Some checks failed
Default (tags) / security (push) Failing after 0s
Default (tags) / test (push) Failing after 0s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-03-20 07:50:41 +00:00
ff09cef350 fix(readme): document QUIC and HTTP/3 compatibility caveats 2026-03-20 07:50:41 +00:00
d0148b2ac3 v25.17.6
Some checks failed
Default (tags) / security (push) Failing after 0s
Default (tags) / test (push) Failing after 0s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-03-20 07:48:26 +00:00
7217e15649 fix(rustproxy-http): disable HTTP/3 GREASE for client and server connections 2026-03-20 07:48:26 +00:00
bfcf92a855 v25.17.5
Some checks failed
Default (tags) / security (push) Failing after 0s
Default (tags) / test (push) Failing after 0s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-03-20 07:43:32 +00:00
8e0804cd20 fix(rustproxy): add HTTP/3 integration test for QUIC response stream FIN handling 2026-03-20 07:43:32 +00:00
c63f6fcd5f v25.17.4
Some checks failed
Default (tags) / security (push) Failing after 0s
Default (tags) / test (push) Failing after 0s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-03-20 03:19:57 +00:00
f3cd4d193e fix(rustproxy-http): prevent HTTP/3 response body streaming from hanging on backend completion 2026-03-20 03:19:57 +00:00
81de611255 v25.17.3
Some checks failed
Default (tags) / security (push) Failing after 0s
Default (tags) / test (push) Failing after 0s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-03-20 02:54:44 +00:00
91598b3be9 fix(repository): no changes detected 2026-03-20 02:54:44 +00:00
4e3c548012 v25.17.2
Some checks failed
Default (tags) / security (push) Failing after 0s
Default (tags) / test (push) Failing after 0s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-03-20 02:53:41 +00:00
1a2d7529db fix(rustproxy-http): enable TLS connections for HTTP/3 upstream requests when backend re-encryption or TLS is configured 2026-03-20 02:53:41 +00:00
31514f54ae v25.17.1
Some checks failed
Default (tags) / security (push) Failing after 0s
Default (tags) / test (push) Failing after 0s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-03-20 02:35:22 +00:00
247653c9d0 fix(rustproxy-routing): allow QUIC UDP TLS connections without SNI to match domain-restricted routes 2026-03-20 02:35:22 +00:00
07d88f6f6a v25.17.0
Some checks failed
Default (tags) / security (push) Failing after 0s
Default (tags) / test (push) Failing after 0s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-03-19 23:16:42 +00:00
4b64de2c67 feat(rustproxy-passthrough): add PROXY protocol v2 client IP handling for UDP and QUIC listeners 2026-03-19 23:16:42 +00:00
e8db7bc96d v25.16.3
Some checks failed
Default (tags) / security (push) Failing after 0s
Default (tags) / test (push) Failing after 1s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-03-19 22:00:07 +00:00
2621dea9fa fix(rustproxy): upgrade fallback UDP listeners to QUIC when TLS certificates become available 2026-03-19 22:00:07 +00:00
bb5b9b3d12 v25.16.2
Some checks failed
Default (tags) / security (push) Failing after 12s
Default (tags) / test (push) Failing after 13s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped
2026-03-19 21:24:05 +00:00
d70c2d77ed fix(rustproxy-http): cache backend Alt-Svc only from original upstream responses during protocol auto-detection 2026-03-19 21:24:05 +00:00
138 changed files with 4404 additions and 15443 deletions

View File

@@ -1,5 +1,156 @@
# Changelog
## 2026-03-26 - 27.0.0 - BREAKING CHANGE(smart-proxy)
remove route helper APIs and standardize route configuration on plain route objects
- Removes TypeScript route helper exports and related Rust config helpers in favor of defining routes directly with match and action properties.
- Updates documentation and tests to use plain IRouteConfig objects and SocketHandlers imports instead of helper factory functions.
- Moves socket handlers to a top-level utils export and keeps direct socket-handler route configuration as the supported pattern.
## 2026-03-26 - 26.3.0 - feat(nftables)
move NFTables forwarding management from the Rust engine to @push.rocks/smartnftables
- add @push.rocks/smartnftables as a runtime dependency and export it via the plugin layer
- remove the internal rustproxy-nftables crate along with Rust-side NFTables rule application and status management
- apply and clean up NFTables port-forwarding rules in the TypeScript SmartProxy lifecycle and route update flow
- change getNfTablesStatus to return local smartnftables status instead of querying the Rust bridge
- update README documentation to describe NFTables support as provided through @push.rocks/smartnftables
## 2026-03-26 - 26.2.4 - fix(rustproxy-http)
improve HTTP/3 connection reuse and clean up stale proxy state
- Reuse pooled HTTP/3 SendRequest handles to skip repeated SETTINGS handshakes and reduce request overhead on QUIC pool hits
- Add periodic cleanup for per-route rate limiters and orphaned backend metrics to prevent unbounded memory growth after traffic or backend errors stop
- Enforce HTTP max connection lifetime alongside idle timeouts and apply configured lifetime values from the TCP listener
- Reduce HTTP/3 body copying by using owned Bytes paths for request and response streaming, and replace the custom response body adapter with a stream-based implementation
- Harden auxiliary proxy components by capping datagram handler buffer growth and removing duplicate RustProxy exit listeners
## 2026-03-25 - 26.2.3 - fix(repo)
no changes to commit
## 2026-03-25 - 26.2.2 - fix(proxy)
improve connection cleanup and route validation handling
- add timeouts for HTTP/1 upstream connection drivers to prevent lingering tasks
- ensure QUIC relay sessions cancel and abort background tasks on drop
- avoid registering unnamed routes as duplicates and label unnamed catch-all conflicts clearly
- fix offset mapping route helper to forward only remaining route options without overriding derived values
- update project config filename and toolchain versions for the current build setup
## 2026-03-23 - 26.2.1 - fix(rustproxy-http)
include the upstream request URL when caching H3 Alt-Svc discoveries
- Tracks the request path that triggered Alt-Svc discovery in connection activity state
- Adds request URL context to Alt-Svc debug logging and protocol cache insertion reasons for better traceability
## 2026-03-23 - 26.2.0 - feat(protocol-cache)
add sliding TTL re-probing and eviction for backend protocol detection
- extend protocol cache entries and metrics with last accessed and last probed timestamps
- trigger periodic ALPN re-probes for cached H1/H2 entries while keeping active entries alive with a sliding 1 day TTL
- log protocol transitions with reasons and evict cache entries when all protocol fallback attempts fail
## 2026-03-22 - 26.1.0 - feat(rustproxy-http)
add protocol failure suppression, h3 fallback escalation, and protocol cache metrics exposure
- introduces escalating cooldowns for failed H2/H3 protocol detection to prevent repeated upgrades to unstable backends
- adds within-request escalation to cached HTTP/3 when TCP or TLS backend connections fail in auto-detect mode
- exposes detected protocol cache entries and suppression state through Rust metrics and the TypeScript metrics adapter
## 2026-03-21 - 26.0.0 - BREAKING CHANGE(ts-api,rustproxy)
remove deprecated TypeScript protocol and utility exports while hardening QUIC, HTTP/3, WebSocket, and rate limiter cleanup paths
- Removes large parts of the public TypeScript surface including detection, TLS, router, websocket, proxy/common protocol, and multiple core utility exports and files.
- Adds parent-child cancellation handling for HTTP/3 and QUIC stream forwarding to stop orphaned tasks and close idle or overlong streams.
- Improves cleanup reliability with RAII guards for WebSocket upstream tracking and QUIC connection metrics, plus periodic cleanup for rate limiter and proxy address maps.
- Cleans backend metrics state when active backend connections drop to zero and tracks passthrough backend sockets for shutdown cleanup.
## 2026-03-20 - 25.17.10 - fix(rustproxy-http)
reuse the shared HTTP proxy service for HTTP/3 request handling
- Refactors H3ProxyService to delegate requests to the shared HttpProxyService instead of maintaining separate routing and backend forwarding logic.
- Aligns HTTP/3 with the TCP/HTTP path for route matching, connection pooling, and ALPN-based upstream protocol detection.
- Generalizes request handling and filters to accept boxed/generic HTTP bodies so both HTTP/3 and existing HTTP paths share the same proxy pipeline.
- Updates the HTTP/3 integration route matcher to allow transport matching across shared HTTP and QUIC handling.
## 2026-03-20 - 25.17.9 - fix(rustproxy-http)
correct HTTP/3 host extraction and avoid protocol filtering during UDP route lookup
- Use the URI host or strip the port from the Host header so HTTP/3 requests match routes consistently with TCP/HTTP handling.
- Remove protocol filtering from HTTP/3 route lookup because QUIC transport already constrains routing to UDP and protocol validation happens earlier.
## 2026-03-20 - 25.17.8 - fix(rustproxy)
use SNI-based certificate resolution for QUIC TLS connections
- Replaces static first-certificate selection with the shared CertResolver used by the TCP/TLS path.
- Ensures QUIC connections can present the correct certificate per requested domain.
- Keeps HTTP/3 ALPN configuration while improving multi-domain TLS handling.
## 2026-03-20 - 25.17.7 - fix(readme)
document QUIC and HTTP/3 compatibility caveats
- Add notes explaining that GREASE frames are disabled on both server and client HTTP/3 paths to avoid interoperability issues
- Document that the current HTTP/3 stack depends on pre-1.0 h3 ecosystem components and may still have rough edges
## 2026-03-20 - 25.17.6 - fix(rustproxy-http)
disable HTTP/3 GREASE for client and server connections
- Switch the HTTP/3 server connection setup to use the builder API with send_grease(false)
- Switch the HTTP/3 client handshake to use the builder API with send_grease(false) to improve compatibility
## 2026-03-20 - 25.17.5 - fix(rustproxy)
add HTTP/3 integration test for QUIC response stream FIN handling
- adds an integration test covering HTTP/3 proxying over QUIC with TLS termination
- verifies response bodies fully arrive and the client receives stream termination instead of hanging
- adds test-only dependencies for quinn, h3, h3-quinn, rustls, bytes, and http
## 2026-03-20 - 25.17.4 - fix(rustproxy-http)
prevent HTTP/3 response body streaming from hanging on backend completion
- extract and track Content-Length before consuming the response body
- stop the HTTP/3 body loop when the stream reports end-of-stream or the expected byte count has been sent
- add a per-frame idle timeout to avoid indefinite waits on stalled or close-delimited backend bodies
## 2026-03-20 - 25.17.3 - fix(repository)
no changes detected
## 2026-03-20 - 25.17.2 - fix(rustproxy-http)
enable TLS connections for HTTP/3 upstream requests when backend re-encryption or TLS is configured
- Pass backend TLS client configuration into the HTTP/3 request handler.
- Detect TLS-required upstream targets using route and target TLS settings before connecting.
- Build backend request URIs with the correct http or https scheme to match the upstream connection.
## 2026-03-20 - 25.17.1 - fix(rustproxy-routing)
allow QUIC UDP TLS connections without SNI to match domain-restricted routes
- Exempts UDP transport from the no-SNI rejection logic because QUIC encrypts the TLS ClientHello and SNI is unavailable at accept time
- Adds regression tests to confirm QUIC route matching succeeds without SNI while TCP TLS without SNI remains rejected
## 2026-03-19 - 25.17.0 - feat(rustproxy-passthrough)
add PROXY protocol v2 client IP handling for UDP and QUIC listeners
- propagate trusted proxy IP configuration into UDP and QUIC listener managers
- extract and preserve real client addresses from PROXY protocol v2 headers for HTTP/3 and QUIC stream handling
- apply rate limiting, session limits, routing, and metrics using the resolved client IP while preserving correct proxy return-path routing
## 2026-03-19 - 25.16.3 - fix(rustproxy)
upgrade fallback UDP listeners to QUIC when TLS certificates become available
- Rebuild and apply QUIC TLS configuration during route and certificate updates instead of only when adding new UDP ports.
- Add logic to drain UDP sessions, stop raw fallback listeners, and start QUIC endpoints on existing ports once TLS is available.
- Retry QUIC endpoint creation during upgrade and fall back to rebinding raw UDP if the upgrade cannot complete.
## 2026-03-19 - 25.16.2 - fix(rustproxy-http)
cache backend Alt-Svc only from original upstream responses during protocol auto-detection
- Moves Alt-Svc discovery into streaming response construction so it reads backend headers before response filters inject client-facing Alt-Svc values
- Stores the protocol cache key in connection activity during auto-detect mode and clears it after HTTP/3 connection failure to avoid re-caching failed H3 routes
- Prevents fallback requests from reintroducing stale or self-injected Alt-Svc entries that could cause repeated H3 retry loops
## 2026-03-19 - 25.16.1 - fix(http-proxy)
avoid repeated HTTP/3 recaching after QUIC fallback and document backend protocol selection

780
deno.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -1,4 +1,4 @@
Copyright (c) 2019 Lossless GmbH (hello@lossless.com)
Copyright (c) 2019 Task Venture Capital GmbH (hello@task.vc)
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal

View File

@@ -1,6 +1,6 @@
{
"name": "@push.rocks/smartproxy",
"version": "25.16.1",
"version": "27.0.0",
"private": false,
"description": "A powerful proxy package with unified route-based configuration for high traffic management. Features include SSL/TLS support, flexible routing patterns, WebSocket handling, advanced security options, and automatic ACME certificate management.",
"main": "dist_ts/index.js",
@@ -16,18 +16,19 @@
"buildDocs": "tsdoc"
},
"devDependencies": {
"@git.zone/tsbuild": "^4.3.0",
"@git.zone/tsrun": "^2.0.1",
"@git.zone/tsrust": "^1.3.0",
"@git.zone/tstest": "^3.5.0",
"@push.rocks/smartserve": "^2.0.1",
"@git.zone/tsbuild": "^4.4.0",
"@git.zone/tsrun": "^2.0.2",
"@git.zone/tsrust": "^1.3.2",
"@git.zone/tstest": "^3.6.0",
"@push.rocks/smartserve": "^2.0.3",
"@types/node": "^25.5.0",
"typescript": "^5.9.3",
"typescript": "^6.0.2",
"why-is-node-running": "^3.2.2"
},
"dependencies": {
"@push.rocks/smartcrypto": "^2.0.4",
"@push.rocks/smartlog": "^3.2.1",
"@push.rocks/smartnftables": "^1.0.1",
"@push.rocks/smartrust": "^1.3.2",
"@tsclass/tsclass": "^9.5.0",
"minimatch": "^10.2.4"
@@ -41,7 +42,7 @@
"dist_ts_web/**/*",
"assets/**/*",
"cli.js",
"npmextra.json",
".smartconfig.json",
"readme.md",
"changelog.md"
],

506
pnpm-lock.yaml generated

File diff suppressed because it is too large Load Diff

View File

@@ -462,35 +462,57 @@ For TLS termination modes (`terminate` and `terminate-and-reencrypt`), SmartProx
**HTTP to HTTPS Redirect**:
```typescript
import { createHttpToHttpsRedirect } from '@push.rocks/smartproxy';
import { SocketHandlers } from '@push.rocks/smartproxy';
const redirectRoute = createHttpToHttpsRedirect(['example.com', 'www.example.com']);
const redirectRoute = {
name: 'http-to-https',
match: { ports: 80, domains: ['example.com', 'www.example.com'] },
action: {
type: 'socket-handler' as const,
socketHandler: SocketHandlers.httpRedirect('https://{domain}:443{path}', 301)
}
};
```
**Complete HTTPS Server (with redirect)**:
```typescript
import { createCompleteHttpsServer } from '@push.rocks/smartproxy';
const routes = createCompleteHttpsServer(
'example.com',
{ host: 'localhost', port: 8080 },
{ certificate: 'auto' }
);
const routes = [
{
name: 'https-server',
match: { ports: 443, domains: 'example.com' },
action: {
type: 'forward' as const,
targets: [{ host: 'localhost', port: 8080 }],
tls: { mode: 'terminate' as const, certificate: 'auto' as const }
}
},
{
name: 'http-redirect',
match: { ports: 80, domains: 'example.com' },
action: {
type: 'socket-handler' as const,
socketHandler: SocketHandlers.httpRedirect('https://{domain}:443{path}', 301)
}
}
];
```
**Load Balancer with Health Checks**:
```typescript
import { createLoadBalancerRoute } from '@push.rocks/smartproxy';
const lbRoute = createLoadBalancerRoute(
'api.example.com',
[
{ host: 'backend1', port: 8080 },
{ host: 'backend2', port: 8080 },
{ host: 'backend3', port: 8080 }
],
{ tls: { mode: 'terminate', certificate: 'auto' } }
);
const lbRoute = {
name: 'load-balancer',
match: { ports: 443, domains: 'api.example.com' },
action: {
type: 'forward' as const,
targets: [
{ host: 'backend1', port: 8080 },
{ host: 'backend2', port: 8080 },
{ host: 'backend3', port: 8080 }
],
tls: { mode: 'terminate' as const, certificate: 'auto' as const },
loadBalancing: { algorithm: 'round-robin' as const }
}
};
```
### Smart SNI Requirement (v22.3+)

328
readme.md
View File

@@ -1,6 +1,6 @@
# @push.rocks/smartproxy 🚀
**A high-performance, Rust-powered proxy toolkit for Node.js** — unified route-based configuration for SSL/TLS termination, HTTP/HTTPS reverse proxying, WebSocket support, UDP/QUIC/HTTP3, load balancing, custom protocol handlers, and kernel-level NFTables forwarding.
**A high-performance, Rust-powered proxy toolkit for Node.js** — unified route-based configuration for SSL/TLS termination, HTTP/HTTPS reverse proxying, WebSocket support, UDP/QUIC/HTTP3, load balancing, custom protocol handlers, and kernel-level NFTables forwarding via [`@push.rocks/smartnftables`](https://code.foss.global/push.rocks/smartnftables).
## 📦 Installation
@@ -44,7 +44,7 @@ Whether you're building microservices, deploying edge infrastructure, proxying U
Get up and running in 30 seconds:
```typescript
import { SmartProxy, createCompleteHttpsServer } from '@push.rocks/smartproxy';
import { SmartProxy, SocketHandlers } from '@push.rocks/smartproxy';
// Create a proxy with automatic HTTPS
const proxy = new SmartProxy({
@@ -53,13 +53,25 @@ const proxy = new SmartProxy({
useProduction: true
},
routes: [
// Complete HTTPS setup in one call! ✨
...createCompleteHttpsServer('app.example.com', {
host: 'localhost',
port: 3000
}, {
certificate: 'auto' // Automatic Let's Encrypt cert 🎩
})
// HTTPS route with automatic Let's Encrypt cert
{
name: 'https-app',
match: { ports: 443, domains: 'app.example.com' },
action: {
type: 'forward',
targets: [{ host: 'localhost', port: 3000 }],
tls: { mode: 'terminate', certificate: 'auto' }
}
},
// HTTP → HTTPS redirect
{
name: 'http-redirect',
match: { ports: 80, domains: 'app.example.com' },
action: {
type: 'socket-handler',
socketHandler: SocketHandlers.httpRedirect('https://{domain}:443{path}', 301)
}
}
]
});
@@ -111,31 +123,38 @@ SmartProxy supports three TLS handling modes:
### 🌐 HTTP to HTTPS Redirect
```typescript
import { SmartProxy, createHttpToHttpsRedirect } from '@push.rocks/smartproxy';
import { SmartProxy, SocketHandlers } from '@push.rocks/smartproxy';
const proxy = new SmartProxy({
routes: [
createHttpToHttpsRedirect(['example.com', '*.example.com'])
]
routes: [{
name: 'http-to-https',
match: { ports: 80, domains: ['example.com', '*.example.com'] },
action: {
type: 'socket-handler',
socketHandler: SocketHandlers.httpRedirect('https://{domain}:443{path}', 301)
}
}]
});
```
### ⚖️ Load Balancer with Health Checks
```typescript
import { SmartProxy, createLoadBalancerRoute } from '@push.rocks/smartproxy';
import { SmartProxy } from '@push.rocks/smartproxy';
const proxy = new SmartProxy({
routes: [
createLoadBalancerRoute(
'app.example.com',
[
routes: [{
name: 'load-balancer',
match: { ports: 443, domains: 'app.example.com' },
action: {
type: 'forward',
targets: [
{ host: 'server1.internal', port: 8080 },
{ host: 'server2.internal', port: 8080 },
{ host: 'server3.internal', port: 8080 }
],
{
tls: { mode: 'terminate', certificate: 'auto' },
tls: { mode: 'terminate', certificate: 'auto' },
loadBalancing: {
algorithm: 'round-robin',
healthCheck: {
path: '/health',
@@ -145,57 +164,68 @@ const proxy = new SmartProxy({
healthyThreshold: 2
}
}
)
]
}
}]
});
```
### 🔌 WebSocket Proxy
```typescript
import { SmartProxy, createWebSocketRoute } from '@push.rocks/smartproxy';
import { SmartProxy } from '@push.rocks/smartproxy';
const proxy = new SmartProxy({
routes: [
createWebSocketRoute(
'ws.example.com',
{ host: 'websocket-server', port: 8080 },
{
path: '/socket',
useTls: true,
certificate: 'auto',
routes: [{
name: 'websocket',
match: { ports: 443, domains: 'ws.example.com', path: '/socket' },
priority: 100,
action: {
type: 'forward',
targets: [{ host: 'websocket-server', port: 8080 }],
tls: { mode: 'terminate', certificate: 'auto' },
websocket: {
enabled: true,
pingInterval: 30000,
pingTimeout: 10000
}
)
]
}
}]
});
```
### 🚦 API Gateway with Rate Limiting
```typescript
import { SmartProxy, createApiGatewayRoute, addRateLimiting } from '@push.rocks/smartproxy';
import { SmartProxy } from '@push.rocks/smartproxy';
let apiRoute = createApiGatewayRoute(
'api.example.com',
'/api',
{ host: 'api-backend', port: 8080 },
{
useTls: true,
certificate: 'auto',
addCorsHeaders: true
}
);
// Add rate limiting — 100 requests per minute per IP
apiRoute = addRateLimiting(apiRoute, {
maxRequests: 100,
window: 60,
keyBy: 'ip'
const proxy = new SmartProxy({
routes: [{
name: 'api-gateway',
match: { ports: 443, domains: 'api.example.com', path: '/api/*' },
priority: 100,
action: {
type: 'forward',
targets: [{ host: 'api-backend', port: 8080 }],
tls: { mode: 'terminate', certificate: 'auto' }
},
headers: {
response: {
'Access-Control-Allow-Origin': '*',
'Access-Control-Allow-Methods': 'GET, POST, PUT, DELETE, OPTIONS',
'Access-Control-Allow-Headers': 'Content-Type, Authorization',
'Access-Control-Max-Age': '86400'
}
},
security: {
rateLimit: {
enabled: true,
maxRequests: 100,
window: 60,
keyBy: 'ip'
}
}
}]
});
const proxy = new SmartProxy({ routes: [apiRoute] });
```
### 🎮 Custom Protocol Handler (TCP)
@@ -203,36 +233,40 @@ const proxy = new SmartProxy({ routes: [apiRoute] });
SmartProxy lets you implement any protocol with full socket control. Routes with JavaScript socket handlers are automatically relayed from the Rust engine back to your TypeScript code:
```typescript
import { SmartProxy, createSocketHandlerRoute, SocketHandlers } from '@push.rocks/smartproxy';
import { SmartProxy, SocketHandlers } from '@push.rocks/smartproxy';
// Use pre-built handlers
const echoRoute = createSocketHandlerRoute(
'echo.example.com',
7777,
SocketHandlers.echo
);
const proxy = new SmartProxy({
routes: [
// Use pre-built handlers
{
name: 'echo-server',
match: { ports: 7777, domains: 'echo.example.com' },
action: { type: 'socket-handler', socketHandler: SocketHandlers.echo }
},
// Or create your own custom protocol
{
name: 'custom-protocol',
match: { ports: 9999, domains: 'custom.example.com' },
action: {
type: 'socket-handler',
socketHandler: async (socket) => {
console.log(`New connection on custom protocol`);
socket.write('Welcome to my custom protocol!\n');
// Or create your own custom protocol
const customRoute = createSocketHandlerRoute(
'custom.example.com',
9999,
async (socket) => {
console.log(`New connection on custom protocol`);
socket.write('Welcome to my custom protocol!\n');
socket.on('data', (data) => {
const command = data.toString().trim();
switch (command) {
case 'PING': socket.write('PONG\n'); break;
case 'TIME': socket.write(`${new Date().toISOString()}\n`); break;
case 'QUIT': socket.end('Goodbye!\n'); break;
default: socket.write(`Unknown: ${command}\n`);
socket.on('data', (data) => {
const command = data.toString().trim();
switch (command) {
case 'PING': socket.write('PONG\n'); break;
case 'TIME': socket.write(`${new Date().toISOString()}\n`); break;
case 'QUIT': socket.end('Goodbye!\n'); break;
default: socket.write(`Unknown: ${command}\n`);
}
});
}
}
});
}
);
const proxy = new SmartProxy({ routes: [echoRoute, customRoute] });
}
]
});
```
**Pre-built Socket Handlers:**
@@ -384,23 +418,26 @@ const dualStackRoute: IRouteConfig = {
### ⚡ High-Performance NFTables Forwarding
For ultra-low latency on Linux, use kernel-level forwarding (requires root):
For ultra-low latency on Linux, use kernel-level forwarding via [`@push.rocks/smartnftables`](https://code.foss.global/push.rocks/smartnftables) (requires root):
```typescript
import { SmartProxy, createNfTablesTerminateRoute } from '@push.rocks/smartproxy';
import { SmartProxy } from '@push.rocks/smartproxy';
const proxy = new SmartProxy({
routes: [
createNfTablesTerminateRoute(
'fast.example.com',
{ host: 'backend', port: 8080 },
{
ports: 443,
certificate: 'auto',
routes: [{
name: 'nftables-fast',
match: { ports: 443, domains: 'fast.example.com' },
action: {
type: 'forward',
forwardingEngine: 'nftables',
targets: [{ host: 'backend', port: 8080 }],
tls: { mode: 'terminate', certificate: 'auto' },
nftables: {
protocol: 'tcp',
preserveSourceIP: true // Backend sees real client IP
}
)
]
}
}]
});
```
@@ -409,15 +446,18 @@ const proxy = new SmartProxy({
Forward encrypted traffic to backends without terminating TLS — the proxy routes based on the SNI hostname alone:
```typescript
import { SmartProxy, createHttpsPassthroughRoute } from '@push.rocks/smartproxy';
import { SmartProxy } from '@push.rocks/smartproxy';
const proxy = new SmartProxy({
routes: [
createHttpsPassthroughRoute('secure.example.com', {
host: 'backend-that-handles-tls',
port: 8443
})
]
routes: [{
name: 'sni-passthrough',
match: { ports: 443, domains: 'secure.example.com' },
action: {
type: 'forward',
targets: [{ host: 'backend-that-handles-tls', port: 8443 }],
tls: { mode: 'passthrough' }
}
}]
});
```
@@ -524,15 +564,7 @@ Comprehensive per-route security options:
}
```
**Security modifier helpers** let you add security to any existing route:
```typescript
import { addRateLimiting, addBasicAuth, addJwtAuth } from '@push.rocks/smartproxy';
let route = createHttpsTerminateRoute('api.example.com', { host: 'backend', port: 8080 });
route = addRateLimiting(route, { maxRequests: 100, window: 60, keyBy: 'ip' });
route = addBasicAuth(route, { users: [{ username: 'admin', password: 'secret' }] });
```
Security options are configured directly on each route's `security` property — no separate helpers needed.
### 📊 Runtime Management
@@ -694,22 +726,26 @@ SmartProxy uses a hybrid **Rust + TypeScript** architecture:
│ │ Listener│ │ Reverse │ │ Matcher │ │ Cert Mgr │ │
│ │ │ │ Proxy │ │ │ │ │ │
│ └─────────┘ └─────────┘ └─────────┘ └──────────┘ │
│ ┌─────────┐ ┌─────────┐ ┌─────────┐ ┌──────────┐
│ │ UDP │ │ Security│ │ Metrics │ │ NFTables │
│ │ QUIC │ │ Enforce │ │ Collect │ Mgr
│ │ HTTP/3 │ │ │ │ │
│ └─────────┘ └─────────┘ └─────────┘ └──────────┘
│ ┌─────────┐ ┌─────────┐ ┌─────────┐
│ │ UDP │ │ Security│ │ Metrics │
│ │ QUIC │ │ Enforce │ │ Collect │
│ │ HTTP/3 │ │ │ │ │
│ └─────────┘ └─────────┘ └─────────┘
└──────────────────┬──────────────────────────────────┘
│ Unix Socket Relay
┌──────────────────▼──────────────────────────────────┐
│ TypeScript Socket & Datagram Handler Servers │
│ (for JS socket handlers, datagram handlers, │
│ and dynamic routes) │
├─────────────────────────────────────────────────────┤
│ @push.rocks/smartnftables (kernel-level NFTables) │
│ (DNAT/SNAT, firewall, rate limiting via nft CLI) │
└─────────────────────────────────────────────────────┘
```
- **Rust Engine** handles all networking: TCP, UDP, TLS, QUIC, HTTP proxying, connection management, security, and metrics
- **TypeScript** provides the npm API, configuration types, route helpers, validation, and handler callbacks
- **TypeScript** provides the npm API, configuration types, validation, and handler callbacks
- **NFTables** managed by [`@push.rocks/smartnftables`](https://code.foss.global/push.rocks/smartnftables) — kernel-level DNAT/SNAT forwarding, firewall rules, and rate limiting via the `nft` CLI
- **IPC** — The TypeScript wrapper uses JSON commands/events over stdin/stdout to communicate with the Rust binary
- **Socket/Datagram Relay** — Unix domain socket servers for routes requiring TypeScript-side handling (socket handlers, datagram handlers, dynamic host/port functions)
@@ -854,47 +890,13 @@ interface IRouteQuic {
}
```
## 🛠️ Helper Functions Reference
All helpers are fully typed and return `IRouteConfig` or `IRouteConfig[]`:
## 🛠️ Exports Reference
```typescript
import {
// HTTP/HTTPS
createHttpRoute, // Plain HTTP route
createHttpsTerminateRoute, // HTTPS with TLS termination
createHttpsPassthroughRoute, // SNI passthrough (no termination)
createHttpToHttpsRedirect, // HTTP → HTTPS redirect
createCompleteHttpsServer, // HTTPS + redirect combo (returns IRouteConfig[])
// Load Balancing
createLoadBalancerRoute, // Multi-backend with health checks
createSmartLoadBalancer, // Dynamic domain-based backend selection
// API & WebSocket
createApiRoute, // API route with path matching
createApiGatewayRoute, // API gateway with CORS
createWebSocketRoute, // WebSocket-enabled route
// Custom Protocols
createSocketHandlerRoute, // Custom TCP socket handler
SocketHandlers, // Pre-built handlers (echo, proxy, block, etc.)
// NFTables (Linux, requires root)
createNfTablesRoute, // Kernel-level packet forwarding
createNfTablesTerminateRoute, // NFTables + TLS termination
createCompleteNfTablesHttpsServer, // NFTables HTTPS + redirect combo
// Dynamic Routing
createPortMappingRoute, // Port mapping with context
createOffsetPortMappingRoute, // Simple port offset
createDynamicRoute, // Dynamic host/port via functions
createPortOffset, // Port offset factory
// Security Modifiers
addRateLimiting, // Add rate limiting to any route
addBasicAuth, // Add basic auth to any route
addJwtAuth, // Add JWT auth to any route
// Core
SmartProxy, // Main proxy class
SocketHandlers, // Pre-built socket handlers (echo, proxy, block, httpRedirect, httpServer, etc.)
// Route Utilities
mergeRouteConfigs, // Deep-merge two route configs
@@ -906,7 +908,7 @@ import {
} from '@push.rocks/smartproxy';
```
> **Tip:** For UDP datagram handler routes or QUIC/HTTP3 routes, construct `IRouteConfig` objects directly — there are no helper functions for these yet. See the [UDP Datagram Handler](#-udp-datagram-handler) and [QUIC / HTTP3 Forwarding](#-quic--http3-forwarding) examples above.
All routes are configured as plain `IRouteConfig` objects with `match` and `action` properties — see the examples throughout this document.
## 📖 API Documentation
@@ -938,8 +940,8 @@ class SmartProxy extends EventEmitter {
getCertificateStatus(routeName: string): Promise<any>;
getEligibleDomainsForCertificates(): string[];
// NFTables
getNfTablesStatus(): Promise<Record<string, any>>;
// NFTables (managed by @push.rocks/smartnftables)
getNfTablesStatus(): INftStatus | null;
// Events
on(event: 'error', handler: (err: Error) => void): this;
@@ -991,11 +993,11 @@ interface ISmartProxyOptions {
sendProxyProtocol?: boolean; // Send PROXY protocol to targets
// Timeouts
connectionTimeout?: number; // Backend connection timeout (default: 30s)
initialDataTimeout?: number; // Initial data/SNI timeout (default: 120s)
socketTimeout?: number; // Socket inactivity timeout (default: 1h)
maxConnectionLifetime?: number; // Max connection lifetime (default: 24h)
inactivityTimeout?: number; // Inactivity timeout (default: 4h)
connectionTimeout?: number; // Backend connection timeout (default: 60s)
initialDataTimeout?: number; // Initial data/SNI timeout (default: 60s)
socketTimeout?: number; // Socket inactivity timeout (default: 60s)
maxConnectionLifetime?: number; // Max connection lifetime (default: 1h)
inactivityTimeout?: number; // Inactivity timeout (default: 75s)
gracefulShutdownTimeout?: number; // Shutdown grace period (default: 30s)
// Connection limits
@@ -1004,8 +1006,8 @@ interface ISmartProxyOptions {
// Keep-alive
keepAliveTreatment?: 'standard' | 'extended' | 'immortal';
keepAliveInactivityMultiplier?: number; // (default: 6)
extendedKeepAliveLifetime?: number; // (default: 7 days)
keepAliveInactivityMultiplier?: number; // (default: 4)
extendedKeepAliveLifetime?: number; // (default: 1h)
// Metrics
metrics?: {
@@ -1111,6 +1113,10 @@ SmartProxy searches for the Rust binary in this order:
5. Local dev build (`./rust/target/release/rustproxy`)
6. System PATH (`rustproxy`)
### QUIC / HTTP3 Caveats
- **GREASE frames are disabled.** The underlying h3 crate sends [GREASE frames](https://www.rfc-editor.org/rfc/rfc9114.html#frame-reserved) by default to test protocol extensibility. However, some HTTP/3 clients and servers don't properly ignore unknown frame types, causing 400/500 errors or stream hangs ([h3#206](https://github.com/hyperium/h3/issues/206)). SmartProxy disables GREASE on both the server side (for incoming H3 requests) and the client side (for H3 backend connections) to maximize compatibility.
- **HTTP/3 is pre-release.** The h3 ecosystem (h3 0.0.8, h3-quinn 0.0.10, quinn 0.11) is still pre-1.0. Expect rough edges.
### Performance Tuning
- ✅ Use NFTables forwarding for high-traffic routes (Linux only)
- ✅ Enable connection keep-alive where appropriate
@@ -1133,7 +1139,7 @@ SmartProxy searches for the Rust binary in this order:
## License and Legal Information
This repository contains open-source code licensed under the MIT License. A copy of the license can be found in the [LICENSE](./LICENSE) file.
This repository contains open-source code licensed under the MIT License. A copy of the license can be found in the [license](./license) file.
**Please note:** The MIT License does not grant permission to use the trade names, trademarks, service marks, or product names of the project, except as required for reasonable and customary use in describing the origin of the work and reproducing the content of the NOTICE file.

20
rust/Cargo.lock generated
View File

@@ -1224,17 +1224,20 @@ dependencies = [
"bytes",
"clap",
"dashmap",
"h3",
"h3-quinn",
"http",
"http-body-util",
"hyper",
"hyper-util",
"mimalloc",
"quinn",
"rcgen",
"rustls",
"rustls-pemfile",
"rustproxy-config",
"rustproxy-http",
"rustproxy-metrics",
"rustproxy-nftables",
"rustproxy-passthrough",
"rustproxy-routing",
"rustproxy-security",
@@ -1266,6 +1269,7 @@ dependencies = [
"arc-swap",
"bytes",
"dashmap",
"futures",
"h3",
"h3-quinn",
"http-body",
@@ -1299,20 +1303,6 @@ dependencies = [
"tracing",
]
[[package]]
name = "rustproxy-nftables"
version = "0.1.0"
dependencies = [
"anyhow",
"libc",
"rustproxy-config",
"serde",
"serde_json",
"thiserror 2.0.18",
"tokio",
"tracing",
]
[[package]]
name = "rustproxy-passthrough"
version = "0.1.0"

View File

@@ -7,7 +7,6 @@ members = [
"crates/rustproxy-tls",
"crates/rustproxy-passthrough",
"crates/rustproxy-http",
"crates/rustproxy-nftables",
"crates/rustproxy-metrics",
"crates/rustproxy-security",
]
@@ -107,6 +106,5 @@ rustproxy-routing = { path = "crates/rustproxy-routing" }
rustproxy-tls = { path = "crates/rustproxy-tls" }
rustproxy-passthrough = { path = "crates/rustproxy-passthrough" }
rustproxy-http = { path = "crates/rustproxy-http" }
rustproxy-nftables = { path = "crates/rustproxy-nftables" }
rustproxy-metrics = { path = "crates/rustproxy-metrics" }
rustproxy-security = { path = "crates/rustproxy-security" }

View File

@@ -1,345 +0,0 @@
use crate::route_types::*;
use crate::tls_types::*;
/// Create a simple HTTP forwarding route.
/// Equivalent to SmartProxy's `createHttpRoute()`.
pub fn create_http_route(
domains: impl Into<DomainSpec>,
target_host: impl Into<String>,
target_port: u16,
) -> RouteConfig {
RouteConfig {
id: None,
route_match: RouteMatch {
ports: PortRange::Single(80),
domains: Some(domains.into()),
path: None,
client_ip: None,
transport: None,
tls_version: None,
headers: None,
protocol: None,
},
action: RouteAction {
action_type: RouteActionType::Forward,
targets: Some(vec![RouteTarget {
target_match: None,
host: HostSpec::Single(target_host.into()),
port: PortSpec::Fixed(target_port),
tls: None,
websocket: None,
load_balancing: None,
send_proxy_protocol: None,
headers: None,
advanced: None,
backend_transport: None,
priority: None,
}]),
tls: None,
websocket: None,
load_balancing: None,
advanced: None,
options: None,
forwarding_engine: None,
nftables: None,
send_proxy_protocol: None,
udp: None,
},
headers: None,
security: None,
name: None,
description: None,
priority: None,
tags: None,
enabled: None,
}
}
/// Create an HTTPS termination route.
/// Equivalent to SmartProxy's `createHttpsTerminateRoute()`.
pub fn create_https_terminate_route(
domains: impl Into<DomainSpec>,
target_host: impl Into<String>,
target_port: u16,
) -> RouteConfig {
let mut route = create_http_route(domains, target_host, target_port);
route.route_match.ports = PortRange::Single(443);
route.action.tls = Some(RouteTls {
mode: TlsMode::Terminate,
certificate: Some(CertificateSpec::Auto("auto".to_string())),
acme: None,
versions: None,
ciphers: None,
honor_cipher_order: None,
session_timeout: None,
});
route
}
/// Create a TLS passthrough route.
/// Equivalent to SmartProxy's `createHttpsPassthroughRoute()`.
pub fn create_https_passthrough_route(
domains: impl Into<DomainSpec>,
target_host: impl Into<String>,
target_port: u16,
) -> RouteConfig {
let mut route = create_http_route(domains, target_host, target_port);
route.route_match.ports = PortRange::Single(443);
route.action.tls = Some(RouteTls {
mode: TlsMode::Passthrough,
certificate: None,
acme: None,
versions: None,
ciphers: None,
honor_cipher_order: None,
session_timeout: None,
});
route
}
/// Create an HTTP-to-HTTPS redirect route.
/// Equivalent to SmartProxy's `createHttpToHttpsRedirect()`.
pub fn create_http_to_https_redirect(
domains: impl Into<DomainSpec>,
) -> RouteConfig {
let domains = domains.into();
RouteConfig {
id: None,
route_match: RouteMatch {
ports: PortRange::Single(80),
domains: Some(domains),
path: None,
client_ip: None,
transport: None,
tls_version: None,
headers: None,
protocol: None,
},
action: RouteAction {
action_type: RouteActionType::Forward,
targets: None,
tls: None,
websocket: None,
load_balancing: None,
advanced: Some(RouteAdvanced {
timeout: None,
headers: None,
keep_alive: None,
static_files: None,
test_response: Some(RouteTestResponse {
status: 301,
headers: {
let mut h = std::collections::HashMap::new();
h.insert("Location".to_string(), "https://{domain}{path}".to_string());
h
},
body: String::new(),
}),
url_rewrite: None,
}),
options: None,
forwarding_engine: None,
nftables: None,
send_proxy_protocol: None,
udp: None,
},
headers: None,
security: None,
name: Some("HTTP to HTTPS Redirect".to_string()),
description: None,
priority: None,
tags: None,
enabled: None,
}
}
/// Create a complete HTTPS server with HTTP redirect.
/// Equivalent to SmartProxy's `createCompleteHttpsServer()`.
pub fn create_complete_https_server(
domain: impl Into<String>,
target_host: impl Into<String>,
target_port: u16,
) -> Vec<RouteConfig> {
let domain = domain.into();
let target_host = target_host.into();
vec![
create_http_to_https_redirect(DomainSpec::Single(domain.clone())),
create_https_terminate_route(
DomainSpec::Single(domain),
target_host,
target_port,
),
]
}
/// Create a load balancer route.
/// Equivalent to SmartProxy's `createLoadBalancerRoute()`.
pub fn create_load_balancer_route(
domains: impl Into<DomainSpec>,
targets: Vec<(String, u16)>,
tls: Option<RouteTls>,
) -> RouteConfig {
let route_targets: Vec<RouteTarget> = targets
.into_iter()
.map(|(host, port)| RouteTarget {
target_match: None,
host: HostSpec::Single(host),
port: PortSpec::Fixed(port),
tls: None,
websocket: None,
load_balancing: None,
send_proxy_protocol: None,
headers: None,
advanced: None,
backend_transport: None,
priority: None,
})
.collect();
let port = if tls.is_some() { 443 } else { 80 };
RouteConfig {
id: None,
route_match: RouteMatch {
ports: PortRange::Single(port),
domains: Some(domains.into()),
path: None,
client_ip: None,
transport: None,
tls_version: None,
headers: None,
protocol: None,
},
action: RouteAction {
action_type: RouteActionType::Forward,
targets: Some(route_targets),
tls,
websocket: None,
load_balancing: Some(RouteLoadBalancing {
algorithm: LoadBalancingAlgorithm::RoundRobin,
health_check: None,
}),
advanced: None,
options: None,
forwarding_engine: None,
nftables: None,
send_proxy_protocol: None,
udp: None,
},
headers: None,
security: None,
name: Some("Load Balancer".to_string()),
description: None,
priority: None,
tags: None,
enabled: None,
}
}
// Convenience conversions for DomainSpec
impl From<&str> for DomainSpec {
fn from(s: &str) -> Self {
DomainSpec::Single(s.to_string())
}
}
impl From<String> for DomainSpec {
fn from(s: String) -> Self {
DomainSpec::Single(s)
}
}
impl From<Vec<String>> for DomainSpec {
fn from(v: Vec<String>) -> Self {
DomainSpec::List(v)
}
}
impl From<Vec<&str>> for DomainSpec {
fn from(v: Vec<&str>) -> Self {
DomainSpec::List(v.into_iter().map(|s| s.to_string()).collect())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tls_types::TlsMode;
#[test]
fn test_create_http_route() {
let route = create_http_route("example.com", "localhost", 8080);
assert_eq!(route.route_match.ports.to_ports(), vec![80]);
let domains = route.route_match.domains.as_ref().unwrap().to_vec();
assert_eq!(domains, vec!["example.com"]);
let target = &route.action.targets.as_ref().unwrap()[0];
assert_eq!(target.host.first(), "localhost");
assert_eq!(target.port.resolve(80), 8080);
assert!(route.action.tls.is_none());
}
#[test]
fn test_create_https_terminate_route() {
let route = create_https_terminate_route("api.example.com", "backend", 3000);
assert_eq!(route.route_match.ports.to_ports(), vec![443]);
let tls = route.action.tls.as_ref().unwrap();
assert_eq!(tls.mode, TlsMode::Terminate);
assert!(tls.certificate.as_ref().unwrap().is_auto());
}
#[test]
fn test_create_https_passthrough_route() {
let route = create_https_passthrough_route("secure.example.com", "backend", 443);
assert_eq!(route.route_match.ports.to_ports(), vec![443]);
let tls = route.action.tls.as_ref().unwrap();
assert_eq!(tls.mode, TlsMode::Passthrough);
assert!(tls.certificate.is_none());
}
#[test]
fn test_create_http_to_https_redirect() {
let route = create_http_to_https_redirect("example.com");
assert_eq!(route.route_match.ports.to_ports(), vec![80]);
assert!(route.action.targets.is_none());
let test_response = route.action.advanced.as_ref().unwrap().test_response.as_ref().unwrap();
assert_eq!(test_response.status, 301);
assert!(test_response.headers.contains_key("Location"));
}
#[test]
fn test_create_complete_https_server() {
let routes = create_complete_https_server("example.com", "backend", 8080);
assert_eq!(routes.len(), 2);
// First route is HTTP redirect
assert_eq!(routes[0].route_match.ports.to_ports(), vec![80]);
// Second route is HTTPS terminate
assert_eq!(routes[1].route_match.ports.to_ports(), vec![443]);
}
#[test]
fn test_create_load_balancer_route() {
let targets = vec![
("backend1".to_string(), 8080),
("backend2".to_string(), 8080),
("backend3".to_string(), 8080),
];
let route = create_load_balancer_route("*.example.com", targets, None);
assert_eq!(route.route_match.ports.to_ports(), vec![80]);
assert_eq!(route.action.targets.as_ref().unwrap().len(), 3);
let lb = route.action.load_balancing.as_ref().unwrap();
assert_eq!(lb.algorithm, LoadBalancingAlgorithm::RoundRobin);
}
#[test]
fn test_domain_spec_from_str() {
let spec: DomainSpec = "example.com".into();
assert_eq!(spec.to_vec(), vec!["example.com"]);
}
#[test]
fn test_domain_spec_from_vec() {
let spec: DomainSpec = vec!["a.com", "b.com"].into();
assert_eq!(spec.to_vec(), vec!["a.com", "b.com"]);
}
}

View File

@@ -8,7 +8,6 @@ pub mod proxy_options;
pub mod tls_types;
pub mod security_types;
pub mod validation;
pub mod helpers;
// Re-export all primary types
pub use route_types::*;
@@ -16,4 +15,3 @@ pub use proxy_options::*;
pub use tls_types::*;
pub use security_types::*;
pub use validation::*;
pub use helpers::*;

View File

@@ -331,12 +331,48 @@ impl RustProxyOptions {
#[cfg(test)]
mod tests {
use super::*;
use crate::helpers::*;
use crate::route_types::*;
use crate::tls_types::*;
fn make_route(domain: &str, host: &str, port: u16, listen_port: u16) -> RouteConfig {
RouteConfig {
id: None,
route_match: RouteMatch {
ports: PortRange::Single(listen_port),
domains: Some(DomainSpec::Single(domain.to_string())),
path: None, client_ip: None, transport: None, tls_version: None, headers: None, protocol: None,
},
action: RouteAction {
action_type: RouteActionType::Forward,
targets: Some(vec![RouteTarget {
target_match: None,
host: HostSpec::Single(host.to_string()),
port: PortSpec::Fixed(port),
tls: None, websocket: None, load_balancing: None, send_proxy_protocol: None,
headers: None, advanced: None, backend_transport: None, priority: None,
}]),
tls: None, websocket: None, load_balancing: None, advanced: None,
options: None, send_proxy_protocol: None, udp: None,
},
headers: None, security: None, name: None, description: None,
priority: None, tags: None, enabled: None,
}
}
fn make_passthrough_route(domain: &str, host: &str, port: u16) -> RouteConfig {
let mut route = make_route(domain, host, port, 443);
route.action.tls = Some(RouteTls {
mode: TlsMode::Passthrough,
certificate: None, acme: None, versions: None, ciphers: None,
honor_cipher_order: None, session_timeout: None,
});
route
}
#[test]
fn test_serde_roundtrip_minimal() {
let options = RustProxyOptions {
routes: vec![create_http_route("example.com", "localhost", 8080)],
routes: vec![make_route("example.com", "localhost", 8080, 80)],
..Default::default()
};
let json = serde_json::to_string(&options).unwrap();
@@ -348,8 +384,8 @@ mod tests {
fn test_serde_roundtrip_full() {
let options = RustProxyOptions {
routes: vec![
create_http_route("a.com", "backend1", 8080),
create_https_passthrough_route("b.com", "backend2", 443),
make_route("a.com", "backend1", 8080, 80),
make_passthrough_route("b.com", "backend2", 443),
],
connection_timeout: Some(5000),
socket_timeout: Some(60000),
@@ -402,9 +438,9 @@ mod tests {
fn test_all_listening_ports() {
let options = RustProxyOptions {
routes: vec![
create_http_route("a.com", "backend", 8080), // port 80
create_https_passthrough_route("b.com", "backend", 443), // port 443
create_http_route("c.com", "backend", 9090), // port 80 (duplicate)
make_route("a.com", "backend", 8080, 80), // port 80
make_passthrough_route("b.com", "backend", 443), // port 443
make_route("c.com", "backend", 9090, 80), // port 80 (duplicate)
],
..Default::default()
};

View File

@@ -60,16 +60,6 @@ pub enum RouteActionType {
SocketHandler,
}
// ─── Forwarding Engine ───────────────────────────────────────────────
/// Forwarding engine specification.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum ForwardingEngine {
Node,
Nftables,
}
// ─── Route Match ─────────────────────────────────────────────────────
/// Domain specification: single string or array.
@@ -89,6 +79,31 @@ impl DomainSpec {
}
}
// Convenience conversions for DomainSpec
impl From<&str> for DomainSpec {
fn from(s: &str) -> Self {
DomainSpec::Single(s.to_string())
}
}
impl From<String> for DomainSpec {
fn from(s: String) -> Self {
DomainSpec::Single(s)
}
}
impl From<Vec<String>> for DomainSpec {
fn from(v: Vec<String>) -> Self {
DomainSpec::List(v)
}
}
impl From<Vec<&str>> for DomainSpec {
fn from(v: Vec<&str>) -> Self {
DomainSpec::List(v.into_iter().map(|s| s.to_string()).collect())
}
}
/// Header match value: either exact string or regex pattern.
/// In JSON, all values come as strings. Regex patterns are prefixed with `/` and suffixed with `/`.
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -341,38 +356,6 @@ pub struct RouteAdvanced {
pub url_rewrite: Option<RouteUrlRewrite>,
}
// ─── NFTables Options ────────────────────────────────────────────────
/// NFTables protocol type.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum NfTablesProtocol {
Tcp,
Udp,
All,
}
/// NFTables-specific configuration options.
/// Matches TypeScript: `INfTablesOptions`
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct NfTablesOptions {
#[serde(skip_serializing_if = "Option::is_none")]
pub preserve_source_ip: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub protocol: Option<NfTablesProtocol>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_rate: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub priority: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub table_name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub use_ip_sets: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub use_advanced_nat: Option<bool>,
}
// ─── Backend Protocol ────────────────────────────────────────────────
/// Backend protocol.
@@ -541,14 +524,6 @@ pub struct RouteAction {
#[serde(skip_serializing_if = "Option::is_none")]
pub options: Option<ActionOptions>,
/// Forwarding engine specification
#[serde(skip_serializing_if = "Option::is_none")]
pub forwarding_engine: Option<ForwardingEngine>,
/// NFTables-specific options
#[serde(skip_serializing_if = "Option::is_none")]
pub nftables: Option<NfTablesOptions>,
/// PROXY protocol support (default for all targets)
#[serde(skip_serializing_if = "Option::is_none")]
pub send_proxy_protocol: Option<bool>,

View File

@@ -104,7 +104,49 @@ mod tests {
use crate::route_types::*;
fn make_valid_route() -> RouteConfig {
crate::helpers::create_http_route("example.com", "localhost", 8080)
RouteConfig {
id: None,
route_match: RouteMatch {
ports: PortRange::Single(80),
domains: Some(DomainSpec::Single("example.com".to_string())),
path: None,
client_ip: None,
transport: None,
tls_version: None,
headers: None,
protocol: None,
},
action: RouteAction {
action_type: RouteActionType::Forward,
targets: Some(vec![RouteTarget {
target_match: None,
host: HostSpec::Single("localhost".to_string()),
port: PortSpec::Fixed(8080),
tls: None,
websocket: None,
load_balancing: None,
send_proxy_protocol: None,
headers: None,
advanced: None,
backend_transport: None,
priority: None,
}]),
tls: None,
websocket: None,
load_balancing: None,
advanced: None,
options: None,
send_proxy_protocol: None,
udp: None,
},
headers: None,
security: None,
name: None,
description: None,
priority: None,
tags: None,
enabled: None,
}
}
#[test]

View File

@@ -30,3 +30,4 @@ socket2 = { workspace = true }
quinn = { workspace = true }
h3 = { workspace = true }
h3-quinn = { workspace = true }
futures = { version = "0.3", default-features = false, features = ["std"] }

View File

@@ -56,7 +56,11 @@ struct PooledH2 {
}
/// A pooled QUIC/HTTP/3 connection (multiplexed like H2).
/// Stores the h3 `SendRequest` handle so pool hits skip the h3 SETTINGS handshake.
pub struct PooledH3 {
/// Multiplexed h3 request handle — clone to open a new stream.
pub send_request: h3::client::SendRequest<h3_quinn::OpenStreams, Bytes>,
/// Raw QUIC connection — kept for liveness probing (close_reason) only.
pub connection: quinn::Connection,
pub created_at: Instant,
pub generation: u64,
@@ -197,7 +201,10 @@ impl ConnectionPool {
/// Try to get a pooled QUIC connection for the given key.
/// QUIC connections are multiplexed — the connection is shared, not removed.
pub fn checkout_h3(&self, key: &PoolKey) -> Option<(quinn::Connection, Duration)> {
pub fn checkout_h3(
&self,
key: &PoolKey,
) -> Option<(h3::client::SendRequest<h3_quinn::OpenStreams, Bytes>, quinn::Connection, Duration)> {
let entry = self.h3_pool.get(key)?;
let pooled = entry.value();
let age = pooled.created_at.elapsed();
@@ -215,13 +222,20 @@ impl ConnectionPool {
return None;
}
Some((pooled.connection.clone(), age))
Some((pooled.send_request.clone(), pooled.connection.clone(), age))
}
/// Register a QUIC connection in the pool. Returns the generation ID.
pub fn register_h3(&self, key: PoolKey, connection: quinn::Connection) -> u64 {
/// Register a QUIC connection and its h3 SendRequest handle in the pool.
/// Returns the generation ID.
pub fn register_h3(
&self,
key: PoolKey,
connection: quinn::Connection,
send_request: h3::client::SendRequest<h3_quinn::OpenStreams, Bytes>,
) -> u64 {
let gen = self.h2_generation.fetch_add(1, Ordering::Relaxed);
self.h3_pool.insert(key, PooledH3 {
send_request,
connection,
created_at: Instant::now(),
generation: gen,

View File

@@ -1,205 +1,138 @@
//! HTTP/3 proxy service.
//!
//! Accepts QUIC connections via quinn, runs h3 server to handle HTTP/3 requests,
//! and forwards them to backends using the same routing and pool infrastructure
//! as the HTTP/1+2 proxy.
//! and delegates backend forwarding to the shared `HttpProxyService` — same
//! route matching, connection pool, and protocol auto-detection as TCP/HTTP.
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Duration;
use arc_swap::ArcSwap;
use bytes::{Buf, Bytes};
use http_body::Frame;
use http_body_util::BodyExt;
use http_body_util::combinators::BoxBody;
use tracing::{debug, warn};
use rustproxy_config::{RouteConfig, TransportProtocol};
use rustproxy_metrics::MetricsCollector;
use rustproxy_routing::{MatchContext, RouteManager};
use rustproxy_config::RouteConfig;
use tokio_util::sync::CancellationToken;
use crate::connection_pool::ConnectionPool;
use crate::protocol_cache::ProtocolCache;
use crate::upstream_selector::UpstreamSelector;
use crate::proxy_service::{ConnActivity, HttpProxyService};
/// HTTP/3 proxy service.
///
/// Handles QUIC connections with the h3 crate, parses HTTP/3 requests,
/// and forwards them to backends using per-request route matching and
/// shared connection pooling.
/// Accepts QUIC connections, parses HTTP/3 requests, and delegates backend
/// forwarding to the shared `HttpProxyService`.
pub struct H3ProxyService {
route_manager: Arc<ArcSwap<RouteManager>>,
metrics: Arc<MetricsCollector>,
connection_pool: Arc<ConnectionPool>,
#[allow(dead_code)]
protocol_cache: Arc<ProtocolCache>,
#[allow(dead_code)]
upstream_selector: UpstreamSelector,
#[allow(dead_code)]
backend_tls_config: Arc<rustls::ClientConfig>,
connect_timeout: Duration,
http_proxy: Arc<HttpProxyService>,
}
impl H3ProxyService {
pub fn new(
route_manager: Arc<ArcSwap<RouteManager>>,
metrics: Arc<MetricsCollector>,
connection_pool: Arc<ConnectionPool>,
protocol_cache: Arc<ProtocolCache>,
backend_tls_config: Arc<rustls::ClientConfig>,
connect_timeout: Duration,
) -> Self {
Self {
route_manager: Arc::clone(&route_manager),
metrics: Arc::clone(&metrics),
connection_pool,
protocol_cache,
upstream_selector: UpstreamSelector::new(),
backend_tls_config,
connect_timeout,
}
pub fn new(http_proxy: Arc<HttpProxyService>) -> Self {
Self { http_proxy }
}
/// Handle an accepted QUIC connection as HTTP/3.
///
/// If `real_client_addr` is provided (from PROXY protocol), it overrides
/// `connection.remote_address()` for client IP attribution.
pub async fn handle_connection(
&self,
connection: quinn::Connection,
_fallback_route: &RouteConfig,
port: u16,
real_client_addr: Option<SocketAddr>,
parent_cancel: &CancellationToken,
) -> anyhow::Result<()> {
let remote_addr = connection.remote_address();
let remote_addr = real_client_addr.unwrap_or_else(|| connection.remote_address());
debug!("HTTP/3 connection from {} on port {}", remote_addr, port);
let mut h3_conn: h3::server::Connection<h3_quinn::Connection, Bytes> =
h3::server::Connection::new(h3_quinn::Connection::new(connection))
h3::server::builder()
.send_grease(false)
.build(h3_quinn::Connection::new(connection))
.await
.map_err(|e| anyhow::anyhow!("H3 connection setup failed: {}", e))?;
let client_ip = remote_addr.ip().to_string();
loop {
match h3_conn.accept().await {
Ok(Some(resolver)) => {
let (request, stream) = match resolver.resolve_request().await {
Ok(pair) => pair,
let resolver = tokio::select! {
_ = parent_cancel.cancelled() => {
debug!("HTTP/3 connection from {} cancelled by parent", remote_addr);
break;
}
result = h3_conn.accept() => {
match result {
Ok(Some(resolver)) => resolver,
Ok(None) => {
debug!("HTTP/3 connection from {} closed", remote_addr);
break;
}
Err(e) => {
debug!("HTTP/3 request resolve error: {}", e);
continue;
debug!("HTTP/3 accept error from {}: {}", remote_addr, e);
break;
}
};
self.metrics.record_http_request();
let rm = self.route_manager.load();
let pool = Arc::clone(&self.connection_pool);
let metrics = Arc::clone(&self.metrics);
let connect_timeout = self.connect_timeout;
let client_ip = client_ip.clone();
tokio::spawn(async move {
if let Err(e) = handle_h3_request(
request, stream, port, &client_ip, &rm, &pool, &metrics, connect_timeout,
).await {
debug!("HTTP/3 request error from {}: {}", client_ip, e);
}
});
}
Ok(None) => {
debug!("HTTP/3 connection from {} closed", remote_addr);
break;
}
}
};
let (request, stream) = match resolver.resolve_request().await {
Ok(pair) => pair,
Err(e) => {
debug!("HTTP/3 accept error from {}: {}", remote_addr, e);
break;
debug!("HTTP/3 request resolve error: {}", e);
continue;
}
}
};
let http_proxy = Arc::clone(&self.http_proxy);
let request_cancel = parent_cancel.child_token();
tokio::spawn(async move {
if let Err(e) = handle_h3_request(
request, stream, port, remote_addr, &http_proxy, request_cancel,
).await {
debug!("HTTP/3 request error from {}: {}", remote_addr, e);
}
});
}
Ok(())
}
}
/// Handle a single HTTP/3 request with per-request route matching.
/// Handle a single HTTP/3 request by delegating to HttpProxyService.
///
/// 1. Read the H3 request body via an mpsc channel (streaming, not buffered)
/// 2. Build a `hyper::Request<BoxBody>` that HttpProxyService can handle
/// 3. Call `HttpProxyService::handle_request` — same route matching, connection
/// pool, ALPN protocol detection (H1/H2/H3) as the TCP/HTTP path
/// 4. Stream the response back over the H3 stream
async fn handle_h3_request(
request: hyper::Request<()>,
mut stream: h3::server::RequestStream<h3_quinn::BidiStream<Bytes>, Bytes>,
port: u16,
client_ip: &str,
route_manager: &RouteManager,
_connection_pool: &ConnectionPool,
metrics: &MetricsCollector,
connect_timeout: Duration,
peer_addr: SocketAddr,
http_proxy: &HttpProxyService,
cancel: CancellationToken,
) -> anyhow::Result<()> {
let method = request.method().clone();
let uri = request.uri().clone();
let path = uri.path().to_string();
// Stream request body from H3 client via an mpsc channel.
let (body_tx, body_rx) = tokio::sync::mpsc::channel::<Bytes>(32);
// Extract host from :authority or Host header
let host = request.uri().authority()
.map(|a| a.as_str().to_string())
.or_else(|| request.headers().get("host").and_then(|v| v.to_str().ok()).map(|s| s.to_string()))
.unwrap_or_default();
debug!("HTTP/3 {} {} (host: {}, client: {})", method, path, host, client_ip);
// Per-request route matching
let ctx = MatchContext {
port,
domain: if host.is_empty() { None } else { Some(&host) },
path: Some(&path),
client_ip: Some(client_ip),
tls_version: Some("TLSv1.3"),
headers: None,
is_tls: true,
protocol: Some("http"),
transport: Some(TransportProtocol::Udp),
};
let route_match = route_manager.find_route(&ctx)
.ok_or_else(|| anyhow::anyhow!("No route matched for HTTP/3 request to {}{}", host, path))?;
let route = route_match.route;
// Resolve backend target (use matched target or first target)
let target = route_match.target
.or_else(|| route.action.targets.as_ref().and_then(|t| t.first()))
.ok_or_else(|| anyhow::anyhow!("No target for HTTP/3 route"))?;
let backend_host = target.host.first();
let backend_port = target.port.resolve(port);
let backend_addr = format!("{}:{}", backend_host, backend_port);
// Connect to backend via TCP HTTP/1.1 with timeout
let tcp_stream = tokio::time::timeout(
connect_timeout,
tokio::net::TcpStream::connect(&backend_addr),
).await
.map_err(|_| anyhow::anyhow!("Backend connect timeout to {}", backend_addr))?
.map_err(|e| anyhow::anyhow!("Backend connect to {} failed: {}", backend_addr, e))?;
let _ = tcp_stream.set_nodelay(true);
let io = hyper_util::rt::TokioIo::new(tcp_stream);
let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await
.map_err(|e| anyhow::anyhow!("Backend handshake failed: {}", e))?;
tokio::spawn(async move {
if let Err(e) = conn.await {
debug!("Backend connection closed: {}", e);
}
});
// Stream request body from H3 client to backend via an mpsc channel.
// This avoids buffering the entire request body in memory.
let (body_tx, body_rx) = tokio::sync::mpsc::channel::<Bytes>(4);
let total_bytes_in = Arc::new(std::sync::atomic::AtomicU64::new(0));
let total_bytes_in_writer = Arc::clone(&total_bytes_in);
// Spawn the H3 body reader task
// Spawn the H3 body reader task with cancellation
let body_cancel = cancel.clone();
let body_reader = tokio::spawn(async move {
while let Ok(Some(mut chunk)) = stream.recv_data().await {
let data = Bytes::copy_from_slice(chunk.chunk());
total_bytes_in_writer.fetch_add(data.len() as u64, std::sync::atomic::Ordering::Relaxed);
chunk.advance(chunk.remaining());
loop {
let chunk = tokio::select! {
_ = body_cancel.cancelled() => break,
result = stream.recv_data() => {
match result {
Ok(Some(chunk)) => chunk,
_ => break,
}
}
};
let mut chunk = chunk;
let data = chunk.copy_to_bytes(chunk.remaining());
if body_tx.send(data).await.is_err() {
break;
}
@@ -207,106 +140,63 @@ async fn handle_h3_request(
stream
});
// Create a body that polls from the mpsc receiver
// Build a hyper::Request<BoxBody> from the H3 request + streaming body.
// The URI already has scheme + authority + path set by the h3 crate.
let body = H3RequestBody { receiver: body_rx };
let backend_req = build_backend_request(&method, &backend_addr, &path, &host, &request, body)?;
let (parts, _) = request.into_parts();
let boxed_body: BoxBody<Bytes, hyper::Error> = BoxBody::new(body);
let req = hyper::Request::from_parts(parts, boxed_body);
let response = sender.send_request(backend_req).await
// Delegate to HttpProxyService — same backend path as TCP/HTTP:
// route matching, ALPN protocol detection, connection pool, H1/H2/H3 auto.
let conn_activity = ConnActivity::new_standalone();
let response = http_proxy.handle_request(req, peer_addr, port, cancel, conn_activity).await
.map_err(|e| anyhow::anyhow!("Backend request failed: {}", e))?;
// Await the body reader to get the stream back
// Await the body reader to get the H3 stream back
let mut stream = body_reader.await
.map_err(|e| anyhow::anyhow!("Body reader task failed: {}", e))?;
let total_bytes_in = total_bytes_in.load(std::sync::atomic::Ordering::Relaxed);
// Build H3 response
let status = response.status();
let mut h3_response = hyper::Response::builder().status(status);
// Copy response headers (skip hop-by-hop)
for (name, value) in response.headers() {
let n = name.as_str().to_lowercase();
// Send response headers over H3 (skip hop-by-hop headers)
let (resp_parts, resp_body) = response.into_parts();
let mut h3_response = hyper::Response::builder().status(resp_parts.status);
for (name, value) in &resp_parts.headers {
let n = name.as_str();
if n == "transfer-encoding" || n == "connection" || n == "keep-alive" || n == "upgrade" {
continue;
}
h3_response = h3_response.header(name, value);
}
// Add Alt-Svc for HTTP/3 advertisement
let alt_svc = route.action.udp.as_ref()
.and_then(|u| u.quic.as_ref())
.map(|q| {
let p = q.alt_svc_port.unwrap_or(port);
let ma = q.alt_svc_max_age.unwrap_or(86400);
format!("h3=\":{}\"; ma={}", p, ma)
})
.unwrap_or_else(|| format!("h3=\":{}\"; ma=86400", port));
h3_response = h3_response.header("alt-svc", alt_svc);
let h3_response = h3_response.body(())
.map_err(|e| anyhow::anyhow!("Failed to build H3 response: {}", e))?;
// Send response headers
stream.send_response(h3_response).await
.map_err(|e| anyhow::anyhow!("Failed to send H3 response: {}", e))?;
// Stream response body back
use http_body_util::BodyExt;
let mut body = response.into_body();
let mut total_bytes_out: u64 = 0;
while let Some(frame) = body.frame().await {
// Stream response body back over H3
let mut resp_body = resp_body;
while let Some(frame) = resp_body.frame().await {
match frame {
Ok(frame) => {
if let Some(data) = frame.data_ref() {
total_bytes_out += data.len() as u64;
stream.send_data(Bytes::copy_from_slice(data)).await
if let Ok(data) = frame.into_data() {
stream.send_data(data).await
.map_err(|e| anyhow::anyhow!("Failed to send H3 data: {}", e))?;
}
}
Err(e) => {
warn!("Backend body read error: {}", e);
warn!("Response body read error: {}", e);
break;
}
}
}
// Record metrics
let route_id = route.name.as_deref().or(route.id.as_deref());
metrics.record_bytes(total_bytes_in, total_bytes_out, route_id, Some(client_ip));
// Finish the stream
// Finish the H3 stream (send QUIC FIN)
stream.finish().await
.map_err(|e| anyhow::anyhow!("Failed to finish H3 stream: {}", e))?;
Ok(())
}
/// Build an HTTP/1.1 backend request from the H3 frontend request.
fn build_backend_request<B>(
method: &hyper::Method,
backend_addr: &str,
path: &str,
host: &str,
original_request: &hyper::Request<()>,
body: B,
) -> anyhow::Result<hyper::Request<B>> {
let mut req = hyper::Request::builder()
.method(method)
.uri(format!("http://{}{}", backend_addr, path))
.header("host", host);
// Forward non-pseudo headers
for (name, value) in original_request.headers() {
let n = name.as_str();
if !n.starts_with(':') && n != "host" {
req = req.header(name, value);
}
}
req.body(body)
.map_err(|e| anyhow::anyhow!("Failed to build backend request: {}", e))
}
/// A streaming request body backed by an mpsc channel receiver.
///
/// Implements `http_body::Body` so hyper can poll chunks as they arrive

View File

@@ -1,43 +1,99 @@
//! Bounded, TTL-based protocol detection cache for backend protocol auto-detection.
//! Bounded, sliding-TTL protocol detection cache with periodic re-probing and failure suppression.
//!
//! Caches the detected protocol (H1, H2, or H3) per backend endpoint and requested
//! domain (host:port + requested_host). This prevents cache oscillation when multiple
//! frontend domains share the same backend but differ in protocol support.
//!
//! H3 detection uses the browser model: Alt-Svc headers from H1/H2 responses are
//! parsed and cached, including the advertised H3 port (which may differ from TCP).
//! ## Sliding TTL
//!
//! Each cache hit refreshes the entry's expiry timer (`last_accessed_at`). Entries
//! remain valid for up to 1 day of continuous use. Every 5 minutes, the next request
//! triggers an inline ALPN re-probe to verify the cached protocol is still correct.
//!
//! ## Upgrade signals
//!
//! - ALPN (TLS handshake) → detects H2 vs H1
//! - Alt-Svc (response header) → advertises H3
//!
//! ## Protocol transitions
//!
//! All protocol changes are logged at `info!()` level with the reason:
//! "Protocol transition: H1 → H2 because periodic ALPN re-probe"
//!
//! ## Failure suppression
//!
//! When a protocol fails, `record_failure()` prevents upgrade signals from
//! re-introducing it until an escalating cooldown expires (5s → 10s → ... → 300s).
//! Within-request escalation is allowed via `can_retry()` after a 5s minimum gap.
//!
//! ## Total failure eviction
//!
//! When all protocols (H3, H2, H1) fail for a backend, the cache entry is evicted
//! entirely via `evict()`, forcing a fresh probe on the next request.
//!
//! Cascading: when a lower protocol also fails, higher protocol cooldowns are
//! reduced to 5s remaining (not instant clear), preventing tight retry loops.
use std::sync::Arc;
use std::time::{Duration, Instant};
use dashmap::DashMap;
use tracing::debug;
use tracing::{debug, info};
/// TTL for cached protocol detection results.
/// After this duration, the next request will re-probe the backend.
const PROTOCOL_CACHE_TTL: Duration = Duration::from_secs(300); // 5 minutes
/// Sliding TTL for cached protocol detection results.
/// Entries that haven't been accessed for this duration are evicted.
/// Each `get()` call refreshes the timer (sliding window).
const PROTOCOL_CACHE_TTL: Duration = Duration::from_secs(86400); // 1 day
/// Interval between inline ALPN re-probes for H1/H2 entries.
/// When a cached entry's `last_probed_at` exceeds this, the next request
/// triggers an ALPN re-probe to verify the backend still speaks the same protocol.
const PROTOCOL_REPROBE_INTERVAL: Duration = Duration::from_secs(300); // 5 minutes
/// Maximum number of entries in the protocol cache.
/// Prevents unbounded growth when backends come and go.
const PROTOCOL_CACHE_MAX_ENTRIES: usize = 4096;
/// Background cleanup interval for the protocol cache.
/// Background cleanup interval.
const PROTOCOL_CACHE_CLEANUP_INTERVAL: Duration = Duration::from_secs(60);
/// Minimum cooldown between retry attempts of a failed protocol.
const PROTOCOL_FAILURE_COOLDOWN: Duration = Duration::from_secs(5);
/// Maximum cooldown (escalation ceiling).
const PROTOCOL_FAILURE_MAX_COOLDOWN: Duration = Duration::from_secs(300);
/// Consecutive failure count at which cooldown reaches maximum.
/// 5s × 2^5 = 160s, 5s × 2^6 = 320s → capped at 300s.
const PROTOCOL_FAILURE_ESCALATION_CAP: u32 = 6;
/// Detected backend protocol.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum DetectedProtocol {
H1,
H2,
H3,
}
impl std::fmt::Display for DetectedProtocol {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
DetectedProtocol::H1 => write!(f, "H1"),
DetectedProtocol::H2 => write!(f, "H2"),
DetectedProtocol::H3 => write!(f, "H3"),
}
}
}
/// Result of a protocol cache lookup.
#[derive(Debug, Clone, Copy)]
pub struct CachedProtocol {
pub protocol: DetectedProtocol,
/// For H3: the port advertised by Alt-Svc (may differ from TCP port).
pub h3_port: Option<u16>,
/// True if the entry's `last_probed_at` exceeds `PROTOCOL_REPROBE_INTERVAL`.
/// Caller should perform an inline ALPN re-probe and call `update_probe_result()`.
/// Always `false` for H3 entries (H3 is discovered via Alt-Svc, not ALPN).
pub needs_reprobe: bool,
}
/// Key for the protocol cache: (host, port, requested_host).
@@ -50,24 +106,111 @@ pub struct ProtocolCacheKey {
pub requested_host: Option<String>,
}
/// A cached protocol detection result with a timestamp.
/// A cached protocol detection result with timestamps.
struct CachedEntry {
protocol: DetectedProtocol,
/// When this protocol was first detected (or last changed).
detected_at: Instant,
/// Last time any request used this entry (sliding-window TTL).
last_accessed_at: Instant,
/// Last time an ALPN re-probe was performed for this entry.
last_probed_at: Instant,
/// For H3: the port advertised by Alt-Svc (may differ from TCP port).
h3_port: Option<u16>,
}
/// Bounded, TTL-based protocol detection cache.
/// Failure record for a single protocol level.
#[derive(Debug, Clone)]
struct FailureRecord {
/// When the failure was last recorded.
failed_at: Instant,
/// Current cooldown duration. Escalates on consecutive failures.
cooldown: Duration,
/// Number of consecutive failures (for escalation).
consecutive_failures: u32,
}
/// Per-key failure state. Tracks failures at each upgradeable protocol level.
/// H1 is never tracked (it's the protocol floor — nothing to fall back to).
#[derive(Debug, Clone, Default)]
struct FailureState {
h2: Option<FailureRecord>,
h3: Option<FailureRecord>,
}
impl FailureState {
fn is_empty(&self) -> bool {
self.h2.is_none() && self.h3.is_none()
}
fn all_expired(&self) -> bool {
let h2_expired = self.h2.as_ref()
.map(|r| r.failed_at.elapsed() >= r.cooldown)
.unwrap_or(true);
let h3_expired = self.h3.as_ref()
.map(|r| r.failed_at.elapsed() >= r.cooldown)
.unwrap_or(true);
h2_expired && h3_expired
}
fn get(&self, protocol: DetectedProtocol) -> Option<&FailureRecord> {
match protocol {
DetectedProtocol::H2 => self.h2.as_ref(),
DetectedProtocol::H3 => self.h3.as_ref(),
DetectedProtocol::H1 => None,
}
}
fn get_mut(&mut self, protocol: DetectedProtocol) -> &mut Option<FailureRecord> {
match protocol {
DetectedProtocol::H2 => &mut self.h2,
DetectedProtocol::H3 => &mut self.h3,
DetectedProtocol::H1 => unreachable!("H1 failures are never recorded"),
}
}
}
/// Snapshot of a single protocol cache entry, suitable for metrics/UI display.
#[derive(Debug, Clone)]
pub struct ProtocolCacheEntry {
pub host: String,
pub port: u16,
pub domain: Option<String>,
pub protocol: String,
pub h3_port: Option<u16>,
pub age_secs: u64,
pub last_accessed_secs: u64,
pub last_probed_secs: u64,
pub h2_suppressed: bool,
pub h3_suppressed: bool,
pub h2_cooldown_remaining_secs: Option<u64>,
pub h3_cooldown_remaining_secs: Option<u64>,
pub h2_consecutive_failures: Option<u32>,
pub h3_consecutive_failures: Option<u32>,
}
/// Exponential backoff: PROTOCOL_FAILURE_COOLDOWN × 2^(n-1), capped at MAX.
fn escalate_cooldown(consecutive: u32) -> Duration {
let base = PROTOCOL_FAILURE_COOLDOWN.as_secs();
let exp = consecutive.saturating_sub(1).min(63) as u64;
let secs = base.saturating_mul(1u64.checked_shl(exp as u32).unwrap_or(u64::MAX));
Duration::from_secs(secs.min(PROTOCOL_FAILURE_MAX_COOLDOWN.as_secs()))
}
/// Bounded, sliding-TTL protocol detection cache with failure suppression.
///
/// Memory safety guarantees:
/// - Hard cap at `PROTOCOL_CACHE_MAX_ENTRIES` — cannot grow unboundedly.
/// - TTL expiry — stale entries naturally age out on lookup.
/// - Sliding TTL expiry — entries age out after 1 day without access.
/// - Background cleanup task — proactively removes expired entries every 60s.
/// - `clear()` — called on route updates to discard stale detections.
/// - `Drop` — aborts the background task to prevent dangling tokio tasks.
pub struct ProtocolCache {
cache: Arc<DashMap<ProtocolCacheKey, CachedEntry>>,
/// Generic protocol failure suppression map. Tracks per-protocol failure
/// records (H2, H3) for each cache key. Used to prevent upgrade signals
/// (ALPN, Alt-Svc) from re-introducing failed protocols.
failures: Arc<DashMap<ProtocolCacheKey, FailureState>>,
cleanup_handle: Option<tokio::task::JoinHandle<()>>,
}
@@ -75,26 +218,40 @@ impl ProtocolCache {
/// Create a new protocol cache and start the background cleanup task.
pub fn new() -> Self {
let cache: Arc<DashMap<ProtocolCacheKey, CachedEntry>> = Arc::new(DashMap::new());
let failures: Arc<DashMap<ProtocolCacheKey, FailureState>> = Arc::new(DashMap::new());
let cache_clone = Arc::clone(&cache);
let failures_clone = Arc::clone(&failures);
let cleanup_handle = tokio::spawn(async move {
Self::cleanup_loop(cache_clone).await;
Self::cleanup_loop(cache_clone, failures_clone).await;
});
Self {
cache,
failures,
cleanup_handle: Some(cleanup_handle),
}
}
/// Look up the cached protocol for a backend endpoint.
///
/// Returns `None` if not cached or expired (caller should probe via ALPN).
/// On hit, refreshes `last_accessed_at` (sliding TTL) and sets `needs_reprobe`
/// if the entry hasn't been probed in over 5 minutes (H1/H2 only).
pub fn get(&self, key: &ProtocolCacheKey) -> Option<CachedProtocol> {
let entry = self.cache.get(key)?;
if entry.detected_at.elapsed() < PROTOCOL_CACHE_TTL {
debug!("Protocol cache hit: {:?} for {}:{} (requested: {:?})", entry.protocol, key.host, key.port, key.requested_host);
let mut entry = self.cache.get_mut(key)?;
if entry.last_accessed_at.elapsed() < PROTOCOL_CACHE_TTL {
// Refresh sliding TTL
entry.last_accessed_at = Instant::now();
// H3 is the ceiling — can't ALPN-probe for H3 (discovered via Alt-Svc).
// Only H1/H2 entries trigger periodic re-probing.
let needs_reprobe = entry.protocol != DetectedProtocol::H3
&& entry.last_probed_at.elapsed() >= PROTOCOL_REPROBE_INTERVAL;
Some(CachedProtocol {
protocol: entry.protocol,
h3_port: entry.h3_port,
needs_reprobe,
})
} else {
// Expired — remove and return None to trigger re-probe
@@ -105,47 +262,328 @@ impl ProtocolCache {
}
/// Insert a detected protocol into the cache.
/// If the cache is at capacity, evict the oldest entry first.
pub fn insert(&self, key: ProtocolCacheKey, protocol: DetectedProtocol) {
self.insert_with_h3_port(key, protocol, None);
/// Returns `false` if suppressed due to active failure suppression.
///
/// **Key semantic**: only suppresses if the protocol being inserted matches
/// a suppressed protocol. H1 inserts are NEVER suppressed — downgrades
/// always succeed.
pub fn insert(&self, key: ProtocolCacheKey, protocol: DetectedProtocol, reason: &str) -> bool {
if self.is_suppressed(&key, protocol) {
debug!(
host = %key.host, port = %key.port, domain = ?key.requested_host,
protocol = ?protocol,
"Protocol cache insert suppressed — recent failure"
);
return false;
}
self.insert_internal(key, protocol, None, reason);
true
}
/// Insert an H3 detection result with the Alt-Svc advertised port.
pub fn insert_h3(&self, key: ProtocolCacheKey, h3_port: u16) {
self.insert_with_h3_port(key, DetectedProtocol::H3, Some(h3_port));
/// Returns `false` if H3 is suppressed.
pub fn insert_h3(&self, key: ProtocolCacheKey, h3_port: u16, reason: &str) -> bool {
if self.is_suppressed(&key, DetectedProtocol::H3) {
debug!(
host = %key.host, port = %key.port, domain = ?key.requested_host,
"H3 upgrade suppressed — recent failure"
);
return false;
}
self.insert_internal(key, DetectedProtocol::H3, Some(h3_port), reason);
true
}
/// Insert a protocol detection result with an optional H3 port.
fn insert_with_h3_port(&self, key: ProtocolCacheKey, protocol: DetectedProtocol, h3_port: Option<u16>) {
if self.cache.len() >= PROTOCOL_CACHE_MAX_ENTRIES && !self.cache.contains_key(&key) {
// Evict the oldest entry to stay within bounds
let oldest = self.cache.iter()
.min_by_key(|entry| entry.value().detected_at)
.map(|entry| entry.key().clone());
if let Some(oldest_key) = oldest {
self.cache.remove(&oldest_key);
/// Update the cache after an inline ALPN re-probe completes.
///
/// Always updates `last_probed_at`. If the protocol changed, logs the transition
/// and updates the entry. Returns `Some(new_protocol)` if changed, `None` if unchanged.
pub fn update_probe_result(
&self,
key: &ProtocolCacheKey,
probed_protocol: DetectedProtocol,
reason: &str,
) -> Option<DetectedProtocol> {
if let Some(mut entry) = self.cache.get_mut(key) {
let old_protocol = entry.protocol;
entry.last_probed_at = Instant::now();
entry.last_accessed_at = Instant::now();
if old_protocol != probed_protocol {
info!(
host = %key.host, port = %key.port, domain = ?key.requested_host,
old = %old_protocol, new = %probed_protocol, reason = %reason,
"Protocol transition"
);
entry.protocol = probed_protocol;
entry.detected_at = Instant::now();
// Clear h3_port if downgrading from H3
if old_protocol == DetectedProtocol::H3 && probed_protocol != DetectedProtocol::H3 {
entry.h3_port = None;
}
return Some(probed_protocol);
}
debug!(
host = %key.host, port = %key.port, domain = ?key.requested_host,
protocol = %old_protocol, reason = %reason,
"Re-probe confirmed — no protocol change"
);
None
} else {
// Entry was evicted between the get() and the probe completing.
// Insert as a fresh entry.
self.insert_internal(key.clone(), probed_protocol, None, reason);
Some(probed_protocol)
}
}
/// Record a protocol failure. Future `insert()` calls for this protocol
/// will be suppressed until the escalating cooldown expires.
///
/// Cooldown escalation: 5s → 10s → 20s → 40s → 80s → 160s → 300s.
/// Consecutive counter resets if the previous failure is older than 2× its cooldown.
///
/// Cascading: when H2 fails, H3 cooldown is reduced to 5s remaining.
/// H1 failures are ignored (H1 is the protocol floor).
pub fn record_failure(&self, key: ProtocolCacheKey, protocol: DetectedProtocol) {
if protocol == DetectedProtocol::H1 {
return; // H1 is the floor — nothing to suppress
}
let mut entry = self.failures.entry(key.clone()).or_default();
let record = entry.get_mut(protocol);
let (consecutive, new_cooldown) = match record {
Some(existing) if existing.failed_at.elapsed() < existing.cooldown.saturating_mul(2) => {
// Still within the "recent" window — escalate
let c = existing.consecutive_failures.saturating_add(1)
.min(PROTOCOL_FAILURE_ESCALATION_CAP);
(c, escalate_cooldown(c))
}
_ => {
// First failure or old failure that expired long ago — reset
(1, PROTOCOL_FAILURE_COOLDOWN)
}
};
*record = Some(FailureRecord {
failed_at: Instant::now(),
cooldown: new_cooldown,
consecutive_failures: consecutive,
});
// Cascading: when H2 fails, reduce H3 cooldown to 5s remaining
if protocol == DetectedProtocol::H2 {
Self::reduce_cooldown_to(entry.h3.as_mut(), PROTOCOL_FAILURE_COOLDOWN);
}
info!(
host = %key.host, port = %key.port, domain = ?key.requested_host,
protocol = ?protocol,
consecutive = consecutive,
cooldown_secs = new_cooldown.as_secs(),
"Protocol failure recorded — suppressing for {:?}", new_cooldown
);
}
/// Check whether a protocol is currently suppressed for the given key.
/// Returns `true` if the protocol failed within its cooldown period.
/// H1 is never suppressed.
pub fn is_suppressed(&self, key: &ProtocolCacheKey, protocol: DetectedProtocol) -> bool {
if protocol == DetectedProtocol::H1 {
return false;
}
self.failures.get(key)
.and_then(|entry| entry.get(protocol).map(|r| r.failed_at.elapsed() < r.cooldown))
.unwrap_or(false)
}
/// Check whether a protocol can be retried (for within-request escalation).
/// Returns `true` if there's no failure record OR if ≥5s have passed since
/// the last attempt. More permissive than `is_suppressed`.
pub fn can_retry(&self, key: &ProtocolCacheKey, protocol: DetectedProtocol) -> bool {
if protocol == DetectedProtocol::H1 {
return true;
}
match self.failures.get(key) {
Some(entry) => match entry.get(protocol) {
Some(r) => r.failed_at.elapsed() >= PROTOCOL_FAILURE_COOLDOWN,
None => true, // no failure record
},
None => true,
}
}
/// Record a retry attempt WITHOUT escalating the cooldown.
/// Resets the `failed_at` timestamp to prevent rapid retries (5s gate).
/// Called before an escalation attempt. If the attempt fails,
/// `record_failure` should be called afterward with proper escalation.
pub fn record_retry_attempt(&self, key: &ProtocolCacheKey, protocol: DetectedProtocol) {
if protocol == DetectedProtocol::H1 {
return;
}
if let Some(mut entry) = self.failures.get_mut(key) {
if let Some(ref mut r) = entry.get_mut(protocol) {
r.failed_at = Instant::now();
}
}
self.cache.insert(key, CachedEntry {
protocol,
detected_at: Instant::now(),
h3_port,
});
}
/// Clear the failure record for a protocol (it recovered).
/// Called when an escalation retry succeeds.
pub fn clear_failure(&self, key: &ProtocolCacheKey, protocol: DetectedProtocol) {
if protocol == DetectedProtocol::H1 {
return;
}
if let Some(mut entry) = self.failures.get_mut(key) {
*entry.get_mut(protocol) = None;
if entry.is_empty() {
drop(entry);
self.failures.remove(key);
}
}
}
/// Evict a cache entry entirely. Called when all protocol probes (H3, H2, H1)
/// have failed for a backend.
pub fn evict(&self, key: &ProtocolCacheKey) {
self.cache.remove(key);
self.failures.remove(key);
info!(
host = %key.host, port = %key.port, domain = ?key.requested_host,
"Cache entry evicted — all protocols failed"
);
}
/// Clear all entries. Called on route updates to discard stale detections.
pub fn clear(&self) {
self.cache.clear();
self.failures.clear();
}
/// Background cleanup loop — removes expired entries every `PROTOCOL_CACHE_CLEANUP_INTERVAL`.
async fn cleanup_loop(cache: Arc<DashMap<ProtocolCacheKey, CachedEntry>>) {
/// Snapshot all non-expired cache entries for metrics/UI display.
pub fn snapshot(&self) -> Vec<ProtocolCacheEntry> {
self.cache.iter()
.filter(|entry| entry.value().last_accessed_at.elapsed() < PROTOCOL_CACHE_TTL)
.map(|entry| {
let key = entry.key();
let val = entry.value();
let failure_info = self.failures.get(key);
let (h2_sup, h2_cd, h2_cons) = Self::suppression_info(
failure_info.as_deref().and_then(|f| f.h2.as_ref()),
);
let (h3_sup, h3_cd, h3_cons) = Self::suppression_info(
failure_info.as_deref().and_then(|f| f.h3.as_ref()),
);
ProtocolCacheEntry {
host: key.host.clone(),
port: key.port,
domain: key.requested_host.clone(),
protocol: match val.protocol {
DetectedProtocol::H1 => "h1".to_string(),
DetectedProtocol::H2 => "h2".to_string(),
DetectedProtocol::H3 => "h3".to_string(),
},
h3_port: val.h3_port,
age_secs: val.detected_at.elapsed().as_secs(),
last_accessed_secs: val.last_accessed_at.elapsed().as_secs(),
last_probed_secs: val.last_probed_at.elapsed().as_secs(),
h2_suppressed: h2_sup,
h3_suppressed: h3_sup,
h2_cooldown_remaining_secs: h2_cd,
h3_cooldown_remaining_secs: h3_cd,
h2_consecutive_failures: h2_cons,
h3_consecutive_failures: h3_cons,
}
})
.collect()
}
// --- Internal helpers ---
/// Insert a protocol detection result with an optional H3 port.
/// Logs protocol transitions when overwriting an existing entry.
/// No suppression check — callers must check before calling.
fn insert_internal(&self, key: ProtocolCacheKey, protocol: DetectedProtocol, h3_port: Option<u16>, reason: &str) {
// Check for existing entry to log protocol transitions
if let Some(existing) = self.cache.get(&key) {
if existing.protocol != protocol {
info!(
host = %key.host, port = %key.port, domain = ?key.requested_host,
old = %existing.protocol, new = %protocol, reason = %reason,
"Protocol transition"
);
}
drop(existing);
}
// Evict oldest entry if at capacity
if self.cache.len() >= PROTOCOL_CACHE_MAX_ENTRIES && !self.cache.contains_key(&key) {
let oldest = self.cache.iter()
.min_by_key(|entry| entry.value().last_accessed_at)
.map(|entry| entry.key().clone());
if let Some(oldest_key) = oldest {
self.cache.remove(&oldest_key);
}
}
let now = Instant::now();
self.cache.insert(key, CachedEntry {
protocol,
detected_at: now,
last_accessed_at: now,
last_probed_at: now,
h3_port,
});
}
/// Reduce a failure record's remaining cooldown to `target`, if it currently
/// has MORE than `target` remaining. Never increases cooldown.
fn reduce_cooldown_to(record: Option<&mut FailureRecord>, target: Duration) {
if let Some(r) = record {
let elapsed = r.failed_at.elapsed();
if elapsed < r.cooldown {
let remaining = r.cooldown - elapsed;
if remaining > target {
// Shrink cooldown so it expires in `target` from now
r.cooldown = elapsed + target;
}
}
}
}
/// Extract suppression info from a failure record for metrics.
fn suppression_info(record: Option<&FailureRecord>) -> (bool, Option<u64>, Option<u32>) {
match record {
Some(r) => {
let elapsed = r.failed_at.elapsed();
let suppressed = elapsed < r.cooldown;
let remaining = if suppressed {
Some((r.cooldown - elapsed).as_secs())
} else {
None
};
(suppressed, remaining, Some(r.consecutive_failures))
}
None => (false, None, None),
}
}
/// Background cleanup loop.
async fn cleanup_loop(
cache: Arc<DashMap<ProtocolCacheKey, CachedEntry>>,
failures: Arc<DashMap<ProtocolCacheKey, FailureState>>,
) {
let mut interval = tokio::time::interval(PROTOCOL_CACHE_CLEANUP_INTERVAL);
loop {
interval.tick().await;
// Clean expired cache entries (sliding TTL based on last_accessed_at)
let expired: Vec<ProtocolCacheKey> = cache.iter()
.filter(|entry| entry.value().detected_at.elapsed() >= PROTOCOL_CACHE_TTL)
.filter(|entry| entry.value().last_accessed_at.elapsed() >= PROTOCOL_CACHE_TTL)
.map(|entry| entry.key().clone())
.collect();
@@ -155,6 +593,31 @@ impl ProtocolCache {
cache.remove(&key);
}
}
// Clean fully-expired failure entries
let expired_failures: Vec<ProtocolCacheKey> = failures.iter()
.filter(|entry| entry.value().all_expired())
.map(|entry| entry.key().clone())
.collect();
if !expired_failures.is_empty() {
debug!("Protocol cache cleanup: removing {} expired failure entries", expired_failures.len());
for key in expired_failures {
failures.remove(&key);
}
}
// Safety net: cap failures map at 2× max entries
if failures.len() > PROTOCOL_CACHE_MAX_ENTRIES * 2 {
let oldest: Vec<ProtocolCacheKey> = failures.iter()
.filter(|e| e.value().all_expired())
.map(|e| e.key().clone())
.take(failures.len() - PROTOCOL_CACHE_MAX_ENTRIES)
.collect();
for key in oldest {
failures.remove(&key);
}
}
}
}
}

View File

@@ -36,13 +36,33 @@ use crate::upstream_selector::UpstreamSelector;
/// Per-connection context for keeping the idle watchdog alive during body streaming.
/// Passed through the forwarding chain so CountingBody can update the timestamp.
#[derive(Clone)]
struct ConnActivity {
pub struct ConnActivity {
last_activity: Arc<AtomicU64>,
start: std::time::Instant,
/// Active-request counter from handle_io's idle watchdog. When set, CountingBody
/// increments on creation and decrements on Drop, keeping the watchdog aware that
/// a response body is still streaming after the request handler has returned.
active_requests: Option<Arc<AtomicU64>>,
/// Protocol cache key for Alt-Svc discovery. When set, `build_streaming_response`
/// checks the backend's original response headers for Alt-Svc before our
/// ResponseFilter injects its own. None when not in auto-detect mode or after H3 failure.
alt_svc_cache_key: Option<crate::protocol_cache::ProtocolCacheKey>,
/// The upstream request path that triggered Alt-Svc discovery. Logged for traceability.
alt_svc_request_url: Option<String>,
}
impl ConnActivity {
/// Create a minimal ConnActivity (no idle watchdog, no Alt-Svc cache).
/// Used by H3ProxyService where the TCP idle watchdog doesn't apply.
pub fn new_standalone() -> Self {
Self {
last_activity: Arc::new(AtomicU64::new(0)),
start: std::time::Instant::now(),
active_requests: None,
alt_svc_cache_key: None,
alt_svc_request_url: None,
}
}
}
/// Default upstream connect timeout (30 seconds).
@@ -52,15 +72,16 @@ const DEFAULT_CONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_s
/// If no new request arrives within this duration, the connection is closed.
const DEFAULT_HTTP_IDLE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(60);
/// Default HTTP max connection lifetime (1 hour).
/// HTTP connections are forcefully closed after this duration regardless of activity.
const DEFAULT_HTTP_MAX_LIFETIME: std::time::Duration = std::time::Duration::from_secs(3600);
/// Default WebSocket inactivity timeout (1 hour).
const DEFAULT_WS_INACTIVITY_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(3600);
/// Default WebSocket max lifetime (24 hours).
const DEFAULT_WS_MAX_LIFETIME: std::time::Duration = std::time::Duration::from_secs(86400);
/// Timeout for QUIC (H3) backend connections. Short because UDP is often firewalled.
const QUIC_CONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(3);
/// Protocol decision for backend connection.
#[derive(Debug)]
enum ProtocolDecision {
@@ -186,6 +207,10 @@ pub struct HttpProxyService {
route_rate_limiters: Arc<DashMap<String, Arc<RateLimiter>>>,
/// Request counter for periodic rate limiter cleanup.
request_counter: AtomicU64,
/// Epoch for time-based rate limiter cleanup.
rate_limiter_epoch: std::time::Instant,
/// Last rate limiter cleanup time (ms since epoch).
last_rate_limiter_cleanup_ms: AtomicU64,
/// Cache of compiled URL rewrite regexes (keyed by pattern string).
regex_cache: DashMap<String, Regex>,
/// Shared backend TLS config for session resumption across connections.
@@ -198,6 +223,8 @@ pub struct HttpProxyService {
protocol_cache: Arc<crate::protocol_cache::ProtocolCache>,
/// HTTP keep-alive idle timeout: close connection if no new request arrives within this duration.
http_idle_timeout: std::time::Duration,
/// HTTP max connection lifetime: forcefully close connection after this duration regardless of activity.
http_max_lifetime: std::time::Duration,
/// WebSocket inactivity timeout (no data in either direction).
ws_inactivity_timeout: std::time::Duration,
/// WebSocket maximum connection lifetime.
@@ -216,12 +243,15 @@ impl HttpProxyService {
connect_timeout: DEFAULT_CONNECT_TIMEOUT,
route_rate_limiters: Arc::new(DashMap::new()),
request_counter: AtomicU64::new(0),
rate_limiter_epoch: std::time::Instant::now(),
last_rate_limiter_cleanup_ms: AtomicU64::new(0),
regex_cache: DashMap::new(),
backend_tls_config: Self::default_backend_tls_config(),
backend_tls_config_alpn: Self::default_backend_tls_config_with_alpn(),
connection_pool: Arc::new(crate::connection_pool::ConnectionPool::new()),
protocol_cache: Arc::new(crate::protocol_cache::ProtocolCache::new()),
http_idle_timeout: DEFAULT_HTTP_IDLE_TIMEOUT,
http_max_lifetime: DEFAULT_HTTP_MAX_LIFETIME,
ws_inactivity_timeout: DEFAULT_WS_INACTIVITY_TIMEOUT,
ws_max_lifetime: DEFAULT_WS_MAX_LIFETIME,
quinn_client_endpoint: Arc::new(Self::create_quinn_client_endpoint()),
@@ -241,27 +271,32 @@ impl HttpProxyService {
connect_timeout,
route_rate_limiters: Arc::new(DashMap::new()),
request_counter: AtomicU64::new(0),
rate_limiter_epoch: std::time::Instant::now(),
last_rate_limiter_cleanup_ms: AtomicU64::new(0),
regex_cache: DashMap::new(),
backend_tls_config: Self::default_backend_tls_config(),
backend_tls_config_alpn: Self::default_backend_tls_config_with_alpn(),
connection_pool: Arc::new(crate::connection_pool::ConnectionPool::new()),
protocol_cache: Arc::new(crate::protocol_cache::ProtocolCache::new()),
http_idle_timeout: DEFAULT_HTTP_IDLE_TIMEOUT,
http_max_lifetime: DEFAULT_HTTP_MAX_LIFETIME,
ws_inactivity_timeout: DEFAULT_WS_INACTIVITY_TIMEOUT,
ws_max_lifetime: DEFAULT_WS_MAX_LIFETIME,
quinn_client_endpoint: Arc::new(Self::create_quinn_client_endpoint()),
}
}
/// Set the HTTP keep-alive idle timeout, WebSocket inactivity timeout, and
/// WebSocket max lifetime from connection config values.
/// Set the HTTP keep-alive idle timeout, HTTP max lifetime, WebSocket inactivity
/// timeout, and WebSocket max lifetime from connection config values.
pub fn set_connection_timeouts(
&mut self,
http_idle_timeout: std::time::Duration,
http_max_lifetime: std::time::Duration,
ws_inactivity_timeout: std::time::Duration,
ws_max_lifetime: std::time::Duration,
) {
self.http_idle_timeout = http_idle_timeout;
self.http_max_lifetime = http_max_lifetime;
self.ws_inactivity_timeout = ws_inactivity_timeout;
self.ws_max_lifetime = ws_max_lifetime;
}
@@ -286,6 +321,20 @@ impl HttpProxyService {
self.protocol_cache.clear();
}
/// Clean up expired entries in all per-route rate limiters.
/// Called from the background sampling task to prevent unbounded growth
/// when traffic stops after a burst of unique IPs.
pub fn cleanup_all_rate_limiters(&self) {
for entry in self.route_rate_limiters.iter() {
entry.value().cleanup();
}
}
/// Snapshot the protocol cache for metrics/UI display.
pub fn protocol_cache_snapshot(&self) -> Vec<crate::protocol_cache::ProtocolCacheEntry> {
self.protocol_cache.snapshot()
}
/// Handle an incoming HTTP connection on a plain TCP stream.
pub async fn handle_connection(
self: Arc<Self>,
@@ -321,6 +370,7 @@ impl HttpProxyService {
// Capture timeouts before `self` is moved into the service closure.
let idle_timeout = self.http_idle_timeout;
let max_lifetime = self.http_max_lifetime;
// Activity tracker: updated at the START and END of each request.
// The idle watchdog checks this to determine if the connection is idle
@@ -341,8 +391,9 @@ impl HttpProxyService {
let cn = cancel_inner.clone();
let la = Arc::clone(&la_inner);
let st = start;
let ca = ConnActivity { last_activity: Arc::clone(&la_inner), start, active_requests: Some(Arc::clone(&ar_inner)) };
let ca = ConnActivity { last_activity: Arc::clone(&la_inner), start, active_requests: Some(Arc::clone(&ar_inner)), alt_svc_cache_key: None, alt_svc_request_url: None };
async move {
let req = req.map(|body| BoxBody::new(body));
let result = svc.handle_request(req, peer, port, cn, ca).await;
// Mark request end — update activity timestamp before guard drops
la.store(st.elapsed().as_millis() as u64, Ordering::Relaxed);
@@ -378,15 +429,23 @@ impl HttpProxyService {
}
}
_ = async {
// Idle watchdog: check every 5s whether the connection has been idle
// (no active requests AND no activity for idle_timeout).
// This avoids killing long-running requests or upgraded connections.
// Idle + lifetime watchdog: check every 5s whether the connection has been
// idle (no active requests AND no activity for idle_timeout) or exceeded
// the max connection lifetime.
let check_interval = std::time::Duration::from_secs(5);
let mut last_seen = 0u64;
loop {
tokio::time::sleep(check_interval).await;
// Never close while a request is in progress
// Check max connection lifetime (unconditional — even active connections
// must eventually be recycled to prevent resource accumulation).
if start.elapsed() >= max_lifetime {
debug!("HTTP connection exceeded max lifetime ({}s) from {}",
max_lifetime.as_secs(), peer_addr);
return;
}
// Never close for idleness while a request is in progress
if active_requests.load(Ordering::Relaxed) > 0 {
last_seen = last_activity.load(Ordering::Relaxed);
continue;
@@ -403,7 +462,7 @@ impl HttpProxyService {
last_seen = current;
}
} => {
debug!("HTTP connection idle timeout ({}s) from {}", idle_timeout.as_secs(), peer_addr);
debug!("HTTP connection timeout from {}", peer_addr);
conn.as_mut().graceful_shutdown();
// Give any in-flight work 5s to drain after graceful shutdown
let _ = tokio::time::timeout(std::time::Duration::from_secs(5), conn).await;
@@ -412,13 +471,17 @@ impl HttpProxyService {
}
/// Handle a single HTTP request.
async fn handle_request(
///
/// Accepts a generic body (`BoxBody`) so both the TCP/HTTP path (which boxes
/// `Incoming`) and the H3 path (which boxes the H3 request body stream) can
/// share the same backend forwarding logic.
pub async fn handle_request(
&self,
req: Request<Incoming>,
req: Request<BoxBody<Bytes, hyper::Error>>,
peer_addr: std::net::SocketAddr,
port: u16,
cancel: CancellationToken,
conn_activity: ConnActivity,
mut conn_activity: ConnActivity,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
let host = req.headers()
.get("host")
@@ -502,9 +565,13 @@ impl HttpProxyService {
}
}
// Periodic rate limiter cleanup (every 1000 requests)
// Periodic rate limiter cleanup (every 1000 requests or every 60s)
let count = self.request_counter.fetch_add(1, Ordering::Relaxed);
if count % 1000 == 0 {
let now_ms = self.rate_limiter_epoch.elapsed().as_millis() as u64;
let last_cleanup = self.last_rate_limiter_cleanup_ms.load(Ordering::Relaxed);
let time_triggered = now_ms.saturating_sub(last_cleanup) >= 60_000;
if count % 1000 == 0 || time_triggered {
self.last_rate_limiter_cleanup_ms.store(now_ms, Ordering::Relaxed);
for entry in self.route_rate_limiters.iter() {
entry.value().cleanup();
}
@@ -667,6 +734,14 @@ impl HttpProxyService {
port: upstream.port,
requested_host: host.clone(),
};
// Save cached H3 port for within-request escalation (may be needed later
// if TCP connect fails and we escalate to H3 as a last resort)
let cached_h3_port = self.protocol_cache.get(&protocol_cache_key)
.and_then(|c| c.h3_port);
// Track whether this ALPN probe is a periodic re-probe (vs first-time detection)
let mut is_reprobe = false;
let protocol_decision = match backend_protocol_mode {
rustproxy_config::BackendProtocol::Http1 => ProtocolDecision::H1,
rustproxy_config::BackendProtocol::Http2 => ProtocolDecision::H2,
@@ -677,19 +752,40 @@ impl HttpProxyService {
ProtocolDecision::H1
} else {
match self.protocol_cache.get(&protocol_cache_key) {
Some(cached) if cached.needs_reprobe => {
// Entry exists but 5+ minutes since last probe — force ALPN re-probe
// (only fires for H1/H2; H3 entries have needs_reprobe=false)
is_reprobe = true;
ProtocolDecision::AlpnProbe
}
Some(cached) => match cached.protocol {
crate::protocol_cache::DetectedProtocol::H3 => {
if let Some(h3_port) = cached.h3_port {
if self.protocol_cache.is_suppressed(&protocol_cache_key, crate::protocol_cache::DetectedProtocol::H3) {
// H3 cached but suppressed — fall back to ALPN probe
ProtocolDecision::AlpnProbe
} else if let Some(h3_port) = cached.h3_port {
ProtocolDecision::H3 { port: h3_port }
} else {
// H3 cached but no port — fall back to ALPN probe
ProtocolDecision::AlpnProbe
}
}
crate::protocol_cache::DetectedProtocol::H2 => ProtocolDecision::H2,
crate::protocol_cache::DetectedProtocol::H2 => {
if self.protocol_cache.is_suppressed(&protocol_cache_key, crate::protocol_cache::DetectedProtocol::H2) {
ProtocolDecision::H1
} else {
ProtocolDecision::H2
}
}
crate::protocol_cache::DetectedProtocol::H1 => ProtocolDecision::H1,
},
None => ProtocolDecision::AlpnProbe,
None => {
// Cache miss — skip ALPN probe if H2 is suppressed
if self.protocol_cache.is_suppressed(&protocol_cache_key, crate::protocol_cache::DetectedProtocol::H2) {
ProtocolDecision::H1
} else {
ProtocolDecision::AlpnProbe
}
}
}
}
}
@@ -703,9 +799,12 @@ impl HttpProxyService {
ProtocolDecision::AlpnProbe => (false, true),
};
// Track whether H3 connect failed — suppresses Alt-Svc re-caching to prevent
// the loop: H3 cached → QUIC timeout → H2/H1 fallback → Alt-Svc re-caches H3 → repeat
let mut h3_connect_failed = false;
// Set Alt-Svc cache key on conn_activity so build_streaming_response can check
// the backend's original Alt-Svc header before ResponseFilter injects our own.
if is_auto_detect_mode {
conn_activity.alt_svc_cache_key = Some(protocol_cache_key.clone());
conn_activity.alt_svc_request_url = Some(upstream_path.to_string());
}
// --- H3 path: try QUIC connection before TCP ---
if let ProtocolDecision::H3 { port: h3_port } = protocol_decision {
@@ -717,10 +816,10 @@ impl HttpProxyService {
};
// Try H3 pool checkout first
if let Some((quic_conn, _age)) = self.connection_pool.checkout_h3(&h3_pool_key) {
if let Some((pooled_sr, quic_conn, _age)) = self.connection_pool.checkout_h3(&h3_pool_key) {
self.metrics.backend_pool_hit(&upstream_key);
let result = self.forward_h3(
quic_conn, parts, body, upstream_headers, &upstream_path,
quic_conn, Some(pooled_sr), parts, body, upstream_headers, &upstream_path,
route_match.route, route_id, &ip_str, &h3_pool_key, domain_str, &conn_activity, &upstream_key,
).await;
self.upstream_selector.connection_ended(&upstream_key);
@@ -733,16 +832,26 @@ impl HttpProxyService {
self.metrics.backend_pool_miss(&upstream_key);
self.metrics.backend_connection_opened(&upstream_key, std::time::Instant::now().elapsed());
let result = self.forward_h3(
quic_conn, parts, body, upstream_headers, &upstream_path,
quic_conn, None, parts, body, upstream_headers, &upstream_path,
route_match.route, route_id, &ip_str, &h3_pool_key, domain_str, &conn_activity, &upstream_key,
).await;
self.upstream_selector.connection_ended(&upstream_key);
return result;
}
Err(e) => {
warn!(backend = %upstream_key, error = %e,
warn!(backend = %upstream_key, domain = %domain_str, error = %e,
"H3 backend connect failed, falling back to H2/H1");
h3_connect_failed = true;
// Record failure with escalating cooldown — prevents Alt-Svc
// from re-upgrading to H3 during cooldown period
if is_auto_detect_mode {
self.protocol_cache.record_failure(
protocol_cache_key.clone(),
crate::protocol_cache::DetectedProtocol::H3,
);
}
// Suppress Alt-Svc caching for the fallback to prevent re-caching H3
// from our own injected Alt-Svc header or a stale backend Alt-Svc
conn_activity.alt_svc_cache_key = None;
// Force ALPN probe on TCP fallback so we correctly detect H2 vs H1
// (don't cache anything yet — let the ALPN probe decide)
if is_auto_detect_mode && upstream.use_tls {
@@ -822,7 +931,7 @@ impl HttpProxyService {
let alpn = tls.get_ref().1.alpn_protocol();
let is_h2 = alpn.map(|p| p == b"h2").unwrap_or(false);
// Cache the result
// Cache the result (or update existing entry for re-probes)
let cache_key = crate::protocol_cache::ProtocolCacheKey {
host: upstream.host.clone(),
port: upstream.port,
@@ -833,13 +942,18 @@ impl HttpProxyService {
} else {
crate::protocol_cache::DetectedProtocol::H1
};
self.protocol_cache.insert(cache_key, detected);
if is_reprobe {
self.protocol_cache.update_probe_result(&cache_key, detected, "periodic ALPN re-probe");
} else {
self.protocol_cache.insert(cache_key, detected, "initial ALPN detection");
}
info!(
backend = %upstream_key,
domain = %domain_str,
protocol = if is_h2 { "h2" } else { "h1" },
connect_time_ms = %connect_start.elapsed().as_millis(),
reprobe = is_reprobe,
"Backend protocol detected via ALPN"
);
@@ -861,6 +975,38 @@ impl HttpProxyService {
);
self.metrics.backend_connect_error(&upstream_key);
self.upstream_selector.connection_ended(&upstream_key);
// --- Within-request escalation: try H3 via QUIC if retryable ---
if is_auto_detect_mode {
if let Some(h3_port) = cached_h3_port {
if self.protocol_cache.can_retry(&protocol_cache_key, crate::protocol_cache::DetectedProtocol::H3) {
self.protocol_cache.record_retry_attempt(&protocol_cache_key, crate::protocol_cache::DetectedProtocol::H3);
debug!(backend = %upstream_key, domain = %domain_str, "TLS connect failed — escalating to H3");
match self.connect_quic_backend(&upstream.host, h3_port).await {
Ok(quic_conn) => {
self.protocol_cache.clear_failure(&protocol_cache_key, crate::protocol_cache::DetectedProtocol::H3);
self.protocol_cache.insert_h3(protocol_cache_key.clone(), h3_port, "recovery — TLS failed, H3 succeeded");
let h3_pool_key = crate::connection_pool::PoolKey {
host: upstream.host.clone(), port: h3_port, use_tls: true,
protocol: crate::connection_pool::PoolProtocol::H3,
};
let result = self.forward_h3(
quic_conn, None, parts, body, upstream_headers, &upstream_path,
route_match.route, route_id, &ip_str, &h3_pool_key, domain_str, &conn_activity, &upstream_key,
).await;
self.upstream_selector.connection_ended(&upstream_key);
return result;
}
Err(e3) => {
debug!(backend = %upstream_key, error = %e3, "H3 escalation also failed");
self.protocol_cache.record_failure(protocol_cache_key.clone(), crate::protocol_cache::DetectedProtocol::H3);
}
}
}
}
// All protocols failed — evict cache entry
self.protocol_cache.evict(&protocol_cache_key);
}
return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend TLS unavailable"));
}
Err(_) => {
@@ -872,6 +1018,38 @@ impl HttpProxyService {
);
self.metrics.backend_connect_error(&upstream_key);
self.upstream_selector.connection_ended(&upstream_key);
// --- Within-request escalation: try H3 via QUIC if retryable ---
if is_auto_detect_mode {
if let Some(h3_port) = cached_h3_port {
if self.protocol_cache.can_retry(&protocol_cache_key, crate::protocol_cache::DetectedProtocol::H3) {
self.protocol_cache.record_retry_attempt(&protocol_cache_key, crate::protocol_cache::DetectedProtocol::H3);
debug!(backend = %upstream_key, domain = %domain_str, "TLS connect timeout — escalating to H3");
match self.connect_quic_backend(&upstream.host, h3_port).await {
Ok(quic_conn) => {
self.protocol_cache.clear_failure(&protocol_cache_key, crate::protocol_cache::DetectedProtocol::H3);
self.protocol_cache.insert_h3(protocol_cache_key.clone(), h3_port, "recovery — TLS timeout, H3 succeeded");
let h3_pool_key = crate::connection_pool::PoolKey {
host: upstream.host.clone(), port: h3_port, use_tls: true,
protocol: crate::connection_pool::PoolProtocol::H3,
};
let result = self.forward_h3(
quic_conn, None, parts, body, upstream_headers, &upstream_path,
route_match.route, route_id, &ip_str, &h3_pool_key, domain_str, &conn_activity, &upstream_key,
).await;
self.upstream_selector.connection_ended(&upstream_key);
return result;
}
Err(e3) => {
debug!(backend = %upstream_key, error = %e3, "H3 escalation also failed");
self.protocol_cache.record_failure(protocol_cache_key.clone(), crate::protocol_cache::DetectedProtocol::H3);
}
}
}
}
// All protocols failed — evict cache entry
self.protocol_cache.evict(&protocol_cache_key);
}
return Ok(error_response(StatusCode::GATEWAY_TIMEOUT, "Backend TLS connect timeout"));
}
}
@@ -899,6 +1077,38 @@ impl HttpProxyService {
);
self.metrics.backend_connect_error(&upstream_key);
self.upstream_selector.connection_ended(&upstream_key);
// --- Within-request escalation: try H3 via QUIC if retryable ---
if is_auto_detect_mode {
if let Some(h3_port) = cached_h3_port {
if self.protocol_cache.can_retry(&protocol_cache_key, crate::protocol_cache::DetectedProtocol::H3) {
self.protocol_cache.record_retry_attempt(&protocol_cache_key, crate::protocol_cache::DetectedProtocol::H3);
debug!(backend = %upstream_key, domain = %domain_str, "TCP connect failed — escalating to H3");
match self.connect_quic_backend(&upstream.host, h3_port).await {
Ok(quic_conn) => {
self.protocol_cache.clear_failure(&protocol_cache_key, crate::protocol_cache::DetectedProtocol::H3);
self.protocol_cache.insert_h3(protocol_cache_key.clone(), h3_port, "recovery — TCP failed, H3 succeeded");
let h3_pool_key = crate::connection_pool::PoolKey {
host: upstream.host.clone(), port: h3_port, use_tls: true,
protocol: crate::connection_pool::PoolProtocol::H3,
};
let result = self.forward_h3(
quic_conn, None, parts, body, upstream_headers, &upstream_path,
route_match.route, route_id, &ip_str, &h3_pool_key, domain_str, &conn_activity, &upstream_key,
).await;
self.upstream_selector.connection_ended(&upstream_key);
return result;
}
Err(e3) => {
debug!(backend = %upstream_key, error = %e3, "H3 escalation also failed");
self.protocol_cache.record_failure(protocol_cache_key.clone(), crate::protocol_cache::DetectedProtocol::H3);
}
}
}
}
// All protocols failed — evict cache entry
self.protocol_cache.evict(&protocol_cache_key);
}
return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend unavailable"));
}
Err(_) => {
@@ -910,6 +1120,38 @@ impl HttpProxyService {
);
self.metrics.backend_connect_error(&upstream_key);
self.upstream_selector.connection_ended(&upstream_key);
// --- Within-request escalation: try H3 via QUIC if retryable ---
if is_auto_detect_mode {
if let Some(h3_port) = cached_h3_port {
if self.protocol_cache.can_retry(&protocol_cache_key, crate::protocol_cache::DetectedProtocol::H3) {
self.protocol_cache.record_retry_attempt(&protocol_cache_key, crate::protocol_cache::DetectedProtocol::H3);
debug!(backend = %upstream_key, domain = %domain_str, "TCP connect timeout — escalating to H3");
match self.connect_quic_backend(&upstream.host, h3_port).await {
Ok(quic_conn) => {
self.protocol_cache.clear_failure(&protocol_cache_key, crate::protocol_cache::DetectedProtocol::H3);
self.protocol_cache.insert_h3(protocol_cache_key.clone(), h3_port, "recovery — TCP timeout, H3 succeeded");
let h3_pool_key = crate::connection_pool::PoolKey {
host: upstream.host.clone(), port: h3_port, use_tls: true,
protocol: crate::connection_pool::PoolProtocol::H3,
};
let result = self.forward_h3(
quic_conn, None, parts, body, upstream_headers, &upstream_path,
route_match.route, route_id, &ip_str, &h3_pool_key, domain_str, &conn_activity, &upstream_key,
).await;
self.upstream_selector.connection_ended(&upstream_key);
return result;
}
Err(e3) => {
debug!(backend = %upstream_key, error = %e3, "H3 escalation also failed");
self.protocol_cache.record_failure(protocol_cache_key.clone(), crate::protocol_cache::DetectedProtocol::H3);
}
}
}
}
// All protocols failed — evict cache entry
self.protocol_cache.evict(&protocol_cache_key);
}
return Ok(error_response(StatusCode::GATEWAY_TIMEOUT, "Backend connect timeout"));
}
}
@@ -948,22 +1190,6 @@ impl HttpProxyService {
self.upstream_selector.connection_ended(&upstream_key);
self.metrics.backend_connection_closed(&upstream_key);
// --- Alt-Svc discovery: check if backend advertises H3 ---
// Suppress Alt-Svc caching when we just failed an H3 attempt to prevent the loop:
// H3 cached → QUIC timeout → fallback → Alt-Svc re-caches H3 → repeat.
// The ALPN probe already cached H1 or H2; it will expire after 5min TTL,
// at which point we'll re-probe and see Alt-Svc again, retrying QUIC then.
if is_auto_detect_mode && !h3_connect_failed {
if let Ok(ref resp) = result {
if let Some(alt_svc) = resp.headers().get("alt-svc").and_then(|v| v.to_str().ok()) {
if let Some(h3_port) = parse_alt_svc_h3_port(alt_svc) {
debug!(backend = %upstream_key, h3_port, "Backend advertises H3 via Alt-Svc");
self.protocol_cache.insert_h3(protocol_cache_key, h3_port);
}
}
}
}
result
}
@@ -973,7 +1199,7 @@ impl HttpProxyService {
&self,
io: TokioIo<BackendStream>,
parts: hyper::http::request::Parts,
body: Incoming,
body: BoxBody<Bytes, hyper::Error>,
upstream_headers: hyper::HeaderMap,
upstream_path: &str,
_upstream: &crate::upstream_selector::UpstreamSelection,
@@ -1008,8 +1234,10 @@ impl HttpProxyService {
};
tokio::spawn(async move {
if let Err(e) = conn.await {
debug!("Upstream connection error: {}", e);
match tokio::time::timeout(std::time::Duration::from_secs(300), conn).await {
Ok(Err(e)) => debug!("Upstream connection error: {}", e),
Err(_) => debug!("H1 connection driver timed out after 300s"),
_ => {}
}
});
@@ -1021,7 +1249,7 @@ impl HttpProxyService {
&self,
mut sender: hyper::client::conn::http1::SendRequest<BoxBody<Bytes, hyper::Error>>,
parts: hyper::http::request::Parts,
body: Incoming,
body: BoxBody<Bytes, hyper::Error>,
upstream_headers: hyper::HeaderMap,
upstream_path: &str,
route: &rustproxy_config::RouteConfig,
@@ -1085,7 +1313,7 @@ impl HttpProxyService {
&self,
io: TokioIo<BackendStream>,
parts: hyper::http::request::Parts,
body: Incoming,
body: BoxBody<Bytes, hyper::Error>,
upstream_headers: hyper::HeaderMap,
upstream_path: &str,
_upstream: &crate::upstream_selector::UpstreamSelection,
@@ -1159,7 +1387,7 @@ impl HttpProxyService {
&self,
sender: hyper::client::conn::http2::SendRequest<BoxBody<Bytes, hyper::Error>>,
parts: hyper::http::request::Parts,
body: Incoming,
body: BoxBody<Bytes, hyper::Error>,
upstream_headers: hyper::HeaderMap,
upstream_path: &str,
route: &rustproxy_config::RouteConfig,
@@ -1352,7 +1580,7 @@ impl HttpProxyService {
&self,
io: TokioIo<BackendStream>,
parts: hyper::http::request::Parts,
body: Incoming,
body: BoxBody<Bytes, hyper::Error>,
mut upstream_headers: hyper::HeaderMap,
upstream_path: &str,
upstream: &crate::upstream_selector::UpstreamSelection,
@@ -1394,7 +1622,12 @@ impl HttpProxyService {
port: upstream.port,
requested_host: requested_host.clone(),
};
self.protocol_cache.insert(cache_key, crate::protocol_cache::DetectedProtocol::H1);
// Record H2 failure (escalating cooldown) before downgrading cache to H1
self.protocol_cache.record_failure(
cache_key.clone(),
crate::protocol_cache::DetectedProtocol::H2,
);
self.protocol_cache.insert(cache_key.clone(), crate::protocol_cache::DetectedProtocol::H1, "H2 handshake timeout — downgrade");
match self.reconnect_backend(upstream, domain, backend_key).await {
Some(fallback_backend) => {
@@ -1413,6 +1646,8 @@ impl HttpProxyService {
result
}
None => {
// H2 failed and H1 reconnect also failed — evict cache
self.protocol_cache.evict(&cache_key);
Ok(error_response(StatusCode::BAD_GATEWAY, "Backend unavailable after H2 timeout fallback"))
}
}
@@ -1527,13 +1762,17 @@ impl HttpProxyService {
self.metrics.backend_h2_failure(backend_key);
self.metrics.backend_handshake_error(backend_key);
// Update cache to H1 so subsequent requests skip H2
// Record H2 failure (escalating cooldown) and downgrade cache to H1
let cache_key = crate::protocol_cache::ProtocolCacheKey {
host: upstream.host.clone(),
port: upstream.port,
requested_host: requested_host.clone(),
};
self.protocol_cache.insert(cache_key, crate::protocol_cache::DetectedProtocol::H1);
self.protocol_cache.record_failure(
cache_key.clone(),
crate::protocol_cache::DetectedProtocol::H2,
);
self.protocol_cache.insert(cache_key.clone(), crate::protocol_cache::DetectedProtocol::H1, "H2 handshake error — downgrade");
// Reconnect for H1 (the original io was consumed by the failed h2 handshake)
match self.reconnect_backend(upstream, domain, backend_key).await {
@@ -1554,6 +1793,8 @@ impl HttpProxyService {
result
}
None => {
// H2 failed and H1 reconnect also failed — evict cache
self.protocol_cache.evict(&cache_key);
Ok(error_response(StatusCode::BAD_GATEWAY, "Backend unavailable after H2 fallback"))
}
}
@@ -1589,8 +1830,10 @@ impl HttpProxyService {
};
tokio::spawn(async move {
if let Err(e) = conn.await {
debug!("H1 fallback: upstream connection error: {}", e);
match tokio::time::timeout(std::time::Duration::from_secs(300), conn).await {
Ok(Err(e)) => debug!("H1 fallback: upstream connection error: {}", e),
Err(_) => debug!("H1 fallback: connection driver timed out after 300s"),
_ => {}
}
});
@@ -1683,7 +1926,7 @@ impl HttpProxyService {
&self,
mut sender: hyper::client::conn::http2::SendRequest<BoxBody<Bytes, hyper::Error>>,
parts: hyper::http::request::Parts,
body: Incoming,
body: BoxBody<Bytes, hyper::Error>,
upstream_headers: hyper::HeaderMap,
upstream_path: &str,
route: &rustproxy_config::RouteConfig,
@@ -1762,6 +2005,21 @@ impl HttpProxyService {
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
let (resp_parts, resp_body) = upstream_response.into_parts();
// Check for Alt-Svc in the backend's ORIGINAL response headers BEFORE
// ResponseFilter::apply_headers runs — the filter may inject our own Alt-Svc
// for client-facing HTTP/3 advertisement, which must not be confused with
// backend-originated Alt-Svc.
if let Some(ref cache_key) = conn_activity.alt_svc_cache_key {
if let Some(alt_svc) = resp_parts.headers.get("alt-svc").and_then(|v| v.to_str().ok()) {
if let Some(h3_port) = parse_alt_svc_h3_port(alt_svc) {
let url = conn_activity.alt_svc_request_url.as_deref().unwrap_or("-");
debug!(h3_port, url, "Backend advertises H3 via Alt-Svc");
let reason = format!("Alt-Svc response header ({})", url);
self.protocol_cache.insert_h3(cache_key.clone(), h3_port, &reason);
}
}
}
let mut response = Response::builder()
.status(resp_parts.status);
@@ -1811,7 +2069,7 @@ impl HttpProxyService {
/// Handle a WebSocket upgrade request (H1 Upgrade or H2 Extended CONNECT per RFC 8441).
async fn handle_websocket_upgrade(
&self,
req: Request<Incoming>,
req: Request<BoxBody<Bytes, hyper::Error>>,
peer_addr: std::net::SocketAddr,
upstream: &crate::upstream_selector::UpstreamSelection,
route: &rustproxy_config::RouteConfig,
@@ -2111,12 +2369,26 @@ impl HttpProxyService {
let ws_max_lifetime = self.ws_max_lifetime;
tokio::spawn(async move {
// RAII guard: ensures connection_ended is called even if this task panics
struct WsUpstreamGuard {
selector: UpstreamSelector,
key: String,
}
impl Drop for WsUpstreamGuard {
fn drop(&mut self) {
self.selector.connection_ended(&self.key);
}
}
let _upstream_guard = WsUpstreamGuard {
selector: upstream_selector,
key: upstream_key_owned.clone(),
};
let client_upgraded = match on_client_upgrade.await {
Ok(upgraded) => upgraded,
Err(e) => {
debug!("WebSocket: client upgrade failed: {}", e);
upstream_selector.connection_ended(&upstream_key_owned);
return;
return; // _upstream_guard Drop handles connection_ended
}
};
@@ -2275,9 +2547,7 @@ impl HttpProxyService {
watchdog.abort();
debug!("WebSocket tunnel closed: {} bytes in, {} bytes out", bytes_in, bytes_out);
upstream_selector.connection_ended(&upstream_key_owned);
// Bytes already reported per-chunk in the copy loops above
// _upstream_guard Drop handles connection_ended on all paths including panic
});
let body: BoxBody<Bytes, hyper::Error> = BoxBody::new(
@@ -2499,7 +2769,12 @@ impl HttpProxyService {
let quic_crypto = quinn::crypto::rustls::QuicClientConfig::try_from(tls_config)
.expect("Failed to create QUIC client crypto config");
let client_config = quinn::ClientConfig::new(Arc::new(quic_crypto));
// Tune QUIC transport to match H2 flow-control: 2 MB per-stream receive window.
let mut transport = quinn::TransportConfig::default();
transport.stream_receive_window(quinn::VarInt::from_u32(2 * 1024 * 1024));
let mut client_config = quinn::ClientConfig::new(Arc::new(quic_crypto));
client_config.transport_config(Arc::new(transport));
let mut endpoint = quinn::Endpoint::client("0.0.0.0:0".parse().unwrap())
.expect("Failed to create QUIC client endpoint");
@@ -2521,8 +2796,8 @@ impl HttpProxyService {
let server_name = host.to_string();
let connecting = self.quinn_client_endpoint.connect(addr, &server_name)?;
let connection = tokio::time::timeout(QUIC_CONNECT_TIMEOUT, connecting).await
.map_err(|_| "QUIC connect timeout (3s)")??;
let connection = tokio::time::timeout(self.connect_timeout, connecting).await
.map_err(|_| format!("QUIC connect timeout ({:?}) for {}", self.connect_timeout, host))??;
debug!("QUIC backend connection established to {}:{}", host, port);
Ok(connection)
@@ -2532,8 +2807,9 @@ impl HttpProxyService {
async fn forward_h3(
&self,
quic_conn: quinn::Connection,
pooled_sender: Option<h3::client::SendRequest<h3_quinn::OpenStreams, Bytes>>,
parts: hyper::http::request::Parts,
body: Incoming,
body: BoxBody<Bytes, hyper::Error>,
upstream_headers: hyper::HeaderMap,
upstream_path: &str,
route: &rustproxy_config::RouteConfig,
@@ -2544,29 +2820,42 @@ impl HttpProxyService {
conn_activity: &ConnActivity,
backend_key: &str,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
let h3_quinn_conn = h3_quinn::Connection::new(quic_conn.clone());
let (mut driver, mut send_request) = match h3::client::new(h3_quinn_conn).await {
Ok(pair) => pair,
Err(e) => {
error!(backend = %backend_key, domain = %domain, error = %e, "H3 client handshake failed");
self.metrics.backend_handshake_error(backend_key);
return Ok(error_response(StatusCode::BAD_GATEWAY, "H3 handshake failed"));
}
};
// Obtain the h3 SendRequest handle: skip handshake + driver on pool hit.
let (mut send_request, gen_holder) = if let Some(sr) = pooled_sender {
// Pool hit — reuse existing h3 session, no SETTINGS round-trip
(sr, None)
} else {
// Fresh QUIC connection — full h3 handshake + driver spawn
let h3_quinn_conn = h3_quinn::Connection::new(quic_conn.clone());
let (mut driver, sr) = match h3::client::builder()
.send_grease(false)
.build(h3_quinn_conn)
.await
{
Ok(pair) => pair,
Err(e) => {
error!(backend = %backend_key, domain = %domain, error = %e, "H3 client handshake failed");
self.metrics.backend_handshake_error(backend_key);
return Ok(error_response(StatusCode::BAD_GATEWAY, "H3 handshake failed"));
}
};
// Spawn the h3 connection driver
let driver_pool = Arc::clone(&self.connection_pool);
let driver_pool_key = pool_key.clone();
let gen_holder = Arc::new(std::sync::atomic::AtomicU64::new(u64::MAX));
let driver_gen = Arc::clone(&gen_holder);
tokio::spawn(async move {
let close_err = std::future::poll_fn(|cx| driver.poll_close(cx)).await;
debug!("H3 connection driver closed: {:?}", close_err);
let g = driver_gen.load(std::sync::atomic::Ordering::Relaxed);
if g != u64::MAX {
driver_pool.remove_h3_if_generation(&driver_pool_key, g);
let gen_holder = Arc::new(std::sync::atomic::AtomicU64::new(u64::MAX));
{
let driver_pool = Arc::clone(&self.connection_pool);
let driver_pool_key = pool_key.clone();
let driver_gen = Arc::clone(&gen_holder);
tokio::spawn(async move {
let close_err = std::future::poll_fn(|cx| driver.poll_close(cx)).await;
debug!("H3 connection driver closed: {:?}", close_err);
let g = driver_gen.load(std::sync::atomic::Ordering::Relaxed);
if g != u64::MAX {
driver_pool.remove_h3_if_generation(&driver_pool_key, g);
}
});
}
});
(sr, Some(gen_holder))
};
// Build the H3 request
let uri = hyper::Uri::builder()
@@ -2596,7 +2885,7 @@ impl HttpProxyService {
}
};
// Stream request body
// Stream request body (zero-copy: into_data yields owned Bytes)
let rid: Option<Arc<str>> = route_id.map(Arc::from);
let sip: Arc<str> = Arc::from(source_ip);
@@ -2606,9 +2895,9 @@ impl HttpProxyService {
while let Some(frame) = body.frame().await {
match frame {
Ok(frame) => {
if let Some(data) = frame.data_ref() {
if let Ok(data) = frame.into_data() {
self.metrics.record_bytes(data.len() as u64, 0, rid.as_deref(), Some(&sip));
if let Err(e) = stream.send_data(Bytes::copy_from_slice(data)).await {
if let Err(e) = stream.send_data(data).await {
error!(backend = %backend_key, error = %e, "H3 send_data failed");
return Ok(error_response(StatusCode::BAD_GATEWAY, "H3 body send failed"));
}
@@ -2650,8 +2939,23 @@ impl HttpProxyService {
ResponseFilter::apply_headers(route, headers, None);
}
// Stream response body back via an adapter
let h3_body = H3ClientResponseBody { stream };
// Stream response body back via unfold — correctly preserves waker across polls
let body_stream = futures::stream::unfold(stream, |mut s| async move {
match s.recv_data().await {
Ok(Some(mut buf)) => {
use bytes::Buf;
let data = buf.copy_to_bytes(buf.remaining());
Some((Ok::<_, hyper::Error>(http_body::Frame::data(data)), s))
}
Ok(None) => None,
Err(e) => {
warn!("H3 response body recv error: {}", e);
None
}
}
});
let h3_body = http_body_util::StreamBody::new(body_stream);
let counting_body = CountingBody::new(
h3_body,
Arc::clone(&self.metrics),
@@ -2668,10 +2972,16 @@ impl HttpProxyService {
let body: BoxBody<Bytes, hyper::Error> = BoxBody::new(counting_body);
// Register connection in pool on success
// Register connection in pool on success (fresh connections only)
if status != StatusCode::BAD_GATEWAY {
let g = self.connection_pool.register_h3(pool_key.clone(), quic_conn);
gen_holder.store(g, std::sync::atomic::Ordering::Relaxed);
if let Some(gh) = gen_holder {
let g = self.connection_pool.register_h3(
pool_key.clone(),
quic_conn,
send_request,
);
gh.store(g, std::sync::atomic::Ordering::Relaxed);
}
}
self.metrics.set_backend_protocol(backend_key, "h3");
@@ -2700,41 +3010,6 @@ fn parse_alt_svc_h3_port(header_value: &str) -> Option<u16> {
None
}
/// Response body adapter for H3 client responses.
/// Reads data from the h3 `RequestStream` recv side and presents it as an `http_body::Body`.
struct H3ClientResponseBody {
stream: h3::client::RequestStream<h3_quinn::BidiStream<Bytes>, Bytes>,
}
impl http_body::Body for H3ClientResponseBody {
type Data = Bytes;
type Error = hyper::Error;
fn poll_frame(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
// h3's recv_data is async, so we need to poll it manually.
// Use a small future to poll the recv_data call.
use std::future::Future;
let mut fut = Box::pin(self.stream.recv_data());
match fut.as_mut().poll(_cx) {
Poll::Ready(Ok(Some(mut buf))) => {
use bytes::Buf;
let data = Bytes::copy_from_slice(buf.chunk());
buf.advance(buf.remaining());
Poll::Ready(Some(Ok(http_body::Frame::data(data))))
}
Poll::Ready(Ok(None)) => Poll::Ready(None),
Poll::Ready(Err(e)) => {
warn!("H3 response body recv error: {}", e);
Poll::Ready(None)
}
Poll::Pending => Poll::Pending,
}
}
}
/// Insecure certificate verifier for backend TLS connections (fallback only).
/// The production path uses the shared config from tls_handler which has the same
/// behavior but with session resumption across all outbound connections.
@@ -2795,12 +3070,15 @@ impl Default for HttpProxyService {
connect_timeout: DEFAULT_CONNECT_TIMEOUT,
route_rate_limiters: Arc::new(DashMap::new()),
request_counter: AtomicU64::new(0),
rate_limiter_epoch: std::time::Instant::now(),
last_rate_limiter_cleanup_ms: AtomicU64::new(0),
regex_cache: DashMap::new(),
backend_tls_config: Self::default_backend_tls_config(),
backend_tls_config_alpn: Self::default_backend_tls_config_with_alpn(),
connection_pool: Arc::new(crate::connection_pool::ConnectionPool::new()),
protocol_cache: Arc::new(crate::protocol_cache::ProtocolCache::new()),
http_idle_timeout: DEFAULT_HTTP_IDLE_TIMEOUT,
http_max_lifetime: DEFAULT_HTTP_MAX_LIFETIME,
ws_inactivity_timeout: DEFAULT_WS_INACTIVITY_TIMEOUT,
ws_max_lifetime: DEFAULT_WS_MAX_LIFETIME,
quinn_client_endpoint: Arc::new(Self::create_quinn_client_endpoint()),

View File

@@ -6,7 +6,6 @@ use std::sync::Arc;
use bytes::Bytes;
use http_body_util::Full;
use http_body_util::BodyExt;
use hyper::body::Incoming;
use hyper::{Request, Response, StatusCode};
use http_body_util::combinators::BoxBody;
@@ -19,7 +18,7 @@ impl RequestFilter {
/// Apply security filters. Returns Some(response) if the request should be blocked.
pub fn apply(
security: &RouteSecurity,
req: &Request<Incoming>,
req: &Request<impl hyper::body::Body>,
peer_addr: &SocketAddr,
) -> Option<Response<BoxBody<Bytes, hyper::Error>>> {
Self::apply_with_rate_limiter(security, req, peer_addr, None)
@@ -29,7 +28,7 @@ impl RequestFilter {
/// Returns Some(response) if the request should be blocked.
pub fn apply_with_rate_limiter(
security: &RouteSecurity,
req: &Request<Incoming>,
req: &Request<impl hyper::body::Body>,
peer_addr: &SocketAddr,
rate_limiter: Option<&Arc<RateLimiter>>,
) -> Option<Response<BoxBody<Bytes, hyper::Error>>> {
@@ -182,7 +181,7 @@ impl RequestFilter {
/// Determine the rate limit key based on configuration.
fn rate_limit_key(
config: &rustproxy_config::RouteRateLimit,
req: &Request<Incoming>,
req: &Request<impl hyper::body::Body>,
peer_addr: &SocketAddr,
) -> String {
use rustproxy_config::RateLimitKeyBy;
@@ -220,7 +219,7 @@ impl RequestFilter {
/// Handle CORS preflight (OPTIONS) requests.
/// Returns Some(response) if this is a CORS preflight that should be handled.
pub fn handle_cors_preflight(
req: &Request<Incoming>,
req: &Request<impl hyper::body::Body>,
) -> Option<Response<BoxBody<Bytes, hyper::Error>>> {
if req.method() != hyper::Method::OPTIONS {
return None;

View File

@@ -31,6 +31,8 @@ pub struct Metrics {
pub total_udp_sessions: u64,
pub total_datagrams_in: u64,
pub total_datagrams_out: u64,
// Protocol detection cache snapshot (populated by RustProxy from HttpProxyService)
pub detected_protocols: Vec<ProtocolCacheEntryMetric>,
}
/// Per-route metrics.
@@ -76,6 +78,27 @@ pub struct BackendMetrics {
pub h2_failures: u64,
}
/// Protocol cache entry for metrics/UI display.
/// Populated from the HTTP proxy service's protocol detection cache.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ProtocolCacheEntryMetric {
pub host: String,
pub port: u16,
pub domain: Option<String>,
pub protocol: String,
pub h3_port: Option<u16>,
pub age_secs: u64,
pub last_accessed_secs: u64,
pub last_probed_secs: u64,
pub h2_suppressed: bool,
pub h3_suppressed: bool,
pub h2_cooldown_remaining_secs: Option<u64>,
pub h3_cooldown_remaining_secs: Option<u64>,
pub h2_consecutive_failures: Option<u32>,
pub h3_consecutive_failures: Option<u32>,
}
/// Statistics snapshot.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
@@ -411,11 +434,24 @@ impl MetricsCollector {
}
/// Record a backend connection closing.
/// Removes all per-backend tracking entries when the active count reaches 0.
pub fn backend_connection_closed(&self, key: &str) {
if let Some(counter) = self.backend_active.get(key) {
let val = counter.load(Ordering::Relaxed);
if val > 0 {
counter.fetch_sub(1, Ordering::Relaxed);
let prev = counter.fetch_sub(1, Ordering::Relaxed);
if prev <= 1 {
// Active count reached 0 — clean up all per-backend maps
drop(counter); // release DashMap ref before remove
self.backend_active.remove(key);
self.backend_total.remove(key);
self.backend_protocol.remove(key);
self.backend_connect_errors.remove(key);
self.backend_handshake_errors.remove(key);
self.backend_request_errors.remove(key);
self.backend_connect_time_us.remove(key);
self.backend_connect_count.remove(key);
self.backend_pool_hits.remove(key);
self.backend_pool_misses.remove(key);
self.backend_h2_failures.remove(key);
}
}
}
@@ -588,6 +624,24 @@ impl MetricsCollector {
self.ip_pending_tp.retain(|k, _| self.ip_connections.contains_key(k));
self.ip_throughput.retain(|k, _| self.ip_connections.contains_key(k));
self.ip_total_connections.retain(|k, _| self.ip_connections.contains_key(k));
// Safety-net: prune orphaned backend error/stats entries for backends
// that have no active or total connections (error-only backends).
// These accumulate when backend_connect_error/backend_handshake_error
// create entries but backend_connection_opened is never called.
let known_backends: HashSet<String> = self.backend_active.iter()
.map(|e| e.key().clone())
.chain(self.backend_total.iter().map(|e| e.key().clone()))
.collect();
self.backend_connect_errors.retain(|k, _| known_backends.contains(k));
self.backend_handshake_errors.retain(|k, _| known_backends.contains(k));
self.backend_request_errors.retain(|k, _| known_backends.contains(k));
self.backend_connect_time_us.retain(|k, _| known_backends.contains(k));
self.backend_connect_count.retain(|k, _| known_backends.contains(k));
self.backend_pool_hits.retain(|k, _| known_backends.contains(k));
self.backend_pool_misses.retain(|k, _| known_backends.contains(k));
self.backend_h2_failures.retain(|k, _| known_backends.contains(k));
self.backend_protocol.retain(|k, _| known_backends.contains(k));
}
/// Remove per-route metrics for route IDs that are no longer active.
@@ -811,6 +865,7 @@ impl MetricsCollector {
total_udp_sessions: self.total_udp_sessions.load(Ordering::Relaxed),
total_datagrams_in: self.total_datagrams_in.load(Ordering::Relaxed),
total_datagrams_out: self.total_datagrams_out.load(Ordering::Relaxed),
detected_protocols: vec![],
}
}
}
@@ -1213,10 +1268,13 @@ mod tests {
// No entry created
assert!(collector.backend_active.get(key).is_none());
// Open one, close two — should saturate at 0
// Open one, close — entries are removed when active count reaches 0
collector.backend_connection_opened(key, Duration::from_millis(1));
collector.backend_connection_closed(key);
// Entry should be cleaned up (active reached 0)
assert!(collector.backend_active.get(key).is_none());
// Second close on missing entry is a no-op
collector.backend_connection_closed(key);
assert_eq!(collector.backend_active.get(key).unwrap().load(Ordering::Relaxed), 0);
assert!(collector.backend_active.get(key).is_none());
}
}

View File

@@ -1,17 +0,0 @@
[package]
name = "rustproxy-nftables"
version.workspace = true
edition.workspace = true
license.workspace = true
authors.workspace = true
description = "NFTables kernel-level forwarding for RustProxy"
[dependencies]
rustproxy-config = { workspace = true }
tokio = { workspace = true }
tracing = { workspace = true }
thiserror = { workspace = true }
anyhow = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
libc = { workspace = true }

View File

@@ -1,10 +0,0 @@
//! # rustproxy-nftables
//!
//! NFTables kernel-level forwarding for RustProxy.
//! Generates and manages nft CLI rules for DNAT/SNAT.
pub mod nft_manager;
pub mod rule_builder;
pub use nft_manager::*;
pub use rule_builder::*;

View File

@@ -1,238 +0,0 @@
use thiserror::Error;
use std::collections::HashMap;
use tracing::{debug, info, warn};
#[derive(Debug, Error)]
pub enum NftError {
#[error("nft command failed: {0}")]
CommandFailed(String),
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
#[error("Not running as root")]
NotRoot,
}
/// Manager for nftables rules.
///
/// Executes `nft` CLI commands to manage kernel-level packet forwarding.
/// Requires root privileges; operations are skipped gracefully if not root.
pub struct NftManager {
table_name: String,
/// Active rules indexed by route ID
active_rules: HashMap<String, Vec<String>>,
/// Whether the table has been initialized
table_initialized: bool,
}
impl NftManager {
pub fn new(table_name: Option<String>) -> Self {
Self {
table_name: table_name.unwrap_or_else(|| "rustproxy".to_string()),
active_rules: HashMap::new(),
table_initialized: false,
}
}
/// Check if we are running as root.
fn is_root() -> bool {
unsafe { libc::geteuid() == 0 }
}
/// Execute a single nft command via the CLI.
async fn exec_nft(command: &str) -> Result<String, NftError> {
// The command starts with "nft ", strip it to get the args
let args = if command.starts_with("nft ") {
&command[4..]
} else {
command
};
let output = tokio::process::Command::new("nft")
.args(args.split_whitespace())
.output()
.await
.map_err(NftError::Io)?;
if output.status.success() {
Ok(String::from_utf8_lossy(&output.stdout).to_string())
} else {
let stderr = String::from_utf8_lossy(&output.stderr);
Err(NftError::CommandFailed(format!(
"Command '{}' failed: {}",
command, stderr
)))
}
}
/// Ensure the nftables table and chains are set up.
async fn ensure_table(&mut self) -> Result<(), NftError> {
if self.table_initialized {
return Ok(());
}
let setup_commands = crate::rule_builder::build_table_setup(&self.table_name);
for cmd in &setup_commands {
Self::exec_nft(cmd).await?;
}
self.table_initialized = true;
info!("NFTables table '{}' initialized", self.table_name);
Ok(())
}
/// Apply rules for a route.
///
/// Executes the nft commands via the CLI. If not running as root,
/// the rules are stored locally but not applied to the kernel.
pub async fn apply_rules(&mut self, route_id: &str, rules: Vec<String>) -> Result<(), NftError> {
if !Self::is_root() {
warn!("Not running as root, nftables rules will not be applied to kernel");
self.active_rules.insert(route_id.to_string(), rules);
return Ok(());
}
self.ensure_table().await?;
for cmd in &rules {
Self::exec_nft(cmd).await?;
debug!("Applied nft rule: {}", cmd);
}
info!("Applied {} nftables rules for route '{}'", rules.len(), route_id);
self.active_rules.insert(route_id.to_string(), rules);
Ok(())
}
/// Remove rules for a route.
///
/// Currently removes the route from tracking. To fully remove specific
/// rules would require handle-based tracking; for now, cleanup() removes
/// the entire table.
pub async fn remove_rules(&mut self, route_id: &str) -> Result<(), NftError> {
if let Some(rules) = self.active_rules.remove(route_id) {
info!("Removed {} tracked nft rules for route '{}'", rules.len(), route_id);
}
Ok(())
}
/// Clean up all managed rules by deleting the entire nftables table.
pub async fn cleanup(&mut self) -> Result<(), NftError> {
if !Self::is_root() {
warn!("Not running as root, skipping nftables cleanup");
self.active_rules.clear();
self.table_initialized = false;
return Ok(());
}
if self.table_initialized {
let cleanup_commands = crate::rule_builder::build_table_cleanup(&self.table_name);
for cmd in &cleanup_commands {
match Self::exec_nft(cmd).await {
Ok(_) => debug!("Cleanup: {}", cmd),
Err(e) => warn!("Cleanup command failed (may be ok): {}", e),
}
}
info!("NFTables table '{}' cleaned up", self.table_name);
}
self.active_rules.clear();
self.table_initialized = false;
Ok(())
}
/// Get the table name.
pub fn table_name(&self) -> &str {
&self.table_name
}
/// Whether the table has been initialized in the kernel.
pub fn is_initialized(&self) -> bool {
self.table_initialized
}
/// Get the number of active route rule sets.
pub fn active_route_count(&self) -> usize {
self.active_rules.len()
}
/// Get the status of all active rules.
pub fn status(&self) -> HashMap<String, serde_json::Value> {
let mut status = HashMap::new();
for (route_id, rules) in &self.active_rules {
status.insert(
route_id.clone(),
serde_json::json!({
"ruleCount": rules.len(),
"rules": rules,
}),
);
}
status
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_new_default_table_name() {
let mgr = NftManager::new(None);
assert_eq!(mgr.table_name(), "rustproxy");
assert!(!mgr.is_initialized());
}
#[test]
fn test_new_custom_table_name() {
let mgr = NftManager::new(Some("custom".to_string()));
assert_eq!(mgr.table_name(), "custom");
}
#[tokio::test]
async fn test_apply_rules_non_root() {
let mut mgr = NftManager::new(None);
// When not root, rules are stored but not applied to kernel
let rules = vec!["nft add rule ip rustproxy prerouting tcp dport 443 dnat to 10.0.0.1:8443".to_string()];
mgr.apply_rules("route-1", rules).await.unwrap();
assert_eq!(mgr.active_route_count(), 1);
let status = mgr.status();
assert!(status.contains_key("route-1"));
assert_eq!(status["route-1"]["ruleCount"], 1);
}
#[tokio::test]
async fn test_remove_rules() {
let mut mgr = NftManager::new(None);
let rules = vec!["nft add rule test".to_string()];
mgr.apply_rules("route-1", rules).await.unwrap();
assert_eq!(mgr.active_route_count(), 1);
mgr.remove_rules("route-1").await.unwrap();
assert_eq!(mgr.active_route_count(), 0);
}
#[tokio::test]
async fn test_cleanup_non_root() {
let mut mgr = NftManager::new(None);
let rules = vec!["nft add rule test".to_string()];
mgr.apply_rules("route-1", rules).await.unwrap();
mgr.apply_rules("route-2", vec!["nft add rule test2".to_string()]).await.unwrap();
mgr.cleanup().await.unwrap();
assert_eq!(mgr.active_route_count(), 0);
assert!(!mgr.is_initialized());
}
#[tokio::test]
async fn test_status_multiple_routes() {
let mut mgr = NftManager::new(None);
mgr.apply_rules("web", vec!["rule1".to_string(), "rule2".to_string()]).await.unwrap();
mgr.apply_rules("api", vec!["rule3".to_string()]).await.unwrap();
let status = mgr.status();
assert_eq!(status.len(), 2);
assert_eq!(status["web"]["ruleCount"], 2);
assert_eq!(status["api"]["ruleCount"], 1);
}
}

View File

@@ -1,146 +0,0 @@
use rustproxy_config::{NfTablesOptions, NfTablesProtocol};
/// Build nftables DNAT rule for port forwarding.
pub fn build_dnat_rule(
table_name: &str,
chain_name: &str,
source_port: u16,
target_host: &str,
target_port: u16,
options: &NfTablesOptions,
) -> Vec<String> {
let protocols: Vec<&str> = match options.protocol.as_ref().unwrap_or(&NfTablesProtocol::Tcp) {
NfTablesProtocol::Tcp => vec!["tcp"],
NfTablesProtocol::Udp => vec!["udp"],
NfTablesProtocol::All => vec!["tcp", "udp"],
};
let mut rules = Vec::new();
for protocol in &protocols {
// DNAT rule
rules.push(format!(
"nft add rule ip {} {} {} dport {} dnat to {}:{}",
table_name, chain_name, protocol, source_port, target_host, target_port,
));
// SNAT rule if preserving source IP is not enabled
if !options.preserve_source_ip.unwrap_or(false) {
rules.push(format!(
"nft add rule ip {} postrouting {} dport {} masquerade",
table_name, protocol, target_port,
));
}
// Rate limiting
if let Some(max_rate) = &options.max_rate {
rules.push(format!(
"nft add rule ip {} {} {} dport {} limit rate {} accept",
table_name, chain_name, protocol, source_port, max_rate,
));
}
}
rules
}
/// Build the initial table and chain setup commands.
pub fn build_table_setup(table_name: &str) -> Vec<String> {
vec![
format!("nft add table ip {}", table_name),
format!("nft add chain ip {} prerouting {{ type nat hook prerouting priority 0 \\; }}", table_name),
format!("nft add chain ip {} postrouting {{ type nat hook postrouting priority 100 \\; }}", table_name),
]
}
/// Build cleanup commands to remove the table.
pub fn build_table_cleanup(table_name: &str) -> Vec<String> {
vec![format!("nft delete table ip {}", table_name)]
}
#[cfg(test)]
mod tests {
use super::*;
fn make_options() -> NfTablesOptions {
NfTablesOptions {
preserve_source_ip: None,
protocol: None,
max_rate: None,
priority: None,
table_name: None,
use_ip_sets: None,
use_advanced_nat: None,
}
}
#[test]
fn test_basic_dnat_rule() {
let options = make_options();
let rules = build_dnat_rule("rustproxy", "prerouting", 443, "10.0.0.1", 8443, &options);
assert!(rules.len() >= 1);
assert!(rules[0].contains("dnat to 10.0.0.1:8443"));
assert!(rules[0].contains("dport 443"));
}
#[test]
fn test_preserve_source_ip() {
let mut options = make_options();
options.preserve_source_ip = Some(true);
let rules = build_dnat_rule("rustproxy", "prerouting", 443, "10.0.0.1", 8443, &options);
// When preserving source IP, no masquerade rule
assert!(rules.iter().all(|r| !r.contains("masquerade")));
}
#[test]
fn test_without_preserve_source_ip() {
let options = make_options();
let rules = build_dnat_rule("rustproxy", "prerouting", 443, "10.0.0.1", 8443, &options);
assert!(rules.iter().any(|r| r.contains("masquerade")));
}
#[test]
fn test_rate_limited_rule() {
let mut options = make_options();
options.max_rate = Some("100/second".to_string());
let rules = build_dnat_rule("rustproxy", "prerouting", 80, "10.0.0.1", 8080, &options);
assert!(rules.iter().any(|r| r.contains("limit rate 100/second")));
}
#[test]
fn test_table_setup_commands() {
let commands = build_table_setup("rustproxy");
assert_eq!(commands.len(), 3);
assert!(commands[0].contains("add table ip rustproxy"));
assert!(commands[1].contains("prerouting"));
assert!(commands[2].contains("postrouting"));
}
#[test]
fn test_table_cleanup() {
let commands = build_table_cleanup("rustproxy");
assert_eq!(commands.len(), 1);
assert!(commands[0].contains("delete table ip rustproxy"));
}
#[test]
fn test_protocol_all_generates_tcp_and_udp_rules() {
let mut options = make_options();
options.protocol = Some(NfTablesProtocol::All);
let rules = build_dnat_rule("rustproxy", "prerouting", 53, "10.0.0.53", 53, &options);
// Should have TCP DNAT + masquerade + UDP DNAT + masquerade = 4 rules
assert_eq!(rules.len(), 4);
assert!(rules.iter().any(|r| r.contains("tcp dport 53 dnat")));
assert!(rules.iter().any(|r| r.contains("udp dport 53 dnat")));
assert!(rules.iter().filter(|r| r.contains("masquerade")).count() == 2);
}
#[test]
fn test_protocol_udp() {
let mut options = make_options();
options.protocol = Some(NfTablesProtocol::Udp);
let rules = build_dnat_rule("rustproxy", "prerouting", 53, "10.0.0.53", 53, &options);
assert!(rules.iter().all(|r| !r.contains("tcp")));
assert!(rules.iter().any(|r| r.contains("udp dport 53 dnat")));
}
}

View File

@@ -10,7 +10,6 @@ pub mod forwarder;
pub mod proxy_protocol;
pub mod tls_handler;
pub mod connection_tracker;
pub mod socket_relay;
pub mod socket_opts;
pub mod udp_session;
pub mod udp_listener;
@@ -22,7 +21,6 @@ pub use forwarder::*;
pub use proxy_protocol::*;
pub use tls_handler::*;
pub use connection_tracker::*;
pub use socket_relay::*;
pub use socket_opts::*;
pub use udp_session::*;
pub use udp_listener::*;

View File

@@ -3,13 +3,21 @@
//! Manages QUIC endpoints (via quinn), accepts connections, and either:
//! - Forwards streams bidirectionally to TCP backends (QUIC termination)
//! - Dispatches to H3ProxyService for HTTP/3 handling (Phase 5)
//!
//! When `proxy_ips` is configured, a UDP relay layer intercepts PROXY protocol v2
//! headers before they reach quinn, extracting real client IPs for attribution.
use std::net::SocketAddr;
use std::net::{IpAddr, SocketAddr};
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::Instant;
use tokio::io::AsyncWriteExt;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::UdpSocket;
use tokio::task::JoinHandle;
use arc_swap::ArcSwap;
use dashmap::DashMap;
use quinn::{Endpoint, ServerConfig as QuinnServerConfig};
use rustls::ServerConfig as RustlsServerConfig;
use tokio_util::sync::CancellationToken;
@@ -47,9 +55,292 @@ pub fn create_quic_endpoint(
Ok(endpoint)
}
// ===== PROXY protocol relay for QUIC =====
/// Result of creating a QUIC endpoint with a PROXY protocol relay layer.
pub struct QuicProxyRelay {
/// The quinn endpoint (bound to 127.0.0.1:ephemeral).
pub endpoint: Endpoint,
/// The relay recv loop task handle.
pub relay_task: JoinHandle<()>,
/// Maps relay socket local addr → real client SocketAddr (from PROXY v2).
/// Consulted by `quic_accept_loop` to resolve real client IPs.
pub real_client_map: Arc<DashMap<SocketAddr, SocketAddr>>,
}
/// A single relay session for forwarding datagrams between an external source
/// and the internal quinn endpoint.
struct RelaySession {
socket: Arc<UdpSocket>,
last_activity: AtomicU64,
return_task: JoinHandle<()>,
cancel: CancellationToken,
}
impl Drop for RelaySession {
fn drop(&mut self) {
self.cancel.cancel();
self.return_task.abort();
}
}
/// Create a QUIC endpoint with a PROXY protocol v2 relay layer.
///
/// Instead of giving the external socket to quinn, we:
/// 1. Bind a raw UDP socket on 0.0.0.0:port (external)
/// 2. Bind quinn on 127.0.0.1:0 (internal, ephemeral)
/// 3. Run a relay loop that filters PROXY v2 headers and forwards datagrams
///
/// Only used when `proxy_ips` is non-empty.
pub fn create_quic_endpoint_with_proxy_relay(
port: u16,
tls_config: Arc<RustlsServerConfig>,
proxy_ips: Arc<Vec<IpAddr>>,
cancel: CancellationToken,
) -> anyhow::Result<QuicProxyRelay> {
// Bind external socket on the real port
let external_socket = std::net::UdpSocket::bind(SocketAddr::from(([0, 0, 0, 0], port)))?;
external_socket.set_nonblocking(true)?;
let external_socket = Arc::new(
UdpSocket::from_std(external_socket)
.map_err(|e| anyhow::anyhow!("Failed to wrap external socket: {}", e))?,
);
// Bind quinn on localhost ephemeral port
let internal_socket = std::net::UdpSocket::bind("127.0.0.1:0")?;
let quinn_internal_addr = internal_socket.local_addr()?;
let quic_crypto = quinn::crypto::rustls::QuicServerConfig::try_from(tls_config)
.map_err(|e| anyhow::anyhow!("Failed to create QUIC crypto config: {}", e))?;
let server_config = QuinnServerConfig::with_crypto(Arc::new(quic_crypto));
let endpoint = Endpoint::new(
quinn::EndpointConfig::default(),
Some(server_config),
internal_socket,
quinn::default_runtime()
.ok_or_else(|| anyhow::anyhow!("No async runtime for quinn"))?,
)?;
let real_client_map = Arc::new(DashMap::new());
let relay_task = tokio::spawn(quic_proxy_relay_loop(
external_socket,
quinn_internal_addr,
proxy_ips,
Arc::clone(&real_client_map),
cancel,
));
info!("QUIC endpoint with PROXY relay on port {} (quinn internal: {})", port, quinn_internal_addr);
Ok(QuicProxyRelay { endpoint, relay_task, real_client_map })
}
/// Main relay loop: reads datagrams from the external socket, filters PROXY v2
/// headers from trusted proxy IPs, and forwards everything else to quinn via
/// per-session relay sockets.
async fn quic_proxy_relay_loop(
external_socket: Arc<UdpSocket>,
quinn_internal_addr: SocketAddr,
proxy_ips: Arc<Vec<IpAddr>>,
real_client_map: Arc<DashMap<SocketAddr, SocketAddr>>,
cancel: CancellationToken,
) {
// Maps external source addr → real client addr (from PROXY v2 headers)
let proxy_addr_map: DashMap<SocketAddr, SocketAddr> = DashMap::new();
// Maps external source addr → relay session
let relay_sessions: DashMap<SocketAddr, Arc<RelaySession>> = DashMap::new();
let epoch = Instant::now();
let mut buf = vec![0u8; 65535];
// Inline cleanup: periodically scan relay_sessions for stale entries
let mut last_cleanup = Instant::now();
let cleanup_interval = std::time::Duration::from_secs(30);
let session_timeout_ms: u64 = 120_000;
loop {
let (len, src_addr) = tokio::select! {
_ = cancel.cancelled() => {
debug!("QUIC proxy relay loop cancelled");
break;
}
result = external_socket.recv_from(&mut buf) => {
match result {
Ok(r) => r,
Err(e) => {
warn!("QUIC proxy relay recv error: {}", e);
continue;
}
}
}
};
let datagram = &buf[..len];
// PROXY v2 handling: only on first datagram from a trusted proxy IP
// (before a relay session exists for this source)
if proxy_ips.contains(&src_addr.ip()) && relay_sessions.get(&src_addr).is_none() {
if crate::proxy_protocol::is_proxy_protocol_v2(datagram) {
match crate::proxy_protocol::parse_v2(datagram) {
Ok((header, _consumed)) => {
debug!("QUIC PROXY v2 from {}: real client {}", src_addr, header.source_addr);
proxy_addr_map.insert(src_addr, header.source_addr);
continue; // consume the PROXY v2 datagram
}
Err(e) => {
debug!("QUIC proxy relay: failed to parse PROXY v2 from {}: {}", src_addr, e);
}
}
}
}
// Determine real client address
let real_client = proxy_addr_map.get(&src_addr)
.map(|r| *r)
.unwrap_or(src_addr);
// Get or create relay session for this external source
let session = match relay_sessions.get(&src_addr) {
Some(s) => {
s.last_activity.store(epoch.elapsed().as_millis() as u64, Ordering::Relaxed);
Arc::clone(s.value())
}
None => {
// Create new relay socket connected to quinn's internal address
let relay_socket = match UdpSocket::bind("127.0.0.1:0").await {
Ok(s) => s,
Err(e) => {
warn!("QUIC relay: failed to bind relay socket: {}", e);
continue;
}
};
if let Err(e) = relay_socket.connect(quinn_internal_addr).await {
warn!("QUIC relay: failed to connect relay socket to {}: {}", quinn_internal_addr, e);
continue;
}
let relay_local_addr = match relay_socket.local_addr() {
Ok(a) => a,
Err(e) => {
warn!("QUIC relay: failed to get relay socket local addr: {}", e);
continue;
}
};
let relay_socket = Arc::new(relay_socket);
// Store the real client mapping for the QUIC accept loop
real_client_map.insert(relay_local_addr, real_client);
// Spawn return-path relay: quinn -> external socket -> original source
let session_cancel = cancel.child_token();
let return_task = tokio::spawn(relay_return_path(
Arc::clone(&relay_socket),
Arc::clone(&external_socket),
src_addr,
session_cancel.child_token(),
));
let session = Arc::new(RelaySession {
socket: relay_socket,
last_activity: AtomicU64::new(epoch.elapsed().as_millis() as u64),
return_task,
cancel: session_cancel,
});
relay_sessions.insert(src_addr, Arc::clone(&session));
debug!("QUIC relay: new session for {} (relay {}), real client {}",
src_addr, relay_local_addr, real_client);
session
}
};
// Forward datagram to quinn via the relay socket
if let Err(e) = session.socket.send(datagram).await {
debug!("QUIC relay: forward error to quinn for {}: {}", src_addr, e);
}
// Periodic cleanup of stale relay sessions
if last_cleanup.elapsed() >= cleanup_interval {
last_cleanup = Instant::now();
let now_ms = epoch.elapsed().as_millis() as u64;
let stale_keys: Vec<SocketAddr> = relay_sessions.iter()
.filter(|entry| {
let age = now_ms.saturating_sub(entry.value().last_activity.load(Ordering::Relaxed));
age > session_timeout_ms
})
.map(|entry| *entry.key())
.collect();
for key in stale_keys {
if let Some((_, session)) = relay_sessions.remove(&key) {
session.cancel.cancel();
session.return_task.abort();
// Clean up real_client_map entry
if let Ok(addr) = session.socket.local_addr() {
real_client_map.remove(&addr);
}
proxy_addr_map.remove(&key);
debug!("QUIC relay: cleaned up stale session for {}", key);
}
}
// Also clean orphaned proxy_addr_map entries (PROXY header received
// but no relay session was ever created, e.g. client never sent data)
let orphaned: Vec<SocketAddr> = proxy_addr_map.iter()
.filter(|entry| relay_sessions.get(entry.key()).is_none())
.map(|entry| *entry.key())
.collect();
for key in orphaned {
proxy_addr_map.remove(&key);
debug!("QUIC relay: cleaned up orphaned proxy_addr_map entry for {}", key);
}
}
}
// Shutdown: cancel all relay sessions
for entry in relay_sessions.iter() {
entry.value().cancel.cancel();
entry.value().return_task.abort();
}
}
/// Return-path relay: receives datagrams from quinn (via the relay socket)
/// and forwards them back to the external client through the external socket.
async fn relay_return_path(
relay_socket: Arc<UdpSocket>,
external_socket: Arc<UdpSocket>,
external_src_addr: SocketAddr,
cancel: CancellationToken,
) {
let mut buf = vec![0u8; 65535];
loop {
let len = tokio::select! {
_ = cancel.cancelled() => break,
result = relay_socket.recv(&mut buf) => {
match result {
Ok(len) => len,
Err(e) => {
debug!("QUIC relay return recv error for {}: {}", external_src_addr, e);
break;
}
}
}
};
if let Err(e) = external_socket.send_to(&buf[..len], external_src_addr).await {
debug!("QUIC relay return send error to {}: {}", external_src_addr, e);
break;
}
}
}
// ===== QUIC accept loop =====
/// Run the QUIC accept loop for a single endpoint.
///
/// Accepts incoming QUIC connections and spawns a task per connection.
/// When `real_client_map` is provided, it is consulted to resolve real client
/// IPs from PROXY protocol v2 headers (relay socket addr → real client addr).
pub async fn quic_accept_loop(
endpoint: Endpoint,
port: u16,
@@ -58,6 +349,7 @@ pub async fn quic_accept_loop(
conn_tracker: Arc<ConnectionTracker>,
cancel: CancellationToken,
h3_service: Option<Arc<H3ProxyService>>,
real_client_map: Option<Arc<DashMap<SocketAddr, SocketAddr>>>,
) {
loop {
let incoming = tokio::select! {
@@ -77,11 +369,16 @@ pub async fn quic_accept_loop(
};
let remote_addr = incoming.remote_address();
let ip = remote_addr.ip();
// Resolve real client IP from PROXY protocol map if available
let real_addr = real_client_map.as_ref()
.and_then(|map| map.get(&remote_addr).map(|r| *r))
.unwrap_or(remote_addr);
let ip = real_addr.ip();
// Per-IP rate limiting
if !conn_tracker.try_accept(&ip) {
debug!("QUIC connection rejected from {} (rate limit)", remote_addr);
debug!("QUIC connection rejected from {} (rate limit)", real_addr);
// Drop `incoming` to refuse the connection
continue;
}
@@ -104,7 +401,7 @@ pub async fn quic_accept_loop(
let route = match rm.find_route(&ctx) {
Some(m) => m.route.clone(),
None => {
debug!("No QUIC route matched for port {} from {}", port, remote_addr);
debug!("No QUIC route matched for port {} from {}", port, real_addr);
continue;
}
};
@@ -117,16 +414,35 @@ pub async fn quic_accept_loop(
let conn_tracker = Arc::clone(&conn_tracker);
let cancel = cancel.child_token();
let h3_svc = h3_service.clone();
let real_client_addr = if real_addr != remote_addr { Some(real_addr) } else { None };
tokio::spawn(async move {
match handle_quic_connection(incoming, route, port, Arc::clone(&metrics), &cancel, h3_svc).await {
Ok(()) => debug!("QUIC connection from {} completed", remote_addr),
Err(e) => debug!("QUIC connection from {} error: {}", remote_addr, e),
// RAII guard: ensures metrics/tracker cleanup even on panic
struct QuicConnGuard {
tracker: Arc<ConnectionTracker>,
metrics: Arc<MetricsCollector>,
ip: std::net::IpAddr,
ip_str: String,
route_id: Option<String>,
}
impl Drop for QuicConnGuard {
fn drop(&mut self) {
self.tracker.connection_closed(&self.ip);
self.metrics.connection_closed(self.route_id.as_deref(), Some(&self.ip_str));
}
}
let _guard = QuicConnGuard {
tracker: conn_tracker,
metrics: Arc::clone(&metrics),
ip,
ip_str,
route_id,
};
// Cleanup
conn_tracker.connection_closed(&ip);
metrics.connection_closed(route_id.as_deref(), Some(&ip_str));
match handle_quic_connection(incoming, route, port, Arc::clone(&metrics), &cancel, h3_svc, real_client_addr).await {
Ok(()) => debug!("QUIC connection from {} completed", real_addr),
Err(e) => debug!("QUIC connection from {} error: {}", real_addr, e),
}
});
}
@@ -144,10 +460,11 @@ async fn handle_quic_connection(
metrics: Arc<MetricsCollector>,
cancel: &CancellationToken,
h3_service: Option<Arc<H3ProxyService>>,
real_client_addr: Option<SocketAddr>,
) -> anyhow::Result<()> {
let connection = incoming.await?;
let remote_addr = connection.remote_address();
debug!("QUIC connection established from {}", remote_addr);
let effective_addr = real_client_addr.unwrap_or_else(|| connection.remote_address());
debug!("QUIC connection established from {}", effective_addr);
// Check if this route has HTTP/3 enabled
let enable_http3 = route.action.udp.as_ref()
@@ -158,7 +475,7 @@ async fn handle_quic_connection(
if enable_http3 {
if let Some(ref h3_svc) = h3_service {
debug!("HTTP/3 enabled for route {:?}, dispatching to H3ProxyService", route.name);
h3_svc.handle_connection(connection, &route, port).await
h3_svc.handle_connection(connection, &route, port, real_client_addr, cancel).await
} else {
warn!("HTTP/3 enabled for route {:?} but H3ProxyService not initialized", route.name);
// Keep connection alive until cancelled
@@ -172,7 +489,7 @@ async fn handle_quic_connection(
}
} else {
// Non-HTTP3 QUIC: bidirectional stream forwarding to TCP backend
handle_quic_stream_forwarding(connection, route, port, metrics, cancel).await
handle_quic_stream_forwarding(connection, route, port, metrics, cancel, real_client_addr).await
}
}
@@ -187,8 +504,9 @@ async fn handle_quic_stream_forwarding(
port: u16,
metrics: Arc<MetricsCollector>,
cancel: &CancellationToken,
real_client_addr: Option<SocketAddr>,
) -> anyhow::Result<()> {
let remote_addr = connection.remote_address();
let effective_addr = real_client_addr.unwrap_or_else(|| connection.remote_address());
let route_id = route.name.as_deref().or(route.id.as_deref());
let metrics_arc = metrics;
@@ -209,7 +527,7 @@ async fn handle_quic_stream_forwarding(
Err(quinn::ConnectionError::ApplicationClosed(_)) => break,
Err(quinn::ConnectionError::LocallyClosed) => break,
Err(e) => {
debug!("QUIC stream accept error from {}: {}", remote_addr, e);
debug!("QUIC stream accept error from {}: {}", effective_addr, e);
break;
}
}
@@ -217,9 +535,10 @@ async fn handle_quic_stream_forwarding(
};
let backend_addr = backend_addr.clone();
let ip_str = remote_addr.ip().to_string();
let ip_str = effective_addr.ip().to_string();
let stream_metrics = Arc::clone(&metrics_arc);
let stream_route_id = route_id.map(|s| s.to_string());
let stream_cancel = cancel.child_token();
// Spawn a task for each QUIC stream → TCP bidirectional forwarding
tokio::spawn(async move {
@@ -227,6 +546,7 @@ async fn handle_quic_stream_forwarding(
send_stream,
recv_stream,
&backend_addr,
stream_cancel,
).await {
Ok((bytes_in, bytes_out)) => {
stream_metrics.record_bytes(
@@ -247,27 +567,112 @@ async fn handle_quic_stream_forwarding(
}
/// Forward a single QUIC bidirectional stream to a TCP backend connection.
///
/// Includes inactivity timeout (60s), max lifetime (10min), and cancellation
/// to prevent leaked stream tasks when the parent connection closes.
async fn forward_quic_stream_to_tcp(
mut quic_send: quinn::SendStream,
mut quic_recv: quinn::RecvStream,
backend_addr: &str,
cancel: CancellationToken,
) -> anyhow::Result<(u64, u64)> {
let inactivity_timeout = std::time::Duration::from_secs(60);
let max_lifetime = std::time::Duration::from_secs(600);
// Connect to backend TCP
let tcp_stream = tokio::net::TcpStream::connect(backend_addr).await?;
let (mut tcp_read, mut tcp_write) = tcp_stream.into_split();
// Bidirectional copy
let client_to_backend = tokio::io::copy(&mut quic_recv, &mut tcp_write);
let backend_to_client = tokio::io::copy(&mut tcp_read, &mut quic_send);
let last_activity = Arc::new(AtomicU64::new(0));
let start = std::time::Instant::now();
let conn_cancel = CancellationToken::new();
let (c2b, b2c) = tokio::join!(client_to_backend, backend_to_client);
let la1 = Arc::clone(&last_activity);
let cc1 = conn_cancel.clone();
let c2b = tokio::spawn(async move {
let mut buf = vec![0u8; 65536];
let mut total = 0u64;
loop {
let n = tokio::select! {
result = quic_recv.read(&mut buf) => match result {
Ok(Some(0)) | Ok(None) | Err(_) => break,
Ok(Some(n)) => n,
},
_ = cc1.cancelled() => break,
};
if tcp_write.write_all(&buf[..n]).await.is_err() {
break;
}
total += n as u64;
la1.store(start.elapsed().as_millis() as u64, Ordering::Relaxed);
}
let _ = tokio::time::timeout(
std::time::Duration::from_secs(2),
tcp_write.shutdown(),
).await;
total
});
let bytes_in = c2b.unwrap_or(0);
let bytes_out = b2c.unwrap_or(0);
let la2 = Arc::clone(&last_activity);
let cc2 = conn_cancel.clone();
let b2c = tokio::spawn(async move {
let mut buf = vec![0u8; 65536];
let mut total = 0u64;
loop {
let n = tokio::select! {
result = tcp_read.read(&mut buf) => match result {
Ok(0) | Err(_) => break,
Ok(n) => n,
},
_ = cc2.cancelled() => break,
};
// quinn SendStream implements AsyncWrite
if quic_send.write_all(&buf[..n]).await.is_err() {
break;
}
total += n as u64;
la2.store(start.elapsed().as_millis() as u64, Ordering::Relaxed);
}
let _ = quic_send.finish();
total
});
// Graceful shutdown
let _ = quic_send.finish();
let _ = tcp_write.shutdown().await;
// Watchdog: inactivity, max lifetime, and cancellation
let la_watch = Arc::clone(&last_activity);
let c2b_abort = c2b.abort_handle();
let b2c_abort = b2c.abort_handle();
let watchdog = tokio::spawn(async move {
let check_interval = std::time::Duration::from_secs(5);
let mut last_seen = 0u64;
loop {
tokio::select! {
_ = cancel.cancelled() => break,
_ = tokio::time::sleep(check_interval) => {
if start.elapsed() >= max_lifetime {
debug!("QUIC stream exceeded max lifetime, closing");
break;
}
let current = la_watch.load(Ordering::Relaxed);
if current == last_seen {
let elapsed = start.elapsed().as_millis() as u64 - current;
if elapsed >= inactivity_timeout.as_millis() as u64 {
debug!("QUIC stream inactive for {}ms, closing", elapsed);
break;
}
}
last_seen = current;
}
}
}
conn_cancel.cancel();
tokio::time::sleep(std::time::Duration::from_secs(4)).await;
c2b_abort.abort();
b2c_abort.abort();
});
let bytes_in = c2b.await.unwrap_or(0);
let bytes_out = b2c.await.unwrap_or(0);
watchdog.abort();
Ok((bytes_in, bytes_out))
}

View File

@@ -1,126 +1,4 @@
//! Socket handler relay for connecting client connections to a TypeScript handler
//! via a Unix domain socket.
//! Socket handler relay module.
//!
//! Protocol: Send a JSON metadata line terminated by `\n`, then bidirectional relay.
use tokio::net::UnixStream;
use tokio::io::{AsyncWriteExt, AsyncReadExt};
use tokio::net::TcpStream;
use serde::Serialize;
use tracing::debug;
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
struct RelayMetadata {
connection_id: u64,
remote_ip: String,
remote_port: u16,
local_port: u16,
sni: Option<String>,
route_name: String,
initial_data_base64: Option<String>,
}
/// Relay a client connection to a TypeScript handler via Unix domain socket.
///
/// Protocol: Send a JSON metadata line terminated by `\n`, then bidirectional relay.
pub async fn relay_to_handler(
client: TcpStream,
relay_socket_path: &str,
connection_id: u64,
remote_ip: String,
remote_port: u16,
local_port: u16,
sni: Option<String>,
route_name: String,
initial_data: Option<&[u8]>,
) -> std::io::Result<()> {
debug!(
"Relaying connection {} to handler socket {}",
connection_id, relay_socket_path
);
// Connect to TypeScript handler Unix socket
let mut handler = UnixStream::connect(relay_socket_path).await?;
// Build and send metadata header
let initial_data_base64 = initial_data.map(base64_encode);
let metadata = RelayMetadata {
connection_id,
remote_ip,
remote_port,
local_port,
sni,
route_name,
initial_data_base64,
};
let metadata_json = serde_json::to_string(&metadata)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
handler.write_all(metadata_json.as_bytes()).await?;
handler.write_all(b"\n").await?;
// Bidirectional relay between client and handler
let (mut client_read, mut client_write) = client.into_split();
let (mut handler_read, mut handler_write) = handler.into_split();
let c2h = tokio::spawn(async move {
let mut buf = vec![0u8; 65536];
loop {
let n = match client_read.read(&mut buf).await {
Ok(0) | Err(_) => break,
Ok(n) => n,
};
if handler_write.write_all(&buf[..n]).await.is_err() {
break;
}
}
let _ = handler_write.shutdown().await;
});
let h2c = tokio::spawn(async move {
let mut buf = vec![0u8; 65536];
loop {
let n = match handler_read.read(&mut buf).await {
Ok(0) | Err(_) => break,
Ok(n) => n,
};
if client_write.write_all(&buf[..n]).await.is_err() {
break;
}
}
let _ = client_write.shutdown().await;
});
let _ = tokio::join!(c2h, h2c);
debug!("Relay connection {} completed", connection_id);
Ok(())
}
/// Simple base64 encoding without external dependency.
fn base64_encode(data: &[u8]) -> String {
const CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
let mut result = String::new();
for chunk in data.chunks(3) {
let b0 = chunk[0] as u32;
let b1 = if chunk.len() > 1 { chunk[1] as u32 } else { 0 };
let b2 = if chunk.len() > 2 { chunk[2] as u32 } else { 0 };
let n = (b0 << 16) | (b1 << 8) | b2;
result.push(CHARS[((n >> 18) & 0x3F) as usize] as char);
result.push(CHARS[((n >> 12) & 0x3F) as usize] as char);
if chunk.len() > 1 {
result.push(CHARS[((n >> 6) & 0x3F) as usize] as char);
} else {
result.push('=');
}
if chunk.len() > 2 {
result.push(CHARS[(n & 0x3F) as usize] as char);
} else {
result.push('=');
}
}
result
}
//! Note: The actual relay logic lives in `tcp_listener::relay_to_socket_handler()`
//! which has proper timeouts, cancellation, and metrics integration.

View File

@@ -182,6 +182,7 @@ impl TcpListenerManager {
http_proxy_svc.set_backend_tls_config_alpn(tls_handler::shared_backend_tls_config_alpn());
http_proxy_svc.set_connection_timeouts(
std::time::Duration::from_millis(conn_config.socket_timeout_ms),
std::time::Duration::from_millis(conn_config.max_connection_lifetime_ms),
std::time::Duration::from_millis(conn_config.socket_timeout_ms),
std::time::Duration::from_millis(conn_config.max_connection_lifetime_ms),
);
@@ -220,6 +221,7 @@ impl TcpListenerManager {
http_proxy_svc.set_backend_tls_config_alpn(tls_handler::shared_backend_tls_config_alpn());
http_proxy_svc.set_connection_timeouts(
std::time::Duration::from_millis(conn_config.socket_timeout_ms),
std::time::Duration::from_millis(conn_config.max_connection_lifetime_ms),
std::time::Duration::from_millis(conn_config.socket_timeout_ms),
std::time::Duration::from_millis(conn_config.max_connection_lifetime_ms),
);
@@ -263,6 +265,7 @@ impl TcpListenerManager {
http_proxy_svc.set_backend_tls_config_alpn(tls_handler::shared_backend_tls_config_alpn());
http_proxy_svc.set_connection_timeouts(
std::time::Duration::from_millis(config.socket_timeout_ms),
std::time::Duration::from_millis(config.max_connection_lifetime_ms),
std::time::Duration::from_millis(config.socket_timeout_ms),
std::time::Duration::from_millis(config.max_connection_lifetime_ms),
);
@@ -428,6 +431,11 @@ impl TcpListenerManager {
self.http_proxy.prune_stale_routes(active_route_ids);
}
/// Get a reference to the HTTP proxy service (shared with H3).
pub fn http_proxy(&self) -> &Arc<HttpProxyService> {
&self.http_proxy
}
/// Get a reference to the connection tracker.
pub fn conn_tracker(&self) -> &Arc<ConnectionTracker> {
&self.conn_tracker

View File

@@ -2,12 +2,17 @@
//!
//! Binds UDP sockets on configured ports, receives datagrams, matches routes,
//! tracks sessions (flows), and forwards datagrams to backend UDP sockets.
//!
//! Supports PROXY protocol v2 on both raw UDP and QUIC paths when `proxy_ips`
//! is configured. For QUIC, a relay layer intercepts datagrams before they
//! reach the quinn endpoint.
use std::collections::HashMap;
use std::net::SocketAddr;
use std::net::{IpAddr, SocketAddr};
use std::sync::atomic::Ordering;
use std::sync::Arc;
use dashmap::DashMap;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use arc_swap::ArcSwap;
@@ -48,6 +53,9 @@ pub struct UdpListenerManager {
relay_reader_cancel: Option<CancellationToken>,
/// H3 proxy service for HTTP/3 request handling
h3_service: Option<Arc<H3ProxyService>>,
/// Trusted proxy IPs that may send PROXY protocol v2 headers.
/// When non-empty, PROXY v2 detection is enabled on both raw UDP and QUIC paths.
proxy_ips: Arc<Vec<IpAddr>>,
}
impl Drop for UdpListenerManager {
@@ -80,9 +88,18 @@ impl UdpListenerManager {
relay_writer: Arc::new(Mutex::new(None)),
relay_reader_cancel: None,
h3_service: None,
proxy_ips: Arc::new(Vec::new()),
}
}
/// Set the trusted proxy IPs for PROXY protocol v2 detection.
pub fn set_proxy_ips(&mut self, ips: Vec<IpAddr>) {
if !ips.is_empty() {
info!("UDP/QUIC PROXY protocol v2 enabled for {} trusted IPs", ips.len());
}
self.proxy_ips = Arc::new(ips);
}
/// Set the H3 proxy service for HTTP/3 request handling.
pub fn set_h3_service(&mut self, svc: Arc<H3ProxyService>) {
self.h3_service = Some(svc);
@@ -122,20 +139,44 @@ impl UdpListenerManager {
if has_quic {
if let Some(tls) = tls_config {
// Create QUIC endpoint; clone it so we can hot-swap TLS later
let endpoint = crate::quic_handler::create_quic_endpoint(port, tls)?;
let endpoint_for_updates = endpoint.clone(); // quinn::Endpoint is Arc-based
let handle = tokio::spawn(crate::quic_handler::quic_accept_loop(
endpoint,
port,
Arc::clone(&self.route_manager),
Arc::clone(&self.metrics),
Arc::clone(&self.conn_tracker),
self.cancel_token.child_token(),
self.h3_service.clone(),
));
self.listeners.insert(port, (handle, Some(endpoint_for_updates)));
info!("QUIC endpoint started on port {}", port);
if self.proxy_ips.is_empty() {
// Direct path: quinn owns the external socket (zero overhead)
let endpoint = crate::quic_handler::create_quic_endpoint(port, tls)?;
let endpoint_for_updates = endpoint.clone();
let handle = tokio::spawn(crate::quic_handler::quic_accept_loop(
endpoint,
port,
Arc::clone(&self.route_manager),
Arc::clone(&self.metrics),
Arc::clone(&self.conn_tracker),
self.cancel_token.child_token(),
self.h3_service.clone(),
None,
));
self.listeners.insert(port, (handle, Some(endpoint_for_updates)));
info!("QUIC endpoint started on port {}", port);
} else {
// Proxy relay path: we own external socket, quinn on localhost
let relay = crate::quic_handler::create_quic_endpoint_with_proxy_relay(
port,
tls,
Arc::clone(&self.proxy_ips),
self.cancel_token.child_token(),
)?;
let endpoint_for_updates = relay.endpoint.clone();
let handle = tokio::spawn(crate::quic_handler::quic_accept_loop(
relay.endpoint,
port,
Arc::clone(&self.route_manager),
Arc::clone(&self.metrics),
Arc::clone(&self.conn_tracker),
self.cancel_token.child_token(),
self.h3_service.clone(),
Some(relay.real_client_map),
));
self.listeners.insert(port, (handle, Some(endpoint_for_updates)));
info!("QUIC endpoint with PROXY relay started on port {}", port);
}
return Ok(());
} else {
warn!("QUIC routes on port {} but no TLS config provided, falling back to raw UDP", port);
@@ -158,6 +199,7 @@ impl UdpListenerManager {
Arc::clone(&self.datagram_handler_relay),
Arc::clone(&self.relay_writer),
self.cancel_token.child_token(),
Arc::clone(&self.proxy_ips),
));
self.listeners.insert(port, (handle, None));
@@ -222,6 +264,149 @@ impl UdpListenerManager {
}
}
/// Upgrade raw UDP fallback listeners to QUIC endpoints.
///
/// At startup, if no TLS certs are available, QUIC routes fall back to raw UDP.
/// When certs become available later (via loadCertificate IPC or ACME), this method
/// stops the raw UDP listener, drains sessions, and creates a proper QUIC endpoint.
///
/// This is idempotent — ports that already have QUIC endpoints are skipped.
pub async fn upgrade_raw_to_quic(&mut self, tls_config: Arc<rustls::ServerConfig>) {
// Find ports that are raw UDP fallback (endpoint=None) but have QUIC routes
let rm = self.route_manager.load();
let upgrade_ports: Vec<u16> = self.listeners.iter()
.filter(|(_, (_, endpoint))| endpoint.is_none())
.filter(|(port, _)| {
rm.routes_for_port(**port).iter().any(|r| {
r.action.udp.as_ref()
.and_then(|u| u.quic.as_ref())
.is_some()
})
})
.map(|(port, _)| *port)
.collect();
for port in upgrade_ports {
info!("Upgrading raw UDP listener on port {} to QUIC endpoint", port);
// Stop the raw UDP listener task and drain sessions to release the socket
if let Some((handle, _)) = self.listeners.remove(&port) {
handle.abort();
}
let drained = self.session_table.drain_port(
port, &self.metrics, &self.conn_tracker,
);
if drained > 0 {
debug!("Drained {} UDP sessions on port {} for QUIC upgrade", drained, port);
}
// Brief yield to let aborted tasks drop their socket references
tokio::task::yield_now().await;
// Create QUIC endpoint on the now-free port
let create_result = if self.proxy_ips.is_empty() {
self.create_quic_direct(port, Arc::clone(&tls_config))
} else {
self.create_quic_with_relay(port, Arc::clone(&tls_config))
};
match create_result {
Ok(()) => {
info!("QUIC endpoint started on port {} (upgraded from raw UDP)", port);
}
Err(e) => {
// Port may still be held — retry once after a brief delay
warn!("QUIC endpoint creation failed on port {}, retrying: {}", port, e);
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
let retry_result = if self.proxy_ips.is_empty() {
self.create_quic_direct(port, Arc::clone(&tls_config))
} else {
self.create_quic_with_relay(port, Arc::clone(&tls_config))
};
match retry_result {
Ok(()) => {
info!("QUIC endpoint started on port {} (upgraded from raw UDP, retry)", port);
}
Err(e2) => {
error!("Failed to upgrade port {} to QUIC after retry: {}. \
Rebinding as raw UDP.", port, e2);
// Fallback: rebind as raw UDP so the port isn't dead
if let Ok(()) = self.rebind_raw_udp(port).await {
warn!("Port {} rebound as raw UDP (QUIC upgrade failed)", port);
}
}
}
}
}
}
}
/// Create a direct QUIC endpoint (quinn owns the socket).
fn create_quic_direct(&mut self, port: u16, tls_config: Arc<rustls::ServerConfig>) -> anyhow::Result<()> {
let endpoint = crate::quic_handler::create_quic_endpoint(port, tls_config)?;
let endpoint_for_updates = endpoint.clone();
let handle = tokio::spawn(crate::quic_handler::quic_accept_loop(
endpoint,
port,
Arc::clone(&self.route_manager),
Arc::clone(&self.metrics),
Arc::clone(&self.conn_tracker),
self.cancel_token.child_token(),
self.h3_service.clone(),
None,
));
self.listeners.insert(port, (handle, Some(endpoint_for_updates)));
Ok(())
}
/// Create a QUIC endpoint with PROXY protocol relay.
fn create_quic_with_relay(&mut self, port: u16, tls_config: Arc<rustls::ServerConfig>) -> anyhow::Result<()> {
let relay = crate::quic_handler::create_quic_endpoint_with_proxy_relay(
port,
tls_config,
Arc::clone(&self.proxy_ips),
self.cancel_token.child_token(),
)?;
let endpoint_for_updates = relay.endpoint.clone();
let handle = tokio::spawn(crate::quic_handler::quic_accept_loop(
relay.endpoint,
port,
Arc::clone(&self.route_manager),
Arc::clone(&self.metrics),
Arc::clone(&self.conn_tracker),
self.cancel_token.child_token(),
self.h3_service.clone(),
Some(relay.real_client_map),
));
self.listeners.insert(port, (handle, Some(endpoint_for_updates)));
Ok(())
}
/// Rebind a port as a raw UDP listener (fallback when QUIC upgrade fails).
async fn rebind_raw_udp(&mut self, port: u16) -> anyhow::Result<()> {
let addr: std::net::SocketAddr = ([0, 0, 0, 0], port).into();
let socket = UdpSocket::bind(addr).await?;
let socket = Arc::new(socket);
let handle = tokio::spawn(Self::recv_loop(
socket,
port,
Arc::clone(&self.route_manager),
Arc::clone(&self.metrics),
Arc::clone(&self.conn_tracker),
Arc::clone(&self.session_table),
Arc::clone(&self.datagram_handler_relay),
Arc::clone(&self.relay_writer),
self.cancel_token.child_token(),
Arc::clone(&self.proxy_ips),
));
self.listeners.insert(port, (handle, None));
Ok(())
}
/// Set the datagram handler relay socket path and establish connection.
pub async fn set_datagram_handler_relay(&mut self, path: String) {
// Cancel previous relay reader task if any
@@ -296,6 +481,10 @@ impl UdpListenerManager {
}
/// Main receive loop for a UDP port.
///
/// When `proxy_ips` is non-empty, the first datagram from a trusted proxy IP
/// is checked for PROXY protocol v2. If found, the real client IP is extracted
/// and used for all subsequent session handling for that source address.
async fn recv_loop(
socket: Arc<UdpSocket>,
port: u16,
@@ -306,11 +495,38 @@ impl UdpListenerManager {
_datagram_handler_relay: Arc<RwLock<Option<String>>>,
relay_writer: Arc<Mutex<Option<tokio::net::unix::OwnedWriteHalf>>>,
cancel: CancellationToken,
proxy_ips: Arc<Vec<IpAddr>>,
) {
// Use a reasonably large buffer; actual max is per-route but we need a single buffer
let mut buf = vec![0u8; 65535];
// Maps proxy source addr → real client addr (from PROXY v2 headers).
// Only populated when proxy_ips is non-empty.
let proxy_addr_map: DashMap<SocketAddr, SocketAddr> = DashMap::new();
// Periodic cleanup for proxy_addr_map to prevent unbounded growth
let mut last_proxy_cleanup = tokio::time::Instant::now();
let proxy_cleanup_interval = std::time::Duration::from_secs(60);
loop {
// Periodic cleanup: remove proxy_addr_map entries with no active session
if !proxy_addr_map.is_empty() && last_proxy_cleanup.elapsed() >= proxy_cleanup_interval {
last_proxy_cleanup = tokio::time::Instant::now();
let stale: Vec<SocketAddr> = proxy_addr_map.iter()
.filter(|entry| {
let key: SessionKey = (*entry.key(), port);
session_table.get(&key).is_none()
})
.map(|entry| *entry.key())
.collect();
if !stale.is_empty() {
debug!("UDP proxy_addr_map cleanup: removing {} stale entries on port {}", stale.len(), port);
for addr in stale {
proxy_addr_map.remove(&addr);
}
}
}
let (len, client_addr) = tokio::select! {
_ = cancel.cancelled() => {
debug!("UDP recv loop on port {} cancelled", port);
@@ -329,9 +545,39 @@ impl UdpListenerManager {
let datagram = &buf[..len];
// Route matching
// PROXY protocol v2 detection for datagrams from trusted proxy IPs
let effective_client_ip = if !proxy_ips.is_empty() && proxy_ips.contains(&client_addr.ip()) {
let session_key: SessionKey = (client_addr, port);
if session_table.get(&session_key).is_none() && !proxy_addr_map.contains_key(&client_addr) {
// No session and no prior PROXY header — check for PROXY v2
if crate::proxy_protocol::is_proxy_protocol_v2(datagram) {
match crate::proxy_protocol::parse_v2(datagram) {
Ok((header, _consumed)) => {
debug!("UDP PROXY v2 from {}: real client {}", client_addr, header.source_addr);
proxy_addr_map.insert(client_addr, header.source_addr);
continue; // discard the PROXY v2 datagram
}
Err(e) => {
debug!("UDP PROXY v2 parse error from {}: {}", client_addr, e);
client_addr.ip()
}
}
} else {
client_addr.ip()
}
} else {
// Use real client IP if we've previously seen a PROXY v2 header
proxy_addr_map.get(&client_addr)
.map(|r| r.ip())
.unwrap_or_else(|| client_addr.ip())
}
} else {
client_addr.ip()
};
// Route matching — use effective (real) client IP
let rm = route_manager.load();
let ip_str = client_addr.ip().to_string();
let ip_str = effective_client_ip.to_string();
let ctx = MatchContext {
port,
domain: None,
@@ -380,20 +626,21 @@ impl UdpListenerManager {
}
// Session lookup or create
// Session key uses the proxy's source addr for correct return-path routing
let session_key: SessionKey = (client_addr, port);
let session = match session_table.get(&session_key) {
Some(s) => s,
None => {
// New session — check per-IP limits
if !conn_tracker.try_accept(&client_addr.ip()) {
debug!("UDP session rejected for {} (rate limit)", client_addr);
// New session — check per-IP limits using the real client IP
if !conn_tracker.try_accept(&effective_client_ip) {
debug!("UDP session rejected for {} (rate limit)", effective_client_ip);
continue;
}
if !session_table.can_create_session(
&client_addr.ip(),
&effective_client_ip,
udp_config.max_sessions_per_ip,
) {
debug!("UDP session rejected for {} (per-IP session limit)", client_addr);
debug!("UDP session rejected for {} (per-IP session limit)", effective_client_ip);
continue;
}
@@ -426,8 +673,8 @@ impl UdpListenerManager {
}
let backend_socket = Arc::new(backend_socket);
debug!("New UDP session: {} -> {} (via port {})",
client_addr, backend_addr, port);
debug!("New UDP session: {} -> {} (via port {}, real client {})",
client_addr, backend_addr, port, effective_client_ip);
// Spawn return-path relay task
let session_cancel = CancellationToken::new();
@@ -447,7 +694,7 @@ impl UdpListenerManager {
last_activity: std::sync::atomic::AtomicU64::new(session_table.elapsed_ms()),
created_at: std::time::Instant::now(),
route_id: route_id.map(|s| s.to_string()),
source_ip: client_addr.ip(),
source_ip: effective_client_ip,
client_addr,
return_task,
cancel: session_cancel,
@@ -458,8 +705,8 @@ impl UdpListenerManager {
continue;
}
// Track in metrics
conn_tracker.connection_opened(&client_addr.ip());
// Track in metrics using the real client IP
conn_tracker.connection_opened(&effective_client_ip);
metrics.connection_opened(route_id, Some(&ip_str));
metrics.udp_session_opened();

View File

@@ -201,6 +201,36 @@ impl UdpSessionTable {
removed
}
/// Drain all sessions on a given listening port, releasing socket references.
/// Used when upgrading a raw UDP listener to QUIC — the raw UDP socket's
/// Arc refcount must drop to zero so the port can be rebound.
pub fn drain_port(
&self,
port: u16,
metrics: &MetricsCollector,
conn_tracker: &ConnectionTracker,
) -> usize {
let keys: Vec<SessionKey> = self.sessions.iter()
.filter(|entry| entry.key().1 == port)
.map(|entry| *entry.key())
.collect();
let mut removed = 0;
for key in keys {
if let Some(session) = self.remove(&key) {
session.cancel.cancel();
conn_tracker.connection_closed(&session.source_ip);
metrics.connection_closed(
session.route_id.as_deref(),
Some(&session.source_ip.to_string()),
);
metrics.udp_session_closed();
removed += 1;
}
}
removed
}
/// Total number of active sessions.
pub fn session_count(&self) -> usize {
self.sessions.len()

View File

@@ -122,10 +122,16 @@ impl RouteManager {
// This prevents session-ticket resumption from misrouting when clients
// omit SNI (RFC 8446 recommends but doesn't mandate SNI on resumption).
// Wildcard-only routes (domains: ["*"]) still match since they accept all.
let patterns = domains.to_vec();
let is_wildcard_only = patterns.iter().all(|d| *d == "*");
if !is_wildcard_only {
return false;
//
// Exception: QUIC (UDP transport) encrypts the TLS ClientHello, so SNI
// is unavailable at accept time. Domain verification happens per-request
// in H3ProxyService via the :authority header.
if ctx.transport != Some(TransportProtocol::Udp) {
let patterns = domains.to_vec();
let is_wildcard_only = patterns.iter().all(|d| *d == "*");
if !is_wildcard_only {
return false;
}
}
}
}
@@ -349,8 +355,6 @@ mod tests {
load_balancing: None,
advanced: None,
options: None,
forwarding_engine: None,
nftables: None,
send_proxy_protocol: None,
udp: None,
},
@@ -997,4 +1001,52 @@ mod tests {
let result = manager.find_route(&udp_ctx).unwrap();
assert_eq!(result.route.name.as_deref(), Some("udp-route"));
}
#[test]
fn test_quic_tls_no_sni_matches_domain_restricted_route() {
// QUIC accept-level matching: is_tls=true, domain=None, transport=Udp.
// Should match because QUIC encrypts the ClientHello — SNI is unavailable
// at accept time but verified per-request in H3ProxyService.
let mut route = make_route(443, Some("example.com"), 0);
route.route_match.transport = Some(TransportProtocol::Udp);
let routes = vec![route];
let manager = RouteManager::new(routes);
let ctx = MatchContext {
port: 443,
domain: None,
path: None,
client_ip: None,
tls_version: None,
headers: None,
is_tls: true,
protocol: Some("quic"),
transport: Some(TransportProtocol::Udp),
};
assert!(manager.find_route(&ctx).is_some(),
"QUIC (UDP) with is_tls=true and domain=None should match domain-restricted routes");
}
#[test]
fn test_tcp_tls_no_sni_still_rejects_domain_restricted_route() {
// TCP TLS without SNI must still be rejected (no QUIC exemption).
let routes = vec![make_route(443, Some("example.com"), 0)];
let manager = RouteManager::new(routes);
let ctx = MatchContext {
port: 443,
domain: None,
path: None,
client_ip: None,
tls_version: None,
headers: None,
is_tls: true,
protocol: None,
transport: None, // TCP (default)
};
assert!(manager.find_route(&ctx).is_none(),
"TCP TLS without SNI should NOT match domain-restricted routes");
}
}

View File

@@ -20,7 +20,6 @@ rustproxy-routing = { workspace = true }
rustproxy-tls = { workspace = true }
rustproxy-passthrough = { workspace = true }
rustproxy-http = { workspace = true }
rustproxy-nftables = { workspace = true }
rustproxy-metrics = { workspace = true }
rustproxy-security = { workspace = true }
tokio = { workspace = true }
@@ -44,3 +43,9 @@ mimalloc = { workspace = true }
[dev-dependencies]
rcgen = { workspace = true }
quinn = { workspace = true }
h3 = { workspace = true }
h3-quinn = { workspace = true }
bytes = { workspace = true }
rustls = { workspace = true }
http = "1"

View File

@@ -7,14 +7,40 @@
//!
//! ```rust,no_run
//! use rustproxy::RustProxy;
//! use rustproxy_config::{RustProxyOptions, create_https_passthrough_route};
//! use rustproxy_config::*;
//!
//! #[tokio::main]
//! async fn main() -> anyhow::Result<()> {
//! let options = RustProxyOptions {
//! routes: vec![
//! create_https_passthrough_route("example.com", "backend", 443),
//! ],
//! routes: vec![RouteConfig {
//! id: None,
//! route_match: RouteMatch {
//! ports: PortRange::Single(443),
//! domains: Some(DomainSpec::Single("example.com".to_string())),
//! path: None, client_ip: None, transport: None,
//! tls_version: None, headers: None, protocol: None,
//! },
//! action: RouteAction {
//! action_type: RouteActionType::Forward,
//! targets: Some(vec![RouteTarget {
//! target_match: None,
//! host: HostSpec::Single("backend".to_string()),
//! port: PortSpec::Fixed(443),
//! tls: None, websocket: None, load_balancing: None,
//! send_proxy_protocol: None, headers: None, advanced: None,
//! backend_transport: None, priority: None,
//! }]),
//! tls: Some(RouteTls {
//! mode: TlsMode::Passthrough,
//! certificate: None, acme: None, versions: None,
//! ciphers: None, honor_cipher_order: None, session_timeout: None,
//! }),
//! websocket: None, load_balancing: None, advanced: None,
//! options: None, send_proxy_protocol: None, udp: None,
//! },
//! headers: None, security: None, name: None, description: None,
//! priority: None, tags: None, enabled: None,
//! }],
//! ..Default::default()
//! };
//!
@@ -41,16 +67,14 @@ pub use rustproxy_routing;
pub use rustproxy_passthrough;
pub use rustproxy_tls;
pub use rustproxy_http;
pub use rustproxy_nftables;
pub use rustproxy_metrics;
pub use rustproxy_security;
use rustproxy_config::{RouteConfig, RustProxyOptions, TlsMode, CertificateSpec, ForwardingEngine};
use rustproxy_config::{RouteConfig, RustProxyOptions, TlsMode, CertificateSpec};
use rustproxy_routing::RouteManager;
use rustproxy_passthrough::{TcpListenerManager, UdpListenerManager, TlsCertConfig, ConnectionConfig};
use rustproxy_metrics::{MetricsCollector, Metrics, Statistics};
use rustproxy_tls::{CertManager, CertStore, CertBundle, CertMetadata, CertSource};
use rustproxy_nftables::{NftManager, rule_builder};
use tokio_util::sync::CancellationToken;
/// Certificate status.
@@ -74,7 +98,6 @@ pub struct RustProxy {
challenge_server: Option<challenge_server::ChallengeServer>,
renewal_handle: Option<tokio::task::JoinHandle<()>>,
sampling_handle: Option<tokio::task::JoinHandle<()>>,
nft_manager: Option<NftManager>,
started: bool,
started_at: Option<Instant>,
/// Shared path to a Unix domain socket for relaying socket-handler connections back to TypeScript.
@@ -121,7 +144,6 @@ impl RustProxy {
challenge_server: None,
renewal_handle: None,
sampling_handle: None,
nft_manager: None,
started: false,
started_at: None,
socket_handler_relay: Arc::new(std::sync::RwLock::new(None)),
@@ -264,6 +286,8 @@ impl RustProxy {
conn_config.socket_timeout_ms,
conn_config.max_connection_lifetime_ms,
);
// Clone proxy_ips before conn_config is moved into the TCP listener
let udp_proxy_ips = conn_config.proxy_ips.clone();
listener.set_connection_config(conn_config);
// Share the socket-handler relay path with the listener
@@ -339,16 +363,12 @@ impl RustProxy {
conn_tracker,
self.cancel_token.clone(),
);
udp_mgr.set_proxy_ips(udp_proxy_ips.clone());
// Construct H3ProxyService for HTTP/3 request handling
let h3_svc = rustproxy_http::h3_service::H3ProxyService::new(
Arc::new(ArcSwap::from(Arc::clone(&*self.route_table.load()))),
Arc::clone(&self.metrics),
Arc::new(rustproxy_http::connection_pool::ConnectionPool::new()),
Arc::new(rustproxy_http::protocol_cache::ProtocolCache::new()),
rustproxy_passthrough::tls_handler::shared_backend_tls_config(),
std::time::Duration::from_secs(30),
);
// Share HttpProxyService with H3 — same route matching, connection
// pool, and ALPN protocol detection as the TCP/HTTP path.
let http_proxy = self.listener_manager.as_ref().unwrap().http_proxy().clone();
let h3_svc = rustproxy_http::h3_service::H3ProxyService::new(http_proxy);
udp_mgr.set_h3_service(Arc::new(h3_svc));
for port in &udp_ports {
@@ -365,6 +385,7 @@ impl RustProxy {
// Start the throughput sampling task with cooperative cancellation
let metrics = Arc::clone(&self.metrics);
let conn_tracker = self.listener_manager.as_ref().unwrap().conn_tracker().clone();
let http_proxy = self.listener_manager.as_ref().unwrap().http_proxy().clone();
let interval_ms = self.options.metrics.as_ref()
.and_then(|m| m.sample_interval_ms)
.unwrap_or(1000);
@@ -380,14 +401,14 @@ impl RustProxy {
metrics.sample_all();
// Periodically clean up stale rate-limit timestamp entries
conn_tracker.cleanup_stale_timestamps();
// Clean up expired rate limiter entries to prevent unbounded
// growth from unique IPs after traffic stops
http_proxy.cleanup_all_rate_limiters();
}
}
}
}));
// Apply NFTables rules for routes using nftables forwarding engine
self.apply_nftables_rules(&self.options.routes.clone()).await;
// Start renewal timer if ACME is enabled
self.start_renewal_timer();
@@ -614,14 +635,6 @@ impl RustProxy {
}
self.challenge_server = None;
// Clean up NFTables rules
if let Some(ref mut nft) = self.nft_manager {
if let Err(e) = nft.cleanup().await {
warn!("NFTables cleanup failed: {}", e);
}
}
self.nft_manager = None;
if let Some(ref mut listener) = self.listener_manager {
listener.graceful_stop().await;
}
@@ -774,21 +787,22 @@ impl RustProxy {
if self.udp_listener_manager.is_none() {
if let Some(ref listener) = self.listener_manager {
let conn_tracker = listener.conn_tracker().clone();
self.udp_listener_manager = Some(UdpListenerManager::new(
let conn_config = Self::build_connection_config(&self.options);
let mut udp_mgr = UdpListenerManager::new(
Arc::clone(&new_manager),
Arc::clone(&self.metrics),
conn_tracker,
self.cancel_token.clone(),
));
);
udp_mgr.set_proxy_ips(conn_config.proxy_ips);
self.udp_listener_manager = Some(udp_mgr);
}
}
// Build TLS config for QUIC before taking mutable borrow on udp_mgr
let quic_tls = if new_udp_ports.iter().any(|p| !old_udp_ports.contains(p)) {
// Build TLS config for QUIC (needed for new ports and upgrading existing raw UDP)
let quic_tls = {
let tls_configs = self.current_tls_configs().await;
Self::build_quic_tls_config(&tls_configs)
} else {
None
};
if let Some(ref mut udp_mgr) = self.udp_listener_manager {
@@ -806,6 +820,12 @@ impl RustProxy {
udp_mgr.remove_port(*port);
}
}
// Upgrade existing raw UDP fallback listeners to QUIC if TLS is now available
if let Some(ref quic_config) = quic_tls {
udp_mgr.update_quic_tls(Arc::clone(quic_config));
udp_mgr.upgrade_raw_to_quic(Arc::clone(quic_config)).await;
}
}
} else if self.udp_listener_manager.is_some() {
// All UDP routes removed — shut down UDP manager
@@ -816,9 +836,6 @@ impl RustProxy {
}
}
// Update NFTables rules: remove old, apply new
self.update_nftables_rules(&routes).await;
self.options.routes = routes;
Ok(())
}
@@ -862,12 +879,12 @@ impl RustProxy {
.map_err(|e| anyhow::anyhow!("ACME provisioning failed: {}", e))?;
// Hot-swap into TLS configs
if let Some(ref mut listener) = self.listener_manager {
let mut tls_configs = Self::extract_tls_configs(&self.options.routes);
tls_configs.insert(domain.clone(), TlsCertConfig {
cert_pem: bundle.cert_pem.clone(),
key_pem: bundle.key_pem.clone(),
});
let mut tls_configs = Self::extract_tls_configs(&self.options.routes);
tls_configs.insert(domain.clone(), TlsCertConfig {
cert_pem: bundle.cert_pem.clone(),
key_pem: bundle.key_pem.clone(),
});
{
let cm = cm_arc.lock().await;
for (d, b) in cm.store().iter() {
if !tls_configs.contains_key(d) {
@@ -877,9 +894,22 @@ impl RustProxy {
});
}
}
}
let quic_tls = Self::build_quic_tls_config(&tls_configs);
if let Some(ref listener) = self.listener_manager {
listener.set_tls_configs(tls_configs);
}
// Update existing QUIC endpoints and upgrade raw UDP fallback listeners
if let Some(ref mut udp_mgr) = self.udp_listener_manager {
if let Some(ref quic_config) = quic_tls {
udp_mgr.update_quic_tls(Arc::clone(quic_config));
udp_mgr.upgrade_raw_to_quic(Arc::clone(quic_config)).await;
}
}
info!("Certificate provisioned and loaded for route '{}'", route_name);
Ok(())
}
@@ -919,8 +949,31 @@ impl RustProxy {
}
/// Get current metrics snapshot.
/// Includes protocol cache entries from the HTTP proxy service.
pub fn get_metrics(&self) -> Metrics {
self.metrics.snapshot()
let mut metrics = self.metrics.snapshot();
if let Some(ref lm) = self.listener_manager {
let entries = lm.http_proxy().protocol_cache_snapshot();
metrics.detected_protocols = entries.into_iter().map(|e| {
rustproxy_metrics::ProtocolCacheEntryMetric {
host: e.host,
port: e.port,
domain: e.domain,
protocol: e.protocol,
h3_port: e.h3_port,
age_secs: e.age_secs,
last_accessed_secs: e.last_accessed_secs,
last_probed_secs: e.last_probed_secs,
h2_suppressed: e.h2_suppressed,
h3_suppressed: e.h3_suppressed,
h2_cooldown_remaining_secs: e.h2_cooldown_remaining_secs,
h3_cooldown_remaining_secs: e.h3_cooldown_remaining_secs,
h2_consecutive_failures: e.h2_consecutive_failures,
h3_consecutive_failures: e.h3_consecutive_failures,
}
}).collect();
}
metrics
}
/// Add a listening port at runtime.
@@ -980,44 +1033,25 @@ impl RustProxy {
fn build_quic_tls_config(
tls_configs: &HashMap<String, TlsCertConfig>,
) -> Option<Arc<rustls::ServerConfig>> {
// Find the first available cert (prefer wildcard, then any)
let cert_config = tls_configs.get("*")
.or_else(|| tls_configs.values().next());
let cert_config = match cert_config {
Some(c) => c,
None => return None,
};
// Parse cert chain from PEM
let mut cert_reader = std::io::BufReader::new(cert_config.cert_pem.as_bytes());
let certs: Vec<rustls::pki_types::CertificateDer<'static>> =
rustls_pemfile::certs(&mut cert_reader)
.filter_map(|r| r.ok())
.collect();
if certs.is_empty() {
if tls_configs.is_empty() {
return None;
}
// Parse private key from PEM
let mut key_reader = std::io::BufReader::new(cert_config.key_pem.as_bytes());
let key = match rustls_pemfile::private_key(&mut key_reader) {
Ok(Some(key)) => key,
_ => return None,
};
let mut tls_config = match rustls::ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(certs, key)
{
Ok(c) => c,
// Reuse CertResolver for SNI-based cert selection (same as TCP/TLS path).
// This ensures QUIC connections get the correct certificate for each domain
// instead of a single static cert.
let resolver = match rustproxy_passthrough::tls_handler::CertResolver::new(tls_configs) {
Ok(r) => r,
Err(e) => {
warn!("Failed to build QUIC TLS config: {}", e);
warn!("Failed to build QUIC cert resolver: {}", e);
return None;
}
};
let mut tls_config = rustls::ServerConfig::builder()
.with_no_client_auth()
.with_cert_resolver(Arc::new(resolver));
// QUIC requires h3 ALPN
tls_config.alpn_protocols = vec![b"h3".to_vec()];
@@ -1104,17 +1138,18 @@ impl RustProxy {
// Hot-swap TLS config on TCP and QUIC listeners
let tls_configs = self.current_tls_configs().await;
// Build QUIC TLS config before TCP consumes the map
let quic_tls = Self::build_quic_tls_config(&tls_configs);
if let Some(ref listener) = self.listener_manager {
// Build QUIC TLS config before TCP consumes the map
let quic_tls = Self::build_quic_tls_config(&tls_configs);
listener.set_tls_configs(tls_configs);
}
// Also update QUIC endpoints with the new certs
if let Some(ref udp_mgr) = self.udp_listener_manager {
if let Some(quic_config) = quic_tls {
udp_mgr.update_quic_tls(quic_config);
}
// Update existing QUIC endpoints and upgrade raw UDP fallback listeners
if let Some(ref mut udp_mgr) = self.udp_listener_manager {
if let Some(ref quic_config) = quic_tls {
udp_mgr.update_quic_tls(Arc::clone(quic_config));
udp_mgr.upgrade_raw_to_quic(Arc::clone(quic_config)).await;
}
}
@@ -1122,99 +1157,6 @@ impl RustProxy {
Ok(())
}
/// Get NFTables status.
pub async fn get_nftables_status(&self) -> Result<HashMap<String, serde_json::Value>> {
match &self.nft_manager {
Some(nft) => Ok(nft.status()),
None => Ok(HashMap::new()),
}
}
/// Apply NFTables rules for routes using the nftables forwarding engine.
async fn apply_nftables_rules(&mut self, routes: &[RouteConfig]) {
let nft_routes: Vec<&RouteConfig> = routes.iter()
.filter(|r| r.action.forwarding_engine.as_ref() == Some(&ForwardingEngine::Nftables))
.collect();
if nft_routes.is_empty() {
return;
}
info!("Applying NFTables rules for {} routes", nft_routes.len());
let table_name = nft_routes.iter()
.find_map(|r| r.action.nftables.as_ref()?.table_name.clone())
.unwrap_or_else(|| "rustproxy".to_string());
let mut nft = NftManager::new(Some(table_name));
for route in &nft_routes {
let route_id = route.id.as_deref()
.or(route.name.as_deref())
.unwrap_or("unnamed");
let nft_options = match &route.action.nftables {
Some(opts) => opts.clone(),
None => rustproxy_config::NfTablesOptions {
preserve_source_ip: None,
protocol: None,
max_rate: None,
priority: None,
table_name: None,
use_ip_sets: None,
use_advanced_nat: None,
},
};
let targets = match &route.action.targets {
Some(targets) => targets,
None => {
warn!("NFTables route '{}' has no targets, skipping", route_id);
continue;
}
};
let source_ports = route.route_match.ports.to_ports();
for target in targets {
let target_host = target.host.first().to_string();
let target_port_spec = &target.port;
for &source_port in &source_ports {
let resolved_port = target_port_spec.resolve(source_port);
let rules = rule_builder::build_dnat_rule(
nft.table_name(),
"prerouting",
source_port,
&target_host,
resolved_port,
&nft_options,
);
let rule_id = format!("{}-{}-{}", route_id, source_port, resolved_port);
if let Err(e) = nft.apply_rules(&rule_id, rules).await {
error!("Failed to apply NFTables rules for route '{}': {}", route_id, e);
}
}
}
}
self.nft_manager = Some(nft);
}
/// Update NFTables rules when routes change.
async fn update_nftables_rules(&mut self, new_routes: &[RouteConfig]) {
// Clean up old rules
if let Some(ref mut nft) = self.nft_manager {
if let Err(e) = nft.cleanup().await {
warn!("NFTables cleanup during update failed: {}", e);
}
}
self.nft_manager = None;
// Apply new rules
self.apply_nftables_rules(new_routes).await;
}
/// Extract TLS configurations from route configs.
fn extract_tls_configs(routes: &[RouteConfig]) -> HashMap<String, TlsCertConfig> {
let mut configs = HashMap::new();

View File

@@ -147,7 +147,6 @@ async fn handle_request(
"renewCertificate" => handle_renew_certificate(&id, &request.params, proxy).await,
"getCertificateStatus" => handle_get_certificate_status(&id, &request.params, proxy).await,
"getListeningPorts" => handle_get_listening_ports(&id, proxy),
"getNftablesStatus" => handle_get_nftables_status(&id, proxy).await,
"setSocketHandlerRelay" => handle_set_socket_handler_relay(&id, &request.params, proxy).await,
"setDatagramHandlerRelay" => handle_set_datagram_handler_relay(&id, &request.params, proxy).await,
"addListeningPort" => handle_add_listening_port(&id, &request.params, proxy).await,
@@ -352,26 +351,6 @@ fn handle_get_listening_ports(
}
}
async fn handle_get_nftables_status(
id: &str,
proxy: &Option<RustProxy>,
) -> ManagementResponse {
match proxy.as_ref() {
Some(p) => {
match p.get_nftables_status().await {
Ok(status) => {
match serde_json::to_value(&status) {
Ok(v) => ManagementResponse::ok(id.to_string(), v),
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to serialize: {}", e)),
}
}
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to get status: {}", e)),
}
}
None => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
}
}
async fn handle_set_socket_handler_relay(
id: &str,
params: &serde_json::Value,

View File

@@ -297,8 +297,6 @@ pub fn make_test_route(
load_balancing: None,
advanced: None,
options: None,
forwarding_engine: None,
nftables: None,
send_proxy_protocol: None,
udp: None,
},

View File

@@ -0,0 +1,195 @@
mod common;
use common::*;
use rustproxy::RustProxy;
use rustproxy_config::{RustProxyOptions, TransportProtocol, RouteUdp, RouteQuic};
use bytes::Buf;
use std::sync::Arc;
/// Build a route that listens on UDP with HTTP/3 enabled and TLS terminate.
fn make_h3_route(
port: u16,
target_host: &str,
target_port: u16,
cert_pem: &str,
key_pem: &str,
) -> rustproxy_config::RouteConfig {
let mut route = make_tls_terminate_route(port, "localhost", target_host, target_port, cert_pem, key_pem);
route.route_match.transport = Some(TransportProtocol::All);
// Keep domain="localhost" from make_tls_terminate_route — needed for TLS cert extraction
route.action.udp = Some(RouteUdp {
session_timeout: None,
max_sessions_per_ip: None,
max_datagram_size: None,
quic: Some(RouteQuic {
max_idle_timeout: Some(30000),
max_concurrent_bidi_streams: None,
max_concurrent_uni_streams: None,
enable_http3: Some(true),
alt_svc_port: None,
alt_svc_max_age: None,
initial_congestion_window: None,
}),
});
route
}
/// Build a quinn client endpoint with insecure TLS for testing.
fn make_h3_client_endpoint() -> quinn::Endpoint {
let mut tls_config = rustls::ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(Arc::new(InsecureVerifier))
.with_no_client_auth();
tls_config.alpn_protocols = vec![b"h3".to_vec()];
let quic_client_config = quinn::crypto::rustls::QuicClientConfig::try_from(tls_config)
.expect("Failed to build QUIC client config");
let client_config = quinn::ClientConfig::new(Arc::new(quic_client_config));
let mut endpoint = quinn::Endpoint::client("0.0.0.0:0".parse().unwrap())
.expect("Failed to create QUIC client endpoint");
endpoint.set_default_client_config(client_config);
endpoint
}
/// Test that HTTP/3 response streams properly finish (FIN is received by client).
///
/// This is the critical regression test for the FIN bug: the proxy must send
/// a QUIC stream FIN after the response body so the client's `recv_data()`
/// returns `None` instead of hanging forever.
#[tokio::test]
async fn test_h3_response_stream_finishes() {
let backend_port = next_port();
let proxy_port = next_port();
let body_text = "Hello from HTTP/3 backend! This body has a known length for testing.";
// 1. Start plain HTTP backend with known body + content-length
let _backend = start_http_server(backend_port, 200, body_text).await;
// 2. Generate self-signed cert and configure H3 route
let (cert_pem, key_pem) = generate_self_signed_cert("localhost");
let route = make_h3_route(proxy_port, "127.0.0.1", backend_port, &cert_pem, &key_pem);
let options = RustProxyOptions {
routes: vec![route],
..Default::default()
};
// 3. Start proxy and wait for UDP bind
let mut proxy = RustProxy::new(options).unwrap();
proxy.start().await.unwrap();
tokio::time::sleep(std::time::Duration::from_millis(500)).await;
// 4. Connect QUIC/H3 client
let endpoint = make_h3_client_endpoint();
let addr: std::net::SocketAddr = format!("127.0.0.1:{}", proxy_port).parse().unwrap();
let connection = endpoint
.connect(addr, "localhost")
.expect("Failed to initiate QUIC connection")
.await
.expect("QUIC handshake failed");
let (mut driver, mut send_request) = h3::client::new(
h3_quinn::Connection::new(connection),
)
.await
.expect("H3 connection setup failed");
// Drive the H3 connection in background
tokio::spawn(async move {
let _ = driver.wait_idle().await;
});
// 5. Send GET request
let req = http::Request::builder()
.method("GET")
.uri("https://localhost/")
.header("host", "localhost")
.body(())
.unwrap();
let mut stream = send_request.send_request(req).await
.expect("Failed to send H3 request");
stream.finish().await
.expect("Failed to finish sending H3 request body");
// 6. Read response headers
let resp = stream.recv_response().await
.expect("Failed to receive H3 response");
assert_eq!(resp.status(), http::StatusCode::OK,
"Expected 200 OK, got {}", resp.status());
// 7. Read body and verify stream ends (FIN received)
// This is the critical assertion: recv_data() must return None (stream ended)
// within the timeout, NOT hang forever waiting for a FIN that never arrives.
let result = with_timeout(async {
let mut total = 0usize;
while let Some(chunk) = stream.recv_data().await.expect("H3 data receive error") {
total += chunk.remaining();
}
// recv_data() returned None => stream ended (FIN received)
total
}, 10)
.await;
let bytes_received = result.expect(
"TIMEOUT: H3 stream never ended (FIN not received by client). \
The proxy sent all response data but failed to send the QUIC stream FIN."
);
assert_eq!(
bytes_received,
body_text.len(),
"Expected {} bytes, got {}",
body_text.len(),
bytes_received
);
// 8. Cleanup
endpoint.close(quinn::VarInt::from_u32(0), b"test done");
proxy.stop().await.unwrap();
}
/// Insecure TLS verifier that accepts any certificate (for tests only).
#[derive(Debug)]
struct InsecureVerifier;
impl rustls::client::danger::ServerCertVerifier for InsecureVerifier {
fn verify_server_cert(
&self,
_end_entity: &rustls::pki_types::CertificateDer<'_>,
_intermediates: &[rustls::pki_types::CertificateDer<'_>],
_server_name: &rustls::pki_types::ServerName<'_>,
_ocsp_response: &[u8],
_now: rustls::pki_types::UnixTime,
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
Ok(rustls::client::danger::ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &rustls::pki_types::CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &rustls::pki_types::CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
vec![
rustls::SignatureScheme::RSA_PKCS1_SHA256,
rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
rustls::SignatureScheme::ED25519,
rustls::SignatureScheme::RSA_PSS_SHA256,
]
}
}

View File

@@ -1,200 +0,0 @@
import { tap, expect } from '@git.zone/tstest/tapbundle';
import {
delay,
retryWithBackoff,
withTimeout,
parallelLimit,
debounceAsync,
AsyncMutex,
CircuitBreaker
} from '../../../ts/core/utils/async-utils.js';
tap.test('delay should pause execution for specified milliseconds', async () => {
const startTime = Date.now();
await delay(100);
const elapsed = Date.now() - startTime;
// Allow some tolerance for timing
expect(elapsed).toBeGreaterThan(90);
expect(elapsed).toBeLessThan(150);
});
tap.test('retryWithBackoff should retry failed operations', async () => {
let attempts = 0;
const operation = async () => {
attempts++;
if (attempts < 3) {
throw new Error('Test error');
}
return 'success';
};
const result = await retryWithBackoff(operation, {
maxAttempts: 3,
initialDelay: 10
});
expect(result).toEqual('success');
expect(attempts).toEqual(3);
});
tap.test('retryWithBackoff should throw after max attempts', async () => {
let attempts = 0;
const operation = async () => {
attempts++;
throw new Error('Always fails');
};
let error: Error | null = null;
try {
await retryWithBackoff(operation, {
maxAttempts: 2,
initialDelay: 10
});
} catch (e: any) {
error = e;
}
expect(error).not.toBeNull();
expect(error?.message).toEqual('Always fails');
expect(attempts).toEqual(2);
});
tap.test('withTimeout should complete operations within timeout', async () => {
const operation = async () => {
await delay(50);
return 'completed';
};
const result = await withTimeout(operation, 100);
expect(result).toEqual('completed');
});
tap.test('withTimeout should throw on timeout', async () => {
const operation = async () => {
await delay(200);
return 'never happens';
};
let error: Error | null = null;
try {
await withTimeout(operation, 50);
} catch (e: any) {
error = e;
}
expect(error).not.toBeNull();
expect(error?.message).toContain('timed out');
});
tap.test('parallelLimit should respect concurrency limit', async () => {
let concurrent = 0;
let maxConcurrent = 0;
const items = [1, 2, 3, 4, 5, 6];
const operation = async (item: number) => {
concurrent++;
maxConcurrent = Math.max(maxConcurrent, concurrent);
await delay(50);
concurrent--;
return item * 2;
};
const results = await parallelLimit(items, operation, 2);
expect(results).toEqual([2, 4, 6, 8, 10, 12]);
expect(maxConcurrent).toBeLessThan(3);
expect(maxConcurrent).toBeGreaterThan(0);
});
tap.test('debounceAsync should debounce function calls', async () => {
let callCount = 0;
const fn = async (value: string) => {
callCount++;
return value;
};
const debounced = debounceAsync(fn, 50);
// Make multiple calls quickly
debounced('a');
debounced('b');
debounced('c');
const result = await debounced('d');
// Wait a bit to ensure no more calls
await delay(100);
expect(result).toEqual('d');
expect(callCount).toEqual(1); // Only the last call should execute
});
tap.test('AsyncMutex should ensure exclusive access', async () => {
const mutex = new AsyncMutex();
const results: number[] = [];
const operation = async (value: number) => {
await mutex.runExclusive(async () => {
results.push(value);
await delay(10);
results.push(value * 10);
});
};
// Run operations concurrently
await Promise.all([
operation(1),
operation(2),
operation(3)
]);
// Results should show sequential execution
expect(results).toEqual([1, 10, 2, 20, 3, 30]);
});
tap.test('CircuitBreaker should open after failures', async () => {
const breaker = new CircuitBreaker({
failureThreshold: 2,
resetTimeout: 100
});
let attempt = 0;
const failingOperation = async () => {
attempt++;
throw new Error('Test failure');
};
// First two failures
for (let i = 0; i < 2; i++) {
try {
await breaker.execute(failingOperation);
} catch (e) {
// Expected
}
}
expect(breaker.isOpen()).toBeTrue();
// Next attempt should fail immediately
let error: Error | null = null;
try {
await breaker.execute(failingOperation);
} catch (e: any) {
error = e;
}
expect(error?.message).toEqual('Circuit breaker is open');
expect(attempt).toEqual(2); // Operation not called when circuit is open
// Wait for reset timeout
await delay(150);
// Circuit should be half-open now, allowing one attempt
const successOperation = async () => 'success';
const result = await breaker.execute(successOperation);
expect(result).toEqual('success');
expect(breaker.getState()).toEqual('closed');
});
tap.start();

View File

@@ -1,206 +0,0 @@
import { tap, expect } from '@git.zone/tstest/tapbundle';
import { BinaryHeap } from '../../../ts/core/utils/binary-heap.js';
interface TestItem {
id: string;
priority: number;
value: string;
}
tap.test('should create empty heap', async () => {
const heap = new BinaryHeap<number>((a, b) => a - b);
expect(heap.size).toEqual(0);
expect(heap.isEmpty()).toBeTrue();
expect(heap.peek()).toBeUndefined();
});
tap.test('should insert and extract in correct order', async () => {
const heap = new BinaryHeap<number>((a, b) => a - b);
heap.insert(5);
heap.insert(3);
heap.insert(7);
heap.insert(1);
heap.insert(9);
heap.insert(4);
expect(heap.size).toEqual(6);
// Extract in ascending order
expect(heap.extract()).toEqual(1);
expect(heap.extract()).toEqual(3);
expect(heap.extract()).toEqual(4);
expect(heap.extract()).toEqual(5);
expect(heap.extract()).toEqual(7);
expect(heap.extract()).toEqual(9);
expect(heap.extract()).toBeUndefined();
});
tap.test('should work with custom objects and comparator', async () => {
const heap = new BinaryHeap<TestItem>(
(a, b) => a.priority - b.priority,
(item) => item.id
);
heap.insert({ id: 'a', priority: 5, value: 'five' });
heap.insert({ id: 'b', priority: 2, value: 'two' });
heap.insert({ id: 'c', priority: 8, value: 'eight' });
heap.insert({ id: 'd', priority: 1, value: 'one' });
const first = heap.extract();
expect(first?.priority).toEqual(1);
expect(first?.value).toEqual('one');
const second = heap.extract();
expect(second?.priority).toEqual(2);
expect(second?.value).toEqual('two');
});
tap.test('should support reverse order (max heap)', async () => {
const heap = new BinaryHeap<number>((a, b) => b - a);
heap.insert(5);
heap.insert(3);
heap.insert(7);
heap.insert(1);
heap.insert(9);
// Extract in descending order
expect(heap.extract()).toEqual(9);
expect(heap.extract()).toEqual(7);
expect(heap.extract()).toEqual(5);
});
tap.test('should extract by predicate', async () => {
const heap = new BinaryHeap<TestItem>((a, b) => a.priority - b.priority);
heap.insert({ id: 'a', priority: 5, value: 'five' });
heap.insert({ id: 'b', priority: 2, value: 'two' });
heap.insert({ id: 'c', priority: 8, value: 'eight' });
const extracted = heap.extractIf(item => item.id === 'b');
expect(extracted?.id).toEqual('b');
expect(heap.size).toEqual(2);
// Should not find it again
const notFound = heap.extractIf(item => item.id === 'b');
expect(notFound).toBeUndefined();
});
tap.test('should extract by key', async () => {
const heap = new BinaryHeap<TestItem>(
(a, b) => a.priority - b.priority,
(item) => item.id
);
heap.insert({ id: 'a', priority: 5, value: 'five' });
heap.insert({ id: 'b', priority: 2, value: 'two' });
heap.insert({ id: 'c', priority: 8, value: 'eight' });
expect(heap.hasKey('b')).toBeTrue();
const extracted = heap.extractByKey('b');
expect(extracted?.id).toEqual('b');
expect(heap.size).toEqual(2);
expect(heap.hasKey('b')).toBeFalse();
// Should not find it again
const notFound = heap.extractByKey('b');
expect(notFound).toBeUndefined();
});
tap.test('should throw when using key operations without extractKey', async () => {
const heap = new BinaryHeap<TestItem>((a, b) => a.priority - b.priority);
heap.insert({ id: 'a', priority: 5, value: 'five' });
let error: Error | null = null;
try {
heap.extractByKey('a');
} catch (e: any) {
error = e;
}
expect(error).not.toBeNull();
expect(error?.message).toContain('extractKey function must be provided');
});
tap.test('should handle duplicates correctly', async () => {
const heap = new BinaryHeap<number>((a, b) => a - b);
heap.insert(5);
heap.insert(5);
heap.insert(5);
heap.insert(3);
heap.insert(7);
expect(heap.size).toEqual(5);
expect(heap.extract()).toEqual(3);
expect(heap.extract()).toEqual(5);
expect(heap.extract()).toEqual(5);
expect(heap.extract()).toEqual(5);
expect(heap.extract()).toEqual(7);
});
tap.test('should convert to array without modifying heap', async () => {
const heap = new BinaryHeap<number>((a, b) => a - b);
heap.insert(5);
heap.insert(3);
heap.insert(7);
const array = heap.toArray();
expect(array).toContain(3);
expect(array).toContain(5);
expect(array).toContain(7);
expect(array.length).toEqual(3);
// Heap should still be intact
expect(heap.size).toEqual(3);
expect(heap.extract()).toEqual(3);
});
tap.test('should clear the heap', async () => {
const heap = new BinaryHeap<TestItem>(
(a, b) => a.priority - b.priority,
(item) => item.id
);
heap.insert({ id: 'a', priority: 5, value: 'five' });
heap.insert({ id: 'b', priority: 2, value: 'two' });
expect(heap.size).toEqual(2);
expect(heap.hasKey('a')).toBeTrue();
heap.clear();
expect(heap.size).toEqual(0);
expect(heap.isEmpty()).toBeTrue();
expect(heap.hasKey('a')).toBeFalse();
});
tap.test('should handle complex extraction patterns', async () => {
const heap = new BinaryHeap<number>((a, b) => a - b);
// Insert numbers 1-10 in random order
[8, 3, 5, 9, 1, 7, 4, 10, 2, 6].forEach(n => heap.insert(n));
// Extract some in order
expect(heap.extract()).toEqual(1);
expect(heap.extract()).toEqual(2);
// Insert more
heap.insert(0);
heap.insert(1.5);
// Continue extracting
expect(heap.extract()).toEqual(0);
expect(heap.extract()).toEqual(1.5);
expect(heap.extract()).toEqual(3);
// Verify remaining size (10 - 2 extracted + 2 inserted - 3 extracted = 7)
expect(heap.size).toEqual(7);
});
tap.start();

View File

@@ -1,185 +0,0 @@
import { tap, expect } from '@git.zone/tstest/tapbundle';
import * as path from 'path';
import { AsyncFileSystem } from '../../../ts/core/utils/fs-utils.js';
// Use a temporary directory for tests
const testDir = path.join(process.cwd(), '.nogit', 'test-fs-utils');
const testFile = path.join(testDir, 'test.txt');
const testJsonFile = path.join(testDir, 'test.json');
tap.test('should create and check directory existence', async () => {
// Ensure directory
await AsyncFileSystem.ensureDir(testDir);
// Check it exists
const exists = await AsyncFileSystem.exists(testDir);
expect(exists).toBeTrue();
// Check it's a directory
const isDir = await AsyncFileSystem.isDirectory(testDir);
expect(isDir).toBeTrue();
});
tap.test('should write and read text files', async () => {
const testContent = 'Hello, async filesystem!';
// Write file
await AsyncFileSystem.writeFile(testFile, testContent);
// Check file exists
const exists = await AsyncFileSystem.exists(testFile);
expect(exists).toBeTrue();
// Read file
const content = await AsyncFileSystem.readFile(testFile);
expect(content).toEqual(testContent);
// Check it's a file
const isFile = await AsyncFileSystem.isFile(testFile);
expect(isFile).toBeTrue();
});
tap.test('should write and read JSON files', async () => {
const testData = {
name: 'Test',
value: 42,
nested: {
array: [1, 2, 3]
}
};
// Write JSON
await AsyncFileSystem.writeJSON(testJsonFile, testData);
// Read JSON
const readData = await AsyncFileSystem.readJSON(testJsonFile);
expect(readData).toEqual(testData);
});
tap.test('should copy files', async () => {
const copyFile = path.join(testDir, 'copy.txt');
// Copy file
await AsyncFileSystem.copyFile(testFile, copyFile);
// Check copy exists
const exists = await AsyncFileSystem.exists(copyFile);
expect(exists).toBeTrue();
// Check content matches
const content = await AsyncFileSystem.readFile(copyFile);
const originalContent = await AsyncFileSystem.readFile(testFile);
expect(content).toEqual(originalContent);
});
tap.test('should move files', async () => {
const moveFile = path.join(testDir, 'moved.txt');
const copyFile = path.join(testDir, 'copy.txt');
// Move file
await AsyncFileSystem.moveFile(copyFile, moveFile);
// Check moved file exists
const movedExists = await AsyncFileSystem.exists(moveFile);
expect(movedExists).toBeTrue();
// Check original doesn't exist
const originalExists = await AsyncFileSystem.exists(copyFile);
expect(originalExists).toBeFalse();
});
tap.test('should list files in directory', async () => {
const files = await AsyncFileSystem.listFiles(testDir);
expect(files).toContain('test.txt');
expect(files).toContain('test.json');
expect(files).toContain('moved.txt');
});
tap.test('should list files with full paths', async () => {
const files = await AsyncFileSystem.listFilesFullPath(testDir);
const fileNames = files.map(f => path.basename(f));
expect(fileNames).toContain('test.txt');
expect(fileNames).toContain('test.json');
// All paths should be absolute
files.forEach(file => {
expect(path.isAbsolute(file)).toBeTrue();
});
});
tap.test('should get file stats', async () => {
const stats = await AsyncFileSystem.getStats(testFile);
expect(stats).not.toBeNull();
expect(stats?.isFile()).toBeTrue();
expect(stats?.size).toBeGreaterThan(0);
});
tap.test('should handle non-existent files gracefully', async () => {
const nonExistent = path.join(testDir, 'does-not-exist.txt');
// Check existence
const exists = await AsyncFileSystem.exists(nonExistent);
expect(exists).toBeFalse();
// Get stats should return null
const stats = await AsyncFileSystem.getStats(nonExistent);
expect(stats).toBeNull();
// Remove should not throw
await AsyncFileSystem.remove(nonExistent);
});
tap.test('should remove files', async () => {
// Remove a file
await AsyncFileSystem.remove(testFile);
// Check it's gone
const exists = await AsyncFileSystem.exists(testFile);
expect(exists).toBeFalse();
});
tap.test('should ensure file exists', async () => {
const ensureFile = path.join(testDir, 'ensure.txt');
// Ensure file
await AsyncFileSystem.ensureFile(ensureFile);
// Check it exists
const exists = await AsyncFileSystem.exists(ensureFile);
expect(exists).toBeTrue();
// Check it's empty
const content = await AsyncFileSystem.readFile(ensureFile);
expect(content).toEqual('');
});
tap.test('should recursively list files', async () => {
// Create subdirectory with file
const subDir = path.join(testDir, 'subdir');
const subFile = path.join(subDir, 'nested.txt');
await AsyncFileSystem.ensureDir(subDir);
await AsyncFileSystem.writeFile(subFile, 'nested content');
// List recursively
const files = await AsyncFileSystem.listFilesRecursive(testDir);
// Should include files from subdirectory
const fileNames = files.map(f => path.relative(testDir, f));
expect(fileNames).toContain('test.json');
expect(fileNames).toContain(path.join('subdir', 'nested.txt'));
});
tap.test('should clean up test directory', async () => {
// Remove entire test directory
await AsyncFileSystem.removeDir(testDir);
// Check it's gone
const exists = await AsyncFileSystem.exists(testDir);
expect(exists).toBeFalse();
});
tap.start();

View File

@@ -1,156 +0,0 @@
import { expect, tap } from '@git.zone/tstest/tapbundle';
import { IpUtils } from '../../../ts/core/utils/ip-utils.js';
tap.test('ip-utils - normalizeIP', async () => {
// IPv4 normalization
const ipv4Variants = IpUtils.normalizeIP('127.0.0.1');
expect(ipv4Variants).toEqual(['127.0.0.1', '::ffff:127.0.0.1']);
// IPv6-mapped IPv4 normalization
const ipv6MappedVariants = IpUtils.normalizeIP('::ffff:127.0.0.1');
expect(ipv6MappedVariants).toEqual(['::ffff:127.0.0.1', '127.0.0.1']);
// IPv6 normalization
const ipv6Variants = IpUtils.normalizeIP('::1');
expect(ipv6Variants).toEqual(['::1']);
// Invalid/empty input handling
expect(IpUtils.normalizeIP('')).toEqual([]);
expect(IpUtils.normalizeIP(null as any)).toEqual([]);
expect(IpUtils.normalizeIP(undefined as any)).toEqual([]);
});
tap.test('ip-utils - isGlobIPMatch', async () => {
// Direct matches
expect(IpUtils.isGlobIPMatch('127.0.0.1', ['127.0.0.1'])).toEqual(true);
expect(IpUtils.isGlobIPMatch('::1', ['::1'])).toEqual(true);
// Wildcard matches
expect(IpUtils.isGlobIPMatch('127.0.0.1', ['127.0.0.*'])).toEqual(true);
expect(IpUtils.isGlobIPMatch('127.0.0.1', ['127.0.*.*'])).toEqual(true);
expect(IpUtils.isGlobIPMatch('127.0.0.1', ['127.*.*.*'])).toEqual(true);
// IPv4-mapped IPv6 handling
expect(IpUtils.isGlobIPMatch('::ffff:127.0.0.1', ['127.0.0.1'])).toEqual(true);
expect(IpUtils.isGlobIPMatch('127.0.0.1', ['::ffff:127.0.0.1'])).toEqual(true);
// Match multiple patterns
expect(IpUtils.isGlobIPMatch('127.0.0.1', ['10.0.0.1', '127.0.0.1', '192.168.1.1'])).toEqual(true);
// Non-matching patterns
expect(IpUtils.isGlobIPMatch('127.0.0.1', ['10.0.0.1'])).toEqual(false);
expect(IpUtils.isGlobIPMatch('127.0.0.1', ['128.0.0.1'])).toEqual(false);
expect(IpUtils.isGlobIPMatch('127.0.0.1', ['127.0.0.2'])).toEqual(false);
// Edge cases
expect(IpUtils.isGlobIPMatch('', ['127.0.0.1'])).toEqual(false);
expect(IpUtils.isGlobIPMatch('127.0.0.1', [])).toEqual(false);
expect(IpUtils.isGlobIPMatch('127.0.0.1', null as any)).toEqual(false);
expect(IpUtils.isGlobIPMatch(null as any, ['127.0.0.1'])).toEqual(false);
});
tap.test('ip-utils - isIPAuthorized', async () => {
// Basic tests to check the core functionality works
// No restrictions - all IPs allowed
expect(IpUtils.isIPAuthorized('127.0.0.1')).toEqual(true);
// Basic blocked IP test
const blockedIP = '8.8.8.8';
const blockedIPs = [blockedIP];
expect(IpUtils.isIPAuthorized(blockedIP, [], blockedIPs)).toEqual(false);
// Basic allowed IP test
const allowedIP = '10.0.0.1';
const allowedIPs = [allowedIP];
expect(IpUtils.isIPAuthorized(allowedIP, allowedIPs)).toEqual(true);
expect(IpUtils.isIPAuthorized('192.168.1.1', allowedIPs)).toEqual(false);
});
tap.test('ip-utils - isPrivateIP', async () => {
// Private IPv4 ranges
expect(IpUtils.isPrivateIP('10.0.0.1')).toEqual(true);
expect(IpUtils.isPrivateIP('172.16.0.1')).toEqual(true);
expect(IpUtils.isPrivateIP('172.31.255.255')).toEqual(true);
expect(IpUtils.isPrivateIP('192.168.0.1')).toEqual(true);
expect(IpUtils.isPrivateIP('127.0.0.1')).toEqual(true);
// Public IPv4 addresses
expect(IpUtils.isPrivateIP('8.8.8.8')).toEqual(false);
expect(IpUtils.isPrivateIP('203.0.113.1')).toEqual(false);
// IPv4-mapped IPv6 handling
expect(IpUtils.isPrivateIP('::ffff:10.0.0.1')).toEqual(true);
expect(IpUtils.isPrivateIP('::ffff:8.8.8.8')).toEqual(false);
// Private IPv6 addresses
expect(IpUtils.isPrivateIP('::1')).toEqual(true);
expect(IpUtils.isPrivateIP('fd00::')).toEqual(true);
expect(IpUtils.isPrivateIP('fe80::1')).toEqual(true);
// Public IPv6 addresses
expect(IpUtils.isPrivateIP('2001:db8::1')).toEqual(false);
// Edge cases
expect(IpUtils.isPrivateIP('')).toEqual(false);
expect(IpUtils.isPrivateIP(null as any)).toEqual(false);
expect(IpUtils.isPrivateIP(undefined as any)).toEqual(false);
});
tap.test('ip-utils - isPublicIP', async () => {
// Public IPv4 addresses
expect(IpUtils.isPublicIP('8.8.8.8')).toEqual(true);
expect(IpUtils.isPublicIP('203.0.113.1')).toEqual(true);
// Private IPv4 ranges
expect(IpUtils.isPublicIP('10.0.0.1')).toEqual(false);
expect(IpUtils.isPublicIP('172.16.0.1')).toEqual(false);
expect(IpUtils.isPublicIP('192.168.0.1')).toEqual(false);
expect(IpUtils.isPublicIP('127.0.0.1')).toEqual(false);
// Public IPv6 addresses
expect(IpUtils.isPublicIP('2001:db8::1')).toEqual(true);
// Private IPv6 addresses
expect(IpUtils.isPublicIP('::1')).toEqual(false);
expect(IpUtils.isPublicIP('fd00::')).toEqual(false);
expect(IpUtils.isPublicIP('fe80::1')).toEqual(false);
// Edge cases - the implementation treats these as non-private, which is technically correct but might not be what users expect
const emptyResult = IpUtils.isPublicIP('');
expect(emptyResult).toEqual(true);
const nullResult = IpUtils.isPublicIP(null as any);
expect(nullResult).toEqual(true);
const undefinedResult = IpUtils.isPublicIP(undefined as any);
expect(undefinedResult).toEqual(true);
});
tap.test('ip-utils - cidrToGlobPatterns', async () => {
// Class C network
const classC = IpUtils.cidrToGlobPatterns('192.168.1.0/24');
expect(classC).toEqual(['192.168.1.*']);
// Class B network
const classB = IpUtils.cidrToGlobPatterns('172.16.0.0/16');
expect(classB).toEqual(['172.16.*.*']);
// Class A network
const classA = IpUtils.cidrToGlobPatterns('10.0.0.0/8');
expect(classA).toEqual(['10.*.*.*']);
// Small subnet (/28 = 16 addresses)
const smallSubnet = IpUtils.cidrToGlobPatterns('192.168.1.0/28');
expect(smallSubnet.length).toEqual(16);
expect(smallSubnet).toContain('192.168.1.0');
expect(smallSubnet).toContain('192.168.1.15');
// Invalid inputs
expect(IpUtils.cidrToGlobPatterns('')).toEqual([]);
expect(IpUtils.cidrToGlobPatterns('192.168.1.0')).toEqual([]);
expect(IpUtils.cidrToGlobPatterns('192.168.1.0/')).toEqual([]);
expect(IpUtils.cidrToGlobPatterns('192.168.1.0/33')).toEqual([]);
expect(IpUtils.cidrToGlobPatterns('invalid/24')).toEqual([]);
});
export default tap.start();

View File

@@ -1,252 +0,0 @@
import { tap, expect } from '@git.zone/tstest/tapbundle';
import { LifecycleComponent } from '../../../ts/core/utils/lifecycle-component.js';
import { EventEmitter } from 'events';
// Test implementation of LifecycleComponent
class TestComponent extends LifecycleComponent {
public timerCallCount = 0;
public intervalCallCount = 0;
public cleanupCalled = false;
public testEmitter = new EventEmitter();
public listenerCallCount = 0;
constructor() {
super();
this.setupTimers();
this.setupListeners();
}
private setupTimers() {
// Set up a timeout
this.setTimeout(() => {
this.timerCallCount++;
}, 100);
// Set up an interval
this.setInterval(() => {
this.intervalCallCount++;
}, 50);
}
private setupListeners() {
this.addEventListener(this.testEmitter, 'test-event', () => {
this.listenerCallCount++;
});
}
protected async onCleanup(): Promise<void> {
this.cleanupCalled = true;
}
// Expose protected methods for testing
public testSetTimeout(handler: Function, timeout: number): NodeJS.Timeout {
return this.setTimeout(handler, timeout);
}
public testSetInterval(handler: Function, interval: number): NodeJS.Timeout {
return this.setInterval(handler, interval);
}
public testClearTimeout(timer: NodeJS.Timeout): void {
return this.clearTimeout(timer);
}
public testClearInterval(timer: NodeJS.Timeout): void {
return this.clearInterval(timer);
}
public testAddEventListener(target: any, event: string, handler: Function, options?: { once?: boolean }): void {
return this.addEventListener(target, event, handler, options);
}
public testIsShuttingDown(): boolean {
return this.isShuttingDownState();
}
}
tap.test('should manage timers properly', async () => {
const component = new TestComponent();
// Wait for timers to fire
await new Promise(resolve => setTimeout(resolve, 200));
expect(component.timerCallCount).toEqual(1);
expect(component.intervalCallCount).toBeGreaterThan(2);
await component.cleanup();
});
tap.test('should manage event listeners properly', async () => {
const component = new TestComponent();
// Emit events
component.testEmitter.emit('test-event');
component.testEmitter.emit('test-event');
expect(component.listenerCallCount).toEqual(2);
// Cleanup and verify listeners are removed
await component.cleanup();
component.testEmitter.emit('test-event');
expect(component.listenerCallCount).toEqual(2); // Should not increase
});
tap.test('should prevent timer execution after cleanup', async () => {
const component = new TestComponent();
let laterCallCount = 0;
component.testSetTimeout(() => {
laterCallCount++;
}, 100);
// Cleanup immediately
await component.cleanup();
// Wait for timer that would have fired
await new Promise(resolve => setTimeout(resolve, 150));
expect(laterCallCount).toEqual(0);
});
tap.test('should handle child components', async () => {
class ParentComponent extends LifecycleComponent {
public child: TestComponent;
constructor() {
super();
this.child = new TestComponent();
this.registerChildComponent(this.child);
}
}
const parent = new ParentComponent();
// Wait for child timers
await new Promise(resolve => setTimeout(resolve, 100));
expect(parent.child.timerCallCount).toEqual(1);
// Cleanup parent should cleanup child
await parent.cleanup();
expect(parent.child.cleanupCalled).toBeTrue();
expect(parent.child.testIsShuttingDown()).toBeTrue();
});
tap.test('should handle multiple cleanup calls gracefully', async () => {
const component = new TestComponent();
// Call cleanup multiple times
const promises = [
component.cleanup(),
component.cleanup(),
component.cleanup()
];
await Promise.all(promises);
// Should only clean up once
expect(component.cleanupCalled).toBeTrue();
});
tap.test('should clear specific timers', async () => {
const component = new TestComponent();
let callCount = 0;
const timer = component.testSetTimeout(() => {
callCount++;
}, 100);
// Clear the timer
component.testClearTimeout(timer);
// Wait and verify it didn't fire
await new Promise(resolve => setTimeout(resolve, 150));
expect(callCount).toEqual(0);
await component.cleanup();
});
tap.test('should clear specific intervals', async () => {
const component = new TestComponent();
let callCount = 0;
const interval = component.testSetInterval(() => {
callCount++;
}, 50);
// Let it run a bit
await new Promise(resolve => setTimeout(resolve, 120));
const countBeforeClear = callCount;
expect(countBeforeClear).toBeGreaterThan(1);
// Clear the interval
component.testClearInterval(interval);
// Wait and verify it stopped
await new Promise(resolve => setTimeout(resolve, 100));
expect(callCount).toEqual(countBeforeClear);
await component.cleanup();
});
tap.test('should handle once event listeners', async () => {
const component = new TestComponent();
const emitter = new EventEmitter();
let callCount = 0;
const handler = () => {
callCount++;
};
component.testAddEventListener(emitter, 'once-event', handler, { once: true });
// Check listener count before emit
const beforeCount = emitter.listenerCount('once-event');
expect(beforeCount).toEqual(1);
// Emit once - the listener should fire and auto-remove
emitter.emit('once-event');
expect(callCount).toEqual(1);
// Check listener was auto-removed
const afterCount = emitter.listenerCount('once-event');
expect(afterCount).toEqual(0);
// Emit again - should not increase count
emitter.emit('once-event');
expect(callCount).toEqual(1);
await component.cleanup();
});
tap.test('should not create timers when shutting down', async () => {
const component = new TestComponent();
// Start cleanup
const cleanupPromise = component.cleanup();
// Try to create timers during shutdown
let timerFired = false;
let intervalFired = false;
component.testSetTimeout(() => {
timerFired = true;
}, 10);
component.testSetInterval(() => {
intervalFired = true;
}, 10);
await cleanupPromise;
await new Promise(resolve => setTimeout(resolve, 50));
expect(timerFired).toBeFalse();
expect(intervalFired).toBeFalse();
});
export default tap.start();

View File

@@ -1,158 +0,0 @@
import { expect, tap } from '@git.zone/tstest/tapbundle';
import { SharedSecurityManager } from '../../../ts/core/utils/shared-security-manager.js';
import type { IRouteConfig, IRouteContext } from '../../../ts/proxies/smart-proxy/models/route-types.js';
// Test security manager
tap.test('Shared Security Manager', async () => {
let securityManager: SharedSecurityManager;
// Set up a new security manager for each test
securityManager = new SharedSecurityManager({
maxConnectionsPerIP: 5,
connectionRateLimitPerMinute: 10
});
tap.test('should validate IPs correctly', async () => {
// Should allow IPs under connection limit
expect(securityManager.validateIP('192.168.1.1').allowed).toBeTrue();
// Track multiple connections
for (let i = 0; i < 4; i++) {
securityManager.trackConnectionByIP('192.168.1.1', `conn_${i}`);
}
// Should still allow IPs under connection limit
expect(securityManager.validateIP('192.168.1.1').allowed).toBeTrue();
// Add one more to reach the limit
securityManager.trackConnectionByIP('192.168.1.1', 'conn_4');
// Should now block IPs over connection limit
expect(securityManager.validateIP('192.168.1.1').allowed).toBeFalse();
// Remove a connection
securityManager.removeConnectionByIP('192.168.1.1', 'conn_0');
// Should allow again after connection is removed
expect(securityManager.validateIP('192.168.1.1').allowed).toBeTrue();
});
tap.test('should authorize IPs based on allow/block lists', async () => {
// Test with allow list only
expect(securityManager.isIPAuthorized('192.168.1.1', ['192.168.1.*'])).toBeTrue();
expect(securityManager.isIPAuthorized('192.168.2.1', ['192.168.1.*'])).toBeFalse();
// Test with block list
expect(securityManager.isIPAuthorized('192.168.1.5', ['*'], ['192.168.1.5'])).toBeFalse();
expect(securityManager.isIPAuthorized('192.168.1.1', ['*'], ['192.168.1.5'])).toBeTrue();
// Test with both allow and block lists
expect(securityManager.isIPAuthorized('192.168.1.1', ['192.168.1.*'], ['192.168.1.5'])).toBeTrue();
expect(securityManager.isIPAuthorized('192.168.1.5', ['192.168.1.*'], ['192.168.1.5'])).toBeFalse();
});
tap.test('should validate route access', async () => {
const route: IRouteConfig = {
match: {
ports: [8080]
},
action: {
type: 'forward',
targets: [{ host: 'target.com', port: 443 }]
},
security: {
ipAllowList: ['10.0.0.*', '192.168.1.*'],
ipBlockList: ['192.168.1.100'],
maxConnections: 3
}
};
const allowedContext: IRouteContext = {
clientIp: '192.168.1.1',
port: 8080,
serverIp: '127.0.0.1',
isTls: false,
timestamp: Date.now(),
connectionId: 'test_conn_1'
};
const blockedByIPContext: IRouteContext = {
...allowedContext,
clientIp: '192.168.1.100'
};
const blockedByRangeContext: IRouteContext = {
...allowedContext,
clientIp: '172.16.0.1'
};
const blockedByMaxConnectionsContext: IRouteContext = {
...allowedContext,
connectionId: 'test_conn_4'
};
expect(securityManager.isAllowed(route, allowedContext)).toBeTrue();
expect(securityManager.isAllowed(route, blockedByIPContext)).toBeFalse();
expect(securityManager.isAllowed(route, blockedByRangeContext)).toBeFalse();
// Test max connections for route - assuming implementation has been updated
if ((securityManager as any).trackConnectionByRoute) {
(securityManager as any).trackConnectionByRoute(route, 'conn_1');
(securityManager as any).trackConnectionByRoute(route, 'conn_2');
(securityManager as any).trackConnectionByRoute(route, 'conn_3');
// Should now block due to max connections
expect(securityManager.isAllowed(route, blockedByMaxConnectionsContext)).toBeFalse();
}
});
tap.test('should clean up expired entries', async () => {
const route: IRouteConfig = {
match: {
ports: [8080]
},
action: {
type: 'forward',
targets: [{ host: 'target.com', port: 443 }]
},
security: {
rateLimit: {
enabled: true,
maxRequests: 5,
window: 60 // 60 seconds
}
}
};
const context: IRouteContext = {
clientIp: '192.168.1.1',
port: 8080,
serverIp: '127.0.0.1',
isTls: false,
timestamp: Date.now(),
connectionId: 'test_conn_1'
};
// Test rate limiting if method exists
if ((securityManager as any).checkRateLimit) {
// Add 5 attempts (max allowed)
for (let i = 0; i < 5; i++) {
expect((securityManager as any).checkRateLimit(route, context)).toBeTrue();
}
// Should now be blocked
expect((securityManager as any).checkRateLimit(route, context)).toBeFalse();
// Force cleanup (normally runs periodically)
if ((securityManager as any).cleanup) {
(securityManager as any).cleanup();
}
// Should still be blocked since entries are not expired yet
expect((securityManager as any).checkRateLimit(route, context)).toBeFalse();
}
});
});
// Export test runner
export default tap.start();

View File

@@ -1,302 +0,0 @@
import { expect, tap } from '@git.zone/tstest/tapbundle';
import { ValidationUtils } from '../../../ts/core/utils/validation-utils.js';
import type { IDomainOptions, IAcmeOptions } from '../../../ts/core/models/common-types.js';
tap.test('validation-utils - isValidPort', async () => {
// Valid port values
expect(ValidationUtils.isValidPort(1)).toEqual(true);
expect(ValidationUtils.isValidPort(80)).toEqual(true);
expect(ValidationUtils.isValidPort(443)).toEqual(true);
expect(ValidationUtils.isValidPort(8080)).toEqual(true);
expect(ValidationUtils.isValidPort(65535)).toEqual(true);
// Invalid port values
expect(ValidationUtils.isValidPort(0)).toEqual(false);
expect(ValidationUtils.isValidPort(-1)).toEqual(false);
expect(ValidationUtils.isValidPort(65536)).toEqual(false);
expect(ValidationUtils.isValidPort(80.5)).toEqual(false);
expect(ValidationUtils.isValidPort(NaN)).toEqual(false);
expect(ValidationUtils.isValidPort(null as any)).toEqual(false);
expect(ValidationUtils.isValidPort(undefined as any)).toEqual(false);
});
tap.test('validation-utils - isValidDomainName', async () => {
// Valid domain names
expect(ValidationUtils.isValidDomainName('example.com')).toEqual(true);
expect(ValidationUtils.isValidDomainName('sub.example.com')).toEqual(true);
expect(ValidationUtils.isValidDomainName('*.example.com')).toEqual(true);
expect(ValidationUtils.isValidDomainName('a-hyphenated-domain.example.com')).toEqual(true);
expect(ValidationUtils.isValidDomainName('example123.com')).toEqual(true);
// Invalid domain names
expect(ValidationUtils.isValidDomainName('')).toEqual(false);
expect(ValidationUtils.isValidDomainName(null as any)).toEqual(false);
expect(ValidationUtils.isValidDomainName(undefined as any)).toEqual(false);
expect(ValidationUtils.isValidDomainName('-invalid.com')).toEqual(false);
expect(ValidationUtils.isValidDomainName('invalid-.com')).toEqual(false);
expect(ValidationUtils.isValidDomainName('inv@lid.com')).toEqual(false);
expect(ValidationUtils.isValidDomainName('example')).toEqual(false);
expect(ValidationUtils.isValidDomainName('example.')).toEqual(false);
});
tap.test('validation-utils - isValidEmail', async () => {
// Valid email addresses
expect(ValidationUtils.isValidEmail('user@example.com')).toEqual(true);
expect(ValidationUtils.isValidEmail('admin@sub.example.com')).toEqual(true);
expect(ValidationUtils.isValidEmail('first.last@example.com')).toEqual(true);
expect(ValidationUtils.isValidEmail('user+tag@example.com')).toEqual(true);
// Invalid email addresses
expect(ValidationUtils.isValidEmail('')).toEqual(false);
expect(ValidationUtils.isValidEmail(null as any)).toEqual(false);
expect(ValidationUtils.isValidEmail(undefined as any)).toEqual(false);
expect(ValidationUtils.isValidEmail('user')).toEqual(false);
expect(ValidationUtils.isValidEmail('user@')).toEqual(false);
expect(ValidationUtils.isValidEmail('@example.com')).toEqual(false);
expect(ValidationUtils.isValidEmail('user example.com')).toEqual(false);
});
tap.test('validation-utils - isValidCertificate', async () => {
// Valid certificate format
const validCert = `-----BEGIN CERTIFICATE-----
MIIDazCCAlOgAwIBAgIUJlq+zz9CO2E91rlD4vhx0CX1Z/kwDQYJKoZIhvcNAQEL
BQAwRTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM
GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDAeFw0yMzAxMDEwMDAwMDBaFw0yNDAx
MDEwMDAwMDBaMEUxCzAJBgNVBAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEw
HwYDVQQKDBhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGQwggEiMA0GCSqGSIb3DQEB
AQUAA4IBDwAwggEKAoIBAQC0aQeHIV9vQpZ4UVwW/xhx9zl01UbppLXdoqe3NP9x
KfXTCB1YbtJ4GgKIlQqHGLGsLI5ZOE7KxmJeGEwK7ueP4f3WkUlM5C5yTbZ5hSUo
R+OFnszFRJJiBXJlw57YAW9+zqKQHYxwve64O64dlgw6pekDYJhXtrUUZ78Lz0GX
veJvCrci1M4Xk6/7/p1Ii9PNmbPKqHafdmkFLf6TXiWPuRDhPuHW7cXyE8xD5ahr
NsDuwJyRUk+GS4/oJg0TqLSiD0IPxDH50V5MSfUIB82i+lc1t+OAGwLhjUDuQmJi
Pv1+9Zvv+HA5PXBCsGXnSADrOOUO6t9q5R9PXbSvAgMBAAGjUzBRMB0GA1UdDgQW
BBQEtdtBhH/z1XyIf+y+5O9ErDGCVjAfBgNVHSMEGDAWgBQEtdtBhH/z1XyIf+y+
5O9ErDGCVjAPBgNVHRMBAf8EBTADAQH/MA0GCSqGSIb3DQEBCwUAA4IBAQBmJyQ0
r0pBJkYJJVDJ6i3WMoEEFTD8MEUkWxASHRnuMzm7XlZ8WS1HvbEWF0+WfJPCYHnk
tGbvUFGaZ4qUxZ4Ip2mvKXoeYTJCZRxxhHeSVWnZZu0KS3X7xVAFwQYQNhdLOqP8
XOHyLhHV/1/kcFd3GvKKjXxE79jUUZ/RXHZ/IY50KvxGzWc/5ZOFYrPEW1/rNlRo
7ixXo1hNnBQsG1YoFAxTBGegdTFJeTYHYjZZ5XlRvY2aBq6QveRbJGJLcPm1UQMd
HQYxacbWSVAQf3ltYwSH+y3a97C5OsJJiQXpRRJlQKL3txklzcpg3E5swhr63bM2
jUoNXr5G5Q5h3GD5
-----END CERTIFICATE-----`;
expect(ValidationUtils.isValidCertificate(validCert)).toEqual(true);
// Invalid certificate format
expect(ValidationUtils.isValidCertificate('')).toEqual(false);
expect(ValidationUtils.isValidCertificate(null as any)).toEqual(false);
expect(ValidationUtils.isValidCertificate(undefined as any)).toEqual(false);
expect(ValidationUtils.isValidCertificate('invalid certificate')).toEqual(false);
expect(ValidationUtils.isValidCertificate('-----BEGIN CERTIFICATE-----')).toEqual(false);
});
tap.test('validation-utils - isValidPrivateKey', async () => {
// Valid private key format
const validKey = `-----BEGIN PRIVATE KEY-----
MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQC0aQeHIV9vQpZ4
UVwW/xhx9zl01UbppLXdoqe3NP9xKfXTCB1YbtJ4GgKIlQqHGLGsLI5ZOE7KxmJe
GEwK7ueP4f3WkUlM5C5yTbZ5hSUoR+OFnszFRJJiBXJlw57YAW9+zqKQHYxwve64
O64dlgw6pekDYJhXtrUUZ78Lz0GXveJvCrci1M4Xk6/7/p1Ii9PNmbPKqHafdmkF
Lf6TXiWPuRDhPuHW7cXyE8xD5ahrNsDuwJyRUk+GS4/oJg0TqLSiD0IPxDH50V5M
SfUIB82i+lc1t+OAGwLhjUDuQmJiPv1+9Zvv+HA5PXBCsGXnSADrOOUO6t9q5R9P
XbSvAgMBAAECggEADw8Xx9iEv3FvS8hYIRn2ZWM8ObRgbHkFN92NJ/5RvUwgyV03
gG8GwVN+7IsVLnIQRyIYEGGJ0ZLZFIq7//Jy0jYUgEGLmXxknuZQn1cQEqqYVyBr
G9JrfKkXaDEoP/bZBMvZ0KEO2C9Vq6mY8M0h0GxDT2y6UQnQYjH3+H6Rvhbhh+Ld
n8lCJqWoW1t9GOUZ4xLsZ5jEDibcMJJzLBWYRxgHWyECK31/VtEQDKFiUcymrJ3I
/zoDEDGbp1gdJHvlCxfSLJ2za7ErtRKRXYFRhZ9QkNSXl1pVFMqRQkedXIcA1/Cs
VpUxiIE2JA3hSrv2csjmXoGJKDLVCvZ3CFxKL3u/AQKBgQDf6MxHXN3IDuJNrJP7
0gyRbO5d6vcvP/8qiYjtEt2xB2MNt5jDz9Bxl6aKEdNW2+UE0rvXXT6KAMZv9LiF
hxr5qiJmmSB8OeGfr0W4FCixGN4BkRNwfT1gUqZgQOrfMOLHNXOksc1CJwHJfROV
h6AH+gjtF2BCXnVEHcqtRklk4QKBgQDOOYnLJn1CwgFAyRUYK8LQYKnrLp2cGn7N
YH0SLf+VnCu7RCeNr3dm9FoHBCynjkx+qv9kGvCaJuZqEJ7+7IimNUZfDjwXTOJ+
pzs8kEPN5EQOcbkmYCTQyOA0YeBuEXcv5xIZRZUYQvKg1xXOe/JhAQ4siVIMhgQL
2XR3QwzRDwKBgB7rjZs2VYnuVExGr74lUUAGoZ71WCgt9Du9aYGJfNUriDtTEWAd
VT5sKgVqpRwkY/zXujdxGr+K8DZu4vSdHBLcDLQsEBvRZIILTzjwXBRPGMnVe95v
Q90+vytbmHshlkbMaVRNQxCjdbf7LbQbLecgRt+5BKxHVwL4u3BZNIqhAoGAas4f
PoPOdFfKAMKZL7FLGMhEXLyFsg1JcGRfmByxTNgOJKXpYv5Hl7JLYOvfaiUOUYKI
5Dnh5yLdFOaOjnB3iP0KEiSVEwZK0/Vna5JkzFTqImK9QD3SQCtQLXHJLD52EPFR
9gRa8N5k68+mIzGDEzPBoC1AajbXFGPxNOwaQQ0CgYEAq0dPYK0TTv3Yez27LzVy
RbHkwpE+df4+KhpHbCzUKzfQYo4WTahlR6IzhpOyVQKIptkjuTDyQzkmt0tXEGw3
/M3yHa1FcY9IzPrHXHJoOeU1r9ay0GOQUi4FxKkYYWxUCtjOi5xlUxI0ABD8vGGR
QbKMrQXRgLd/84nDnY2cYzA=
-----END PRIVATE KEY-----`;
expect(ValidationUtils.isValidPrivateKey(validKey)).toEqual(true);
// Invalid private key format
expect(ValidationUtils.isValidPrivateKey('')).toEqual(false);
expect(ValidationUtils.isValidPrivateKey(null as any)).toEqual(false);
expect(ValidationUtils.isValidPrivateKey(undefined as any)).toEqual(false);
expect(ValidationUtils.isValidPrivateKey('invalid key')).toEqual(false);
expect(ValidationUtils.isValidPrivateKey('-----BEGIN PRIVATE KEY-----')).toEqual(false);
});
tap.test('validation-utils - validateDomainOptions', async () => {
// Valid domain options
const validDomainOptions: IDomainOptions = {
domainName: 'example.com',
sslRedirect: true,
acmeMaintenance: true
};
expect(ValidationUtils.validateDomainOptions(validDomainOptions).isValid).toEqual(true);
// Valid domain options with forward
const validDomainOptionsWithForward: IDomainOptions = {
domainName: 'example.com',
sslRedirect: true,
acmeMaintenance: true,
forward: {
ip: '127.0.0.1',
port: 8080
}
};
expect(ValidationUtils.validateDomainOptions(validDomainOptionsWithForward).isValid).toEqual(true);
// Invalid domain options - no domain name
const invalidDomainOptions1: IDomainOptions = {
domainName: '',
sslRedirect: true,
acmeMaintenance: true
};
expect(ValidationUtils.validateDomainOptions(invalidDomainOptions1).isValid).toEqual(false);
// Invalid domain options - invalid domain name
const invalidDomainOptions2: IDomainOptions = {
domainName: 'inv@lid.com',
sslRedirect: true,
acmeMaintenance: true
};
expect(ValidationUtils.validateDomainOptions(invalidDomainOptions2).isValid).toEqual(false);
// Invalid domain options - forward missing ip
const invalidDomainOptions3: IDomainOptions = {
domainName: 'example.com',
sslRedirect: true,
acmeMaintenance: true,
forward: {
ip: '',
port: 8080
}
};
expect(ValidationUtils.validateDomainOptions(invalidDomainOptions3).isValid).toEqual(false);
// Invalid domain options - forward missing port
const invalidDomainOptions4: IDomainOptions = {
domainName: 'example.com',
sslRedirect: true,
acmeMaintenance: true,
forward: {
ip: '127.0.0.1',
port: null as any
}
};
expect(ValidationUtils.validateDomainOptions(invalidDomainOptions4).isValid).toEqual(false);
// Invalid domain options - invalid forward port
const invalidDomainOptions5: IDomainOptions = {
domainName: 'example.com',
sslRedirect: true,
acmeMaintenance: true,
forward: {
ip: '127.0.0.1',
port: 99999
}
};
expect(ValidationUtils.validateDomainOptions(invalidDomainOptions5).isValid).toEqual(false);
});
tap.test('validation-utils - validateAcmeOptions', async () => {
// Valid ACME options
const validAcmeOptions: IAcmeOptions = {
enabled: true,
accountEmail: 'admin@example.com',
port: 80,
httpsRedirectPort: 443,
useProduction: false,
renewThresholdDays: 30,
renewCheckIntervalHours: 24,
certificateStore: './certs'
};
expect(ValidationUtils.validateAcmeOptions(validAcmeOptions).isValid).toEqual(true);
// ACME disabled - should be valid regardless of other options
const disabledAcmeOptions: IAcmeOptions = {
enabled: false
};
// Don't need to verify other fields when ACME is disabled
const disabledResult = ValidationUtils.validateAcmeOptions(disabledAcmeOptions);
expect(disabledResult.isValid).toEqual(true);
// Invalid ACME options - missing email
const invalidAcmeOptions1: IAcmeOptions = {
enabled: true,
accountEmail: '',
port: 80
};
expect(ValidationUtils.validateAcmeOptions(invalidAcmeOptions1).isValid).toEqual(false);
// Invalid ACME options - invalid email
const invalidAcmeOptions2: IAcmeOptions = {
enabled: true,
accountEmail: 'invalid-email',
port: 80
};
expect(ValidationUtils.validateAcmeOptions(invalidAcmeOptions2).isValid).toEqual(false);
// Invalid ACME options - invalid port
const invalidAcmeOptions3: IAcmeOptions = {
enabled: true,
accountEmail: 'admin@example.com',
port: 99999
};
expect(ValidationUtils.validateAcmeOptions(invalidAcmeOptions3).isValid).toEqual(false);
// Invalid ACME options - invalid HTTPS redirect port
const invalidAcmeOptions4: IAcmeOptions = {
enabled: true,
accountEmail: 'admin@example.com',
port: 80,
httpsRedirectPort: -1
};
expect(ValidationUtils.validateAcmeOptions(invalidAcmeOptions4).isValid).toEqual(false);
// Invalid ACME options - invalid renew threshold days
const invalidAcmeOptions5: IAcmeOptions = {
enabled: true,
accountEmail: 'admin@example.com',
port: 80,
renewThresholdDays: 0
};
// The implementation allows renewThresholdDays of 0, even though the docstring suggests otherwise
const validationResult5 = ValidationUtils.validateAcmeOptions(invalidAcmeOptions5);
expect(validationResult5.isValid).toEqual(true);
// Invalid ACME options - invalid renew check interval hours
const invalidAcmeOptions6: IAcmeOptions = {
enabled: true,
accountEmail: 'admin@example.com',
port: 80,
renewCheckIntervalHours: 0
};
// The implementation should validate this, but let's check the actual result
const checkIntervalResult = ValidationUtils.validateAcmeOptions(invalidAcmeOptions6);
// Adjust test to match actual implementation behavior
expect(checkIntervalResult.isValid !== false ? true : false).toEqual(true);
});
export default tap.start();

View File

@@ -1,11 +1,5 @@
import { tap, expect } from '@git.zone/tstest/tapbundle';
import {
createHttpsTerminateRoute,
createCompleteHttpsServer,
createHttpRoute,
} from '../ts/proxies/smart-proxy/utils/route-helpers.js';
import {
mergeRouteConfigs,
cloneRoute,
@@ -19,8 +13,11 @@ import {
import type { IRouteConfig } from '../ts/proxies/smart-proxy/models/route-types.js';
tap.test('route creation - createHttpsTerminateRoute produces correct structure', async () => {
const route = createHttpsTerminateRoute('secure.example.com', { host: '127.0.0.1', port: 8443 });
tap.test('route creation - HTTPS terminate route has correct structure', async () => {
const route: IRouteConfig = {
match: { ports: 443, domains: 'secure.example.com' },
action: { type: 'forward', targets: [{ host: '127.0.0.1', port: 8443 }], tls: { mode: 'terminate', certificate: 'auto' } }
};
expect(route).toHaveProperty('match');
expect(route).toHaveProperty('action');
expect(route.action.type).toEqual('forward');
@@ -29,20 +26,10 @@ tap.test('route creation - createHttpsTerminateRoute produces correct structure'
expect(route.match.domains).toEqual('secure.example.com');
});
tap.test('route creation - createCompleteHttpsServer returns redirect and main route', async () => {
const routes = createCompleteHttpsServer('app.example.com', { host: '127.0.0.1', port: 3000 });
expect(routes).toBeArray();
expect(routes.length).toBeGreaterThanOrEqual(2);
// Should have an HTTP→HTTPS redirect and an HTTPS route
const hasRedirect = routes.some((r) => r.action.type === 'forward' && r.action.redirect !== undefined);
const hasHttps = routes.some((r) => r.action.tls?.mode === 'terminate');
expect(hasRedirect || hasHttps).toBeTrue();
});
tap.test('route validation - validateRoutes on a set of routes', async () => {
const routes: IRouteConfig[] = [
createHttpRoute('a.com', { host: '127.0.0.1', port: 3000 }),
createHttpRoute('b.com', { host: '127.0.0.1', port: 4000 }),
{ match: { ports: 80, domains: 'a.com' }, action: { type: 'forward', targets: [{ host: '127.0.0.1', port: 3000 }] } },
{ match: { ports: 80, domains: 'b.com' }, action: { type: 'forward', targets: [{ host: '127.0.0.1', port: 4000 }] } },
];
const result = validateRoutes(routes);
expect(result.valid).toBeTrue();
@@ -51,7 +38,7 @@ tap.test('route validation - validateRoutes on a set of routes', async () => {
tap.test('route validation - validateRoutes catches invalid route in set', async () => {
const routes: any[] = [
createHttpRoute('valid.com', { host: '127.0.0.1', port: 3000 }),
{ match: { ports: 80, domains: 'valid.com' }, action: { type: 'forward', targets: [{ host: '127.0.0.1', port: 3000 }] } },
{ match: { ports: 80 } }, // missing action
];
const result = validateRoutes(routes);
@@ -60,23 +47,30 @@ tap.test('route validation - validateRoutes catches invalid route in set', async
});
tap.test('path matching - routeMatchesPath with exact path', async () => {
const route = createHttpRoute('example.com', { host: '127.0.0.1', port: 3000 });
route.match.path = '/api';
const route: IRouteConfig = {
match: { ports: 80, domains: 'example.com', path: '/api' },
action: { type: 'forward', targets: [{ host: '127.0.0.1', port: 3000 }] }
};
expect(routeMatchesPath(route, '/api')).toBeTrue();
expect(routeMatchesPath(route, '/other')).toBeFalse();
});
tap.test('path matching - route without path matches everything', async () => {
const route = createHttpRoute('example.com', { host: '127.0.0.1', port: 3000 });
// No path set, should match any path
const route: IRouteConfig = {
match: { ports: 80, domains: 'example.com' },
action: { type: 'forward', targets: [{ host: '127.0.0.1', port: 3000 }] }
};
expect(routeMatchesPath(route, '/anything')).toBeTrue();
expect(routeMatchesPath(route, '/')).toBeTrue();
});
tap.test('route merging - mergeRouteConfigs combines routes', async () => {
const base = createHttpRoute('example.com', { host: '127.0.0.1', port: 3000 });
base.priority = 10;
base.name = 'base-route';
const base: IRouteConfig = {
match: { ports: 80, domains: 'example.com' },
action: { type: 'forward', targets: [{ host: '127.0.0.1', port: 3000 }] },
priority: 10,
name: 'base-route'
};
const merged = mergeRouteConfigs(base, {
priority: 50,
@@ -85,14 +79,16 @@ tap.test('route merging - mergeRouteConfigs combines routes', async () => {
expect(merged.priority).toEqual(50);
expect(merged.name).toEqual('merged-route');
// Original route fields should be preserved
expect(merged.match.domains).toEqual('example.com');
expect(merged.action.targets![0].host).toEqual('127.0.0.1');
});
tap.test('route merging - mergeRouteConfigs does not mutate original', async () => {
const base = createHttpRoute('example.com', { host: '127.0.0.1', port: 3000 });
base.name = 'original';
const base: IRouteConfig = {
match: { ports: 80, domains: 'example.com' },
action: { type: 'forward', targets: [{ host: '127.0.0.1', port: 3000 }] },
name: 'original'
};
const merged = mergeRouteConfigs(base, { name: 'changed' });
expect(base.name).toEqual('original');
@@ -100,20 +96,21 @@ tap.test('route merging - mergeRouteConfigs does not mutate original', async ()
});
tap.test('route cloning - cloneRoute produces independent copy', async () => {
const original = createHttpRoute('example.com', { host: '127.0.0.1', port: 3000 });
original.priority = 42;
original.name = 'original-route';
const original: IRouteConfig = {
match: { ports: 80, domains: 'example.com' },
action: { type: 'forward', targets: [{ host: '127.0.0.1', port: 3000 }] },
priority: 42,
name: 'original-route'
};
const cloned = cloneRoute(original);
// Should be equal in value
expect(cloned.match.domains).toEqual('example.com');
expect(cloned.priority).toEqual(42);
expect(cloned.name).toEqual('original-route');
expect(cloned.action.targets![0].host).toEqual('127.0.0.1');
expect(cloned.action.targets![0].port).toEqual(3000);
// Should be independent - modifying clone shouldn't affect original
cloned.name = 'cloned-route';
cloned.priority = 99;
expect(original.name).toEqual('original-route');

View File

@@ -1,11 +1,5 @@
import { tap, expect } from '@git.zone/tstest/tapbundle';
import {
createHttpRoute,
createHttpsTerminateRoute,
createLoadBalancerRoute,
} from '../ts/proxies/smart-proxy/utils/route-helpers.js';
import {
findMatchingRoutes,
findBestMatchingRoute,
@@ -22,24 +16,11 @@ import {
import type { IRouteConfig } from '../ts/proxies/smart-proxy/models/route-types.js';
tap.test('route creation - createHttpRoute produces correct structure', async () => {
const route = createHttpRoute('example.com', { host: '127.0.0.1', port: 3000 });
expect(route).toHaveProperty('match');
expect(route).toHaveProperty('action');
expect(route.match.domains).toEqual('example.com');
expect(route.action.type).toEqual('forward');
expect(route.action.targets).toBeArray();
expect(route.action.targets![0].host).toEqual('127.0.0.1');
expect(route.action.targets![0].port).toEqual(3000);
});
tap.test('route creation - createHttpRoute with array of domains', async () => {
const route = createHttpRoute(['a.com', 'b.com'], { host: 'localhost', port: 8080 });
expect(route.match.domains).toEqual(['a.com', 'b.com']);
});
tap.test('route validation - validateRouteConfig accepts valid route', async () => {
const route = createHttpRoute('valid.example.com', { host: '10.0.0.1', port: 8080 });
const route: IRouteConfig = {
match: { ports: 80, domains: 'valid.example.com' },
action: { type: 'forward', targets: [{ host: '10.0.0.1', port: 8080 }] }
};
const result = validateRouteConfig(route);
expect(result.valid).toBeTrue();
expect(result.errors).toHaveLength(0);
@@ -67,30 +48,44 @@ tap.test('route validation - isValidPort checks correctly', async () => {
});
tap.test('domain matching - exact domain', async () => {
const route = createHttpRoute('example.com', { host: '127.0.0.1', port: 3000 });
const route: IRouteConfig = {
match: { ports: 80, domains: 'example.com' },
action: { type: 'forward', targets: [{ host: '127.0.0.1', port: 3000 }] }
};
expect(routeMatchesDomain(route, 'example.com')).toBeTrue();
expect(routeMatchesDomain(route, 'other.com')).toBeFalse();
});
tap.test('domain matching - wildcard domain', async () => {
const route = createHttpRoute('*.example.com', { host: '127.0.0.1', port: 3000 });
const route: IRouteConfig = {
match: { ports: 80, domains: '*.example.com' },
action: { type: 'forward', targets: [{ host: '127.0.0.1', port: 3000 }] }
};
expect(routeMatchesDomain(route, 'sub.example.com')).toBeTrue();
expect(routeMatchesDomain(route, 'example.com')).toBeFalse();
});
tap.test('port matching - single port', async () => {
const route = createHttpRoute('example.com', { host: '127.0.0.1', port: 3000 });
// createHttpRoute defaults to port 80
const route: IRouteConfig = {
match: { ports: 80, domains: 'example.com' },
action: { type: 'forward', targets: [{ host: '127.0.0.1', port: 3000 }] }
};
expect(routeMatchesPort(route, 80)).toBeTrue();
expect(routeMatchesPort(route, 443)).toBeFalse();
});
tap.test('route finding - findBestMatchingRoute selects by priority', async () => {
const lowPriority = createHttpRoute('example.com', { host: '127.0.0.1', port: 3000 });
lowPriority.priority = 10;
const lowPriority: IRouteConfig = {
match: { ports: 80, domains: 'example.com' },
action: { type: 'forward', targets: [{ host: '127.0.0.1', port: 3000 }] },
priority: 10
};
const highPriority = createHttpRoute('example.com', { host: '127.0.0.1', port: 4000 });
highPriority.priority = 100;
const highPriority: IRouteConfig = {
match: { ports: 80, domains: 'example.com' },
action: { type: 'forward', targets: [{ host: '127.0.0.1', port: 4000 }] },
priority: 100
};
const routes: IRouteConfig[] = [lowPriority, highPriority];
const best = findBestMatchingRoute(routes, { domain: 'example.com', port: 80 });
@@ -100,9 +95,18 @@ tap.test('route finding - findBestMatchingRoute selects by priority', async () =
});
tap.test('route finding - findMatchingRoutes returns all matches', async () => {
const route1 = createHttpRoute('example.com', { host: '127.0.0.1', port: 3000 });
const route2 = createHttpRoute('example.com', { host: '127.0.0.1', port: 4000 });
const route3 = createHttpRoute('other.com', { host: '127.0.0.1', port: 5000 });
const route1: IRouteConfig = {
match: { ports: 80, domains: 'example.com' },
action: { type: 'forward', targets: [{ host: '127.0.0.1', port: 3000 }] }
};
const route2: IRouteConfig = {
match: { ports: 80, domains: 'example.com' },
action: { type: 'forward', targets: [{ host: '127.0.0.1', port: 4000 }] }
};
const route3: IRouteConfig = {
match: { ports: 80, domains: 'other.com' },
action: { type: 'forward', targets: [{ host: '127.0.0.1', port: 5000 }] }
};
const matches = findMatchingRoutes([route1, route2, route3], { domain: 'example.com', port: 80 });
expect(matches).toHaveLength(2);

View File

@@ -1,146 +0,0 @@
import { expect, tap } from '@git.zone/tstest/tapbundle';
import * as smartproxy from '../ts/index.js';
tap.test('Protocol Detection - TLS Detection', async () => {
// Test TLS handshake detection
const tlsHandshake = Buffer.from([
0x16, // Handshake record type
0x03, 0x01, // TLS 1.0
0x00, 0x05, // Length: 5 bytes
0x01, // ClientHello
0x00, 0x00, 0x01, 0x00 // Handshake length and data
]);
const detector = new smartproxy.detection.TlsDetector();
expect(detector.canHandle(tlsHandshake)).toEqual(true);
const result = detector.detect(tlsHandshake);
expect(result).toBeDefined();
expect(result?.protocol).toEqual('tls');
expect(result?.connectionInfo.tlsVersion).toEqual('TLSv1.0');
});
tap.test('Protocol Detection - HTTP Detection', async () => {
// Test HTTP request detection
const httpRequest = Buffer.from(
'GET /test HTTP/1.1\r\n' +
'Host: example.com\r\n' +
'User-Agent: TestClient/1.0\r\n' +
'\r\n'
);
const detector = new smartproxy.detection.HttpDetector();
expect(detector.canHandle(httpRequest)).toEqual(true);
const result = detector.detect(httpRequest);
expect(result).toBeDefined();
expect(result?.protocol).toEqual('http');
expect(result?.connectionInfo.method).toEqual('GET');
expect(result?.connectionInfo.path).toEqual('/test');
expect(result?.connectionInfo.domain).toEqual('example.com');
});
tap.test('Protocol Detection - Main Detector TLS', async () => {
const tlsHandshake = Buffer.from([
0x16, // Handshake record type
0x03, 0x03, // TLS 1.2
0x00, 0x05, // Length: 5 bytes
0x01, // ClientHello
0x00, 0x00, 0x01, 0x00 // Handshake length and data
]);
const result = await smartproxy.detection.ProtocolDetector.detect(tlsHandshake);
expect(result.protocol).toEqual('tls');
expect(result.connectionInfo.tlsVersion).toEqual('TLSv1.2');
});
tap.test('Protocol Detection - Main Detector HTTP', async () => {
const httpRequest = Buffer.from(
'POST /api/test HTTP/1.1\r\n' +
'Host: api.example.com\r\n' +
'Content-Type: application/json\r\n' +
'Content-Length: 2\r\n' +
'\r\n' +
'{}'
);
const result = await smartproxy.detection.ProtocolDetector.detect(httpRequest);
expect(result.protocol).toEqual('http');
expect(result.connectionInfo.method).toEqual('POST');
expect(result.connectionInfo.path).toEqual('/api/test');
expect(result.connectionInfo.domain).toEqual('api.example.com');
});
tap.test('Protocol Detection - Unknown Protocol', async () => {
const unknownData = Buffer.from('UNKNOWN PROTOCOL DATA\r\n');
const result = await smartproxy.detection.ProtocolDetector.detect(unknownData);
expect(result.protocol).toEqual('unknown');
expect(result.isComplete).toEqual(true);
});
tap.test('Protocol Detection - Fragmented HTTP', async () => {
// Create connection context
const context = smartproxy.detection.ProtocolDetector.createConnectionContext({
sourceIp: '127.0.0.1',
sourcePort: 12345,
destIp: '127.0.0.1',
destPort: 80,
socketId: 'test-connection-1'
});
// First fragment
const fragment1 = Buffer.from('GET /test HT');
let result = await smartproxy.detection.ProtocolDetector.detectWithContext(
fragment1,
context
);
expect(result.protocol).toEqual('http');
expect(result.isComplete).toEqual(false);
// Second fragment
const fragment2 = Buffer.from('TP/1.1\r\nHost: example.com\r\n\r\n');
result = await smartproxy.detection.ProtocolDetector.detectWithContext(
fragment2,
context
);
expect(result.protocol).toEqual('http');
expect(result.isComplete).toEqual(true);
expect(result.connectionInfo.method).toEqual('GET');
expect(result.connectionInfo.path).toEqual('/test');
expect(result.connectionInfo.domain).toEqual('example.com');
// Clean up fragments
smartproxy.detection.ProtocolDetector.cleanupConnection(context);
});
tap.test('Protocol Detection - HTTP Methods', async () => {
const methods = ['GET', 'POST', 'PUT', 'DELETE', 'PATCH', 'HEAD', 'OPTIONS'];
for (const method of methods) {
const request = Buffer.from(
`${method} /test HTTP/1.1\r\n` +
'Host: example.com\r\n' +
'\r\n'
);
const detector = new smartproxy.detection.HttpDetector();
const result = detector.detect(request);
expect(result?.connectionInfo.method).toEqual(method);
}
});
tap.test('Protocol Detection - Invalid Data', async () => {
// Binary data that's not a valid protocol
const binaryData = Buffer.from([0xFF, 0xFE, 0xFD, 0xFC, 0xFB]);
const result = await smartproxy.detection.ProtocolDetector.detect(binaryData);
expect(result.protocol).toEqual('unknown');
});
tap.test('cleanup detection', async () => {
// Clean up the protocol detector instance
smartproxy.detection.ProtocolDetector.destroy();
});
export default tap.start();

View File

@@ -2,146 +2,101 @@ import * as path from 'path';
import { tap, expect } from '@git.zone/tstest/tapbundle';
import { SmartProxy } from '../ts/proxies/smart-proxy/index.js';
import {
createHttpRoute,
createHttpsTerminateRoute,
createHttpsPassthroughRoute,
createHttpToHttpsRedirect,
createCompleteHttpsServer,
createLoadBalancerRoute,
createApiRoute,
createWebSocketRoute
} from '../ts/proxies/smart-proxy/utils/route-helpers.js';
import { SocketHandlers } from '../ts/proxies/smart-proxy/utils/socket-handlers.js';
import type { IRouteConfig } from '../ts/proxies/smart-proxy/models/route-types.js';
// Test to demonstrate various route configurations using the new helpers
tap.test('Route-based configuration examples', async (tools) => {
// Example 1: HTTP-only configuration
const httpOnlyRoute = createHttpRoute(
'http.example.com',
{
host: 'localhost',
port: 3000
},
{
name: 'Basic HTTP Route'
}
);
const httpOnlyRoute: IRouteConfig = {
match: { ports: 80, domains: 'http.example.com' },
action: { type: 'forward', targets: [{ host: 'localhost', port: 3000 }] },
name: 'Basic HTTP Route'
};
console.log('HTTP-only route created successfully:', httpOnlyRoute.name);
expect(httpOnlyRoute.action.type).toEqual('forward');
expect(httpOnlyRoute.match.domains).toEqual('http.example.com');
// Example 2: HTTPS Passthrough (SNI) configuration
const httpsPassthroughRoute = createHttpsPassthroughRoute(
'pass.example.com',
{
host: ['10.0.0.1', '10.0.0.2'], // Round-robin target IPs
port: 443
},
{
name: 'HTTPS Passthrough Route'
}
);
const httpsPassthroughRoute: IRouteConfig = {
match: { ports: 443, domains: 'pass.example.com' },
action: { type: 'forward', targets: [{ host: '10.0.0.1', port: 443 }, { host: '10.0.0.2', port: 443 }], tls: { mode: 'passthrough' } },
name: 'HTTPS Passthrough Route'
};
expect(httpsPassthroughRoute).toBeTruthy();
expect(httpsPassthroughRoute.action.tls?.mode).toEqual('passthrough');
expect(Array.isArray(httpsPassthroughRoute.action.targets)).toBeTrue();
// Example 3: HTTPS Termination to HTTP Backend
const terminateToHttpRoute = createHttpsTerminateRoute(
'secure.example.com',
{
host: 'localhost',
port: 8080
},
{
certificate: 'auto',
name: 'HTTPS Termination to HTTP Backend'
}
);
const terminateToHttpRoute: IRouteConfig = {
match: { ports: 443, domains: 'secure.example.com' },
action: { type: 'forward', targets: [{ host: 'localhost', port: 8080 }], tls: { mode: 'terminate', certificate: 'auto' } },
name: 'HTTPS Termination to HTTP Backend'
};
// Create the HTTP to HTTPS redirect for this domain
const httpToHttpsRedirect = createHttpToHttpsRedirect(
'secure.example.com',
443,
{
name: 'HTTP to HTTPS Redirect for secure.example.com'
}
);
const httpToHttpsRedirect: IRouteConfig = {
match: { ports: 80, domains: 'secure.example.com' },
action: { type: 'socket-handler', socketHandler: SocketHandlers.httpRedirect('https://{domain}:443{path}', 301) },
name: 'HTTP to HTTPS Redirect for secure.example.com'
};
expect(terminateToHttpRoute).toBeTruthy();
expect(terminateToHttpRoute.action.tls?.mode).toEqual('terminate');
expect(httpToHttpsRedirect.action.type).toEqual('socket-handler');
// Example 4: Load Balancer with HTTPS
const loadBalancerRoute = createLoadBalancerRoute(
'proxy.example.com',
['internal-api-1.local', 'internal-api-2.local'],
8443,
{
tls: {
mode: 'terminate-and-reencrypt',
certificate: 'auto'
},
name: 'Load Balanced HTTPS Route'
}
);
const loadBalancerRoute: IRouteConfig = {
match: { ports: 443, domains: 'proxy.example.com' },
action: {
type: 'forward',
targets: [
{ host: 'internal-api-1.local', port: 8443 },
{ host: 'internal-api-2.local', port: 8443 }
],
tls: { mode: 'terminate-and-reencrypt', certificate: 'auto' }
},
name: 'Load Balanced HTTPS Route'
};
expect(loadBalancerRoute).toBeTruthy();
expect(loadBalancerRoute.action.tls?.mode).toEqual('terminate-and-reencrypt');
expect(Array.isArray(loadBalancerRoute.action.targets)).toBeTrue();
// Example 5: API Route
const apiRoute = createApiRoute(
'api.example.com',
'/api',
{ host: 'localhost', port: 8081 },
{
name: 'API Route',
useTls: true,
addCorsHeaders: true
}
);
const apiRoute: IRouteConfig = {
match: { ports: 443, domains: 'api.example.com', path: '/api' },
action: {
type: 'forward',
targets: [{ host: 'localhost', port: 8081 }],
tls: { mode: 'terminate', certificate: 'auto' }
},
name: 'API Route'
};
expect(apiRoute.action.type).toEqual('forward');
expect(apiRoute.match.path).toBeTruthy();
// Example 6: Complete HTTPS Server with HTTP Redirect
const httpsServerRoutes = createCompleteHttpsServer(
'complete.example.com',
{
host: 'localhost',
port: 8080
const httpsRoute: IRouteConfig = {
match: { ports: 443, domains: 'complete.example.com' },
action: { type: 'forward', targets: [{ host: 'localhost', port: 8080 }], tls: { mode: 'terminate', certificate: 'auto' } },
name: 'Complete HTTPS Server'
};
const httpsRedirectRoute: IRouteConfig = {
match: { ports: 80, domains: 'complete.example.com' },
action: { type: 'socket-handler', socketHandler: SocketHandlers.httpRedirect('https://{domain}:443{path}', 301) },
name: 'Complete HTTPS Server - Redirect'
};
const webSocketRoute: IRouteConfig = {
match: { ports: 443, domains: 'ws.example.com', path: '/ws' },
action: {
type: 'forward',
targets: [{ host: 'localhost', port: 8082 }],
tls: { mode: 'terminate', certificate: 'auto' },
websocket: { enabled: true }
},
{
certificate: 'auto',
name: 'Complete HTTPS Server'
}
);
expect(Array.isArray(httpsServerRoutes)).toBeTrue();
expect(httpsServerRoutes.length).toEqual(2); // HTTPS route and HTTP redirect
expect(httpsServerRoutes[0].action.tls?.mode).toEqual('terminate');
expect(httpsServerRoutes[1].action.type).toEqual('socket-handler');
// Example 7: Static File Server - removed (use nginx/apache behind proxy)
// Example 8: WebSocket Route
const webSocketRoute = createWebSocketRoute(
'ws.example.com',
'/ws',
{ host: 'localhost', port: 8082 },
{
useTls: true,
name: 'WebSocket Route'
}
);
name: 'WebSocket Route'
};
expect(webSocketRoute.action.type).toEqual('forward');
expect(webSocketRoute.action.websocket?.enabled).toBeTrue();
// Create a SmartProxy instance with all routes
const allRoutes: IRouteConfig[] = [
httpOnlyRoute,
httpsPassthroughRoute,
@@ -149,19 +104,17 @@ tap.test('Route-based configuration examples', async (tools) => {
httpToHttpsRedirect,
loadBalancerRoute,
apiRoute,
...httpsServerRoutes,
httpsRoute,
httpsRedirectRoute,
webSocketRoute
];
// We're not actually starting the SmartProxy in this test,
// just verifying that the configuration is valid
const smartProxy = new SmartProxy({
routes: allRoutes
});
// Just verify that all routes are configured correctly
console.log(`Created ${allRoutes.length} example routes`);
expect(allRoutes.length).toEqual(9); // One less without static file route
expect(allRoutes.length).toEqual(9);
});
export default tap.start();
export default tap.start();

View File

@@ -1,27 +1,8 @@
import { tap, expect } from '@git.zone/tstest/tapbundle';
import * as plugins from '../ts/plugins.js';
// Import route-based helpers
import {
createHttpRoute,
createHttpsTerminateRoute,
createHttpsPassthroughRoute,
createHttpToHttpsRedirect,
createCompleteHttpsServer
} from '../ts/proxies/smart-proxy/utils/route-helpers.js';
import type { IRouteConfig } from '../ts/proxies/smart-proxy/models/route-types.js';
// Create helper functions for backward compatibility
const helpers = {
httpOnly: (domains: string | string[], target: any) => createHttpRoute(domains, target),
tlsTerminateToHttp: (domains: string | string[], target: any) =>
createHttpsTerminateRoute(domains, target),
tlsTerminateToHttps: (domains: string | string[], target: any) =>
createHttpsTerminateRoute(domains, target, { reencrypt: true }),
httpsPassthrough: (domains: string | string[], target: any) =>
createHttpsPassthroughRoute(domains, target)
};
// Route-based utility functions for testing
function findRouteForDomain(routes: any[], domain: string): any {
return routes.find(route => {
const domains = Array.isArray(route.match.domains)
@@ -31,55 +12,44 @@ function findRouteForDomain(routes: any[], domain: string): any {
});
}
// Replace the old test with route-based tests
tap.test('Route Helpers - Create HTTP routes', async () => {
const route = helpers.httpOnly('example.com', { host: 'localhost', port: 3000 });
const route: IRouteConfig = {
match: { ports: 80, domains: 'example.com' },
action: { type: 'forward', targets: [{ host: 'localhost', port: 3000 }] }
};
expect(route.action.type).toEqual('forward');
expect(route.match.domains).toEqual('example.com');
expect(route.action.targets?.[0]).toEqual({ host: 'localhost', port: 3000 });
});
tap.test('Route Helpers - Create HTTPS terminate to HTTP routes', async () => {
const route = helpers.tlsTerminateToHttp('secure.example.com', { host: 'localhost', port: 3000 });
const route: IRouteConfig = {
match: { ports: 443, domains: 'secure.example.com' },
action: { type: 'forward', targets: [{ host: 'localhost', port: 3000 }], tls: { mode: 'terminate', certificate: 'auto' } }
};
expect(route.action.type).toEqual('forward');
expect(route.match.domains).toEqual('secure.example.com');
expect(route.action.tls?.mode).toEqual('terminate');
});
tap.test('Route Helpers - Create HTTPS passthrough routes', async () => {
const route = helpers.httpsPassthrough('passthrough.example.com', { host: 'backend', port: 443 });
const route: IRouteConfig = {
match: { ports: 443, domains: 'passthrough.example.com' },
action: { type: 'forward', targets: [{ host: 'backend', port: 443 }], tls: { mode: 'passthrough' } }
};
expect(route.action.type).toEqual('forward');
expect(route.match.domains).toEqual('passthrough.example.com');
expect(route.action.tls?.mode).toEqual('passthrough');
});
tap.test('Route Helpers - Create HTTPS to HTTPS routes', async () => {
const route = helpers.tlsTerminateToHttps('reencrypt.example.com', { host: 'backend', port: 443 });
const route: IRouteConfig = {
match: { ports: 443, domains: 'reencrypt.example.com' },
action: { type: 'forward', targets: [{ host: 'backend', port: 443 }], tls: { mode: 'terminate-and-reencrypt', certificate: 'auto' } }
};
expect(route.action.type).toEqual('forward');
expect(route.match.domains).toEqual('reencrypt.example.com');
expect(route.action.tls?.mode).toEqual('terminate-and-reencrypt');
});
tap.test('Route Helpers - Create complete HTTPS server with redirect', async () => {
const routes = createCompleteHttpsServer(
'full.example.com',
{ host: 'localhost', port: 3000 },
{ certificate: 'auto' }
);
expect(routes.length).toEqual(2);
// Check HTTP to HTTPS redirect - find route by port
const redirectRoute = routes.find(r => r.match.ports === 80);
expect(redirectRoute.action.type).toEqual('socket-handler');
expect(redirectRoute.action.socketHandler).toBeDefined();
expect(redirectRoute.match.ports).toEqual(80);
// Check HTTPS route
const httpsRoute = routes.find(r => r.action.type === 'forward');
expect(httpsRoute.match.ports).toEqual(443);
expect(httpsRoute.action.tls?.mode).toEqual('terminate');
});
// Export test runner
export default tap.start();
export default tap.start();

View File

@@ -1,128 +0,0 @@
import { expect, tap } from '@git.zone/tstest/tapbundle';
import * as smartproxy from '../ts/index.js';
import { RouteValidator } from '../ts/proxies/smart-proxy/utils/route-validator.js';
import { IpUtils } from '../ts/core/utils/ip-utils.js';
tap.test('IP Validation - Shorthand patterns', async () => {
// Test shorthand patterns are now accepted
const testPatterns = [
{ pattern: '192.168.*', shouldPass: true },
{ pattern: '192.168.*.*', shouldPass: true },
{ pattern: '10.*', shouldPass: true },
{ pattern: '10.*.*.*', shouldPass: true },
{ pattern: '172.16.*', shouldPass: true },
{ pattern: '10.0.0.0/8', shouldPass: true },
{ pattern: '192.168.0.0/16', shouldPass: true },
{ pattern: '192.168.1.100', shouldPass: true },
{ pattern: '*', shouldPass: true },
{ pattern: '192.168.1.1-192.168.1.100', shouldPass: true },
];
for (const { pattern, shouldPass } of testPatterns) {
const route = {
name: 'test',
match: { ports: 80 },
action: { type: 'forward' as const, targets: [{ host: 'localhost', port: 8080 }] },
security: { ipAllowList: [pattern] }
};
const result = RouteValidator.validateRoute(route);
if (shouldPass) {
expect(result.valid).toEqual(true);
console.log(`✅ Pattern '${pattern}' correctly accepted`);
} else {
expect(result.valid).toEqual(false);
console.log(`✅ Pattern '${pattern}' correctly rejected`);
}
}
});
tap.test('IP Matching - Runtime shorthand pattern matching', async () => {
// Test runtime matching with shorthand patterns
const testCases = [
{ ip: '192.168.1.100', patterns: ['192.168.*'], expected: true },
{ ip: '192.168.1.100', patterns: ['192.168.1.*'], expected: true },
{ ip: '192.168.1.100', patterns: ['192.168.2.*'], expected: false },
{ ip: '10.0.0.1', patterns: ['10.*'], expected: true },
{ ip: '10.1.2.3', patterns: ['10.*'], expected: true },
{ ip: '172.16.0.1', patterns: ['10.*'], expected: false },
{ ip: '192.168.1.1', patterns: ['192.168.*.*'], expected: true },
];
for (const { ip, patterns, expected } of testCases) {
const result = IpUtils.isGlobIPMatch(ip, patterns);
expect(result).toEqual(expected);
console.log(`✅ IP ${ip} with pattern ${patterns[0]} = ${result} (expected ${expected})`);
}
});
tap.test('IP Matching - CIDR notation', async () => {
// Test CIDR notation matching
const cidrTests = [
{ ip: '10.0.0.1', cidr: '10.0.0.0/8', expected: true },
{ ip: '10.255.255.255', cidr: '10.0.0.0/8', expected: true },
{ ip: '11.0.0.1', cidr: '10.0.0.0/8', expected: false },
{ ip: '192.168.1.1', cidr: '192.168.0.0/16', expected: true },
{ ip: '192.168.255.255', cidr: '192.168.0.0/16', expected: true },
{ ip: '192.169.0.1', cidr: '192.168.0.0/16', expected: false },
{ ip: '192.168.1.100', cidr: '192.168.1.0/24', expected: true },
{ ip: '192.168.2.100', cidr: '192.168.1.0/24', expected: false },
];
for (const { ip, cidr, expected } of cidrTests) {
const result = IpUtils.isGlobIPMatch(ip, [cidr]);
expect(result).toEqual(expected);
console.log(`✅ IP ${ip} in CIDR ${cidr} = ${result} (expected ${expected})`);
}
});
tap.test('IP Matching - Range notation', async () => {
// Test range notation matching
const rangeTests = [
{ ip: '192.168.1.1', range: '192.168.1.1-192.168.1.100', expected: true },
{ ip: '192.168.1.50', range: '192.168.1.1-192.168.1.100', expected: true },
{ ip: '192.168.1.100', range: '192.168.1.1-192.168.1.100', expected: true },
{ ip: '192.168.1.101', range: '192.168.1.1-192.168.1.100', expected: false },
{ ip: '192.168.2.50', range: '192.168.1.1-192.168.1.100', expected: false },
];
for (const { ip, range, expected } of rangeTests) {
const result = IpUtils.isGlobIPMatch(ip, [range]);
expect(result).toEqual(expected);
console.log(`✅ IP ${ip} in range ${range} = ${result} (expected ${expected})`);
}
});
tap.test('IP Matching - Mixed patterns', async () => {
// Test with mixed pattern types
const allowList = [
'10.0.0.0/8', // CIDR
'192.168.*', // Shorthand glob
'172.16.1.*', // Specific subnet glob
'8.8.8.8', // Single IP
'1.1.1.1-1.1.1.10' // Range
];
const tests = [
{ ip: '10.1.2.3', expected: true }, // Matches CIDR
{ ip: '192.168.100.1', expected: true }, // Matches shorthand glob
{ ip: '172.16.1.5', expected: true }, // Matches specific glob
{ ip: '8.8.8.8', expected: true }, // Matches single IP
{ ip: '1.1.1.5', expected: true }, // Matches range
{ ip: '9.9.9.9', expected: false }, // Doesn't match any
];
for (const { ip, expected } of tests) {
const result = IpUtils.isGlobIPMatch(ip, allowList);
expect(result).toEqual(expected);
console.log(`✅ IP ${ip} in mixed patterns = ${result} (expected ${expected})`);
}
});
export default tap.start();

View File

@@ -1,112 +0,0 @@
import { expect, tap } from '@git.zone/tstest/tapbundle';
import { LogDeduplicator } from '../ts/core/utils/log-deduplicator.js';
let deduplicator: LogDeduplicator;
tap.test('Setup log deduplicator', async () => {
deduplicator = new LogDeduplicator(1000); // 1 second flush interval for testing
});
tap.test('Connection rejection deduplication', async (tools) => {
// Simulate multiple connection rejections
for (let i = 0; i < 10; i++) {
deduplicator.log(
'connection-rejected',
'warn',
'Connection rejected',
{ reason: 'global-limit', component: 'test' },
'global-limit'
);
}
for (let i = 0; i < 5; i++) {
deduplicator.log(
'connection-rejected',
'warn',
'Connection rejected',
{ reason: 'route-limit', component: 'test' },
'route-limit'
);
}
// Force flush
deduplicator.flush('connection-rejected');
// The logs should have been aggregated
// (Can't easily test the actual log output, but we can verify the mechanism works)
expect(deduplicator).toBeInstanceOf(LogDeduplicator);
});
tap.test('IP rejection deduplication', async (tools) => {
// Simulate rejections from multiple IPs
const ips = ['192.168.1.100', '192.168.1.101', '192.168.1.100', '10.0.0.1'];
const reasons = ['per-ip-limit', 'rate-limit', 'per-ip-limit', 'global-limit'];
for (let i = 0; i < ips.length; i++) {
deduplicator.log(
'ip-rejected',
'warn',
`Connection rejected from ${ips[i]}`,
{ remoteIP: ips[i], reason: reasons[i] },
ips[i]
);
}
// Add more rejections from the same IP
for (let i = 0; i < 20; i++) {
deduplicator.log(
'ip-rejected',
'warn',
'Connection rejected from 192.168.1.100',
{ remoteIP: '192.168.1.100', reason: 'rate-limit' },
'192.168.1.100'
);
}
// Force flush
deduplicator.flush('ip-rejected');
// Verify the deduplicator exists and works
expect(deduplicator).toBeInstanceOf(LogDeduplicator);
});
tap.test('Connection cleanup deduplication', async (tools) => {
// Simulate various cleanup events
const reasons = ['normal', 'timeout', 'error', 'normal', 'zombie'];
for (const reason of reasons) {
for (let i = 0; i < 5; i++) {
deduplicator.log(
'connection-cleanup',
'info',
`Connection cleanup: ${reason}`,
{ connectionId: `conn-${i}`, reason },
reason
);
}
}
// Wait for automatic flush
await tools.delayFor(1500);
// Verify deduplicator is working
expect(deduplicator).toBeInstanceOf(LogDeduplicator);
});
tap.test('Automatic periodic flush', async (tools) => {
// Add some events
deduplicator.log('test-event', 'info', 'Test message', {}, 'test');
// Wait for automatic flush (should happen within 2x flush interval = 2 seconds)
await tools.delayFor(2500);
// Events should have been flushed automatically
expect(deduplicator).toBeInstanceOf(LogDeduplicator);
});
tap.test('Cleanup deduplicator', async () => {
deduplicator.cleanup();
expect(deduplicator).toBeInstanceOf(LogDeduplicator);
});
export default tap.start();

View File

@@ -1,15 +1,14 @@
import { SmartProxy } from '../ts/proxies/smart-proxy/index.js';
import { createNfTablesRoute, createNfTablesTerminateRoute } from '../ts/proxies/smart-proxy/utils/route-helpers.js';
import { expect, tap } from '@git.zone/tstest/tapbundle';
import * as child_process from 'child_process';
import { promisify } from 'util';
import type { IRouteConfig } from '../ts/proxies/smart-proxy/models/route-types.js';
const exec = promisify(child_process.exec);
// Check if we have root privileges to run NFTables tests
async function checkRootPrivileges(): Promise<boolean> {
try {
// Check if we're running as root
const { stdout } = await exec('id -u');
return stdout.trim() === '0';
} catch (err) {
@@ -17,7 +16,6 @@ async function checkRootPrivileges(): Promise<boolean> {
}
}
// Check if tests should run
const isRoot = await checkRootPrivileges();
if (!isRoot) {
@@ -29,68 +27,70 @@ if (!isRoot) {
console.log('');
}
// Define the test with proper skip condition
const testFn = isRoot ? tap.test : tap.skip.test;
testFn('NFTables integration tests', async () => {
console.log('Running NFTables tests with root privileges');
// Create test routes
const routes = [
createNfTablesRoute('tcp-forward', {
host: 'localhost',
port: 8080
}, {
ports: 9080,
protocol: 'tcp'
}),
createNfTablesRoute('udp-forward', {
host: 'localhost',
port: 5353
}, {
ports: 5354,
protocol: 'udp'
}),
createNfTablesRoute('port-range', {
host: 'localhost',
port: 8080
}, {
ports: [{ from: 9000, to: 9100 }],
protocol: 'tcp'
})
const routes: IRouteConfig[] = [
{
match: { ports: 9080 },
action: {
type: 'forward',
forwardingEngine: 'nftables',
targets: [{ host: 'localhost', port: 8080 }],
nftables: { protocol: 'tcp' }
},
name: 'tcp-forward'
},
{
match: { ports: 5354 },
action: {
type: 'forward',
forwardingEngine: 'nftables',
targets: [{ host: 'localhost', port: 5353 }],
nftables: { protocol: 'udp' }
},
name: 'udp-forward'
},
{
match: { ports: [{ from: 9000, to: 9100 }] },
action: {
type: 'forward',
forwardingEngine: 'nftables',
targets: [{ host: 'localhost', port: 8080 }],
nftables: { protocol: 'tcp' }
},
name: 'port-range'
}
];
const smartProxy = new SmartProxy({
enableDetailedLogging: true,
routes
});
// Start the proxy
await smartProxy.start();
console.log('SmartProxy started with NFTables routes');
// Get NFTables status
const status = await smartProxy.getNfTablesStatus();
console.log('NFTables status:', JSON.stringify(status, null, 2));
// Verify all routes are provisioned
expect(Object.keys(status).length).toEqual(routes.length);
for (const routeStatus of Object.values(status)) {
expect(routeStatus.active).toBeTrue();
expect(routeStatus.ruleCount.total).toBeGreaterThan(0);
}
// Stop the proxy
await smartProxy.stop();
console.log('SmartProxy stopped');
// Verify all rules are cleaned up
const finalStatus = await smartProxy.getNfTablesStatus();
expect(Object.keys(finalStatus).length).toEqual(0);
});
export default tap.start();
export default tap.start();

View File

@@ -1,5 +1,4 @@
import { SmartProxy } from '../ts/proxies/smart-proxy/index.js';
import { createNfTablesRoute, createNfTablesTerminateRoute } from '../ts/proxies/smart-proxy/utils/route-helpers.js';
import { expect, tap } from '@git.zone/tstest/tapbundle';
import * as net from 'net';
import * as http from 'http';
@@ -10,13 +9,13 @@ import { fileURLToPath } from 'url';
import * as child_process from 'child_process';
import { promisify } from 'util';
import type { IRouteConfig } from '../ts/proxies/smart-proxy/models/route-types.js';
const exec = promisify(child_process.exec);
// Get __dirname equivalent for ES modules
const __filename = fileURLToPath(import.meta.url);
const __dirname = path.dirname(__filename);
// Check if we have root privileges
async function checkRootPrivileges(): Promise<boolean> {
try {
const { stdout } = await exec('id -u');
@@ -26,7 +25,6 @@ async function checkRootPrivileges(): Promise<boolean> {
}
}
// Check if tests should run
const runTests = await checkRootPrivileges();
if (!runTests) {
@@ -36,10 +34,8 @@ if (!runTests) {
console.log('Skipping NFTables integration tests');
console.log('========================================');
console.log('');
// Skip tests when not running as root - tests are marked with tap.skip.test
}
// Test server and client utilities
let testTcpServer: net.Server;
let testHttpServer: http.Server;
let testHttpsServer: https.Server;
@@ -53,10 +49,8 @@ const PROXY_HTTP_PORT = 5001;
const PROXY_HTTPS_PORT = 5002;
const TEST_DATA = 'Hello through NFTables!';
// Helper to create test certificates
async function createTestCertificates() {
try {
// Import the certificate helper
const certsModule = await import('./helpers/certificates.js');
const certificates = certsModule.loadTestCertificates();
return {
@@ -65,7 +59,6 @@ async function createTestCertificates() {
};
} catch (err) {
console.error('Failed to load test certificates:', err);
// Use dummy certificates for testing
return {
cert: fs.readFileSync(path.join(__dirname, '..', 'assets', 'certs', 'cert.pem'), 'utf8'),
key: fs.readFileSync(path.join(__dirname, '..', 'assets', 'certs', 'key.pem'), 'utf8')
@@ -75,111 +68,112 @@ async function createTestCertificates() {
tap.skip.test('setup NFTables integration test environment', async () => {
console.log('Running NFTables integration tests with root privileges');
// Create a basic TCP test server
testTcpServer = net.createServer((socket) => {
socket.on('data', (data) => {
socket.write(`Server says: ${data.toString()}`);
});
});
await new Promise<void>((resolve) => {
testTcpServer.listen(TEST_TCP_PORT, () => {
console.log(`TCP test server listening on port ${TEST_TCP_PORT}`);
resolve();
});
});
// Create an HTTP test server
testHttpServer = http.createServer((req, res) => {
res.writeHead(200, { 'Content-Type': 'text/plain' });
res.end(`HTTP Server says: ${TEST_DATA}`);
});
await new Promise<void>((resolve) => {
testHttpServer.listen(TEST_HTTP_PORT, () => {
console.log(`HTTP test server listening on port ${TEST_HTTP_PORT}`);
resolve();
});
});
// Create an HTTPS test server
const certs = await createTestCertificates();
testHttpsServer = https.createServer({ key: certs.key, cert: certs.cert }, (req, res) => {
res.writeHead(200, { 'Content-Type': 'text/plain' });
res.end(`HTTPS Server says: ${TEST_DATA}`);
});
await new Promise<void>((resolve) => {
testHttpsServer.listen(TEST_HTTPS_PORT, () => {
console.log(`HTTPS test server listening on port ${TEST_HTTPS_PORT}`);
resolve();
});
});
// Create SmartProxy with various NFTables routes
smartProxy = new SmartProxy({
enableDetailedLogging: true,
routes: [
// TCP forwarding route
createNfTablesRoute('tcp-nftables', {
host: 'localhost',
port: TEST_TCP_PORT
}, {
ports: PROXY_TCP_PORT,
protocol: 'tcp'
}),
// HTTP forwarding route
createNfTablesRoute('http-nftables', {
host: 'localhost',
port: TEST_HTTP_PORT
}, {
ports: PROXY_HTTP_PORT,
protocol: 'tcp'
}),
// HTTPS termination route
createNfTablesTerminateRoute('https-nftables.example.com', {
host: 'localhost',
port: TEST_HTTPS_PORT
}, {
ports: PROXY_HTTPS_PORT,
protocol: 'tcp',
certificate: certs
}),
// Route with IP allow list
createNfTablesRoute('secure-tcp', {
host: 'localhost',
port: TEST_TCP_PORT
}, {
ports: 5003,
protocol: 'tcp',
ipAllowList: ['127.0.0.1', '::1']
}),
// Route with QoS settings
createNfTablesRoute('qos-tcp', {
host: 'localhost',
port: TEST_TCP_PORT
}, {
ports: 5004,
protocol: 'tcp',
maxRate: '10mbps',
priority: 1
})
{
match: { ports: PROXY_TCP_PORT },
action: {
type: 'forward',
forwardingEngine: 'nftables',
targets: [{ host: 'localhost', port: TEST_TCP_PORT }],
nftables: { protocol: 'tcp' }
},
name: 'tcp-nftables'
},
{
match: { ports: PROXY_HTTP_PORT },
action: {
type: 'forward',
forwardingEngine: 'nftables',
targets: [{ host: 'localhost', port: TEST_HTTP_PORT }],
nftables: { protocol: 'tcp' }
},
name: 'http-nftables'
},
{
match: { ports: PROXY_HTTPS_PORT, domains: 'https-nftables.example.com' },
action: {
type: 'forward',
forwardingEngine: 'nftables',
targets: [{ host: 'localhost', port: TEST_HTTPS_PORT }],
tls: { mode: 'terminate', certificate: 'auto' },
nftables: { protocol: 'tcp' }
},
name: 'https-nftables'
},
{
match: { ports: 5003 },
action: {
type: 'forward',
forwardingEngine: 'nftables',
targets: [{ host: 'localhost', port: TEST_TCP_PORT }],
nftables: { protocol: 'tcp', ipAllowList: ['127.0.0.1', '::1'] }
},
name: 'secure-tcp'
},
{
match: { ports: 5004 },
action: {
type: 'forward',
forwardingEngine: 'nftables',
targets: [{ host: 'localhost', port: TEST_TCP_PORT }],
nftables: { protocol: 'tcp', maxRate: '10mbps', priority: 1 }
},
name: 'qos-tcp'
}
]
});
console.log('SmartProxy created, now starting...');
// Start the proxy
try {
await smartProxy.start();
console.log('SmartProxy started successfully');
// Verify proxy is listening on expected ports
const listeningPorts = smartProxy.getListeningPorts();
console.log(`SmartProxy is listening on ports: ${listeningPorts.join(', ')}`);
} catch (err) {
@@ -190,8 +184,7 @@ tap.skip.test('setup NFTables integration test environment', async () => {
tap.skip.test('should forward TCP connections through NFTables', async () => {
console.log(`Attempting to connect to proxy TCP port ${PROXY_TCP_PORT}...`);
// First verify our test server is running
try {
const testClient = new net.Socket();
await new Promise<void>((resolve, reject) => {
@@ -205,40 +198,39 @@ tap.skip.test('should forward TCP connections through NFTables', async () => {
} catch (err) {
console.error(`Test server on port ${TEST_TCP_PORT} is not accessible: ${err}`);
}
// Connect to the proxy port
const client = new net.Socket();
const response = await new Promise<string>((resolve, reject) => {
let responseData = '';
const timeout = setTimeout(() => {
client.destroy();
reject(new Error(`Connection timeout after 5 seconds to proxy port ${PROXY_TCP_PORT}`));
}, 5000);
client.connect(PROXY_TCP_PORT, 'localhost', () => {
console.log(`Connected to proxy port ${PROXY_TCP_PORT}, sending data...`);
client.write(TEST_DATA);
});
client.on('data', (data) => {
console.log(`Received data from proxy: ${data.toString()}`);
responseData += data.toString();
client.end();
});
client.on('end', () => {
clearTimeout(timeout);
resolve(responseData);
});
client.on('error', (err) => {
clearTimeout(timeout);
console.error(`Connection error on proxy port ${PROXY_TCP_PORT}: ${err.message}`);
reject(err);
});
});
expect(response).toEqual(`Server says: ${TEST_DATA}`);
});
@@ -254,21 +246,20 @@ tap.skip.test('should forward HTTP connections through NFTables', async () => {
});
}).on('error', reject);
});
expect(response).toEqual(`HTTP Server says: ${TEST_DATA}`);
});
tap.skip.test('should handle HTTPS termination with NFTables', async () => {
// Skip this test if running without proper certificates
const response = await new Promise<string>((resolve, reject) => {
const options = {
hostname: 'localhost',
port: PROXY_HTTPS_PORT,
path: '/',
method: 'GET',
rejectUnauthorized: false // For self-signed cert
rejectUnauthorized: false
};
https.get(options, (res) => {
let data = '';
res.on('data', (chunk) => {
@@ -279,43 +270,40 @@ tap.skip.test('should handle HTTPS termination with NFTables', async () => {
});
}).on('error', reject);
});
expect(response).toEqual(`HTTPS Server says: ${TEST_DATA}`);
});
tap.skip.test('should respect IP allow lists in NFTables', async () => {
// This test should pass since we're connecting from localhost
const client = new net.Socket();
const connected = await new Promise<boolean>((resolve) => {
const timeout = setTimeout(() => {
client.destroy();
resolve(false);
}, 2000);
client.connect(5003, 'localhost', () => {
clearTimeout(timeout);
client.end();
resolve(true);
});
client.on('error', () => {
clearTimeout(timeout);
resolve(false);
});
});
expect(connected).toBeTrue();
});
tap.skip.test('should get NFTables status', async () => {
const status = await smartProxy.getNfTablesStatus();
// Check that we have status for our routes
const statusKeys = Object.keys(status);
expect(statusKeys.length).toBeGreaterThan(0);
// Check status structure for one of the routes
const firstStatus = status[statusKeys[0]];
expect(firstStatus).toHaveProperty('active');
expect(firstStatus).toHaveProperty('ruleCount');
@@ -324,21 +312,20 @@ tap.skip.test('should get NFTables status', async () => {
});
tap.skip.test('cleanup NFTables integration test environment', async () => {
// Stop the proxy and test servers
await smartProxy.stop();
await new Promise<void>((resolve) => {
testTcpServer.close(() => {
resolve();
});
});
await new Promise<void>((resolve) => {
testHttpServer.close(() => {
resolve();
});
});
await new Promise<void>((resolve) => {
testHttpsServer.close(() => {
resolve();
@@ -346,4 +333,4 @@ tap.skip.test('cleanup NFTables integration test environment', async () => {
});
});
export default tap.start();
export default tap.start();

View File

@@ -1,30 +1,20 @@
import { expect, tap } from '@git.zone/tstest/tapbundle';
import * as net from 'net';
import { SmartProxy } from '../ts/proxies/smart-proxy/index.js';
import {
createPortMappingRoute,
createOffsetPortMappingRoute,
createDynamicRoute,
createSmartLoadBalancer,
createPortOffset
} from '../ts/proxies/smart-proxy/utils/route-helpers.js';
import type { IRouteConfig, IRouteContext } from '../ts/proxies/smart-proxy/models/route-types.js';
import { findFreePorts, assertPortsFree } from './helpers/port-allocator.js';
// Test server and client utilities
let testServers: Array<{ server: net.Server; port: number }> = [];
let smartProxy: SmartProxy;
let TEST_PORTS: number[]; // 3 test server ports
let PROXY_PORTS: number[]; // 6 proxy ports
let TEST_PORTS: number[];
let PROXY_PORTS: number[];
const TEST_DATA = 'Hello through dynamic port mapper!';
// Cleanup function to close all servers and proxies
function cleanup() {
console.log('Starting cleanup...');
const promises = [];
// Close test servers
for (const { server, port } of testServers) {
promises.push(new Promise<void>(resolve => {
console.log(`Closing test server on port ${port}`);
@@ -34,31 +24,28 @@ function cleanup() {
});
}));
}
// Stop SmartProxy
if (smartProxy) {
console.log('Stopping SmartProxy...');
promises.push(smartProxy.stop().then(() => {
console.log('SmartProxy stopped');
}));
}
return Promise.all(promises);
}
// Helper: Creates a test TCP server that listens on a given port
function createTestServer(port: number): Promise<net.Server> {
return new Promise((resolve) => {
const server = net.createServer((socket) => {
socket.on('data', (data) => {
// Echo the received data back with a server identifier
socket.write(`Server ${port} says: ${data.toString()}`);
});
socket.on('error', (error) => {
console.error(`[Test Server] Socket error on port ${port}:`, error);
});
});
server.listen(port, () => {
console.log(`[Test Server] Listening on port ${port}`);
testServers.push({ server, port });
@@ -67,32 +54,31 @@ function createTestServer(port: number): Promise<net.Server> {
});
}
// Helper: Creates a test client connection with timeout
function createTestClient(port: number, data: string): Promise<string> {
return new Promise((resolve, reject) => {
const client = new net.Socket();
let response = '';
const timeout = setTimeout(() => {
client.destroy();
reject(new Error(`Client connection timeout to port ${port}`));
}, 5000);
client.connect(port, 'localhost', () => {
console.log(`[Test Client] Connected to server on port ${port}`);
client.write(data);
});
client.on('data', (chunk) => {
response += chunk.toString();
client.end();
});
client.on('end', () => {
clearTimeout(timeout);
resolve(response);
});
client.on('error', (error) => {
clearTimeout(timeout);
reject(error);
@@ -100,123 +86,108 @@ function createTestClient(port: number, data: string): Promise<string> {
});
}
// Set up test environment
tap.test('setup port mapping test environment', async () => {
const allPorts = await findFreePorts(9);
TEST_PORTS = allPorts.slice(0, 3);
PROXY_PORTS = allPorts.slice(3, 9);
// Create multiple test servers on different ports
await Promise.all([
createTestServer(TEST_PORTS[0]),
createTestServer(TEST_PORTS[1]),
createTestServer(TEST_PORTS[2]),
]);
// Compute dynamic offset between proxy and test ports
const portOffset = TEST_PORTS[1] - PROXY_PORTS[1];
// Create a SmartProxy with dynamic port mapping routes
smartProxy = new SmartProxy({
routes: [
// Simple function that returns the same port (identity mapping)
createPortMappingRoute({
sourcePortRange: PROXY_PORTS[0],
targetHost: 'localhost',
portMapper: (context) => TEST_PORTS[0],
name: 'Identity Port Mapping'
}),
// Offset port mapping using dynamic offset
createOffsetPortMappingRoute({
ports: PROXY_PORTS[1],
targetHost: 'localhost',
offset: portOffset,
name: `Offset Port Mapping (${portOffset})`
}),
// Dynamic route with conditional port mapping
createDynamicRoute({
ports: [PROXY_PORTS[2], PROXY_PORTS[3]],
targetHost: (context) => {
// Dynamic host selection based on port
return context.port === PROXY_PORTS[2] ? 'localhost' : '127.0.0.1';
{
match: { ports: PROXY_PORTS[0] },
action: {
type: 'forward',
targets: [{
host: 'localhost',
port: (context: IRouteContext) => TEST_PORTS[0]
}]
},
portMapper: (context) => {
// Port mapping logic based on incoming port
if (context.port === PROXY_PORTS[2]) {
return TEST_PORTS[0];
} else {
return TEST_PORTS[2];
}
name: 'Identity Port Mapping'
},
{
match: { ports: PROXY_PORTS[1] },
action: {
type: 'forward',
targets: [{
host: 'localhost',
port: (context: IRouteContext) => context.port + portOffset
}]
},
name: `Offset Port Mapping (${portOffset})`
},
{
match: { ports: [PROXY_PORTS[2], PROXY_PORTS[3]] },
action: {
type: 'forward',
targets: [{
host: (context: IRouteContext) => {
return context.port === PROXY_PORTS[2] ? 'localhost' : '127.0.0.1';
},
port: (context: IRouteContext) => {
if (context.port === PROXY_PORTS[2]) {
return TEST_PORTS[0];
} else {
return TEST_PORTS[2];
}
}
}]
},
name: 'Dynamic Host and Port Mapping'
}),
},
// Smart load balancer for domain-based routing
createSmartLoadBalancer({
ports: PROXY_PORTS[4],
domainTargets: {
'test1.example.com': 'localhost',
'test2.example.com': '127.0.0.1'
{
match: { ports: PROXY_PORTS[4] },
action: {
type: 'forward',
targets: [{
host: (context: IRouteContext) => {
if (context.domain === 'test1.example.com') return 'localhost';
if (context.domain === 'test2.example.com') return '127.0.0.1';
return 'localhost';
},
port: (context: IRouteContext) => {
if (context.domain === 'test1.example.com') {
return TEST_PORTS[0];
} else {
return TEST_PORTS[1];
}
}
}]
},
portMapper: (context) => {
// Use different backend ports based on domain
if (context.domain === 'test1.example.com') {
return TEST_PORTS[0];
} else {
return TEST_PORTS[1];
}
},
defaultTarget: 'localhost',
name: 'Smart Domain Load Balancer'
})
}
]
});
// Start the SmartProxy
await smartProxy.start();
});
// Test 1: Simple identity port mapping
tap.test('should map port using identity function', async () => {
const response = await createTestClient(PROXY_PORTS[0], TEST_DATA);
expect(response).toEqual(`Server ${TEST_PORTS[0]} says: ${TEST_DATA}`);
});
// Test 2: Offset port mapping
tap.test('should map port using offset function', async () => {
const response = await createTestClient(PROXY_PORTS[1], TEST_DATA);
expect(response).toEqual(`Server ${TEST_PORTS[1]} says: ${TEST_DATA}`);
});
// Test 3: Dynamic port and host mapping (conditional logic)
tap.test('should map port using dynamic logic', async () => {
const response = await createTestClient(PROXY_PORTS[2], TEST_DATA);
expect(response).toEqual(`Server ${TEST_PORTS[0]} says: ${TEST_DATA}`);
});
// Test 4: Test reuse of createPortOffset helper
tap.test('should use createPortOffset helper for port mapping', async () => {
// Test the createPortOffset helper with dynamic offset
const portOffset = TEST_PORTS[1] - PROXY_PORTS[1];
const offsetFn = createPortOffset(portOffset);
const context = {
port: PROXY_PORTS[1],
clientIp: '127.0.0.1',
serverIp: '127.0.0.1',
isTls: false,
timestamp: Date.now(),
connectionId: 'test-connection'
} as IRouteContext;
const mappedPort = offsetFn(context);
expect(mappedPort).toEqual(TEST_PORTS[1]);
});
// Test 5: Test error handling for invalid port mapping functions
tap.test('should handle errors in port mapping functions', async () => {
// Create a route with a function that throws an error
const errorRoute: IRouteConfig = {
match: {
ports: PROXY_PORTS[5]
@@ -232,34 +203,27 @@ tap.test('should handle errors in port mapping functions', async () => {
},
name: 'Error Route'
};
// Add the route to SmartProxy
await smartProxy.updateRoutes([...smartProxy.settings.routes, errorRoute]);
// The connection should fail or timeout
try {
await createTestClient(PROXY_PORTS[5], TEST_DATA);
// Connection should not succeed
expect(false).toBeTrue();
} catch (error) {
// Connection failed as expected
expect(true).toBeTrue();
}
});
// Cleanup
tap.test('cleanup port mapping test environment', async () => {
// Add timeout to prevent hanging if SmartProxy shutdown has issues
const cleanupPromise = cleanup();
const timeoutPromise = new Promise((_, reject) =>
const timeoutPromise = new Promise((_, reject) =>
setTimeout(() => reject(new Error('Cleanup timeout after 5 seconds')), 5000)
);
try {
await Promise.race([cleanupPromise, timeoutPromise]);
} catch (error) {
console.error('Cleanup error:', error);
// Force cleanup even if there's an error
testServers = [];
smartProxy = null as any;
}
@@ -267,4 +231,4 @@ tap.test('cleanup port mapping test environment', async () => {
await assertPortsFree([...TEST_PORTS, ...PROXY_PORTS]);
});
export default tap.start();
export default tap.start();

View File

@@ -6,7 +6,7 @@ import { expect, tap } from '@git.zone/tstest/tapbundle';
// Import from core modules
import { SmartProxy } from '../ts/proxies/smart-proxy/index.js';
// Import route utilities and helpers
// Import route utilities
import {
findMatchingRoutes,
findBestMatchingRoute,
@@ -28,16 +28,7 @@ import {
assertValidRoute
} from '../ts/proxies/smart-proxy/utils/route-validator.js';
import {
createHttpRoute,
createHttpsTerminateRoute,
createHttpsPassthroughRoute,
createHttpToHttpsRedirect,
createCompleteHttpsServer,
createLoadBalancerRoute,
createApiRoute,
createWebSocketRoute
} from '../ts/proxies/smart-proxy/utils/route-helpers.js';
import { SocketHandlers } from '../ts/proxies/smart-proxy/utils/socket-handlers.js';
// Import test helpers
import { loadTestCertificates } from './helpers/certificates.js';
@@ -47,12 +38,12 @@ import type { IRouteConfig } from '../ts/proxies/smart-proxy/models/route-types.
// --------------------------------- Route Creation Tests ---------------------------------
tap.test('Routes: Should create basic HTTP route', async () => {
// Create a simple HTTP route
const httpRoute = createHttpRoute('example.com', { host: 'localhost', port: 3000 }, {
const httpRoute: IRouteConfig = {
match: { ports: 80, domains: 'example.com' },
action: { type: 'forward', targets: [{ host: 'localhost', port: 3000 }] },
name: 'Basic HTTP Route'
});
};
// Validate the route configuration
expect(httpRoute.match.ports).toEqual(80);
expect(httpRoute.match.domains).toEqual('example.com');
expect(httpRoute.action.type).toEqual('forward');
@@ -62,14 +53,17 @@ tap.test('Routes: Should create basic HTTP route', async () => {
});
tap.test('Routes: Should create HTTPS route with TLS termination', async () => {
// Create an HTTPS route with TLS termination
const httpsRoute = createHttpsTerminateRoute('secure.example.com', { host: 'localhost', port: 8080 }, {
certificate: 'auto',
const httpsRoute: IRouteConfig = {
match: { ports: 443, domains: 'secure.example.com' },
action: {
type: 'forward',
targets: [{ host: 'localhost', port: 8080 }],
tls: { mode: 'terminate', certificate: 'auto' }
},
name: 'HTTPS Route'
});
};
// Validate the route configuration
expect(httpsRoute.match.ports).toEqual(443); // Default HTTPS port
expect(httpsRoute.match.ports).toEqual(443);
expect(httpsRoute.match.domains).toEqual('secure.example.com');
expect(httpsRoute.action.type).toEqual('forward');
expect(httpsRoute.action.tls?.mode).toEqual('terminate');
@@ -80,10 +74,15 @@ tap.test('Routes: Should create HTTPS route with TLS termination', async () => {
});
tap.test('Routes: Should create HTTP to HTTPS redirect', async () => {
// Create an HTTP to HTTPS redirect
const redirectRoute = createHttpToHttpsRedirect('example.com', 443);
const redirectRoute: IRouteConfig = {
match: { ports: 80, domains: 'example.com' },
action: {
type: 'socket-handler',
socketHandler: SocketHandlers.httpRedirect('https://{domain}:443{path}', 301)
},
name: 'HTTP to HTTPS Redirect for example.com'
};
// Validate the route configuration
expect(redirectRoute.match.ports).toEqual(80);
expect(redirectRoute.match.domains).toEqual('example.com');
expect(redirectRoute.action.type).toEqual('socket-handler');
@@ -91,22 +90,34 @@ tap.test('Routes: Should create HTTP to HTTPS redirect', async () => {
});
tap.test('Routes: Should create complete HTTPS server with redirects', async () => {
// Create a complete HTTPS server setup
const routes = createCompleteHttpsServer('example.com', { host: 'localhost', port: 8080 }, {
certificate: 'auto'
});
const routes: IRouteConfig[] = [
{
match: { ports: 443, domains: 'example.com' },
action: {
type: 'forward',
targets: [{ host: 'localhost', port: 8080 }],
tls: { mode: 'terminate', certificate: 'auto' }
},
name: 'HTTPS Terminate Route for example.com'
},
{
match: { ports: 80, domains: 'example.com' },
action: {
type: 'socket-handler',
socketHandler: SocketHandlers.httpRedirect('https://{domain}:443{path}', 301)
},
name: 'HTTP to HTTPS Redirect for example.com'
}
];
// Validate that we got two routes (HTTPS route and HTTP redirect)
expect(routes.length).toEqual(2);
// Validate HTTPS route
const httpsRoute = routes[0];
expect(httpsRoute.match.ports).toEqual(443);
expect(httpsRoute.match.domains).toEqual('example.com');
expect(httpsRoute.action.type).toEqual('forward');
expect(httpsRoute.action.tls?.mode).toEqual('terminate');
// Validate HTTP redirect route
const redirectRoute = routes[1];
expect(redirectRoute.match.ports).toEqual(80);
expect(redirectRoute.action.type).toEqual('socket-handler');
@@ -114,21 +125,17 @@ tap.test('Routes: Should create complete HTTPS server with redirects', async ()
});
tap.test('Routes: Should create load balancer route', async () => {
// Create a load balancer route
const lbRoute = createLoadBalancerRoute(
'app.example.com',
['10.0.0.1', '10.0.0.2', '10.0.0.3'],
8080,
{
tls: {
mode: 'terminate',
certificate: 'auto'
},
name: 'Load Balanced Route'
}
);
const lbRoute: IRouteConfig = {
match: { ports: 443, domains: 'app.example.com' },
action: {
type: 'forward',
targets: [{ host: ['10.0.0.1', '10.0.0.2', '10.0.0.3'], port: 8080 }],
tls: { mode: 'terminate', certificate: 'auto' },
loadBalancing: { algorithm: 'round-robin' }
},
name: 'Load Balanced Route'
};
// Validate the route configuration
expect(lbRoute.match.domains).toEqual('app.example.com');
expect(lbRoute.action.type).toEqual('forward');
expect(Array.isArray(lbRoute.action.targets?.[0]?.host)).toBeTrue();
@@ -139,23 +146,32 @@ tap.test('Routes: Should create load balancer route', async () => {
});
tap.test('Routes: Should create API route with CORS', async () => {
// Create an API route with CORS headers
const apiRoute = createApiRoute('api.example.com', '/v1', { host: 'localhost', port: 3000 }, {
useTls: true,
certificate: 'auto',
addCorsHeaders: true,
const apiRoute: IRouteConfig = {
match: { ports: 443, domains: 'api.example.com', path: '/v1/*' },
action: {
type: 'forward',
targets: [{ host: 'localhost', port: 3000 }],
tls: { mode: 'terminate', certificate: 'auto' }
},
headers: {
response: {
'Access-Control-Allow-Origin': '*',
'Access-Control-Allow-Methods': 'GET, POST, PUT, DELETE, OPTIONS',
'Access-Control-Allow-Headers': 'Content-Type, Authorization',
'Access-Control-Max-Age': '86400'
}
},
priority: 100,
name: 'API Route'
});
};
// Validate the route configuration
expect(apiRoute.match.domains).toEqual('api.example.com');
expect(apiRoute.match.path).toEqual('/v1/*');
expect(apiRoute.action.type).toEqual('forward');
expect(apiRoute.action.tls?.mode).toEqual('terminate');
expect(apiRoute.action.targets?.[0]?.host).toEqual('localhost');
expect(apiRoute.action.targets?.[0]?.port).toEqual(3000);
// Check CORS headers
expect(apiRoute.headers).toBeDefined();
if (apiRoute.headers?.response) {
expect(apiRoute.headers.response['Access-Control-Allow-Origin']).toEqual('*');
@@ -164,23 +180,25 @@ tap.test('Routes: Should create API route with CORS', async () => {
});
tap.test('Routes: Should create WebSocket route', async () => {
// Create a WebSocket route
const wsRoute = createWebSocketRoute('ws.example.com', '/socket', { host: 'localhost', port: 5000 }, {
useTls: true,
certificate: 'auto',
pingInterval: 15000,
const wsRoute: IRouteConfig = {
match: { ports: 443, domains: 'ws.example.com', path: '/socket' },
action: {
type: 'forward',
targets: [{ host: 'localhost', port: 5000 }],
tls: { mode: 'terminate', certificate: 'auto' },
websocket: { enabled: true, pingInterval: 15000 }
},
priority: 100,
name: 'WebSocket Route'
});
};
// Validate the route configuration
expect(wsRoute.match.domains).toEqual('ws.example.com');
expect(wsRoute.match.path).toEqual('/socket');
expect(wsRoute.action.type).toEqual('forward');
expect(wsRoute.action.tls?.mode).toEqual('terminate');
expect(wsRoute.action.targets?.[0]?.host).toEqual('localhost');
expect(wsRoute.action.targets?.[0]?.port).toEqual(5000);
// Check WebSocket configuration
expect(wsRoute.action.websocket).toBeDefined();
if (wsRoute.action.websocket) {
expect(wsRoute.action.websocket.enabled).toBeTrue();
@@ -191,22 +209,27 @@ tap.test('Routes: Should create WebSocket route', async () => {
// Static file serving has been removed - should be handled by external servers
tap.test('SmartProxy: Should create instance with route-based config', async () => {
// Create TLS certificates for testing
const certs = loadTestCertificates();
// Create a SmartProxy instance with route-based configuration
const proxy = new SmartProxy({
routes: [
createHttpRoute('example.com', { host: 'localhost', port: 3000 }, {
{
match: { ports: 80, domains: 'example.com' },
action: { type: 'forward', targets: [{ host: 'localhost', port: 3000 }] },
name: 'HTTP Route'
}),
createHttpsTerminateRoute('secure.example.com', { host: 'localhost', port: 8443 }, {
certificate: {
key: certs.privateKey,
cert: certs.publicKey
},
{
match: { ports: 443, domains: 'secure.example.com' },
action: {
type: 'forward',
targets: [{ host: 'localhost', port: 8443 }],
tls: {
mode: 'terminate',
certificate: { key: certs.privateKey, cert: certs.publicKey }
}
},
name: 'HTTPS Route'
})
}
],
defaults: {
target: {
@@ -218,13 +241,11 @@ tap.test('SmartProxy: Should create instance with route-based config', async ()
maxConnections: 100
}
},
// Additional settings
initialDataTimeout: 10000,
inactivityTimeout: 300000,
enableDetailedLogging: true
});
// Simply verify the instance was created successfully
expect(typeof proxy).toEqual('object');
expect(typeof proxy.start).toEqual('function');
expect(typeof proxy.stop).toEqual('function');
@@ -233,94 +254,109 @@ tap.test('SmartProxy: Should create instance with route-based config', async ()
// --------------------------------- Edge Case Tests ---------------------------------
tap.test('Edge Case - Empty Routes Array', async () => {
// Attempting to find routes in an empty array
const emptyRoutes: IRouteConfig[] = [];
const matches = findMatchingRoutes(emptyRoutes, { domain: 'example.com', port: 80 });
expect(matches).toBeInstanceOf(Array);
expect(matches.length).toEqual(0);
const bestMatch = findBestMatchingRoute(emptyRoutes, { domain: 'example.com', port: 80 });
expect(bestMatch).toBeUndefined();
});
tap.test('Edge Case - Multiple Matching Routes with Same Priority', async () => {
// Create multiple routes with identical priority but different targets
const route1 = createHttpRoute('example.com', { host: 'server1', port: 3000 });
const route2 = createHttpRoute('example.com', { host: 'server2', port: 3000 });
const route3 = createHttpRoute('example.com', { host: 'server3', port: 3000 });
// Set all to the same priority
const route1: IRouteConfig = {
match: { ports: 80, domains: 'example.com' },
action: { type: 'forward', targets: [{ host: 'server1', port: 3000 }] },
name: 'HTTP Route for example.com'
};
const route2: IRouteConfig = {
match: { ports: 80, domains: 'example.com' },
action: { type: 'forward', targets: [{ host: 'server2', port: 3000 }] },
name: 'HTTP Route for example.com'
};
const route3: IRouteConfig = {
match: { ports: 80, domains: 'example.com' },
action: { type: 'forward', targets: [{ host: 'server3', port: 3000 }] },
name: 'HTTP Route for example.com'
};
route1.priority = 100;
route2.priority = 100;
route3.priority = 100;
const routes = [route1, route2, route3];
// Find matching routes
const matches = findMatchingRoutes(routes, { domain: 'example.com', port: 80 });
// Should find all three routes
expect(matches.length).toEqual(3);
// First match could be any of the routes since they have the same priority
// But the implementation should be consistent (likely keep the original order)
const bestMatch = findBestMatchingRoute(routes, { domain: 'example.com', port: 80 });
expect(bestMatch).not.toBeUndefined();
});
tap.test('Edge Case - Wildcard Domains and Path Matching', async () => {
// Create routes with wildcard domains and path patterns
const wildcardApiRoute = createApiRoute('*.example.com', '/api', { host: 'api-server', port: 3000 }, {
useTls: true,
certificate: 'auto'
});
const exactApiRoute = createApiRoute('api.example.com', '/api', { host: 'specific-api-server', port: 3001 }, {
useTls: true,
certificate: 'auto',
priority: 200 // Higher priority
});
const wildcardApiRoute: IRouteConfig = {
match: { ports: 443, domains: '*.example.com', path: '/api/*' },
action: {
type: 'forward',
targets: [{ host: 'api-server', port: 3000 }],
tls: { mode: 'terminate', certificate: 'auto' }
},
priority: 100,
name: 'API Route for *.example.com'
};
const exactApiRoute: IRouteConfig = {
match: { ports: 443, domains: 'api.example.com', path: '/api/*' },
action: {
type: 'forward',
targets: [{ host: 'specific-api-server', port: 3001 }],
tls: { mode: 'terminate', certificate: 'auto' }
},
priority: 200,
name: 'API Route for api.example.com'
};
const routes = [wildcardApiRoute, exactApiRoute];
// Test with a specific subdomain that matches both routes
const matches = findMatchingRoutes(routes, { domain: 'api.example.com', path: '/api/users', port: 443 });
// Should match both routes
expect(matches.length).toEqual(2);
// The exact domain match should have higher priority
const bestMatch = findBestMatchingRoute(routes, { domain: 'api.example.com', path: '/api/users', port: 443 });
expect(bestMatch).not.toBeUndefined();
if (bestMatch) {
expect(bestMatch.action.targets[0].port).toEqual(3001); // Should match the exact domain route
expect(bestMatch.action.targets[0].port).toEqual(3001);
}
// Test with a different subdomain - should only match the wildcard route
const otherMatches = findMatchingRoutes(routes, { domain: 'other.example.com', path: '/api/products', port: 443 });
expect(otherMatches.length).toEqual(1);
expect(otherMatches[0].action.targets[0].port).toEqual(3000); // Should match the wildcard domain route
expect(otherMatches[0].action.targets[0].port).toEqual(3000);
});
tap.test('Edge Case - Disabled Routes', async () => {
// Create enabled and disabled routes
const enabledRoute = createHttpRoute('example.com', { host: 'server1', port: 3000 });
const disabledRoute = createHttpRoute('example.com', { host: 'server2', port: 3001 });
const enabledRoute: IRouteConfig = {
match: { ports: 80, domains: 'example.com' },
action: { type: 'forward', targets: [{ host: 'server1', port: 3000 }] },
name: 'HTTP Route for example.com'
};
const disabledRoute: IRouteConfig = {
match: { ports: 80, domains: 'example.com' },
action: { type: 'forward', targets: [{ host: 'server2', port: 3001 }] },
name: 'HTTP Route for example.com'
};
disabledRoute.enabled = false;
const routes = [enabledRoute, disabledRoute];
// Find matching routes
const matches = findMatchingRoutes(routes, { domain: 'example.com', port: 80 });
// Should only find the enabled route
expect(matches.length).toEqual(1);
expect(matches[0].action.targets[0].port).toEqual(3000);
});
tap.test('Edge Case - Complex Path and Headers Matching', async () => {
// Create route with complex path and headers matching
const complexRoute: IRouteConfig = {
match: {
domains: 'api.example.com',
@@ -344,22 +380,20 @@ tap.test('Edge Case - Complex Path and Headers Matching', async () => {
},
name: 'Complex API Route'
};
// Test with matching criteria
const matchingPath = routeMatchesPath(complexRoute, '/api/v2/users');
expect(matchingPath).toBeTrue();
const matchingHeaders = routeMatchesHeaders(complexRoute, {
'Content-Type': 'application/json',
'X-API-Key': 'valid-key',
'Accept': 'application/json'
});
expect(matchingHeaders).toBeTrue();
// Test with non-matching criteria
const nonMatchingPath = routeMatchesPath(complexRoute, '/api/v1/users');
expect(nonMatchingPath).toBeFalse();
const nonMatchingHeaders = routeMatchesHeaders(complexRoute, {
'Content-Type': 'application/json',
'X-API-Key': 'invalid-key'
@@ -368,7 +402,6 @@ tap.test('Edge Case - Complex Path and Headers Matching', async () => {
});
tap.test('Edge Case - Port Range Matching', async () => {
// Create route with port range matching
const portRangeRoute: IRouteConfig = {
match: {
domains: 'example.com',
@@ -383,17 +416,14 @@ tap.test('Edge Case - Port Range Matching', async () => {
},
name: 'Port Range Route'
};
// Test with ports in the range
expect(routeMatchesPort(portRangeRoute, 8000)).toBeTrue(); // Lower bound
expect(routeMatchesPort(portRangeRoute, 8500)).toBeTrue(); // Middle
expect(routeMatchesPort(portRangeRoute, 9000)).toBeTrue(); // Upper bound
// Test with ports outside the range
expect(routeMatchesPort(portRangeRoute, 7999)).toBeFalse(); // Just below
expect(routeMatchesPort(portRangeRoute, 9001)).toBeFalse(); // Just above
// Test with multiple port ranges
expect(routeMatchesPort(portRangeRoute, 8000)).toBeTrue();
expect(routeMatchesPort(portRangeRoute, 8500)).toBeTrue();
expect(routeMatchesPort(portRangeRoute, 9000)).toBeTrue();
expect(routeMatchesPort(portRangeRoute, 7999)).toBeFalse();
expect(routeMatchesPort(portRangeRoute, 9001)).toBeFalse();
const multiRangeRoute: IRouteConfig = {
match: {
domains: 'example.com',
@@ -411,7 +441,7 @@ tap.test('Edge Case - Port Range Matching', async () => {
},
name: 'Multi Range Route'
};
expect(routeMatchesPort(multiRangeRoute, 85)).toBeTrue();
expect(routeMatchesPort(multiRangeRoute, 8500)).toBeTrue();
expect(routeMatchesPort(multiRangeRoute, 100)).toBeFalse();
@@ -420,55 +450,56 @@ tap.test('Edge Case - Port Range Matching', async () => {
// --------------------------------- Wildcard Domain Tests ---------------------------------
tap.test('Wildcard Domain Handling', async () => {
// Create routes with different wildcard patterns
const simpleDomainRoute = createHttpRoute('example.com', { host: 'server1', port: 3000 });
const wildcardSubdomainRoute = createHttpRoute('*.example.com', { host: 'server2', port: 3001 });
const specificSubdomainRoute = createHttpRoute('api.example.com', { host: 'server3', port: 3002 });
const simpleDomainRoute: IRouteConfig = {
match: { ports: 80, domains: 'example.com' },
action: { type: 'forward', targets: [{ host: 'server1', port: 3000 }] },
name: 'HTTP Route for example.com'
};
const wildcardSubdomainRoute: IRouteConfig = {
match: { ports: 80, domains: '*.example.com' },
action: { type: 'forward', targets: [{ host: 'server2', port: 3001 }] },
name: 'HTTP Route for *.example.com'
};
const specificSubdomainRoute: IRouteConfig = {
match: { ports: 80, domains: 'api.example.com' },
action: { type: 'forward', targets: [{ host: 'server3', port: 3002 }] },
name: 'HTTP Route for api.example.com'
};
// Set explicit priorities to ensure deterministic matching
specificSubdomainRoute.priority = 200; // Highest priority for specific domain
wildcardSubdomainRoute.priority = 100; // Medium priority for wildcard
simpleDomainRoute.priority = 50; // Lowest priority for generic domain
specificSubdomainRoute.priority = 200;
wildcardSubdomainRoute.priority = 100;
simpleDomainRoute.priority = 50;
const routes = [simpleDomainRoute, wildcardSubdomainRoute, specificSubdomainRoute];
// Test exact domain match
expect(routeMatchesDomain(simpleDomainRoute, 'example.com')).toBeTrue();
expect(routeMatchesDomain(simpleDomainRoute, 'sub.example.com')).toBeFalse();
// Test wildcard subdomain match
expect(routeMatchesDomain(wildcardSubdomainRoute, 'any.example.com')).toBeTrue();
expect(routeMatchesDomain(wildcardSubdomainRoute, 'nested.sub.example.com')).toBeTrue();
expect(routeMatchesDomain(wildcardSubdomainRoute, 'example.com')).toBeFalse();
// Test specific subdomain match
expect(routeMatchesDomain(specificSubdomainRoute, 'api.example.com')).toBeTrue();
expect(routeMatchesDomain(specificSubdomainRoute, 'other.example.com')).toBeFalse();
expect(routeMatchesDomain(specificSubdomainRoute, 'sub.api.example.com')).toBeFalse();
// Test finding best match when multiple domains match
const specificSubdomainRequest = { domain: 'api.example.com', port: 80 };
const bestSpecificMatch = findBestMatchingRoute(routes, specificSubdomainRequest);
expect(bestSpecificMatch).not.toBeUndefined();
if (bestSpecificMatch) {
// Find which route was matched
const matchedPort = bestSpecificMatch.action.targets[0].port;
console.log(`Matched route with port: ${matchedPort}`);
// Verify it's the specific subdomain route (with highest priority)
expect(bestSpecificMatch.priority).toEqual(200);
}
// Test with a subdomain that matches wildcard but not specific
const otherSubdomainRequest = { domain: 'other.example.com', port: 80 };
const bestWildcardMatch = findBestMatchingRoute(routes, otherSubdomainRequest);
expect(bestWildcardMatch).not.toBeUndefined();
if (bestWildcardMatch) {
// Find which route was matched
const matchedPort = bestWildcardMatch.action.targets[0].port;
console.log(`Matched route with port: ${matchedPort}`);
// Verify it's the wildcard subdomain route (with medium priority)
expect(bestWildcardMatch.priority).toEqual(100);
}
});
@@ -476,56 +507,83 @@ tap.test('Wildcard Domain Handling', async () => {
// --------------------------------- Integration Tests ---------------------------------
tap.test('Route Integration - Combining Multiple Route Types', async () => {
// Create a comprehensive set of routes for a full application
const routes: IRouteConfig[] = [
// Main website with HTTPS and HTTP redirect
...createCompleteHttpsServer('example.com', { host: 'web-server', port: 8080 }, {
certificate: 'auto'
}),
// API endpoints
createApiRoute('api.example.com', '/v1', { host: 'api-server', port: 3000 }, {
useTls: true,
certificate: 'auto',
addCorsHeaders: true
}),
// WebSocket for real-time updates
createWebSocketRoute('ws.example.com', '/live', { host: 'websocket-server', port: 5000 }, {
useTls: true,
certificate: 'auto'
}),
// Legacy system with passthrough
createHttpsPassthroughRoute('legacy.example.com', { host: 'legacy-server', port: 443 })
{
match: { ports: 443, domains: 'example.com' },
action: {
type: 'forward',
targets: [{ host: 'web-server', port: 8080 }],
tls: { mode: 'terminate', certificate: 'auto' }
},
name: 'HTTPS Terminate Route for example.com'
},
{
match: { ports: 80, domains: 'example.com' },
action: {
type: 'socket-handler',
socketHandler: SocketHandlers.httpRedirect('https://{domain}:443{path}', 301)
},
name: 'HTTP to HTTPS Redirect for example.com'
},
{
match: { ports: 443, domains: 'api.example.com', path: '/v1/*' },
action: {
type: 'forward',
targets: [{ host: 'api-server', port: 3000 }],
tls: { mode: 'terminate', certificate: 'auto' }
},
headers: {
response: {
'Access-Control-Allow-Origin': '*',
'Access-Control-Allow-Methods': 'GET, POST, PUT, DELETE, OPTIONS',
'Access-Control-Allow-Headers': 'Content-Type, Authorization',
'Access-Control-Max-Age': '86400'
}
},
priority: 100,
name: 'API Route for api.example.com'
},
{
match: { ports: 443, domains: 'ws.example.com', path: '/live' },
action: {
type: 'forward',
targets: [{ host: 'websocket-server', port: 5000 }],
tls: { mode: 'terminate', certificate: 'auto' },
websocket: { enabled: true }
},
priority: 100,
name: 'WebSocket Route for ws.example.com'
},
{
match: { ports: 443, domains: 'legacy.example.com' },
action: {
type: 'forward',
targets: [{ host: 'legacy-server', port: 443 }],
tls: { mode: 'passthrough' }
},
name: 'HTTPS Passthrough Route for legacy.example.com'
}
];
// Validate all routes
const validationResult = validateRoutes(routes);
expect(validationResult.valid).toBeTrue();
expect(validationResult.errors.length).toEqual(0);
// Test route matching for different endpoints
// Web server (HTTPS)
const webServerMatch = findBestMatchingRoute(routes, { domain: 'example.com', port: 443 });
expect(webServerMatch).not.toBeUndefined();
if (webServerMatch) {
expect(webServerMatch.action.type).toEqual('forward');
expect(webServerMatch.action.targets[0].host).toEqual('web-server');
}
// Web server (HTTP redirect via socket handler)
const webRedirectMatch = findBestMatchingRoute(routes, { domain: 'example.com', port: 80 });
expect(webRedirectMatch).not.toBeUndefined();
if (webRedirectMatch) {
expect(webRedirectMatch.action.type).toEqual('socket-handler');
}
// API server
const apiMatch = findBestMatchingRoute(routes, {
domain: 'api.example.com',
const apiMatch = findBestMatchingRoute(routes, {
domain: 'api.example.com',
port: 443,
path: '/v1/users'
});
@@ -534,10 +592,9 @@ tap.test('Route Integration - Combining Multiple Route Types', async () => {
expect(apiMatch.action.type).toEqual('forward');
expect(apiMatch.action.targets[0].host).toEqual('api-server');
}
// WebSocket server
const wsMatch = findBestMatchingRoute(routes, {
domain: 'ws.example.com',
const wsMatch = findBestMatchingRoute(routes, {
domain: 'ws.example.com',
port: 443,
path: '/live'
});
@@ -547,12 +604,9 @@ tap.test('Route Integration - Combining Multiple Route Types', async () => {
expect(wsMatch.action.targets[0].host).toEqual('websocket-server');
expect(wsMatch.action.websocket?.enabled).toBeTrue();
}
// Static assets route was removed - static file serving should be handled externally
// Legacy system
const legacyMatch = findBestMatchingRoute(routes, {
domain: 'legacy.example.com',
const legacyMatch = findBestMatchingRoute(routes, {
domain: 'legacy.example.com',
port: 443
});
expect(legacyMatch).not.toBeUndefined();
@@ -565,7 +619,6 @@ tap.test('Route Integration - Combining Multiple Route Types', async () => {
// --------------------------------- Protocol Match Field Tests ---------------------------------
tap.test('Routes: Should accept protocol field on route match', async () => {
// Create a route with protocol: 'http'
const httpOnlyRoute: IRouteConfig = {
match: {
ports: 443,
@@ -583,16 +636,13 @@ tap.test('Routes: Should accept protocol field on route match', async () => {
name: 'HTTP-only Route',
};
// Validate the route - protocol field should not cause errors
const validation = validateRouteConfig(httpOnlyRoute);
expect(validation.valid).toBeTrue();
// Verify the protocol field is preserved
expect(httpOnlyRoute.match.protocol).toEqual('http');
});
tap.test('Routes: Should accept protocol tcp on route match', async () => {
// Create a route with protocol: 'tcp'
const tcpOnlyRoute: IRouteConfig = {
match: {
ports: 443,
@@ -616,28 +666,26 @@ tap.test('Routes: Should accept protocol tcp on route match', async () => {
});
tap.test('Routes: Protocol field should work with terminate-and-reencrypt', async () => {
// Create a terminate-and-reencrypt route that only accepts HTTP
const reencryptRoute = createHttpsTerminateRoute(
'secure.example.com',
{ host: 'backend', port: 443 },
{ reencrypt: true, certificate: 'auto', name: 'Reencrypt HTTP Route' }
);
const reencryptRoute: IRouteConfig = {
match: { ports: 443, domains: 'secure.example.com' },
action: {
type: 'forward',
targets: [{ host: 'backend', port: 443 }],
tls: { mode: 'terminate-and-reencrypt', certificate: 'auto' }
},
name: 'Reencrypt HTTP Route'
};
// Set protocol restriction to http
reencryptRoute.match.protocol = 'http';
// Validate the route
const validation = validateRouteConfig(reencryptRoute);
expect(validation.valid).toBeTrue();
// Verify TLS mode
expect(reencryptRoute.action.tls?.mode).toEqual('terminate-and-reencrypt');
// Verify protocol field is preserved
expect(reencryptRoute.match.protocol).toEqual('http');
});
tap.test('Routes: Protocol field should not affect domain/port matching', async () => {
// Routes with and without protocol field should both match the same domain/port
const routeWithProtocol: IRouteConfig = {
match: {
ports: 443,
@@ -669,11 +717,9 @@ tap.test('Routes: Protocol field should not affect domain/port matching', async
const routes = [routeWithProtocol, routeWithoutProtocol];
// Both routes should match the domain/port (protocol is a hint for Rust-side matching)
const matches = findMatchingRoutes(routes, { domain: 'example.com', port: 443 });
expect(matches.length).toEqual(2);
// The one with higher priority should be first
const best = findBestMatchingRoute(routes, { domain: 'example.com', port: 443 });
expect(best).not.toBeUndefined();
expect(best!.name).toEqual('With Protocol');
@@ -696,11 +742,9 @@ tap.test('Routes: Protocol field preserved through route cloning', async () => {
const cloned = cloneRoute(original);
// Verify protocol is preserved in clone
expect(cloned.match.protocol).toEqual('http');
expect(cloned.action.tls?.mode).toEqual('terminate-and-reencrypt');
// Modify clone should not affect original
cloned.match.protocol = 'tcp';
expect(original.match.protocol).toEqual('http');
});
@@ -720,10 +764,9 @@ tap.test('Routes: Protocol field preserved through route merging', async () => {
name: 'Merge Base',
};
// Merge with override that changes name but not protocol
const merged = mergeRouteConfigs(base, { name: 'Merged Route' });
expect(merged.match.protocol).toEqual('http');
expect(merged.name).toEqual('Merged Route');
});
export default tap.start();
export default tap.start();

View File

@@ -1,21 +1,7 @@
import { tap, expect } from '@git.zone/tstest/tapbundle';
import * as plugins from '../ts/plugins.js';
// Import from individual modules to avoid naming conflicts
import {
// Route helpers
createHttpRoute,
createHttpsTerminateRoute,
createApiRoute,
createWebSocketRoute,
createHttpToHttpsRedirect,
createHttpsPassthroughRoute,
createCompleteHttpsServer,
createLoadBalancerRoute
} from '../ts/proxies/smart-proxy/utils/route-helpers.js';
import {
// Route validators
validateRouteConfig,
validateRoutes,
isValidDomain,
@@ -27,7 +13,6 @@ import {
} from '../ts/proxies/smart-proxy/utils/route-validator.js';
import {
// Route utilities
mergeRouteConfigs,
findMatchingRoutes,
findBestMatchingRoute,
@@ -39,16 +24,6 @@ import {
cloneRoute
} from '../ts/proxies/smart-proxy/utils/route-utils.js';
import {
// Route patterns
createApiGatewayRoute,
createWebSocketRoute as createWebSocketPattern,
createLoadBalancerRoute as createLbPattern,
addRateLimiting,
addBasicAuth,
addJwtAuth
} from '../ts/proxies/smart-proxy/utils/route-helpers.js';
import type {
IRouteConfig,
IRouteMatch,
@@ -84,7 +59,7 @@ tap.test('Route Validation - isValidPort', async () => {
expect(isValidPort(443)).toBeTrue();
expect(isValidPort(8080)).toBeTrue();
expect(isValidPort([80, 443])).toBeTrue();
// Invalid ports
expect(isValidPort(0)).toBeFalse();
expect(isValidPort(65536)).toBeFalse();
@@ -101,7 +76,7 @@ tap.test('Route Validation - validateRouteMatch', async () => {
const validResult = validateRouteMatch(validMatch);
expect(validResult.valid).toBeTrue();
expect(validResult.errors.length).toEqual(0);
// Invalid match configuration (invalid domain)
const invalidMatch: IRouteMatch = {
ports: 80,
@@ -111,7 +86,7 @@ tap.test('Route Validation - validateRouteMatch', async () => {
expect(invalidResult.valid).toBeFalse();
expect(invalidResult.errors.length).toBeGreaterThan(0);
expect(invalidResult.errors[0]).toInclude('Invalid domain');
// Invalid match configuration (invalid port)
const invalidPortMatch: IRouteMatch = {
ports: 0,
@@ -121,7 +96,7 @@ tap.test('Route Validation - validateRouteMatch', async () => {
expect(invalidPortResult.valid).toBeFalse();
expect(invalidPortResult.errors.length).toBeGreaterThan(0);
expect(invalidPortResult.errors[0]).toInclude('Invalid port');
// Test path validation
const invalidPathMatch: IRouteMatch = {
ports: 80,
@@ -146,7 +121,7 @@ tap.test('Route Validation - validateRouteAction', async () => {
const validForwardResult = validateRouteAction(validForwardAction);
expect(validForwardResult.valid).toBeTrue();
expect(validForwardResult.errors.length).toEqual(0);
// Valid socket-handler action
const validSocketAction: IRouteAction = {
type: 'socket-handler',
@@ -157,7 +132,7 @@ tap.test('Route Validation - validateRouteAction', async () => {
const validSocketResult = validateRouteAction(validSocketAction);
expect(validSocketResult.valid).toBeTrue();
expect(validSocketResult.errors.length).toEqual(0);
// Invalid action (missing targets)
const invalidAction: IRouteAction = {
type: 'forward'
@@ -166,7 +141,7 @@ tap.test('Route Validation - validateRouteAction', async () => {
expect(invalidResult.valid).toBeFalse();
expect(invalidResult.errors.length).toBeGreaterThan(0);
expect(invalidResult.errors[0]).toInclude('Targets array is required');
// Invalid action (missing socket handler)
const invalidSocketAction: IRouteAction = {
type: 'socket-handler'
@@ -179,11 +154,15 @@ tap.test('Route Validation - validateRouteAction', async () => {
tap.test('Route Validation - validateRouteConfig', async () => {
// Valid route config
const validRoute = createHttpRoute('example.com', { host: 'localhost', port: 3000 });
const validRoute: IRouteConfig = {
match: { ports: 80, domains: 'example.com' },
action: { type: 'forward', targets: [{ host: 'localhost', port: 3000 }] },
name: 'HTTP Route for example.com',
};
const validResult = validateRouteConfig(validRoute);
expect(validResult.valid).toBeTrue();
expect(validResult.errors.length).toEqual(0);
// Invalid route config (missing targets)
const invalidRoute: IRouteConfig = {
match: {
@@ -203,7 +182,11 @@ tap.test('Route Validation - validateRouteConfig', async () => {
tap.test('Route Validation - validateRoutes', async () => {
// Create valid and invalid routes
const routes = [
createHttpRoute('example.com', { host: 'localhost', port: 3000 }),
{
match: { ports: 80, domains: 'example.com' },
action: { type: 'forward', targets: [{ host: 'localhost', port: 3000 }] },
name: 'HTTP Route for example.com',
} as IRouteConfig,
{
match: {
domains: 'invalid..domain',
@@ -217,9 +200,13 @@ tap.test('Route Validation - validateRoutes', async () => {
}
}
} as IRouteConfig,
createHttpsTerminateRoute('secure.example.com', { host: 'localhost', port: 3001 })
{
match: { ports: 443, domains: 'secure.example.com' },
action: { type: 'forward', targets: [{ host: 'localhost', port: 3001 }], tls: { mode: 'terminate', certificate: 'auto' } },
name: 'HTTPS Terminate Route for secure.example.com',
} as IRouteConfig
];
const result = validateRoutes(routes);
expect(result.valid).toBeFalse();
expect(result.errors.length).toEqual(1);
@@ -230,13 +217,13 @@ tap.test('Route Validation - validateRoutes', async () => {
tap.test('Route Validation - hasRequiredPropertiesForAction', async () => {
// Forward action
const forwardRoute = createHttpRoute('example.com', { host: 'localhost', port: 3000 });
const forwardRoute: IRouteConfig = {
match: { ports: 80, domains: 'example.com' },
action: { type: 'forward', targets: [{ host: 'localhost', port: 3000 }] },
name: 'HTTP Route for example.com',
};
expect(hasRequiredPropertiesForAction(forwardRoute, 'forward')).toBeTrue();
// Socket handler action (redirect functionality)
const redirectRoute = createHttpToHttpsRedirect('example.com');
expect(hasRequiredPropertiesForAction(redirectRoute, 'socket-handler')).toBeTrue();
// Socket handler action
const socketRoute: IRouteConfig = {
match: {
@@ -252,7 +239,7 @@ tap.test('Route Validation - hasRequiredPropertiesForAction', async () => {
name: 'Socket Handler Route'
};
expect(hasRequiredPropertiesForAction(socketRoute, 'socket-handler')).toBeTrue();
// Missing required properties
const invalidForwardRoute: IRouteConfig = {
match: {
@@ -269,9 +256,13 @@ tap.test('Route Validation - hasRequiredPropertiesForAction', async () => {
tap.test('Route Validation - assertValidRoute', async () => {
// Valid route
const validRoute = createHttpRoute('example.com', { host: 'localhost', port: 3000 });
const validRoute: IRouteConfig = {
match: { ports: 80, domains: 'example.com' },
action: { type: 'forward', targets: [{ host: 'localhost', port: 3000 }] },
name: 'HTTP Route for example.com',
};
expect(() => assertValidRoute(validRoute)).not.toThrow();
// Invalid route
const invalidRoute: IRouteConfig = {
match: {
@@ -290,8 +281,12 @@ tap.test('Route Validation - assertValidRoute', async () => {
tap.test('Route Utilities - mergeRouteConfigs', async () => {
// Base route
const baseRoute = createHttpRoute('example.com', { host: 'localhost', port: 3000 });
const baseRoute: IRouteConfig = {
match: { ports: 80, domains: 'example.com' },
action: { type: 'forward', targets: [{ host: 'localhost', port: 3000 }] },
name: 'HTTP Route for example.com',
};
// Override with different name and port
const overrideRoute: Partial<IRouteConfig> = {
name: 'Merged Route',
@@ -299,16 +294,16 @@ tap.test('Route Utilities - mergeRouteConfigs', async () => {
ports: 8080
}
};
// Merge configs
const mergedRoute = mergeRouteConfigs(baseRoute, overrideRoute);
// Check merged properties
expect(mergedRoute.name).toEqual('Merged Route');
expect(mergedRoute.match.ports).toEqual(8080);
expect(mergedRoute.match.domains).toEqual('example.com');
expect(mergedRoute.action.type).toEqual('forward');
// Test merging action properties
const actionOverride: Partial<IRouteConfig> = {
action: {
@@ -319,11 +314,11 @@ tap.test('Route Utilities - mergeRouteConfigs', async () => {
}]
}
};
const actionMergedRoute = mergeRouteConfigs(baseRoute, actionOverride);
expect(actionMergedRoute.action.targets?.[0]?.host).toEqual('new-host.local');
expect(actionMergedRoute.action.targets?.[0]?.port).toEqual(5000);
// Test replacing action with socket handler
const typeChangeOverride: Partial<IRouteConfig> = {
action: {
@@ -336,7 +331,7 @@ tap.test('Route Utilities - mergeRouteConfigs', async () => {
}
}
};
const typeChangedRoute = mergeRouteConfigs(baseRoute, typeChangeOverride);
expect(typeChangedRoute.action.type).toEqual('socket-handler');
expect(typeChangedRoute.action.socketHandler).toBeDefined();
@@ -345,37 +340,53 @@ tap.test('Route Utilities - mergeRouteConfigs', async () => {
tap.test('Route Matching - routeMatchesDomain', async () => {
// Create route with wildcard domain
const wildcardRoute = createHttpRoute('*.example.com', { host: 'localhost', port: 3000 });
const wildcardRoute: IRouteConfig = {
match: { ports: 80, domains: '*.example.com' },
action: { type: 'forward', targets: [{ host: 'localhost', port: 3000 }] },
name: 'HTTP Route for *.example.com',
};
// Create route with exact domain
const exactRoute = createHttpRoute('example.com', { host: 'localhost', port: 3000 });
const exactRoute: IRouteConfig = {
match: { ports: 80, domains: 'example.com' },
action: { type: 'forward', targets: [{ host: 'localhost', port: 3000 }] },
name: 'HTTP Route for example.com',
};
// Create route with multiple domains
const multiDomainRoute = createHttpRoute(['example.com', 'example.org'], { host: 'localhost', port: 3000 });
const multiDomainRoute: IRouteConfig = {
match: { ports: 80, domains: ['example.com', 'example.org'] },
action: { type: 'forward', targets: [{ host: 'localhost', port: 3000 }] },
name: 'HTTP Route for example.com,example.org',
};
// Test wildcard domain matching
expect(routeMatchesDomain(wildcardRoute, 'sub.example.com')).toBeTrue();
expect(routeMatchesDomain(wildcardRoute, 'another.example.com')).toBeTrue();
expect(routeMatchesDomain(wildcardRoute, 'example.com')).toBeFalse();
expect(routeMatchesDomain(wildcardRoute, 'example.org')).toBeFalse();
// Test exact domain matching
expect(routeMatchesDomain(exactRoute, 'example.com')).toBeTrue();
expect(routeMatchesDomain(exactRoute, 'sub.example.com')).toBeFalse();
// Test multiple domains matching
expect(routeMatchesDomain(multiDomainRoute, 'example.com')).toBeTrue();
expect(routeMatchesDomain(multiDomainRoute, 'example.org')).toBeTrue();
expect(routeMatchesDomain(multiDomainRoute, 'example.net')).toBeFalse();
// Test case insensitivity
expect(routeMatchesDomain(exactRoute, 'Example.Com')).toBeTrue();
});
tap.test('Route Matching - routeMatchesPort', async () => {
// Create routes with different port configurations
const singlePortRoute = createHttpRoute('example.com', { host: 'localhost', port: 3000 });
const singlePortRoute: IRouteConfig = {
match: { ports: 80, domains: 'example.com' },
action: { type: 'forward', targets: [{ host: 'localhost', port: 3000 }] },
name: 'HTTP Route for example.com',
};
const multiPortRoute: IRouteConfig = {
match: {
domains: 'example.com',
@@ -389,7 +400,7 @@ tap.test('Route Matching - routeMatchesPort', async () => {
}]
}
};
const portRangeRoute: IRouteConfig = {
match: {
domains: 'example.com',
@@ -403,16 +414,16 @@ tap.test('Route Matching - routeMatchesPort', async () => {
}]
}
};
// Test single port matching
expect(routeMatchesPort(singlePortRoute, 80)).toBeTrue();
expect(routeMatchesPort(singlePortRoute, 443)).toBeFalse();
// Test multi-port matching
expect(routeMatchesPort(multiPortRoute, 80)).toBeTrue();
expect(routeMatchesPort(multiPortRoute, 8080)).toBeTrue();
expect(routeMatchesPort(multiPortRoute, 3000)).toBeFalse();
// Test port range matching
expect(routeMatchesPort(portRangeRoute, 8000)).toBeTrue();
expect(routeMatchesPort(portRangeRoute, 8500)).toBeTrue();
@@ -437,11 +448,11 @@ tap.test('Route Matching - routeMatchesPath', async () => {
}]
}
};
// Test prefix matching with wildcard (not trailing slash)
const prefixPathRoute: IRouteConfig = {
match: {
domains: 'example.com',
domains: 'example.com',
ports: 80,
path: '/api/*'
},
@@ -453,7 +464,7 @@ tap.test('Route Matching - routeMatchesPath', async () => {
}]
}
};
const wildcardPathRoute: IRouteConfig = {
match: {
domains: 'example.com',
@@ -468,17 +479,17 @@ tap.test('Route Matching - routeMatchesPath', async () => {
}]
}
};
// Test exact path matching
expect(routeMatchesPath(exactPathRoute, '/api')).toBeTrue();
expect(routeMatchesPath(exactPathRoute, '/api/users')).toBeFalse();
expect(routeMatchesPath(exactPathRoute, '/app')).toBeFalse();
// Test prefix path matching with wildcard
expect(routeMatchesPath(prefixPathRoute, '/api/')).toBeFalse(); // Wildcard requires content after /api/
expect(routeMatchesPath(prefixPathRoute, '/api/users')).toBeTrue();
expect(routeMatchesPath(prefixPathRoute, '/app/')).toBeFalse();
// Test wildcard path matching
expect(routeMatchesPath(wildcardPathRoute, '/api/users')).toBeTrue();
expect(routeMatchesPath(wildcardPathRoute, '/api/products')).toBeTrue();
@@ -504,30 +515,34 @@ tap.test('Route Matching - routeMatchesHeaders', async () => {
}]
}
};
// Test header matching
expect(routeMatchesHeaders(headerRoute, {
'Content-Type': 'application/json',
'X-Custom-Header': 'value'
})).toBeTrue();
expect(routeMatchesHeaders(headerRoute, {
'Content-Type': 'application/json',
'X-Custom-Header': 'value',
'Extra-Header': 'something'
})).toBeTrue();
expect(routeMatchesHeaders(headerRoute, {
'Content-Type': 'application/json'
})).toBeFalse();
expect(routeMatchesHeaders(headerRoute, {
'Content-Type': 'text/html',
'X-Custom-Header': 'value'
})).toBeFalse();
// Route without header matching should match any headers
const noHeaderRoute = createHttpRoute('example.com', { host: 'localhost', port: 3000 });
const noHeaderRoute: IRouteConfig = {
match: { ports: 80, domains: 'example.com' },
action: { type: 'forward', targets: [{ host: 'localhost', port: 3000 }] },
name: 'HTTP Route for example.com',
};
expect(routeMatchesHeaders(noHeaderRoute, {
'Content-Type': 'application/json'
})).toBeTrue();
@@ -536,78 +551,118 @@ tap.test('Route Matching - routeMatchesHeaders', async () => {
tap.test('Route Finding - findMatchingRoutes', async () => {
// Create multiple routes
const routes: IRouteConfig[] = [
createHttpRoute('example.com', { host: 'localhost', port: 3000 }),
createHttpsTerminateRoute('secure.example.com', { host: 'localhost', port: 3001 }),
createApiRoute('api.example.com', '/v1', { host: 'localhost', port: 3002 }),
createWebSocketRoute('ws.example.com', '/socket', { host: 'localhost', port: 3003 })
{
match: { ports: 80, domains: 'example.com' },
action: { type: 'forward', targets: [{ host: 'localhost', port: 3000 }] },
name: 'HTTP Route for example.com',
},
{
match: { ports: 443, domains: 'secure.example.com' },
action: { type: 'forward', targets: [{ host: 'localhost', port: 3001 }], tls: { mode: 'terminate', certificate: 'auto' } },
name: 'HTTPS Route for secure.example.com',
},
{
match: { ports: 443, domains: 'api.example.com', path: '/v1/*' },
action: { type: 'forward', targets: [{ host: 'localhost', port: 3002 }], tls: { mode: 'terminate', certificate: 'auto' } },
name: 'API Route for api.example.com',
},
{
match: { ports: 443, domains: 'ws.example.com', path: '/socket' },
action: { type: 'forward', targets: [{ host: 'localhost', port: 3003 }], tls: { mode: 'terminate', certificate: 'auto' }, websocket: { enabled: true } },
name: 'WebSocket Route for ws.example.com',
},
];
// Set priorities
routes[0].priority = 10;
routes[1].priority = 20;
routes[2].priority = 30;
routes[3].priority = 40;
// Find routes for different criteria
const httpMatches = findMatchingRoutes(routes, { domain: 'example.com', port: 80 });
expect(httpMatches.length).toEqual(1);
expect(httpMatches[0].name).toInclude('HTTP Route');
const httpsMatches = findMatchingRoutes(routes, { domain: 'secure.example.com', port: 443 });
expect(httpsMatches.length).toEqual(1);
expect(httpsMatches[0].name).toInclude('HTTPS Route');
const apiMatches = findMatchingRoutes(routes, { domain: 'api.example.com', path: '/v1/users' });
expect(apiMatches.length).toEqual(1);
expect(apiMatches[0].name).toInclude('API Route');
const wsMatches = findMatchingRoutes(routes, { domain: 'ws.example.com', path: '/socket' });
expect(wsMatches.length).toEqual(1);
expect(wsMatches[0].name).toInclude('WebSocket Route');
// Test finding multiple routes that match same criteria
const route1 = createHttpRoute('example.com', { host: 'localhost', port: 3000 });
const route1: IRouteConfig = {
match: { ports: 80, domains: 'example.com' },
action: { type: 'forward', targets: [{ host: 'localhost', port: 3000 }] },
name: 'HTTP Route for example.com',
};
route1.priority = 10;
const route2 = createHttpRoute('example.com', { host: 'localhost', port: 3001 });
const route2: IRouteConfig = {
match: { ports: 80, domains: 'example.com' },
action: { type: 'forward', targets: [{ host: 'localhost', port: 3001 }] },
name: 'HTTP Route for example.com',
};
route2.priority = 20;
route2.match.path = '/api';
const multiMatchRoutes = [route1, route2];
const multiMatches = findMatchingRoutes(multiMatchRoutes, { domain: 'example.com', port: 80 });
expect(multiMatches.length).toEqual(2);
expect(multiMatches[0].priority).toEqual(20); // Higher priority should be first
expect(multiMatches[1].priority).toEqual(10);
// Test disabled routes
const disabledRoute = createHttpRoute('example.com', { host: 'localhost', port: 3000 });
const disabledRoute: IRouteConfig = {
match: { ports: 80, domains: 'example.com' },
action: { type: 'forward', targets: [{ host: 'localhost', port: 3000 }] },
name: 'HTTP Route for example.com',
};
disabledRoute.enabled = false;
const enabledRoutes = findMatchingRoutes([disabledRoute], { domain: 'example.com', port: 80 });
expect(enabledRoutes.length).toEqual(0);
});
tap.test('Route Finding - findBestMatchingRoute', async () => {
// Create multiple routes with different priorities
const route1 = createHttpRoute('example.com', { host: 'localhost', port: 3000 });
const route1: IRouteConfig = {
match: { ports: 80, domains: 'example.com' },
action: { type: 'forward', targets: [{ host: 'localhost', port: 3000 }] },
name: 'HTTP Route for example.com',
};
route1.priority = 10;
const route2 = createHttpRoute('example.com', { host: 'localhost', port: 3001 });
const route2: IRouteConfig = {
match: { ports: 80, domains: 'example.com' },
action: { type: 'forward', targets: [{ host: 'localhost', port: 3001 }] },
name: 'HTTP Route for example.com',
};
route2.priority = 20;
route2.match.path = '/api';
const route3 = createHttpRoute('example.com', { host: 'localhost', port: 3002 });
const route3: IRouteConfig = {
match: { ports: 80, domains: 'example.com' },
action: { type: 'forward', targets: [{ host: 'localhost', port: 3002 }] },
name: 'HTTP Route for example.com',
};
route3.priority = 30;
route3.match.path = '/api/users';
const routes = [route1, route2, route3];
// Find best route for different criteria
const bestGeneral = findBestMatchingRoute(routes, { domain: 'example.com', port: 80 });
expect(bestGeneral).not.toBeUndefined();
expect(bestGeneral?.priority).toEqual(30);
// Test when no routes match
const noMatch = findBestMatchingRoute(routes, { domain: 'unknown.com', port: 80 });
expect(noMatch).toBeUndefined();
@@ -615,389 +670,54 @@ tap.test('Route Finding - findBestMatchingRoute', async () => {
tap.test('Route Utilities - generateRouteId', async () => {
// Test ID generation for different route types
const httpRoute = createHttpRoute('example.com', { host: 'localhost', port: 3000 });
const httpRoute: IRouteConfig = {
match: { ports: 80, domains: 'example.com' },
action: { type: 'forward', targets: [{ host: 'localhost', port: 3000 }] },
name: 'HTTP Route for example.com',
};
const httpId = generateRouteId(httpRoute);
expect(httpId).toInclude('example-com');
expect(httpId).toInclude('80');
expect(httpId).toInclude('forward');
const httpsRoute = createHttpsTerminateRoute('secure.example.com', { host: 'localhost', port: 3001 });
const httpsRoute: IRouteConfig = {
match: { ports: 443, domains: 'secure.example.com' },
action: { type: 'forward', targets: [{ host: 'localhost', port: 3001 }], tls: { mode: 'terminate', certificate: 'auto' } },
name: 'HTTPS Terminate Route for secure.example.com',
};
const httpsId = generateRouteId(httpsRoute);
expect(httpsId).toInclude('secure-example-com');
expect(httpsId).toInclude('443');
expect(httpsId).toInclude('forward');
const multiDomainRoute = createHttpRoute(['example.com', 'example.org'], { host: 'localhost', port: 3000 });
const multiDomainRoute: IRouteConfig = {
match: { ports: 80, domains: ['example.com', 'example.org'] },
action: { type: 'forward', targets: [{ host: 'localhost', port: 3000 }] },
name: 'HTTP Route for example.com,example.org',
};
const multiDomainId = generateRouteId(multiDomainRoute);
expect(multiDomainId).toInclude('example-com-example-org');
});
tap.test('Route Utilities - cloneRoute', async () => {
// Create a route and clone it
const originalRoute = createHttpsTerminateRoute('example.com', { host: 'localhost', port: 3000 }, {
certificate: 'auto',
name: 'Original Route'
});
const originalRoute: IRouteConfig = {
match: { ports: 443, domains: 'example.com' },
action: { type: 'forward', targets: [{ host: 'localhost', port: 3000 }], tls: { mode: 'terminate', certificate: 'auto' } },
name: 'Original Route',
};
const clonedRoute = cloneRoute(originalRoute);
// Check that the values are identical
expect(clonedRoute.name).toEqual(originalRoute.name);
expect(clonedRoute.match.domains).toEqual(originalRoute.match.domains);
expect(clonedRoute.action.type).toEqual(originalRoute.action.type);
expect(clonedRoute.action.targets?.[0]?.port).toEqual(originalRoute.action.targets?.[0]?.port);
// Modify the clone and check that the original is unchanged
clonedRoute.name = 'Modified Clone';
expect(originalRoute.name).toEqual('Original Route');
});
// --------------------------------- Route Helper Tests ---------------------------------
tap.test('Route Helpers - createHttpRoute', async () => {
const route = createHttpRoute('example.com', { host: 'localhost', port: 3000 });
expect(route.match.domains).toEqual('example.com');
expect(route.match.ports).toEqual(80);
expect(route.action.type).toEqual('forward');
expect(route.action.targets?.[0]?.host).toEqual('localhost');
expect(route.action.targets?.[0]?.port).toEqual(3000);
const validationResult = validateRouteConfig(route);
expect(validationResult.valid).toBeTrue();
});
tap.test('Route Helpers - createHttpsTerminateRoute', async () => {
const route = createHttpsTerminateRoute('example.com', { host: 'localhost', port: 3000 }, {
certificate: 'auto'
});
expect(route.match.domains).toEqual('example.com');
expect(route.match.ports).toEqual(443);
expect(route.action.type).toEqual('forward');
expect(route.action.tls.mode).toEqual('terminate');
expect(route.action.tls.certificate).toEqual('auto');
const validationResult = validateRouteConfig(route);
expect(validationResult.valid).toBeTrue();
});
tap.test('Route Helpers - createHttpToHttpsRedirect', async () => {
const route = createHttpToHttpsRedirect('example.com');
expect(route.match.domains).toEqual('example.com');
expect(route.match.ports).toEqual(80);
expect(route.action.type).toEqual('socket-handler');
expect(route.action.socketHandler).toBeDefined();
const validationResult = validateRouteConfig(route);
expect(validationResult.valid).toBeTrue();
});
tap.test('Route Helpers - createHttpsPassthroughRoute', async () => {
const route = createHttpsPassthroughRoute('example.com', { host: 'localhost', port: 3000 });
expect(route.match.domains).toEqual('example.com');
expect(route.match.ports).toEqual(443);
expect(route.action.type).toEqual('forward');
expect(route.action.tls.mode).toEqual('passthrough');
const validationResult = validateRouteConfig(route);
expect(validationResult.valid).toBeTrue();
});
tap.test('Route Helpers - createCompleteHttpsServer', async () => {
const routes = createCompleteHttpsServer('example.com', { host: 'localhost', port: 3000 }, {
certificate: 'auto'
});
expect(routes.length).toEqual(2);
// HTTPS route
expect(routes[0].match.domains).toEqual('example.com');
expect(routes[0].match.ports).toEqual(443);
expect(routes[0].action.type).toEqual('forward');
expect(routes[0].action.tls.mode).toEqual('terminate');
// HTTP redirect route
expect(routes[1].match.domains).toEqual('example.com');
expect(routes[1].match.ports).toEqual(80);
expect(routes[1].action.type).toEqual('socket-handler');
const validation1 = validateRouteConfig(routes[0]);
const validation2 = validateRouteConfig(routes[1]);
expect(validation1.valid).toBeTrue();
expect(validation2.valid).toBeTrue();
});
// createStaticFileRoute has been removed - static file serving should be handled by
// external servers (nginx/apache) behind the proxy
tap.test('Route Helpers - createApiRoute', async () => {
const route = createApiRoute('api.example.com', '/v1', { host: 'localhost', port: 3000 }, {
useTls: true,
certificate: 'auto',
addCorsHeaders: true
});
expect(route.match.domains).toEqual('api.example.com');
expect(route.match.ports).toEqual(443);
expect(route.match.path).toEqual('/v1/*');
expect(route.action.type).toEqual('forward');
expect(route.action.tls.mode).toEqual('terminate');
// Check CORS headers if they exist
if (route.headers && route.headers.response) {
expect(route.headers.response['Access-Control-Allow-Origin']).toEqual('*');
}
const validationResult = validateRouteConfig(route);
expect(validationResult.valid).toBeTrue();
});
tap.test('Route Helpers - createWebSocketRoute', async () => {
const route = createWebSocketRoute('ws.example.com', '/socket', { host: 'localhost', port: 3000 }, {
useTls: true,
certificate: 'auto',
pingInterval: 15000
});
expect(route.match.domains).toEqual('ws.example.com');
expect(route.match.ports).toEqual(443);
expect(route.match.path).toEqual('/socket');
expect(route.action.type).toEqual('forward');
expect(route.action.tls.mode).toEqual('terminate');
// Check websocket configuration if it exists
if (route.action.websocket) {
expect(route.action.websocket.enabled).toBeTrue();
expect(route.action.websocket.pingInterval).toEqual(15000);
}
const validationResult = validateRouteConfig(route);
expect(validationResult.valid).toBeTrue();
});
tap.test('Route Helpers - createLoadBalancerRoute', async () => {
const route = createLoadBalancerRoute(
'loadbalancer.example.com',
['server1.local', 'server2.local', 'server3.local'],
8080,
{
tls: {
mode: 'terminate',
certificate: 'auto'
}
}
);
expect(route.match.domains).toEqual('loadbalancer.example.com');
expect(route.match.ports).toEqual(443);
expect(route.action.type).toEqual('forward');
expect(route.action.targets).toBeDefined();
if (route.action.targets && Array.isArray(route.action.targets[0]?.host)) {
expect((route.action.targets[0].host as string[]).length).toEqual(3);
}
expect(route.action.targets?.[0]?.port).toEqual(8080);
expect(route.action.tls.mode).toEqual('terminate');
const validationResult = validateRouteConfig(route);
expect(validationResult.valid).toBeTrue();
});
// --------------------------------- Route Pattern Tests ---------------------------------
tap.test('Route Patterns - createApiGatewayRoute', async () => {
// Create API Gateway route
const apiGatewayRoute = createApiGatewayRoute(
'api.example.com',
'/v1',
{ host: 'localhost', port: 3000 },
{
useTls: true,
addCorsHeaders: true
}
);
// Validate route configuration
expect(apiGatewayRoute.match.domains).toEqual('api.example.com');
expect(apiGatewayRoute.match.path).toInclude('/v1');
expect(apiGatewayRoute.action.type).toEqual('forward');
expect(apiGatewayRoute.action.targets?.[0]?.port).toEqual(3000);
// Check TLS configuration
if (apiGatewayRoute.action.tls) {
expect(apiGatewayRoute.action.tls.mode).toEqual('terminate');
}
// Check CORS headers
if (apiGatewayRoute.headers && apiGatewayRoute.headers.response) {
expect(apiGatewayRoute.headers.response['Access-Control-Allow-Origin']).toEqual('*');
}
const result = validateRouteConfig(apiGatewayRoute);
expect(result.valid).toBeTrue();
});
// createStaticFileServerRoute has been removed - static file serving should be handled by
// external servers (nginx/apache) behind the proxy
tap.test('Route Patterns - createWebSocketPattern', async () => {
// Create WebSocket route pattern
const wsRoute = createWebSocketPattern(
'ws.example.com',
{ host: 'localhost', port: 3000 },
{
useTls: true,
path: '/socket',
pingInterval: 10000
}
);
// Validate route configuration
expect(wsRoute.match.domains).toEqual('ws.example.com');
expect(wsRoute.match.path).toEqual('/socket');
expect(wsRoute.action.type).toEqual('forward');
expect(wsRoute.action.targets?.[0]?.port).toEqual(3000);
// Check TLS configuration
if (wsRoute.action.tls) {
expect(wsRoute.action.tls.mode).toEqual('terminate');
}
// Check websocket configuration if it exists
if (wsRoute.action.websocket) {
expect(wsRoute.action.websocket.enabled).toBeTrue();
expect(wsRoute.action.websocket.pingInterval).toEqual(10000);
}
const result = validateRouteConfig(wsRoute);
expect(result.valid).toBeTrue();
});
tap.test('Route Patterns - createLoadBalancerRoute pattern', async () => {
// Create load balancer route pattern with missing algorithm as it might not be implemented yet
try {
const lbRoute = createLbPattern(
'lb.example.com',
[
{ host: 'server1.local', port: 8080 },
{ host: 'server2.local', port: 8080 },
{ host: 'server3.local', port: 8080 }
],
{
useTls: true
}
);
// Validate route configuration
expect(lbRoute.match.domains).toEqual('lb.example.com');
expect(lbRoute.action.type).toEqual('forward');
// Check target hosts
if (lbRoute.action.targets && Array.isArray(lbRoute.action.targets[0]?.host)) {
expect((lbRoute.action.targets[0].host as string[]).length).toEqual(3);
}
// Check TLS configuration
if (lbRoute.action.tls) {
expect(lbRoute.action.tls.mode).toEqual('terminate');
}
const result = validateRouteConfig(lbRoute);
expect(result.valid).toBeTrue();
} catch (error) {
// If the pattern is not implemented yet, skip this test
console.log('Load balancer pattern might not be fully implemented yet');
}
});
tap.test('Route Security - addRateLimiting', async () => {
// Create base route
const baseRoute = createHttpRoute('example.com', { host: 'localhost', port: 3000 });
// Add rate limiting
const secureRoute = addRateLimiting(baseRoute, {
maxRequests: 100,
window: 60, // 1 minute
keyBy: 'ip'
});
// Check if rate limiting is applied
if (secureRoute.security) {
expect(secureRoute.security.rateLimit?.enabled).toBeTrue();
expect(secureRoute.security.rateLimit?.maxRequests).toEqual(100);
expect(secureRoute.security.rateLimit?.window).toEqual(60);
expect(secureRoute.security.rateLimit?.keyBy).toEqual('ip');
} else {
// Skip this test if security features are not implemented yet
console.log('Security features not implemented yet in route configuration');
}
// Just check that the route itself is valid
const result = validateRouteConfig(secureRoute);
expect(result.valid).toBeTrue();
});
tap.test('Route Security - addBasicAuth', async () => {
// Create base route
const baseRoute = createHttpRoute('example.com', { host: 'localhost', port: 3000 });
// Add basic authentication
const authRoute = addBasicAuth(baseRoute, {
users: [
{ username: 'admin', password: 'secret' },
{ username: 'user', password: 'password' }
],
realm: 'Protected Area',
excludePaths: ['/public']
});
// Check if basic auth is applied
if (authRoute.security) {
expect(authRoute.security.basicAuth?.enabled).toBeTrue();
expect(authRoute.security.basicAuth?.users.length).toEqual(2);
expect(authRoute.security.basicAuth?.realm).toEqual('Protected Area');
expect(authRoute.security.basicAuth?.excludePaths).toInclude('/public');
} else {
// Skip this test if security features are not implemented yet
console.log('Security features not implemented yet in route configuration');
}
// Check that the route itself is valid
const result = validateRouteConfig(authRoute);
expect(result.valid).toBeTrue();
});
tap.test('Route Security - addJwtAuth', async () => {
// Create base route
const baseRoute = createHttpRoute('example.com', { host: 'localhost', port: 3000 });
// Add JWT authentication
const jwtRoute = addJwtAuth(baseRoute, {
secret: 'your-jwt-secret-key',
algorithm: 'HS256',
issuer: 'auth.example.com',
audience: 'api.example.com',
expiresIn: 3600
});
// Check if JWT auth is applied
if (jwtRoute.security) {
expect(jwtRoute.security.jwtAuth?.enabled).toBeTrue();
expect(jwtRoute.security.jwtAuth?.secret).toEqual('your-jwt-secret-key');
expect(jwtRoute.security.jwtAuth?.algorithm).toEqual('HS256');
expect(jwtRoute.security.jwtAuth?.issuer).toEqual('auth.example.com');
expect(jwtRoute.security.jwtAuth?.audience).toEqual('api.example.com');
expect(jwtRoute.security.jwtAuth?.expiresIn).toEqual(3600);
} else {
// Skip this test if security features are not implemented yet
console.log('Security features not implemented yet in route configuration');
}
// Check that the route itself is valid
const result = validateRouteConfig(jwtRoute);
expect(result.valid).toBeTrue();
});
export default tap.start();
export default tap.start();

View File

@@ -1,403 +0,0 @@
import { expect, tap } from '@git.zone/tstest/tapbundle';
import * as http from 'http';
import { HttpRouter, type RouterResult } from '../ts/routing/router/http-router.js';
import type { IRouteConfig } from '../ts/proxies/smart-proxy/models/route-types.js';
// Test proxies and configurations
let router: HttpRouter;
// Sample hostname for testing
const TEST_DOMAIN = 'example.com';
const TEST_SUBDOMAIN = 'api.example.com';
const TEST_WILDCARD = '*.example.com';
// Helper: Creates a mock HTTP request for testing
function createMockRequest(host: string, url: string = '/'): http.IncomingMessage {
const req = {
headers: { host },
url,
socket: {
remoteAddress: '127.0.0.1'
}
} as any;
return req;
}
// Helper: Creates a test route configuration
function createRouteConfig(
hostname: string,
destinationIp: string = '10.0.0.1',
destinationPort: number = 8080
): IRouteConfig {
return {
name: `route-${hostname}`,
match: {
domains: [hostname],
ports: 443
},
action: {
type: 'forward',
targets: [{
host: destinationIp,
port: destinationPort
}]
}
};
}
// SETUP: Create an HttpRouter instance
tap.test('setup http router test environment', async () => {
router = new HttpRouter();
// Initialize with empty config
router.setRoutes([]);
});
// Test basic routing by hostname
tap.test('should route requests by hostname', async () => {
const config = createRouteConfig(TEST_DOMAIN);
router.setRoutes([config]);
const req = createMockRequest(TEST_DOMAIN);
const result = router.routeReq(req);
expect(result).toBeTruthy();
expect(result).toEqual(config);
});
// Test handling of hostname with port number
tap.test('should handle hostname with port number', async () => {
const config = createRouteConfig(TEST_DOMAIN);
router.setRoutes([config]);
const req = createMockRequest(`${TEST_DOMAIN}:443`);
const result = router.routeReq(req);
expect(result).toBeTruthy();
expect(result).toEqual(config);
});
// Test case-insensitive hostname matching
tap.test('should perform case-insensitive hostname matching', async () => {
const config = createRouteConfig(TEST_DOMAIN.toLowerCase());
router.setRoutes([config]);
const req = createMockRequest(TEST_DOMAIN.toUpperCase());
const result = router.routeReq(req);
expect(result).toBeTruthy();
expect(result).toEqual(config);
});
// Test handling of unmatched hostnames
tap.test('should return undefined for unmatched hostnames', async () => {
const config = createRouteConfig(TEST_DOMAIN);
router.setRoutes([config]);
const req = createMockRequest('unknown.domain.com');
const result = router.routeReq(req);
expect(result).toBeUndefined();
});
// Test adding path patterns
tap.test('should match requests using path patterns', async () => {
const config = createRouteConfig(TEST_DOMAIN);
config.match.path = '/api/users';
router.setRoutes([config]);
// Test that path matches
const req1 = createMockRequest(TEST_DOMAIN, '/api/users');
const result1 = router.routeReqWithDetails(req1);
expect(result1).toBeTruthy();
expect(result1.route).toEqual(config);
expect(result1.pathMatch).toEqual('/api/users');
// Test that non-matching path doesn't match
const req2 = createMockRequest(TEST_DOMAIN, '/web/users');
const result2 = router.routeReqWithDetails(req2);
expect(result2).toBeUndefined();
});
// Test handling wildcard patterns
tap.test('should support wildcard path patterns', async () => {
const config = createRouteConfig(TEST_DOMAIN);
config.match.path = '/api/*';
router.setRoutes([config]);
// Test with path that matches the wildcard pattern
const req = createMockRequest(TEST_DOMAIN, '/api/users/123');
const result = router.routeReqWithDetails(req);
expect(result).toBeTruthy();
expect(result.route).toEqual(config);
expect(result.pathMatch).toEqual('/api');
// Print the actual value to diagnose issues
console.log('Path remainder value:', result.pathRemainder);
expect(result.pathRemainder).toBeTruthy();
expect(result.pathRemainder).toEqual('/users/123');
});
// Test extracting path parameters
tap.test('should extract path parameters from URL', async () => {
const config = createRouteConfig(TEST_DOMAIN);
config.match.path = '/users/:id/profile';
router.setRoutes([config]);
const req = createMockRequest(TEST_DOMAIN, '/users/123/profile');
const result = router.routeReqWithDetails(req);
expect(result).toBeTruthy();
expect(result.route).toEqual(config);
expect(result.pathParams).toBeTruthy();
expect(result.pathParams.id).toEqual('123');
});
// Test multiple configs for same hostname with different paths
tap.test('should support multiple configs for same hostname with different paths', async () => {
const apiConfig = createRouteConfig(TEST_DOMAIN, '10.0.0.1', 8001);
apiConfig.match.path = '/api/*';
apiConfig.name = 'api-route';
const webConfig = createRouteConfig(TEST_DOMAIN, '10.0.0.2', 8002);
webConfig.match.path = '/web/*';
webConfig.name = 'web-route';
// Add both configs
router.setRoutes([apiConfig, webConfig]);
// Test API path routes to API config
const apiReq = createMockRequest(TEST_DOMAIN, '/api/users');
const apiResult = router.routeReq(apiReq);
expect(apiResult).toEqual(apiConfig);
// Test web path routes to web config
const webReq = createMockRequest(TEST_DOMAIN, '/web/dashboard');
const webResult = router.routeReq(webReq);
expect(webResult).toEqual(webConfig);
// Test unknown path returns undefined
const unknownReq = createMockRequest(TEST_DOMAIN, '/unknown');
const unknownResult = router.routeReq(unknownReq);
expect(unknownResult).toBeUndefined();
});
// Test wildcard subdomains
tap.test('should match wildcard subdomains', async () => {
const wildcardConfig = createRouteConfig(TEST_WILDCARD);
router.setRoutes([wildcardConfig]);
// Test that subdomain.example.com matches *.example.com
const req = createMockRequest('subdomain.example.com');
const result = router.routeReq(req);
expect(result).toBeTruthy();
expect(result).toEqual(wildcardConfig);
});
// Test TLD wildcards (example.*)
tap.test('should match TLD wildcards', async () => {
const tldWildcardConfig = createRouteConfig('example.*');
router.setRoutes([tldWildcardConfig]);
// Test that example.com matches example.*
const req1 = createMockRequest('example.com');
const result1 = router.routeReq(req1);
expect(result1).toBeTruthy();
expect(result1).toEqual(tldWildcardConfig);
// Test that example.org matches example.*
const req2 = createMockRequest('example.org');
const result2 = router.routeReq(req2);
expect(result2).toBeTruthy();
expect(result2).toEqual(tldWildcardConfig);
// Test that subdomain.example.com doesn't match example.*
const req3 = createMockRequest('subdomain.example.com');
const result3 = router.routeReq(req3);
expect(result3).toBeUndefined();
});
// Test complex pattern matching (*.lossless*)
tap.test('should match complex wildcard patterns', async () => {
const complexWildcardConfig = createRouteConfig('*.lossless*');
router.setRoutes([complexWildcardConfig]);
// Test that sub.lossless.com matches *.lossless*
const req1 = createMockRequest('sub.lossless.com');
const result1 = router.routeReq(req1);
expect(result1).toBeTruthy();
expect(result1).toEqual(complexWildcardConfig);
// Test that api.lossless.org matches *.lossless*
const req2 = createMockRequest('api.lossless.org');
const result2 = router.routeReq(req2);
expect(result2).toBeTruthy();
expect(result2).toEqual(complexWildcardConfig);
// Test that losslessapi.com matches *.lossless*
const req3 = createMockRequest('losslessapi.com');
const result3 = router.routeReq(req3);
expect(result3).toBeUndefined(); // Should not match as it doesn't have a subdomain
});
// Test default configuration fallback
tap.test('should fall back to default configuration', async () => {
const defaultConfig = createRouteConfig('*');
const specificConfig = createRouteConfig(TEST_DOMAIN);
router.setRoutes([specificConfig, defaultConfig]);
// Test specific domain routes to specific config
const specificReq = createMockRequest(TEST_DOMAIN);
const specificResult = router.routeReq(specificReq);
expect(specificResult).toEqual(specificConfig);
// Test unknown domain falls back to default config
const unknownReq = createMockRequest('unknown.com');
const unknownResult = router.routeReq(unknownReq);
expect(unknownResult).toEqual(defaultConfig);
});
// Test priority between exact and wildcard matches
tap.test('should prioritize exact hostname over wildcard', async () => {
const wildcardConfig = createRouteConfig(TEST_WILDCARD);
const exactConfig = createRouteConfig(TEST_SUBDOMAIN);
router.setRoutes([exactConfig, wildcardConfig]);
// Test that exact match takes priority
const req = createMockRequest(TEST_SUBDOMAIN);
const result = router.routeReq(req);
expect(result).toEqual(exactConfig);
});
// Test adding and removing configurations
tap.test('should manage configurations correctly', async () => {
router.setRoutes([]);
// Add a config
const config = createRouteConfig(TEST_DOMAIN);
router.setRoutes([config]);
// Verify routing works
const req = createMockRequest(TEST_DOMAIN);
let result = router.routeReq(req);
expect(result).toEqual(config);
// Remove the config and verify it no longer routes
router.setRoutes([]);
result = router.routeReq(req);
expect(result).toBeUndefined();
});
// Test path pattern specificity
tap.test('should prioritize more specific path patterns', async () => {
const genericConfig = createRouteConfig(TEST_DOMAIN, '10.0.0.1', 8001);
genericConfig.match.path = '/api/*';
genericConfig.name = 'generic-api';
const specificConfig = createRouteConfig(TEST_DOMAIN, '10.0.0.2', 8002);
specificConfig.match.path = '/api/users';
specificConfig.name = 'specific-api';
specificConfig.priority = 10; // Higher priority
router.setRoutes([genericConfig, specificConfig]);
// The more specific '/api/users' should match before the '/api/*' wildcard
const req = createMockRequest(TEST_DOMAIN, '/api/users');
const result = router.routeReq(req);
expect(result).toEqual(specificConfig);
});
// Test multiple hostnames
tap.test('should handle multiple configured hostnames', async () => {
const routes = [
createRouteConfig(TEST_DOMAIN),
createRouteConfig(TEST_SUBDOMAIN)
];
router.setRoutes(routes);
// Test first domain routes correctly
const req1 = createMockRequest(TEST_DOMAIN);
const result1 = router.routeReq(req1);
expect(result1).toEqual(routes[0]);
// Test second domain routes correctly
const req2 = createMockRequest(TEST_SUBDOMAIN);
const result2 = router.routeReq(req2);
expect(result2).toEqual(routes[1]);
});
// Test handling missing host header
tap.test('should handle missing host header', async () => {
const defaultConfig = createRouteConfig('*');
router.setRoutes([defaultConfig]);
const req = createMockRequest('');
req.headers.host = undefined;
const result = router.routeReq(req);
expect(result).toEqual(defaultConfig);
});
// Test complex path parameters
tap.test('should handle complex path parameters', async () => {
const config = createRouteConfig(TEST_DOMAIN);
config.match.path = '/api/:version/users/:userId/posts/:postId';
router.setRoutes([config]);
const req = createMockRequest(TEST_DOMAIN, '/api/v1/users/123/posts/456');
const result = router.routeReqWithDetails(req);
expect(result).toBeTruthy();
expect(result.route).toEqual(config);
expect(result.pathParams).toBeTruthy();
expect(result.pathParams.version).toEqual('v1');
expect(result.pathParams.userId).toEqual('123');
expect(result.pathParams.postId).toEqual('456');
});
// Performance test
tap.test('should handle many configurations efficiently', async () => {
const configs = [];
// Create many configs with different hostnames
for (let i = 0; i < 100; i++) {
configs.push(createRouteConfig(`host-${i}.example.com`));
}
router.setRoutes(configs);
// Test middle of the list to avoid best/worst case
const req = createMockRequest('host-50.example.com');
const result = router.routeReq(req);
expect(result).toEqual(configs[50]);
});
// Test cleanup
tap.test('cleanup proxy router test environment', async () => {
// Clear all configurations
router.setRoutes([]);
// Verify empty state by testing that no routes match
const req = createMockRequest(TEST_DOMAIN);
const result = router.routeReq(req);
expect(result).toBeUndefined();
});
export default tap.start();

View File

@@ -1,157 +0,0 @@
import { expect, tap } from '@git.zone/tstest/tapbundle';
import { SharedSecurityManager } from '../ts/core/utils/shared-security-manager.js';
import type { IRouteConfig, IRouteContext } from '../ts/proxies/smart-proxy/models/route-types.js';
let securityManager: SharedSecurityManager;
tap.test('Setup SharedSecurityManager', async () => {
securityManager = new SharedSecurityManager({
maxConnectionsPerIP: 5,
connectionRateLimitPerMinute: 10,
cleanupIntervalMs: 1000 // 1 second for faster testing
});
});
tap.test('IP connection tracking', async () => {
const testIP = '192.168.1.100';
// Track multiple connections
securityManager.trackConnectionByIP(testIP, 'conn1');
securityManager.trackConnectionByIP(testIP, 'conn2');
securityManager.trackConnectionByIP(testIP, 'conn3');
// Verify connection count
expect(securityManager.getConnectionCountByIP(testIP)).toEqual(3);
// Remove a connection
securityManager.removeConnectionByIP(testIP, 'conn2');
expect(securityManager.getConnectionCountByIP(testIP)).toEqual(2);
// Remove remaining connections
securityManager.removeConnectionByIP(testIP, 'conn1');
securityManager.removeConnectionByIP(testIP, 'conn3');
expect(securityManager.getConnectionCountByIP(testIP)).toEqual(0);
});
tap.test('Per-IP connection limits validation', async () => {
const testIP = '192.168.1.101';
// Track connections up to limit
for (let i = 1; i <= 5; i++) {
// Validate BEFORE tracking the connection (checking if we can add a new connection)
const result = securityManager.validateIP(testIP);
expect(result.allowed).toBeTrue();
// Now track the connection
securityManager.trackConnectionByIP(testIP, `conn${i}`);
}
// Verify we're at the limit
expect(securityManager.getConnectionCountByIP(testIP)).toEqual(5);
// Next connection should be rejected (we're already at 5)
const result = securityManager.validateIP(testIP);
expect(result.allowed).toBeFalse();
expect(result.reason).toInclude('Maximum connections per IP');
// Clean up
for (let i = 1; i <= 5; i++) {
securityManager.removeConnectionByIP(testIP, `conn${i}`);
}
});
tap.test('Connection rate limiting', async () => {
const testIP = '192.168.1.102';
// Make connections at the rate limit
// Note: validateIP() already tracks timestamps internally for rate limiting
for (let i = 0; i < 10; i++) {
const result = securityManager.validateIP(testIP);
expect(result.allowed).toBeTrue();
}
// Next connection should exceed rate limit
const result = securityManager.validateIP(testIP);
expect(result.allowed).toBeFalse();
expect(result.reason).toInclude('Connection rate limit');
});
tap.test('Route-level connection limits', async () => {
const route: IRouteConfig = {
name: 'test-route',
match: { ports: 443 },
action: { type: 'forward', targets: [{ host: 'localhost', port: 8080 }] },
security: {
maxConnections: 3
}
};
const context: IRouteContext = {
port: 443,
clientIp: '192.168.1.103',
serverIp: '0.0.0.0',
timestamp: Date.now(),
connectionId: 'test-conn',
isTls: true
};
// Test with connection counts below limit
expect(securityManager.isAllowed(route, context, 0)).toBeTrue();
expect(securityManager.isAllowed(route, context, 2)).toBeTrue();
// Test at limit
expect(securityManager.isAllowed(route, context, 3)).toBeFalse();
// Test above limit
expect(securityManager.isAllowed(route, context, 5)).toBeFalse();
});
tap.test('IPv4/IPv6 normalization', async () => {
const ipv4 = '127.0.0.1';
const ipv4Mapped = '::ffff:127.0.0.1';
// Track connection with IPv4
securityManager.trackConnectionByIP(ipv4, 'conn1');
// Both representations should show the same connection
expect(securityManager.getConnectionCountByIP(ipv4)).toEqual(1);
expect(securityManager.getConnectionCountByIP(ipv4Mapped)).toEqual(1);
// Track another connection with IPv6 representation
securityManager.trackConnectionByIP(ipv4Mapped, 'conn2');
// Both should show 2 connections
expect(securityManager.getConnectionCountByIP(ipv4)).toEqual(2);
expect(securityManager.getConnectionCountByIP(ipv4Mapped)).toEqual(2);
// Clean up
securityManager.removeConnectionByIP(ipv4, 'conn1');
securityManager.removeConnectionByIP(ipv4Mapped, 'conn2');
});
tap.test('Automatic cleanup of expired data', async (tools) => {
const testIP = '192.168.1.104';
// Track a connection and then remove it
securityManager.trackConnectionByIP(testIP, 'temp-conn');
securityManager.removeConnectionByIP(testIP, 'temp-conn');
// Add some rate limit entries (they expire after 1 minute)
for (let i = 0; i < 5; i++) {
securityManager.validateIP(testIP);
}
// Wait for cleanup interval (set to 1 second in our test)
await tools.delayFor(1500);
// The IP should be cleaned up since it has no connections
// Note: We can't directly check the internal map, but we can verify
// that a new connection is allowed (fresh rate limit)
const result = securityManager.validateIP(testIP);
expect(result.allowed).toBeTrue();
});
tap.test('Cleanup SharedSecurityManager', async () => {
securityManager.clearIPTracking();
});
export default tap.start();

View File

@@ -1,315 +0,0 @@
import { expect, tap } from '@git.zone/tstest/tapbundle';
import * as plugins from '../ts/plugins.js';
import { WrappedSocket } from '../ts/core/models/wrapped-socket.js';
import * as net from 'net';
tap.test('WrappedSocket - should wrap a regular socket', async () => {
// Create a simple test server
const server = net.createServer();
await new Promise<void>((resolve) => {
server.listen(0, 'localhost', () => resolve());
});
const serverPort = (server.address() as net.AddressInfo).port;
// Create a client connection
const clientSocket = net.connect(serverPort, 'localhost');
// Wrap the socket
const wrappedSocket = new WrappedSocket(clientSocket);
// Test initial state - should use underlying socket values
expect(wrappedSocket.remoteAddress).toEqual(clientSocket.remoteAddress);
expect(wrappedSocket.remotePort).toEqual(clientSocket.remotePort);
expect(wrappedSocket.localAddress).toEqual(clientSocket.localAddress);
expect(wrappedSocket.localPort).toEqual(clientSocket.localPort);
expect(wrappedSocket.isFromTrustedProxy).toBeFalse();
// Clean up
clientSocket.destroy();
server.close();
});
tap.test('WrappedSocket - should provide real client info when set', async () => {
// Create a simple test server
const server = net.createServer();
await new Promise<void>((resolve) => {
server.listen(0, 'localhost', () => resolve());
});
const serverPort = (server.address() as net.AddressInfo).port;
// Create a client connection
const clientSocket = net.connect(serverPort, 'localhost');
// Wrap the socket with initial proxy info
const wrappedSocket = new WrappedSocket(clientSocket, '192.168.1.100', 54321);
// Test that real client info is returned
expect(wrappedSocket.remoteAddress).toEqual('192.168.1.100');
expect(wrappedSocket.remotePort).toEqual(54321);
expect(wrappedSocket.isFromTrustedProxy).toBeTrue();
// Local info should still come from underlying socket
expect(wrappedSocket.localAddress).toEqual(clientSocket.localAddress);
expect(wrappedSocket.localPort).toEqual(clientSocket.localPort);
// Clean up
clientSocket.destroy();
server.close();
});
tap.test('WrappedSocket - should update proxy info via setProxyInfo', async () => {
// Create a simple test server
const server = net.createServer();
await new Promise<void>((resolve) => {
server.listen(0, 'localhost', () => resolve());
});
const serverPort = (server.address() as net.AddressInfo).port;
// Create a client connection
const clientSocket = net.connect(serverPort, 'localhost');
// Wrap the socket without initial proxy info
const wrappedSocket = new WrappedSocket(clientSocket);
// Initially should use underlying socket
expect(wrappedSocket.isFromTrustedProxy).toBeFalse();
expect(wrappedSocket.remoteAddress).toEqual(clientSocket.remoteAddress);
// Update proxy info
wrappedSocket.setProxyInfo('10.0.0.5', 12345);
// Now should return proxy info
expect(wrappedSocket.remoteAddress).toEqual('10.0.0.5');
expect(wrappedSocket.remotePort).toEqual(12345);
expect(wrappedSocket.isFromTrustedProxy).toBeTrue();
// Clean up
clientSocket.destroy();
server.close();
});
tap.test('WrappedSocket - should correctly determine IP family', async () => {
// Create a simple test server
const server = net.createServer();
await new Promise<void>((resolve) => {
server.listen(0, 'localhost', () => resolve());
});
const serverPort = (server.address() as net.AddressInfo).port;
// Create a client connection
const clientSocket = net.connect(serverPort, 'localhost');
// Test IPv4
const wrappedSocketIPv4 = new WrappedSocket(clientSocket, '192.168.1.1', 80);
expect(wrappedSocketIPv4.remoteFamily).toEqual('IPv4');
// Test IPv6
const wrappedSocketIPv6 = new WrappedSocket(clientSocket, '2001:0db8:85a3:0000:0000:8a2e:0370:7334', 443);
expect(wrappedSocketIPv6.remoteFamily).toEqual('IPv6');
// Test fallback to underlying socket
const wrappedSocketNoProxy = new WrappedSocket(clientSocket);
expect(wrappedSocketNoProxy.remoteFamily).toEqual(clientSocket.remoteFamily);
// Clean up
clientSocket.destroy();
server.close();
});
tap.test('WrappedSocket - should forward events correctly', async () => {
// Create a simple echo server
let serverConnection: net.Socket;
const server = net.createServer((socket) => {
serverConnection = socket;
socket.on('data', (data) => {
socket.write(data); // Echo back
});
});
await new Promise<void>((resolve) => {
server.listen(0, 'localhost', () => resolve());
});
const serverPort = (server.address() as net.AddressInfo).port;
// Create a client connection
const clientSocket = net.connect(serverPort, 'localhost');
// Wrap the socket
const wrappedSocket = new WrappedSocket(clientSocket);
// Set up event tracking
let connectReceived = false;
let dataReceived = false;
let endReceived = false;
let closeReceived = false;
wrappedSocket.on('connect', () => {
connectReceived = true;
});
wrappedSocket.on('data', (chunk) => {
dataReceived = true;
expect(chunk.toString()).toEqual('test data');
});
wrappedSocket.on('end', () => {
endReceived = true;
});
wrappedSocket.on('close', () => {
closeReceived = true;
});
// Wait for connection
await new Promise<void>((resolve) => {
if (clientSocket.readyState === 'open') {
resolve();
} else {
clientSocket.once('connect', () => resolve());
}
});
// Send data
wrappedSocket.write('test data');
// Wait for echo
await new Promise(resolve => setTimeout(resolve, 100));
// Close the connection
serverConnection.end();
// Wait for events
await new Promise(resolve => setTimeout(resolve, 100));
// Verify all events were received
expect(dataReceived).toBeTrue();
expect(endReceived).toBeTrue();
expect(closeReceived).toBeTrue();
// Clean up
server.close();
});
tap.test('WrappedSocket - should pass through socket methods', async () => {
// Create a simple test server
const server = net.createServer();
await new Promise<void>((resolve) => {
server.listen(0, 'localhost', () => resolve());
});
const serverPort = (server.address() as net.AddressInfo).port;
// Create a client connection
const clientSocket = net.connect(serverPort, 'localhost');
await new Promise<void>((resolve) => {
clientSocket.once('connect', () => resolve());
});
// Wrap the socket
const wrappedSocket = new WrappedSocket(clientSocket);
// Test various pass-through methods
expect(wrappedSocket.readable).toEqual(clientSocket.readable);
expect(wrappedSocket.writable).toEqual(clientSocket.writable);
expect(wrappedSocket.destroyed).toEqual(clientSocket.destroyed);
expect(wrappedSocket.bytesRead).toEqual(clientSocket.bytesRead);
expect(wrappedSocket.bytesWritten).toEqual(clientSocket.bytesWritten);
// Test method calls
wrappedSocket.pause();
expect(clientSocket.isPaused()).toBeTrue();
wrappedSocket.resume();
expect(clientSocket.isPaused()).toBeFalse();
// Test setTimeout
let timeoutCalled = false;
wrappedSocket.setTimeout(100, () => {
timeoutCalled = true;
});
await new Promise(resolve => setTimeout(resolve, 150));
expect(timeoutCalled).toBeTrue();
// Clean up
wrappedSocket.destroy();
server.close();
});
tap.test('WrappedSocket - should handle write and pipe operations', async () => {
// Create a simple echo server
const server = net.createServer((socket) => {
socket.pipe(socket); // Echo everything back
});
await new Promise<void>((resolve) => {
server.listen(0, 'localhost', () => resolve());
});
const serverPort = (server.address() as net.AddressInfo).port;
// Create a client connection
const clientSocket = net.connect(serverPort, 'localhost');
await new Promise<void>((resolve) => {
clientSocket.once('connect', () => resolve());
});
// Wrap the socket
const wrappedSocket = new WrappedSocket(clientSocket);
// Test write with callback
const writeResult = wrappedSocket.write('test', 'utf8', () => {
// Write completed
});
expect(typeof writeResult).toEqual('boolean');
// Test pipe
const { PassThrough } = await import('stream');
const passThrough = new PassThrough();
const piped = wrappedSocket.pipe(passThrough);
expect(piped).toEqual(passThrough);
// Clean up
wrappedSocket.destroy();
server.close();
});
tap.test('WrappedSocket - should handle encoding and address methods', async () => {
// Create a simple test server
const server = net.createServer();
await new Promise<void>((resolve) => {
server.listen(0, 'localhost', () => resolve());
});
const serverPort = (server.address() as net.AddressInfo).port;
// Create a client connection
const clientSocket = net.connect(serverPort, 'localhost');
await new Promise<void>((resolve) => {
clientSocket.once('connect', () => resolve());
});
// Wrap the socket
const wrappedSocket = new WrappedSocket(clientSocket);
// Test setEncoding
wrappedSocket.setEncoding('utf8');
// Test address method
const addr = wrappedSocket.address();
expect(addr).toEqual(clientSocket.address());
// Test cork/uncork (if available)
wrappedSocket.cork();
wrappedSocket.uncork();
// Clean up
wrappedSocket.destroy();
server.close();
});
export default tap.start();

View File

@@ -3,6 +3,6 @@
*/
export const commitinfo = {
name: '@push.rocks/smartproxy',
version: '25.16.1',
version: '27.0.0',
description: 'A powerful proxy package with unified route-based configuration for high traffic management. Features include SSL/TLS support, flexible routing patterns, WebSocket handling, advanced security options, and automatic ACME certificate management.'
}

View File

@@ -1,3 +0,0 @@
/**
* Common event definitions
*/

View File

@@ -5,4 +5,3 @@
// Export submodules
export * from './models/index.js';
export * from './utils/index.js';
export * from './events/index.js';

View File

@@ -3,7 +3,6 @@
*/
export * from './common-types.js';
export * from './socket-augmentation.js';
export * from './route-context.js';
export * from './wrapped-socket.js';
export * from './socket-types.js';

View File

@@ -1,38 +0,0 @@
import * as plugins from '../../plugins.js';
// Augment the Node.js Socket type to include TLS-related properties
// This helps TypeScript understand properties that are dynamically added by Node.js
declare module 'net' {
interface Socket {
// TLS-related properties
encrypted?: boolean; // Indicates if the socket is encrypted (TLS/SSL)
authorizationError?: Error; // Authentication error if TLS handshake failed
// TLS-related methods
getTLSVersion?(): string; // Returns the TLS version (e.g., 'TLSv1.2', 'TLSv1.3')
getPeerCertificate?(detailed?: boolean): any; // Returns the peer's certificate
getSession?(): Buffer; // Returns the TLS session data
// Connection tracking properties (used by HttpProxy)
_connectionId?: string; // Unique identifier for the connection
_remoteIP?: string; // Remote IP address
_realRemoteIP?: string; // Real remote IP (when proxied)
}
}
// Export a utility function to check if a socket is a TLS socket
export function isTLSSocket(socket: plugins.net.Socket): boolean {
return 'encrypted' in socket && !!socket.encrypted;
}
// Export a utility function to safely get the TLS version
export function getTLSVersion(socket: plugins.net.Socket): string | null {
if (socket.getTLSVersion) {
try {
return socket.getTLSVersion();
} catch (e) {
return null;
}
}
return null;
}

View File

@@ -1,275 +0,0 @@
/**
* Async utility functions for SmartProxy
* Provides non-blocking alternatives to synchronous operations
*/
/**
* Delays execution for the specified number of milliseconds
* Non-blocking alternative to busy wait loops
* @param ms - Number of milliseconds to delay
* @returns Promise that resolves after the delay
*/
export async function delay(ms: number): Promise<void> {
return new Promise(resolve => setTimeout(resolve, ms));
}
/**
* Retry an async operation with exponential backoff
* @param fn - The async function to retry
* @param options - Retry options
* @returns The result of the function or throws the last error
*/
export async function retryWithBackoff<T>(
fn: () => Promise<T>,
options: {
maxAttempts?: number;
initialDelay?: number;
maxDelay?: number;
factor?: number;
onRetry?: (attempt: number, error: Error) => void;
} = {}
): Promise<T> {
const {
maxAttempts = 3,
initialDelay = 100,
maxDelay = 10000,
factor = 2,
onRetry
} = options;
let lastError: Error | null = null;
let currentDelay = initialDelay;
for (let attempt = 1; attempt <= maxAttempts; attempt++) {
try {
return await fn();
} catch (error: any) {
lastError = error;
if (attempt === maxAttempts) {
throw error;
}
if (onRetry) {
onRetry(attempt, error);
}
await delay(currentDelay);
currentDelay = Math.min(currentDelay * factor, maxDelay);
}
}
throw lastError || new Error('Retry failed');
}
/**
* Execute an async operation with a timeout
* @param fn - The async function to execute
* @param timeoutMs - Timeout in milliseconds
* @param timeoutError - Optional custom timeout error
* @returns The result of the function or throws timeout error
*/
export async function withTimeout<T>(
fn: () => Promise<T>,
timeoutMs: number,
timeoutError?: Error
): Promise<T> {
const timeoutPromise = new Promise<never>((_, reject) => {
setTimeout(() => {
reject(timeoutError || new Error(`Operation timed out after ${timeoutMs}ms`));
}, timeoutMs);
});
return Promise.race([fn(), timeoutPromise]);
}
/**
* Run multiple async operations in parallel with a concurrency limit
* @param items - Array of items to process
* @param fn - Async function to run for each item
* @param concurrency - Maximum number of concurrent operations
* @returns Array of results in the same order as input
*/
export async function parallelLimit<T, R>(
items: T[],
fn: (item: T, index: number) => Promise<R>,
concurrency: number
): Promise<R[]> {
const results: R[] = new Array(items.length);
const executing: Set<Promise<void>> = new Set();
for (let i = 0; i < items.length; i++) {
const promise = fn(items[i], i).then(result => {
results[i] = result;
executing.delete(promise);
});
executing.add(promise);
if (executing.size >= concurrency) {
await Promise.race(executing);
}
}
await Promise.all(executing);
return results;
}
/**
* Debounce an async function
* @param fn - The async function to debounce
* @param delayMs - Delay in milliseconds
* @returns Debounced function with cancel method
*/
export function debounceAsync<T extends (...args: any[]) => Promise<any>>(
fn: T,
delayMs: number
): T & { cancel: () => void } {
let timeoutId: NodeJS.Timeout | null = null;
let lastPromise: Promise<any> | null = null;
const debounced = ((...args: Parameters<T>) => {
if (timeoutId) {
clearTimeout(timeoutId);
}
lastPromise = new Promise((resolve, reject) => {
timeoutId = setTimeout(async () => {
timeoutId = null;
try {
const result = await fn(...args);
resolve(result);
} catch (error) {
reject(error);
}
}, delayMs);
});
return lastPromise;
}) as any;
debounced.cancel = () => {
if (timeoutId) {
clearTimeout(timeoutId);
timeoutId = null;
}
};
return debounced as T & { cancel: () => void };
}
/**
* Create a mutex for ensuring exclusive access to a resource
*/
export class AsyncMutex {
private queue: Array<() => void> = [];
private locked = false;
async acquire(): Promise<() => void> {
if (!this.locked) {
this.locked = true;
return () => this.release();
}
return new Promise<() => void>(resolve => {
this.queue.push(() => {
resolve(() => this.release());
});
});
}
private release(): void {
const next = this.queue.shift();
if (next) {
next();
} else {
this.locked = false;
}
}
async runExclusive<T>(fn: () => Promise<T>): Promise<T> {
const release = await this.acquire();
try {
return await fn();
} finally {
release();
}
}
}
/**
* Circuit breaker for protecting against cascading failures
*/
export class CircuitBreaker {
private failureCount = 0;
private lastFailureTime = 0;
private state: 'closed' | 'open' | 'half-open' = 'closed';
constructor(
private options: {
failureThreshold: number;
resetTimeout: number;
onStateChange?: (state: 'closed' | 'open' | 'half-open') => void;
}
) {}
async execute<T>(fn: () => Promise<T>): Promise<T> {
if (this.state === 'open') {
if (Date.now() - this.lastFailureTime > this.options.resetTimeout) {
this.setState('half-open');
} else {
throw new Error('Circuit breaker is open');
}
}
try {
const result = await fn();
this.onSuccess();
return result;
} catch (error) {
this.onFailure();
throw error;
}
}
private onSuccess(): void {
this.failureCount = 0;
if (this.state !== 'closed') {
this.setState('closed');
}
}
private onFailure(): void {
this.failureCount++;
this.lastFailureTime = Date.now();
if (this.failureCount >= this.options.failureThreshold) {
this.setState('open');
}
}
private setState(state: 'closed' | 'open' | 'half-open'): void {
if (this.state !== state) {
this.state = state;
if (this.options.onStateChange) {
this.options.onStateChange(state);
}
}
}
isOpen(): boolean {
return this.state === 'open';
}
getState(): 'closed' | 'open' | 'half-open' {
return this.state;
}
recordSuccess(): void {
this.onSuccess();
}
recordFailure(): void {
this.onFailure();
}
}

View File

@@ -1,225 +0,0 @@
/**
* A binary heap implementation for efficient priority queue operations
* Supports O(log n) insert and extract operations
*/
export class BinaryHeap<T> {
private heap: T[] = [];
private keyMap?: Map<string, number>; // For efficient key-based lookups
constructor(
private compareFn: (a: T, b: T) => number,
private extractKey?: (item: T) => string
) {
if (extractKey) {
this.keyMap = new Map();
}
}
/**
* Get the current size of the heap
*/
public get size(): number {
return this.heap.length;
}
/**
* Check if the heap is empty
*/
public isEmpty(): boolean {
return this.heap.length === 0;
}
/**
* Peek at the top element without removing it
*/
public peek(): T | undefined {
return this.heap[0];
}
/**
* Insert a new item into the heap
* O(log n) time complexity
*/
public insert(item: T): void {
const index = this.heap.length;
this.heap.push(item);
if (this.keyMap && this.extractKey) {
const key = this.extractKey(item);
this.keyMap.set(key, index);
}
this.bubbleUp(index);
}
/**
* Extract the top element from the heap
* O(log n) time complexity
*/
public extract(): T | undefined {
if (this.heap.length === 0) return undefined;
if (this.heap.length === 1) {
const item = this.heap.pop()!;
if (this.keyMap && this.extractKey) {
this.keyMap.delete(this.extractKey(item));
}
return item;
}
const result = this.heap[0];
const lastItem = this.heap.pop()!;
this.heap[0] = lastItem;
if (this.keyMap && this.extractKey) {
this.keyMap.delete(this.extractKey(result));
this.keyMap.set(this.extractKey(lastItem), 0);
}
this.bubbleDown(0);
return result;
}
/**
* Extract an element that matches the predicate
* O(n) time complexity for search, O(log n) for extraction
*/
public extractIf(predicate: (item: T) => boolean): T | undefined {
const index = this.heap.findIndex(predicate);
if (index === -1) return undefined;
return this.extractAt(index);
}
/**
* Extract an element by its key (if extractKey was provided)
* O(log n) time complexity
*/
public extractByKey(key: string): T | undefined {
if (!this.keyMap || !this.extractKey) {
throw new Error('extractKey function must be provided to use key-based extraction');
}
const index = this.keyMap.get(key);
if (index === undefined) return undefined;
return this.extractAt(index);
}
/**
* Check if a key exists in the heap
* O(1) time complexity
*/
public hasKey(key: string): boolean {
if (!this.keyMap) return false;
return this.keyMap.has(key);
}
/**
* Get all elements as an array (does not modify heap)
* O(n) time complexity
*/
public toArray(): T[] {
return [...this.heap];
}
/**
* Clear the heap
*/
public clear(): void {
this.heap = [];
if (this.keyMap) {
this.keyMap.clear();
}
}
/**
* Extract element at specific index
*/
private extractAt(index: number): T {
const item = this.heap[index];
if (this.keyMap && this.extractKey) {
this.keyMap.delete(this.extractKey(item));
}
if (index === this.heap.length - 1) {
this.heap.pop();
return item;
}
const lastItem = this.heap.pop()!;
this.heap[index] = lastItem;
if (this.keyMap && this.extractKey) {
this.keyMap.set(this.extractKey(lastItem), index);
}
// Try bubbling up first
const parentIndex = Math.floor((index - 1) / 2);
if (parentIndex >= 0 && this.compareFn(this.heap[index], this.heap[parentIndex]) < 0) {
this.bubbleUp(index);
} else {
this.bubbleDown(index);
}
return item;
}
/**
* Bubble up element at given index to maintain heap property
*/
private bubbleUp(index: number): void {
while (index > 0) {
const parentIndex = Math.floor((index - 1) / 2);
if (this.compareFn(this.heap[index], this.heap[parentIndex]) >= 0) {
break;
}
this.swap(index, parentIndex);
index = parentIndex;
}
}
/**
* Bubble down element at given index to maintain heap property
*/
private bubbleDown(index: number): void {
const length = this.heap.length;
while (true) {
const leftChild = 2 * index + 1;
const rightChild = 2 * index + 2;
let smallest = index;
if (leftChild < length &&
this.compareFn(this.heap[leftChild], this.heap[smallest]) < 0) {
smallest = leftChild;
}
if (rightChild < length &&
this.compareFn(this.heap[rightChild], this.heap[smallest]) < 0) {
smallest = rightChild;
}
if (smallest === index) break;
this.swap(index, smallest);
index = smallest;
}
}
/**
* Swap two elements in the heap
*/
private swap(i: number, j: number): void {
const temp = this.heap[i];
this.heap[i] = this.heap[j];
this.heap[j] = temp;
if (this.keyMap && this.extractKey) {
this.keyMap.set(this.extractKey(this.heap[i]), i);
this.keyMap.set(this.extractKey(this.heap[j]), j);
}
}
}

View File

@@ -1,425 +0,0 @@
import { LifecycleComponent } from './lifecycle-component.js';
import { BinaryHeap } from './binary-heap.js';
import { AsyncMutex } from './async-utils.js';
import { EventEmitter } from 'node:events';
/**
* Interface for pooled connection
*/
export interface IPooledConnection<T> {
id: string;
connection: T;
createdAt: number;
lastUsedAt: number;
useCount: number;
inUse: boolean;
metadata?: any;
}
/**
* Configuration options for the connection pool
*/
export interface IConnectionPoolOptions<T> {
minSize?: number;
maxSize?: number;
acquireTimeout?: number;
idleTimeout?: number;
maxUseCount?: number;
validateOnAcquire?: boolean;
validateOnReturn?: boolean;
queueTimeout?: number;
connectionFactory: () => Promise<T>;
connectionValidator?: (connection: T) => Promise<boolean>;
connectionDestroyer?: (connection: T) => Promise<void>;
onConnectionError?: (error: Error, connection?: T) => void;
}
/**
* Interface for queued acquire request
*/
interface IAcquireRequest<T> {
id: string;
priority: number;
timestamp: number;
resolve: (connection: IPooledConnection<T>) => void;
reject: (error: Error) => void;
timeoutHandle?: NodeJS.Timeout;
}
/**
* Enhanced connection pool with priority queue, backpressure, and lifecycle management
*/
export class EnhancedConnectionPool<T> extends LifecycleComponent {
private readonly options: Required<Omit<IConnectionPoolOptions<T>, 'connectionValidator' | 'connectionDestroyer' | 'onConnectionError'>> & Pick<IConnectionPoolOptions<T>, 'connectionValidator' | 'connectionDestroyer' | 'onConnectionError'>;
private readonly availableConnections: IPooledConnection<T>[] = [];
private readonly activeConnections: Map<string, IPooledConnection<T>> = new Map();
private readonly waitQueue: BinaryHeap<IAcquireRequest<T>>;
private readonly mutex = new AsyncMutex();
private readonly eventEmitter = new EventEmitter();
private connectionIdCounter = 0;
private requestIdCounter = 0;
private isClosing = false;
// Metrics
private metrics = {
connectionsCreated: 0,
connectionsDestroyed: 0,
connectionsAcquired: 0,
connectionsReleased: 0,
acquireTimeouts: 0,
validationFailures: 0,
queueHighWaterMark: 0,
};
constructor(options: IConnectionPoolOptions<T>) {
super();
this.options = {
minSize: 0,
maxSize: 10,
acquireTimeout: 30000,
idleTimeout: 300000, // 5 minutes
maxUseCount: Infinity,
validateOnAcquire: true,
validateOnReturn: false,
queueTimeout: 60000,
...options,
};
// Initialize priority queue (higher priority = extracted first)
this.waitQueue = new BinaryHeap<IAcquireRequest<T>>(
(a, b) => b.priority - a.priority || a.timestamp - b.timestamp,
(item) => item.id
);
// Start maintenance cycle
this.startMaintenance();
// Initialize minimum connections
this.initializeMinConnections();
}
/**
* Initialize minimum number of connections
*/
private async initializeMinConnections(): Promise<void> {
const promises: Promise<void>[] = [];
for (let i = 0; i < this.options.minSize; i++) {
promises.push(
this.createConnection()
.then(conn => {
this.availableConnections.push(conn);
})
.catch(err => {
if (this.options.onConnectionError) {
this.options.onConnectionError(err);
}
})
);
}
await Promise.all(promises);
}
/**
* Start maintenance timer for idle connection cleanup
*/
private startMaintenance(): void {
this.setInterval(() => {
this.performMaintenance();
}, 30000); // Every 30 seconds
}
/**
* Perform maintenance tasks
*/
private async performMaintenance(): Promise<void> {
await this.mutex.runExclusive(async () => {
const now = Date.now();
const toRemove: IPooledConnection<T>[] = [];
// Check for idle connections beyond minimum size
for (let i = this.availableConnections.length - 1; i >= 0; i--) {
const conn = this.availableConnections[i];
// Keep minimum connections
if (this.availableConnections.length <= this.options.minSize) {
break;
}
// Remove idle connections
if (now - conn.lastUsedAt > this.options.idleTimeout) {
toRemove.push(conn);
this.availableConnections.splice(i, 1);
}
}
// Destroy idle connections
for (const conn of toRemove) {
await this.destroyConnection(conn);
}
});
}
/**
* Acquire a connection from the pool
*/
public async acquire(priority: number = 0, timeout?: number): Promise<IPooledConnection<T>> {
if (this.isClosing) {
throw new Error('Connection pool is closing');
}
return this.mutex.runExclusive(async () => {
// Try to get an available connection
const connection = await this.tryAcquireConnection();
if (connection) {
return connection;
}
// Check if we can create a new connection
const totalConnections = this.availableConnections.length + this.activeConnections.size;
if (totalConnections < this.options.maxSize) {
try {
const newConnection = await this.createConnection();
return this.checkoutConnection(newConnection);
} catch (err) {
// Fall through to queue if creation fails
}
}
// Add to wait queue
return this.queueAcquireRequest(priority, timeout);
});
}
/**
* Try to acquire an available connection
*/
private async tryAcquireConnection(): Promise<IPooledConnection<T> | null> {
while (this.availableConnections.length > 0) {
const connection = this.availableConnections.shift()!;
// Check if connection exceeded max use count
if (connection.useCount >= this.options.maxUseCount) {
await this.destroyConnection(connection);
continue;
}
// Validate connection if required
if (this.options.validateOnAcquire && this.options.connectionValidator) {
try {
const isValid = await this.options.connectionValidator(connection.connection);
if (!isValid) {
this.metrics.validationFailures++;
await this.destroyConnection(connection);
continue;
}
} catch (err) {
this.metrics.validationFailures++;
await this.destroyConnection(connection);
continue;
}
}
return this.checkoutConnection(connection);
}
return null;
}
/**
* Checkout a connection for use
*/
private checkoutConnection(connection: IPooledConnection<T>): IPooledConnection<T> {
connection.inUse = true;
connection.lastUsedAt = Date.now();
connection.useCount++;
this.activeConnections.set(connection.id, connection);
this.metrics.connectionsAcquired++;
this.eventEmitter.emit('acquire', connection);
return connection;
}
/**
* Queue an acquire request
*/
private queueAcquireRequest(priority: number, timeout?: number): Promise<IPooledConnection<T>> {
return new Promise<IPooledConnection<T>>((resolve, reject) => {
const request: IAcquireRequest<T> = {
id: `req-${this.requestIdCounter++}`,
priority,
timestamp: Date.now(),
resolve,
reject,
};
// Set timeout
const timeoutMs = timeout || this.options.queueTimeout;
request.timeoutHandle = this.setTimeout(() => {
if (this.waitQueue.extractByKey(request.id)) {
this.metrics.acquireTimeouts++;
reject(new Error(`Connection acquire timeout after ${timeoutMs}ms`));
}
}, timeoutMs);
this.waitQueue.insert(request);
this.metrics.queueHighWaterMark = Math.max(
this.metrics.queueHighWaterMark,
this.waitQueue.size
);
this.eventEmitter.emit('enqueue', { queueSize: this.waitQueue.size });
});
}
/**
* Release a connection back to the pool
*/
public async release(connection: IPooledConnection<T>): Promise<void> {
return this.mutex.runExclusive(async () => {
if (!connection.inUse || !this.activeConnections.has(connection.id)) {
throw new Error('Connection is not active');
}
this.activeConnections.delete(connection.id);
connection.inUse = false;
connection.lastUsedAt = Date.now();
this.metrics.connectionsReleased++;
// Check if connection should be destroyed
if (connection.useCount >= this.options.maxUseCount) {
await this.destroyConnection(connection);
return;
}
// Validate on return if required
if (this.options.validateOnReturn && this.options.connectionValidator) {
try {
const isValid = await this.options.connectionValidator(connection.connection);
if (!isValid) {
await this.destroyConnection(connection);
return;
}
} catch (err) {
await this.destroyConnection(connection);
return;
}
}
// Check if there are waiting requests
const request = this.waitQueue.extract();
if (request) {
this.clearTimeout(request.timeoutHandle!);
request.resolve(this.checkoutConnection(connection));
this.eventEmitter.emit('dequeue', { queueSize: this.waitQueue.size });
} else {
// Return to available pool
this.availableConnections.push(connection);
this.eventEmitter.emit('release', connection);
}
});
}
/**
* Create a new connection
*/
private async createConnection(): Promise<IPooledConnection<T>> {
const rawConnection = await this.options.connectionFactory();
const connection: IPooledConnection<T> = {
id: `conn-${this.connectionIdCounter++}`,
connection: rawConnection,
createdAt: Date.now(),
lastUsedAt: Date.now(),
useCount: 0,
inUse: false,
};
this.metrics.connectionsCreated++;
this.eventEmitter.emit('create', connection);
return connection;
}
/**
* Destroy a connection
*/
private async destroyConnection(connection: IPooledConnection<T>): Promise<void> {
try {
if (this.options.connectionDestroyer) {
await this.options.connectionDestroyer(connection.connection);
}
this.metrics.connectionsDestroyed++;
this.eventEmitter.emit('destroy', connection);
} catch (err) {
if (this.options.onConnectionError) {
this.options.onConnectionError(err as Error, connection.connection);
}
}
}
/**
* Get current pool statistics
*/
public getStats() {
return {
available: this.availableConnections.length,
active: this.activeConnections.size,
waiting: this.waitQueue.size,
total: this.availableConnections.length + this.activeConnections.size,
...this.metrics,
};
}
/**
* Subscribe to pool events
*/
public on(event: string, listener: Function): void {
this.addEventListener(this.eventEmitter, event, listener);
}
/**
* Close the pool and cleanup resources
*/
protected async onCleanup(): Promise<void> {
this.isClosing = true;
// Clear the wait queue
while (!this.waitQueue.isEmpty()) {
const request = this.waitQueue.extract();
if (request) {
this.clearTimeout(request.timeoutHandle!);
request.reject(new Error('Connection pool is closing'));
}
}
// Wait for active connections to be released (with timeout)
const timeout = 30000;
const startTime = Date.now();
while (this.activeConnections.size > 0 && Date.now() - startTime < timeout) {
await new Promise(resolve => {
const timer = setTimeout(resolve, 100);
if (typeof timer.unref === 'function') {
timer.unref();
}
});
}
// Destroy all connections
const allConnections = [
...this.availableConnections,
...this.activeConnections.values(),
];
await Promise.all(allConnections.map(conn => this.destroyConnection(conn)));
this.availableConnections.length = 0;
this.activeConnections.clear();
}
}

View File

@@ -1,270 +0,0 @@
/**
* Async filesystem utilities for SmartProxy
* Provides non-blocking alternatives to synchronous filesystem operations
*/
import * as plugins from '../../plugins.js';
export class AsyncFileSystem {
/**
* Check if a file or directory exists
* @param path - Path to check
* @returns Promise resolving to true if exists, false otherwise
*/
static async exists(path: string): Promise<boolean> {
try {
await plugins.fs.promises.access(path);
return true;
} catch {
return false;
}
}
/**
* Ensure a directory exists, creating it if necessary
* @param dirPath - Directory path to ensure
* @returns Promise that resolves when directory is ensured
*/
static async ensureDir(dirPath: string): Promise<void> {
await plugins.fs.promises.mkdir(dirPath, { recursive: true });
}
/**
* Read a file as string
* @param filePath - Path to the file
* @param encoding - File encoding (default: utf8)
* @returns Promise resolving to file contents
*/
static async readFile(filePath: string, encoding: BufferEncoding = 'utf8'): Promise<string> {
return plugins.fs.promises.readFile(filePath, encoding);
}
/**
* Read a file as buffer
* @param filePath - Path to the file
* @returns Promise resolving to file buffer
*/
static async readFileBuffer(filePath: string): Promise<Buffer> {
return plugins.fs.promises.readFile(filePath);
}
/**
* Write string data to a file
* @param filePath - Path to the file
* @param data - String data to write
* @param encoding - File encoding (default: utf8)
* @returns Promise that resolves when file is written
*/
static async writeFile(filePath: string, data: string, encoding: BufferEncoding = 'utf8'): Promise<void> {
// Ensure directory exists
const dir = plugins.path.dirname(filePath);
await this.ensureDir(dir);
await plugins.fs.promises.writeFile(filePath, data, encoding);
}
/**
* Write buffer data to a file
* @param filePath - Path to the file
* @param data - Buffer data to write
* @returns Promise that resolves when file is written
*/
static async writeFileBuffer(filePath: string, data: Buffer): Promise<void> {
const dir = plugins.path.dirname(filePath);
await this.ensureDir(dir);
await plugins.fs.promises.writeFile(filePath, data);
}
/**
* Remove a file
* @param filePath - Path to the file
* @returns Promise that resolves when file is removed
*/
static async remove(filePath: string): Promise<void> {
try {
await plugins.fs.promises.unlink(filePath);
} catch (error: any) {
if (error.code !== 'ENOENT') {
throw error;
}
// File doesn't exist, which is fine
}
}
/**
* Remove a directory and all its contents
* @param dirPath - Path to the directory
* @returns Promise that resolves when directory is removed
*/
static async removeDir(dirPath: string): Promise<void> {
try {
await plugins.fs.promises.rm(dirPath, { recursive: true, force: true });
} catch (error: any) {
if (error.code !== 'ENOENT') {
throw error;
}
}
}
/**
* Read JSON from a file
* @param filePath - Path to the JSON file
* @returns Promise resolving to parsed JSON
*/
static async readJSON<T = any>(filePath: string): Promise<T> {
const content = await this.readFile(filePath);
return JSON.parse(content);
}
/**
* Write JSON to a file
* @param filePath - Path to the file
* @param data - Data to write as JSON
* @param pretty - Whether to pretty-print JSON (default: true)
* @returns Promise that resolves when file is written
*/
static async writeJSON(filePath: string, data: any, pretty = true): Promise<void> {
const jsonString = pretty ? JSON.stringify(data, null, 2) : JSON.stringify(data);
await this.writeFile(filePath, jsonString);
}
/**
* Copy a file from source to destination
* @param source - Source file path
* @param destination - Destination file path
* @returns Promise that resolves when file is copied
*/
static async copyFile(source: string, destination: string): Promise<void> {
const destDir = plugins.path.dirname(destination);
await this.ensureDir(destDir);
await plugins.fs.promises.copyFile(source, destination);
}
/**
* Move/rename a file
* @param source - Source file path
* @param destination - Destination file path
* @returns Promise that resolves when file is moved
*/
static async moveFile(source: string, destination: string): Promise<void> {
const destDir = plugins.path.dirname(destination);
await this.ensureDir(destDir);
await plugins.fs.promises.rename(source, destination);
}
/**
* Get file stats
* @param filePath - Path to the file
* @returns Promise resolving to file stats or null if doesn't exist
*/
static async getStats(filePath: string): Promise<plugins.fs.Stats | null> {
try {
return await plugins.fs.promises.stat(filePath);
} catch (error: any) {
if (error.code === 'ENOENT') {
return null;
}
throw error;
}
}
/**
* List files in a directory
* @param dirPath - Directory path
* @returns Promise resolving to array of filenames
*/
static async listFiles(dirPath: string): Promise<string[]> {
try {
return await plugins.fs.promises.readdir(dirPath);
} catch (error: any) {
if (error.code === 'ENOENT') {
return [];
}
throw error;
}
}
/**
* List files in a directory with full paths
* @param dirPath - Directory path
* @returns Promise resolving to array of full file paths
*/
static async listFilesFullPath(dirPath: string): Promise<string[]> {
const files = await this.listFiles(dirPath);
return files.map(file => plugins.path.join(dirPath, file));
}
/**
* Recursively list all files in a directory
* @param dirPath - Directory path
* @param fileList - Accumulator for file list (used internally)
* @returns Promise resolving to array of all file paths
*/
static async listFilesRecursive(dirPath: string, fileList: string[] = []): Promise<string[]> {
const files = await this.listFiles(dirPath);
for (const file of files) {
const filePath = plugins.path.join(dirPath, file);
const stats = await this.getStats(filePath);
if (stats?.isDirectory()) {
await this.listFilesRecursive(filePath, fileList);
} else if (stats?.isFile()) {
fileList.push(filePath);
}
}
return fileList;
}
/**
* Create a read stream for a file
* @param filePath - Path to the file
* @param options - Stream options
* @returns Read stream
*/
static createReadStream(filePath: string, options?: Parameters<typeof plugins.fs.createReadStream>[1]): plugins.fs.ReadStream {
return plugins.fs.createReadStream(filePath, options);
}
/**
* Create a write stream for a file
* @param filePath - Path to the file
* @param options - Stream options
* @returns Write stream
*/
static createWriteStream(filePath: string, options?: Parameters<typeof plugins.fs.createWriteStream>[1]): plugins.fs.WriteStream {
return plugins.fs.createWriteStream(filePath, options);
}
/**
* Ensure a file exists, creating an empty file if necessary
* @param filePath - Path to the file
* @returns Promise that resolves when file is ensured
*/
static async ensureFile(filePath: string): Promise<void> {
const exists = await this.exists(filePath);
if (!exists) {
await this.writeFile(filePath, '');
}
}
/**
* Check if a path is a directory
* @param path - Path to check
* @returns Promise resolving to true if directory, false otherwise
*/
static async isDirectory(path: string): Promise<boolean> {
const stats = await this.getStats(path);
return stats?.isDirectory() ?? false;
}
/**
* Check if a path is a file
* @param path - Path to check
* @returns Promise resolving to true if file, false otherwise
*/
static async isFile(path: string): Promise<boolean> {
const stats = await this.getStats(path);
return stats?.isFile() ?? false;
}
}

View File

@@ -2,16 +2,4 @@
* Core utility functions
*/
export * from './validation-utils.js';
export * from './ip-utils.js';
export * from './template-utils.js';
export * from './security-utils.js';
export * from './shared-security-manager.js';
export * from './websocket-utils.js';
export * from './logger.js';
export * from './async-utils.js';
export * from './fs-utils.js';
export * from './lifecycle-component.js';
export * from './binary-heap.js';
export * from './enhanced-connection-pool.js';
export * from './socket-utils.js';

View File

@@ -1,303 +0,0 @@
import * as plugins from '../../plugins.js';
/**
* Utility class for IP address operations
*/
export class IpUtils {
/**
* Check if the IP matches any of the glob patterns
*
* This method checks IP addresses against glob patterns and handles IPv4/IPv6 normalization.
* It's used to implement IP filtering based on security configurations.
*
* @param ip - The IP address to check
* @param patterns - Array of glob patterns
* @returns true if IP matches any pattern, false otherwise
*/
public static isGlobIPMatch(ip: string, patterns: string[]): boolean {
if (!ip || !patterns || patterns.length === 0) return false;
// Normalize the IP being checked
const normalizedIPVariants = this.normalizeIP(ip);
if (normalizedIPVariants.length === 0) return false;
// Check each pattern
for (const pattern of patterns) {
// Handle CIDR notation
if (pattern.includes('/')) {
if (this.matchCIDR(ip, pattern)) {
return true;
}
continue;
}
// Handle range notation
if (pattern.includes('-') && !pattern.includes('*')) {
if (this.matchIPRange(ip, pattern)) {
return true;
}
continue;
}
// Expand shorthand patterns for glob matching
let expandedPattern = pattern;
if (pattern.includes('*') && !pattern.includes(':')) {
const parts = pattern.split('.');
while (parts.length < 4) {
parts.push('*');
}
expandedPattern = parts.join('.');
}
// Normalize and check with minimatch
const normalizedPatterns = this.normalizeIP(expandedPattern);
for (const ipVariant of normalizedIPVariants) {
for (const normalizedPattern of normalizedPatterns) {
if (plugins.minimatch(ipVariant, normalizedPattern)) {
return true;
}
}
}
}
return false;
}
/**
* Normalize IP addresses for consistent comparison
*
* @param ip The IP address to normalize
* @returns Array of normalized IP forms
*/
public static normalizeIP(ip: string): string[] {
if (!ip) return [];
// Handle IPv4-mapped IPv6 addresses (::ffff:127.0.0.1)
if (ip.startsWith('::ffff:')) {
const ipv4 = ip.slice(7);
return [ip, ipv4];
}
// Handle IPv4 addresses by also checking IPv4-mapped form
if (/^\d{1,3}(\.\d{1,3}){3}$/.test(ip)) {
return [ip, `::ffff:${ip}`];
}
return [ip];
}
/**
* Check if an IP is authorized using security rules
*
* @param ip - The IP address to check
* @param allowedIPs - Array of allowed IP patterns
* @param blockedIPs - Array of blocked IP patterns
* @returns true if IP is authorized, false if blocked
*/
public static isIPAuthorized(ip: string, allowedIPs: string[] = [], blockedIPs: string[] = []): boolean {
// Skip IP validation if no rules are defined
if (!ip || (allowedIPs.length === 0 && blockedIPs.length === 0)) {
return true;
}
// First check if IP is blocked - blocked IPs take precedence
if (blockedIPs.length > 0 && this.isGlobIPMatch(ip, blockedIPs)) {
return false;
}
// Then check if IP is allowed (if no allowed IPs are specified, all non-blocked IPs are allowed)
return allowedIPs.length === 0 || this.isGlobIPMatch(ip, allowedIPs);
}
/**
* Check if an IP address is a private network address
*
* @param ip The IP address to check
* @returns true if the IP is a private network address, false otherwise
*/
public static isPrivateIP(ip: string): boolean {
if (!ip) return false;
// Handle IPv4-mapped IPv6 addresses
if (ip.startsWith('::ffff:')) {
ip = ip.slice(7);
}
// Check IPv4 private ranges
if (/^\d{1,3}(\.\d{1,3}){3}$/.test(ip)) {
const parts = ip.split('.').map(Number);
// Check common private ranges
// 10.0.0.0/8
if (parts[0] === 10) return true;
// 172.16.0.0/12
if (parts[0] === 172 && parts[1] >= 16 && parts[1] <= 31) return true;
// 192.168.0.0/16
if (parts[0] === 192 && parts[1] === 168) return true;
// 127.0.0.0/8 (localhost)
if (parts[0] === 127) return true;
return false;
}
// IPv6 local addresses
return ip === '::1' || ip.startsWith('fc00:') || ip.startsWith('fd00:') || ip.startsWith('fe80:');
}
/**
* Check if an IP address is a public network address
*
* @param ip The IP address to check
* @returns true if the IP is a public network address, false otherwise
*/
public static isPublicIP(ip: string): boolean {
return !this.isPrivateIP(ip);
}
/**
* Check if an IP matches a CIDR notation
*
* @param ip The IP address to check
* @param cidr The CIDR notation (e.g., "192.168.1.0/24")
* @returns true if IP is within the CIDR range
*/
private static matchCIDR(ip: string, cidr: string): boolean {
if (!cidr.includes('/')) return false;
const [networkAddr, prefixStr] = cidr.split('/');
const prefix = parseInt(prefixStr, 10);
// Handle IPv4-mapped IPv6 in the IP being checked
let checkIP = ip;
if (checkIP.startsWith('::ffff:')) {
checkIP = checkIP.slice(7);
}
// Handle IPv6 CIDR
if (networkAddr.includes(':')) {
// TODO: Implement IPv6 CIDR matching
return false;
}
// IPv4 CIDR matching
if (!/^\d{1,3}(\.\d{1,3}){3}$/.test(checkIP)) return false;
if (!/^\d{1,3}(\.\d{1,3}){3}$/.test(networkAddr)) return false;
if (isNaN(prefix) || prefix < 0 || prefix > 32) return false;
const ipParts = checkIP.split('.').map(Number);
const netParts = networkAddr.split('.').map(Number);
// Validate IP parts
for (const part of [...ipParts, ...netParts]) {
if (part < 0 || part > 255) return false;
}
// Convert to 32-bit integers
const ipNum = (ipParts[0] << 24) | (ipParts[1] << 16) | (ipParts[2] << 8) | ipParts[3];
const netNum = (netParts[0] << 24) | (netParts[1] << 16) | (netParts[2] << 8) | netParts[3];
// Create mask
const mask = (-1 << (32 - prefix)) >>> 0;
// Check if IP is in network range
return (ipNum & mask) === (netNum & mask);
}
/**
* Check if an IP matches a range notation
*
* @param ip The IP address to check
* @param range The range notation (e.g., "192.168.1.1-192.168.1.100")
* @returns true if IP is within the range
*/
private static matchIPRange(ip: string, range: string): boolean {
if (!range.includes('-')) return false;
const [startIP, endIP] = range.split('-').map(s => s.trim());
// Handle IPv4-mapped IPv6 in the IP being checked
let checkIP = ip;
if (checkIP.startsWith('::ffff:')) {
checkIP = checkIP.slice(7);
}
// Only handle IPv4 for now
if (!/^\d{1,3}(\.\d{1,3}){3}$/.test(checkIP)) return false;
if (!/^\d{1,3}(\.\d{1,3}){3}$/.test(startIP)) return false;
if (!/^\d{1,3}(\.\d{1,3}){3}$/.test(endIP)) return false;
const ipParts = checkIP.split('.').map(Number);
const startParts = startIP.split('.').map(Number);
const endParts = endIP.split('.').map(Number);
// Validate parts
for (const part of [...ipParts, ...startParts, ...endParts]) {
if (part < 0 || part > 255) return false;
}
// Convert to 32-bit integers for comparison
const ipNum = (ipParts[0] << 24) | (ipParts[1] << 16) | (ipParts[2] << 8) | ipParts[3];
const startNum = (startParts[0] << 24) | (startParts[1] << 16) | (startParts[2] << 8) | startParts[3];
const endNum = (endParts[0] << 24) | (endParts[1] << 16) | (endParts[2] << 8) | endParts[3];
// Convert to unsigned for proper comparison
const ipUnsigned = ipNum >>> 0;
const startUnsigned = startNum >>> 0;
const endUnsigned = endNum >>> 0;
return ipUnsigned >= startUnsigned && ipUnsigned <= endUnsigned;
}
/**
* Convert a subnet CIDR to an IP range for filtering
*
* @param cidr The CIDR notation (e.g., "192.168.1.0/24")
* @returns Array of glob patterns that match the CIDR range
*/
public static cidrToGlobPatterns(cidr: string): string[] {
if (!cidr || !cidr.includes('/')) return [];
const [ipPart, prefixPart] = cidr.split('/');
const prefix = parseInt(prefixPart, 10);
if (isNaN(prefix) || prefix < 0 || prefix > 32) return [];
// For IPv4 only for now
if (!/^\d{1,3}(\.\d{1,3}){3}$/.test(ipPart)) return [];
const ipParts = ipPart.split('.').map(Number);
const fullMask = Math.pow(2, 32 - prefix) - 1;
// Convert IP to a numeric value
const ipNum = (ipParts[0] << 24) | (ipParts[1] << 16) | (ipParts[2] << 8) | ipParts[3];
// Calculate network address (IP & ~fullMask)
const networkNum = ipNum & ~fullMask;
// For large ranges, return wildcard patterns
if (prefix <= 8) {
return [`${(networkNum >>> 24) & 255}.*.*.*`];
} else if (prefix <= 16) {
return [`${(networkNum >>> 24) & 255}.${(networkNum >>> 16) & 255}.*.*`];
} else if (prefix <= 24) {
return [`${(networkNum >>> 24) & 255}.${(networkNum >>> 16) & 255}.${(networkNum >>> 8) & 255}.*`];
}
// For small ranges, create individual IP patterns
const patterns = [];
const maxAddresses = Math.min(256, Math.pow(2, 32 - prefix));
for (let i = 0; i < maxAddresses; i++) {
const currentIpNum = networkNum + i;
patterns.push(
`${(currentIpNum >>> 24) & 255}.${(currentIpNum >>> 16) & 255}.${(currentIpNum >>> 8) & 255}.${currentIpNum & 255}`
);
}
return patterns;
}
}

View File

@@ -1,251 +0,0 @@
/**
* Base class for components that need proper resource lifecycle management
* Provides automatic cleanup of timers and event listeners to prevent memory leaks
*/
export abstract class LifecycleComponent {
private timers: Set<NodeJS.Timeout> = new Set();
private intervals: Set<NodeJS.Timeout> = new Set();
private listeners: Array<{
target: any;
event: string;
handler: Function;
actualHandler?: Function; // The actual handler registered (may be wrapped)
once?: boolean;
}> = [];
private childComponents: Set<LifecycleComponent> = new Set();
protected isShuttingDown = false;
private cleanupPromise?: Promise<void>;
/**
* Create a managed setTimeout that will be automatically cleaned up
*/
protected setTimeout(handler: Function, timeout: number): NodeJS.Timeout {
if (this.isShuttingDown) {
// Return a dummy timer if shutting down
const dummyTimer = setTimeout(() => {}, 0);
if (typeof dummyTimer.unref === 'function') {
dummyTimer.unref();
}
return dummyTimer;
}
const wrappedHandler = () => {
this.timers.delete(timer);
if (!this.isShuttingDown) {
handler();
}
};
const timer = setTimeout(wrappedHandler, timeout);
this.timers.add(timer);
// Allow process to exit even with timer
if (typeof timer.unref === 'function') {
timer.unref();
}
return timer;
}
/**
* Create a managed setInterval that will be automatically cleaned up
*/
protected setInterval(handler: Function, interval: number): NodeJS.Timeout {
if (this.isShuttingDown) {
// Return a dummy timer if shutting down
const dummyTimer = setInterval(() => {}, interval);
if (typeof dummyTimer.unref === 'function') {
dummyTimer.unref();
}
clearInterval(dummyTimer); // Clear immediately since we don't need it
return dummyTimer;
}
const wrappedHandler = () => {
if (!this.isShuttingDown) {
handler();
}
};
const timer = setInterval(wrappedHandler, interval);
this.intervals.add(timer);
// Allow process to exit even with timer
if (typeof timer.unref === 'function') {
timer.unref();
}
return timer;
}
/**
* Clear a managed timeout
*/
protected clearTimeout(timer: NodeJS.Timeout): void {
clearTimeout(timer);
this.timers.delete(timer);
}
/**
* Clear a managed interval
*/
protected clearInterval(timer: NodeJS.Timeout): void {
clearInterval(timer);
this.intervals.delete(timer);
}
/**
* Add a managed event listener that will be automatically removed on cleanup
*/
protected addEventListener(
target: any,
event: string,
handler: Function,
options?: { once?: boolean }
): void {
if (this.isShuttingDown) {
return;
}
// For 'once' listeners, we need to wrap the handler to remove it from our tracking
let actualHandler = handler;
if (options?.once) {
actualHandler = (...args: any[]) => {
// Call the original handler
handler(...args);
// Remove from our internal tracking
const index = this.listeners.findIndex(
l => l.target === target && l.event === event && l.handler === handler
);
if (index !== -1) {
this.listeners.splice(index, 1);
}
};
}
// Support both EventEmitter and DOM-style event targets
if (typeof target.on === 'function') {
if (options?.once) {
target.once(event, actualHandler);
} else {
target.on(event, actualHandler);
}
} else if (typeof target.addEventListener === 'function') {
target.addEventListener(event, actualHandler, options);
} else {
throw new Error('Target must support on() or addEventListener()');
}
// Store both the original handler and the actual handler registered
this.listeners.push({
target,
event,
handler,
actualHandler, // The handler that was actually registered (may be wrapped)
once: options?.once
});
}
/**
* Remove a specific event listener
*/
protected removeEventListener(target: any, event: string, handler: Function): void {
// Remove from target
if (typeof target.removeListener === 'function') {
target.removeListener(event, handler);
} else if (typeof target.removeEventListener === 'function') {
target.removeEventListener(event, handler);
}
// Remove from our tracking
const index = this.listeners.findIndex(
l => l.target === target && l.event === event && l.handler === handler
);
if (index !== -1) {
this.listeners.splice(index, 1);
}
}
/**
* Register a child component that should be cleaned up when this component is cleaned up
*/
protected registerChildComponent(component: LifecycleComponent): void {
this.childComponents.add(component);
}
/**
* Unregister a child component
*/
protected unregisterChildComponent(component: LifecycleComponent): void {
this.childComponents.delete(component);
}
/**
* Override this method to implement component-specific cleanup logic
*/
protected async onCleanup(): Promise<void> {
// Override in subclasses
}
/**
* Clean up all managed resources
*/
public async cleanup(): Promise<void> {
// Return existing cleanup promise if already cleaning up
if (this.cleanupPromise) {
return this.cleanupPromise;
}
this.cleanupPromise = this.performCleanup();
return this.cleanupPromise;
}
private async performCleanup(): Promise<void> {
this.isShuttingDown = true;
// First, clean up child components
const childCleanupPromises: Promise<void>[] = [];
for (const child of this.childComponents) {
childCleanupPromises.push(child.cleanup());
}
await Promise.all(childCleanupPromises);
this.childComponents.clear();
// Clear all timers
for (const timer of this.timers) {
clearTimeout(timer);
}
this.timers.clear();
// Clear all intervals
for (const timer of this.intervals) {
clearInterval(timer);
}
this.intervals.clear();
// Remove all event listeners
for (const { target, event, handler, actualHandler } of this.listeners) {
// Use actualHandler if available (for wrapped handlers), otherwise use the original handler
const handlerToRemove = actualHandler || handler;
// All listeners need to be removed, including 'once' listeners that might not have fired
if (typeof target.removeListener === 'function') {
target.removeListener(event, handlerToRemove);
} else if (typeof target.removeEventListener === 'function') {
target.removeEventListener(event, handlerToRemove);
}
}
this.listeners = [];
// Call subclass cleanup
await this.onCleanup();
}
/**
* Check if the component is shutting down
*/
protected isShuttingDownState(): boolean {
return this.isShuttingDown;
}
}

View File

@@ -1,370 +0,0 @@
import { logger } from './logger.js';
interface ILogEvent {
level: 'info' | 'warn' | 'error' | 'debug';
message: string;
data?: any;
count: number;
firstSeen: number;
lastSeen: number;
}
interface IAggregatedEvent {
key: string;
events: Map<string, ILogEvent>;
flushTimer?: NodeJS.Timeout;
}
/**
* Log deduplication utility to reduce log spam for repetitive events
*/
export class LogDeduplicator {
private globalFlushTimer?: NodeJS.Timeout;
private aggregatedEvents: Map<string, IAggregatedEvent> = new Map();
private flushInterval: number = 5000; // 5 seconds
private maxBatchSize: number = 100;
private rapidEventThreshold: number = 50; // Flush early if this many events in 1 second
private lastRapidCheck: number = Date.now();
constructor(flushInterval?: number) {
if (flushInterval) {
this.flushInterval = flushInterval;
}
// Set up global periodic flush to ensure logs are emitted regularly
this.globalFlushTimer = setInterval(() => {
this.flushAll();
}, this.flushInterval * 2); // Flush everything every 2x the normal interval
if (this.globalFlushTimer.unref) {
this.globalFlushTimer.unref();
}
}
/**
* Log a deduplicated event
* @param key - Aggregation key (e.g., 'connection-rejected', 'cleanup-batch')
* @param level - Log level
* @param message - Log message template
* @param data - Additional data
* @param dedupeKey - Deduplication key within the aggregation (e.g., IP address, reason)
*/
public log(
key: string,
level: 'info' | 'warn' | 'error' | 'debug',
message: string,
data?: any,
dedupeKey?: string
): void {
const eventKey = dedupeKey || message;
const now = Date.now();
if (!this.aggregatedEvents.has(key)) {
this.aggregatedEvents.set(key, {
key,
events: new Map(),
flushTimer: undefined
});
}
const aggregated = this.aggregatedEvents.get(key)!;
if (aggregated.events.has(eventKey)) {
const event = aggregated.events.get(eventKey)!;
event.count++;
event.lastSeen = now;
if (data) {
event.data = { ...event.data, ...data };
}
} else {
aggregated.events.set(eventKey, {
level,
message,
data,
count: 1,
firstSeen: now,
lastSeen: now
});
}
// Check for rapid events (many events in short time)
const totalEvents = Array.from(aggregated.events.values()).reduce((sum, e) => sum + e.count, 0);
// If we're getting flooded with events, flush more frequently
if (now - this.lastRapidCheck < 1000 && totalEvents >= this.rapidEventThreshold) {
this.flush(key);
this.lastRapidCheck = now;
} else if (aggregated.events.size >= this.maxBatchSize) {
// Check if we should flush due to size
this.flush(key);
} else if (!aggregated.flushTimer) {
// Schedule flush
aggregated.flushTimer = setTimeout(() => {
this.flush(key);
}, this.flushInterval);
if (aggregated.flushTimer.unref) {
aggregated.flushTimer.unref();
}
}
// Update rapid check time
if (now - this.lastRapidCheck >= 1000) {
this.lastRapidCheck = now;
}
}
/**
* Flush aggregated events for a specific key
*/
public flush(key: string): void {
const aggregated = this.aggregatedEvents.get(key);
if (!aggregated || aggregated.events.size === 0) {
return;
}
if (aggregated.flushTimer) {
clearTimeout(aggregated.flushTimer);
aggregated.flushTimer = undefined;
}
// Emit aggregated log based on the key
switch (key) {
case 'connection-rejected':
this.flushConnectionRejections(aggregated);
break;
case 'connection-cleanup':
this.flushConnectionCleanups(aggregated);
break;
case 'connection-terminated':
this.flushConnectionTerminations(aggregated);
break;
case 'ip-rejected':
this.flushIPRejections(aggregated);
break;
default:
this.flushGeneric(aggregated);
}
// Clear events
aggregated.events.clear();
}
/**
* Flush all pending events
*/
public flushAll(): void {
for (const key of this.aggregatedEvents.keys()) {
this.flush(key);
}
}
private flushConnectionRejections(aggregated: IAggregatedEvent): void {
const totalCount = Array.from(aggregated.events.values()).reduce((sum, e) => sum + e.count, 0);
const byReason = new Map<string, number>();
for (const [, event] of aggregated.events) {
const reason = event.data?.reason || 'unknown';
byReason.set(reason, (byReason.get(reason) || 0) + event.count);
}
const reasonSummary = Array.from(byReason.entries())
.sort((a, b) => b[1] - a[1])
.map(([reason, count]) => `${reason}: ${count}`)
.join(', ');
const duration = Date.now() - Math.min(...Array.from(aggregated.events.values()).map(e => e.firstSeen));
logger.log('warn', `[SUMMARY] Rejected ${totalCount} connections in ${Math.round(duration/1000)}s`, {
reasons: reasonSummary,
uniqueIPs: aggregated.events.size,
component: 'connection-dedup'
});
}
private flushConnectionCleanups(aggregated: IAggregatedEvent): void {
const totalCount = Array.from(aggregated.events.values()).reduce((sum, e) => sum + e.count, 0);
const byReason = new Map<string, number>();
for (const [, event] of aggregated.events) {
const reason = event.data?.reason || 'normal';
byReason.set(reason, (byReason.get(reason) || 0) + event.count);
}
const reasonSummary = Array.from(byReason.entries())
.sort((a, b) => b[1] - a[1])
.slice(0, 5) // Top 5 reasons
.map(([reason, count]) => `${reason}: ${count}`)
.join(', ');
logger.log('info', `Cleaned up ${totalCount} connections`, {
reasons: reasonSummary,
duration: Date.now() - Math.min(...Array.from(aggregated.events.values()).map(e => e.firstSeen)),
component: 'connection-dedup'
});
}
private flushConnectionTerminations(aggregated: IAggregatedEvent): void {
const totalCount = Array.from(aggregated.events.values()).reduce((sum, e) => sum + e.count, 0);
const byReason = new Map<string, number>();
const byIP = new Map<string, number>();
let lastActiveCount = 0;
for (const [, event] of aggregated.events) {
const reason = event.data?.reason || 'unknown';
const ip = event.data?.remoteIP || 'unknown';
byReason.set(reason, (byReason.get(reason) || 0) + event.count);
// Track by IP
if (ip !== 'unknown') {
byIP.set(ip, (byIP.get(ip) || 0) + event.count);
}
// Track the last active connection count
if (event.data?.activeConnections !== undefined) {
lastActiveCount = event.data.activeConnections;
}
}
const reasonSummary = Array.from(byReason.entries())
.sort((a, b) => b[1] - a[1])
.slice(0, 5) // Top 5 reasons
.map(([reason, count]) => `${reason}: ${count}`)
.join(', ');
// Show top IPs if there are many different ones
let ipInfo = '';
if (byIP.size > 3) {
const topIPs = Array.from(byIP.entries())
.sort((a, b) => b[1] - a[1])
.slice(0, 3)
.map(([ip, count]) => `${ip} (${count})`)
.join(', ');
ipInfo = `, from ${byIP.size} IPs (top: ${topIPs})`;
} else if (byIP.size > 0) {
ipInfo = `, IPs: ${Array.from(byIP.keys()).join(', ')}`;
}
const duration = Date.now() - Math.min(...Array.from(aggregated.events.values()).map(e => e.firstSeen));
// Special handling for localhost connections (HttpProxy)
const localhostCount = byIP.get('::ffff:127.0.0.1') || 0;
if (localhostCount > 0 && byIP.size === 1) {
// All connections are from localhost (HttpProxy)
logger.log('info', `[SUMMARY] ${totalCount} HttpProxy connections terminated in ${Math.round(duration/1000)}s`, {
reasons: reasonSummary,
activeConnections: lastActiveCount,
component: 'connection-dedup'
});
} else {
logger.log('info', `[SUMMARY] ${totalCount} connections terminated in ${Math.round(duration/1000)}s`, {
reasons: reasonSummary,
activeConnections: lastActiveCount,
uniqueReasons: byReason.size,
...(ipInfo ? { ips: ipInfo } : {}),
component: 'connection-dedup'
});
}
}
private flushIPRejections(aggregated: IAggregatedEvent): void {
const byIP = new Map<string, { count: number; reasons: Set<string> }>();
const allReasons = new Map<string, number>();
for (const [ip, event] of aggregated.events) {
if (!byIP.has(ip)) {
byIP.set(ip, { count: 0, reasons: new Set() });
}
const ipData = byIP.get(ip)!;
ipData.count += event.count;
if (event.data?.reason) {
ipData.reasons.add(event.data.reason);
// Track overall reason counts
allReasons.set(event.data.reason, (allReasons.get(event.data.reason) || 0) + event.count);
}
}
// Create reason summary
const reasonSummary = Array.from(allReasons.entries())
.sort((a, b) => b[1] - a[1])
.map(([reason, count]) => `${reason}: ${count}`)
.join(', ');
// Log top offenders
const topOffenders = Array.from(byIP.entries())
.sort((a, b) => b[1].count - a[1].count)
.slice(0, 10)
.map(([ip, data]) => `${ip} (${data.count}x, ${Array.from(data.reasons).join('/')})`)
.join(', ');
const totalRejections = Array.from(byIP.values()).reduce((sum, data) => sum + data.count, 0);
const duration = Date.now() - Math.min(...Array.from(aggregated.events.values()).map(e => e.firstSeen));
logger.log('warn', `[SUMMARY] Rejected ${totalRejections} connections from ${byIP.size} IPs in ${Math.round(duration/1000)}s (${reasonSummary})`, {
topOffenders,
component: 'ip-dedup'
});
}
private flushGeneric(aggregated: IAggregatedEvent): void {
const totalCount = Array.from(aggregated.events.values()).reduce((sum, e) => sum + e.count, 0);
const level = aggregated.events.values().next().value?.level || 'info';
// Special handling for IP cleanup events
if (aggregated.key === 'ip-cleanup') {
const totalCleaned = Array.from(aggregated.events.values()).reduce((sum, e) => {
return sum + (e.data?.cleanedIPs || 0) + (e.data?.cleanedRateLimits || 0);
}, 0);
if (totalCleaned > 0) {
logger.log(level as any, `IP tracking cleanup: removed ${totalCleaned} entries across ${totalCount} cleanup cycles`, {
duration: Date.now() - Math.min(...Array.from(aggregated.events.values()).map(e => e.firstSeen)),
component: 'log-dedup'
});
}
} else {
logger.log(level as any, `${aggregated.key}: ${totalCount} events`, {
uniqueEvents: aggregated.events.size,
duration: Date.now() - Math.min(...Array.from(aggregated.events.values()).map(e => e.firstSeen)),
component: 'log-dedup'
});
}
}
/**
* Cleanup and stop deduplication
*/
public cleanup(): void {
this.flushAll();
if (this.globalFlushTimer) {
clearInterval(this.globalFlushTimer);
this.globalFlushTimer = undefined;
}
for (const aggregated of this.aggregatedEvents.values()) {
if (aggregated.flushTimer) {
clearTimeout(aggregated.flushTimer);
}
}
this.aggregatedEvents.clear();
}
}
// Global instance for connection-related log deduplication
export const connectionLogDeduplicator = new LogDeduplicator(5000); // 5 second batches
// Ensure logs are flushed on process exit.
// Only use beforeExit — do NOT call process.exit() from SIGINT/SIGTERM handlers
// as that kills the host process's graceful shutdown (e.g., dcrouter connection draining).
process.on('beforeExit', () => {
connectionLogDeduplicator.flushAll();
});
process.on('SIGINT', () => {
connectionLogDeduplicator.cleanup();
});
process.on('SIGTERM', () => {
connectionLogDeduplicator.cleanup();
});

View File

@@ -1,305 +0,0 @@
import * as plugins from '../../plugins.js';
import { IpMatcher } from '../routing/matchers/ip.js';
/**
* Security utilities for IP validation, rate limiting,
* authentication, and other security features
*/
/**
* Result of IP validation
*/
export interface IIpValidationResult {
allowed: boolean;
reason?: string;
}
/**
* IP connection tracking information
*/
export interface IIpConnectionInfo {
connections: Set<string>; // ConnectionIDs
timestamps: number[]; // Connection timestamps
ipVariants: string[]; // Normalized IP variants (e.g., ::ffff:127.0.0.1 and 127.0.0.1)
}
/**
* Rate limit tracking
*/
export interface IRateLimitInfo {
count: number;
expiry: number;
}
/**
* Logger interface for security utilities
*/
export interface ISecurityLogger {
info: (message: string, ...args: any[]) => void;
warn: (message: string, ...args: any[]) => void;
error: (message: string, ...args: any[]) => void;
debug?: (message: string, ...args: any[]) => void;
}
/**
* Normalize IP addresses for comparison
* Handles IPv4-mapped IPv6 addresses (::ffff:127.0.0.1)
*
* @param ip IP address to normalize
* @returns Array of equivalent IP representations
*/
export function normalizeIP(ip: string): string[] {
if (!ip) return [];
// Handle IPv4-mapped IPv6 addresses (::ffff:127.0.0.1)
if (ip.startsWith('::ffff:')) {
const ipv4 = ip.slice(7);
return [ip, ipv4];
}
// Handle IPv4 addresses by also checking IPv4-mapped form
if (/^\d{1,3}(\.\d{1,3}){3}$/.test(ip)) {
return [ip, `::ffff:${ip}`];
}
return [ip];
}
/**
* Check if an IP is authorized based on allow and block lists
*
* @param ip - The IP address to check
* @param allowedIPs - Array of allowed IP patterns
* @param blockedIPs - Array of blocked IP patterns
* @returns Whether the IP is authorized
*/
export function isIPAuthorized(
ip: string,
allowedIPs: string[] = ['*'],
blockedIPs: string[] = []
): boolean {
// Skip IP validation if no rules
if (!ip || (allowedIPs.length === 0 && blockedIPs.length === 0)) {
return true;
}
// First check if IP is blocked - blocked IPs take precedence
if (blockedIPs.length > 0) {
for (const pattern of blockedIPs) {
if (IpMatcher.match(pattern, ip)) {
return false;
}
}
}
// If allowed IPs list has wildcard, all non-blocked IPs are allowed
if (allowedIPs.includes('*')) {
return true;
}
// Then check if IP is allowed in the explicit allow list
if (allowedIPs.length > 0) {
for (const pattern of allowedIPs) {
if (IpMatcher.match(pattern, ip)) {
return true;
}
}
// If allowedIPs is specified but no match, deny access
return false;
}
// Default allow if no explicit allow list
return true;
}
/**
* Check if an IP exceeds maximum connections
*
* @param ip - The IP address to check
* @param ipConnectionsMap - Map of IPs to connection info
* @param maxConnectionsPerIP - Maximum allowed connections per IP
* @returns Result with allowed status and reason if blocked
*/
export function checkMaxConnections(
ip: string,
ipConnectionsMap: Map<string, IIpConnectionInfo>,
maxConnectionsPerIP: number
): IIpValidationResult {
if (!ipConnectionsMap.has(ip)) {
return { allowed: true };
}
const connectionCount = ipConnectionsMap.get(ip)!.connections.size;
if (connectionCount >= maxConnectionsPerIP) {
return {
allowed: false,
reason: `Maximum connections per IP (${maxConnectionsPerIP}) exceeded`
};
}
return { allowed: true };
}
/**
* Check if an IP exceeds connection rate limit
*
* @param ip - The IP address to check
* @param ipConnectionsMap - Map of IPs to connection info
* @param rateLimit - Maximum connections per minute
* @returns Result with allowed status and reason if blocked
*/
export function checkConnectionRate(
ip: string,
ipConnectionsMap: Map<string, IIpConnectionInfo>,
rateLimit: number
): IIpValidationResult {
const now = Date.now();
const minute = 60 * 1000;
// Get or create connection info
if (!ipConnectionsMap.has(ip)) {
const info: IIpConnectionInfo = {
connections: new Set(),
timestamps: [now],
ipVariants: normalizeIP(ip)
};
ipConnectionsMap.set(ip, info);
return { allowed: true };
}
// Get timestamps and filter out entries older than 1 minute
const info = ipConnectionsMap.get(ip)!;
const timestamps = info.timestamps.filter(time => now - time < minute);
timestamps.push(now);
info.timestamps = timestamps;
// Check if rate exceeds limit
if (timestamps.length > rateLimit) {
return {
allowed: false,
reason: `Connection rate limit (${rateLimit}/min) exceeded`
};
}
return { allowed: true };
}
/**
* Track a connection for an IP
*
* @param ip - The IP address
* @param connectionId - The connection ID to track
* @param ipConnectionsMap - Map of IPs to connection info
*/
export function trackConnection(
ip: string,
connectionId: string,
ipConnectionsMap: Map<string, IIpConnectionInfo>
): void {
if (!ipConnectionsMap.has(ip)) {
ipConnectionsMap.set(ip, {
connections: new Set([connectionId]),
timestamps: [Date.now()],
ipVariants: normalizeIP(ip)
});
return;
}
const info = ipConnectionsMap.get(ip)!;
info.connections.add(connectionId);
}
/**
* Remove connection tracking for an IP
*
* @param ip - The IP address
* @param connectionId - The connection ID to remove
* @param ipConnectionsMap - Map of IPs to connection info
*/
export function removeConnection(
ip: string,
connectionId: string,
ipConnectionsMap: Map<string, IIpConnectionInfo>
): void {
if (!ipConnectionsMap.has(ip)) return;
const info = ipConnectionsMap.get(ip)!;
info.connections.delete(connectionId);
if (info.connections.size === 0) {
ipConnectionsMap.delete(ip);
}
}
/**
* Clean up expired rate limits
*
* @param rateLimits - Map of rate limits to clean up
* @param logger - Logger for debug messages
*/
export function cleanupExpiredRateLimits(
rateLimits: Map<string, Map<string, IRateLimitInfo>>,
logger?: ISecurityLogger
): void {
const now = Date.now();
let totalRemoved = 0;
for (const [routeId, routeLimits] of rateLimits.entries()) {
let removed = 0;
for (const [key, limit] of routeLimits.entries()) {
if (limit.expiry < now) {
routeLimits.delete(key);
removed++;
totalRemoved++;
}
}
if (removed > 0 && logger?.debug) {
logger.debug(`Cleaned up ${removed} expired rate limits for route ${routeId}`);
}
}
if (totalRemoved > 0 && logger?.info) {
logger.info(`Cleaned up ${totalRemoved} expired rate limits total`);
}
}
/**
* Generate basic auth header value from username and password
*
* @param username - The username
* @param password - The password
* @returns Base64 encoded basic auth string
*/
export function generateBasicAuthHeader(username: string, password: string): string {
return `Basic ${Buffer.from(`${username}:${password}`).toString('base64')}`;
}
/**
* Parse basic auth header
*
* @param authHeader - The Authorization header value
* @returns Username and password, or null if invalid
*/
export function parseBasicAuthHeader(
authHeader: string
): { username: string; password: string } | null {
if (!authHeader || !authHeader.startsWith('Basic ')) {
return null;
}
try {
const base64 = authHeader.slice(6); // Remove 'Basic '
const decoded = Buffer.from(base64, 'base64').toString();
const [username, password] = decoded.split(':');
if (!username || !password) {
return null;
}
return { username, password };
} catch (err) {
return null;
}
}

View File

@@ -1,470 +0,0 @@
import * as plugins from '../../plugins.js';
import type { IRouteConfig, IRouteContext } from '../../proxies/smart-proxy/models/route-types.js';
import type {
IIpValidationResult,
IIpConnectionInfo,
ISecurityLogger,
IRateLimitInfo
} from './security-utils.js';
import {
isIPAuthorized,
checkMaxConnections,
checkConnectionRate,
trackConnection,
removeConnection,
cleanupExpiredRateLimits,
parseBasicAuthHeader,
normalizeIP
} from './security-utils.js';
/**
* Shared SecurityManager for use across proxy components
* Handles IP tracking, rate limiting, and authentication
*/
export class SharedSecurityManager {
// IP connection tracking
private connectionsByIP: Map<string, IIpConnectionInfo> = new Map();
// Route-specific rate limiting
private rateLimits: Map<string, Map<string, IRateLimitInfo>> = new Map();
// Cache IP filtering results to avoid constant regex matching
private ipFilterCache: Map<string, Map<string, boolean>> = new Map();
// Default limits
private maxConnectionsPerIP: number;
private connectionRateLimitPerMinute: number;
// Cache cleanup interval
private cleanupInterval: NodeJS.Timeout | null = null;
/**
* Create a new SharedSecurityManager
*
* @param options - Configuration options
* @param logger - Logger instance
*/
constructor(options: {
maxConnectionsPerIP?: number;
connectionRateLimitPerMinute?: number;
cleanupIntervalMs?: number;
routes?: IRouteConfig[];
}, private logger?: ISecurityLogger) {
this.maxConnectionsPerIP = options.maxConnectionsPerIP || 100;
this.connectionRateLimitPerMinute = options.connectionRateLimitPerMinute || 300;
// Set up logger with defaults if not provided
this.logger = logger || {
info: console.log,
warn: console.warn,
error: console.error
};
// Set up cache cleanup interval
const cleanupInterval = options.cleanupIntervalMs || 60000; // Default: 1 minute
this.cleanupInterval = setInterval(() => {
this.cleanupCaches();
}, cleanupInterval);
// Don't keep the process alive just for cleanup
if (this.cleanupInterval.unref) {
this.cleanupInterval.unref();
}
}
/**
* Get connections count by IP
*
* @param ip - The IP address to check
* @returns Number of connections from this IP
*/
public getConnectionCountByIP(ip: string): number {
// Check all normalized variants of the IP
const variants = normalizeIP(ip);
for (const variant of variants) {
const info = this.connectionsByIP.get(variant);
if (info) {
return info.connections.size;
}
}
return 0;
}
/**
* Track connection by IP
*
* @param ip - The IP address to track
* @param connectionId - The connection ID to associate
*/
public trackConnectionByIP(ip: string, connectionId: string): void {
// Check if any variant already exists
const variants = normalizeIP(ip);
let existingKey: string | null = null;
for (const variant of variants) {
if (this.connectionsByIP.has(variant)) {
existingKey = variant;
break;
}
}
// Use existing key or the original IP
trackConnection(existingKey || ip, connectionId, this.connectionsByIP);
}
/**
* Remove connection tracking for an IP
*
* @param ip - The IP address to update
* @param connectionId - The connection ID to remove
*/
public removeConnectionByIP(ip: string, connectionId: string): void {
// Check all variants to find where the connection is tracked
const variants = normalizeIP(ip);
for (const variant of variants) {
if (this.connectionsByIP.has(variant)) {
removeConnection(variant, connectionId, this.connectionsByIP);
break;
}
}
}
/**
* Check if IP is authorized based on route security settings
*
* @param ip - The IP address to check
* @param allowedIPs - List of allowed IP patterns
* @param blockedIPs - List of blocked IP patterns
* @returns Whether the IP is authorized
*/
public isIPAuthorized(
ip: string,
allowedIPs: string[] = ['*'],
blockedIPs: string[] = []
): boolean {
return isIPAuthorized(ip, allowedIPs, blockedIPs);
}
/**
* Validate IP against rate limits and connection limits
*
* @param ip - The IP address to validate
* @returns Result with allowed status and reason if blocked
*/
public validateIP(ip: string): IIpValidationResult {
// Check connection count limit
const connectionResult = checkMaxConnections(
ip,
this.connectionsByIP,
this.maxConnectionsPerIP
);
if (!connectionResult.allowed) {
return connectionResult;
}
// Check connection rate limit
const rateResult = checkConnectionRate(
ip,
this.connectionsByIP,
this.connectionRateLimitPerMinute
);
if (!rateResult.allowed) {
return rateResult;
}
return { allowed: true };
}
/**
* Atomically validate an IP and track the connection if allowed.
* This prevents race conditions where concurrent connections could bypass per-IP limits.
*
* @param ip - The IP address to validate
* @param connectionId - The connection ID to track if validation passes
* @returns Object with validation result and reason
*/
public validateAndTrackIP(ip: string, connectionId: string): IIpValidationResult {
// Check connection count limit BEFORE tracking
const connectionResult = checkMaxConnections(
ip,
this.connectionsByIP,
this.maxConnectionsPerIP
);
if (!connectionResult.allowed) {
return connectionResult;
}
// Check connection rate limit
const rateResult = checkConnectionRate(
ip,
this.connectionsByIP,
this.connectionRateLimitPerMinute
);
if (!rateResult.allowed) {
return rateResult;
}
// Validation passed - immediately track to prevent race conditions
this.trackConnectionByIP(ip, connectionId);
return { allowed: true };
}
/**
* Check if a client is allowed to access a specific route
*
* @param route - The route to check
* @param context - The request context
* @param routeConnectionCount - Current connection count for this route (optional)
* @returns Whether access is allowed
*/
public isAllowed(route: IRouteConfig, context: IRouteContext, routeConnectionCount?: number): boolean {
if (!route.security) {
return true; // No security restrictions
}
// --- IP filtering ---
if (!this.isClientIpAllowed(route, context.clientIp)) {
this.logger?.debug?.(`IP ${context.clientIp} is blocked for route ${route.name || 'unnamed'}`);
return false;
}
// --- Route-level connection limit ---
if (route.security.maxConnections !== undefined && routeConnectionCount !== undefined) {
if (routeConnectionCount >= route.security.maxConnections) {
this.logger?.debug?.(`Route connection limit (${route.security.maxConnections}) exceeded for route ${route.name || 'unnamed'}`);
return false;
}
}
// --- Rate limiting ---
if (route.security.rateLimit?.enabled && !this.isWithinRateLimit(route, context)) {
this.logger?.debug?.(`Rate limit exceeded for route ${route.name || 'unnamed'}`);
return false;
}
return true;
}
/**
* Check if a client IP is allowed for a route
*
* @param route - The route to check
* @param clientIp - The client IP
* @returns Whether the IP is allowed
*/
private isClientIpAllowed(route: IRouteConfig, clientIp: string): boolean {
if (!route.security) {
return true; // No security restrictions
}
const routeId = route.id || route.name || 'unnamed';
// Check cache first
if (!this.ipFilterCache.has(routeId)) {
this.ipFilterCache.set(routeId, new Map());
}
const routeCache = this.ipFilterCache.get(routeId)!;
if (routeCache.has(clientIp)) {
return routeCache.get(clientIp)!;
}
// Check IP against route security settings
const ipAllowList = route.security.ipAllowList;
const ipBlockList = route.security.ipBlockList;
const allowed = this.isIPAuthorized(clientIp, ipAllowList, ipBlockList);
// Cache the result
routeCache.set(clientIp, allowed);
return allowed;
}
/**
* Check if request is within rate limit
*
* @param route - The route to check
* @param context - The request context
* @returns Whether the request is within rate limit
*/
private isWithinRateLimit(route: IRouteConfig, context: IRouteContext): boolean {
if (!route.security?.rateLimit?.enabled) {
return true;
}
const rateLimit = route.security.rateLimit;
const routeId = route.id || route.name || 'unnamed';
// Determine rate limit key (by IP, path, or header)
let key = context.clientIp; // Default to IP
if (rateLimit.keyBy === 'path' && context.path) {
key = `${context.clientIp}:${context.path}`;
} else if (rateLimit.keyBy === 'header' && rateLimit.headerName && context.headers) {
const headerValue = context.headers[rateLimit.headerName.toLowerCase()];
if (headerValue) {
key = `${context.clientIp}:${headerValue}`;
}
}
// Get or create rate limit tracking for this route
if (!this.rateLimits.has(routeId)) {
this.rateLimits.set(routeId, new Map());
}
const routeLimits = this.rateLimits.get(routeId)!;
const now = Date.now();
// Get or create rate limit tracking for this key
let limit = routeLimits.get(key);
if (!limit || limit.expiry < now) {
// Create new rate limit or reset expired one
limit = {
count: 1,
expiry: now + (rateLimit.window * 1000)
};
routeLimits.set(key, limit);
return true;
}
// Increment the counter
limit.count++;
// Check if rate limit is exceeded
return limit.count <= rateLimit.maxRequests;
}
/**
* Validate HTTP Basic Authentication
*
* @param route - The route to check
* @param authHeader - The Authorization header
* @returns Whether authentication is valid
*/
public validateBasicAuth(route: IRouteConfig, authHeader?: string): boolean {
// Skip if basic auth not enabled for route
if (!route.security?.basicAuth?.enabled) {
return true;
}
// No auth header means auth failed
if (!authHeader) {
return false;
}
// Parse auth header
const credentials = parseBasicAuthHeader(authHeader);
if (!credentials) {
return false;
}
// Check credentials against configured users
const { username, password } = credentials;
const users = route.security.basicAuth.users;
return users.some(user =>
user.username === username && user.password === password
);
}
/**
* Verify a JWT token against route configuration
*
* @param route - The route to verify the token for
* @param token - The JWT token to verify
* @returns True if the token is valid, false otherwise
*/
public verifyJwtToken(route: IRouteConfig, token: string): boolean {
if (!route.security?.jwtAuth?.enabled) {
return true;
}
try {
const jwtAuth = route.security.jwtAuth;
// Verify structure (header.payload.signature)
const parts = token.split('.');
if (parts.length !== 3) {
return false;
}
// Decode payload
const payload = JSON.parse(Buffer.from(parts[1], 'base64').toString());
// Check expiration
if (payload.exp && payload.exp < Math.floor(Date.now() / 1000)) {
return false;
}
// Check issuer
if (jwtAuth.issuer && payload.iss !== jwtAuth.issuer) {
return false;
}
// Check audience
if (jwtAuth.audience && payload.aud !== jwtAuth.audience) {
return false;
}
// Note: In a real implementation, you'd also verify the signature
// using the secret and algorithm specified in jwtAuth.
// This requires a proper JWT library for cryptographic verification.
return true;
} catch (err) {
this.logger?.error?.(`Error verifying JWT: ${err}`);
return false;
}
}
/**
* Clean up caches to prevent memory leaks
*/
private cleanupCaches(): void {
// Clean up rate limits
cleanupExpiredRateLimits(this.rateLimits, this.logger);
// Clean up IP connection tracking
let cleanedIPs = 0;
for (const [ip, info] of this.connectionsByIP.entries()) {
// Remove IPs with no active connections and no recent timestamps
if (info.connections.size === 0 && info.timestamps.length === 0) {
this.connectionsByIP.delete(ip);
cleanedIPs++;
}
}
if (cleanedIPs > 0 && this.logger?.debug) {
this.logger.debug(`Cleaned up ${cleanedIPs} IPs with no active connections`);
}
// IP filter cache doesn't need cleanup (tied to routes)
}
/**
* Clear all IP tracking data (for shutdown)
*/
public clearIPTracking(): void {
this.connectionsByIP.clear();
this.rateLimits.clear();
this.ipFilterCache.clear();
if (this.cleanupInterval) {
clearInterval(this.cleanupInterval);
this.cleanupInterval = null;
}
}
/**
* Update routes for security checking
*
* @param routes - New routes to use
*/
public setRoutes(routes: IRouteConfig[]): void {
// Only clear the IP filter cache - route-specific
this.ipFilterCache.clear();
}
}

View File

@@ -1,322 +0,0 @@
import * as plugins from '../../plugins.js';
export interface CleanupOptions {
immediate?: boolean; // Force immediate destruction
allowDrain?: boolean; // Allow write buffer to drain
gracePeriod?: number; // Ms to wait before force close
}
export interface SafeSocketOptions {
port: number;
host: string;
onError?: (error: Error) => void;
onConnect?: () => void;
timeout?: number;
}
/**
* Safely cleanup a socket by removing all listeners and destroying it
* @param socket The socket to cleanup
* @param socketName Optional name for logging
* @param options Cleanup options
*/
export function cleanupSocket(
socket: plugins.net.Socket | plugins.tls.TLSSocket | null,
socketName?: string,
options: CleanupOptions = {}
): Promise<void> {
if (!socket || socket.destroyed) return Promise.resolve();
return new Promise<void>((resolve) => {
const cleanup = () => {
try {
// Remove all event listeners
socket.removeAllListeners();
// Destroy if not already destroyed
if (!socket.destroyed) {
socket.destroy();
}
} catch (err) {
console.error(`Error cleaning up socket${socketName ? ` (${socketName})` : ''}: ${err}`);
}
resolve();
};
if (options.immediate) {
// Immediate cleanup (old behavior)
socket.unpipe();
cleanup();
} else if (options.allowDrain && socket.writable) {
// Allow pending writes to complete
socket.end(() => cleanup());
// Force cleanup after grace period
if (options.gracePeriod) {
setTimeout(() => {
if (!socket.destroyed) {
cleanup();
}
}, options.gracePeriod);
}
} else {
// Default: immediate cleanup
socket.unpipe();
cleanup();
}
});
}
/**
* Create independent cleanup handlers for paired sockets that support half-open connections
* @param clientSocket The client socket
* @param serverSocket The server socket
* @param onBothClosed Callback when both sockets are closed
* @returns Independent cleanup functions for each socket
*/
export function createIndependentSocketHandlers(
clientSocket: plugins.net.Socket | plugins.tls.TLSSocket,
serverSocket: plugins.net.Socket | plugins.tls.TLSSocket,
onBothClosed: (reason: string) => void,
options: { enableHalfOpen?: boolean } = {}
): { cleanupClient: (reason: string) => Promise<void>, cleanupServer: (reason: string) => Promise<void> } {
let clientClosed = false;
let serverClosed = false;
let clientReason = '';
let serverReason = '';
const checkBothClosed = () => {
if (clientClosed && serverClosed) {
onBothClosed(`client: ${clientReason}, server: ${serverReason}`);
}
};
const cleanupClient = async (reason: string) => {
if (clientClosed) return;
clientClosed = true;
clientReason = reason;
// Default behavior: close both sockets when one closes (required for proxy chains)
if (!serverClosed && !options.enableHalfOpen) {
serverSocket.destroy();
}
// Half-open support (opt-in only)
if (!serverClosed && serverSocket.writable && options.enableHalfOpen) {
// Half-close: stop reading from client, let server finish
clientSocket.pause();
clientSocket.unpipe(serverSocket);
await cleanupSocket(clientSocket, 'client', { allowDrain: true, gracePeriod: 5000 });
} else {
await cleanupSocket(clientSocket, 'client', { immediate: true });
}
checkBothClosed();
};
const cleanupServer = async (reason: string) => {
if (serverClosed) return;
serverClosed = true;
serverReason = reason;
// Default behavior: close both sockets when one closes (required for proxy chains)
if (!clientClosed && !options.enableHalfOpen) {
clientSocket.destroy();
}
// Half-open support (opt-in only)
if (!clientClosed && clientSocket.writable && options.enableHalfOpen) {
// Half-close: stop reading from server, let client finish
serverSocket.pause();
serverSocket.unpipe(clientSocket);
await cleanupSocket(serverSocket, 'server', { allowDrain: true, gracePeriod: 5000 });
} else {
await cleanupSocket(serverSocket, 'server', { immediate: true });
}
checkBothClosed();
};
return { cleanupClient, cleanupServer };
}
/**
* Setup socket error and close handlers with proper cleanup
* @param socket The socket to setup handlers for
* @param handleClose The cleanup function to call
* @param handleTimeout Optional custom timeout handler
* @param errorPrefix Optional prefix for error messages
*/
export function setupSocketHandlers(
socket: plugins.net.Socket | plugins.tls.TLSSocket,
handleClose: (reason: string) => void,
handleTimeout?: (socket: plugins.net.Socket | plugins.tls.TLSSocket) => void,
errorPrefix?: string
): void {
socket.on('error', (error) => {
const prefix = errorPrefix || 'Socket';
handleClose(`${prefix}_error: ${error.message}`);
});
socket.on('close', () => {
const prefix = errorPrefix || 'socket';
handleClose(`${prefix}_closed`);
});
socket.on('timeout', () => {
if (handleTimeout) {
handleTimeout(socket); // Custom timeout handling
} else {
// Default: just log, don't close
console.warn(`Socket timeout: ${errorPrefix || 'socket'}`);
}
});
}
/**
* Setup bidirectional data forwarding between two sockets with proper cleanup
* @param clientSocket The client/incoming socket
* @param serverSocket The server/outgoing socket
* @param handlers Object containing optional handlers for data and cleanup
* @returns Cleanup functions for both sockets
*/
export function setupBidirectionalForwarding(
clientSocket: plugins.net.Socket | plugins.tls.TLSSocket,
serverSocket: plugins.net.Socket | plugins.tls.TLSSocket,
handlers: {
onClientData?: (chunk: Buffer) => void;
onServerData?: (chunk: Buffer) => void;
onCleanup: (reason: string) => void;
enableHalfOpen?: boolean;
}
): { cleanupClient: (reason: string) => Promise<void>, cleanupServer: (reason: string) => Promise<void> } {
// Set up cleanup handlers
const { cleanupClient, cleanupServer } = createIndependentSocketHandlers(
clientSocket,
serverSocket,
handlers.onCleanup,
{ enableHalfOpen: handlers.enableHalfOpen }
);
// Set up error and close handlers
setupSocketHandlers(clientSocket, cleanupClient, undefined, 'client');
setupSocketHandlers(serverSocket, cleanupServer, undefined, 'server');
// Set up data forwarding with backpressure handling
clientSocket.on('data', (chunk: Buffer) => {
if (handlers.onClientData) {
handlers.onClientData(chunk);
}
if (serverSocket.writable) {
const flushed = serverSocket.write(chunk);
// Handle backpressure
if (!flushed) {
clientSocket.pause();
serverSocket.once('drain', () => {
if (!clientSocket.destroyed) {
clientSocket.resume();
}
});
}
}
});
serverSocket.on('data', (chunk: Buffer) => {
if (handlers.onServerData) {
handlers.onServerData(chunk);
}
if (clientSocket.writable) {
const flushed = clientSocket.write(chunk);
// Handle backpressure
if (!flushed) {
serverSocket.pause();
clientSocket.once('drain', () => {
if (!serverSocket.destroyed) {
serverSocket.resume();
}
});
}
}
});
return { cleanupClient, cleanupServer };
}
/**
* Create a socket with immediate error handling to prevent crashes
* @param options Socket creation options
* @returns The created socket
*/
export function createSocketWithErrorHandler(options: SafeSocketOptions): plugins.net.Socket {
const { port, host, onError, onConnect, timeout } = options;
// Create socket with immediate error handler attachment
const socket = new plugins.net.Socket();
// Track if connected
let connected = false;
let connectionTimeout: NodeJS.Timeout | null = null;
// Attach error handler BEFORE connecting to catch immediate errors
socket.on('error', (error) => {
console.error(`Socket connection error to ${host}:${port}: ${error.message}`);
// Clear the connection timeout if it exists
if (connectionTimeout) {
clearTimeout(connectionTimeout);
connectionTimeout = null;
}
if (onError) {
onError(error);
}
});
// Attach connect handler
const handleConnect = () => {
connected = true;
// Clear the connection timeout
if (connectionTimeout) {
clearTimeout(connectionTimeout);
connectionTimeout = null;
}
// Set inactivity timeout if provided (after connection is established)
if (timeout) {
socket.setTimeout(timeout);
}
if (onConnect) {
onConnect();
}
};
socket.on('connect', handleConnect);
// Implement connection establishment timeout
if (timeout) {
connectionTimeout = setTimeout(() => {
if (!connected && !socket.destroyed) {
// Connection timed out - destroy the socket
const error = new Error(`Connection timeout after ${timeout}ms to ${host}:${port}`);
(error as any).code = 'ETIMEDOUT';
console.error(`Socket connection timeout to ${host}:${port} after ${timeout}ms`);
// Destroy the socket
socket.destroy();
// Call error handler
if (onError) {
onError(error);
}
}
}, timeout);
}
// Now attempt to connect - any immediate errors will be caught
socket.connect(port, host);
return socket;
}

View File

@@ -1,124 +0,0 @@
import type { IRouteContext } from '../models/route-context.js';
/**
* Utility class for resolving template variables in strings
*/
export class TemplateUtils {
/**
* Resolve template variables in a string using the route context
* Supports variables like {domain}, {path}, {clientIp}, etc.
*
* @param template The template string with {variables}
* @param context The route context with values
* @returns The resolved string
*/
public static resolveTemplateVariables(template: string, context: IRouteContext): string {
if (!template) {
return template;
}
// Replace variables with values from context
return template.replace(/\{([a-zA-Z0-9_\.]+)\}/g, (match, varName) => {
// Handle nested properties with dot notation (e.g., {headers.host})
if (varName.includes('.')) {
const parts = varName.split('.');
let current: any = context;
// Traverse nested object structure
for (const part of parts) {
if (current === undefined || current === null) {
return match; // Return original if path doesn't exist
}
current = current[part];
}
// Return the resolved value if it exists
if (current !== undefined && current !== null) {
return TemplateUtils.convertToString(current);
}
return match;
}
// Direct property access
const value = context[varName as keyof IRouteContext];
if (value === undefined) {
return match; // Keep the original {variable} if not found
}
// Convert value to string
return TemplateUtils.convertToString(value);
});
}
/**
* Safely convert a value to a string
*
* @param value Any value to convert to string
* @returns String representation or original match for complex objects
*/
private static convertToString(value: any): string {
if (value === null || value === undefined) {
return '';
}
if (typeof value === 'string') {
return value;
}
if (typeof value === 'number' || typeof value === 'boolean') {
return value.toString();
}
if (Array.isArray(value)) {
return value.join(',');
}
if (typeof value === 'object') {
try {
return JSON.stringify(value);
} catch (e) {
return '[Object]';
}
}
return String(value);
}
/**
* Resolve template variables in header values
*
* @param headers Header object with potential template variables
* @param context Route context for variable resolution
* @returns New header object with resolved values
*/
public static resolveHeaderTemplates(
headers: Record<string, string>,
context: IRouteContext
): Record<string, string> {
const result: Record<string, string> = {};
for (const [key, value] of Object.entries(headers)) {
// Skip special directive headers (starting with !)
if (value.startsWith('!')) {
result[key] = value;
continue;
}
// Resolve template variables in the header value
result[key] = TemplateUtils.resolveTemplateVariables(value, context);
}
return result;
}
/**
* Check if a string contains template variables
*
* @param str String to check for template variables
* @returns True if string contains template variables
*/
public static containsTemplateVariables(str: string): boolean {
return !!str && /\{([a-zA-Z0-9_\.]+)\}/g.test(str);
}
}

View File

@@ -1,177 +0,0 @@
import * as plugins from '../../plugins.js';
import type { IDomainOptions, IAcmeOptions } from '../models/common-types.js';
/**
* Collection of validation utilities for configuration and domain options
*/
export class ValidationUtils {
/**
* Validates domain configuration options
*
* @param domainOptions The domain options to validate
* @returns An object with validation result and error message if invalid
*/
public static validateDomainOptions(domainOptions: IDomainOptions): { isValid: boolean; error?: string } {
if (!domainOptions) {
return { isValid: false, error: 'Domain options cannot be null or undefined' };
}
if (!domainOptions.domainName) {
return { isValid: false, error: 'Domain name is required' };
}
// Check domain pattern
if (!this.isValidDomainName(domainOptions.domainName)) {
return { isValid: false, error: `Invalid domain name: ${domainOptions.domainName}` };
}
// Validate forward config if provided
if (domainOptions.forward) {
if (!domainOptions.forward.ip) {
return { isValid: false, error: 'Forward IP is required when forward is specified' };
}
if (!domainOptions.forward.port) {
return { isValid: false, error: 'Forward port is required when forward is specified' };
}
if (!this.isValidPort(domainOptions.forward.port)) {
return { isValid: false, error: `Invalid forward port: ${domainOptions.forward.port}` };
}
}
// Validate ACME forward config if provided
if (domainOptions.acmeForward) {
if (!domainOptions.acmeForward.ip) {
return { isValid: false, error: 'ACME forward IP is required when acmeForward is specified' };
}
if (!domainOptions.acmeForward.port) {
return { isValid: false, error: 'ACME forward port is required when acmeForward is specified' };
}
if (!this.isValidPort(domainOptions.acmeForward.port)) {
return { isValid: false, error: `Invalid ACME forward port: ${domainOptions.acmeForward.port}` };
}
}
return { isValid: true };
}
/**
* Validates ACME configuration options
*
* @param acmeOptions The ACME options to validate
* @returns An object with validation result and error message if invalid
*/
public static validateAcmeOptions(acmeOptions: IAcmeOptions): { isValid: boolean; error?: string } {
if (!acmeOptions) {
return { isValid: false, error: 'ACME options cannot be null or undefined' };
}
if (acmeOptions.enabled) {
if (!acmeOptions.accountEmail) {
return { isValid: false, error: 'Account email is required when ACME is enabled' };
}
if (!this.isValidEmail(acmeOptions.accountEmail)) {
return { isValid: false, error: `Invalid email: ${acmeOptions.accountEmail}` };
}
if (acmeOptions.port && !this.isValidPort(acmeOptions.port)) {
return { isValid: false, error: `Invalid ACME port: ${acmeOptions.port}` };
}
if (acmeOptions.httpsRedirectPort && !this.isValidPort(acmeOptions.httpsRedirectPort)) {
return { isValid: false, error: `Invalid HTTPS redirect port: ${acmeOptions.httpsRedirectPort}` };
}
if (acmeOptions.renewThresholdDays && acmeOptions.renewThresholdDays < 1) {
return { isValid: false, error: 'Renew threshold days must be greater than 0' };
}
if (acmeOptions.renewCheckIntervalHours && acmeOptions.renewCheckIntervalHours < 1) {
return { isValid: false, error: 'Renew check interval hours must be greater than 0' };
}
}
return { isValid: true };
}
/**
* Validates a port number
*
* @param port The port to validate
* @returns true if the port is valid, false otherwise
*/
public static isValidPort(port: number): boolean {
return typeof port === 'number' && port > 0 && port <= 65535 && Number.isInteger(port);
}
/**
* Validates a domain name
*
* @param domain The domain name to validate
* @returns true if the domain name is valid, false otherwise
*/
public static isValidDomainName(domain: string): boolean {
if (!domain || typeof domain !== 'string') {
return false;
}
// Wildcard domain check (*.example.com)
if (domain.startsWith('*.')) {
domain = domain.substring(2);
}
// Simple domain validation pattern
const domainPattern = /^([a-zA-Z0-9]([a-zA-Z0-9\-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,}$/;
return domainPattern.test(domain);
}
/**
* Validates an email address
*
* @param email The email to validate
* @returns true if the email is valid, false otherwise
*/
public static isValidEmail(email: string): boolean {
if (!email || typeof email !== 'string') {
return false;
}
// Basic email validation pattern
const emailPattern = /^[^\s@]+@[^\s@]+\.[^\s@]+$/;
return emailPattern.test(email);
}
/**
* Validates a certificate format (PEM)
*
* @param cert The certificate content to validate
* @returns true if the certificate appears to be in PEM format, false otherwise
*/
public static isValidCertificate(cert: string): boolean {
if (!cert || typeof cert !== 'string') {
return false;
}
return cert.includes('-----BEGIN CERTIFICATE-----') &&
cert.includes('-----END CERTIFICATE-----');
}
/**
* Validates a private key format (PEM)
*
* @param key The private key content to validate
* @returns true if the key appears to be in PEM format, false otherwise
*/
public static isValidPrivateKey(key: string): boolean {
if (!key || typeof key !== 'string') {
return false;
}
return key.includes('-----BEGIN PRIVATE KEY-----') &&
key.includes('-----END PRIVATE KEY-----');
}
}

View File

@@ -1,33 +0,0 @@
/**
* WebSocket utility functions
*
* This module provides smartproxy-specific WebSocket utilities
* and re-exports protocol utilities from the protocols module
*/
// Import and re-export from protocols
import { getMessageSize as protocolGetMessageSize, toBuffer as protocolToBuffer } from '../../protocols/websocket/index.js';
export type { RawData } from '../../protocols/websocket/index.js';
/**
* Get the length of a WebSocket message regardless of its type
* (handles all possible WebSocket message data types)
*
* @param data - The data message from WebSocket (could be any RawData type)
* @returns The length of the data in bytes
*/
export function getMessageSize(data: import('../../protocols/websocket/index.js').RawData): number {
// Delegate to protocol implementation
return protocolGetMessageSize(data);
}
/**
* Convert any raw WebSocket data to Buffer for consistent handling
*
* @param data - The data message from WebSocket (could be any RawData type)
* @returns A Buffer containing the data
*/
export function toBuffer(data: import('../../protocols/websocket/index.js').RawData): Buffer {
// Delegate to protocol implementation
return protocolToBuffer(data);
}

View File

@@ -1,127 +0,0 @@
/**
* HTTP Protocol Detector
*
* Simplified HTTP detection using the new architecture
*/
import type { IProtocolDetector } from '../models/interfaces.js';
import type { IDetectionResult, IDetectionOptions } from '../models/detection-types.js';
import type { IProtocolDetectionResult, IConnectionContext } from '../../protocols/common/types.js';
import type { THttpMethod } from '../../protocols/http/index.js';
import { QuickProtocolDetector } from './quick-detector.js';
import { RoutingExtractor } from './routing-extractor.js';
import { DetectionFragmentManager } from '../utils/fragment-manager.js';
import { HttpParser } from '../../protocols/http/parser.js';
/**
* Simplified HTTP detector
*/
export class HttpDetector implements IProtocolDetector {
private quickDetector = new QuickProtocolDetector();
private fragmentManager: DetectionFragmentManager;
constructor(fragmentManager?: DetectionFragmentManager) {
this.fragmentManager = fragmentManager || new DetectionFragmentManager();
}
/**
* Check if buffer can be handled by this detector
*/
canHandle(buffer: Buffer): boolean {
const result = this.quickDetector.quickDetect(buffer);
return result.protocol === 'http' && result.confidence > 50;
}
/**
* Get minimum bytes needed for detection
*/
getMinimumBytes(): number {
return 4; // "GET " minimum
}
/**
* Detect HTTP protocol from buffer
*/
detect(buffer: Buffer, options?: IDetectionOptions): IDetectionResult | null {
// Quick detection first
const quickResult = this.quickDetector.quickDetect(buffer);
if (quickResult.protocol !== 'http' || quickResult.confidence < 50) {
return null;
}
// Check if we have complete headers first
const headersEnd = buffer.indexOf('\r\n\r\n');
const isComplete = headersEnd !== -1;
// Extract routing information
const routing = RoutingExtractor.extract(buffer, 'http');
// Extract headers if requested and we have complete headers
let headers: Record<string, string> | undefined;
if (options?.extractFullHeaders && isComplete) {
const headerSection = buffer.slice(0, headersEnd).toString();
const lines = headerSection.split('\r\n');
if (lines.length > 1) {
// Skip the request line and parse headers
headers = HttpParser.parseHeaders(lines.slice(1));
}
}
// If we don't need full headers and we have complete headers, we can return early
if (quickResult.confidence >= 95 && !options?.extractFullHeaders && isComplete) {
return {
protocol: 'http',
connectionInfo: {
protocol: 'http',
method: quickResult.metadata?.method as THttpMethod,
domain: routing?.domain,
path: routing?.path
},
isComplete: true
};
}
return {
protocol: 'http',
connectionInfo: {
protocol: 'http',
domain: routing?.domain,
path: routing?.path,
method: quickResult.metadata?.method as THttpMethod,
headers: headers
},
isComplete,
bytesNeeded: isComplete ? undefined : buffer.length + 512 // Need more for headers
};
}
/**
* Handle fragmented detection
*/
detectWithContext(
buffer: Buffer,
context: IConnectionContext,
options?: IDetectionOptions
): IDetectionResult | null {
const handler = this.fragmentManager.getHandler('http');
const connectionId = DetectionFragmentManager.createConnectionId(context);
// Add fragment
const result = handler.addFragment(connectionId, buffer);
if (result.error) {
handler.complete(connectionId);
return null;
}
// Try detection on accumulated buffer
const detectResult = this.detect(result.buffer!, options);
if (detectResult && detectResult.isComplete) {
handler.complete(connectionId);
}
return detectResult;
}
}

View File

@@ -1,148 +0,0 @@
/**
* Quick Protocol Detector
*
* Lightweight protocol identification based on minimal bytes
* No parsing, just identification
*/
import type { IProtocolDetector, IProtocolDetectionResult } from '../../protocols/common/types.js';
import { TlsRecordType } from '../../protocols/tls/index.js';
import { HttpParser } from '../../protocols/http/index.js';
/**
* Quick protocol detector for fast identification
*/
export class QuickProtocolDetector implements IProtocolDetector {
/**
* Check if this detector can handle the data
*/
canHandle(data: Buffer): boolean {
return data.length >= 1;
}
/**
* Perform quick detection based on first few bytes
*/
quickDetect(data: Buffer): IProtocolDetectionResult {
if (data.length === 0) {
return {
protocol: 'unknown',
confidence: 0,
requiresMoreData: true
};
}
// Check for TLS
const tlsResult = this.checkTls(data);
if (tlsResult.confidence > 80) {
return tlsResult;
}
// Check for HTTP
const httpResult = this.checkHttp(data);
if (httpResult.confidence > 80) {
return httpResult;
}
// Need more data or unknown
return {
protocol: 'unknown',
confidence: 0,
requiresMoreData: data.length < 20
};
}
/**
* Check if data looks like TLS
*/
private checkTls(data: Buffer): IProtocolDetectionResult {
if (data.length < 3) {
return {
protocol: 'tls',
confidence: 0,
requiresMoreData: true
};
}
const firstByte = data[0];
const secondByte = data[1];
// Check for valid TLS record type
const validRecordTypes = [
TlsRecordType.CHANGE_CIPHER_SPEC,
TlsRecordType.ALERT,
TlsRecordType.HANDSHAKE,
TlsRecordType.APPLICATION_DATA,
TlsRecordType.HEARTBEAT
];
if (!validRecordTypes.includes(firstByte)) {
return {
protocol: 'tls',
confidence: 0
};
}
// Check TLS version byte (0x03 for all TLS/SSL versions)
if (secondByte !== 0x03) {
return {
protocol: 'tls',
confidence: 0
};
}
// High confidence it's TLS
return {
protocol: 'tls',
confidence: 95,
metadata: {
recordType: firstByte
}
};
}
/**
* Check if data looks like HTTP
*/
private checkHttp(data: Buffer): IProtocolDetectionResult {
if (data.length < 3) {
return {
protocol: 'http',
confidence: 0,
requiresMoreData: true
};
}
// Quick check for HTTP methods
const start = data.subarray(0, Math.min(10, data.length)).toString('ascii');
// Check common HTTP methods
const httpMethods = ['GET ', 'POST ', 'PUT ', 'DELETE ', 'HEAD ', 'OPTIONS', 'PATCH ', 'CONNECT', 'TRACE '];
for (const method of httpMethods) {
if (start.startsWith(method)) {
return {
protocol: 'http',
confidence: 95,
metadata: {
method: method.trim()
}
};
}
}
// Check if it might be HTTP but need more data
if (HttpParser.isPrintableAscii(data, Math.min(20, data.length))) {
// Could be HTTP, but not sure
return {
protocol: 'http',
confidence: 30,
requiresMoreData: data.length < 20
};
}
return {
protocol: 'http',
confidence: 0
};
}
}

View File

@@ -1,147 +0,0 @@
/**
* Routing Information Extractor
*
* Extracts minimal routing information from protocols
* without full parsing
*/
import type { IRoutingInfo, IConnectionContext, TProtocolType } from '../../protocols/common/types.js';
import { SniExtraction } from '../../protocols/tls/sni/sni-extraction.js';
import { HttpParser } from '../../protocols/http/index.js';
/**
* Extracts routing information from protocol data
*/
export class RoutingExtractor {
/**
* Extract routing info based on protocol type
*/
static extract(
data: Buffer,
protocol: TProtocolType,
context?: IConnectionContext
): IRoutingInfo | null {
switch (protocol) {
case 'tls':
case 'https':
return this.extractTlsRouting(data, context);
case 'http':
return this.extractHttpRouting(data);
default:
return null;
}
}
/**
* Extract routing from TLS ClientHello (SNI)
*/
private static extractTlsRouting(
data: Buffer,
context?: IConnectionContext
): IRoutingInfo | null {
try {
// Quick SNI extraction without full parsing
const sni = SniExtraction.extractSNI(data);
if (sni) {
return {
domain: sni,
protocol: 'tls',
port: 443 // Default HTTPS port
};
}
return null;
} catch (error) {
// Extraction failed, return null
return null;
}
}
/**
* Extract routing from HTTP headers (Host header)
*/
private static extractHttpRouting(data: Buffer): IRoutingInfo | null {
try {
// Look for first line
const firstLineEnd = data.indexOf('\n');
if (firstLineEnd === -1) {
return null;
}
// Parse request line
const firstLine = data.subarray(0, firstLineEnd).toString('ascii').trim();
const requestLine = HttpParser.parseRequestLine(firstLine);
if (!requestLine) {
return null;
}
// Look for Host header
let pos = firstLineEnd + 1;
const maxSearch = Math.min(data.length, 4096); // Don't search too far
while (pos < maxSearch) {
const lineEnd = data.indexOf('\n', pos);
if (lineEnd === -1) break;
const line = data.subarray(pos, lineEnd).toString('ascii').trim();
// Empty line means end of headers
if (line.length === 0) break;
// Check for Host header
if (line.toLowerCase().startsWith('host:')) {
const hostValue = line.substring(5).trim();
const domain = HttpParser.extractDomainFromHost(hostValue);
return {
domain,
path: requestLine.path,
protocol: 'http',
port: 80 // Default HTTP port
};
}
pos = lineEnd + 1;
}
// No Host header found, but we have the path
return {
path: requestLine.path,
protocol: 'http',
port: 80
};
} catch (error) {
// Extraction failed
return null;
}
}
/**
* Try to extract domain from any protocol
*/
static extractDomain(data: Buffer, hint?: TProtocolType): string | null {
// If we have a hint, use it
if (hint) {
const routing = this.extract(data, hint);
return routing?.domain || null;
}
// Try TLS first (more specific)
const tlsRouting = this.extractTlsRouting(data);
if (tlsRouting?.domain) {
return tlsRouting.domain;
}
// Try HTTP
const httpRouting = this.extractHttpRouting(data);
if (httpRouting?.domain) {
return httpRouting.domain;
}
return null;
}
}

View File

@@ -1,223 +0,0 @@
/**
* TLS protocol detector
*/
// TLS detector doesn't need plugins imports
import type { IProtocolDetector } from '../models/interfaces.js';
import type { IDetectionResult, IDetectionOptions, IConnectionInfo } from '../models/detection-types.js';
import { readUInt16BE } from '../utils/buffer-utils.js';
import { tlsVersionToString } from '../utils/parser-utils.js';
// Import from protocols
import { TlsRecordType, TlsHandshakeType, TlsExtensionType } from '../../protocols/tls/index.js';
// Import TLS utilities for SNI extraction from protocols
import { SniExtraction } from '../../protocols/tls/sni/sni-extraction.js';
import { ClientHelloParser } from '../../protocols/tls/sni/client-hello-parser.js';
/**
* TLS detector implementation
*/
export class TlsDetector implements IProtocolDetector {
/**
* Minimum bytes needed to identify TLS (record header)
*/
private static readonly MIN_TLS_HEADER_SIZE = 5;
/**
* Detect TLS protocol from buffer
*/
detect(buffer: Buffer, options?: IDetectionOptions): IDetectionResult | null {
// Check if buffer is too small
if (buffer.length < TlsDetector.MIN_TLS_HEADER_SIZE) {
return null;
}
// Check if this is a TLS record
if (!this.isTlsRecord(buffer)) {
return null;
}
// Extract basic TLS info
const recordType = buffer[0];
const tlsMajor = buffer[1];
const tlsMinor = buffer[2];
const recordLength = readUInt16BE(buffer, 3);
// Initialize connection info
const connectionInfo: IConnectionInfo = {
protocol: 'tls',
tlsVersion: tlsVersionToString(tlsMajor, tlsMinor) || undefined
};
// If it's a handshake, try to extract more info
if (recordType === TlsRecordType.HANDSHAKE && buffer.length >= 6) {
const handshakeType = buffer[5];
// For ClientHello, extract SNI and other info
if (handshakeType === TlsHandshakeType.CLIENT_HELLO) {
// Check if we have the complete handshake
const totalRecordLength = recordLength + 5; // Including TLS header
if (buffer.length >= totalRecordLength) {
// Extract SNI using existing logic
const sni = SniExtraction.extractSNI(buffer);
if (sni) {
connectionInfo.domain = sni;
connectionInfo.sni = sni;
}
// Parse ClientHello for additional info
const parseResult = ClientHelloParser.parseClientHello(buffer);
if (parseResult.isValid) {
// Extract ALPN if present
const alpnExtension = parseResult.extensions.find(
ext => ext.type === TlsExtensionType.APPLICATION_LAYER_PROTOCOL_NEGOTIATION
);
if (alpnExtension) {
connectionInfo.alpn = this.parseAlpnExtension(alpnExtension.data);
}
// Store cipher suites if needed
if (parseResult.cipherSuites && options?.extractFullHeaders) {
connectionInfo.cipherSuites = this.parseCipherSuites(parseResult.cipherSuites);
}
}
// Return complete result
return {
protocol: 'tls',
connectionInfo,
remainingBuffer: buffer.length > totalRecordLength
? buffer.subarray(totalRecordLength)
: undefined,
isComplete: true
};
} else {
// Incomplete handshake
return {
protocol: 'tls',
connectionInfo,
isComplete: false,
bytesNeeded: totalRecordLength
};
}
}
}
// For other TLS record types, just return basic info
return {
protocol: 'tls',
connectionInfo,
isComplete: true,
remainingBuffer: buffer.length > recordLength + 5
? buffer.subarray(recordLength + 5)
: undefined
};
}
/**
* Check if buffer can be handled by this detector
*/
canHandle(buffer: Buffer): boolean {
return buffer.length >= TlsDetector.MIN_TLS_HEADER_SIZE &&
this.isTlsRecord(buffer);
}
/**
* Get minimum bytes needed for detection
*/
getMinimumBytes(): number {
return TlsDetector.MIN_TLS_HEADER_SIZE;
}
/**
* Check if buffer contains a valid TLS record
*/
private isTlsRecord(buffer: Buffer): boolean {
const recordType = buffer[0];
// Check for valid record type
const validTypes = [
TlsRecordType.CHANGE_CIPHER_SPEC,
TlsRecordType.ALERT,
TlsRecordType.HANDSHAKE,
TlsRecordType.APPLICATION_DATA,
TlsRecordType.HEARTBEAT
];
if (!validTypes.includes(recordType)) {
return false;
}
// Check TLS version bytes (should be 0x03 0x0X)
if (buffer[1] !== 0x03) {
return false;
}
// Check record length is reasonable
const recordLength = readUInt16BE(buffer, 3);
if (recordLength > 16384) { // Max TLS record size
return false;
}
return true;
}
/**
* Parse ALPN extension data
*/
private parseAlpnExtension(data: Buffer): string[] {
const protocols: string[] = [];
if (data.length < 2) {
return protocols;
}
const listLength = readUInt16BE(data, 0);
let offset = 2;
while (offset < Math.min(2 + listLength, data.length)) {
const protoLength = data[offset];
offset++;
if (offset + protoLength <= data.length) {
const protocol = data.subarray(offset, offset + protoLength).toString('ascii');
protocols.push(protocol);
offset += protoLength;
} else {
break;
}
}
return protocols;
}
/**
* Parse cipher suites
*/
private parseCipherSuites(cipherData: Buffer): number[] {
const suites: number[] = [];
for (let i = 0; i < cipherData.length - 1; i += 2) {
const suite = readUInt16BE(cipherData, i);
suites.push(suite);
}
return suites;
}
/**
* Detect with context for fragmented data
*/
detectWithContext(
buffer: Buffer,
_context: { sourceIp?: string; sourcePort?: number; destIp?: string; destPort?: number },
options?: IDetectionOptions
): IDetectionResult | null {
// This method is deprecated - TLS detection should use the fragment manager
// from the parent detector system, not maintain its own fragments
return this.detect(buffer, options);
}
}

View File

@@ -1,25 +0,0 @@
/**
* Centralized Protocol Detection Module
*
* This module provides unified protocol detection capabilities for
* both TLS and HTTP protocols, extracting connection information
* without consuming the data stream.
*/
// Main detector
export * from './protocol-detector.js';
// Models
export * from './models/detection-types.js';
export * from './models/interfaces.js';
// Individual detectors
export * from './detectors/tls-detector.js';
export * from './detectors/http-detector.js';
export * from './detectors/quick-detector.js';
export * from './detectors/routing-extractor.js';
// Utilities
export * from './utils/buffer-utils.js';
export * from './utils/parser-utils.js';
export * from './utils/fragment-manager.js';

View File

@@ -1,102 +0,0 @@
/**
* Type definitions for protocol detection
*/
/**
* Supported protocol types that can be detected
*/
export type TProtocolType = 'tls' | 'http' | 'unknown';
/**
* HTTP method types
*/
export type THttpMethod = 'GET' | 'POST' | 'PUT' | 'DELETE' | 'PATCH' | 'HEAD' | 'OPTIONS' | 'CONNECT' | 'TRACE';
/**
* TLS version identifiers
*/
export type TTlsVersion = 'SSLv3' | 'TLSv1.0' | 'TLSv1.1' | 'TLSv1.2' | 'TLSv1.3';
/**
* Connection information extracted from protocol detection
*/
export interface IConnectionInfo {
/**
* The detected protocol type
*/
protocol: TProtocolType;
/**
* Domain/hostname extracted from the connection
* - For TLS: from SNI extension
* - For HTTP: from Host header
*/
domain?: string;
/**
* HTTP-specific fields
*/
method?: THttpMethod;
path?: string;
httpVersion?: string;
headers?: Record<string, string>;
/**
* TLS-specific fields
*/
tlsVersion?: TTlsVersion;
sni?: string;
alpn?: string[];
cipherSuites?: number[];
}
/**
* Result of protocol detection
*/
export interface IDetectionResult {
/**
* The detected protocol type
*/
protocol: TProtocolType;
/**
* Extracted connection information
*/
connectionInfo: IConnectionInfo;
/**
* Any remaining buffer data after detection headers
* This can be used to continue processing the stream
*/
remainingBuffer?: Buffer;
/**
* Whether the detection is complete or needs more data
*/
isComplete: boolean;
/**
* Minimum bytes needed for complete detection (if incomplete)
*/
bytesNeeded?: number;
}
/**
* Options for protocol detection
*/
export interface IDetectionOptions {
/**
* Maximum bytes to buffer for detection (default: 8192)
*/
maxBufferSize?: number;
/**
* Timeout for detection in milliseconds (default: 5000)
*/
timeout?: number;
/**
* Whether to extract full headers or just essential info
*/
extractFullHeaders?: boolean;
}

View File

@@ -1,115 +0,0 @@
/**
* Interface definitions for protocol detection components
*/
import type { IDetectionResult, IDetectionOptions } from './detection-types.js';
/**
* Interface for protocol detectors
*/
export interface IProtocolDetector {
/**
* Detect protocol from buffer data
* @param buffer The buffer to analyze
* @param options Detection options
* @returns Detection result or null if protocol cannot be determined
*/
detect(buffer: Buffer, options?: IDetectionOptions): IDetectionResult | null;
/**
* Check if buffer potentially contains this protocol
* @param buffer The buffer to check
* @returns True if buffer might contain this protocol
*/
canHandle(buffer: Buffer): boolean;
/**
* Get the minimum bytes needed for detection
*/
getMinimumBytes(): number;
}
/**
* Interface for connection tracking during fragmented detection
*/
export interface IConnectionTracker {
/**
* Connection identifier
*/
id: string;
/**
* Accumulated buffer data
*/
buffer: Buffer;
/**
* Timestamp of first data
*/
startTime: number;
/**
* Current detection state
*/
state: 'detecting' | 'complete' | 'failed';
/**
* Partial detection result (if any)
*/
partialResult?: Partial<IDetectionResult>;
}
/**
* Interface for buffer accumulator (handles fragmented data)
*/
export interface IBufferAccumulator {
/**
* Add data to accumulator
*/
append(data: Buffer): void;
/**
* Get accumulated buffer
*/
getBuffer(): Buffer;
/**
* Get buffer length
*/
length(): number;
/**
* Clear accumulated data
*/
clear(): void;
/**
* Check if accumulator has enough data
*/
hasMinimumBytes(minBytes: number): boolean;
}
/**
* Detection events
*/
export interface IDetectionEvents {
/**
* Emitted when protocol is successfully detected
*/
detected: (result: IDetectionResult) => void;
/**
* Emitted when detection fails
*/
failed: (error: Error) => void;
/**
* Emitted when detection times out
*/
timeout: () => void;
/**
* Emitted when more data is needed
*/
needMoreData: (bytesNeeded: number) => void;
}

View File

@@ -1,319 +0,0 @@
/**
* Protocol Detector
*
* Simplified protocol detection using the new architecture
*/
import type { IDetectionResult, IDetectionOptions } from './models/detection-types.js';
import type { IConnectionContext } from '../protocols/common/types.js';
import { TlsDetector } from './detectors/tls-detector.js';
import { HttpDetector } from './detectors/http-detector.js';
import { DetectionFragmentManager } from './utils/fragment-manager.js';
/**
* Main protocol detector class
*/
export class ProtocolDetector {
private static instance: ProtocolDetector;
private fragmentManager: DetectionFragmentManager;
private tlsDetector: TlsDetector;
private httpDetector: HttpDetector;
private connectionProtocols: Map<string, { protocol: 'tls' | 'http'; createdAt: number }> = new Map();
constructor() {
this.fragmentManager = new DetectionFragmentManager();
this.tlsDetector = new TlsDetector();
this.httpDetector = new HttpDetector(this.fragmentManager);
}
private static getInstance(): ProtocolDetector {
if (!this.instance) {
this.instance = new ProtocolDetector();
}
return this.instance;
}
/**
* Detect protocol from buffer data
*/
static async detect(buffer: Buffer, options?: IDetectionOptions): Promise<IDetectionResult> {
return this.getInstance().detectInstance(buffer, options);
}
private async detectInstance(buffer: Buffer, options?: IDetectionOptions): Promise<IDetectionResult> {
// Quick sanity check
if (!buffer || buffer.length === 0) {
return {
protocol: 'unknown',
connectionInfo: { protocol: 'unknown' },
isComplete: true
};
}
// Try TLS detection first (more specific)
if (this.tlsDetector.canHandle(buffer)) {
const tlsResult = this.tlsDetector.detect(buffer, options);
if (tlsResult) {
return tlsResult;
}
}
// Try HTTP detection
if (this.httpDetector.canHandle(buffer)) {
const httpResult = this.httpDetector.detect(buffer, options);
if (httpResult) {
return httpResult;
}
}
// Neither TLS nor HTTP
return {
protocol: 'unknown',
connectionInfo: { protocol: 'unknown' },
isComplete: true
};
}
/**
* Detect protocol with connection tracking for fragmented data
* @deprecated Use detectWithContext instead
*/
static async detectWithConnectionTracking(
buffer: Buffer,
connectionId: string,
options?: IDetectionOptions
): Promise<IDetectionResult> {
// Convert connection ID to context
const context: IConnectionContext = {
id: connectionId,
sourceIp: 'unknown',
sourcePort: 0,
destIp: 'unknown',
destPort: 0,
timestamp: Date.now()
};
return this.getInstance().detectWithContextInstance(buffer, context, options);
}
/**
* Detect protocol with connection context for fragmented data
*/
static async detectWithContext(
buffer: Buffer,
context: IConnectionContext,
options?: IDetectionOptions
): Promise<IDetectionResult> {
return this.getInstance().detectWithContextInstance(buffer, context, options);
}
private async detectWithContextInstance(
buffer: Buffer,
context: IConnectionContext,
options?: IDetectionOptions
): Promise<IDetectionResult> {
// Quick sanity check
if (!buffer || buffer.length === 0) {
return {
protocol: 'unknown',
connectionInfo: { protocol: 'unknown' },
isComplete: true
};
}
const connectionId = DetectionFragmentManager.createConnectionId(context);
// Check if we already know the protocol for this connection
const knownEntry = this.connectionProtocols.get(connectionId);
const knownProtocol = knownEntry?.protocol;
if (knownProtocol === 'http') {
const result = this.httpDetector.detectWithContext(buffer, context, options);
if (result) {
if (result.isComplete) {
this.connectionProtocols.delete(connectionId);
}
return result;
}
} else if (knownProtocol === 'tls') {
// Handle TLS with fragment accumulation
const handler = this.fragmentManager.getHandler('tls');
const fragmentResult = handler.addFragment(connectionId, buffer);
if (fragmentResult.error) {
handler.complete(connectionId);
this.connectionProtocols.delete(connectionId);
return {
protocol: 'unknown',
connectionInfo: { protocol: 'unknown' },
isComplete: true
};
}
const result = this.tlsDetector.detect(fragmentResult.buffer!, options);
if (result) {
if (result.isComplete) {
handler.complete(connectionId);
this.connectionProtocols.delete(connectionId);
}
return result;
}
}
// If we don't know the protocol yet, try to detect it
if (!knownProtocol) {
// First peek to determine protocol type
if (this.tlsDetector.canHandle(buffer)) {
this.connectionProtocols.set(connectionId, { protocol: 'tls', createdAt: Date.now() });
// Handle TLS with fragment accumulation
const handler = this.fragmentManager.getHandler('tls');
const fragmentResult = handler.addFragment(connectionId, buffer);
if (fragmentResult.error) {
handler.complete(connectionId);
this.connectionProtocols.delete(connectionId);
return {
protocol: 'unknown',
connectionInfo: { protocol: 'unknown' },
isComplete: true
};
}
const result = this.tlsDetector.detect(fragmentResult.buffer!, options);
if (result) {
if (result.isComplete) {
handler.complete(connectionId);
this.connectionProtocols.delete(connectionId);
}
return result;
}
}
if (this.httpDetector.canHandle(buffer)) {
this.connectionProtocols.set(connectionId, { protocol: 'http', createdAt: Date.now() });
const result = this.httpDetector.detectWithContext(buffer, context, options);
if (result) {
if (result.isComplete) {
this.connectionProtocols.delete(connectionId);
}
return result;
}
}
}
// Can't determine protocol
return {
protocol: 'unknown',
connectionInfo: { protocol: 'unknown' },
isComplete: false,
bytesNeeded: Math.max(
this.tlsDetector.getMinimumBytes(),
this.httpDetector.getMinimumBytes()
)
};
}
/**
* Clean up resources
*/
static cleanup(): void {
this.getInstance().cleanupInstance();
}
private cleanupInstance(): void {
this.fragmentManager.cleanup();
// Remove stale connectionProtocols entries (abandoned handshakes, port scanners)
const maxAge = 30_000; // 30 seconds
const now = Date.now();
for (const [id, entry] of this.connectionProtocols) {
if (now - entry.createdAt > maxAge) {
this.connectionProtocols.delete(id);
}
}
}
/**
* Destroy detector instance
*/
static destroy(): void {
this.getInstance().destroyInstance();
this.instance = null as any;
}
private destroyInstance(): void {
this.fragmentManager.destroy();
this.connectionProtocols.clear();
}
/**
* Clean up old connection tracking entries
*
* @param _maxAge Maximum age in milliseconds (default: 30 seconds)
*/
static cleanupConnections(_maxAge: number = 30000): void {
this.getInstance().cleanupInstance();
}
/**
* Clean up fragments for a specific connection
*/
static cleanupConnection(context: IConnectionContext): void {
const instance = this.getInstance();
const connectionId = DetectionFragmentManager.createConnectionId(context);
// Clean up both TLS and HTTP fragments for this connection
instance.fragmentManager.getHandler('tls').complete(connectionId);
instance.fragmentManager.getHandler('http').complete(connectionId);
// Remove from connection protocols tracking
instance.connectionProtocols.delete(connectionId);
}
/**
* Extract domain from connection info
*/
static extractDomain(connectionInfo: any): string | undefined {
return connectionInfo.domain || connectionInfo.sni || connectionInfo.host;
}
/**
* Create a connection ID from connection parameters
* @deprecated Use createConnectionContext instead
*/
static createConnectionId(params: {
sourceIp?: string;
sourcePort?: number;
destIp?: string;
destPort?: number;
socketId?: string;
}): string {
// If socketId is provided, use it
if (params.socketId) {
return params.socketId;
}
// Otherwise create from connection tuple
const { sourceIp = 'unknown', sourcePort = 0, destIp = 'unknown', destPort = 0 } = params;
return `${sourceIp}:${sourcePort}-${destIp}:${destPort}`;
}
/**
* Create a connection context from parameters
*/
static createConnectionContext(params: {
sourceIp?: string;
sourcePort?: number;
destIp?: string;
destPort?: number;
socketId?: string;
}): IConnectionContext {
return {
id: params.socketId,
sourceIp: params.sourceIp || 'unknown',
sourcePort: params.sourcePort || 0,
destIp: params.destIp || 'unknown',
destPort: params.destPort || 0,
timestamp: Date.now()
};
}
}

View File

@@ -1,141 +0,0 @@
/**
* Buffer manipulation utilities for protocol detection
*/
// Import from protocols
import { HttpParser } from '../../protocols/http/index.js';
/**
* BufferAccumulator class for handling fragmented data
*/
export class BufferAccumulator {
private chunks: Buffer[] = [];
private totalLength = 0;
/**
* Append data to the accumulator
*/
append(data: Buffer): void {
this.chunks.push(data);
this.totalLength += data.length;
}
/**
* Get the accumulated buffer
*/
getBuffer(): Buffer {
if (this.chunks.length === 0) {
return Buffer.alloc(0);
}
if (this.chunks.length === 1) {
return this.chunks[0];
}
return Buffer.concat(this.chunks, this.totalLength);
}
/**
* Get current buffer length
*/
length(): number {
return this.totalLength;
}
/**
* Clear all accumulated data
*/
clear(): void {
this.chunks = [];
this.totalLength = 0;
}
/**
* Check if accumulator has minimum bytes
*/
hasMinimumBytes(minBytes: number): boolean {
return this.totalLength >= minBytes;
}
}
/**
* Read a big-endian 16-bit integer from buffer
*/
export function readUInt16BE(buffer: Buffer, offset: number): number {
if (offset + 2 > buffer.length) {
throw new Error('Buffer too short for UInt16BE read');
}
return (buffer[offset] << 8) | buffer[offset + 1];
}
/**
* Read a big-endian 24-bit integer from buffer
*/
export function readUInt24BE(buffer: Buffer, offset: number): number {
if (offset + 3 > buffer.length) {
throw new Error('Buffer too short for UInt24BE read');
}
return (buffer[offset] << 16) | (buffer[offset + 1] << 8) | buffer[offset + 2];
}
/**
* Find a byte sequence in a buffer
*/
export function findSequence(buffer: Buffer, sequence: Buffer, startOffset = 0): number {
if (sequence.length === 0) {
return startOffset;
}
const searchLength = buffer.length - sequence.length + 1;
for (let i = startOffset; i < searchLength; i++) {
let found = true;
for (let j = 0; j < sequence.length; j++) {
if (buffer[i + j] !== sequence[j]) {
found = false;
break;
}
}
if (found) {
return i;
}
}
return -1;
}
/**
* Extract a line from buffer (up to CRLF or LF)
*/
export function extractLine(buffer: Buffer, startOffset = 0): { line: string; nextOffset: number } | null {
// Delegate to protocol parser
return HttpParser.extractLine(buffer, startOffset);
}
/**
* Check if buffer starts with a string (case-insensitive)
*/
export function startsWithString(buffer: Buffer, str: string, offset = 0): boolean {
if (offset + str.length > buffer.length) {
return false;
}
const bufferStr = buffer.slice(offset, offset + str.length).toString('utf8');
return bufferStr.toLowerCase() === str.toLowerCase();
}
/**
* Safe buffer slice that doesn't throw on out-of-bounds
*/
export function safeSlice(buffer: Buffer, start: number, end?: number): Buffer {
const safeStart = Math.max(0, Math.min(start, buffer.length));
const safeEnd = end === undefined
? buffer.length
: Math.max(safeStart, Math.min(end, buffer.length));
return buffer.slice(safeStart, safeEnd);
}
/**
* Check if buffer contains printable ASCII
*/
export function isPrintableAscii(buffer: Buffer, length?: number): boolean {
// Delegate to protocol parser
return HttpParser.isPrintableAscii(buffer, length);
}

View File

@@ -1,64 +0,0 @@
/**
* Fragment Manager for Detection Module
*
* Manages fragmented protocol data using the shared fragment handler
*/
import { FragmentHandler, type IFragmentOptions } from '../../protocols/common/fragment-handler.js';
import type { IConnectionContext } from '../../protocols/common/types.js';
/**
* Detection-specific fragment manager
*/
export class DetectionFragmentManager {
private tlsFragments: FragmentHandler;
private httpFragments: FragmentHandler;
constructor() {
// Configure fragment handlers with appropriate limits
const tlsOptions: IFragmentOptions = {
maxBufferSize: 16384, // TLS record max size
timeout: 5000,
cleanupInterval: 30000
};
const httpOptions: IFragmentOptions = {
maxBufferSize: 8192, // HTTP header reasonable limit
timeout: 5000,
cleanupInterval: 30000
};
this.tlsFragments = new FragmentHandler(tlsOptions);
this.httpFragments = new FragmentHandler(httpOptions);
}
/**
* Get fragment handler for protocol type
*/
getHandler(protocol: 'tls' | 'http'): FragmentHandler {
return protocol === 'tls' ? this.tlsFragments : this.httpFragments;
}
/**
* Create connection ID from context
*/
static createConnectionId(context: IConnectionContext): string {
return context.id || `${context.sourceIp}:${context.sourcePort}-${context.destIp}:${context.destPort}`;
}
/**
* Clean up all handlers
*/
cleanup(): void {
this.tlsFragments.cleanup();
this.httpFragments.cleanup();
}
/**
* Destroy all handlers
*/
destroy(): void {
this.tlsFragments.destroy();
this.httpFragments.destroy();
}
}

View File

@@ -1,77 +0,0 @@
/**
* Parser utilities for protocol detection
* Now delegates to protocol modules for actual parsing
*/
import type { THttpMethod, TTlsVersion } from '../models/detection-types.js';
import { HttpParser, HTTP_METHODS, HTTP_VERSIONS } from '../../protocols/http/index.js';
import { tlsVersionToString as protocolTlsVersionToString } from '../../protocols/tls/index.js';
// Re-export constants for backward compatibility
export { HTTP_METHODS, HTTP_VERSIONS };
/**
* Parse HTTP request line
*/
export function parseHttpRequestLine(line: string): {
method: THttpMethod;
path: string;
version: string;
} | null {
// Delegate to protocol parser
const result = HttpParser.parseRequestLine(line);
return result ? {
method: result.method as THttpMethod,
path: result.path,
version: result.version
} : null;
}
/**
* Parse HTTP header line
*/
export function parseHttpHeader(line: string): { name: string; value: string } | null {
// Delegate to protocol parser
return HttpParser.parseHeaderLine(line);
}
/**
* Parse HTTP headers from lines
*/
export function parseHttpHeaders(lines: string[]): Record<string, string> {
// Delegate to protocol parser
return HttpParser.parseHeaders(lines);
}
/**
* Convert TLS version bytes to version string
*/
export function tlsVersionToString(major: number, minor: number): TTlsVersion | null {
// Delegate to protocol parser
return protocolTlsVersionToString(major, minor) as TTlsVersion;
}
/**
* Extract domain from Host header value
*/
export function extractDomainFromHost(hostHeader: string): string {
// Delegate to protocol parser
return HttpParser.extractDomainFromHost(hostHeader);
}
/**
* Validate domain name
*/
export function isValidDomain(domain: string): boolean {
// Delegate to protocol parser
return HttpParser.isValidDomain(domain);
}
/**
* Check if string is a valid HTTP method
*/
export function isHttpMethod(str: string): str is THttpMethod {
// Delegate to protocol parser
return HttpParser.isHttpMethod(str) && (str as THttpMethod) !== undefined;
}

View File

@@ -11,18 +11,12 @@ export type { ISmartProxyOptions, IConnectionRecord, IRouteConfig, IRouteMatch,
export type { TSmartProxyCertProvisionObject, ICertProvisionEventComms, ICertificateIssuedEvent, ICertificateFailedEvent } from './proxies/smart-proxy/models/interfaces.js';
export * from './proxies/smart-proxy/utils/index.js';
// Original: export * from './smartproxy/classes.pp.snihandler.js'
// Now we export from the new module
export { SniHandler } from './tls/sni/sni-handler.js';
// Core types and utilities
export * from './core/models/common-types.js';
// Export IAcmeOptions from one place only
export type { IAcmeOptions } from './proxies/smart-proxy/models/interfaces.js';
// Modular exports for new architecture
export * as tls from './tls/index.js';
// Modular exports
export * as routing from './routing/index.js';
export * as detection from './detection/index.js';
export * as protocols from './protocols/index.js';

View File

@@ -19,12 +19,14 @@ export { tsclass };
import * as smartcrypto from '@push.rocks/smartcrypto';
import * as smartlog from '@push.rocks/smartlog';
import * as smartlogDestinationLocal from '@push.rocks/smartlog/destination-local';
import * as smartnftables from '@push.rocks/smartnftables';
import * as smartrust from '@push.rocks/smartrust';
export {
smartcrypto,
smartlog,
smartlogDestinationLocal,
smartnftables,
smartrust,
};

View File

@@ -1,167 +0,0 @@
/**
* Shared Fragment Handler for Protocol Detection
*
* Provides unified fragment buffering and reassembly for protocols
* that may span multiple TCP packets.
*/
import { Buffer } from 'node:buffer';
/**
* Fragment tracking information
*/
export interface IFragmentInfo {
buffer: Buffer;
timestamp: number;
connectionId: string;
}
/**
* Options for fragment handling
*/
export interface IFragmentOptions {
maxBufferSize?: number;
timeout?: number;
cleanupInterval?: number;
}
/**
* Result of fragment processing
*/
export interface IFragmentResult {
isComplete: boolean;
buffer?: Buffer;
needsMoreData: boolean;
error?: string;
}
/**
* Shared fragment handler for protocol detection
*/
export class FragmentHandler {
private fragments = new Map<string, IFragmentInfo>();
private cleanupTimer?: NodeJS.Timeout;
constructor(private options: IFragmentOptions = {}) {
// Start cleanup timer if not already running
if (options.cleanupInterval && !this.cleanupTimer) {
this.cleanupTimer = setInterval(
() => this.cleanup(),
options.cleanupInterval
);
// Don't let this timer prevent process exit
if (this.cleanupTimer.unref) {
this.cleanupTimer.unref();
}
}
}
/**
* Add a fragment for a connection
*/
addFragment(connectionId: string, fragment: Buffer): IFragmentResult {
const existing = this.fragments.get(connectionId);
if (existing) {
// Append to existing buffer
const newBuffer = Buffer.concat([existing.buffer, fragment]);
// Check size limit
const maxSize = this.options.maxBufferSize || 65536;
if (newBuffer.length > maxSize) {
this.fragments.delete(connectionId);
return {
isComplete: false,
needsMoreData: false,
error: 'Buffer size exceeded maximum allowed'
};
}
// Update fragment info
this.fragments.set(connectionId, {
buffer: newBuffer,
timestamp: Date.now(),
connectionId
});
return {
isComplete: false,
buffer: newBuffer,
needsMoreData: true
};
} else {
// New fragment
this.fragments.set(connectionId, {
buffer: fragment,
timestamp: Date.now(),
connectionId
});
return {
isComplete: false,
buffer: fragment,
needsMoreData: true
};
}
}
/**
* Get the current buffer for a connection
*/
getBuffer(connectionId: string): Buffer | undefined {
return this.fragments.get(connectionId)?.buffer;
}
/**
* Mark a connection as complete and clean up
*/
complete(connectionId: string): void {
this.fragments.delete(connectionId);
}
/**
* Check if we're tracking a connection
*/
hasConnection(connectionId: string): boolean {
return this.fragments.has(connectionId);
}
/**
* Clean up expired fragments
*/
cleanup(): void {
const now = Date.now();
const timeout = this.options.timeout || 5000;
for (const [connectionId, info] of this.fragments.entries()) {
if (now - info.timestamp > timeout) {
this.fragments.delete(connectionId);
}
}
}
/**
* Clear all fragments
*/
clear(): void {
this.fragments.clear();
}
/**
* Destroy the handler and clean up resources
*/
destroy(): void {
if (this.cleanupTimer) {
clearInterval(this.cleanupTimer);
this.cleanupTimer = undefined;
}
this.clear();
}
/**
* Get the number of tracked connections
*/
get size(): number {
return this.fragments.size;
}
}

View File

@@ -1,8 +0,0 @@
/**
* Common Protocol Infrastructure
*
* Shared utilities and types for protocol handling
*/
export * from './fragment-handler.js';
export * from './types.js';

View File

@@ -1,76 +0,0 @@
/**
* Common Protocol Types
*
* Shared types used across different protocol implementations
*/
/**
* Supported protocol types
*/
export type TProtocolType = 'tls' | 'http' | 'https' | 'websocket' | 'unknown';
/**
* Protocol detection result
*/
export interface IProtocolDetectionResult {
protocol: TProtocolType;
confidence: number; // 0-100
requiresMoreData?: boolean;
metadata?: {
version?: string;
method?: string;
[key: string]: any;
};
}
/**
* Routing information extracted from protocols
*/
export interface IRoutingInfo {
domain?: string;
port?: number;
path?: string;
protocol: TProtocolType;
}
/**
* Connection context for protocol operations
*/
export interface IConnectionContext {
id: string;
sourceIp?: string;
sourcePort?: number;
destIp?: string;
destPort?: number;
timestamp?: number;
}
/**
* Protocol detection options
*/
export interface IProtocolDetectionOptions {
quickMode?: boolean; // Only do minimal detection
extractRouting?: boolean; // Extract routing information
maxWaitTime?: number; // Max time to wait for complete data
maxBufferSize?: number; // Max buffer size for fragmented data
}
/**
* Base interface for protocol detectors
*/
export interface IProtocolDetector {
/**
* Check if this detector can handle the data
*/
canHandle(data: Buffer): boolean;
/**
* Perform quick detection (first few bytes only)
*/
quickDetect(data: Buffer): IProtocolDetectionResult;
/**
* Extract routing information if possible
*/
extractRouting?(data: Buffer, context?: IConnectionContext): IRoutingInfo | null;
}

View File

@@ -4,5 +4,4 @@
*/
export * from './constants.js';
export * from './types.js';
export * from './parser.js';
export * from './types.js';

View File

@@ -1,219 +0,0 @@
/**
* HTTP Protocol Parser
* Generic HTTP parsing utilities
*/
import { HTTP_METHODS, type THttpMethod, type THttpVersion } from './constants.js';
import type { IHttpRequestLine, IHttpHeader } from './types.js';
/**
* HTTP parser utilities
*/
export class HttpParser {
/**
* Check if string is a valid HTTP method
*/
static isHttpMethod(str: string): str is THttpMethod {
return HTTP_METHODS.includes(str as THttpMethod);
}
/**
* Parse HTTP request line
*/
static parseRequestLine(line: string): IHttpRequestLine | null {
const parts = line.trim().split(' ');
if (parts.length !== 3) {
return null;
}
const [method, path, version] = parts;
// Validate method
if (!this.isHttpMethod(method)) {
return null;
}
// Validate version
if (!version.startsWith('HTTP/')) {
return null;
}
return {
method: method as THttpMethod,
path,
version: version as THttpVersion
};
}
/**
* Parse HTTP header line
*/
static parseHeaderLine(line: string): IHttpHeader | null {
const colonIndex = line.indexOf(':');
if (colonIndex === -1) {
return null;
}
const name = line.slice(0, colonIndex).trim();
const value = line.slice(colonIndex + 1).trim();
if (!name) {
return null;
}
return { name, value };
}
/**
* Parse HTTP headers from lines
*/
static parseHeaders(lines: string[]): Record<string, string> {
const headers: Record<string, string> = {};
for (const line of lines) {
const header = this.parseHeaderLine(line);
if (header) {
// Convert header names to lowercase for consistency
headers[header.name.toLowerCase()] = header.value;
}
}
return headers;
}
/**
* Extract domain from Host header value
*/
static extractDomainFromHost(hostHeader: string): string {
// Remove port if present
const colonIndex = hostHeader.lastIndexOf(':');
if (colonIndex !== -1) {
// Check if it's not part of IPv6 address
const beforeColon = hostHeader.slice(0, colonIndex);
if (!beforeColon.includes(']')) {
return beforeColon;
}
}
return hostHeader;
}
/**
* Validate domain name
*/
static isValidDomain(domain: string): boolean {
// Basic domain validation
if (!domain || domain.length > 253) {
return false;
}
// Check for valid characters and structure
const domainRegex = /^(?!-)[A-Za-z0-9-]{1,63}(?<!-)(\.[A-Za-z0-9-]{1,63})*$/;
return domainRegex.test(domain);
}
/**
* Extract line from buffer
*/
static extractLine(buffer: Buffer, offset: number = 0): { line: string; nextOffset: number } | null {
// Look for CRLF
const crlfIndex = buffer.indexOf('\r\n', offset);
if (crlfIndex === -1) {
// Look for just LF
const lfIndex = buffer.indexOf('\n', offset);
if (lfIndex === -1) {
return null;
}
return {
line: buffer.slice(offset, lfIndex).toString('utf8'),
nextOffset: lfIndex + 1
};
}
return {
line: buffer.slice(offset, crlfIndex).toString('utf8'),
nextOffset: crlfIndex + 2
};
}
/**
* Check if buffer contains printable ASCII
*/
static isPrintableAscii(buffer: Buffer, length?: number): boolean {
const checkLength = Math.min(length || buffer.length, buffer.length);
for (let i = 0; i < checkLength; i++) {
const byte = buffer[i];
// Allow printable ASCII (32-126) plus tab (9), LF (10), and CR (13)
if (byte < 32 || byte > 126) {
if (byte !== 9 && byte !== 10 && byte !== 13) {
return false;
}
}
}
return true;
}
/**
* Quick check if buffer starts with HTTP method
*/
static quickCheck(buffer: Buffer): boolean {
if (buffer.length < 3) {
return false;
}
// Check common HTTP methods
const start = buffer.slice(0, 7).toString('ascii');
return start.startsWith('GET ') ||
start.startsWith('POST ') ||
start.startsWith('PUT ') ||
start.startsWith('DELETE ') ||
start.startsWith('HEAD ') ||
start.startsWith('OPTIONS') ||
start.startsWith('PATCH ') ||
start.startsWith('CONNECT') ||
start.startsWith('TRACE ');
}
/**
* Parse query string
*/
static parseQueryString(queryString: string): Record<string, string> {
const params: Record<string, string> = {};
if (!queryString) {
return params;
}
// Remove leading '?' if present
if (queryString.startsWith('?')) {
queryString = queryString.slice(1);
}
const pairs = queryString.split('&');
for (const pair of pairs) {
const [key, value] = pair.split('=');
if (key) {
params[decodeURIComponent(key)] = value ? decodeURIComponent(value) : '';
}
}
return params;
}
/**
* Build query string from params
*/
static buildQueryString(params: Record<string, string>): string {
const pairs: string[] = [];
for (const [key, value] of Object.entries(params)) {
pairs.push(`${encodeURIComponent(key)}=${encodeURIComponent(value)}`);
}
return pairs.length > 0 ? '?' + pairs.join('&') : '';
}
}

View File

@@ -1,12 +1,5 @@
/**
* Protocol-specific modules for smartproxy
*
* This directory contains generic protocol knowledge separated from
* smartproxy-specific implementation details.
*/
export * as common from './common/index.js';
export * as tls from './tls/index.js';
export * as http from './http/index.js';
export * as proxy from './proxy/index.js';
export * as websocket from './websocket/index.js';

View File

@@ -1,6 +0,0 @@
/**
* PROXY Protocol Module
* Type definitions for HAProxy PROXY protocol v1/v2
*/
export * from './types.js';

View File

@@ -1,53 +0,0 @@
/**
* PROXY Protocol Type Definitions
* Based on HAProxy PROXY protocol specification
*/
/**
* PROXY protocol version
*/
export type TProxyProtocolVersion = 'v1' | 'v2';
/**
* Connection protocol type
*/
export type TProxyProtocol = 'TCP4' | 'TCP6' | 'UDP4' | 'UDP6' | 'UNKNOWN';
/**
* Interface representing parsed PROXY protocol information
*/
export interface IProxyInfo {
protocol: TProxyProtocol;
sourceIP: string;
sourcePort: number;
destinationIP: string;
destinationPort: number;
}
/**
* Interface for parse result including remaining data
*/
export interface IProxyParseResult {
proxyInfo: IProxyInfo | null;
remainingData: Buffer;
}
/**
* PROXY protocol v2 header format
*/
export interface IProxyV2Header {
signature: Buffer;
versionCommand: number;
family: number;
length: number;
}
/**
* Connection information for PROXY protocol
*/
export interface IProxyConnectionInfo {
sourceIp?: string;
sourcePort?: number;
destIp?: string;
destPort?: number;
}

Some files were not shown because too many files have changed in this diff Show More