Compare commits
34 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| d12812bb8d | |||
| fc04a0210b | |||
| 33fdf42a70 | |||
| fb1c59ac9a | |||
| ea8224c400 | |||
| da1cc58a3d | |||
| 606c620849 | |||
| 4ae09ac6ae | |||
| 2fce910795 | |||
| ff09cef350 | |||
| d0148b2ac3 | |||
| 7217e15649 | |||
| bfcf92a855 | |||
| 8e0804cd20 | |||
| c63f6fcd5f | |||
| f3cd4d193e | |||
| 81de611255 | |||
| 91598b3be9 | |||
| 4e3c548012 | |||
| 1a2d7529db | |||
| 31514f54ae | |||
| 247653c9d0 | |||
| 07d88f6f6a | |||
| 4b64de2c67 | |||
| e8db7bc96d | |||
| 2621dea9fa | |||
| bb5b9b3d12 | |||
| d70c2d77ed | |||
| 4cf13c36f8 | |||
| 37c7233780 | |||
| 15d0a721d5 | |||
| af970c447e | |||
| 9e1103e7a7 | |||
| 2b990527ac |
115
changelog.md
115
changelog.md
@@ -1,5 +1,120 @@
|
|||||||
# Changelog
|
# Changelog
|
||||||
|
|
||||||
|
## 2026-03-21 - 26.0.0 - BREAKING CHANGE(ts-api,rustproxy)
|
||||||
|
remove deprecated TypeScript protocol and utility exports while hardening QUIC, HTTP/3, WebSocket, and rate limiter cleanup paths
|
||||||
|
|
||||||
|
- Removes large parts of the public TypeScript surface including detection, TLS, router, websocket, proxy/common protocol, and multiple core utility exports and files.
|
||||||
|
- Adds parent-child cancellation handling for HTTP/3 and QUIC stream forwarding to stop orphaned tasks and close idle or overlong streams.
|
||||||
|
- Improves cleanup reliability with RAII guards for WebSocket upstream tracking and QUIC connection metrics, plus periodic cleanup for rate limiter and proxy address maps.
|
||||||
|
- Cleans backend metrics state when active backend connections drop to zero and tracks passthrough backend sockets for shutdown cleanup.
|
||||||
|
|
||||||
|
## 2026-03-20 - 25.17.10 - fix(rustproxy-http)
|
||||||
|
reuse the shared HTTP proxy service for HTTP/3 request handling
|
||||||
|
|
||||||
|
- Refactors H3ProxyService to delegate requests to the shared HttpProxyService instead of maintaining separate routing and backend forwarding logic.
|
||||||
|
- Aligns HTTP/3 with the TCP/HTTP path for route matching, connection pooling, and ALPN-based upstream protocol detection.
|
||||||
|
- Generalizes request handling and filters to accept boxed/generic HTTP bodies so both HTTP/3 and existing HTTP paths share the same proxy pipeline.
|
||||||
|
- Updates the HTTP/3 integration route matcher to allow transport matching across shared HTTP and QUIC handling.
|
||||||
|
|
||||||
|
## 2026-03-20 - 25.17.9 - fix(rustproxy-http)
|
||||||
|
correct HTTP/3 host extraction and avoid protocol filtering during UDP route lookup
|
||||||
|
|
||||||
|
- Use the URI host or strip the port from the Host header so HTTP/3 requests match routes consistently with TCP/HTTP handling.
|
||||||
|
- Remove protocol filtering from HTTP/3 route lookup because QUIC transport already constrains routing to UDP and protocol validation happens earlier.
|
||||||
|
|
||||||
|
## 2026-03-20 - 25.17.8 - fix(rustproxy)
|
||||||
|
use SNI-based certificate resolution for QUIC TLS connections
|
||||||
|
|
||||||
|
- Replaces static first-certificate selection with the shared CertResolver used by the TCP/TLS path.
|
||||||
|
- Ensures QUIC connections can present the correct certificate per requested domain.
|
||||||
|
- Keeps HTTP/3 ALPN configuration while improving multi-domain TLS handling.
|
||||||
|
|
||||||
|
## 2026-03-20 - 25.17.7 - fix(readme)
|
||||||
|
document QUIC and HTTP/3 compatibility caveats
|
||||||
|
|
||||||
|
- Add notes explaining that GREASE frames are disabled on both server and client HTTP/3 paths to avoid interoperability issues
|
||||||
|
- Document that the current HTTP/3 stack depends on pre-1.0 h3 ecosystem components and may still have rough edges
|
||||||
|
|
||||||
|
## 2026-03-20 - 25.17.6 - fix(rustproxy-http)
|
||||||
|
disable HTTP/3 GREASE for client and server connections
|
||||||
|
|
||||||
|
- Switch the HTTP/3 server connection setup to use the builder API with send_grease(false)
|
||||||
|
- Switch the HTTP/3 client handshake to use the builder API with send_grease(false) to improve compatibility
|
||||||
|
|
||||||
|
## 2026-03-20 - 25.17.5 - fix(rustproxy)
|
||||||
|
add HTTP/3 integration test for QUIC response stream FIN handling
|
||||||
|
|
||||||
|
- adds an integration test covering HTTP/3 proxying over QUIC with TLS termination
|
||||||
|
- verifies response bodies fully arrive and the client receives stream termination instead of hanging
|
||||||
|
- adds test-only dependencies for quinn, h3, h3-quinn, rustls, bytes, and http
|
||||||
|
|
||||||
|
## 2026-03-20 - 25.17.4 - fix(rustproxy-http)
|
||||||
|
prevent HTTP/3 response body streaming from hanging on backend completion
|
||||||
|
|
||||||
|
- extract and track Content-Length before consuming the response body
|
||||||
|
- stop the HTTP/3 body loop when the stream reports end-of-stream or the expected byte count has been sent
|
||||||
|
- add a per-frame idle timeout to avoid indefinite waits on stalled or close-delimited backend bodies
|
||||||
|
|
||||||
|
## 2026-03-20 - 25.17.3 - fix(repository)
|
||||||
|
no changes detected
|
||||||
|
|
||||||
|
|
||||||
|
## 2026-03-20 - 25.17.2 - fix(rustproxy-http)
|
||||||
|
enable TLS connections for HTTP/3 upstream requests when backend re-encryption or TLS is configured
|
||||||
|
|
||||||
|
- Pass backend TLS client configuration into the HTTP/3 request handler.
|
||||||
|
- Detect TLS-required upstream targets using route and target TLS settings before connecting.
|
||||||
|
- Build backend request URIs with the correct http or https scheme to match the upstream connection.
|
||||||
|
|
||||||
|
## 2026-03-20 - 25.17.1 - fix(rustproxy-routing)
|
||||||
|
allow QUIC UDP TLS connections without SNI to match domain-restricted routes
|
||||||
|
|
||||||
|
- Exempts UDP transport from the no-SNI rejection logic because QUIC encrypts the TLS ClientHello and SNI is unavailable at accept time
|
||||||
|
- Adds regression tests to confirm QUIC route matching succeeds without SNI while TCP TLS without SNI remains rejected
|
||||||
|
|
||||||
|
## 2026-03-19 - 25.17.0 - feat(rustproxy-passthrough)
|
||||||
|
add PROXY protocol v2 client IP handling for UDP and QUIC listeners
|
||||||
|
|
||||||
|
- propagate trusted proxy IP configuration into UDP and QUIC listener managers
|
||||||
|
- extract and preserve real client addresses from PROXY protocol v2 headers for HTTP/3 and QUIC stream handling
|
||||||
|
- apply rate limiting, session limits, routing, and metrics using the resolved client IP while preserving correct proxy return-path routing
|
||||||
|
|
||||||
|
## 2026-03-19 - 25.16.3 - fix(rustproxy)
|
||||||
|
upgrade fallback UDP listeners to QUIC when TLS certificates become available
|
||||||
|
|
||||||
|
- Rebuild and apply QUIC TLS configuration during route and certificate updates instead of only when adding new UDP ports.
|
||||||
|
- Add logic to drain UDP sessions, stop raw fallback listeners, and start QUIC endpoints on existing ports once TLS is available.
|
||||||
|
- Retry QUIC endpoint creation during upgrade and fall back to rebinding raw UDP if the upgrade cannot complete.
|
||||||
|
|
||||||
|
## 2026-03-19 - 25.16.2 - fix(rustproxy-http)
|
||||||
|
cache backend Alt-Svc only from original upstream responses during protocol auto-detection
|
||||||
|
|
||||||
|
- Moves Alt-Svc discovery into streaming response construction so it reads backend headers before response filters inject client-facing Alt-Svc values
|
||||||
|
- Stores the protocol cache key in connection activity during auto-detect mode and clears it after HTTP/3 connection failure to avoid re-caching failed H3 routes
|
||||||
|
- Prevents fallback requests from reintroducing stale or self-injected Alt-Svc entries that could cause repeated H3 retry loops
|
||||||
|
|
||||||
|
## 2026-03-19 - 25.16.1 - fix(http-proxy)
|
||||||
|
avoid repeated HTTP/3 recaching after QUIC fallback and document backend protocol selection
|
||||||
|
|
||||||
|
- Suppress Alt-Svc HTTP/3 recaching after a failed QUIC backend connection to prevent repeated H3 timeout fallback loops
|
||||||
|
- Force an ALPN probe on TCP fallback so auto detection correctly reselects HTTP/2 or HTTP/1.1 after H3 connection failure
|
||||||
|
- Add README documentation for best-effort backendProtocol selection and supported protocol modes
|
||||||
|
|
||||||
|
## 2026-03-19 - 25.16.0 - feat(quic,http3)
|
||||||
|
add HTTP/3 proxy handling and hot-reload QUIC TLS configuration
|
||||||
|
|
||||||
|
- initialize and wire H3ProxyService into QUIC listeners so HTTP/3 requests are handled instead of being kept as placeholder connections
|
||||||
|
- add backend HTTP/3 support with protocol caching that stores Alt-Svc advertised H3 ports for auto-detection
|
||||||
|
- hot-swap TLS certificates across active QUIC endpoints and require terminating TLS for QUIC route validation
|
||||||
|
- document QUIC route setup with required TLS and ACME configuration
|
||||||
|
|
||||||
|
## 2026-03-19 - 25.15.0 - feat(readme)
|
||||||
|
document UDP, QUIC, and HTTP/3 support in the README
|
||||||
|
|
||||||
|
- Adds README examples for UDP datagram handlers, QUIC/HTTP3 forwarding, and dual-stack TCP/UDP routes
|
||||||
|
- Expands configuration and API reference sections to cover transport matching, UDP/QUIC options, backend transport selection, and UDP metrics
|
||||||
|
- Updates architecture and feature descriptions to reflect UDP, QUIC, HTTP/3, and datagram handler capabilities
|
||||||
|
|
||||||
## 2026-03-19 - 25.14.1 - fix(deps)
|
## 2026-03-19 - 25.14.1 - fix(deps)
|
||||||
update build and runtime dependencies and align route validation test expectations
|
update build and runtime dependencies and align route validation test expectations
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "@push.rocks/smartproxy",
|
"name": "@push.rocks/smartproxy",
|
||||||
"version": "25.14.1",
|
"version": "26.0.0",
|
||||||
"private": false,
|
"private": false,
|
||||||
"description": "A powerful proxy package with unified route-based configuration for high traffic management. Features include SSL/TLS support, flexible routing patterns, WebSocket handling, advanced security options, and automatic ACME certificate management.",
|
"description": "A powerful proxy package with unified route-based configuration for high traffic management. Features include SSL/TLS support, flexible routing patterns, WebSocket handling, advanced security options, and automatic ACME certificate management.",
|
||||||
"main": "dist_ts/index.js",
|
"main": "dist_ts/index.js",
|
||||||
|
|||||||
265
readme.md
265
readme.md
@@ -1,6 +1,6 @@
|
|||||||
# @push.rocks/smartproxy 🚀
|
# @push.rocks/smartproxy 🚀
|
||||||
|
|
||||||
**A high-performance, Rust-powered proxy toolkit for Node.js** — unified route-based configuration for SSL/TLS termination, HTTP/HTTPS reverse proxying, WebSocket support, load balancing, custom protocol handlers, and kernel-level NFTables forwarding.
|
**A high-performance, Rust-powered proxy toolkit for Node.js** — unified route-based configuration for SSL/TLS termination, HTTP/HTTPS reverse proxying, WebSocket support, UDP/QUIC/HTTP3, load balancing, custom protocol handlers, and kernel-level NFTables forwarding.
|
||||||
|
|
||||||
## 📦 Installation
|
## 📦 Installation
|
||||||
|
|
||||||
@@ -16,9 +16,9 @@ For reporting bugs, issues, or security vulnerabilities, please visit [community
|
|||||||
|
|
||||||
## 🎯 What is SmartProxy?
|
## 🎯 What is SmartProxy?
|
||||||
|
|
||||||
SmartProxy is a production-ready proxy solution that takes the complexity out of traffic management. Under the hood, all networking — TCP, TLS, HTTP reverse proxy, connection tracking, security enforcement, and NFTables — is handled by a **Rust engine** for maximum performance, while you configure everything through a clean TypeScript API with full type safety.
|
SmartProxy is a production-ready proxy solution that takes the complexity out of traffic management. Under the hood, all networking — TCP, UDP, TLS, HTTP reverse proxy, QUIC/HTTP3, connection tracking, security enforcement, and NFTables — is handled by a **Rust engine** for maximum performance, while you configure everything through a clean TypeScript API with full type safety.
|
||||||
|
|
||||||
Whether you're building microservices, deploying edge infrastructure, or need a battle-tested reverse proxy with automatic Let's Encrypt certificates, SmartProxy has you covered.
|
Whether you're building microservices, deploying edge infrastructure, proxying UDP-based protocols, or need a battle-tested reverse proxy with automatic Let's Encrypt certificates, SmartProxy has you covered.
|
||||||
|
|
||||||
### ⚡ Key Features
|
### ⚡ Key Features
|
||||||
|
|
||||||
@@ -29,11 +29,12 @@ Whether you're building microservices, deploying edge infrastructure, or need a
|
|||||||
| 🔒 **Automatic SSL/TLS** | Zero-config HTTPS with Let's Encrypt ACME integration |
|
| 🔒 **Automatic SSL/TLS** | Zero-config HTTPS with Let's Encrypt ACME integration |
|
||||||
| 🎯 **Flexible Matching** | Route by port, domain, path, protocol, client IP, TLS version, headers, or custom logic |
|
| 🎯 **Flexible Matching** | Route by port, domain, path, protocol, client IP, TLS version, headers, or custom logic |
|
||||||
| 🚄 **High-Performance** | Choose between user-space or kernel-level (NFTables) forwarding |
|
| 🚄 **High-Performance** | Choose between user-space or kernel-level (NFTables) forwarding |
|
||||||
|
| 📡 **UDP & QUIC/HTTP3** | First-class UDP transport, datagram handlers, QUIC tunneling, and HTTP/3 support |
|
||||||
| ⚖️ **Load Balancing** | Round-robin, least-connections, IP-hash with health checks |
|
| ⚖️ **Load Balancing** | Round-robin, least-connections, IP-hash with health checks |
|
||||||
| 🛡️ **Enterprise Security** | IP filtering, rate limiting, basic auth, JWT auth, connection limits |
|
| 🛡️ **Enterprise Security** | IP filtering, rate limiting, basic auth, JWT auth, connection limits |
|
||||||
| 🔌 **WebSocket Support** | First-class WebSocket proxying with ping/pong keep-alive |
|
| 🔌 **WebSocket Support** | First-class WebSocket proxying with ping/pong keep-alive |
|
||||||
| 🎮 **Custom Protocols** | Socket handlers for implementing any protocol in TypeScript |
|
| 🎮 **Custom Protocols** | Socket and datagram handlers for implementing any protocol in TypeScript |
|
||||||
| 📊 **Live Metrics** | Real-time throughput, connection counts, and performance data |
|
| 📊 **Live Metrics** | Real-time throughput, connection counts, UDP sessions, and performance data |
|
||||||
| 🔧 **Dynamic Management** | Add/remove ports and routes at runtime without restarts |
|
| 🔧 **Dynamic Management** | Add/remove ports and routes at runtime without restarts |
|
||||||
| 🔄 **PROXY Protocol** | Full PROXY protocol v1/v2 support for preserving client information |
|
| 🔄 **PROXY Protocol** | Full PROXY protocol v1/v2 support for preserving client information |
|
||||||
| 💾 **Consumer Cert Storage** | Bring your own persistence — SmartProxy never writes certs to disk |
|
| 💾 **Consumer Cert Storage** | Bring your own persistence — SmartProxy never writes certs to disk |
|
||||||
@@ -89,7 +90,7 @@ SmartProxy uses a powerful **match/action** pattern that makes routing predictab
|
|||||||
```
|
```
|
||||||
|
|
||||||
Every route consists of:
|
Every route consists of:
|
||||||
- **Match** — What traffic to capture (ports, domains, paths, protocol, IPs, headers)
|
- **Match** — What traffic to capture (ports, domains, paths, transport, protocol, IPs, headers)
|
||||||
- **Action** — What to do with it (`forward` or `socket-handler`)
|
- **Action** — What to do with it (`forward` or `socket-handler`)
|
||||||
- **Security** (optional) — IP allow/block lists, rate limits, authentication
|
- **Security** (optional) — IP allow/block lists, rate limits, authentication
|
||||||
- **Headers** (optional) — Request/response header manipulation with template variables
|
- **Headers** (optional) — Request/response header manipulation with template variables
|
||||||
@@ -197,7 +198,7 @@ apiRoute = addRateLimiting(apiRoute, {
|
|||||||
const proxy = new SmartProxy({ routes: [apiRoute] });
|
const proxy = new SmartProxy({ routes: [apiRoute] });
|
||||||
```
|
```
|
||||||
|
|
||||||
### 🎮 Custom Protocol Handler
|
### 🎮 Custom Protocol Handler (TCP)
|
||||||
|
|
||||||
SmartProxy lets you implement any protocol with full socket control. Routes with JavaScript socket handlers are automatically relayed from the Rust engine back to your TypeScript code:
|
SmartProxy lets you implement any protocol with full socket control. Routes with JavaScript socket handlers are automatically relayed from the Rust engine back to your TypeScript code:
|
||||||
|
|
||||||
@@ -247,6 +248,140 @@ const proxy = new SmartProxy({ routes: [echoRoute, customRoute] });
|
|||||||
| `SocketHandlers.httpBlock(status, message)` | HTTP block response |
|
| `SocketHandlers.httpBlock(status, message)` | HTTP block response |
|
||||||
| `SocketHandlers.block(message)` | Block with optional message |
|
| `SocketHandlers.block(message)` | Block with optional message |
|
||||||
|
|
||||||
|
### 📡 UDP Datagram Handler
|
||||||
|
|
||||||
|
Handle raw UDP datagrams with custom TypeScript logic — perfect for DNS, game servers, IoT protocols, or any UDP-based service:
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
import { SmartProxy } from '@push.rocks/smartproxy';
|
||||||
|
import type { IRouteConfig, TDatagramHandler, IDatagramInfo } from '@push.rocks/smartproxy';
|
||||||
|
|
||||||
|
// Custom UDP echo handler
|
||||||
|
const udpHandler: TDatagramHandler = (datagram, info, reply) => {
|
||||||
|
console.log(`UDP from ${info.sourceIp}:${info.sourcePort} on port ${info.destPort}`);
|
||||||
|
reply(datagram); // Echo it back
|
||||||
|
};
|
||||||
|
|
||||||
|
const proxy = new SmartProxy({
|
||||||
|
routes: [{
|
||||||
|
name: 'udp-echo',
|
||||||
|
match: {
|
||||||
|
ports: 5353,
|
||||||
|
transport: 'udp' // 👈 Listen for UDP datagrams
|
||||||
|
},
|
||||||
|
action: {
|
||||||
|
type: 'socket-handler',
|
||||||
|
datagramHandler: udpHandler, // 👈 Process each datagram
|
||||||
|
udp: {
|
||||||
|
sessionTimeout: 60000, // Session idle timeout (ms)
|
||||||
|
maxSessionsPerIP: 100,
|
||||||
|
maxDatagramSize: 65535
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}]
|
||||||
|
});
|
||||||
|
|
||||||
|
await proxy.start();
|
||||||
|
```
|
||||||
|
|
||||||
|
### 📡 QUIC / HTTP3 Forwarding
|
||||||
|
|
||||||
|
Forward QUIC traffic to backends with optional protocol translation (e.g., receive QUIC, forward as TCP/HTTP1):
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
import { SmartProxy } from '@push.rocks/smartproxy';
|
||||||
|
import type { IRouteConfig } from '@push.rocks/smartproxy';
|
||||||
|
|
||||||
|
const quicRoute: IRouteConfig = {
|
||||||
|
name: 'quic-to-backend',
|
||||||
|
match: {
|
||||||
|
ports: 443,
|
||||||
|
transport: 'udp',
|
||||||
|
protocol: 'quic' // 👈 Match QUIC protocol
|
||||||
|
},
|
||||||
|
action: {
|
||||||
|
type: 'forward',
|
||||||
|
targets: [{
|
||||||
|
host: 'backend-server',
|
||||||
|
port: 8443,
|
||||||
|
backendTransport: 'tcp' // 👈 Translate QUIC → TCP for backend
|
||||||
|
}],
|
||||||
|
tls: {
|
||||||
|
mode: 'terminate',
|
||||||
|
certificate: 'auto' // 👈 QUIC requires TLS 1.3
|
||||||
|
},
|
||||||
|
udp: {
|
||||||
|
quic: {
|
||||||
|
enableHttp3: true,
|
||||||
|
maxIdleTimeout: 30000,
|
||||||
|
maxConcurrentBidiStreams: 100,
|
||||||
|
altSvcPort: 443, // Advertise in Alt-Svc header
|
||||||
|
altSvcMaxAge: 86400
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const proxy = new SmartProxy({
|
||||||
|
acme: { email: 'ssl@example.com' },
|
||||||
|
routes: [quicRoute]
|
||||||
|
});
|
||||||
|
```
|
||||||
|
|
||||||
|
### 🚄 Best-Effort Backend Protocol (H3 > H2 > H1)
|
||||||
|
|
||||||
|
SmartProxy automatically uses the **highest protocol your backend supports** for HTTP requests. The backend protocol is independent of the client protocol — a client using HTTP/1.1 can be forwarded over HTTP/3 to the backend, and vice versa.
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
const route: IRouteConfig = {
|
||||||
|
name: 'auto-protocol',
|
||||||
|
match: { ports: 443, domains: 'app.example.com' },
|
||||||
|
action: {
|
||||||
|
type: 'forward',
|
||||||
|
targets: [{ host: 'backend', port: 8443 }],
|
||||||
|
tls: { mode: 'terminate', certificate: 'auto' },
|
||||||
|
options: {
|
||||||
|
backendProtocol: 'auto' // 👈 Default — best-effort selection
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
```
|
||||||
|
|
||||||
|
**How protocol discovery works (browser model):**
|
||||||
|
|
||||||
|
1. First request → TLS ALPN probe detects H2 or H1
|
||||||
|
2. Backend response inspected for `Alt-Svc: h3=":port"` header
|
||||||
|
3. If H3 advertised → cached and used for subsequent requests via QUIC
|
||||||
|
4. Graceful fallback: H3 failure → H2 → H1 with automatic cache invalidation
|
||||||
|
|
||||||
|
| `backendProtocol` | Behavior |
|
||||||
|
|---|---|
|
||||||
|
| `'auto'` (default) | Best-effort: H3 > H2 > H1 with Alt-Svc discovery |
|
||||||
|
| `'http1'` | Always HTTP/1.1 |
|
||||||
|
| `'http2'` | Always HTTP/2 (hard-fail if unsupported) |
|
||||||
|
| `'http3'` | Always HTTP/3 via QUIC (hard-fail if unsupported) |
|
||||||
|
|
||||||
|
> **Note:** WebSocket upgrades always use HTTP/1.1 to the backend regardless of `backendProtocol`, since there's no performance benefit from H2/H3 Extended CONNECT for tunneled connections, and backend support is rare.
|
||||||
|
|
||||||
|
### 🔁 Dual-Stack TCP + UDP Route
|
||||||
|
|
||||||
|
Listen on both TCP and UDP with a single route — handle each transport with its own handler:
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
const dualStackRoute: IRouteConfig = {
|
||||||
|
name: 'dual-stack-dns',
|
||||||
|
match: {
|
||||||
|
ports: 53,
|
||||||
|
transport: 'all' // 👈 Listen on both TCP and UDP
|
||||||
|
},
|
||||||
|
action: {
|
||||||
|
type: 'socket-handler',
|
||||||
|
socketHandler: handleTcpDns, // 👈 TCP connections
|
||||||
|
datagramHandler: handleUdpDns, // 👈 UDP datagrams
|
||||||
|
}
|
||||||
|
};
|
||||||
|
```
|
||||||
|
|
||||||
### ⚡ High-Performance NFTables Forwarding
|
### ⚡ High-Performance NFTables Forwarding
|
||||||
|
|
||||||
For ultra-low latency on Linux, use kernel-level forwarding (requires root):
|
For ultra-low latency on Linux, use kernel-level forwarding (requires root):
|
||||||
@@ -419,6 +554,10 @@ console.log(`Bytes in: ${metrics.totals.bytesIn()}`);
|
|||||||
console.log(`Requests/sec: ${metrics.requests.perSecond()}`);
|
console.log(`Requests/sec: ${metrics.requests.perSecond()}`);
|
||||||
console.log(`Throughput in: ${metrics.throughput.instant().in} bytes/sec`);
|
console.log(`Throughput in: ${metrics.throughput.instant().in} bytes/sec`);
|
||||||
|
|
||||||
|
// UDP metrics
|
||||||
|
console.log(`UDP sessions: ${metrics.udp.activeSessions()}`);
|
||||||
|
console.log(`Datagrams in: ${metrics.udp.datagramsIn()}`);
|
||||||
|
|
||||||
// Get detailed statistics from the Rust engine
|
// Get detailed statistics from the Rust engine
|
||||||
const stats = await proxy.getStatistics();
|
const stats = await proxy.getStatistics();
|
||||||
|
|
||||||
@@ -545,7 +684,7 @@ SmartProxy uses a hybrid **Rust + TypeScript** architecture:
|
|||||||
```
|
```
|
||||||
┌─────────────────────────────────────────────────────┐
|
┌─────────────────────────────────────────────────────┐
|
||||||
│ Your Application │
|
│ Your Application │
|
||||||
│ (TypeScript — routes, config, socket handlers) │
|
│ (TypeScript — routes, config, handlers) │
|
||||||
└──────────────────┬──────────────────────────────────┘
|
└──────────────────┬──────────────────────────────────┘
|
||||||
│ IPC (JSON over stdin/stdout)
|
│ IPC (JSON over stdin/stdout)
|
||||||
┌──────────────────▼──────────────────────────────────┐
|
┌──────────────────▼──────────────────────────────────┐
|
||||||
@@ -556,22 +695,23 @@ SmartProxy uses a hybrid **Rust + TypeScript** architecture:
|
|||||||
│ │ │ │ Proxy │ │ │ │ │ │
|
│ │ │ │ Proxy │ │ │ │ │ │
|
||||||
│ └─────────┘ └─────────┘ └─────────┘ └──────────┘ │
|
│ └─────────┘ └─────────┘ └─────────┘ └──────────┘ │
|
||||||
│ ┌─────────┐ ┌─────────┐ ┌─────────┐ ┌──────────┐ │
|
│ ┌─────────┐ ┌─────────┐ ┌─────────┐ ┌──────────┐ │
|
||||||
│ │ Security│ │ Metrics │ │ Connec- │ │ NFTables │ │
|
│ │ UDP │ │ Security│ │ Metrics │ │ NFTables │ │
|
||||||
│ │ Enforce │ │ Collect │ │ tion │ │ Mgr │ │
|
│ │ QUIC │ │ Enforce │ │ Collect │ │ Mgr │ │
|
||||||
│ │ │ │ │ │ Tracker │ │ │ │
|
│ │ HTTP/3 │ │ │ │ │ │ │ │
|
||||||
│ └─────────┘ └─────────┘ └─────────┘ └──────────┘ │
|
│ └─────────┘ └─────────┘ └─────────┘ └──────────┘ │
|
||||||
└──────────────────┬──────────────────────────────────┘
|
└──────────────────┬──────────────────────────────────┘
|
||||||
│ Unix Socket Relay
|
│ Unix Socket Relay
|
||||||
┌──────────────────▼──────────────────────────────────┐
|
┌──────────────────▼──────────────────────────────────┐
|
||||||
│ TypeScript Socket Handler Server │
|
│ TypeScript Socket & Datagram Handler Servers │
|
||||||
│ (for JS-defined socket handlers & dynamic routes) │
|
│ (for JS socket handlers, datagram handlers, │
|
||||||
|
│ and dynamic routes) │
|
||||||
└─────────────────────────────────────────────────────┘
|
└─────────────────────────────────────────────────────┘
|
||||||
```
|
```
|
||||||
|
|
||||||
- **Rust Engine** handles all networking, TLS, HTTP proxying, connection management, security, and metrics
|
- **Rust Engine** handles all networking: TCP, UDP, TLS, QUIC, HTTP proxying, connection management, security, and metrics
|
||||||
- **TypeScript** provides the npm API, configuration types, route helpers, validation, and socket handler callbacks
|
- **TypeScript** provides the npm API, configuration types, route helpers, validation, and handler callbacks
|
||||||
- **IPC** — The TypeScript wrapper uses JSON commands/events over stdin/stdout to communicate with the Rust binary
|
- **IPC** — The TypeScript wrapper uses JSON commands/events over stdin/stdout to communicate with the Rust binary
|
||||||
- **Socket Relay** — A Unix domain socket server for routes requiring TypeScript-side handling (socket handlers, dynamic host/port functions)
|
- **Socket/Datagram Relay** — Unix domain socket servers for routes requiring TypeScript-side handling (socket handlers, datagram handlers, dynamic host/port functions)
|
||||||
|
|
||||||
## 🎯 Route Configuration Reference
|
## 🎯 Route Configuration Reference
|
||||||
|
|
||||||
@@ -579,22 +719,26 @@ SmartProxy uses a hybrid **Rust + TypeScript** architecture:
|
|||||||
|
|
||||||
```typescript
|
```typescript
|
||||||
interface IRouteMatch {
|
interface IRouteMatch {
|
||||||
ports: number | number[] | Array<{ from: number; to: number }>; // Required — port(s) to listen on
|
ports: TPortRange; // Required — port(s) to listen on
|
||||||
|
transport?: 'tcp' | 'udp' | 'all'; // Transport protocol (default: 'tcp')
|
||||||
domains?: string | string[]; // 'example.com', '*.example.com'
|
domains?: string | string[]; // 'example.com', '*.example.com'
|
||||||
path?: string; // '/api/*', '/users/:id'
|
path?: string; // '/api/*', '/users/:id'
|
||||||
clientIp?: string[]; // ['10.0.0.0/8', '192.168.*']
|
clientIp?: string[]; // ['10.0.0.0/8', '192.168.*']
|
||||||
tlsVersion?: string[]; // ['TLSv1.2', 'TLSv1.3']
|
tlsVersion?: string[]; // ['TLSv1.2', 'TLSv1.3']
|
||||||
headers?: Record<string, string | RegExp>; // Match by HTTP headers
|
headers?: Record<string, string | RegExp>; // Match by HTTP headers
|
||||||
protocol?: 'http' | 'tcp'; // Match specific protocol ('http' includes h2 + WebSocket upgrades)
|
protocol?: 'http' | 'tcp' | 'udp' | 'quic' | 'http3'; // Application-layer protocol
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Port range supports single numbers, arrays, and ranges
|
||||||
|
type TPortRange = number | Array<number | { from: number; to: number }>;
|
||||||
```
|
```
|
||||||
|
|
||||||
### Action Types
|
### Action Types
|
||||||
|
|
||||||
| Type | Description |
|
| Type | Description |
|
||||||
|------|-------------|
|
|------|-------------|
|
||||||
| `forward` | Proxy to one or more backend targets (with optional TLS, WebSocket, load balancing) |
|
| `forward` | Proxy to one or more backend targets (with optional TLS, WebSocket, load balancing, UDP/QUIC) |
|
||||||
| `socket-handler` | Custom socket handling function in TypeScript |
|
| `socket-handler` | Custom socket/datagram handling function in TypeScript |
|
||||||
|
|
||||||
### Target Options
|
### Target Options
|
||||||
|
|
||||||
@@ -610,6 +754,7 @@ interface IRouteTarget {
|
|||||||
sendProxyProtocol?: boolean;
|
sendProxyProtocol?: boolean;
|
||||||
headers?: IRouteHeaders;
|
headers?: IRouteHeaders;
|
||||||
advanced?: IRouteAdvanced;
|
advanced?: IRouteAdvanced;
|
||||||
|
backendTransport?: 'tcp' | 'udp'; // Backend transport (e.g., receive QUIC, forward as TCP)
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -666,6 +811,49 @@ interface IRouteLoadBalancing {
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Backend Protocol Options
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// Set on action.options
|
||||||
|
{
|
||||||
|
action: {
|
||||||
|
type: 'forward',
|
||||||
|
targets: [...],
|
||||||
|
options: {
|
||||||
|
backendProtocol: 'auto' | 'http1' | 'http2' | 'http3'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
| Value | Backend Behavior |
|
||||||
|
|-------|-----------------|
|
||||||
|
| `'auto'` | Best-effort: discovers H3 via Alt-Svc, probes H2 via ALPN, falls back to H1 |
|
||||||
|
| `'http1'` | Always HTTP/1.1 (no ALPN probe) |
|
||||||
|
| `'http2'` | Always HTTP/2 (hard-fail if handshake fails) |
|
||||||
|
| `'http3'` | Always HTTP/3 over QUIC (3s connect timeout, hard-fail if unreachable) |
|
||||||
|
|
||||||
|
### UDP & QUIC Options
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
interface IRouteUdp {
|
||||||
|
sessionTimeout?: number; // Idle timeout per UDP session (ms, default: 60000)
|
||||||
|
maxSessionsPerIP?: number; // Max concurrent sessions per IP (default: 1000)
|
||||||
|
maxDatagramSize?: number; // Max datagram size in bytes (default: 65535)
|
||||||
|
quic?: IRouteQuic;
|
||||||
|
}
|
||||||
|
|
||||||
|
interface IRouteQuic {
|
||||||
|
maxIdleTimeout?: number; // QUIC idle timeout (ms, default: 30000)
|
||||||
|
maxConcurrentBidiStreams?: number; // Max bidi streams (default: 100)
|
||||||
|
maxConcurrentUniStreams?: number; // Max uni streams (default: 100)
|
||||||
|
enableHttp3?: boolean; // Enable HTTP/3 (default: false)
|
||||||
|
altSvcPort?: number; // Port for Alt-Svc header
|
||||||
|
altSvcMaxAge?: number; // Alt-Svc max age in seconds (default: 86400)
|
||||||
|
initialCongestionWindow?: number; // Initial congestion window (bytes)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
## 🛠️ Helper Functions Reference
|
## 🛠️ Helper Functions Reference
|
||||||
|
|
||||||
All helpers are fully typed and return `IRouteConfig` or `IRouteConfig[]`:
|
All helpers are fully typed and return `IRouteConfig` or `IRouteConfig[]`:
|
||||||
@@ -689,7 +877,7 @@ import {
|
|||||||
createWebSocketRoute, // WebSocket-enabled route
|
createWebSocketRoute, // WebSocket-enabled route
|
||||||
|
|
||||||
// Custom Protocols
|
// Custom Protocols
|
||||||
createSocketHandlerRoute, // Custom socket handler
|
createSocketHandlerRoute, // Custom TCP socket handler
|
||||||
SocketHandlers, // Pre-built handlers (echo, proxy, block, etc.)
|
SocketHandlers, // Pre-built handlers (echo, proxy, block, etc.)
|
||||||
|
|
||||||
// NFTables (Linux, requires root)
|
// NFTables (Linux, requires root)
|
||||||
@@ -718,6 +906,8 @@ import {
|
|||||||
} from '@push.rocks/smartproxy';
|
} from '@push.rocks/smartproxy';
|
||||||
```
|
```
|
||||||
|
|
||||||
|
> **Tip:** For UDP datagram handler routes or QUIC/HTTP3 routes, construct `IRouteConfig` objects directly — there are no helper functions for these yet. See the [UDP Datagram Handler](#-udp-datagram-handler) and [QUIC / HTTP3 Forwarding](#-quic--http3-forwarding) examples above.
|
||||||
|
|
||||||
## 📖 API Documentation
|
## 📖 API Documentation
|
||||||
|
|
||||||
### SmartProxy Class
|
### SmartProxy Class
|
||||||
@@ -753,6 +943,8 @@ class SmartProxy extends EventEmitter {
|
|||||||
|
|
||||||
// Events
|
// Events
|
||||||
on(event: 'error', handler: (err: Error) => void): this;
|
on(event: 'error', handler: (err: Error) => void): this;
|
||||||
|
on(event: 'certificate-issued', handler: (ev: ICertificateIssuedEvent) => void): this;
|
||||||
|
on(event: 'certificate-failed', handler: (ev: ICertificateFailedEvent) => void): this;
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -775,6 +967,8 @@ interface ISmartProxyOptions {
|
|||||||
// Custom certificate provisioning
|
// Custom certificate provisioning
|
||||||
certProvisionFunction?: (domain: string) => Promise<ICert | 'http01'>;
|
certProvisionFunction?: (domain: string) => Promise<ICert | 'http01'>;
|
||||||
certProvisionFallbackToAcme?: boolean; // Fall back to ACME on failure (default: true)
|
certProvisionFallbackToAcme?: boolean; // Fall back to ACME on failure (default: true)
|
||||||
|
certProvisionTimeout?: number; // Timeout per provision call (ms)
|
||||||
|
certProvisionConcurrency?: number; // Max concurrent provisions
|
||||||
|
|
||||||
// Consumer-managed certificate persistence (see "Consumer-Managed Certificate Storage")
|
// Consumer-managed certificate persistence (see "Consumer-Managed Certificate Storage")
|
||||||
certStore?: ISmartProxyCertStore;
|
certStore?: ISmartProxyCertStore;
|
||||||
@@ -782,6 +976,9 @@ interface ISmartProxyOptions {
|
|||||||
// Self-signed fallback
|
// Self-signed fallback
|
||||||
disableDefaultCert?: boolean; // Disable '*' self-signed fallback (default: false)
|
disableDefaultCert?: boolean; // Disable '*' self-signed fallback (default: false)
|
||||||
|
|
||||||
|
// Rust binary path override
|
||||||
|
rustBinaryPath?: string; // Custom path to the Rust proxy binary
|
||||||
|
|
||||||
// Global defaults
|
// Global defaults
|
||||||
defaults?: {
|
defaults?: {
|
||||||
target?: { host: string; port: number };
|
target?: { host: string; port: number };
|
||||||
@@ -868,11 +1065,22 @@ metrics.requests.perSecond(); // Requests per second
|
|||||||
metrics.requests.perMinute(); // Requests per minute
|
metrics.requests.perMinute(); // Requests per minute
|
||||||
metrics.requests.total(); // Total requests
|
metrics.requests.total(); // Total requests
|
||||||
|
|
||||||
|
// UDP metrics
|
||||||
|
metrics.udp.activeSessions(); // Current active UDP sessions
|
||||||
|
metrics.udp.totalSessions(); // Total UDP sessions since start
|
||||||
|
metrics.udp.datagramsIn(); // Datagrams received
|
||||||
|
metrics.udp.datagramsOut(); // Datagrams sent
|
||||||
|
|
||||||
// Cumulative totals
|
// Cumulative totals
|
||||||
metrics.totals.bytesIn(); // Total bytes received
|
metrics.totals.bytesIn(); // Total bytes received
|
||||||
metrics.totals.bytesOut(); // Total bytes sent
|
metrics.totals.bytesOut(); // Total bytes sent
|
||||||
metrics.totals.connections(); // Total connections
|
metrics.totals.connections(); // Total connections
|
||||||
|
|
||||||
|
// Backend metrics
|
||||||
|
metrics.backends.byBackend(); // Map<backend, IBackendMetrics>
|
||||||
|
metrics.backends.protocols(); // Map<backend, protocol>
|
||||||
|
metrics.backends.topByErrors(10); // Top N error-prone backends
|
||||||
|
|
||||||
// Percentiles
|
// Percentiles
|
||||||
metrics.percentiles.connectionDuration(); // { p50, p95, p99 }
|
metrics.percentiles.connectionDuration(); // { p50, p95, p99 }
|
||||||
metrics.percentiles.bytesTransferred(); // { in: { p50, p95, p99 }, out: { p50, p95, p99 } }
|
metrics.percentiles.bytesTransferred(); // { in: { p50, p95, p99 }, out: { p50, p95, p99 } }
|
||||||
@@ -896,11 +1104,16 @@ metrics.percentiles.bytesTransferred(); // { in: { p50, p95, p99 }, out: { p5
|
|||||||
### Rust Binary Not Found
|
### Rust Binary Not Found
|
||||||
|
|
||||||
SmartProxy searches for the Rust binary in this order:
|
SmartProxy searches for the Rust binary in this order:
|
||||||
1. `SMARTPROXY_RUST_BINARY` environment variable
|
1. `rustBinaryPath` option in `ISmartProxyOptions`
|
||||||
2. Platform-specific npm package (`@push.rocks/smartproxy-linux-x64`, etc.)
|
2. `SMARTPROXY_RUST_BINARY` environment variable
|
||||||
3. `dist_rust/rustproxy` relative to the package root (built by `tsrust`)
|
3. Platform-specific npm package (`@push.rocks/smartproxy-linux-x64`, etc.)
|
||||||
4. Local dev build (`./rust/target/release/rustproxy`)
|
4. `dist_rust/rustproxy` relative to the package root (built by `tsrust`)
|
||||||
5. System PATH (`rustproxy`)
|
5. Local dev build (`./rust/target/release/rustproxy`)
|
||||||
|
6. System PATH (`rustproxy`)
|
||||||
|
|
||||||
|
### QUIC / HTTP3 Caveats
|
||||||
|
- **GREASE frames are disabled.** The underlying h3 crate sends [GREASE frames](https://www.rfc-editor.org/rfc/rfc9114.html#frame-reserved) by default to test protocol extensibility. However, some HTTP/3 clients and servers don't properly ignore unknown frame types, causing 400/500 errors or stream hangs ([h3#206](https://github.com/hyperium/h3/issues/206)). SmartProxy disables GREASE on both the server side (for incoming H3 requests) and the client side (for H3 backend connections) to maximize compatibility.
|
||||||
|
- **HTTP/3 is pre-release.** The h3 ecosystem (h3 0.0.8, h3-quinn 0.0.10, quinn 0.11) is still pre-1.0. Expect rough edges.
|
||||||
|
|
||||||
### Performance Tuning
|
### Performance Tuning
|
||||||
- ✅ Use NFTables forwarding for high-traffic routes (Linux only)
|
- ✅ Use NFTables forwarding for high-traffic routes (Linux only)
|
||||||
|
|||||||
4
rust/Cargo.lock
generated
4
rust/Cargo.lock
generated
@@ -1224,10 +1224,14 @@ dependencies = [
|
|||||||
"bytes",
|
"bytes",
|
||||||
"clap",
|
"clap",
|
||||||
"dashmap",
|
"dashmap",
|
||||||
|
"h3",
|
||||||
|
"h3-quinn",
|
||||||
|
"http",
|
||||||
"http-body-util",
|
"http-body-util",
|
||||||
"hyper",
|
"hyper",
|
||||||
"hyper-util",
|
"hyper-util",
|
||||||
"mimalloc",
|
"mimalloc",
|
||||||
|
"quinn",
|
||||||
"rcgen",
|
"rcgen",
|
||||||
"rustls",
|
"rustls",
|
||||||
"rustls-pemfile",
|
"rustls-pemfile",
|
||||||
|
|||||||
@@ -1,109 +1,69 @@
|
|||||||
//! HTTP/3 proxy service.
|
//! HTTP/3 proxy service.
|
||||||
//!
|
//!
|
||||||
//! Accepts QUIC connections via quinn, runs h3 server to handle HTTP/3 requests,
|
//! Accepts QUIC connections via quinn, runs h3 server to handle HTTP/3 requests,
|
||||||
//! and forwards them to backends using the same routing and pool infrastructure
|
//! and delegates backend forwarding to the shared `HttpProxyService` — same
|
||||||
//! as the HTTP/1+2 proxy.
|
//! route matching, connection pool, and protocol auto-detection as TCP/HTTP.
|
||||||
|
|
||||||
|
use std::net::SocketAddr;
|
||||||
use std::pin::Pin;
|
use std::pin::Pin;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::task::{Context, Poll};
|
use std::task::{Context, Poll};
|
||||||
use std::time::Duration;
|
|
||||||
|
|
||||||
use arc_swap::ArcSwap;
|
|
||||||
use bytes::{Buf, Bytes};
|
use bytes::{Buf, Bytes};
|
||||||
use http_body::Frame;
|
use http_body::Frame;
|
||||||
|
use http_body_util::BodyExt;
|
||||||
|
use http_body_util::combinators::BoxBody;
|
||||||
use tracing::{debug, warn};
|
use tracing::{debug, warn};
|
||||||
|
|
||||||
use rustproxy_config::{RouteConfig, TransportProtocol};
|
use rustproxy_config::RouteConfig;
|
||||||
use rustproxy_metrics::MetricsCollector;
|
use tokio_util::sync::CancellationToken;
|
||||||
use rustproxy_routing::{MatchContext, RouteManager};
|
|
||||||
|
|
||||||
use crate::connection_pool::ConnectionPool;
|
use crate::proxy_service::{ConnActivity, HttpProxyService};
|
||||||
use crate::protocol_cache::ProtocolCache;
|
|
||||||
use crate::upstream_selector::UpstreamSelector;
|
|
||||||
|
|
||||||
/// HTTP/3 proxy service.
|
/// HTTP/3 proxy service.
|
||||||
///
|
///
|
||||||
/// Handles QUIC connections with the h3 crate, parses HTTP/3 requests,
|
/// Accepts QUIC connections, parses HTTP/3 requests, and delegates backend
|
||||||
/// and forwards them to backends using per-request route matching and
|
/// forwarding to the shared `HttpProxyService`.
|
||||||
/// shared connection pooling.
|
|
||||||
pub struct H3ProxyService {
|
pub struct H3ProxyService {
|
||||||
route_manager: Arc<ArcSwap<RouteManager>>,
|
http_proxy: Arc<HttpProxyService>,
|
||||||
metrics: Arc<MetricsCollector>,
|
|
||||||
connection_pool: Arc<ConnectionPool>,
|
|
||||||
#[allow(dead_code)]
|
|
||||||
protocol_cache: Arc<ProtocolCache>,
|
|
||||||
#[allow(dead_code)]
|
|
||||||
upstream_selector: UpstreamSelector,
|
|
||||||
#[allow(dead_code)]
|
|
||||||
backend_tls_config: Arc<rustls::ClientConfig>,
|
|
||||||
connect_timeout: Duration,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl H3ProxyService {
|
impl H3ProxyService {
|
||||||
pub fn new(
|
pub fn new(http_proxy: Arc<HttpProxyService>) -> Self {
|
||||||
route_manager: Arc<ArcSwap<RouteManager>>,
|
Self { http_proxy }
|
||||||
metrics: Arc<MetricsCollector>,
|
|
||||||
connection_pool: Arc<ConnectionPool>,
|
|
||||||
protocol_cache: Arc<ProtocolCache>,
|
|
||||||
backend_tls_config: Arc<rustls::ClientConfig>,
|
|
||||||
connect_timeout: Duration,
|
|
||||||
) -> Self {
|
|
||||||
Self {
|
|
||||||
route_manager: Arc::clone(&route_manager),
|
|
||||||
metrics: Arc::clone(&metrics),
|
|
||||||
connection_pool,
|
|
||||||
protocol_cache,
|
|
||||||
upstream_selector: UpstreamSelector::new(),
|
|
||||||
backend_tls_config,
|
|
||||||
connect_timeout,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Handle an accepted QUIC connection as HTTP/3.
|
/// Handle an accepted QUIC connection as HTTP/3.
|
||||||
|
///
|
||||||
|
/// If `real_client_addr` is provided (from PROXY protocol), it overrides
|
||||||
|
/// `connection.remote_address()` for client IP attribution.
|
||||||
pub async fn handle_connection(
|
pub async fn handle_connection(
|
||||||
&self,
|
&self,
|
||||||
connection: quinn::Connection,
|
connection: quinn::Connection,
|
||||||
_fallback_route: &RouteConfig,
|
_fallback_route: &RouteConfig,
|
||||||
port: u16,
|
port: u16,
|
||||||
|
real_client_addr: Option<SocketAddr>,
|
||||||
|
parent_cancel: &CancellationToken,
|
||||||
) -> anyhow::Result<()> {
|
) -> anyhow::Result<()> {
|
||||||
let remote_addr = connection.remote_address();
|
let remote_addr = real_client_addr.unwrap_or_else(|| connection.remote_address());
|
||||||
debug!("HTTP/3 connection from {} on port {}", remote_addr, port);
|
debug!("HTTP/3 connection from {} on port {}", remote_addr, port);
|
||||||
|
|
||||||
let mut h3_conn: h3::server::Connection<h3_quinn::Connection, Bytes> =
|
let mut h3_conn: h3::server::Connection<h3_quinn::Connection, Bytes> =
|
||||||
h3::server::Connection::new(h3_quinn::Connection::new(connection))
|
h3::server::builder()
|
||||||
|
.send_grease(false)
|
||||||
|
.build(h3_quinn::Connection::new(connection))
|
||||||
.await
|
.await
|
||||||
.map_err(|e| anyhow::anyhow!("H3 connection setup failed: {}", e))?;
|
.map_err(|e| anyhow::anyhow!("H3 connection setup failed: {}", e))?;
|
||||||
|
|
||||||
let client_ip = remote_addr.ip().to_string();
|
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
match h3_conn.accept().await {
|
let resolver = tokio::select! {
|
||||||
Ok(Some(resolver)) => {
|
_ = parent_cancel.cancelled() => {
|
||||||
let (request, stream) = match resolver.resolve_request().await {
|
debug!("HTTP/3 connection from {} cancelled by parent", remote_addr);
|
||||||
Ok(pair) => pair,
|
break;
|
||||||
Err(e) => {
|
|
||||||
debug!("HTTP/3 request resolve error: {}", e);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
self.metrics.record_http_request();
|
|
||||||
|
|
||||||
let rm = self.route_manager.load();
|
|
||||||
let pool = Arc::clone(&self.connection_pool);
|
|
||||||
let metrics = Arc::clone(&self.metrics);
|
|
||||||
let connect_timeout = self.connect_timeout;
|
|
||||||
let client_ip = client_ip.clone();
|
|
||||||
|
|
||||||
tokio::spawn(async move {
|
|
||||||
if let Err(e) = handle_h3_request(
|
|
||||||
request, stream, port, &client_ip, &rm, &pool, &metrics, connect_timeout,
|
|
||||||
).await {
|
|
||||||
debug!("HTTP/3 request error from {}: {}", client_ip, e);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
result = h3_conn.accept() => {
|
||||||
|
match result {
|
||||||
|
Ok(Some(resolver)) => resolver,
|
||||||
Ok(None) => {
|
Ok(None) => {
|
||||||
debug!("HTTP/3 connection from {} closed", remote_addr);
|
debug!("HTTP/3 connection from {} closed", remote_addr);
|
||||||
break;
|
break;
|
||||||
@@ -114,91 +74,65 @@ impl H3ProxyService {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let (request, stream) = match resolver.resolve_request().await {
|
||||||
|
Ok(pair) => pair,
|
||||||
|
Err(e) => {
|
||||||
|
debug!("HTTP/3 request resolve error: {}", e);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let http_proxy = Arc::clone(&self.http_proxy);
|
||||||
|
let request_cancel = parent_cancel.child_token();
|
||||||
|
|
||||||
|
tokio::spawn(async move {
|
||||||
|
if let Err(e) = handle_h3_request(
|
||||||
|
request, stream, port, remote_addr, &http_proxy, request_cancel,
|
||||||
|
).await {
|
||||||
|
debug!("HTTP/3 request error from {}: {}", remote_addr, e);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Handle a single HTTP/3 request with per-request route matching.
|
/// Handle a single HTTP/3 request by delegating to HttpProxyService.
|
||||||
|
///
|
||||||
|
/// 1. Read the H3 request body via an mpsc channel (streaming, not buffered)
|
||||||
|
/// 2. Build a `hyper::Request<BoxBody>` that HttpProxyService can handle
|
||||||
|
/// 3. Call `HttpProxyService::handle_request` — same route matching, connection
|
||||||
|
/// pool, ALPN protocol detection (H1/H2/H3) as the TCP/HTTP path
|
||||||
|
/// 4. Stream the response back over the H3 stream
|
||||||
async fn handle_h3_request(
|
async fn handle_h3_request(
|
||||||
request: hyper::Request<()>,
|
request: hyper::Request<()>,
|
||||||
mut stream: h3::server::RequestStream<h3_quinn::BidiStream<Bytes>, Bytes>,
|
mut stream: h3::server::RequestStream<h3_quinn::BidiStream<Bytes>, Bytes>,
|
||||||
port: u16,
|
port: u16,
|
||||||
client_ip: &str,
|
peer_addr: SocketAddr,
|
||||||
route_manager: &RouteManager,
|
http_proxy: &HttpProxyService,
|
||||||
_connection_pool: &ConnectionPool,
|
cancel: CancellationToken,
|
||||||
metrics: &MetricsCollector,
|
|
||||||
connect_timeout: Duration,
|
|
||||||
) -> anyhow::Result<()> {
|
) -> anyhow::Result<()> {
|
||||||
let method = request.method().clone();
|
// Stream request body from H3 client via an mpsc channel.
|
||||||
let uri = request.uri().clone();
|
|
||||||
let path = uri.path().to_string();
|
|
||||||
|
|
||||||
// Extract host from :authority or Host header
|
|
||||||
let host = request.uri().authority()
|
|
||||||
.map(|a| a.as_str().to_string())
|
|
||||||
.or_else(|| request.headers().get("host").and_then(|v| v.to_str().ok()).map(|s| s.to_string()))
|
|
||||||
.unwrap_or_default();
|
|
||||||
|
|
||||||
debug!("HTTP/3 {} {} (host: {}, client: {})", method, path, host, client_ip);
|
|
||||||
|
|
||||||
// Per-request route matching
|
|
||||||
let ctx = MatchContext {
|
|
||||||
port,
|
|
||||||
domain: if host.is_empty() { None } else { Some(&host) },
|
|
||||||
path: Some(&path),
|
|
||||||
client_ip: Some(client_ip),
|
|
||||||
tls_version: Some("TLSv1.3"),
|
|
||||||
headers: None,
|
|
||||||
is_tls: true,
|
|
||||||
protocol: Some("http"),
|
|
||||||
transport: Some(TransportProtocol::Udp),
|
|
||||||
};
|
|
||||||
|
|
||||||
let route_match = route_manager.find_route(&ctx)
|
|
||||||
.ok_or_else(|| anyhow::anyhow!("No route matched for HTTP/3 request to {}{}", host, path))?;
|
|
||||||
let route = route_match.route;
|
|
||||||
|
|
||||||
// Resolve backend target (use matched target or first target)
|
|
||||||
let target = route_match.target
|
|
||||||
.or_else(|| route.action.targets.as_ref().and_then(|t| t.first()))
|
|
||||||
.ok_or_else(|| anyhow::anyhow!("No target for HTTP/3 route"))?;
|
|
||||||
|
|
||||||
let backend_host = target.host.first();
|
|
||||||
let backend_port = target.port.resolve(port);
|
|
||||||
let backend_addr = format!("{}:{}", backend_host, backend_port);
|
|
||||||
|
|
||||||
// Connect to backend via TCP HTTP/1.1 with timeout
|
|
||||||
let tcp_stream = tokio::time::timeout(
|
|
||||||
connect_timeout,
|
|
||||||
tokio::net::TcpStream::connect(&backend_addr),
|
|
||||||
).await
|
|
||||||
.map_err(|_| anyhow::anyhow!("Backend connect timeout to {}", backend_addr))?
|
|
||||||
.map_err(|e| anyhow::anyhow!("Backend connect to {} failed: {}", backend_addr, e))?;
|
|
||||||
|
|
||||||
let _ = tcp_stream.set_nodelay(true);
|
|
||||||
|
|
||||||
let io = hyper_util::rt::TokioIo::new(tcp_stream);
|
|
||||||
let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await
|
|
||||||
.map_err(|e| anyhow::anyhow!("Backend handshake failed: {}", e))?;
|
|
||||||
|
|
||||||
tokio::spawn(async move {
|
|
||||||
if let Err(e) = conn.await {
|
|
||||||
debug!("Backend connection closed: {}", e);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
// Stream request body from H3 client to backend via an mpsc channel.
|
|
||||||
// This avoids buffering the entire request body in memory.
|
|
||||||
let (body_tx, body_rx) = tokio::sync::mpsc::channel::<Bytes>(4);
|
let (body_tx, body_rx) = tokio::sync::mpsc::channel::<Bytes>(4);
|
||||||
let total_bytes_in = Arc::new(std::sync::atomic::AtomicU64::new(0));
|
|
||||||
let total_bytes_in_writer = Arc::clone(&total_bytes_in);
|
|
||||||
|
|
||||||
// Spawn the H3 body reader task
|
// Spawn the H3 body reader task with cancellation
|
||||||
|
let body_cancel = cancel.clone();
|
||||||
let body_reader = tokio::spawn(async move {
|
let body_reader = tokio::spawn(async move {
|
||||||
while let Ok(Some(mut chunk)) = stream.recv_data().await {
|
loop {
|
||||||
|
let chunk = tokio::select! {
|
||||||
|
_ = body_cancel.cancelled() => break,
|
||||||
|
result = stream.recv_data() => {
|
||||||
|
match result {
|
||||||
|
Ok(Some(chunk)) => chunk,
|
||||||
|
_ => break,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let mut chunk = chunk;
|
||||||
let data = Bytes::copy_from_slice(chunk.chunk());
|
let data = Bytes::copy_from_slice(chunk.chunk());
|
||||||
total_bytes_in_writer.fetch_add(data.len() as u64, std::sync::atomic::Ordering::Relaxed);
|
|
||||||
chunk.advance(chunk.remaining());
|
chunk.advance(chunk.remaining());
|
||||||
if body_tx.send(data).await.is_err() {
|
if body_tx.send(data).await.is_err() {
|
||||||
break;
|
break;
|
||||||
@@ -207,106 +141,63 @@ async fn handle_h3_request(
|
|||||||
stream
|
stream
|
||||||
});
|
});
|
||||||
|
|
||||||
// Create a body that polls from the mpsc receiver
|
// Build a hyper::Request<BoxBody> from the H3 request + streaming body.
|
||||||
|
// The URI already has scheme + authority + path set by the h3 crate.
|
||||||
let body = H3RequestBody { receiver: body_rx };
|
let body = H3RequestBody { receiver: body_rx };
|
||||||
let backend_req = build_backend_request(&method, &backend_addr, &path, &host, &request, body)?;
|
let (parts, _) = request.into_parts();
|
||||||
|
let boxed_body: BoxBody<Bytes, hyper::Error> = BoxBody::new(body);
|
||||||
|
let req = hyper::Request::from_parts(parts, boxed_body);
|
||||||
|
|
||||||
let response = sender.send_request(backend_req).await
|
// Delegate to HttpProxyService — same backend path as TCP/HTTP:
|
||||||
|
// route matching, ALPN protocol detection, connection pool, H1/H2/H3 auto.
|
||||||
|
let conn_activity = ConnActivity::new_standalone();
|
||||||
|
let response = http_proxy.handle_request(req, peer_addr, port, cancel, conn_activity).await
|
||||||
.map_err(|e| anyhow::anyhow!("Backend request failed: {}", e))?;
|
.map_err(|e| anyhow::anyhow!("Backend request failed: {}", e))?;
|
||||||
|
|
||||||
// Await the body reader to get the stream back
|
// Await the body reader to get the H3 stream back
|
||||||
let mut stream = body_reader.await
|
let mut stream = body_reader.await
|
||||||
.map_err(|e| anyhow::anyhow!("Body reader task failed: {}", e))?;
|
.map_err(|e| anyhow::anyhow!("Body reader task failed: {}", e))?;
|
||||||
let total_bytes_in = total_bytes_in.load(std::sync::atomic::Ordering::Relaxed);
|
|
||||||
|
|
||||||
// Build H3 response
|
// Send response headers over H3 (skip hop-by-hop headers)
|
||||||
let status = response.status();
|
let (resp_parts, resp_body) = response.into_parts();
|
||||||
let mut h3_response = hyper::Response::builder().status(status);
|
let mut h3_response = hyper::Response::builder().status(resp_parts.status);
|
||||||
|
for (name, value) in &resp_parts.headers {
|
||||||
// Copy response headers (skip hop-by-hop)
|
let n = name.as_str();
|
||||||
for (name, value) in response.headers() {
|
|
||||||
let n = name.as_str().to_lowercase();
|
|
||||||
if n == "transfer-encoding" || n == "connection" || n == "keep-alive" || n == "upgrade" {
|
if n == "transfer-encoding" || n == "connection" || n == "keep-alive" || n == "upgrade" {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
h3_response = h3_response.header(name, value);
|
h3_response = h3_response.header(name, value);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add Alt-Svc for HTTP/3 advertisement
|
|
||||||
let alt_svc = route.action.udp.as_ref()
|
|
||||||
.and_then(|u| u.quic.as_ref())
|
|
||||||
.map(|q| {
|
|
||||||
let p = q.alt_svc_port.unwrap_or(port);
|
|
||||||
let ma = q.alt_svc_max_age.unwrap_or(86400);
|
|
||||||
format!("h3=\":{}\"; ma={}", p, ma)
|
|
||||||
})
|
|
||||||
.unwrap_or_else(|| format!("h3=\":{}\"; ma=86400", port));
|
|
||||||
h3_response = h3_response.header("alt-svc", alt_svc);
|
|
||||||
|
|
||||||
let h3_response = h3_response.body(())
|
let h3_response = h3_response.body(())
|
||||||
.map_err(|e| anyhow::anyhow!("Failed to build H3 response: {}", e))?;
|
.map_err(|e| anyhow::anyhow!("Failed to build H3 response: {}", e))?;
|
||||||
|
|
||||||
// Send response headers
|
|
||||||
stream.send_response(h3_response).await
|
stream.send_response(h3_response).await
|
||||||
.map_err(|e| anyhow::anyhow!("Failed to send H3 response: {}", e))?;
|
.map_err(|e| anyhow::anyhow!("Failed to send H3 response: {}", e))?;
|
||||||
|
|
||||||
// Stream response body back
|
// Stream response body back over H3
|
||||||
use http_body_util::BodyExt;
|
let mut resp_body = resp_body;
|
||||||
let mut body = response.into_body();
|
while let Some(frame) = resp_body.frame().await {
|
||||||
let mut total_bytes_out: u64 = 0;
|
|
||||||
while let Some(frame) = body.frame().await {
|
|
||||||
match frame {
|
match frame {
|
||||||
Ok(frame) => {
|
Ok(frame) => {
|
||||||
if let Some(data) = frame.data_ref() {
|
if let Some(data) = frame.data_ref() {
|
||||||
total_bytes_out += data.len() as u64;
|
|
||||||
stream.send_data(Bytes::copy_from_slice(data)).await
|
stream.send_data(Bytes::copy_from_slice(data)).await
|
||||||
.map_err(|e| anyhow::anyhow!("Failed to send H3 data: {}", e))?;
|
.map_err(|e| anyhow::anyhow!("Failed to send H3 data: {}", e))?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
warn!("Backend body read error: {}", e);
|
warn!("Response body read error: {}", e);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Record metrics
|
// Finish the H3 stream (send QUIC FIN)
|
||||||
let route_id = route.name.as_deref().or(route.id.as_deref());
|
|
||||||
metrics.record_bytes(total_bytes_in, total_bytes_out, route_id, Some(client_ip));
|
|
||||||
|
|
||||||
// Finish the stream
|
|
||||||
stream.finish().await
|
stream.finish().await
|
||||||
.map_err(|e| anyhow::anyhow!("Failed to finish H3 stream: {}", e))?;
|
.map_err(|e| anyhow::anyhow!("Failed to finish H3 stream: {}", e))?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Build an HTTP/1.1 backend request from the H3 frontend request.
|
|
||||||
fn build_backend_request<B>(
|
|
||||||
method: &hyper::Method,
|
|
||||||
backend_addr: &str,
|
|
||||||
path: &str,
|
|
||||||
host: &str,
|
|
||||||
original_request: &hyper::Request<()>,
|
|
||||||
body: B,
|
|
||||||
) -> anyhow::Result<hyper::Request<B>> {
|
|
||||||
let mut req = hyper::Request::builder()
|
|
||||||
.method(method)
|
|
||||||
.uri(format!("http://{}{}", backend_addr, path))
|
|
||||||
.header("host", host);
|
|
||||||
|
|
||||||
// Forward non-pseudo headers
|
|
||||||
for (name, value) in original_request.headers() {
|
|
||||||
let n = name.as_str();
|
|
||||||
if !n.starts_with(':') && n != "host" {
|
|
||||||
req = req.header(name, value);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
req.body(body)
|
|
||||||
.map_err(|e| anyhow::anyhow!("Failed to build backend request: {}", e))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// A streaming request body backed by an mpsc channel receiver.
|
/// A streaming request body backed by an mpsc channel receiver.
|
||||||
///
|
///
|
||||||
/// Implements `http_body::Body` so hyper can poll chunks as they arrive
|
/// Implements `http_body::Body` so hyper can poll chunks as they arrive
|
||||||
|
|||||||
@@ -1,8 +1,11 @@
|
|||||||
//! Bounded, TTL-based protocol detection cache for HTTP/2 auto-detection.
|
//! Bounded, TTL-based protocol detection cache for backend protocol auto-detection.
|
||||||
//!
|
//!
|
||||||
//! Caches the ALPN-negotiated protocol (H1 or H2) per backend endpoint and requested
|
//! Caches the detected protocol (H1, H2, or H3) per backend endpoint and requested
|
||||||
//! domain (host:port + requested_host). This prevents cache oscillation when multiple
|
//! domain (host:port + requested_host). This prevents cache oscillation when multiple
|
||||||
//! frontend domains share the same backend but differ in HTTP/2 support.
|
//! frontend domains share the same backend but differ in protocol support.
|
||||||
|
//!
|
||||||
|
//! H3 detection uses the browser model: Alt-Svc headers from H1/H2 responses are
|
||||||
|
//! parsed and cached, including the advertised H3 port (which may differ from TCP).
|
||||||
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
@@ -29,6 +32,14 @@ pub enum DetectedProtocol {
|
|||||||
H3,
|
H3,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Result of a protocol cache lookup.
|
||||||
|
#[derive(Debug, Clone, Copy)]
|
||||||
|
pub struct CachedProtocol {
|
||||||
|
pub protocol: DetectedProtocol,
|
||||||
|
/// For H3: the port advertised by Alt-Svc (may differ from TCP port).
|
||||||
|
pub h3_port: Option<u16>,
|
||||||
|
}
|
||||||
|
|
||||||
/// Key for the protocol cache: (host, port, requested_host).
|
/// Key for the protocol cache: (host, port, requested_host).
|
||||||
#[derive(Clone, Debug, Hash, Eq, PartialEq)]
|
#[derive(Clone, Debug, Hash, Eq, PartialEq)]
|
||||||
pub struct ProtocolCacheKey {
|
pub struct ProtocolCacheKey {
|
||||||
@@ -43,6 +54,8 @@ pub struct ProtocolCacheKey {
|
|||||||
struct CachedEntry {
|
struct CachedEntry {
|
||||||
protocol: DetectedProtocol,
|
protocol: DetectedProtocol,
|
||||||
detected_at: Instant,
|
detected_at: Instant,
|
||||||
|
/// For H3: the port advertised by Alt-Svc (may differ from TCP port).
|
||||||
|
h3_port: Option<u16>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Bounded, TTL-based protocol detection cache.
|
/// Bounded, TTL-based protocol detection cache.
|
||||||
@@ -75,11 +88,14 @@ impl ProtocolCache {
|
|||||||
|
|
||||||
/// Look up the cached protocol for a backend endpoint.
|
/// Look up the cached protocol for a backend endpoint.
|
||||||
/// Returns `None` if not cached or expired (caller should probe via ALPN).
|
/// Returns `None` if not cached or expired (caller should probe via ALPN).
|
||||||
pub fn get(&self, key: &ProtocolCacheKey) -> Option<DetectedProtocol> {
|
pub fn get(&self, key: &ProtocolCacheKey) -> Option<CachedProtocol> {
|
||||||
let entry = self.cache.get(key)?;
|
let entry = self.cache.get(key)?;
|
||||||
if entry.detected_at.elapsed() < PROTOCOL_CACHE_TTL {
|
if entry.detected_at.elapsed() < PROTOCOL_CACHE_TTL {
|
||||||
debug!("Protocol cache hit: {:?} for {}:{} (requested: {:?})", entry.protocol, key.host, key.port, key.requested_host);
|
debug!("Protocol cache hit: {:?} for {}:{} (requested: {:?})", entry.protocol, key.host, key.port, key.requested_host);
|
||||||
Some(entry.protocol)
|
Some(CachedProtocol {
|
||||||
|
protocol: entry.protocol,
|
||||||
|
h3_port: entry.h3_port,
|
||||||
|
})
|
||||||
} else {
|
} else {
|
||||||
// Expired — remove and return None to trigger re-probe
|
// Expired — remove and return None to trigger re-probe
|
||||||
drop(entry); // release DashMap ref before remove
|
drop(entry); // release DashMap ref before remove
|
||||||
@@ -91,6 +107,16 @@ impl ProtocolCache {
|
|||||||
/// Insert a detected protocol into the cache.
|
/// Insert a detected protocol into the cache.
|
||||||
/// If the cache is at capacity, evict the oldest entry first.
|
/// If the cache is at capacity, evict the oldest entry first.
|
||||||
pub fn insert(&self, key: ProtocolCacheKey, protocol: DetectedProtocol) {
|
pub fn insert(&self, key: ProtocolCacheKey, protocol: DetectedProtocol) {
|
||||||
|
self.insert_with_h3_port(key, protocol, None);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Insert an H3 detection result with the Alt-Svc advertised port.
|
||||||
|
pub fn insert_h3(&self, key: ProtocolCacheKey, h3_port: u16) {
|
||||||
|
self.insert_with_h3_port(key, DetectedProtocol::H3, Some(h3_port));
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Insert a protocol detection result with an optional H3 port.
|
||||||
|
fn insert_with_h3_port(&self, key: ProtocolCacheKey, protocol: DetectedProtocol, h3_port: Option<u16>) {
|
||||||
if self.cache.len() >= PROTOCOL_CACHE_MAX_ENTRIES && !self.cache.contains_key(&key) {
|
if self.cache.len() >= PROTOCOL_CACHE_MAX_ENTRIES && !self.cache.contains_key(&key) {
|
||||||
// Evict the oldest entry to stay within bounds
|
// Evict the oldest entry to stay within bounds
|
||||||
let oldest = self.cache.iter()
|
let oldest = self.cache.iter()
|
||||||
@@ -103,6 +129,7 @@ impl ProtocolCache {
|
|||||||
self.cache.insert(key, CachedEntry {
|
self.cache.insert(key, CachedEntry {
|
||||||
protocol,
|
protocol,
|
||||||
detected_at: Instant::now(),
|
detected_at: Instant::now(),
|
||||||
|
h3_port,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -36,13 +36,30 @@ use crate::upstream_selector::UpstreamSelector;
|
|||||||
/// Per-connection context for keeping the idle watchdog alive during body streaming.
|
/// Per-connection context for keeping the idle watchdog alive during body streaming.
|
||||||
/// Passed through the forwarding chain so CountingBody can update the timestamp.
|
/// Passed through the forwarding chain so CountingBody can update the timestamp.
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
struct ConnActivity {
|
pub struct ConnActivity {
|
||||||
last_activity: Arc<AtomicU64>,
|
last_activity: Arc<AtomicU64>,
|
||||||
start: std::time::Instant,
|
start: std::time::Instant,
|
||||||
/// Active-request counter from handle_io's idle watchdog. When set, CountingBody
|
/// Active-request counter from handle_io's idle watchdog. When set, CountingBody
|
||||||
/// increments on creation and decrements on Drop, keeping the watchdog aware that
|
/// increments on creation and decrements on Drop, keeping the watchdog aware that
|
||||||
/// a response body is still streaming after the request handler has returned.
|
/// a response body is still streaming after the request handler has returned.
|
||||||
active_requests: Option<Arc<AtomicU64>>,
|
active_requests: Option<Arc<AtomicU64>>,
|
||||||
|
/// Protocol cache key for Alt-Svc discovery. When set, `build_streaming_response`
|
||||||
|
/// checks the backend's original response headers for Alt-Svc before our
|
||||||
|
/// ResponseFilter injects its own. None when not in auto-detect mode or after H3 failure.
|
||||||
|
alt_svc_cache_key: Option<crate::protocol_cache::ProtocolCacheKey>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ConnActivity {
|
||||||
|
/// Create a minimal ConnActivity (no idle watchdog, no Alt-Svc cache).
|
||||||
|
/// Used by H3ProxyService where the TCP idle watchdog doesn't apply.
|
||||||
|
pub fn new_standalone() -> Self {
|
||||||
|
Self {
|
||||||
|
last_activity: Arc::new(AtomicU64::new(0)),
|
||||||
|
start: std::time::Instant::now(),
|
||||||
|
active_requests: None,
|
||||||
|
alt_svc_cache_key: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Default upstream connect timeout (30 seconds).
|
/// Default upstream connect timeout (30 seconds).
|
||||||
@@ -58,6 +75,18 @@ const DEFAULT_WS_INACTIVITY_TIMEOUT: std::time::Duration = std::time::Duration::
|
|||||||
/// Default WebSocket max lifetime (24 hours).
|
/// Default WebSocket max lifetime (24 hours).
|
||||||
const DEFAULT_WS_MAX_LIFETIME: std::time::Duration = std::time::Duration::from_secs(86400);
|
const DEFAULT_WS_MAX_LIFETIME: std::time::Duration = std::time::Duration::from_secs(86400);
|
||||||
|
|
||||||
|
/// Timeout for QUIC (H3) backend connections. Short because UDP is often firewalled.
|
||||||
|
const QUIC_CONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(3);
|
||||||
|
|
||||||
|
/// Protocol decision for backend connection.
|
||||||
|
#[derive(Debug)]
|
||||||
|
enum ProtocolDecision {
|
||||||
|
H1,
|
||||||
|
H2,
|
||||||
|
H3 { port: u16 },
|
||||||
|
AlpnProbe,
|
||||||
|
}
|
||||||
|
|
||||||
/// RAII guard that decrements the active request counter on drop.
|
/// RAII guard that decrements the active request counter on drop.
|
||||||
/// Ensures the counter is correct even if the request handler panics.
|
/// Ensures the counter is correct even if the request handler panics.
|
||||||
struct ActiveRequestGuard {
|
struct ActiveRequestGuard {
|
||||||
@@ -174,6 +203,10 @@ pub struct HttpProxyService {
|
|||||||
route_rate_limiters: Arc<DashMap<String, Arc<RateLimiter>>>,
|
route_rate_limiters: Arc<DashMap<String, Arc<RateLimiter>>>,
|
||||||
/// Request counter for periodic rate limiter cleanup.
|
/// Request counter for periodic rate limiter cleanup.
|
||||||
request_counter: AtomicU64,
|
request_counter: AtomicU64,
|
||||||
|
/// Epoch for time-based rate limiter cleanup.
|
||||||
|
rate_limiter_epoch: std::time::Instant,
|
||||||
|
/// Last rate limiter cleanup time (ms since epoch).
|
||||||
|
last_rate_limiter_cleanup_ms: AtomicU64,
|
||||||
/// Cache of compiled URL rewrite regexes (keyed by pattern string).
|
/// Cache of compiled URL rewrite regexes (keyed by pattern string).
|
||||||
regex_cache: DashMap<String, Regex>,
|
regex_cache: DashMap<String, Regex>,
|
||||||
/// Shared backend TLS config for session resumption across connections.
|
/// Shared backend TLS config for session resumption across connections.
|
||||||
@@ -190,6 +223,9 @@ pub struct HttpProxyService {
|
|||||||
ws_inactivity_timeout: std::time::Duration,
|
ws_inactivity_timeout: std::time::Duration,
|
||||||
/// WebSocket maximum connection lifetime.
|
/// WebSocket maximum connection lifetime.
|
||||||
ws_max_lifetime: std::time::Duration,
|
ws_max_lifetime: std::time::Duration,
|
||||||
|
/// Shared QUIC client endpoint for outbound H3 backend connections.
|
||||||
|
/// Lazily initialized on first H3 backend attempt.
|
||||||
|
quinn_client_endpoint: Arc<quinn::Endpoint>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl HttpProxyService {
|
impl HttpProxyService {
|
||||||
@@ -201,6 +237,8 @@ impl HttpProxyService {
|
|||||||
connect_timeout: DEFAULT_CONNECT_TIMEOUT,
|
connect_timeout: DEFAULT_CONNECT_TIMEOUT,
|
||||||
route_rate_limiters: Arc::new(DashMap::new()),
|
route_rate_limiters: Arc::new(DashMap::new()),
|
||||||
request_counter: AtomicU64::new(0),
|
request_counter: AtomicU64::new(0),
|
||||||
|
rate_limiter_epoch: std::time::Instant::now(),
|
||||||
|
last_rate_limiter_cleanup_ms: AtomicU64::new(0),
|
||||||
regex_cache: DashMap::new(),
|
regex_cache: DashMap::new(),
|
||||||
backend_tls_config: Self::default_backend_tls_config(),
|
backend_tls_config: Self::default_backend_tls_config(),
|
||||||
backend_tls_config_alpn: Self::default_backend_tls_config_with_alpn(),
|
backend_tls_config_alpn: Self::default_backend_tls_config_with_alpn(),
|
||||||
@@ -209,6 +247,7 @@ impl HttpProxyService {
|
|||||||
http_idle_timeout: DEFAULT_HTTP_IDLE_TIMEOUT,
|
http_idle_timeout: DEFAULT_HTTP_IDLE_TIMEOUT,
|
||||||
ws_inactivity_timeout: DEFAULT_WS_INACTIVITY_TIMEOUT,
|
ws_inactivity_timeout: DEFAULT_WS_INACTIVITY_TIMEOUT,
|
||||||
ws_max_lifetime: DEFAULT_WS_MAX_LIFETIME,
|
ws_max_lifetime: DEFAULT_WS_MAX_LIFETIME,
|
||||||
|
quinn_client_endpoint: Arc::new(Self::create_quinn_client_endpoint()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -225,6 +264,8 @@ impl HttpProxyService {
|
|||||||
connect_timeout,
|
connect_timeout,
|
||||||
route_rate_limiters: Arc::new(DashMap::new()),
|
route_rate_limiters: Arc::new(DashMap::new()),
|
||||||
request_counter: AtomicU64::new(0),
|
request_counter: AtomicU64::new(0),
|
||||||
|
rate_limiter_epoch: std::time::Instant::now(),
|
||||||
|
last_rate_limiter_cleanup_ms: AtomicU64::new(0),
|
||||||
regex_cache: DashMap::new(),
|
regex_cache: DashMap::new(),
|
||||||
backend_tls_config: Self::default_backend_tls_config(),
|
backend_tls_config: Self::default_backend_tls_config(),
|
||||||
backend_tls_config_alpn: Self::default_backend_tls_config_with_alpn(),
|
backend_tls_config_alpn: Self::default_backend_tls_config_with_alpn(),
|
||||||
@@ -233,6 +274,7 @@ impl HttpProxyService {
|
|||||||
http_idle_timeout: DEFAULT_HTTP_IDLE_TIMEOUT,
|
http_idle_timeout: DEFAULT_HTTP_IDLE_TIMEOUT,
|
||||||
ws_inactivity_timeout: DEFAULT_WS_INACTIVITY_TIMEOUT,
|
ws_inactivity_timeout: DEFAULT_WS_INACTIVITY_TIMEOUT,
|
||||||
ws_max_lifetime: DEFAULT_WS_MAX_LIFETIME,
|
ws_max_lifetime: DEFAULT_WS_MAX_LIFETIME,
|
||||||
|
quinn_client_endpoint: Arc::new(Self::create_quinn_client_endpoint()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -324,8 +366,9 @@ impl HttpProxyService {
|
|||||||
let cn = cancel_inner.clone();
|
let cn = cancel_inner.clone();
|
||||||
let la = Arc::clone(&la_inner);
|
let la = Arc::clone(&la_inner);
|
||||||
let st = start;
|
let st = start;
|
||||||
let ca = ConnActivity { last_activity: Arc::clone(&la_inner), start, active_requests: Some(Arc::clone(&ar_inner)) };
|
let ca = ConnActivity { last_activity: Arc::clone(&la_inner), start, active_requests: Some(Arc::clone(&ar_inner)), alt_svc_cache_key: None };
|
||||||
async move {
|
async move {
|
||||||
|
let req = req.map(|body| BoxBody::new(body));
|
||||||
let result = svc.handle_request(req, peer, port, cn, ca).await;
|
let result = svc.handle_request(req, peer, port, cn, ca).await;
|
||||||
// Mark request end — update activity timestamp before guard drops
|
// Mark request end — update activity timestamp before guard drops
|
||||||
la.store(st.elapsed().as_millis() as u64, Ordering::Relaxed);
|
la.store(st.elapsed().as_millis() as u64, Ordering::Relaxed);
|
||||||
@@ -395,13 +438,17 @@ impl HttpProxyService {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Handle a single HTTP request.
|
/// Handle a single HTTP request.
|
||||||
async fn handle_request(
|
///
|
||||||
|
/// Accepts a generic body (`BoxBody`) so both the TCP/HTTP path (which boxes
|
||||||
|
/// `Incoming`) and the H3 path (which boxes the H3 request body stream) can
|
||||||
|
/// share the same backend forwarding logic.
|
||||||
|
pub async fn handle_request(
|
||||||
&self,
|
&self,
|
||||||
req: Request<Incoming>,
|
req: Request<BoxBody<Bytes, hyper::Error>>,
|
||||||
peer_addr: std::net::SocketAddr,
|
peer_addr: std::net::SocketAddr,
|
||||||
port: u16,
|
port: u16,
|
||||||
cancel: CancellationToken,
|
cancel: CancellationToken,
|
||||||
conn_activity: ConnActivity,
|
mut conn_activity: ConnActivity,
|
||||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||||
let host = req.headers()
|
let host = req.headers()
|
||||||
.get("host")
|
.get("host")
|
||||||
@@ -485,9 +532,13 @@ impl HttpProxyService {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Periodic rate limiter cleanup (every 1000 requests)
|
// Periodic rate limiter cleanup (every 1000 requests or every 60s)
|
||||||
let count = self.request_counter.fetch_add(1, Ordering::Relaxed);
|
let count = self.request_counter.fetch_add(1, Ordering::Relaxed);
|
||||||
if count % 1000 == 0 {
|
let now_ms = self.rate_limiter_epoch.elapsed().as_millis() as u64;
|
||||||
|
let last_cleanup = self.last_rate_limiter_cleanup_ms.load(Ordering::Relaxed);
|
||||||
|
let time_triggered = now_ms.saturating_sub(last_cleanup) >= 60_000;
|
||||||
|
if count % 1000 == 0 || time_triggered {
|
||||||
|
self.last_rate_limiter_cleanup_ms.store(now_ms, Ordering::Relaxed);
|
||||||
for entry in self.route_rate_limiters.iter() {
|
for entry in self.route_rate_limiters.iter() {
|
||||||
entry.value().cleanup();
|
entry.value().cleanup();
|
||||||
}
|
}
|
||||||
@@ -645,37 +696,101 @@ impl HttpProxyService {
|
|||||||
|
|
||||||
// --- Resolve protocol decision based on backend protocol mode ---
|
// --- Resolve protocol decision based on backend protocol mode ---
|
||||||
let is_auto_detect_mode = matches!(backend_protocol_mode, rustproxy_config::BackendProtocol::Auto);
|
let is_auto_detect_mode = matches!(backend_protocol_mode, rustproxy_config::BackendProtocol::Auto);
|
||||||
let (use_h2, needs_alpn_probe) = match backend_protocol_mode {
|
let protocol_cache_key = crate::protocol_cache::ProtocolCacheKey {
|
||||||
rustproxy_config::BackendProtocol::Http1 => (false, false),
|
|
||||||
rustproxy_config::BackendProtocol::Http2 => (true, false),
|
|
||||||
rustproxy_config::BackendProtocol::Http3 => {
|
|
||||||
// HTTP/3 (QUIC) backend connections not yet implemented — fall back to H1
|
|
||||||
warn!("backendProtocol 'http3' not yet implemented, falling back to http1");
|
|
||||||
(false, false)
|
|
||||||
}
|
|
||||||
rustproxy_config::BackendProtocol::Auto => {
|
|
||||||
if !upstream.use_tls {
|
|
||||||
// No ALPN without TLS — default to H1
|
|
||||||
(false, false)
|
|
||||||
} else {
|
|
||||||
let cache_key = crate::protocol_cache::ProtocolCacheKey {
|
|
||||||
host: upstream.host.clone(),
|
host: upstream.host.clone(),
|
||||||
port: upstream.port,
|
port: upstream.port,
|
||||||
requested_host: host.clone(),
|
requested_host: host.clone(),
|
||||||
};
|
};
|
||||||
match self.protocol_cache.get(&cache_key) {
|
let protocol_decision = match backend_protocol_mode {
|
||||||
Some(crate::protocol_cache::DetectedProtocol::H2) => (true, false),
|
rustproxy_config::BackendProtocol::Http1 => ProtocolDecision::H1,
|
||||||
Some(crate::protocol_cache::DetectedProtocol::H1) => (false, false),
|
rustproxy_config::BackendProtocol::Http2 => ProtocolDecision::H2,
|
||||||
Some(crate::protocol_cache::DetectedProtocol::H3) => {
|
rustproxy_config::BackendProtocol::Http3 => ProtocolDecision::H3 { port: upstream.port },
|
||||||
// H3 cached but we're on TCP — fall back to H2 probe
|
rustproxy_config::BackendProtocol::Auto => {
|
||||||
(false, true)
|
if !upstream.use_tls {
|
||||||
|
// No ALPN without TLS, no QUIC without TLS — default to H1
|
||||||
|
ProtocolDecision::H1
|
||||||
|
} else {
|
||||||
|
match self.protocol_cache.get(&protocol_cache_key) {
|
||||||
|
Some(cached) => match cached.protocol {
|
||||||
|
crate::protocol_cache::DetectedProtocol::H3 => {
|
||||||
|
if let Some(h3_port) = cached.h3_port {
|
||||||
|
ProtocolDecision::H3 { port: h3_port }
|
||||||
|
} else {
|
||||||
|
// H3 cached but no port — fall back to ALPN probe
|
||||||
|
ProtocolDecision::AlpnProbe
|
||||||
}
|
}
|
||||||
None => (false, true), // needs ALPN probe
|
}
|
||||||
|
crate::protocol_cache::DetectedProtocol::H2 => ProtocolDecision::H2,
|
||||||
|
crate::protocol_cache::DetectedProtocol::H1 => ProtocolDecision::H1,
|
||||||
|
},
|
||||||
|
None => ProtocolDecision::AlpnProbe,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Derive legacy flags for the existing H1/H2 connection path
|
||||||
|
let (use_h2, mut needs_alpn_probe) = match &protocol_decision {
|
||||||
|
ProtocolDecision::H1 => (false, false),
|
||||||
|
ProtocolDecision::H2 => (true, false),
|
||||||
|
ProtocolDecision::H3 { .. } => (false, false), // H3 path handled separately below
|
||||||
|
ProtocolDecision::AlpnProbe => (false, true),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Set Alt-Svc cache key on conn_activity so build_streaming_response can check
|
||||||
|
// the backend's original Alt-Svc header before ResponseFilter injects our own.
|
||||||
|
if is_auto_detect_mode {
|
||||||
|
conn_activity.alt_svc_cache_key = Some(protocol_cache_key.clone());
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- H3 path: try QUIC connection before TCP ---
|
||||||
|
if let ProtocolDecision::H3 { port: h3_port } = protocol_decision {
|
||||||
|
let h3_pool_key = crate::connection_pool::PoolKey {
|
||||||
|
host: upstream.host.clone(),
|
||||||
|
port: h3_port,
|
||||||
|
use_tls: true,
|
||||||
|
protocol: crate::connection_pool::PoolProtocol::H3,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Try H3 pool checkout first
|
||||||
|
if let Some((quic_conn, _age)) = self.connection_pool.checkout_h3(&h3_pool_key) {
|
||||||
|
self.metrics.backend_pool_hit(&upstream_key);
|
||||||
|
let result = self.forward_h3(
|
||||||
|
quic_conn, parts, body, upstream_headers, &upstream_path,
|
||||||
|
route_match.route, route_id, &ip_str, &h3_pool_key, domain_str, &conn_activity, &upstream_key,
|
||||||
|
).await;
|
||||||
|
self.upstream_selector.connection_ended(&upstream_key);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try fresh QUIC connection
|
||||||
|
match self.connect_quic_backend(&upstream.host, h3_port).await {
|
||||||
|
Ok(quic_conn) => {
|
||||||
|
self.metrics.backend_pool_miss(&upstream_key);
|
||||||
|
self.metrics.backend_connection_opened(&upstream_key, std::time::Instant::now().elapsed());
|
||||||
|
let result = self.forward_h3(
|
||||||
|
quic_conn, parts, body, upstream_headers, &upstream_path,
|
||||||
|
route_match.route, route_id, &ip_str, &h3_pool_key, domain_str, &conn_activity, &upstream_key,
|
||||||
|
).await;
|
||||||
|
self.upstream_selector.connection_ended(&upstream_key);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
warn!(backend = %upstream_key, error = %e,
|
||||||
|
"H3 backend connect failed, falling back to H2/H1");
|
||||||
|
// Suppress Alt-Svc caching for the fallback to prevent re-caching H3
|
||||||
|
// from our own injected Alt-Svc header or a stale backend Alt-Svc
|
||||||
|
conn_activity.alt_svc_cache_key = None;
|
||||||
|
// Force ALPN probe on TCP fallback so we correctly detect H2 vs H1
|
||||||
|
// (don't cache anything yet — let the ALPN probe decide)
|
||||||
|
if is_auto_detect_mode && upstream.use_tls {
|
||||||
|
needs_alpn_probe = true;
|
||||||
|
}
|
||||||
|
// Fall through to TCP path
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// --- Connection pooling: try reusing an existing connection first ---
|
// --- Connection pooling: try reusing an existing connection first ---
|
||||||
// For ALPN probe mode, skip pool checkout (we don't know the protocol yet)
|
// For ALPN probe mode, skip pool checkout (we don't know the protocol yet)
|
||||||
if !needs_alpn_probe {
|
if !needs_alpn_probe {
|
||||||
@@ -870,6 +985,7 @@ impl HttpProxyService {
|
|||||||
};
|
};
|
||||||
self.upstream_selector.connection_ended(&upstream_key);
|
self.upstream_selector.connection_ended(&upstream_key);
|
||||||
self.metrics.backend_connection_closed(&upstream_key);
|
self.metrics.backend_connection_closed(&upstream_key);
|
||||||
|
|
||||||
result
|
result
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -879,7 +995,7 @@ impl HttpProxyService {
|
|||||||
&self,
|
&self,
|
||||||
io: TokioIo<BackendStream>,
|
io: TokioIo<BackendStream>,
|
||||||
parts: hyper::http::request::Parts,
|
parts: hyper::http::request::Parts,
|
||||||
body: Incoming,
|
body: BoxBody<Bytes, hyper::Error>,
|
||||||
upstream_headers: hyper::HeaderMap,
|
upstream_headers: hyper::HeaderMap,
|
||||||
upstream_path: &str,
|
upstream_path: &str,
|
||||||
_upstream: &crate::upstream_selector::UpstreamSelection,
|
_upstream: &crate::upstream_selector::UpstreamSelection,
|
||||||
@@ -927,7 +1043,7 @@ impl HttpProxyService {
|
|||||||
&self,
|
&self,
|
||||||
mut sender: hyper::client::conn::http1::SendRequest<BoxBody<Bytes, hyper::Error>>,
|
mut sender: hyper::client::conn::http1::SendRequest<BoxBody<Bytes, hyper::Error>>,
|
||||||
parts: hyper::http::request::Parts,
|
parts: hyper::http::request::Parts,
|
||||||
body: Incoming,
|
body: BoxBody<Bytes, hyper::Error>,
|
||||||
upstream_headers: hyper::HeaderMap,
|
upstream_headers: hyper::HeaderMap,
|
||||||
upstream_path: &str,
|
upstream_path: &str,
|
||||||
route: &rustproxy_config::RouteConfig,
|
route: &rustproxy_config::RouteConfig,
|
||||||
@@ -991,7 +1107,7 @@ impl HttpProxyService {
|
|||||||
&self,
|
&self,
|
||||||
io: TokioIo<BackendStream>,
|
io: TokioIo<BackendStream>,
|
||||||
parts: hyper::http::request::Parts,
|
parts: hyper::http::request::Parts,
|
||||||
body: Incoming,
|
body: BoxBody<Bytes, hyper::Error>,
|
||||||
upstream_headers: hyper::HeaderMap,
|
upstream_headers: hyper::HeaderMap,
|
||||||
upstream_path: &str,
|
upstream_path: &str,
|
||||||
_upstream: &crate::upstream_selector::UpstreamSelection,
|
_upstream: &crate::upstream_selector::UpstreamSelection,
|
||||||
@@ -1065,7 +1181,7 @@ impl HttpProxyService {
|
|||||||
&self,
|
&self,
|
||||||
sender: hyper::client::conn::http2::SendRequest<BoxBody<Bytes, hyper::Error>>,
|
sender: hyper::client::conn::http2::SendRequest<BoxBody<Bytes, hyper::Error>>,
|
||||||
parts: hyper::http::request::Parts,
|
parts: hyper::http::request::Parts,
|
||||||
body: Incoming,
|
body: BoxBody<Bytes, hyper::Error>,
|
||||||
upstream_headers: hyper::HeaderMap,
|
upstream_headers: hyper::HeaderMap,
|
||||||
upstream_path: &str,
|
upstream_path: &str,
|
||||||
route: &rustproxy_config::RouteConfig,
|
route: &rustproxy_config::RouteConfig,
|
||||||
@@ -1258,7 +1374,7 @@ impl HttpProxyService {
|
|||||||
&self,
|
&self,
|
||||||
io: TokioIo<BackendStream>,
|
io: TokioIo<BackendStream>,
|
||||||
parts: hyper::http::request::Parts,
|
parts: hyper::http::request::Parts,
|
||||||
body: Incoming,
|
body: BoxBody<Bytes, hyper::Error>,
|
||||||
mut upstream_headers: hyper::HeaderMap,
|
mut upstream_headers: hyper::HeaderMap,
|
||||||
upstream_path: &str,
|
upstream_path: &str,
|
||||||
upstream: &crate::upstream_selector::UpstreamSelection,
|
upstream: &crate::upstream_selector::UpstreamSelection,
|
||||||
@@ -1589,7 +1705,7 @@ impl HttpProxyService {
|
|||||||
&self,
|
&self,
|
||||||
mut sender: hyper::client::conn::http2::SendRequest<BoxBody<Bytes, hyper::Error>>,
|
mut sender: hyper::client::conn::http2::SendRequest<BoxBody<Bytes, hyper::Error>>,
|
||||||
parts: hyper::http::request::Parts,
|
parts: hyper::http::request::Parts,
|
||||||
body: Incoming,
|
body: BoxBody<Bytes, hyper::Error>,
|
||||||
upstream_headers: hyper::HeaderMap,
|
upstream_headers: hyper::HeaderMap,
|
||||||
upstream_path: &str,
|
upstream_path: &str,
|
||||||
route: &rustproxy_config::RouteConfig,
|
route: &rustproxy_config::RouteConfig,
|
||||||
@@ -1668,6 +1784,19 @@ impl HttpProxyService {
|
|||||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||||
let (resp_parts, resp_body) = upstream_response.into_parts();
|
let (resp_parts, resp_body) = upstream_response.into_parts();
|
||||||
|
|
||||||
|
// Check for Alt-Svc in the backend's ORIGINAL response headers BEFORE
|
||||||
|
// ResponseFilter::apply_headers runs — the filter may inject our own Alt-Svc
|
||||||
|
// for client-facing HTTP/3 advertisement, which must not be confused with
|
||||||
|
// backend-originated Alt-Svc.
|
||||||
|
if let Some(ref cache_key) = conn_activity.alt_svc_cache_key {
|
||||||
|
if let Some(alt_svc) = resp_parts.headers.get("alt-svc").and_then(|v| v.to_str().ok()) {
|
||||||
|
if let Some(h3_port) = parse_alt_svc_h3_port(alt_svc) {
|
||||||
|
debug!(h3_port, "Backend advertises H3 via Alt-Svc");
|
||||||
|
self.protocol_cache.insert_h3(cache_key.clone(), h3_port);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
let mut response = Response::builder()
|
let mut response = Response::builder()
|
||||||
.status(resp_parts.status);
|
.status(resp_parts.status);
|
||||||
|
|
||||||
@@ -1717,7 +1846,7 @@ impl HttpProxyService {
|
|||||||
/// Handle a WebSocket upgrade request (H1 Upgrade or H2 Extended CONNECT per RFC 8441).
|
/// Handle a WebSocket upgrade request (H1 Upgrade or H2 Extended CONNECT per RFC 8441).
|
||||||
async fn handle_websocket_upgrade(
|
async fn handle_websocket_upgrade(
|
||||||
&self,
|
&self,
|
||||||
req: Request<Incoming>,
|
req: Request<BoxBody<Bytes, hyper::Error>>,
|
||||||
peer_addr: std::net::SocketAddr,
|
peer_addr: std::net::SocketAddr,
|
||||||
upstream: &crate::upstream_selector::UpstreamSelection,
|
upstream: &crate::upstream_selector::UpstreamSelection,
|
||||||
route: &rustproxy_config::RouteConfig,
|
route: &rustproxy_config::RouteConfig,
|
||||||
@@ -2017,12 +2146,26 @@ impl HttpProxyService {
|
|||||||
let ws_max_lifetime = self.ws_max_lifetime;
|
let ws_max_lifetime = self.ws_max_lifetime;
|
||||||
|
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
|
// RAII guard: ensures connection_ended is called even if this task panics
|
||||||
|
struct WsUpstreamGuard {
|
||||||
|
selector: UpstreamSelector,
|
||||||
|
key: String,
|
||||||
|
}
|
||||||
|
impl Drop for WsUpstreamGuard {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
self.selector.connection_ended(&self.key);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let _upstream_guard = WsUpstreamGuard {
|
||||||
|
selector: upstream_selector,
|
||||||
|
key: upstream_key_owned.clone(),
|
||||||
|
};
|
||||||
|
|
||||||
let client_upgraded = match on_client_upgrade.await {
|
let client_upgraded = match on_client_upgrade.await {
|
||||||
Ok(upgraded) => upgraded,
|
Ok(upgraded) => upgraded,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
debug!("WebSocket: client upgrade failed: {}", e);
|
debug!("WebSocket: client upgrade failed: {}", e);
|
||||||
upstream_selector.connection_ended(&upstream_key_owned);
|
return; // _upstream_guard Drop handles connection_ended
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -2181,9 +2324,7 @@ impl HttpProxyService {
|
|||||||
watchdog.abort();
|
watchdog.abort();
|
||||||
|
|
||||||
debug!("WebSocket tunnel closed: {} bytes in, {} bytes out", bytes_in, bytes_out);
|
debug!("WebSocket tunnel closed: {} bytes in, {} bytes out", bytes_in, bytes_out);
|
||||||
|
// _upstream_guard Drop handles connection_ended on all paths including panic
|
||||||
upstream_selector.connection_ended(&upstream_key_owned);
|
|
||||||
// Bytes already reported per-chunk in the copy loops above
|
|
||||||
});
|
});
|
||||||
|
|
||||||
let body: BoxBody<Bytes, hyper::Error> = BoxBody::new(
|
let body: BoxBody<Bytes, hyper::Error> = BoxBody::new(
|
||||||
@@ -2393,6 +2534,256 @@ impl HttpProxyService {
|
|||||||
config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
|
config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
|
||||||
Arc::new(config)
|
Arc::new(config)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Create a shared QUIC client endpoint for outbound H3 backend connections.
|
||||||
|
fn create_quinn_client_endpoint() -> quinn::Endpoint {
|
||||||
|
let _ = rustls::crypto::ring::default_provider().install_default();
|
||||||
|
let mut tls_config = rustls::ClientConfig::builder()
|
||||||
|
.dangerous()
|
||||||
|
.with_custom_certificate_verifier(Arc::new(InsecureBackendVerifier))
|
||||||
|
.with_no_client_auth();
|
||||||
|
tls_config.alpn_protocols = vec![b"h3".to_vec()];
|
||||||
|
|
||||||
|
let quic_crypto = quinn::crypto::rustls::QuicClientConfig::try_from(tls_config)
|
||||||
|
.expect("Failed to create QUIC client crypto config");
|
||||||
|
let client_config = quinn::ClientConfig::new(Arc::new(quic_crypto));
|
||||||
|
|
||||||
|
let mut endpoint = quinn::Endpoint::client("0.0.0.0:0".parse().unwrap())
|
||||||
|
.expect("Failed to create QUIC client endpoint");
|
||||||
|
endpoint.set_default_client_config(client_config);
|
||||||
|
endpoint
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Connect to a backend via QUIC (H3).
|
||||||
|
async fn connect_quic_backend(
|
||||||
|
&self,
|
||||||
|
host: &str,
|
||||||
|
port: u16,
|
||||||
|
) -> Result<quinn::Connection, Box<dyn std::error::Error + Send + Sync>> {
|
||||||
|
let addr = tokio::net::lookup_host(format!("{}:{}", host, port))
|
||||||
|
.await?
|
||||||
|
.next()
|
||||||
|
.ok_or("DNS resolution returned no addresses")?;
|
||||||
|
|
||||||
|
let server_name = host.to_string();
|
||||||
|
let connecting = self.quinn_client_endpoint.connect(addr, &server_name)?;
|
||||||
|
|
||||||
|
let connection = tokio::time::timeout(QUIC_CONNECT_TIMEOUT, connecting).await
|
||||||
|
.map_err(|_| "QUIC connect timeout (3s)")??;
|
||||||
|
|
||||||
|
debug!("QUIC backend connection established to {}:{}", host, port);
|
||||||
|
Ok(connection)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Forward request to backend via HTTP/3 over QUIC.
|
||||||
|
async fn forward_h3(
|
||||||
|
&self,
|
||||||
|
quic_conn: quinn::Connection,
|
||||||
|
parts: hyper::http::request::Parts,
|
||||||
|
body: BoxBody<Bytes, hyper::Error>,
|
||||||
|
upstream_headers: hyper::HeaderMap,
|
||||||
|
upstream_path: &str,
|
||||||
|
route: &rustproxy_config::RouteConfig,
|
||||||
|
route_id: Option<&str>,
|
||||||
|
source_ip: &str,
|
||||||
|
pool_key: &crate::connection_pool::PoolKey,
|
||||||
|
domain: &str,
|
||||||
|
conn_activity: &ConnActivity,
|
||||||
|
backend_key: &str,
|
||||||
|
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||||
|
let h3_quinn_conn = h3_quinn::Connection::new(quic_conn.clone());
|
||||||
|
let (mut driver, mut send_request) = match h3::client::builder()
|
||||||
|
.send_grease(false)
|
||||||
|
.build(h3_quinn_conn)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(pair) => pair,
|
||||||
|
Err(e) => {
|
||||||
|
error!(backend = %backend_key, domain = %domain, error = %e, "H3 client handshake failed");
|
||||||
|
self.metrics.backend_handshake_error(backend_key);
|
||||||
|
return Ok(error_response(StatusCode::BAD_GATEWAY, "H3 handshake failed"));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Spawn the h3 connection driver
|
||||||
|
let driver_pool = Arc::clone(&self.connection_pool);
|
||||||
|
let driver_pool_key = pool_key.clone();
|
||||||
|
let gen_holder = Arc::new(std::sync::atomic::AtomicU64::new(u64::MAX));
|
||||||
|
let driver_gen = Arc::clone(&gen_holder);
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let close_err = std::future::poll_fn(|cx| driver.poll_close(cx)).await;
|
||||||
|
debug!("H3 connection driver closed: {:?}", close_err);
|
||||||
|
let g = driver_gen.load(std::sync::atomic::Ordering::Relaxed);
|
||||||
|
if g != u64::MAX {
|
||||||
|
driver_pool.remove_h3_if_generation(&driver_pool_key, g);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// Build the H3 request
|
||||||
|
let uri = hyper::Uri::builder()
|
||||||
|
.scheme("https")
|
||||||
|
.authority(domain)
|
||||||
|
.path_and_query(upstream_path)
|
||||||
|
.build()
|
||||||
|
.unwrap_or_else(|_| upstream_path.parse().unwrap_or_default());
|
||||||
|
|
||||||
|
let mut h3_req = hyper::Request::builder()
|
||||||
|
.method(parts.method.clone())
|
||||||
|
.uri(uri);
|
||||||
|
|
||||||
|
if let Some(headers) = h3_req.headers_mut() {
|
||||||
|
*headers = upstream_headers;
|
||||||
|
}
|
||||||
|
|
||||||
|
let h3_req = h3_req.body(()).unwrap();
|
||||||
|
|
||||||
|
// Send the request
|
||||||
|
let mut stream = match send_request.send_request(h3_req).await {
|
||||||
|
Ok(s) => s,
|
||||||
|
Err(e) => {
|
||||||
|
error!(backend = %backend_key, domain = %domain, error = %e, "H3 send_request failed");
|
||||||
|
self.metrics.backend_request_error(backend_key);
|
||||||
|
return Ok(error_response(StatusCode::BAD_GATEWAY, "H3 request failed"));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Stream request body
|
||||||
|
let rid: Option<Arc<str>> = route_id.map(Arc::from);
|
||||||
|
let sip: Arc<str> = Arc::from(source_ip);
|
||||||
|
|
||||||
|
{
|
||||||
|
use http_body_util::BodyExt;
|
||||||
|
let mut body = body;
|
||||||
|
while let Some(frame) = body.frame().await {
|
||||||
|
match frame {
|
||||||
|
Ok(frame) => {
|
||||||
|
if let Some(data) = frame.data_ref() {
|
||||||
|
self.metrics.record_bytes(data.len() as u64, 0, rid.as_deref(), Some(&sip));
|
||||||
|
if let Err(e) = stream.send_data(Bytes::copy_from_slice(data)).await {
|
||||||
|
error!(backend = %backend_key, error = %e, "H3 send_data failed");
|
||||||
|
return Ok(error_response(StatusCode::BAD_GATEWAY, "H3 body send failed"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
warn!(backend = %backend_key, error = %e, "Client body read error during H3 forward");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Signal end of body
|
||||||
|
stream.finish().await.ok();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read response
|
||||||
|
let h3_response = match stream.recv_response().await {
|
||||||
|
Ok(resp) => resp,
|
||||||
|
Err(e) => {
|
||||||
|
error!(backend = %backend_key, domain = %domain, error = %e, "H3 recv_response failed");
|
||||||
|
self.metrics.backend_request_error(backend_key);
|
||||||
|
return Ok(error_response(StatusCode::BAD_GATEWAY, "H3 response failed"));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Build the response for the client
|
||||||
|
let status = h3_response.status();
|
||||||
|
let mut response = Response::builder().status(status);
|
||||||
|
|
||||||
|
if let Some(headers) = response.headers_mut() {
|
||||||
|
for (name, value) in h3_response.headers() {
|
||||||
|
let n = name.as_str();
|
||||||
|
// Skip hop-by-hop headers
|
||||||
|
if n == "transfer-encoding" || n == "connection" || n == "keep-alive" {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
headers.insert(name.clone(), value.clone());
|
||||||
|
}
|
||||||
|
ResponseFilter::apply_headers(route, headers, None);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stream response body back via an adapter
|
||||||
|
let h3_body = H3ClientResponseBody { stream };
|
||||||
|
let counting_body = CountingBody::new(
|
||||||
|
h3_body,
|
||||||
|
Arc::clone(&self.metrics),
|
||||||
|
rid,
|
||||||
|
Some(sip),
|
||||||
|
Direction::Out,
|
||||||
|
).with_connection_activity(Arc::clone(&conn_activity.last_activity), conn_activity.start);
|
||||||
|
|
||||||
|
let counting_body = if let Some(ref ar) = conn_activity.active_requests {
|
||||||
|
counting_body.with_active_requests(Arc::clone(ar))
|
||||||
|
} else {
|
||||||
|
counting_body
|
||||||
|
};
|
||||||
|
|
||||||
|
let body: BoxBody<Bytes, hyper::Error> = BoxBody::new(counting_body);
|
||||||
|
|
||||||
|
// Register connection in pool on success
|
||||||
|
if status != StatusCode::BAD_GATEWAY {
|
||||||
|
let g = self.connection_pool.register_h3(pool_key.clone(), quic_conn);
|
||||||
|
gen_holder.store(g, std::sync::atomic::Ordering::Relaxed);
|
||||||
|
}
|
||||||
|
|
||||||
|
self.metrics.set_backend_protocol(backend_key, "h3");
|
||||||
|
Ok(response.body(body).unwrap())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Parse an Alt-Svc header value to extract the H3 port.
|
||||||
|
/// Handles formats like `h3=":443"; ma=86400` and `h3=":8443", h2=":443"`.
|
||||||
|
fn parse_alt_svc_h3_port(header_value: &str) -> Option<u16> {
|
||||||
|
for directive in header_value.split(',') {
|
||||||
|
let directive = directive.trim();
|
||||||
|
// Match h3=":<port>" or h3-29=":<port>" etc.
|
||||||
|
if directive.starts_with("h3=") || directive.starts_with("h3-") {
|
||||||
|
// Find the port in ":<port>"
|
||||||
|
if let Some(start) = directive.find("\":") {
|
||||||
|
let rest = &directive[start + 2..];
|
||||||
|
if let Some(end) = rest.find('"') {
|
||||||
|
if let Ok(port) = rest[..end].parse::<u16>() {
|
||||||
|
return Some(port);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Response body adapter for H3 client responses.
|
||||||
|
/// Reads data from the h3 `RequestStream` recv side and presents it as an `http_body::Body`.
|
||||||
|
struct H3ClientResponseBody {
|
||||||
|
stream: h3::client::RequestStream<h3_quinn::BidiStream<Bytes>, Bytes>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl http_body::Body for H3ClientResponseBody {
|
||||||
|
type Data = Bytes;
|
||||||
|
type Error = hyper::Error;
|
||||||
|
|
||||||
|
fn poll_frame(
|
||||||
|
mut self: Pin<&mut Self>,
|
||||||
|
_cx: &mut Context<'_>,
|
||||||
|
) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
|
||||||
|
// h3's recv_data is async, so we need to poll it manually.
|
||||||
|
// Use a small future to poll the recv_data call.
|
||||||
|
use std::future::Future;
|
||||||
|
let mut fut = Box::pin(self.stream.recv_data());
|
||||||
|
match fut.as_mut().poll(_cx) {
|
||||||
|
Poll::Ready(Ok(Some(mut buf))) => {
|
||||||
|
use bytes::Buf;
|
||||||
|
let data = Bytes::copy_from_slice(buf.chunk());
|
||||||
|
buf.advance(buf.remaining());
|
||||||
|
Poll::Ready(Some(Ok(http_body::Frame::data(data))))
|
||||||
|
}
|
||||||
|
Poll::Ready(Ok(None)) => Poll::Ready(None),
|
||||||
|
Poll::Ready(Err(e)) => {
|
||||||
|
warn!("H3 response body recv error: {}", e);
|
||||||
|
Poll::Ready(None)
|
||||||
|
}
|
||||||
|
Poll::Pending => Poll::Pending,
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Insecure certificate verifier for backend TLS connections (fallback only).
|
/// Insecure certificate verifier for backend TLS connections (fallback only).
|
||||||
@@ -2455,6 +2846,8 @@ impl Default for HttpProxyService {
|
|||||||
connect_timeout: DEFAULT_CONNECT_TIMEOUT,
|
connect_timeout: DEFAULT_CONNECT_TIMEOUT,
|
||||||
route_rate_limiters: Arc::new(DashMap::new()),
|
route_rate_limiters: Arc::new(DashMap::new()),
|
||||||
request_counter: AtomicU64::new(0),
|
request_counter: AtomicU64::new(0),
|
||||||
|
rate_limiter_epoch: std::time::Instant::now(),
|
||||||
|
last_rate_limiter_cleanup_ms: AtomicU64::new(0),
|
||||||
regex_cache: DashMap::new(),
|
regex_cache: DashMap::new(),
|
||||||
backend_tls_config: Self::default_backend_tls_config(),
|
backend_tls_config: Self::default_backend_tls_config(),
|
||||||
backend_tls_config_alpn: Self::default_backend_tls_config_with_alpn(),
|
backend_tls_config_alpn: Self::default_backend_tls_config_with_alpn(),
|
||||||
@@ -2463,6 +2856,7 @@ impl Default for HttpProxyService {
|
|||||||
http_idle_timeout: DEFAULT_HTTP_IDLE_TIMEOUT,
|
http_idle_timeout: DEFAULT_HTTP_IDLE_TIMEOUT,
|
||||||
ws_inactivity_timeout: DEFAULT_WS_INACTIVITY_TIMEOUT,
|
ws_inactivity_timeout: DEFAULT_WS_INACTIVITY_TIMEOUT,
|
||||||
ws_max_lifetime: DEFAULT_WS_MAX_LIFETIME,
|
ws_max_lifetime: DEFAULT_WS_MAX_LIFETIME,
|
||||||
|
quinn_client_endpoint: Arc::new(Self::create_quinn_client_endpoint()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ use std::sync::Arc;
|
|||||||
use bytes::Bytes;
|
use bytes::Bytes;
|
||||||
use http_body_util::Full;
|
use http_body_util::Full;
|
||||||
use http_body_util::BodyExt;
|
use http_body_util::BodyExt;
|
||||||
use hyper::body::Incoming;
|
|
||||||
use hyper::{Request, Response, StatusCode};
|
use hyper::{Request, Response, StatusCode};
|
||||||
use http_body_util::combinators::BoxBody;
|
use http_body_util::combinators::BoxBody;
|
||||||
|
|
||||||
@@ -19,7 +18,7 @@ impl RequestFilter {
|
|||||||
/// Apply security filters. Returns Some(response) if the request should be blocked.
|
/// Apply security filters. Returns Some(response) if the request should be blocked.
|
||||||
pub fn apply(
|
pub fn apply(
|
||||||
security: &RouteSecurity,
|
security: &RouteSecurity,
|
||||||
req: &Request<Incoming>,
|
req: &Request<impl hyper::body::Body>,
|
||||||
peer_addr: &SocketAddr,
|
peer_addr: &SocketAddr,
|
||||||
) -> Option<Response<BoxBody<Bytes, hyper::Error>>> {
|
) -> Option<Response<BoxBody<Bytes, hyper::Error>>> {
|
||||||
Self::apply_with_rate_limiter(security, req, peer_addr, None)
|
Self::apply_with_rate_limiter(security, req, peer_addr, None)
|
||||||
@@ -29,7 +28,7 @@ impl RequestFilter {
|
|||||||
/// Returns Some(response) if the request should be blocked.
|
/// Returns Some(response) if the request should be blocked.
|
||||||
pub fn apply_with_rate_limiter(
|
pub fn apply_with_rate_limiter(
|
||||||
security: &RouteSecurity,
|
security: &RouteSecurity,
|
||||||
req: &Request<Incoming>,
|
req: &Request<impl hyper::body::Body>,
|
||||||
peer_addr: &SocketAddr,
|
peer_addr: &SocketAddr,
|
||||||
rate_limiter: Option<&Arc<RateLimiter>>,
|
rate_limiter: Option<&Arc<RateLimiter>>,
|
||||||
) -> Option<Response<BoxBody<Bytes, hyper::Error>>> {
|
) -> Option<Response<BoxBody<Bytes, hyper::Error>>> {
|
||||||
@@ -182,7 +181,7 @@ impl RequestFilter {
|
|||||||
/// Determine the rate limit key based on configuration.
|
/// Determine the rate limit key based on configuration.
|
||||||
fn rate_limit_key(
|
fn rate_limit_key(
|
||||||
config: &rustproxy_config::RouteRateLimit,
|
config: &rustproxy_config::RouteRateLimit,
|
||||||
req: &Request<Incoming>,
|
req: &Request<impl hyper::body::Body>,
|
||||||
peer_addr: &SocketAddr,
|
peer_addr: &SocketAddr,
|
||||||
) -> String {
|
) -> String {
|
||||||
use rustproxy_config::RateLimitKeyBy;
|
use rustproxy_config::RateLimitKeyBy;
|
||||||
@@ -220,7 +219,7 @@ impl RequestFilter {
|
|||||||
/// Handle CORS preflight (OPTIONS) requests.
|
/// Handle CORS preflight (OPTIONS) requests.
|
||||||
/// Returns Some(response) if this is a CORS preflight that should be handled.
|
/// Returns Some(response) if this is a CORS preflight that should be handled.
|
||||||
pub fn handle_cors_preflight(
|
pub fn handle_cors_preflight(
|
||||||
req: &Request<Incoming>,
|
req: &Request<impl hyper::body::Body>,
|
||||||
) -> Option<Response<BoxBody<Bytes, hyper::Error>>> {
|
) -> Option<Response<BoxBody<Bytes, hyper::Error>>> {
|
||||||
if req.method() != hyper::Method::OPTIONS {
|
if req.method() != hyper::Method::OPTIONS {
|
||||||
return None;
|
return None;
|
||||||
|
|||||||
@@ -411,11 +411,24 @@ impl MetricsCollector {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Record a backend connection closing.
|
/// Record a backend connection closing.
|
||||||
|
/// Removes all per-backend tracking entries when the active count reaches 0.
|
||||||
pub fn backend_connection_closed(&self, key: &str) {
|
pub fn backend_connection_closed(&self, key: &str) {
|
||||||
if let Some(counter) = self.backend_active.get(key) {
|
if let Some(counter) = self.backend_active.get(key) {
|
||||||
let val = counter.load(Ordering::Relaxed);
|
let prev = counter.fetch_sub(1, Ordering::Relaxed);
|
||||||
if val > 0 {
|
if prev <= 1 {
|
||||||
counter.fetch_sub(1, Ordering::Relaxed);
|
// Active count reached 0 — clean up all per-backend maps
|
||||||
|
drop(counter); // release DashMap ref before remove
|
||||||
|
self.backend_active.remove(key);
|
||||||
|
self.backend_total.remove(key);
|
||||||
|
self.backend_protocol.remove(key);
|
||||||
|
self.backend_connect_errors.remove(key);
|
||||||
|
self.backend_handshake_errors.remove(key);
|
||||||
|
self.backend_request_errors.remove(key);
|
||||||
|
self.backend_connect_time_us.remove(key);
|
||||||
|
self.backend_connect_count.remove(key);
|
||||||
|
self.backend_pool_hits.remove(key);
|
||||||
|
self.backend_pool_misses.remove(key);
|
||||||
|
self.backend_h2_failures.remove(key);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1213,10 +1226,13 @@ mod tests {
|
|||||||
// No entry created
|
// No entry created
|
||||||
assert!(collector.backend_active.get(key).is_none());
|
assert!(collector.backend_active.get(key).is_none());
|
||||||
|
|
||||||
// Open one, close two — should saturate at 0
|
// Open one, close — entries are removed when active count reaches 0
|
||||||
collector.backend_connection_opened(key, Duration::from_millis(1));
|
collector.backend_connection_opened(key, Duration::from_millis(1));
|
||||||
collector.backend_connection_closed(key);
|
collector.backend_connection_closed(key);
|
||||||
|
// Entry should be cleaned up (active reached 0)
|
||||||
|
assert!(collector.backend_active.get(key).is_none());
|
||||||
|
// Second close on missing entry is a no-op
|
||||||
collector.backend_connection_closed(key);
|
collector.backend_connection_closed(key);
|
||||||
assert_eq!(collector.backend_active.get(key).unwrap().load(Ordering::Relaxed), 0);
|
assert!(collector.backend_active.get(key).is_none());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,13 +3,21 @@
|
|||||||
//! Manages QUIC endpoints (via quinn), accepts connections, and either:
|
//! Manages QUIC endpoints (via quinn), accepts connections, and either:
|
||||||
//! - Forwards streams bidirectionally to TCP backends (QUIC termination)
|
//! - Forwards streams bidirectionally to TCP backends (QUIC termination)
|
||||||
//! - Dispatches to H3ProxyService for HTTP/3 handling (Phase 5)
|
//! - Dispatches to H3ProxyService for HTTP/3 handling (Phase 5)
|
||||||
|
//!
|
||||||
|
//! When `proxy_ips` is configured, a UDP relay layer intercepts PROXY protocol v2
|
||||||
|
//! headers before they reach quinn, extracting real client IPs for attribution.
|
||||||
|
|
||||||
use std::net::SocketAddr;
|
use std::net::{IpAddr, SocketAddr};
|
||||||
|
use std::sync::atomic::{AtomicU64, Ordering};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
use std::time::Instant;
|
||||||
|
|
||||||
use tokio::io::AsyncWriteExt;
|
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||||
|
use tokio::net::UdpSocket;
|
||||||
|
use tokio::task::JoinHandle;
|
||||||
|
|
||||||
use arc_swap::ArcSwap;
|
use arc_swap::ArcSwap;
|
||||||
|
use dashmap::DashMap;
|
||||||
use quinn::{Endpoint, ServerConfig as QuinnServerConfig};
|
use quinn::{Endpoint, ServerConfig as QuinnServerConfig};
|
||||||
use rustls::ServerConfig as RustlsServerConfig;
|
use rustls::ServerConfig as RustlsServerConfig;
|
||||||
use tokio_util::sync::CancellationToken;
|
use tokio_util::sync::CancellationToken;
|
||||||
@@ -19,6 +27,8 @@ use rustproxy_config::{RouteConfig, TransportProtocol};
|
|||||||
use rustproxy_metrics::MetricsCollector;
|
use rustproxy_metrics::MetricsCollector;
|
||||||
use rustproxy_routing::{MatchContext, RouteManager};
|
use rustproxy_routing::{MatchContext, RouteManager};
|
||||||
|
|
||||||
|
use rustproxy_http::h3_service::H3ProxyService;
|
||||||
|
|
||||||
use crate::connection_tracker::ConnectionTracker;
|
use crate::connection_tracker::ConnectionTracker;
|
||||||
|
|
||||||
/// Create a QUIC server endpoint on the given port with the provided TLS config.
|
/// Create a QUIC server endpoint on the given port with the provided TLS config.
|
||||||
@@ -45,9 +55,285 @@ pub fn create_quic_endpoint(
|
|||||||
Ok(endpoint)
|
Ok(endpoint)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ===== PROXY protocol relay for QUIC =====
|
||||||
|
|
||||||
|
/// Result of creating a QUIC endpoint with a PROXY protocol relay layer.
|
||||||
|
pub struct QuicProxyRelay {
|
||||||
|
/// The quinn endpoint (bound to 127.0.0.1:ephemeral).
|
||||||
|
pub endpoint: Endpoint,
|
||||||
|
/// The relay recv loop task handle.
|
||||||
|
pub relay_task: JoinHandle<()>,
|
||||||
|
/// Maps relay socket local addr → real client SocketAddr (from PROXY v2).
|
||||||
|
/// Consulted by `quic_accept_loop` to resolve real client IPs.
|
||||||
|
pub real_client_map: Arc<DashMap<SocketAddr, SocketAddr>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A single relay session for forwarding datagrams between an external source
|
||||||
|
/// and the internal quinn endpoint.
|
||||||
|
struct RelaySession {
|
||||||
|
socket: Arc<UdpSocket>,
|
||||||
|
last_activity: AtomicU64,
|
||||||
|
return_task: JoinHandle<()>,
|
||||||
|
cancel: CancellationToken,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a QUIC endpoint with a PROXY protocol v2 relay layer.
|
||||||
|
///
|
||||||
|
/// Instead of giving the external socket to quinn, we:
|
||||||
|
/// 1. Bind a raw UDP socket on 0.0.0.0:port (external)
|
||||||
|
/// 2. Bind quinn on 127.0.0.1:0 (internal, ephemeral)
|
||||||
|
/// 3. Run a relay loop that filters PROXY v2 headers and forwards datagrams
|
||||||
|
///
|
||||||
|
/// Only used when `proxy_ips` is non-empty.
|
||||||
|
pub fn create_quic_endpoint_with_proxy_relay(
|
||||||
|
port: u16,
|
||||||
|
tls_config: Arc<RustlsServerConfig>,
|
||||||
|
proxy_ips: Arc<Vec<IpAddr>>,
|
||||||
|
cancel: CancellationToken,
|
||||||
|
) -> anyhow::Result<QuicProxyRelay> {
|
||||||
|
// Bind external socket on the real port
|
||||||
|
let external_socket = std::net::UdpSocket::bind(SocketAddr::from(([0, 0, 0, 0], port)))?;
|
||||||
|
external_socket.set_nonblocking(true)?;
|
||||||
|
let external_socket = Arc::new(
|
||||||
|
UdpSocket::from_std(external_socket)
|
||||||
|
.map_err(|e| anyhow::anyhow!("Failed to wrap external socket: {}", e))?,
|
||||||
|
);
|
||||||
|
|
||||||
|
// Bind quinn on localhost ephemeral port
|
||||||
|
let internal_socket = std::net::UdpSocket::bind("127.0.0.1:0")?;
|
||||||
|
let quinn_internal_addr = internal_socket.local_addr()?;
|
||||||
|
|
||||||
|
let quic_crypto = quinn::crypto::rustls::QuicServerConfig::try_from(tls_config)
|
||||||
|
.map_err(|e| anyhow::anyhow!("Failed to create QUIC crypto config: {}", e))?;
|
||||||
|
let server_config = QuinnServerConfig::with_crypto(Arc::new(quic_crypto));
|
||||||
|
|
||||||
|
let endpoint = Endpoint::new(
|
||||||
|
quinn::EndpointConfig::default(),
|
||||||
|
Some(server_config),
|
||||||
|
internal_socket,
|
||||||
|
quinn::default_runtime()
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("No async runtime for quinn"))?,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let real_client_map = Arc::new(DashMap::new());
|
||||||
|
|
||||||
|
let relay_task = tokio::spawn(quic_proxy_relay_loop(
|
||||||
|
external_socket,
|
||||||
|
quinn_internal_addr,
|
||||||
|
proxy_ips,
|
||||||
|
Arc::clone(&real_client_map),
|
||||||
|
cancel,
|
||||||
|
));
|
||||||
|
|
||||||
|
info!("QUIC endpoint with PROXY relay on port {} (quinn internal: {})", port, quinn_internal_addr);
|
||||||
|
Ok(QuicProxyRelay { endpoint, relay_task, real_client_map })
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Main relay loop: reads datagrams from the external socket, filters PROXY v2
|
||||||
|
/// headers from trusted proxy IPs, and forwards everything else to quinn via
|
||||||
|
/// per-session relay sockets.
|
||||||
|
async fn quic_proxy_relay_loop(
|
||||||
|
external_socket: Arc<UdpSocket>,
|
||||||
|
quinn_internal_addr: SocketAddr,
|
||||||
|
proxy_ips: Arc<Vec<IpAddr>>,
|
||||||
|
real_client_map: Arc<DashMap<SocketAddr, SocketAddr>>,
|
||||||
|
cancel: CancellationToken,
|
||||||
|
) {
|
||||||
|
// Maps external source addr → real client addr (from PROXY v2 headers)
|
||||||
|
let proxy_addr_map: DashMap<SocketAddr, SocketAddr> = DashMap::new();
|
||||||
|
// Maps external source addr → relay session
|
||||||
|
let relay_sessions: DashMap<SocketAddr, Arc<RelaySession>> = DashMap::new();
|
||||||
|
let epoch = Instant::now();
|
||||||
|
let mut buf = vec![0u8; 65535];
|
||||||
|
|
||||||
|
// Inline cleanup: periodically scan relay_sessions for stale entries
|
||||||
|
let mut last_cleanup = Instant::now();
|
||||||
|
let cleanup_interval = std::time::Duration::from_secs(30);
|
||||||
|
let session_timeout_ms: u64 = 120_000;
|
||||||
|
|
||||||
|
loop {
|
||||||
|
let (len, src_addr) = tokio::select! {
|
||||||
|
_ = cancel.cancelled() => {
|
||||||
|
debug!("QUIC proxy relay loop cancelled");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
result = external_socket.recv_from(&mut buf) => {
|
||||||
|
match result {
|
||||||
|
Ok(r) => r,
|
||||||
|
Err(e) => {
|
||||||
|
warn!("QUIC proxy relay recv error: {}", e);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let datagram = &buf[..len];
|
||||||
|
|
||||||
|
// PROXY v2 handling: only on first datagram from a trusted proxy IP
|
||||||
|
// (before a relay session exists for this source)
|
||||||
|
if proxy_ips.contains(&src_addr.ip()) && relay_sessions.get(&src_addr).is_none() {
|
||||||
|
if crate::proxy_protocol::is_proxy_protocol_v2(datagram) {
|
||||||
|
match crate::proxy_protocol::parse_v2(datagram) {
|
||||||
|
Ok((header, _consumed)) => {
|
||||||
|
debug!("QUIC PROXY v2 from {}: real client {}", src_addr, header.source_addr);
|
||||||
|
proxy_addr_map.insert(src_addr, header.source_addr);
|
||||||
|
continue; // consume the PROXY v2 datagram
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
debug!("QUIC proxy relay: failed to parse PROXY v2 from {}: {}", src_addr, e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Determine real client address
|
||||||
|
let real_client = proxy_addr_map.get(&src_addr)
|
||||||
|
.map(|r| *r)
|
||||||
|
.unwrap_or(src_addr);
|
||||||
|
|
||||||
|
// Get or create relay session for this external source
|
||||||
|
let session = match relay_sessions.get(&src_addr) {
|
||||||
|
Some(s) => {
|
||||||
|
s.last_activity.store(epoch.elapsed().as_millis() as u64, Ordering::Relaxed);
|
||||||
|
Arc::clone(s.value())
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
// Create new relay socket connected to quinn's internal address
|
||||||
|
let relay_socket = match UdpSocket::bind("127.0.0.1:0").await {
|
||||||
|
Ok(s) => s,
|
||||||
|
Err(e) => {
|
||||||
|
warn!("QUIC relay: failed to bind relay socket: {}", e);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
if let Err(e) = relay_socket.connect(quinn_internal_addr).await {
|
||||||
|
warn!("QUIC relay: failed to connect relay socket to {}: {}", quinn_internal_addr, e);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
let relay_local_addr = match relay_socket.local_addr() {
|
||||||
|
Ok(a) => a,
|
||||||
|
Err(e) => {
|
||||||
|
warn!("QUIC relay: failed to get relay socket local addr: {}", e);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let relay_socket = Arc::new(relay_socket);
|
||||||
|
|
||||||
|
// Store the real client mapping for the QUIC accept loop
|
||||||
|
real_client_map.insert(relay_local_addr, real_client);
|
||||||
|
|
||||||
|
// Spawn return-path relay: quinn -> external socket -> original source
|
||||||
|
let session_cancel = cancel.child_token();
|
||||||
|
let return_task = tokio::spawn(relay_return_path(
|
||||||
|
Arc::clone(&relay_socket),
|
||||||
|
Arc::clone(&external_socket),
|
||||||
|
src_addr,
|
||||||
|
session_cancel.child_token(),
|
||||||
|
));
|
||||||
|
|
||||||
|
let session = Arc::new(RelaySession {
|
||||||
|
socket: relay_socket,
|
||||||
|
last_activity: AtomicU64::new(epoch.elapsed().as_millis() as u64),
|
||||||
|
return_task,
|
||||||
|
cancel: session_cancel,
|
||||||
|
});
|
||||||
|
|
||||||
|
relay_sessions.insert(src_addr, Arc::clone(&session));
|
||||||
|
debug!("QUIC relay: new session for {} (relay {}), real client {}",
|
||||||
|
src_addr, relay_local_addr, real_client);
|
||||||
|
|
||||||
|
session
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Forward datagram to quinn via the relay socket
|
||||||
|
if let Err(e) = session.socket.send(datagram).await {
|
||||||
|
debug!("QUIC relay: forward error to quinn for {}: {}", src_addr, e);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Periodic cleanup of stale relay sessions
|
||||||
|
if last_cleanup.elapsed() >= cleanup_interval {
|
||||||
|
last_cleanup = Instant::now();
|
||||||
|
let now_ms = epoch.elapsed().as_millis() as u64;
|
||||||
|
let stale_keys: Vec<SocketAddr> = relay_sessions.iter()
|
||||||
|
.filter(|entry| {
|
||||||
|
let age = now_ms.saturating_sub(entry.value().last_activity.load(Ordering::Relaxed));
|
||||||
|
age > session_timeout_ms
|
||||||
|
})
|
||||||
|
.map(|entry| *entry.key())
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
for key in stale_keys {
|
||||||
|
if let Some((_, session)) = relay_sessions.remove(&key) {
|
||||||
|
session.cancel.cancel();
|
||||||
|
session.return_task.abort();
|
||||||
|
// Clean up real_client_map entry
|
||||||
|
if let Ok(addr) = session.socket.local_addr() {
|
||||||
|
real_client_map.remove(&addr);
|
||||||
|
}
|
||||||
|
proxy_addr_map.remove(&key);
|
||||||
|
debug!("QUIC relay: cleaned up stale session for {}", key);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Also clean orphaned proxy_addr_map entries (PROXY header received
|
||||||
|
// but no relay session was ever created, e.g. client never sent data)
|
||||||
|
let orphaned: Vec<SocketAddr> = proxy_addr_map.iter()
|
||||||
|
.filter(|entry| relay_sessions.get(entry.key()).is_none())
|
||||||
|
.map(|entry| *entry.key())
|
||||||
|
.collect();
|
||||||
|
for key in orphaned {
|
||||||
|
proxy_addr_map.remove(&key);
|
||||||
|
debug!("QUIC relay: cleaned up orphaned proxy_addr_map entry for {}", key);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Shutdown: cancel all relay sessions
|
||||||
|
for entry in relay_sessions.iter() {
|
||||||
|
entry.value().cancel.cancel();
|
||||||
|
entry.value().return_task.abort();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return-path relay: receives datagrams from quinn (via the relay socket)
|
||||||
|
/// and forwards them back to the external client through the external socket.
|
||||||
|
async fn relay_return_path(
|
||||||
|
relay_socket: Arc<UdpSocket>,
|
||||||
|
external_socket: Arc<UdpSocket>,
|
||||||
|
external_src_addr: SocketAddr,
|
||||||
|
cancel: CancellationToken,
|
||||||
|
) {
|
||||||
|
let mut buf = vec![0u8; 65535];
|
||||||
|
loop {
|
||||||
|
let len = tokio::select! {
|
||||||
|
_ = cancel.cancelled() => break,
|
||||||
|
result = relay_socket.recv(&mut buf) => {
|
||||||
|
match result {
|
||||||
|
Ok(len) => len,
|
||||||
|
Err(e) => {
|
||||||
|
debug!("QUIC relay return recv error for {}: {}", external_src_addr, e);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if let Err(e) = external_socket.send_to(&buf[..len], external_src_addr).await {
|
||||||
|
debug!("QUIC relay return send error to {}: {}", external_src_addr, e);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ===== QUIC accept loop =====
|
||||||
|
|
||||||
/// Run the QUIC accept loop for a single endpoint.
|
/// Run the QUIC accept loop for a single endpoint.
|
||||||
///
|
///
|
||||||
/// Accepts incoming QUIC connections and spawns a task per connection.
|
/// Accepts incoming QUIC connections and spawns a task per connection.
|
||||||
|
/// When `real_client_map` is provided, it is consulted to resolve real client
|
||||||
|
/// IPs from PROXY protocol v2 headers (relay socket addr → real client addr).
|
||||||
pub async fn quic_accept_loop(
|
pub async fn quic_accept_loop(
|
||||||
endpoint: Endpoint,
|
endpoint: Endpoint,
|
||||||
port: u16,
|
port: u16,
|
||||||
@@ -55,6 +341,8 @@ pub async fn quic_accept_loop(
|
|||||||
metrics: Arc<MetricsCollector>,
|
metrics: Arc<MetricsCollector>,
|
||||||
conn_tracker: Arc<ConnectionTracker>,
|
conn_tracker: Arc<ConnectionTracker>,
|
||||||
cancel: CancellationToken,
|
cancel: CancellationToken,
|
||||||
|
h3_service: Option<Arc<H3ProxyService>>,
|
||||||
|
real_client_map: Option<Arc<DashMap<SocketAddr, SocketAddr>>>,
|
||||||
) {
|
) {
|
||||||
loop {
|
loop {
|
||||||
let incoming = tokio::select! {
|
let incoming = tokio::select! {
|
||||||
@@ -74,11 +362,16 @@ pub async fn quic_accept_loop(
|
|||||||
};
|
};
|
||||||
|
|
||||||
let remote_addr = incoming.remote_address();
|
let remote_addr = incoming.remote_address();
|
||||||
let ip = remote_addr.ip();
|
|
||||||
|
// Resolve real client IP from PROXY protocol map if available
|
||||||
|
let real_addr = real_client_map.as_ref()
|
||||||
|
.and_then(|map| map.get(&remote_addr).map(|r| *r))
|
||||||
|
.unwrap_or(remote_addr);
|
||||||
|
let ip = real_addr.ip();
|
||||||
|
|
||||||
// Per-IP rate limiting
|
// Per-IP rate limiting
|
||||||
if !conn_tracker.try_accept(&ip) {
|
if !conn_tracker.try_accept(&ip) {
|
||||||
debug!("QUIC connection rejected from {} (rate limit)", remote_addr);
|
debug!("QUIC connection rejected from {} (rate limit)", real_addr);
|
||||||
// Drop `incoming` to refuse the connection
|
// Drop `incoming` to refuse the connection
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@@ -101,7 +394,7 @@ pub async fn quic_accept_loop(
|
|||||||
let route = match rm.find_route(&ctx) {
|
let route = match rm.find_route(&ctx) {
|
||||||
Some(m) => m.route.clone(),
|
Some(m) => m.route.clone(),
|
||||||
None => {
|
None => {
|
||||||
debug!("No QUIC route matched for port {} from {}", port, remote_addr);
|
debug!("No QUIC route matched for port {} from {}", port, real_addr);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@@ -113,16 +406,36 @@ pub async fn quic_accept_loop(
|
|||||||
let metrics = Arc::clone(&metrics);
|
let metrics = Arc::clone(&metrics);
|
||||||
let conn_tracker = Arc::clone(&conn_tracker);
|
let conn_tracker = Arc::clone(&conn_tracker);
|
||||||
let cancel = cancel.child_token();
|
let cancel = cancel.child_token();
|
||||||
|
let h3_svc = h3_service.clone();
|
||||||
|
let real_client_addr = if real_addr != remote_addr { Some(real_addr) } else { None };
|
||||||
|
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
match handle_quic_connection(incoming, route, port, Arc::clone(&metrics), &cancel).await {
|
// RAII guard: ensures metrics/tracker cleanup even on panic
|
||||||
Ok(()) => debug!("QUIC connection from {} completed", remote_addr),
|
struct QuicConnGuard {
|
||||||
Err(e) => debug!("QUIC connection from {} error: {}", remote_addr, e),
|
tracker: Arc<ConnectionTracker>,
|
||||||
|
metrics: Arc<MetricsCollector>,
|
||||||
|
ip: std::net::IpAddr,
|
||||||
|
ip_str: String,
|
||||||
|
route_id: Option<String>,
|
||||||
}
|
}
|
||||||
|
impl Drop for QuicConnGuard {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
self.tracker.connection_closed(&self.ip);
|
||||||
|
self.metrics.connection_closed(self.route_id.as_deref(), Some(&self.ip_str));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let _guard = QuicConnGuard {
|
||||||
|
tracker: conn_tracker,
|
||||||
|
metrics: Arc::clone(&metrics),
|
||||||
|
ip,
|
||||||
|
ip_str,
|
||||||
|
route_id,
|
||||||
|
};
|
||||||
|
|
||||||
// Cleanup
|
match handle_quic_connection(incoming, route, port, Arc::clone(&metrics), &cancel, h3_svc, real_client_addr).await {
|
||||||
conn_tracker.connection_closed(&ip);
|
Ok(()) => debug!("QUIC connection from {} completed", real_addr),
|
||||||
metrics.connection_closed(route_id.as_deref(), Some(&ip_str));
|
Err(e) => debug!("QUIC connection from {} error: {}", real_addr, e),
|
||||||
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -139,10 +452,12 @@ async fn handle_quic_connection(
|
|||||||
port: u16,
|
port: u16,
|
||||||
metrics: Arc<MetricsCollector>,
|
metrics: Arc<MetricsCollector>,
|
||||||
cancel: &CancellationToken,
|
cancel: &CancellationToken,
|
||||||
|
h3_service: Option<Arc<H3ProxyService>>,
|
||||||
|
real_client_addr: Option<SocketAddr>,
|
||||||
) -> anyhow::Result<()> {
|
) -> anyhow::Result<()> {
|
||||||
let connection = incoming.await?;
|
let connection = incoming.await?;
|
||||||
let remote_addr = connection.remote_address();
|
let effective_addr = real_client_addr.unwrap_or_else(|| connection.remote_address());
|
||||||
debug!("QUIC connection established from {}", remote_addr);
|
debug!("QUIC connection established from {}", effective_addr);
|
||||||
|
|
||||||
// Check if this route has HTTP/3 enabled
|
// Check if this route has HTTP/3 enabled
|
||||||
let enable_http3 = route.action.udp.as_ref()
|
let enable_http3 = route.action.udp.as_ref()
|
||||||
@@ -151,13 +466,23 @@ async fn handle_quic_connection(
|
|||||||
.unwrap_or(false);
|
.unwrap_or(false);
|
||||||
|
|
||||||
if enable_http3 {
|
if enable_http3 {
|
||||||
// Phase 5: dispatch to H3ProxyService
|
if let Some(ref h3_svc) = h3_service {
|
||||||
// For now, log and accept streams for basic handling
|
debug!("HTTP/3 enabled for route {:?}, dispatching to H3ProxyService", route.name);
|
||||||
debug!("HTTP/3 enabled for route {:?}, dispatching to H3 handler", route.name);
|
h3_svc.handle_connection(connection, &route, port, real_client_addr, cancel).await
|
||||||
handle_h3_connection(connection, route, port, &metrics, cancel).await
|
} else {
|
||||||
|
warn!("HTTP/3 enabled for route {:?} but H3ProxyService not initialized", route.name);
|
||||||
|
// Keep connection alive until cancelled
|
||||||
|
tokio::select! {
|
||||||
|
_ = cancel.cancelled() => {}
|
||||||
|
reason = connection.closed() => {
|
||||||
|
debug!("HTTP/3 connection closed (no service): {}", reason);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
// Non-HTTP3 QUIC: bidirectional stream forwarding to TCP backend
|
// Non-HTTP3 QUIC: bidirectional stream forwarding to TCP backend
|
||||||
handle_quic_stream_forwarding(connection, route, port, metrics, cancel).await
|
handle_quic_stream_forwarding(connection, route, port, metrics, cancel, real_client_addr).await
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -172,8 +497,9 @@ async fn handle_quic_stream_forwarding(
|
|||||||
port: u16,
|
port: u16,
|
||||||
metrics: Arc<MetricsCollector>,
|
metrics: Arc<MetricsCollector>,
|
||||||
cancel: &CancellationToken,
|
cancel: &CancellationToken,
|
||||||
|
real_client_addr: Option<SocketAddr>,
|
||||||
) -> anyhow::Result<()> {
|
) -> anyhow::Result<()> {
|
||||||
let remote_addr = connection.remote_address();
|
let effective_addr = real_client_addr.unwrap_or_else(|| connection.remote_address());
|
||||||
let route_id = route.name.as_deref().or(route.id.as_deref());
|
let route_id = route.name.as_deref().or(route.id.as_deref());
|
||||||
let metrics_arc = metrics;
|
let metrics_arc = metrics;
|
||||||
|
|
||||||
@@ -194,7 +520,7 @@ async fn handle_quic_stream_forwarding(
|
|||||||
Err(quinn::ConnectionError::ApplicationClosed(_)) => break,
|
Err(quinn::ConnectionError::ApplicationClosed(_)) => break,
|
||||||
Err(quinn::ConnectionError::LocallyClosed) => break,
|
Err(quinn::ConnectionError::LocallyClosed) => break,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
debug!("QUIC stream accept error from {}: {}", remote_addr, e);
|
debug!("QUIC stream accept error from {}: {}", effective_addr, e);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -202,9 +528,10 @@ async fn handle_quic_stream_forwarding(
|
|||||||
};
|
};
|
||||||
|
|
||||||
let backend_addr = backend_addr.clone();
|
let backend_addr = backend_addr.clone();
|
||||||
let ip_str = remote_addr.ip().to_string();
|
let ip_str = effective_addr.ip().to_string();
|
||||||
let stream_metrics = Arc::clone(&metrics_arc);
|
let stream_metrics = Arc::clone(&metrics_arc);
|
||||||
let stream_route_id = route_id.map(|s| s.to_string());
|
let stream_route_id = route_id.map(|s| s.to_string());
|
||||||
|
let stream_cancel = cancel.child_token();
|
||||||
|
|
||||||
// Spawn a task for each QUIC stream → TCP bidirectional forwarding
|
// Spawn a task for each QUIC stream → TCP bidirectional forwarding
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
@@ -212,6 +539,7 @@ async fn handle_quic_stream_forwarding(
|
|||||||
send_stream,
|
send_stream,
|
||||||
recv_stream,
|
recv_stream,
|
||||||
&backend_addr,
|
&backend_addr,
|
||||||
|
stream_cancel,
|
||||||
).await {
|
).await {
|
||||||
Ok((bytes_in, bytes_out)) => {
|
Ok((bytes_in, bytes_out)) => {
|
||||||
stream_metrics.record_bytes(
|
stream_metrics.record_bytes(
|
||||||
@@ -232,54 +560,115 @@ async fn handle_quic_stream_forwarding(
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Forward a single QUIC bidirectional stream to a TCP backend connection.
|
/// Forward a single QUIC bidirectional stream to a TCP backend connection.
|
||||||
|
///
|
||||||
|
/// Includes inactivity timeout (60s), max lifetime (10min), and cancellation
|
||||||
|
/// to prevent leaked stream tasks when the parent connection closes.
|
||||||
async fn forward_quic_stream_to_tcp(
|
async fn forward_quic_stream_to_tcp(
|
||||||
mut quic_send: quinn::SendStream,
|
mut quic_send: quinn::SendStream,
|
||||||
mut quic_recv: quinn::RecvStream,
|
mut quic_recv: quinn::RecvStream,
|
||||||
backend_addr: &str,
|
backend_addr: &str,
|
||||||
|
cancel: CancellationToken,
|
||||||
) -> anyhow::Result<(u64, u64)> {
|
) -> anyhow::Result<(u64, u64)> {
|
||||||
|
let inactivity_timeout = std::time::Duration::from_secs(60);
|
||||||
|
let max_lifetime = std::time::Duration::from_secs(600);
|
||||||
|
|
||||||
// Connect to backend TCP
|
// Connect to backend TCP
|
||||||
let tcp_stream = tokio::net::TcpStream::connect(backend_addr).await?;
|
let tcp_stream = tokio::net::TcpStream::connect(backend_addr).await?;
|
||||||
let (mut tcp_read, mut tcp_write) = tcp_stream.into_split();
|
let (mut tcp_read, mut tcp_write) = tcp_stream.into_split();
|
||||||
|
|
||||||
// Bidirectional copy
|
let last_activity = Arc::new(AtomicU64::new(0));
|
||||||
let client_to_backend = tokio::io::copy(&mut quic_recv, &mut tcp_write);
|
let start = std::time::Instant::now();
|
||||||
let backend_to_client = tokio::io::copy(&mut tcp_read, &mut quic_send);
|
let conn_cancel = CancellationToken::new();
|
||||||
|
|
||||||
let (c2b, b2c) = tokio::join!(client_to_backend, backend_to_client);
|
let la1 = Arc::clone(&last_activity);
|
||||||
|
let cc1 = conn_cancel.clone();
|
||||||
|
let c2b = tokio::spawn(async move {
|
||||||
|
let mut buf = vec![0u8; 65536];
|
||||||
|
let mut total = 0u64;
|
||||||
|
loop {
|
||||||
|
let n = tokio::select! {
|
||||||
|
result = quic_recv.read(&mut buf) => match result {
|
||||||
|
Ok(Some(0)) | Ok(None) | Err(_) => break,
|
||||||
|
Ok(Some(n)) => n,
|
||||||
|
},
|
||||||
|
_ = cc1.cancelled() => break,
|
||||||
|
};
|
||||||
|
if tcp_write.write_all(&buf[..n]).await.is_err() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
total += n as u64;
|
||||||
|
la1.store(start.elapsed().as_millis() as u64, Ordering::Relaxed);
|
||||||
|
}
|
||||||
|
let _ = tokio::time::timeout(
|
||||||
|
std::time::Duration::from_secs(2),
|
||||||
|
tcp_write.shutdown(),
|
||||||
|
).await;
|
||||||
|
total
|
||||||
|
});
|
||||||
|
|
||||||
let bytes_in = c2b.unwrap_or(0);
|
let la2 = Arc::clone(&last_activity);
|
||||||
let bytes_out = b2c.unwrap_or(0);
|
let cc2 = conn_cancel.clone();
|
||||||
|
let b2c = tokio::spawn(async move {
|
||||||
// Graceful shutdown
|
let mut buf = vec![0u8; 65536];
|
||||||
|
let mut total = 0u64;
|
||||||
|
loop {
|
||||||
|
let n = tokio::select! {
|
||||||
|
result = tcp_read.read(&mut buf) => match result {
|
||||||
|
Ok(0) | Err(_) => break,
|
||||||
|
Ok(n) => n,
|
||||||
|
},
|
||||||
|
_ = cc2.cancelled() => break,
|
||||||
|
};
|
||||||
|
// quinn SendStream implements AsyncWrite
|
||||||
|
if quic_send.write_all(&buf[..n]).await.is_err() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
total += n as u64;
|
||||||
|
la2.store(start.elapsed().as_millis() as u64, Ordering::Relaxed);
|
||||||
|
}
|
||||||
let _ = quic_send.finish();
|
let _ = quic_send.finish();
|
||||||
let _ = tcp_write.shutdown().await;
|
total
|
||||||
|
});
|
||||||
|
|
||||||
|
// Watchdog: inactivity, max lifetime, and cancellation
|
||||||
|
let la_watch = Arc::clone(&last_activity);
|
||||||
|
let c2b_abort = c2b.abort_handle();
|
||||||
|
let b2c_abort = b2c.abort_handle();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let check_interval = std::time::Duration::from_secs(5);
|
||||||
|
let mut last_seen = 0u64;
|
||||||
|
loop {
|
||||||
|
tokio::select! {
|
||||||
|
_ = cancel.cancelled() => break,
|
||||||
|
_ = tokio::time::sleep(check_interval) => {
|
||||||
|
if start.elapsed() >= max_lifetime {
|
||||||
|
debug!("QUIC stream exceeded max lifetime, closing");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
let current = la_watch.load(Ordering::Relaxed);
|
||||||
|
if current == last_seen {
|
||||||
|
let elapsed = start.elapsed().as_millis() as u64 - current;
|
||||||
|
if elapsed >= inactivity_timeout.as_millis() as u64 {
|
||||||
|
debug!("QUIC stream inactive for {}ms, closing", elapsed);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
last_seen = current;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
conn_cancel.cancel();
|
||||||
|
tokio::time::sleep(std::time::Duration::from_secs(4)).await;
|
||||||
|
c2b_abort.abort();
|
||||||
|
b2c_abort.abort();
|
||||||
|
});
|
||||||
|
|
||||||
|
let bytes_in = c2b.await.unwrap_or(0);
|
||||||
|
let bytes_out = b2c.await.unwrap_or(0);
|
||||||
|
|
||||||
Ok((bytes_in, bytes_out))
|
Ok((bytes_in, bytes_out))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Placeholder for HTTP/3 connection handling (Phase 5).
|
|
||||||
///
|
|
||||||
/// Once h3_service is implemented, this will delegate to it.
|
|
||||||
async fn handle_h3_connection(
|
|
||||||
connection: quinn::Connection,
|
|
||||||
_route: RouteConfig,
|
|
||||||
_port: u16,
|
|
||||||
_metrics: &MetricsCollector,
|
|
||||||
cancel: &CancellationToken,
|
|
||||||
) -> anyhow::Result<()> {
|
|
||||||
warn!("HTTP/3 handling not yet fully implemented — accepting connection but no request processing");
|
|
||||||
|
|
||||||
// Keep the connection alive until cancelled or closed
|
|
||||||
tokio::select! {
|
|
||||||
_ = cancel.cancelled() => {}
|
|
||||||
reason = connection.closed() => {
|
|
||||||
debug!("HTTP/3 connection closed: {}", reason);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|||||||
@@ -428,6 +428,11 @@ impl TcpListenerManager {
|
|||||||
self.http_proxy.prune_stale_routes(active_route_ids);
|
self.http_proxy.prune_stale_routes(active_route_ids);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Get a reference to the HTTP proxy service (shared with H3).
|
||||||
|
pub fn http_proxy(&self) -> &Arc<HttpProxyService> {
|
||||||
|
&self.http_proxy
|
||||||
|
}
|
||||||
|
|
||||||
/// Get a reference to the connection tracker.
|
/// Get a reference to the connection tracker.
|
||||||
pub fn conn_tracker(&self) -> &Arc<ConnectionTracker> {
|
pub fn conn_tracker(&self) -> &Arc<ConnectionTracker> {
|
||||||
&self.conn_tracker
|
&self.conn_tracker
|
||||||
|
|||||||
@@ -2,12 +2,17 @@
|
|||||||
//!
|
//!
|
||||||
//! Binds UDP sockets on configured ports, receives datagrams, matches routes,
|
//! Binds UDP sockets on configured ports, receives datagrams, matches routes,
|
||||||
//! tracks sessions (flows), and forwards datagrams to backend UDP sockets.
|
//! tracks sessions (flows), and forwards datagrams to backend UDP sockets.
|
||||||
|
//!
|
||||||
|
//! Supports PROXY protocol v2 on both raw UDP and QUIC paths when `proxy_ips`
|
||||||
|
//! is configured. For QUIC, a relay layer intercepts datagrams before they
|
||||||
|
//! reach the quinn endpoint.
|
||||||
|
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::net::SocketAddr;
|
use std::net::{IpAddr, SocketAddr};
|
||||||
use std::sync::atomic::Ordering;
|
use std::sync::atomic::Ordering;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use dashmap::DashMap;
|
||||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||||
|
|
||||||
use arc_swap::ArcSwap;
|
use arc_swap::ArcSwap;
|
||||||
@@ -21,13 +26,15 @@ use rustproxy_config::{RouteActionType, TransportProtocol};
|
|||||||
use rustproxy_metrics::MetricsCollector;
|
use rustproxy_metrics::MetricsCollector;
|
||||||
use rustproxy_routing::{MatchContext, RouteManager};
|
use rustproxy_routing::{MatchContext, RouteManager};
|
||||||
|
|
||||||
|
use rustproxy_http::h3_service::H3ProxyService;
|
||||||
|
|
||||||
use crate::connection_tracker::ConnectionTracker;
|
use crate::connection_tracker::ConnectionTracker;
|
||||||
use crate::udp_session::{SessionKey, UdpSession, UdpSessionConfig, UdpSessionTable};
|
use crate::udp_session::{SessionKey, UdpSession, UdpSessionConfig, UdpSessionTable};
|
||||||
|
|
||||||
/// Manages UDP listeners across all configured ports.
|
/// Manages UDP listeners across all configured ports.
|
||||||
pub struct UdpListenerManager {
|
pub struct UdpListenerManager {
|
||||||
/// Port → recv loop task handle
|
/// Port → (recv loop task handle, optional QUIC endpoint for TLS updates)
|
||||||
listeners: HashMap<u16, JoinHandle<()>>,
|
listeners: HashMap<u16, (JoinHandle<()>, Option<quinn::Endpoint>)>,
|
||||||
/// Hot-reloadable route table
|
/// Hot-reloadable route table
|
||||||
route_manager: Arc<ArcSwap<RouteManager>>,
|
route_manager: Arc<ArcSwap<RouteManager>>,
|
||||||
/// Shared metrics collector
|
/// Shared metrics collector
|
||||||
@@ -44,13 +51,21 @@ pub struct UdpListenerManager {
|
|||||||
relay_writer: Arc<Mutex<Option<tokio::net::unix::OwnedWriteHalf>>>,
|
relay_writer: Arc<Mutex<Option<tokio::net::unix::OwnedWriteHalf>>>,
|
||||||
/// Cancel token for the current relay reply reader task
|
/// Cancel token for the current relay reply reader task
|
||||||
relay_reader_cancel: Option<CancellationToken>,
|
relay_reader_cancel: Option<CancellationToken>,
|
||||||
|
/// H3 proxy service for HTTP/3 request handling
|
||||||
|
h3_service: Option<Arc<H3ProxyService>>,
|
||||||
|
/// Trusted proxy IPs that may send PROXY protocol v2 headers.
|
||||||
|
/// When non-empty, PROXY v2 detection is enabled on both raw UDP and QUIC paths.
|
||||||
|
proxy_ips: Arc<Vec<IpAddr>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Drop for UdpListenerManager {
|
impl Drop for UdpListenerManager {
|
||||||
fn drop(&mut self) {
|
fn drop(&mut self) {
|
||||||
self.cancel_token.cancel();
|
self.cancel_token.cancel();
|
||||||
for (_, handle) in self.listeners.drain() {
|
for (_, (handle, endpoint)) in self.listeners.drain() {
|
||||||
handle.abort();
|
handle.abort();
|
||||||
|
if let Some(ep) = endpoint {
|
||||||
|
ep.close(quinn::VarInt::from_u32(0), b"shutdown");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -72,9 +87,24 @@ impl UdpListenerManager {
|
|||||||
datagram_handler_relay: Arc::new(RwLock::new(None)),
|
datagram_handler_relay: Arc::new(RwLock::new(None)),
|
||||||
relay_writer: Arc::new(Mutex::new(None)),
|
relay_writer: Arc::new(Mutex::new(None)),
|
||||||
relay_reader_cancel: None,
|
relay_reader_cancel: None,
|
||||||
|
h3_service: None,
|
||||||
|
proxy_ips: Arc::new(Vec::new()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Set the trusted proxy IPs for PROXY protocol v2 detection.
|
||||||
|
pub fn set_proxy_ips(&mut self, ips: Vec<IpAddr>) {
|
||||||
|
if !ips.is_empty() {
|
||||||
|
info!("UDP/QUIC PROXY protocol v2 enabled for {} trusted IPs", ips.len());
|
||||||
|
}
|
||||||
|
self.proxy_ips = Arc::new(ips);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Set the H3 proxy service for HTTP/3 request handling.
|
||||||
|
pub fn set_h3_service(&mut self, svc: Arc<H3ProxyService>) {
|
||||||
|
self.h3_service = Some(svc);
|
||||||
|
}
|
||||||
|
|
||||||
/// Update the route manager (for hot-reload).
|
/// Update the route manager (for hot-reload).
|
||||||
pub fn update_routes(&self, route_manager: Arc<RouteManager>) {
|
pub fn update_routes(&self, route_manager: Arc<RouteManager>) {
|
||||||
self.route_manager.store(route_manager);
|
self.route_manager.store(route_manager);
|
||||||
@@ -109,8 +139,10 @@ impl UdpListenerManager {
|
|||||||
|
|
||||||
if has_quic {
|
if has_quic {
|
||||||
if let Some(tls) = tls_config {
|
if let Some(tls) = tls_config {
|
||||||
// Create QUIC endpoint
|
if self.proxy_ips.is_empty() {
|
||||||
|
// Direct path: quinn owns the external socket (zero overhead)
|
||||||
let endpoint = crate::quic_handler::create_quic_endpoint(port, tls)?;
|
let endpoint = crate::quic_handler::create_quic_endpoint(port, tls)?;
|
||||||
|
let endpoint_for_updates = endpoint.clone();
|
||||||
let handle = tokio::spawn(crate::quic_handler::quic_accept_loop(
|
let handle = tokio::spawn(crate::quic_handler::quic_accept_loop(
|
||||||
endpoint,
|
endpoint,
|
||||||
port,
|
port,
|
||||||
@@ -118,9 +150,33 @@ impl UdpListenerManager {
|
|||||||
Arc::clone(&self.metrics),
|
Arc::clone(&self.metrics),
|
||||||
Arc::clone(&self.conn_tracker),
|
Arc::clone(&self.conn_tracker),
|
||||||
self.cancel_token.child_token(),
|
self.cancel_token.child_token(),
|
||||||
|
self.h3_service.clone(),
|
||||||
|
None,
|
||||||
));
|
));
|
||||||
self.listeners.insert(port, handle);
|
self.listeners.insert(port, (handle, Some(endpoint_for_updates)));
|
||||||
info!("QUIC endpoint started on port {}", port);
|
info!("QUIC endpoint started on port {}", port);
|
||||||
|
} else {
|
||||||
|
// Proxy relay path: we own external socket, quinn on localhost
|
||||||
|
let relay = crate::quic_handler::create_quic_endpoint_with_proxy_relay(
|
||||||
|
port,
|
||||||
|
tls,
|
||||||
|
Arc::clone(&self.proxy_ips),
|
||||||
|
self.cancel_token.child_token(),
|
||||||
|
)?;
|
||||||
|
let endpoint_for_updates = relay.endpoint.clone();
|
||||||
|
let handle = tokio::spawn(crate::quic_handler::quic_accept_loop(
|
||||||
|
relay.endpoint,
|
||||||
|
port,
|
||||||
|
Arc::clone(&self.route_manager),
|
||||||
|
Arc::clone(&self.metrics),
|
||||||
|
Arc::clone(&self.conn_tracker),
|
||||||
|
self.cancel_token.child_token(),
|
||||||
|
self.h3_service.clone(),
|
||||||
|
Some(relay.real_client_map),
|
||||||
|
));
|
||||||
|
self.listeners.insert(port, (handle, Some(endpoint_for_updates)));
|
||||||
|
info!("QUIC endpoint with PROXY relay started on port {}", port);
|
||||||
|
}
|
||||||
return Ok(());
|
return Ok(());
|
||||||
} else {
|
} else {
|
||||||
warn!("QUIC routes on port {} but no TLS config provided, falling back to raw UDP", port);
|
warn!("QUIC routes on port {} but no TLS config provided, falling back to raw UDP", port);
|
||||||
@@ -143,9 +199,10 @@ impl UdpListenerManager {
|
|||||||
Arc::clone(&self.datagram_handler_relay),
|
Arc::clone(&self.datagram_handler_relay),
|
||||||
Arc::clone(&self.relay_writer),
|
Arc::clone(&self.relay_writer),
|
||||||
self.cancel_token.child_token(),
|
self.cancel_token.child_token(),
|
||||||
|
Arc::clone(&self.proxy_ips),
|
||||||
));
|
));
|
||||||
|
|
||||||
self.listeners.insert(port, handle);
|
self.listeners.insert(port, (handle, None));
|
||||||
|
|
||||||
// Start the session cleanup task if this is the first port
|
// Start the session cleanup task if this is the first port
|
||||||
if self.listeners.len() == 1 {
|
if self.listeners.len() == 1 {
|
||||||
@@ -157,8 +214,11 @@ impl UdpListenerManager {
|
|||||||
|
|
||||||
/// Stop listening on a UDP port.
|
/// Stop listening on a UDP port.
|
||||||
pub fn remove_port(&mut self, port: u16) {
|
pub fn remove_port(&mut self, port: u16) {
|
||||||
if let Some(handle) = self.listeners.remove(&port) {
|
if let Some((handle, endpoint)) = self.listeners.remove(&port) {
|
||||||
handle.abort();
|
handle.abort();
|
||||||
|
if let Some(ep) = endpoint {
|
||||||
|
ep.close(quinn::VarInt::from_u32(0), b"port removed");
|
||||||
|
}
|
||||||
info!("UDP listener removed from port {}", port);
|
info!("UDP listener removed from port {}", port);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -173,14 +233,180 @@ impl UdpListenerManager {
|
|||||||
/// Stop all listeners and clean up.
|
/// Stop all listeners and clean up.
|
||||||
pub async fn stop(&mut self) {
|
pub async fn stop(&mut self) {
|
||||||
self.cancel_token.cancel();
|
self.cancel_token.cancel();
|
||||||
for (port, handle) in self.listeners.drain() {
|
for (port, (handle, endpoint)) in self.listeners.drain() {
|
||||||
handle.abort();
|
handle.abort();
|
||||||
|
if let Some(ep) = endpoint {
|
||||||
|
ep.close(quinn::VarInt::from_u32(0), b"shutdown");
|
||||||
|
}
|
||||||
debug!("UDP listener stopped on port {}", port);
|
debug!("UDP listener stopped on port {}", port);
|
||||||
}
|
}
|
||||||
info!("All UDP listeners stopped, {} sessions remaining",
|
info!("All UDP listeners stopped, {} sessions remaining",
|
||||||
self.session_table.session_count());
|
self.session_table.session_count());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Update TLS config on all active QUIC endpoints (cert refresh).
|
||||||
|
/// Only affects new incoming connections — existing connections are undisturbed.
|
||||||
|
/// Uses quinn's Endpoint::set_server_config() for zero-downtime hot-swap.
|
||||||
|
pub fn update_quic_tls(&self, tls_config: Arc<rustls::ServerConfig>) {
|
||||||
|
for (port, (_handle, endpoint)) in &self.listeners {
|
||||||
|
if let Some(ep) = endpoint {
|
||||||
|
match quinn::crypto::rustls::QuicServerConfig::try_from(Arc::clone(&tls_config)) {
|
||||||
|
Ok(quic_crypto) => {
|
||||||
|
let server_config = quinn::ServerConfig::with_crypto(Arc::new(quic_crypto));
|
||||||
|
ep.set_server_config(Some(server_config));
|
||||||
|
info!("Updated QUIC TLS config on port {}", port);
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
warn!("Failed to update QUIC TLS config on port {}: {}", port, e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Upgrade raw UDP fallback listeners to QUIC endpoints.
|
||||||
|
///
|
||||||
|
/// At startup, if no TLS certs are available, QUIC routes fall back to raw UDP.
|
||||||
|
/// When certs become available later (via loadCertificate IPC or ACME), this method
|
||||||
|
/// stops the raw UDP listener, drains sessions, and creates a proper QUIC endpoint.
|
||||||
|
///
|
||||||
|
/// This is idempotent — ports that already have QUIC endpoints are skipped.
|
||||||
|
pub async fn upgrade_raw_to_quic(&mut self, tls_config: Arc<rustls::ServerConfig>) {
|
||||||
|
// Find ports that are raw UDP fallback (endpoint=None) but have QUIC routes
|
||||||
|
let rm = self.route_manager.load();
|
||||||
|
let upgrade_ports: Vec<u16> = self.listeners.iter()
|
||||||
|
.filter(|(_, (_, endpoint))| endpoint.is_none())
|
||||||
|
.filter(|(port, _)| {
|
||||||
|
rm.routes_for_port(**port).iter().any(|r| {
|
||||||
|
r.action.udp.as_ref()
|
||||||
|
.and_then(|u| u.quic.as_ref())
|
||||||
|
.is_some()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.map(|(port, _)| *port)
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
for port in upgrade_ports {
|
||||||
|
info!("Upgrading raw UDP listener on port {} to QUIC endpoint", port);
|
||||||
|
|
||||||
|
// Stop the raw UDP listener task and drain sessions to release the socket
|
||||||
|
if let Some((handle, _)) = self.listeners.remove(&port) {
|
||||||
|
handle.abort();
|
||||||
|
}
|
||||||
|
let drained = self.session_table.drain_port(
|
||||||
|
port, &self.metrics, &self.conn_tracker,
|
||||||
|
);
|
||||||
|
if drained > 0 {
|
||||||
|
debug!("Drained {} UDP sessions on port {} for QUIC upgrade", drained, port);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Brief yield to let aborted tasks drop their socket references
|
||||||
|
tokio::task::yield_now().await;
|
||||||
|
|
||||||
|
// Create QUIC endpoint on the now-free port
|
||||||
|
let create_result = if self.proxy_ips.is_empty() {
|
||||||
|
self.create_quic_direct(port, Arc::clone(&tls_config))
|
||||||
|
} else {
|
||||||
|
self.create_quic_with_relay(port, Arc::clone(&tls_config))
|
||||||
|
};
|
||||||
|
|
||||||
|
match create_result {
|
||||||
|
Ok(()) => {
|
||||||
|
info!("QUIC endpoint started on port {} (upgraded from raw UDP)", port);
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
// Port may still be held — retry once after a brief delay
|
||||||
|
warn!("QUIC endpoint creation failed on port {}, retrying: {}", port, e);
|
||||||
|
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
|
||||||
|
|
||||||
|
let retry_result = if self.proxy_ips.is_empty() {
|
||||||
|
self.create_quic_direct(port, Arc::clone(&tls_config))
|
||||||
|
} else {
|
||||||
|
self.create_quic_with_relay(port, Arc::clone(&tls_config))
|
||||||
|
};
|
||||||
|
|
||||||
|
match retry_result {
|
||||||
|
Ok(()) => {
|
||||||
|
info!("QUIC endpoint started on port {} (upgraded from raw UDP, retry)", port);
|
||||||
|
}
|
||||||
|
Err(e2) => {
|
||||||
|
error!("Failed to upgrade port {} to QUIC after retry: {}. \
|
||||||
|
Rebinding as raw UDP.", port, e2);
|
||||||
|
// Fallback: rebind as raw UDP so the port isn't dead
|
||||||
|
if let Ok(()) = self.rebind_raw_udp(port).await {
|
||||||
|
warn!("Port {} rebound as raw UDP (QUIC upgrade failed)", port);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a direct QUIC endpoint (quinn owns the socket).
|
||||||
|
fn create_quic_direct(&mut self, port: u16, tls_config: Arc<rustls::ServerConfig>) -> anyhow::Result<()> {
|
||||||
|
let endpoint = crate::quic_handler::create_quic_endpoint(port, tls_config)?;
|
||||||
|
let endpoint_for_updates = endpoint.clone();
|
||||||
|
let handle = tokio::spawn(crate::quic_handler::quic_accept_loop(
|
||||||
|
endpoint,
|
||||||
|
port,
|
||||||
|
Arc::clone(&self.route_manager),
|
||||||
|
Arc::clone(&self.metrics),
|
||||||
|
Arc::clone(&self.conn_tracker),
|
||||||
|
self.cancel_token.child_token(),
|
||||||
|
self.h3_service.clone(),
|
||||||
|
None,
|
||||||
|
));
|
||||||
|
self.listeners.insert(port, (handle, Some(endpoint_for_updates)));
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a QUIC endpoint with PROXY protocol relay.
|
||||||
|
fn create_quic_with_relay(&mut self, port: u16, tls_config: Arc<rustls::ServerConfig>) -> anyhow::Result<()> {
|
||||||
|
let relay = crate::quic_handler::create_quic_endpoint_with_proxy_relay(
|
||||||
|
port,
|
||||||
|
tls_config,
|
||||||
|
Arc::clone(&self.proxy_ips),
|
||||||
|
self.cancel_token.child_token(),
|
||||||
|
)?;
|
||||||
|
let endpoint_for_updates = relay.endpoint.clone();
|
||||||
|
let handle = tokio::spawn(crate::quic_handler::quic_accept_loop(
|
||||||
|
relay.endpoint,
|
||||||
|
port,
|
||||||
|
Arc::clone(&self.route_manager),
|
||||||
|
Arc::clone(&self.metrics),
|
||||||
|
Arc::clone(&self.conn_tracker),
|
||||||
|
self.cancel_token.child_token(),
|
||||||
|
self.h3_service.clone(),
|
||||||
|
Some(relay.real_client_map),
|
||||||
|
));
|
||||||
|
self.listeners.insert(port, (handle, Some(endpoint_for_updates)));
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Rebind a port as a raw UDP listener (fallback when QUIC upgrade fails).
|
||||||
|
async fn rebind_raw_udp(&mut self, port: u16) -> anyhow::Result<()> {
|
||||||
|
let addr: std::net::SocketAddr = ([0, 0, 0, 0], port).into();
|
||||||
|
let socket = UdpSocket::bind(addr).await?;
|
||||||
|
let socket = Arc::new(socket);
|
||||||
|
|
||||||
|
let handle = tokio::spawn(Self::recv_loop(
|
||||||
|
socket,
|
||||||
|
port,
|
||||||
|
Arc::clone(&self.route_manager),
|
||||||
|
Arc::clone(&self.metrics),
|
||||||
|
Arc::clone(&self.conn_tracker),
|
||||||
|
Arc::clone(&self.session_table),
|
||||||
|
Arc::clone(&self.datagram_handler_relay),
|
||||||
|
Arc::clone(&self.relay_writer),
|
||||||
|
self.cancel_token.child_token(),
|
||||||
|
Arc::clone(&self.proxy_ips),
|
||||||
|
));
|
||||||
|
|
||||||
|
self.listeners.insert(port, (handle, None));
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
/// Set the datagram handler relay socket path and establish connection.
|
/// Set the datagram handler relay socket path and establish connection.
|
||||||
pub async fn set_datagram_handler_relay(&mut self, path: String) {
|
pub async fn set_datagram_handler_relay(&mut self, path: String) {
|
||||||
// Cancel previous relay reader task if any
|
// Cancel previous relay reader task if any
|
||||||
@@ -255,6 +481,10 @@ impl UdpListenerManager {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Main receive loop for a UDP port.
|
/// Main receive loop for a UDP port.
|
||||||
|
///
|
||||||
|
/// When `proxy_ips` is non-empty, the first datagram from a trusted proxy IP
|
||||||
|
/// is checked for PROXY protocol v2. If found, the real client IP is extracted
|
||||||
|
/// and used for all subsequent session handling for that source address.
|
||||||
async fn recv_loop(
|
async fn recv_loop(
|
||||||
socket: Arc<UdpSocket>,
|
socket: Arc<UdpSocket>,
|
||||||
port: u16,
|
port: u16,
|
||||||
@@ -265,11 +495,38 @@ impl UdpListenerManager {
|
|||||||
_datagram_handler_relay: Arc<RwLock<Option<String>>>,
|
_datagram_handler_relay: Arc<RwLock<Option<String>>>,
|
||||||
relay_writer: Arc<Mutex<Option<tokio::net::unix::OwnedWriteHalf>>>,
|
relay_writer: Arc<Mutex<Option<tokio::net::unix::OwnedWriteHalf>>>,
|
||||||
cancel: CancellationToken,
|
cancel: CancellationToken,
|
||||||
|
proxy_ips: Arc<Vec<IpAddr>>,
|
||||||
) {
|
) {
|
||||||
// Use a reasonably large buffer; actual max is per-route but we need a single buffer
|
// Use a reasonably large buffer; actual max is per-route but we need a single buffer
|
||||||
let mut buf = vec![0u8; 65535];
|
let mut buf = vec![0u8; 65535];
|
||||||
|
|
||||||
|
// Maps proxy source addr → real client addr (from PROXY v2 headers).
|
||||||
|
// Only populated when proxy_ips is non-empty.
|
||||||
|
let proxy_addr_map: DashMap<SocketAddr, SocketAddr> = DashMap::new();
|
||||||
|
|
||||||
|
// Periodic cleanup for proxy_addr_map to prevent unbounded growth
|
||||||
|
let mut last_proxy_cleanup = tokio::time::Instant::now();
|
||||||
|
let proxy_cleanup_interval = std::time::Duration::from_secs(60);
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
|
// Periodic cleanup: remove proxy_addr_map entries with no active session
|
||||||
|
if !proxy_addr_map.is_empty() && last_proxy_cleanup.elapsed() >= proxy_cleanup_interval {
|
||||||
|
last_proxy_cleanup = tokio::time::Instant::now();
|
||||||
|
let stale: Vec<SocketAddr> = proxy_addr_map.iter()
|
||||||
|
.filter(|entry| {
|
||||||
|
let key: SessionKey = (*entry.key(), port);
|
||||||
|
session_table.get(&key).is_none()
|
||||||
|
})
|
||||||
|
.map(|entry| *entry.key())
|
||||||
|
.collect();
|
||||||
|
if !stale.is_empty() {
|
||||||
|
debug!("UDP proxy_addr_map cleanup: removing {} stale entries on port {}", stale.len(), port);
|
||||||
|
for addr in stale {
|
||||||
|
proxy_addr_map.remove(&addr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
let (len, client_addr) = tokio::select! {
|
let (len, client_addr) = tokio::select! {
|
||||||
_ = cancel.cancelled() => {
|
_ = cancel.cancelled() => {
|
||||||
debug!("UDP recv loop on port {} cancelled", port);
|
debug!("UDP recv loop on port {} cancelled", port);
|
||||||
@@ -288,9 +545,39 @@ impl UdpListenerManager {
|
|||||||
|
|
||||||
let datagram = &buf[..len];
|
let datagram = &buf[..len];
|
||||||
|
|
||||||
// Route matching
|
// PROXY protocol v2 detection for datagrams from trusted proxy IPs
|
||||||
|
let effective_client_ip = if !proxy_ips.is_empty() && proxy_ips.contains(&client_addr.ip()) {
|
||||||
|
let session_key: SessionKey = (client_addr, port);
|
||||||
|
if session_table.get(&session_key).is_none() && !proxy_addr_map.contains_key(&client_addr) {
|
||||||
|
// No session and no prior PROXY header — check for PROXY v2
|
||||||
|
if crate::proxy_protocol::is_proxy_protocol_v2(datagram) {
|
||||||
|
match crate::proxy_protocol::parse_v2(datagram) {
|
||||||
|
Ok((header, _consumed)) => {
|
||||||
|
debug!("UDP PROXY v2 from {}: real client {}", client_addr, header.source_addr);
|
||||||
|
proxy_addr_map.insert(client_addr, header.source_addr);
|
||||||
|
continue; // discard the PROXY v2 datagram
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
debug!("UDP PROXY v2 parse error from {}: {}", client_addr, e);
|
||||||
|
client_addr.ip()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
client_addr.ip()
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Use real client IP if we've previously seen a PROXY v2 header
|
||||||
|
proxy_addr_map.get(&client_addr)
|
||||||
|
.map(|r| r.ip())
|
||||||
|
.unwrap_or_else(|| client_addr.ip())
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
client_addr.ip()
|
||||||
|
};
|
||||||
|
|
||||||
|
// Route matching — use effective (real) client IP
|
||||||
let rm = route_manager.load();
|
let rm = route_manager.load();
|
||||||
let ip_str = client_addr.ip().to_string();
|
let ip_str = effective_client_ip.to_string();
|
||||||
let ctx = MatchContext {
|
let ctx = MatchContext {
|
||||||
port,
|
port,
|
||||||
domain: None,
|
domain: None,
|
||||||
@@ -339,20 +626,21 @@ impl UdpListenerManager {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Session lookup or create
|
// Session lookup or create
|
||||||
|
// Session key uses the proxy's source addr for correct return-path routing
|
||||||
let session_key: SessionKey = (client_addr, port);
|
let session_key: SessionKey = (client_addr, port);
|
||||||
let session = match session_table.get(&session_key) {
|
let session = match session_table.get(&session_key) {
|
||||||
Some(s) => s,
|
Some(s) => s,
|
||||||
None => {
|
None => {
|
||||||
// New session — check per-IP limits
|
// New session — check per-IP limits using the real client IP
|
||||||
if !conn_tracker.try_accept(&client_addr.ip()) {
|
if !conn_tracker.try_accept(&effective_client_ip) {
|
||||||
debug!("UDP session rejected for {} (rate limit)", client_addr);
|
debug!("UDP session rejected for {} (rate limit)", effective_client_ip);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if !session_table.can_create_session(
|
if !session_table.can_create_session(
|
||||||
&client_addr.ip(),
|
&effective_client_ip,
|
||||||
udp_config.max_sessions_per_ip,
|
udp_config.max_sessions_per_ip,
|
||||||
) {
|
) {
|
||||||
debug!("UDP session rejected for {} (per-IP session limit)", client_addr);
|
debug!("UDP session rejected for {} (per-IP session limit)", effective_client_ip);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -385,8 +673,8 @@ impl UdpListenerManager {
|
|||||||
}
|
}
|
||||||
let backend_socket = Arc::new(backend_socket);
|
let backend_socket = Arc::new(backend_socket);
|
||||||
|
|
||||||
debug!("New UDP session: {} -> {} (via port {})",
|
debug!("New UDP session: {} -> {} (via port {}, real client {})",
|
||||||
client_addr, backend_addr, port);
|
client_addr, backend_addr, port, effective_client_ip);
|
||||||
|
|
||||||
// Spawn return-path relay task
|
// Spawn return-path relay task
|
||||||
let session_cancel = CancellationToken::new();
|
let session_cancel = CancellationToken::new();
|
||||||
@@ -406,7 +694,7 @@ impl UdpListenerManager {
|
|||||||
last_activity: std::sync::atomic::AtomicU64::new(session_table.elapsed_ms()),
|
last_activity: std::sync::atomic::AtomicU64::new(session_table.elapsed_ms()),
|
||||||
created_at: std::time::Instant::now(),
|
created_at: std::time::Instant::now(),
|
||||||
route_id: route_id.map(|s| s.to_string()),
|
route_id: route_id.map(|s| s.to_string()),
|
||||||
source_ip: client_addr.ip(),
|
source_ip: effective_client_ip,
|
||||||
client_addr,
|
client_addr,
|
||||||
return_task,
|
return_task,
|
||||||
cancel: session_cancel,
|
cancel: session_cancel,
|
||||||
@@ -417,8 +705,8 @@ impl UdpListenerManager {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Track in metrics
|
// Track in metrics using the real client IP
|
||||||
conn_tracker.connection_opened(&client_addr.ip());
|
conn_tracker.connection_opened(&effective_client_ip);
|
||||||
metrics.connection_opened(route_id, Some(&ip_str));
|
metrics.connection_opened(route_id, Some(&ip_str));
|
||||||
metrics.udp_session_opened();
|
metrics.udp_session_opened();
|
||||||
|
|
||||||
|
|||||||
@@ -201,6 +201,36 @@ impl UdpSessionTable {
|
|||||||
removed
|
removed
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Drain all sessions on a given listening port, releasing socket references.
|
||||||
|
/// Used when upgrading a raw UDP listener to QUIC — the raw UDP socket's
|
||||||
|
/// Arc refcount must drop to zero so the port can be rebound.
|
||||||
|
pub fn drain_port(
|
||||||
|
&self,
|
||||||
|
port: u16,
|
||||||
|
metrics: &MetricsCollector,
|
||||||
|
conn_tracker: &ConnectionTracker,
|
||||||
|
) -> usize {
|
||||||
|
let keys: Vec<SessionKey> = self.sessions.iter()
|
||||||
|
.filter(|entry| entry.key().1 == port)
|
||||||
|
.map(|entry| *entry.key())
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let mut removed = 0;
|
||||||
|
for key in keys {
|
||||||
|
if let Some(session) = self.remove(&key) {
|
||||||
|
session.cancel.cancel();
|
||||||
|
conn_tracker.connection_closed(&session.source_ip);
|
||||||
|
metrics.connection_closed(
|
||||||
|
session.route_id.as_deref(),
|
||||||
|
Some(&session.source_ip.to_string()),
|
||||||
|
);
|
||||||
|
metrics.udp_session_closed();
|
||||||
|
removed += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
removed
|
||||||
|
}
|
||||||
|
|
||||||
/// Total number of active sessions.
|
/// Total number of active sessions.
|
||||||
pub fn session_count(&self) -> usize {
|
pub fn session_count(&self) -> usize {
|
||||||
self.sessions.len()
|
self.sessions.len()
|
||||||
|
|||||||
@@ -122,6 +122,11 @@ impl RouteManager {
|
|||||||
// This prevents session-ticket resumption from misrouting when clients
|
// This prevents session-ticket resumption from misrouting when clients
|
||||||
// omit SNI (RFC 8446 recommends but doesn't mandate SNI on resumption).
|
// omit SNI (RFC 8446 recommends but doesn't mandate SNI on resumption).
|
||||||
// Wildcard-only routes (domains: ["*"]) still match since they accept all.
|
// Wildcard-only routes (domains: ["*"]) still match since they accept all.
|
||||||
|
//
|
||||||
|
// Exception: QUIC (UDP transport) encrypts the TLS ClientHello, so SNI
|
||||||
|
// is unavailable at accept time. Domain verification happens per-request
|
||||||
|
// in H3ProxyService via the :authority header.
|
||||||
|
if ctx.transport != Some(TransportProtocol::Udp) {
|
||||||
let patterns = domains.to_vec();
|
let patterns = domains.to_vec();
|
||||||
let is_wildcard_only = patterns.iter().all(|d| *d == "*");
|
let is_wildcard_only = patterns.iter().all(|d| *d == "*");
|
||||||
if !is_wildcard_only {
|
if !is_wildcard_only {
|
||||||
@@ -129,6 +134,7 @@ impl RouteManager {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Path matching
|
// Path matching
|
||||||
if let Some(ref pattern) = rm.path {
|
if let Some(ref pattern) = rm.path {
|
||||||
@@ -997,4 +1003,52 @@ mod tests {
|
|||||||
let result = manager.find_route(&udp_ctx).unwrap();
|
let result = manager.find_route(&udp_ctx).unwrap();
|
||||||
assert_eq!(result.route.name.as_deref(), Some("udp-route"));
|
assert_eq!(result.route.name.as_deref(), Some("udp-route"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_quic_tls_no_sni_matches_domain_restricted_route() {
|
||||||
|
// QUIC accept-level matching: is_tls=true, domain=None, transport=Udp.
|
||||||
|
// Should match because QUIC encrypts the ClientHello — SNI is unavailable
|
||||||
|
// at accept time but verified per-request in H3ProxyService.
|
||||||
|
let mut route = make_route(443, Some("example.com"), 0);
|
||||||
|
route.route_match.transport = Some(TransportProtocol::Udp);
|
||||||
|
let routes = vec![route];
|
||||||
|
let manager = RouteManager::new(routes);
|
||||||
|
|
||||||
|
let ctx = MatchContext {
|
||||||
|
port: 443,
|
||||||
|
domain: None,
|
||||||
|
path: None,
|
||||||
|
client_ip: None,
|
||||||
|
tls_version: None,
|
||||||
|
headers: None,
|
||||||
|
is_tls: true,
|
||||||
|
protocol: Some("quic"),
|
||||||
|
transport: Some(TransportProtocol::Udp),
|
||||||
|
};
|
||||||
|
|
||||||
|
assert!(manager.find_route(&ctx).is_some(),
|
||||||
|
"QUIC (UDP) with is_tls=true and domain=None should match domain-restricted routes");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_tcp_tls_no_sni_still_rejects_domain_restricted_route() {
|
||||||
|
// TCP TLS without SNI must still be rejected (no QUIC exemption).
|
||||||
|
let routes = vec![make_route(443, Some("example.com"), 0)];
|
||||||
|
let manager = RouteManager::new(routes);
|
||||||
|
|
||||||
|
let ctx = MatchContext {
|
||||||
|
port: 443,
|
||||||
|
domain: None,
|
||||||
|
path: None,
|
||||||
|
client_ip: None,
|
||||||
|
tls_version: None,
|
||||||
|
headers: None,
|
||||||
|
is_tls: true,
|
||||||
|
protocol: None,
|
||||||
|
transport: None, // TCP (default)
|
||||||
|
};
|
||||||
|
|
||||||
|
assert!(manager.find_route(&ctx).is_none(),
|
||||||
|
"TCP TLS without SNI should NOT match domain-restricted routes");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -44,3 +44,9 @@ mimalloc = { workspace = true }
|
|||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
rcgen = { workspace = true }
|
rcgen = { workspace = true }
|
||||||
|
quinn = { workspace = true }
|
||||||
|
h3 = { workspace = true }
|
||||||
|
h3-quinn = { workspace = true }
|
||||||
|
bytes = { workspace = true }
|
||||||
|
rustls = { workspace = true }
|
||||||
|
http = "1"
|
||||||
|
|||||||
@@ -264,6 +264,8 @@ impl RustProxy {
|
|||||||
conn_config.socket_timeout_ms,
|
conn_config.socket_timeout_ms,
|
||||||
conn_config.max_connection_lifetime_ms,
|
conn_config.max_connection_lifetime_ms,
|
||||||
);
|
);
|
||||||
|
// Clone proxy_ips before conn_config is moved into the TCP listener
|
||||||
|
let udp_proxy_ips = conn_config.proxy_ips.clone();
|
||||||
listener.set_connection_config(conn_config);
|
listener.set_connection_config(conn_config);
|
||||||
|
|
||||||
// Share the socket-handler relay path with the listener
|
// Share the socket-handler relay path with the listener
|
||||||
@@ -339,6 +341,13 @@ impl RustProxy {
|
|||||||
conn_tracker,
|
conn_tracker,
|
||||||
self.cancel_token.clone(),
|
self.cancel_token.clone(),
|
||||||
);
|
);
|
||||||
|
udp_mgr.set_proxy_ips(udp_proxy_ips.clone());
|
||||||
|
|
||||||
|
// Share HttpProxyService with H3 — same route matching, connection
|
||||||
|
// pool, and ALPN protocol detection as the TCP/HTTP path.
|
||||||
|
let http_proxy = self.listener_manager.as_ref().unwrap().http_proxy().clone();
|
||||||
|
let h3_svc = rustproxy_http::h3_service::H3ProxyService::new(http_proxy);
|
||||||
|
udp_mgr.set_h3_service(Arc::new(h3_svc));
|
||||||
|
|
||||||
for port in &udp_ports {
|
for port in &udp_ports {
|
||||||
udp_mgr.add_port_with_tls(*port, quic_tls_config.clone()).await?;
|
udp_mgr.add_port_with_tls(*port, quic_tls_config.clone()).await?;
|
||||||
@@ -763,22 +772,31 @@ impl RustProxy {
|
|||||||
if self.udp_listener_manager.is_none() {
|
if self.udp_listener_manager.is_none() {
|
||||||
if let Some(ref listener) = self.listener_manager {
|
if let Some(ref listener) = self.listener_manager {
|
||||||
let conn_tracker = listener.conn_tracker().clone();
|
let conn_tracker = listener.conn_tracker().clone();
|
||||||
self.udp_listener_manager = Some(UdpListenerManager::new(
|
let conn_config = Self::build_connection_config(&self.options);
|
||||||
|
let mut udp_mgr = UdpListenerManager::new(
|
||||||
Arc::clone(&new_manager),
|
Arc::clone(&new_manager),
|
||||||
Arc::clone(&self.metrics),
|
Arc::clone(&self.metrics),
|
||||||
conn_tracker,
|
conn_tracker,
|
||||||
self.cancel_token.clone(),
|
self.cancel_token.clone(),
|
||||||
));
|
);
|
||||||
|
udp_mgr.set_proxy_ips(conn_config.proxy_ips);
|
||||||
|
self.udp_listener_manager = Some(udp_mgr);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Build TLS config for QUIC (needed for new ports and upgrading existing raw UDP)
|
||||||
|
let quic_tls = {
|
||||||
|
let tls_configs = self.current_tls_configs().await;
|
||||||
|
Self::build_quic_tls_config(&tls_configs)
|
||||||
|
};
|
||||||
|
|
||||||
if let Some(ref mut udp_mgr) = self.udp_listener_manager {
|
if let Some(ref mut udp_mgr) = self.udp_listener_manager {
|
||||||
udp_mgr.update_routes(Arc::clone(&new_manager));
|
udp_mgr.update_routes(Arc::clone(&new_manager));
|
||||||
|
|
||||||
// Add new UDP ports
|
// Add new UDP ports (with TLS for QUIC)
|
||||||
for port in &new_udp_ports {
|
for port in &new_udp_ports {
|
||||||
if !old_udp_ports.contains(port) {
|
if !old_udp_ports.contains(port) {
|
||||||
udp_mgr.add_port(*port).await?;
|
udp_mgr.add_port_with_tls(*port, quic_tls.clone()).await?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Remove old UDP ports
|
// Remove old UDP ports
|
||||||
@@ -787,6 +805,12 @@ impl RustProxy {
|
|||||||
udp_mgr.remove_port(*port);
|
udp_mgr.remove_port(*port);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Upgrade existing raw UDP fallback listeners to QUIC if TLS is now available
|
||||||
|
if let Some(ref quic_config) = quic_tls {
|
||||||
|
udp_mgr.update_quic_tls(Arc::clone(quic_config));
|
||||||
|
udp_mgr.upgrade_raw_to_quic(Arc::clone(quic_config)).await;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else if self.udp_listener_manager.is_some() {
|
} else if self.udp_listener_manager.is_some() {
|
||||||
// All UDP routes removed — shut down UDP manager
|
// All UDP routes removed — shut down UDP manager
|
||||||
@@ -843,12 +867,12 @@ impl RustProxy {
|
|||||||
.map_err(|e| anyhow::anyhow!("ACME provisioning failed: {}", e))?;
|
.map_err(|e| anyhow::anyhow!("ACME provisioning failed: {}", e))?;
|
||||||
|
|
||||||
// Hot-swap into TLS configs
|
// Hot-swap into TLS configs
|
||||||
if let Some(ref mut listener) = self.listener_manager {
|
|
||||||
let mut tls_configs = Self::extract_tls_configs(&self.options.routes);
|
let mut tls_configs = Self::extract_tls_configs(&self.options.routes);
|
||||||
tls_configs.insert(domain.clone(), TlsCertConfig {
|
tls_configs.insert(domain.clone(), TlsCertConfig {
|
||||||
cert_pem: bundle.cert_pem.clone(),
|
cert_pem: bundle.cert_pem.clone(),
|
||||||
key_pem: bundle.key_pem.clone(),
|
key_pem: bundle.key_pem.clone(),
|
||||||
});
|
});
|
||||||
|
{
|
||||||
let cm = cm_arc.lock().await;
|
let cm = cm_arc.lock().await;
|
||||||
for (d, b) in cm.store().iter() {
|
for (d, b) in cm.store().iter() {
|
||||||
if !tls_configs.contains_key(d) {
|
if !tls_configs.contains_key(d) {
|
||||||
@@ -858,9 +882,22 @@ impl RustProxy {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let quic_tls = Self::build_quic_tls_config(&tls_configs);
|
||||||
|
|
||||||
|
if let Some(ref listener) = self.listener_manager {
|
||||||
listener.set_tls_configs(tls_configs);
|
listener.set_tls_configs(tls_configs);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Update existing QUIC endpoints and upgrade raw UDP fallback listeners
|
||||||
|
if let Some(ref mut udp_mgr) = self.udp_listener_manager {
|
||||||
|
if let Some(ref quic_config) = quic_tls {
|
||||||
|
udp_mgr.update_quic_tls(Arc::clone(quic_config));
|
||||||
|
udp_mgr.upgrade_raw_to_quic(Arc::clone(quic_config)).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
info!("Certificate provisioned and loaded for route '{}'", route_name);
|
info!("Certificate provisioned and loaded for route '{}'", route_name);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@@ -961,50 +998,58 @@ impl RustProxy {
|
|||||||
fn build_quic_tls_config(
|
fn build_quic_tls_config(
|
||||||
tls_configs: &HashMap<String, TlsCertConfig>,
|
tls_configs: &HashMap<String, TlsCertConfig>,
|
||||||
) -> Option<Arc<rustls::ServerConfig>> {
|
) -> Option<Arc<rustls::ServerConfig>> {
|
||||||
// Find the first available cert (prefer wildcard, then any)
|
if tls_configs.is_empty() {
|
||||||
let cert_config = tls_configs.get("*")
|
|
||||||
.or_else(|| tls_configs.values().next());
|
|
||||||
|
|
||||||
let cert_config = match cert_config {
|
|
||||||
Some(c) => c,
|
|
||||||
None => return None,
|
|
||||||
};
|
|
||||||
|
|
||||||
// Parse cert chain from PEM
|
|
||||||
let mut cert_reader = std::io::BufReader::new(cert_config.cert_pem.as_bytes());
|
|
||||||
let certs: Vec<rustls::pki_types::CertificateDer<'static>> =
|
|
||||||
rustls_pemfile::certs(&mut cert_reader)
|
|
||||||
.filter_map(|r| r.ok())
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
if certs.is_empty() {
|
|
||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse private key from PEM
|
// Reuse CertResolver for SNI-based cert selection (same as TCP/TLS path).
|
||||||
let mut key_reader = std::io::BufReader::new(cert_config.key_pem.as_bytes());
|
// This ensures QUIC connections get the correct certificate for each domain
|
||||||
let key = match rustls_pemfile::private_key(&mut key_reader) {
|
// instead of a single static cert.
|
||||||
Ok(Some(key)) => key,
|
let resolver = match rustproxy_passthrough::tls_handler::CertResolver::new(tls_configs) {
|
||||||
_ => return None,
|
Ok(r) => r,
|
||||||
};
|
|
||||||
|
|
||||||
let mut tls_config = match rustls::ServerConfig::builder()
|
|
||||||
.with_no_client_auth()
|
|
||||||
.with_single_cert(certs, key)
|
|
||||||
{
|
|
||||||
Ok(c) => c,
|
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
warn!("Failed to build QUIC TLS config: {}", e);
|
warn!("Failed to build QUIC cert resolver: {}", e);
|
||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let mut tls_config = rustls::ServerConfig::builder()
|
||||||
|
.with_no_client_auth()
|
||||||
|
.with_cert_resolver(Arc::new(resolver));
|
||||||
|
|
||||||
// QUIC requires h3 ALPN
|
// QUIC requires h3 ALPN
|
||||||
tls_config.alpn_protocols = vec![b"h3".to_vec()];
|
tls_config.alpn_protocols = vec![b"h3".to_vec()];
|
||||||
|
|
||||||
Some(Arc::new(tls_config))
|
Some(Arc::new(tls_config))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Build the current full TLS config map from all sources (route configs, loaded certs, cert manager).
|
||||||
|
async fn current_tls_configs(&self) -> HashMap<String, TlsCertConfig> {
|
||||||
|
let mut configs = Self::extract_tls_configs(&self.options.routes);
|
||||||
|
|
||||||
|
// Merge dynamically loaded certs (from loadCertificate IPC)
|
||||||
|
for (d, c) in &self.loaded_certs {
|
||||||
|
if !configs.contains_key(d) {
|
||||||
|
configs.insert(d.clone(), c.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Merge certs from cert manager store
|
||||||
|
if let Some(ref cm_arc) = self.cert_manager {
|
||||||
|
let cm = cm_arc.lock().await;
|
||||||
|
for (d, b) in cm.store().iter() {
|
||||||
|
if !configs.contains_key(d) {
|
||||||
|
configs.insert(d.clone(), TlsCertConfig {
|
||||||
|
cert_pem: b.cert_pem.clone(),
|
||||||
|
key_pem: b.key_pem.clone(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
configs
|
||||||
|
}
|
||||||
|
|
||||||
/// Set the Unix domain socket path for relaying UDP datagrams to TypeScript datagramHandler callbacks.
|
/// Set the Unix domain socket path for relaying UDP datagrams to TypeScript datagramHandler callbacks.
|
||||||
pub async fn set_datagram_handler_relay_path(&mut self, path: Option<String>) {
|
pub async fn set_datagram_handler_relay_path(&mut self, path: Option<String>) {
|
||||||
info!("Datagram handler relay path set to: {:?}", path);
|
info!("Datagram handler relay path set to: {:?}", path);
|
||||||
@@ -1055,39 +1100,24 @@ impl RustProxy {
|
|||||||
key_pem: key_pem.clone(),
|
key_pem: key_pem.clone(),
|
||||||
});
|
});
|
||||||
|
|
||||||
// Hot-swap TLS config on the listener
|
// Hot-swap TLS config on TCP and QUIC listeners
|
||||||
if let Some(ref mut listener) = self.listener_manager {
|
let tls_configs = self.current_tls_configs().await;
|
||||||
let mut tls_configs = Self::extract_tls_configs(&self.options.routes);
|
|
||||||
|
|
||||||
// Add the new cert
|
// Build QUIC TLS config before TCP consumes the map
|
||||||
tls_configs.insert(domain.to_string(), TlsCertConfig {
|
let quic_tls = Self::build_quic_tls_config(&tls_configs);
|
||||||
cert_pem: cert_pem.clone(),
|
|
||||||
key_pem: key_pem.clone(),
|
|
||||||
});
|
|
||||||
|
|
||||||
// Also include all existing certs from cert manager
|
|
||||||
if let Some(ref cm_arc) = self.cert_manager {
|
|
||||||
let cm = cm_arc.lock().await;
|
|
||||||
for (d, b) in cm.store().iter() {
|
|
||||||
if !tls_configs.contains_key(d) {
|
|
||||||
tls_configs.insert(d.clone(), TlsCertConfig {
|
|
||||||
cert_pem: b.cert_pem.clone(),
|
|
||||||
key_pem: b.key_pem.clone(),
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Merge dynamically loaded certs from previous loadCertificate calls
|
|
||||||
for (d, c) in &self.loaded_certs {
|
|
||||||
if !tls_configs.contains_key(d) {
|
|
||||||
tls_configs.insert(d.clone(), c.clone());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
if let Some(ref listener) = self.listener_manager {
|
||||||
listener.set_tls_configs(tls_configs);
|
listener.set_tls_configs(tls_configs);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Update existing QUIC endpoints and upgrade raw UDP fallback listeners
|
||||||
|
if let Some(ref mut udp_mgr) = self.udp_listener_manager {
|
||||||
|
if let Some(ref quic_config) = quic_tls {
|
||||||
|
udp_mgr.update_quic_tls(Arc::clone(quic_config));
|
||||||
|
udp_mgr.upgrade_raw_to_quic(Arc::clone(quic_config)).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
info!("Certificate loaded and TLS config updated for {}", domain);
|
info!("Certificate loaded and TLS config updated for {}", domain);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|||||||
195
rust/crates/rustproxy/tests/integration_h3_proxy.rs
Normal file
195
rust/crates/rustproxy/tests/integration_h3_proxy.rs
Normal file
@@ -0,0 +1,195 @@
|
|||||||
|
mod common;
|
||||||
|
|
||||||
|
use common::*;
|
||||||
|
use rustproxy::RustProxy;
|
||||||
|
use rustproxy_config::{RustProxyOptions, TransportProtocol, RouteUdp, RouteQuic};
|
||||||
|
use bytes::Buf;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
/// Build a route that listens on UDP with HTTP/3 enabled and TLS terminate.
|
||||||
|
fn make_h3_route(
|
||||||
|
port: u16,
|
||||||
|
target_host: &str,
|
||||||
|
target_port: u16,
|
||||||
|
cert_pem: &str,
|
||||||
|
key_pem: &str,
|
||||||
|
) -> rustproxy_config::RouteConfig {
|
||||||
|
let mut route = make_tls_terminate_route(port, "localhost", target_host, target_port, cert_pem, key_pem);
|
||||||
|
route.route_match.transport = Some(TransportProtocol::All);
|
||||||
|
// Keep domain="localhost" from make_tls_terminate_route — needed for TLS cert extraction
|
||||||
|
route.action.udp = Some(RouteUdp {
|
||||||
|
session_timeout: None,
|
||||||
|
max_sessions_per_ip: None,
|
||||||
|
max_datagram_size: None,
|
||||||
|
quic: Some(RouteQuic {
|
||||||
|
max_idle_timeout: Some(30000),
|
||||||
|
max_concurrent_bidi_streams: None,
|
||||||
|
max_concurrent_uni_streams: None,
|
||||||
|
enable_http3: Some(true),
|
||||||
|
alt_svc_port: None,
|
||||||
|
alt_svc_max_age: None,
|
||||||
|
initial_congestion_window: None,
|
||||||
|
}),
|
||||||
|
});
|
||||||
|
route
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Build a quinn client endpoint with insecure TLS for testing.
|
||||||
|
fn make_h3_client_endpoint() -> quinn::Endpoint {
|
||||||
|
let mut tls_config = rustls::ClientConfig::builder()
|
||||||
|
.dangerous()
|
||||||
|
.with_custom_certificate_verifier(Arc::new(InsecureVerifier))
|
||||||
|
.with_no_client_auth();
|
||||||
|
tls_config.alpn_protocols = vec![b"h3".to_vec()];
|
||||||
|
|
||||||
|
let quic_client_config = quinn::crypto::rustls::QuicClientConfig::try_from(tls_config)
|
||||||
|
.expect("Failed to build QUIC client config");
|
||||||
|
let client_config = quinn::ClientConfig::new(Arc::new(quic_client_config));
|
||||||
|
|
||||||
|
let mut endpoint = quinn::Endpoint::client("0.0.0.0:0".parse().unwrap())
|
||||||
|
.expect("Failed to create QUIC client endpoint");
|
||||||
|
endpoint.set_default_client_config(client_config);
|
||||||
|
endpoint
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Test that HTTP/3 response streams properly finish (FIN is received by client).
|
||||||
|
///
|
||||||
|
/// This is the critical regression test for the FIN bug: the proxy must send
|
||||||
|
/// a QUIC stream FIN after the response body so the client's `recv_data()`
|
||||||
|
/// returns `None` instead of hanging forever.
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_h3_response_stream_finishes() {
|
||||||
|
let backend_port = next_port();
|
||||||
|
let proxy_port = next_port();
|
||||||
|
let body_text = "Hello from HTTP/3 backend! This body has a known length for testing.";
|
||||||
|
|
||||||
|
// 1. Start plain HTTP backend with known body + content-length
|
||||||
|
let _backend = start_http_server(backend_port, 200, body_text).await;
|
||||||
|
|
||||||
|
// 2. Generate self-signed cert and configure H3 route
|
||||||
|
let (cert_pem, key_pem) = generate_self_signed_cert("localhost");
|
||||||
|
let route = make_h3_route(proxy_port, "127.0.0.1", backend_port, &cert_pem, &key_pem);
|
||||||
|
|
||||||
|
let options = RustProxyOptions {
|
||||||
|
routes: vec![route],
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
|
||||||
|
// 3. Start proxy and wait for UDP bind
|
||||||
|
let mut proxy = RustProxy::new(options).unwrap();
|
||||||
|
proxy.start().await.unwrap();
|
||||||
|
tokio::time::sleep(std::time::Duration::from_millis(500)).await;
|
||||||
|
|
||||||
|
// 4. Connect QUIC/H3 client
|
||||||
|
let endpoint = make_h3_client_endpoint();
|
||||||
|
let addr: std::net::SocketAddr = format!("127.0.0.1:{}", proxy_port).parse().unwrap();
|
||||||
|
let connection = endpoint
|
||||||
|
.connect(addr, "localhost")
|
||||||
|
.expect("Failed to initiate QUIC connection")
|
||||||
|
.await
|
||||||
|
.expect("QUIC handshake failed");
|
||||||
|
|
||||||
|
let (mut driver, mut send_request) = h3::client::new(
|
||||||
|
h3_quinn::Connection::new(connection),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.expect("H3 connection setup failed");
|
||||||
|
|
||||||
|
// Drive the H3 connection in background
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let _ = driver.wait_idle().await;
|
||||||
|
});
|
||||||
|
|
||||||
|
// 5. Send GET request
|
||||||
|
let req = http::Request::builder()
|
||||||
|
.method("GET")
|
||||||
|
.uri("https://localhost/")
|
||||||
|
.header("host", "localhost")
|
||||||
|
.body(())
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let mut stream = send_request.send_request(req).await
|
||||||
|
.expect("Failed to send H3 request");
|
||||||
|
stream.finish().await
|
||||||
|
.expect("Failed to finish sending H3 request body");
|
||||||
|
|
||||||
|
// 6. Read response headers
|
||||||
|
let resp = stream.recv_response().await
|
||||||
|
.expect("Failed to receive H3 response");
|
||||||
|
assert_eq!(resp.status(), http::StatusCode::OK,
|
||||||
|
"Expected 200 OK, got {}", resp.status());
|
||||||
|
|
||||||
|
// 7. Read body and verify stream ends (FIN received)
|
||||||
|
// This is the critical assertion: recv_data() must return None (stream ended)
|
||||||
|
// within the timeout, NOT hang forever waiting for a FIN that never arrives.
|
||||||
|
let result = with_timeout(async {
|
||||||
|
let mut total = 0usize;
|
||||||
|
while let Some(chunk) = stream.recv_data().await.expect("H3 data receive error") {
|
||||||
|
total += chunk.remaining();
|
||||||
|
}
|
||||||
|
// recv_data() returned None => stream ended (FIN received)
|
||||||
|
total
|
||||||
|
}, 10)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let bytes_received = result.expect(
|
||||||
|
"TIMEOUT: H3 stream never ended (FIN not received by client). \
|
||||||
|
The proxy sent all response data but failed to send the QUIC stream FIN."
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
bytes_received,
|
||||||
|
body_text.len(),
|
||||||
|
"Expected {} bytes, got {}",
|
||||||
|
body_text.len(),
|
||||||
|
bytes_received
|
||||||
|
);
|
||||||
|
|
||||||
|
// 8. Cleanup
|
||||||
|
endpoint.close(quinn::VarInt::from_u32(0), b"test done");
|
||||||
|
proxy.stop().await.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Insecure TLS verifier that accepts any certificate (for tests only).
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct InsecureVerifier;
|
||||||
|
|
||||||
|
impl rustls::client::danger::ServerCertVerifier for InsecureVerifier {
|
||||||
|
fn verify_server_cert(
|
||||||
|
&self,
|
||||||
|
_end_entity: &rustls::pki_types::CertificateDer<'_>,
|
||||||
|
_intermediates: &[rustls::pki_types::CertificateDer<'_>],
|
||||||
|
_server_name: &rustls::pki_types::ServerName<'_>,
|
||||||
|
_ocsp_response: &[u8],
|
||||||
|
_now: rustls::pki_types::UnixTime,
|
||||||
|
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
|
||||||
|
Ok(rustls::client::danger::ServerCertVerified::assertion())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn verify_tls12_signature(
|
||||||
|
&self,
|
||||||
|
_message: &[u8],
|
||||||
|
_cert: &rustls::pki_types::CertificateDer<'_>,
|
||||||
|
_dss: &rustls::DigitallySignedStruct,
|
||||||
|
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
|
||||||
|
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn verify_tls13_signature(
|
||||||
|
&self,
|
||||||
|
_message: &[u8],
|
||||||
|
_cert: &rustls::pki_types::CertificateDer<'_>,
|
||||||
|
_dss: &rustls::DigitallySignedStruct,
|
||||||
|
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
|
||||||
|
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
|
||||||
|
vec![
|
||||||
|
rustls::SignatureScheme::RSA_PKCS1_SHA256,
|
||||||
|
rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
|
||||||
|
rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
|
||||||
|
rustls::SignatureScheme::ED25519,
|
||||||
|
rustls::SignatureScheme::RSA_PSS_SHA256,
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,200 +0,0 @@
|
|||||||
import { tap, expect } from '@git.zone/tstest/tapbundle';
|
|
||||||
import {
|
|
||||||
delay,
|
|
||||||
retryWithBackoff,
|
|
||||||
withTimeout,
|
|
||||||
parallelLimit,
|
|
||||||
debounceAsync,
|
|
||||||
AsyncMutex,
|
|
||||||
CircuitBreaker
|
|
||||||
} from '../../../ts/core/utils/async-utils.js';
|
|
||||||
|
|
||||||
tap.test('delay should pause execution for specified milliseconds', async () => {
|
|
||||||
const startTime = Date.now();
|
|
||||||
await delay(100);
|
|
||||||
const elapsed = Date.now() - startTime;
|
|
||||||
|
|
||||||
// Allow some tolerance for timing
|
|
||||||
expect(elapsed).toBeGreaterThan(90);
|
|
||||||
expect(elapsed).toBeLessThan(150);
|
|
||||||
});
|
|
||||||
|
|
||||||
tap.test('retryWithBackoff should retry failed operations', async () => {
|
|
||||||
let attempts = 0;
|
|
||||||
const operation = async () => {
|
|
||||||
attempts++;
|
|
||||||
if (attempts < 3) {
|
|
||||||
throw new Error('Test error');
|
|
||||||
}
|
|
||||||
return 'success';
|
|
||||||
};
|
|
||||||
|
|
||||||
const result = await retryWithBackoff(operation, {
|
|
||||||
maxAttempts: 3,
|
|
||||||
initialDelay: 10
|
|
||||||
});
|
|
||||||
|
|
||||||
expect(result).toEqual('success');
|
|
||||||
expect(attempts).toEqual(3);
|
|
||||||
});
|
|
||||||
|
|
||||||
tap.test('retryWithBackoff should throw after max attempts', async () => {
|
|
||||||
let attempts = 0;
|
|
||||||
const operation = async () => {
|
|
||||||
attempts++;
|
|
||||||
throw new Error('Always fails');
|
|
||||||
};
|
|
||||||
|
|
||||||
let error: Error | null = null;
|
|
||||||
try {
|
|
||||||
await retryWithBackoff(operation, {
|
|
||||||
maxAttempts: 2,
|
|
||||||
initialDelay: 10
|
|
||||||
});
|
|
||||||
} catch (e: any) {
|
|
||||||
error = e;
|
|
||||||
}
|
|
||||||
|
|
||||||
expect(error).not.toBeNull();
|
|
||||||
expect(error?.message).toEqual('Always fails');
|
|
||||||
expect(attempts).toEqual(2);
|
|
||||||
});
|
|
||||||
|
|
||||||
tap.test('withTimeout should complete operations within timeout', async () => {
|
|
||||||
const operation = async () => {
|
|
||||||
await delay(50);
|
|
||||||
return 'completed';
|
|
||||||
};
|
|
||||||
|
|
||||||
const result = await withTimeout(operation, 100);
|
|
||||||
expect(result).toEqual('completed');
|
|
||||||
});
|
|
||||||
|
|
||||||
tap.test('withTimeout should throw on timeout', async () => {
|
|
||||||
const operation = async () => {
|
|
||||||
await delay(200);
|
|
||||||
return 'never happens';
|
|
||||||
};
|
|
||||||
|
|
||||||
let error: Error | null = null;
|
|
||||||
try {
|
|
||||||
await withTimeout(operation, 50);
|
|
||||||
} catch (e: any) {
|
|
||||||
error = e;
|
|
||||||
}
|
|
||||||
|
|
||||||
expect(error).not.toBeNull();
|
|
||||||
expect(error?.message).toContain('timed out');
|
|
||||||
});
|
|
||||||
|
|
||||||
tap.test('parallelLimit should respect concurrency limit', async () => {
|
|
||||||
let concurrent = 0;
|
|
||||||
let maxConcurrent = 0;
|
|
||||||
|
|
||||||
const items = [1, 2, 3, 4, 5, 6];
|
|
||||||
const operation = async (item: number) => {
|
|
||||||
concurrent++;
|
|
||||||
maxConcurrent = Math.max(maxConcurrent, concurrent);
|
|
||||||
await delay(50);
|
|
||||||
concurrent--;
|
|
||||||
return item * 2;
|
|
||||||
};
|
|
||||||
|
|
||||||
const results = await parallelLimit(items, operation, 2);
|
|
||||||
|
|
||||||
expect(results).toEqual([2, 4, 6, 8, 10, 12]);
|
|
||||||
expect(maxConcurrent).toBeLessThan(3);
|
|
||||||
expect(maxConcurrent).toBeGreaterThan(0);
|
|
||||||
});
|
|
||||||
|
|
||||||
tap.test('debounceAsync should debounce function calls', async () => {
|
|
||||||
let callCount = 0;
|
|
||||||
const fn = async (value: string) => {
|
|
||||||
callCount++;
|
|
||||||
return value;
|
|
||||||
};
|
|
||||||
|
|
||||||
const debounced = debounceAsync(fn, 50);
|
|
||||||
|
|
||||||
// Make multiple calls quickly
|
|
||||||
debounced('a');
|
|
||||||
debounced('b');
|
|
||||||
debounced('c');
|
|
||||||
const result = await debounced('d');
|
|
||||||
|
|
||||||
// Wait a bit to ensure no more calls
|
|
||||||
await delay(100);
|
|
||||||
|
|
||||||
expect(result).toEqual('d');
|
|
||||||
expect(callCount).toEqual(1); // Only the last call should execute
|
|
||||||
});
|
|
||||||
|
|
||||||
tap.test('AsyncMutex should ensure exclusive access', async () => {
|
|
||||||
const mutex = new AsyncMutex();
|
|
||||||
const results: number[] = [];
|
|
||||||
|
|
||||||
const operation = async (value: number) => {
|
|
||||||
await mutex.runExclusive(async () => {
|
|
||||||
results.push(value);
|
|
||||||
await delay(10);
|
|
||||||
results.push(value * 10);
|
|
||||||
});
|
|
||||||
};
|
|
||||||
|
|
||||||
// Run operations concurrently
|
|
||||||
await Promise.all([
|
|
||||||
operation(1),
|
|
||||||
operation(2),
|
|
||||||
operation(3)
|
|
||||||
]);
|
|
||||||
|
|
||||||
// Results should show sequential execution
|
|
||||||
expect(results).toEqual([1, 10, 2, 20, 3, 30]);
|
|
||||||
});
|
|
||||||
|
|
||||||
tap.test('CircuitBreaker should open after failures', async () => {
|
|
||||||
const breaker = new CircuitBreaker({
|
|
||||||
failureThreshold: 2,
|
|
||||||
resetTimeout: 100
|
|
||||||
});
|
|
||||||
|
|
||||||
let attempt = 0;
|
|
||||||
const failingOperation = async () => {
|
|
||||||
attempt++;
|
|
||||||
throw new Error('Test failure');
|
|
||||||
};
|
|
||||||
|
|
||||||
// First two failures
|
|
||||||
for (let i = 0; i < 2; i++) {
|
|
||||||
try {
|
|
||||||
await breaker.execute(failingOperation);
|
|
||||||
} catch (e) {
|
|
||||||
// Expected
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
expect(breaker.isOpen()).toBeTrue();
|
|
||||||
|
|
||||||
// Next attempt should fail immediately
|
|
||||||
let error: Error | null = null;
|
|
||||||
try {
|
|
||||||
await breaker.execute(failingOperation);
|
|
||||||
} catch (e: any) {
|
|
||||||
error = e;
|
|
||||||
}
|
|
||||||
|
|
||||||
expect(error?.message).toEqual('Circuit breaker is open');
|
|
||||||
expect(attempt).toEqual(2); // Operation not called when circuit is open
|
|
||||||
|
|
||||||
// Wait for reset timeout
|
|
||||||
await delay(150);
|
|
||||||
|
|
||||||
// Circuit should be half-open now, allowing one attempt
|
|
||||||
const successOperation = async () => 'success';
|
|
||||||
const result = await breaker.execute(successOperation);
|
|
||||||
|
|
||||||
expect(result).toEqual('success');
|
|
||||||
expect(breaker.getState()).toEqual('closed');
|
|
||||||
});
|
|
||||||
|
|
||||||
tap.start();
|
|
||||||
@@ -1,206 +0,0 @@
|
|||||||
import { tap, expect } from '@git.zone/tstest/tapbundle';
|
|
||||||
import { BinaryHeap } from '../../../ts/core/utils/binary-heap.js';
|
|
||||||
|
|
||||||
interface TestItem {
|
|
||||||
id: string;
|
|
||||||
priority: number;
|
|
||||||
value: string;
|
|
||||||
}
|
|
||||||
|
|
||||||
tap.test('should create empty heap', async () => {
|
|
||||||
const heap = new BinaryHeap<number>((a, b) => a - b);
|
|
||||||
|
|
||||||
expect(heap.size).toEqual(0);
|
|
||||||
expect(heap.isEmpty()).toBeTrue();
|
|
||||||
expect(heap.peek()).toBeUndefined();
|
|
||||||
});
|
|
||||||
|
|
||||||
tap.test('should insert and extract in correct order', async () => {
|
|
||||||
const heap = new BinaryHeap<number>((a, b) => a - b);
|
|
||||||
|
|
||||||
heap.insert(5);
|
|
||||||
heap.insert(3);
|
|
||||||
heap.insert(7);
|
|
||||||
heap.insert(1);
|
|
||||||
heap.insert(9);
|
|
||||||
heap.insert(4);
|
|
||||||
|
|
||||||
expect(heap.size).toEqual(6);
|
|
||||||
|
|
||||||
// Extract in ascending order
|
|
||||||
expect(heap.extract()).toEqual(1);
|
|
||||||
expect(heap.extract()).toEqual(3);
|
|
||||||
expect(heap.extract()).toEqual(4);
|
|
||||||
expect(heap.extract()).toEqual(5);
|
|
||||||
expect(heap.extract()).toEqual(7);
|
|
||||||
expect(heap.extract()).toEqual(9);
|
|
||||||
expect(heap.extract()).toBeUndefined();
|
|
||||||
});
|
|
||||||
|
|
||||||
tap.test('should work with custom objects and comparator', async () => {
|
|
||||||
const heap = new BinaryHeap<TestItem>(
|
|
||||||
(a, b) => a.priority - b.priority,
|
|
||||||
(item) => item.id
|
|
||||||
);
|
|
||||||
|
|
||||||
heap.insert({ id: 'a', priority: 5, value: 'five' });
|
|
||||||
heap.insert({ id: 'b', priority: 2, value: 'two' });
|
|
||||||
heap.insert({ id: 'c', priority: 8, value: 'eight' });
|
|
||||||
heap.insert({ id: 'd', priority: 1, value: 'one' });
|
|
||||||
|
|
||||||
const first = heap.extract();
|
|
||||||
expect(first?.priority).toEqual(1);
|
|
||||||
expect(first?.value).toEqual('one');
|
|
||||||
|
|
||||||
const second = heap.extract();
|
|
||||||
expect(second?.priority).toEqual(2);
|
|
||||||
expect(second?.value).toEqual('two');
|
|
||||||
});
|
|
||||||
|
|
||||||
tap.test('should support reverse order (max heap)', async () => {
|
|
||||||
const heap = new BinaryHeap<number>((a, b) => b - a);
|
|
||||||
|
|
||||||
heap.insert(5);
|
|
||||||
heap.insert(3);
|
|
||||||
heap.insert(7);
|
|
||||||
heap.insert(1);
|
|
||||||
heap.insert(9);
|
|
||||||
|
|
||||||
// Extract in descending order
|
|
||||||
expect(heap.extract()).toEqual(9);
|
|
||||||
expect(heap.extract()).toEqual(7);
|
|
||||||
expect(heap.extract()).toEqual(5);
|
|
||||||
});
|
|
||||||
|
|
||||||
tap.test('should extract by predicate', async () => {
|
|
||||||
const heap = new BinaryHeap<TestItem>((a, b) => a.priority - b.priority);
|
|
||||||
|
|
||||||
heap.insert({ id: 'a', priority: 5, value: 'five' });
|
|
||||||
heap.insert({ id: 'b', priority: 2, value: 'two' });
|
|
||||||
heap.insert({ id: 'c', priority: 8, value: 'eight' });
|
|
||||||
|
|
||||||
const extracted = heap.extractIf(item => item.id === 'b');
|
|
||||||
expect(extracted?.id).toEqual('b');
|
|
||||||
expect(heap.size).toEqual(2);
|
|
||||||
|
|
||||||
// Should not find it again
|
|
||||||
const notFound = heap.extractIf(item => item.id === 'b');
|
|
||||||
expect(notFound).toBeUndefined();
|
|
||||||
});
|
|
||||||
|
|
||||||
tap.test('should extract by key', async () => {
|
|
||||||
const heap = new BinaryHeap<TestItem>(
|
|
||||||
(a, b) => a.priority - b.priority,
|
|
||||||
(item) => item.id
|
|
||||||
);
|
|
||||||
|
|
||||||
heap.insert({ id: 'a', priority: 5, value: 'five' });
|
|
||||||
heap.insert({ id: 'b', priority: 2, value: 'two' });
|
|
||||||
heap.insert({ id: 'c', priority: 8, value: 'eight' });
|
|
||||||
|
|
||||||
expect(heap.hasKey('b')).toBeTrue();
|
|
||||||
|
|
||||||
const extracted = heap.extractByKey('b');
|
|
||||||
expect(extracted?.id).toEqual('b');
|
|
||||||
expect(heap.size).toEqual(2);
|
|
||||||
expect(heap.hasKey('b')).toBeFalse();
|
|
||||||
|
|
||||||
// Should not find it again
|
|
||||||
const notFound = heap.extractByKey('b');
|
|
||||||
expect(notFound).toBeUndefined();
|
|
||||||
});
|
|
||||||
|
|
||||||
tap.test('should throw when using key operations without extractKey', async () => {
|
|
||||||
const heap = new BinaryHeap<TestItem>((a, b) => a.priority - b.priority);
|
|
||||||
|
|
||||||
heap.insert({ id: 'a', priority: 5, value: 'five' });
|
|
||||||
|
|
||||||
let error: Error | null = null;
|
|
||||||
try {
|
|
||||||
heap.extractByKey('a');
|
|
||||||
} catch (e: any) {
|
|
||||||
error = e;
|
|
||||||
}
|
|
||||||
|
|
||||||
expect(error).not.toBeNull();
|
|
||||||
expect(error?.message).toContain('extractKey function must be provided');
|
|
||||||
});
|
|
||||||
|
|
||||||
tap.test('should handle duplicates correctly', async () => {
|
|
||||||
const heap = new BinaryHeap<number>((a, b) => a - b);
|
|
||||||
|
|
||||||
heap.insert(5);
|
|
||||||
heap.insert(5);
|
|
||||||
heap.insert(5);
|
|
||||||
heap.insert(3);
|
|
||||||
heap.insert(7);
|
|
||||||
|
|
||||||
expect(heap.size).toEqual(5);
|
|
||||||
expect(heap.extract()).toEqual(3);
|
|
||||||
expect(heap.extract()).toEqual(5);
|
|
||||||
expect(heap.extract()).toEqual(5);
|
|
||||||
expect(heap.extract()).toEqual(5);
|
|
||||||
expect(heap.extract()).toEqual(7);
|
|
||||||
});
|
|
||||||
|
|
||||||
tap.test('should convert to array without modifying heap', async () => {
|
|
||||||
const heap = new BinaryHeap<number>((a, b) => a - b);
|
|
||||||
|
|
||||||
heap.insert(5);
|
|
||||||
heap.insert(3);
|
|
||||||
heap.insert(7);
|
|
||||||
|
|
||||||
const array = heap.toArray();
|
|
||||||
expect(array).toContain(3);
|
|
||||||
expect(array).toContain(5);
|
|
||||||
expect(array).toContain(7);
|
|
||||||
expect(array.length).toEqual(3);
|
|
||||||
|
|
||||||
// Heap should still be intact
|
|
||||||
expect(heap.size).toEqual(3);
|
|
||||||
expect(heap.extract()).toEqual(3);
|
|
||||||
});
|
|
||||||
|
|
||||||
tap.test('should clear the heap', async () => {
|
|
||||||
const heap = new BinaryHeap<TestItem>(
|
|
||||||
(a, b) => a.priority - b.priority,
|
|
||||||
(item) => item.id
|
|
||||||
);
|
|
||||||
|
|
||||||
heap.insert({ id: 'a', priority: 5, value: 'five' });
|
|
||||||
heap.insert({ id: 'b', priority: 2, value: 'two' });
|
|
||||||
|
|
||||||
expect(heap.size).toEqual(2);
|
|
||||||
expect(heap.hasKey('a')).toBeTrue();
|
|
||||||
|
|
||||||
heap.clear();
|
|
||||||
|
|
||||||
expect(heap.size).toEqual(0);
|
|
||||||
expect(heap.isEmpty()).toBeTrue();
|
|
||||||
expect(heap.hasKey('a')).toBeFalse();
|
|
||||||
});
|
|
||||||
|
|
||||||
tap.test('should handle complex extraction patterns', async () => {
|
|
||||||
const heap = new BinaryHeap<number>((a, b) => a - b);
|
|
||||||
|
|
||||||
// Insert numbers 1-10 in random order
|
|
||||||
[8, 3, 5, 9, 1, 7, 4, 10, 2, 6].forEach(n => heap.insert(n));
|
|
||||||
|
|
||||||
// Extract some in order
|
|
||||||
expect(heap.extract()).toEqual(1);
|
|
||||||
expect(heap.extract()).toEqual(2);
|
|
||||||
|
|
||||||
// Insert more
|
|
||||||
heap.insert(0);
|
|
||||||
heap.insert(1.5);
|
|
||||||
|
|
||||||
// Continue extracting
|
|
||||||
expect(heap.extract()).toEqual(0);
|
|
||||||
expect(heap.extract()).toEqual(1.5);
|
|
||||||
expect(heap.extract()).toEqual(3);
|
|
||||||
|
|
||||||
// Verify remaining size (10 - 2 extracted + 2 inserted - 3 extracted = 7)
|
|
||||||
expect(heap.size).toEqual(7);
|
|
||||||
});
|
|
||||||
|
|
||||||
tap.start();
|
|
||||||
@@ -1,185 +0,0 @@
|
|||||||
import { tap, expect } from '@git.zone/tstest/tapbundle';
|
|
||||||
import * as path from 'path';
|
|
||||||
import { AsyncFileSystem } from '../../../ts/core/utils/fs-utils.js';
|
|
||||||
|
|
||||||
// Use a temporary directory for tests
|
|
||||||
const testDir = path.join(process.cwd(), '.nogit', 'test-fs-utils');
|
|
||||||
const testFile = path.join(testDir, 'test.txt');
|
|
||||||
const testJsonFile = path.join(testDir, 'test.json');
|
|
||||||
|
|
||||||
tap.test('should create and check directory existence', async () => {
|
|
||||||
// Ensure directory
|
|
||||||
await AsyncFileSystem.ensureDir(testDir);
|
|
||||||
|
|
||||||
// Check it exists
|
|
||||||
const exists = await AsyncFileSystem.exists(testDir);
|
|
||||||
expect(exists).toBeTrue();
|
|
||||||
|
|
||||||
// Check it's a directory
|
|
||||||
const isDir = await AsyncFileSystem.isDirectory(testDir);
|
|
||||||
expect(isDir).toBeTrue();
|
|
||||||
});
|
|
||||||
|
|
||||||
tap.test('should write and read text files', async () => {
|
|
||||||
const testContent = 'Hello, async filesystem!';
|
|
||||||
|
|
||||||
// Write file
|
|
||||||
await AsyncFileSystem.writeFile(testFile, testContent);
|
|
||||||
|
|
||||||
// Check file exists
|
|
||||||
const exists = await AsyncFileSystem.exists(testFile);
|
|
||||||
expect(exists).toBeTrue();
|
|
||||||
|
|
||||||
// Read file
|
|
||||||
const content = await AsyncFileSystem.readFile(testFile);
|
|
||||||
expect(content).toEqual(testContent);
|
|
||||||
|
|
||||||
// Check it's a file
|
|
||||||
const isFile = await AsyncFileSystem.isFile(testFile);
|
|
||||||
expect(isFile).toBeTrue();
|
|
||||||
});
|
|
||||||
|
|
||||||
tap.test('should write and read JSON files', async () => {
|
|
||||||
const testData = {
|
|
||||||
name: 'Test',
|
|
||||||
value: 42,
|
|
||||||
nested: {
|
|
||||||
array: [1, 2, 3]
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Write JSON
|
|
||||||
await AsyncFileSystem.writeJSON(testJsonFile, testData);
|
|
||||||
|
|
||||||
// Read JSON
|
|
||||||
const readData = await AsyncFileSystem.readJSON(testJsonFile);
|
|
||||||
expect(readData).toEqual(testData);
|
|
||||||
});
|
|
||||||
|
|
||||||
tap.test('should copy files', async () => {
|
|
||||||
const copyFile = path.join(testDir, 'copy.txt');
|
|
||||||
|
|
||||||
// Copy file
|
|
||||||
await AsyncFileSystem.copyFile(testFile, copyFile);
|
|
||||||
|
|
||||||
// Check copy exists
|
|
||||||
const exists = await AsyncFileSystem.exists(copyFile);
|
|
||||||
expect(exists).toBeTrue();
|
|
||||||
|
|
||||||
// Check content matches
|
|
||||||
const content = await AsyncFileSystem.readFile(copyFile);
|
|
||||||
const originalContent = await AsyncFileSystem.readFile(testFile);
|
|
||||||
expect(content).toEqual(originalContent);
|
|
||||||
});
|
|
||||||
|
|
||||||
tap.test('should move files', async () => {
|
|
||||||
const moveFile = path.join(testDir, 'moved.txt');
|
|
||||||
const copyFile = path.join(testDir, 'copy.txt');
|
|
||||||
|
|
||||||
// Move file
|
|
||||||
await AsyncFileSystem.moveFile(copyFile, moveFile);
|
|
||||||
|
|
||||||
// Check moved file exists
|
|
||||||
const movedExists = await AsyncFileSystem.exists(moveFile);
|
|
||||||
expect(movedExists).toBeTrue();
|
|
||||||
|
|
||||||
// Check original doesn't exist
|
|
||||||
const originalExists = await AsyncFileSystem.exists(copyFile);
|
|
||||||
expect(originalExists).toBeFalse();
|
|
||||||
});
|
|
||||||
|
|
||||||
tap.test('should list files in directory', async () => {
|
|
||||||
const files = await AsyncFileSystem.listFiles(testDir);
|
|
||||||
|
|
||||||
expect(files).toContain('test.txt');
|
|
||||||
expect(files).toContain('test.json');
|
|
||||||
expect(files).toContain('moved.txt');
|
|
||||||
});
|
|
||||||
|
|
||||||
tap.test('should list files with full paths', async () => {
|
|
||||||
const files = await AsyncFileSystem.listFilesFullPath(testDir);
|
|
||||||
|
|
||||||
const fileNames = files.map(f => path.basename(f));
|
|
||||||
expect(fileNames).toContain('test.txt');
|
|
||||||
expect(fileNames).toContain('test.json');
|
|
||||||
|
|
||||||
// All paths should be absolute
|
|
||||||
files.forEach(file => {
|
|
||||||
expect(path.isAbsolute(file)).toBeTrue();
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
tap.test('should get file stats', async () => {
|
|
||||||
const stats = await AsyncFileSystem.getStats(testFile);
|
|
||||||
|
|
||||||
expect(stats).not.toBeNull();
|
|
||||||
expect(stats?.isFile()).toBeTrue();
|
|
||||||
expect(stats?.size).toBeGreaterThan(0);
|
|
||||||
});
|
|
||||||
|
|
||||||
tap.test('should handle non-existent files gracefully', async () => {
|
|
||||||
const nonExistent = path.join(testDir, 'does-not-exist.txt');
|
|
||||||
|
|
||||||
// Check existence
|
|
||||||
const exists = await AsyncFileSystem.exists(nonExistent);
|
|
||||||
expect(exists).toBeFalse();
|
|
||||||
|
|
||||||
// Get stats should return null
|
|
||||||
const stats = await AsyncFileSystem.getStats(nonExistent);
|
|
||||||
expect(stats).toBeNull();
|
|
||||||
|
|
||||||
// Remove should not throw
|
|
||||||
await AsyncFileSystem.remove(nonExistent);
|
|
||||||
});
|
|
||||||
|
|
||||||
tap.test('should remove files', async () => {
|
|
||||||
// Remove a file
|
|
||||||
await AsyncFileSystem.remove(testFile);
|
|
||||||
|
|
||||||
// Check it's gone
|
|
||||||
const exists = await AsyncFileSystem.exists(testFile);
|
|
||||||
expect(exists).toBeFalse();
|
|
||||||
});
|
|
||||||
|
|
||||||
tap.test('should ensure file exists', async () => {
|
|
||||||
const ensureFile = path.join(testDir, 'ensure.txt');
|
|
||||||
|
|
||||||
// Ensure file
|
|
||||||
await AsyncFileSystem.ensureFile(ensureFile);
|
|
||||||
|
|
||||||
// Check it exists
|
|
||||||
const exists = await AsyncFileSystem.exists(ensureFile);
|
|
||||||
expect(exists).toBeTrue();
|
|
||||||
|
|
||||||
// Check it's empty
|
|
||||||
const content = await AsyncFileSystem.readFile(ensureFile);
|
|
||||||
expect(content).toEqual('');
|
|
||||||
});
|
|
||||||
|
|
||||||
tap.test('should recursively list files', async () => {
|
|
||||||
// Create subdirectory with file
|
|
||||||
const subDir = path.join(testDir, 'subdir');
|
|
||||||
const subFile = path.join(subDir, 'nested.txt');
|
|
||||||
|
|
||||||
await AsyncFileSystem.ensureDir(subDir);
|
|
||||||
await AsyncFileSystem.writeFile(subFile, 'nested content');
|
|
||||||
|
|
||||||
// List recursively
|
|
||||||
const files = await AsyncFileSystem.listFilesRecursive(testDir);
|
|
||||||
|
|
||||||
// Should include files from subdirectory
|
|
||||||
const fileNames = files.map(f => path.relative(testDir, f));
|
|
||||||
expect(fileNames).toContain('test.json');
|
|
||||||
expect(fileNames).toContain(path.join('subdir', 'nested.txt'));
|
|
||||||
});
|
|
||||||
|
|
||||||
tap.test('should clean up test directory', async () => {
|
|
||||||
// Remove entire test directory
|
|
||||||
await AsyncFileSystem.removeDir(testDir);
|
|
||||||
|
|
||||||
// Check it's gone
|
|
||||||
const exists = await AsyncFileSystem.exists(testDir);
|
|
||||||
expect(exists).toBeFalse();
|
|
||||||
});
|
|
||||||
|
|
||||||
tap.start();
|
|
||||||
@@ -1,156 +0,0 @@
|
|||||||
import { expect, tap } from '@git.zone/tstest/tapbundle';
|
|
||||||
import { IpUtils } from '../../../ts/core/utils/ip-utils.js';
|
|
||||||
|
|
||||||
tap.test('ip-utils - normalizeIP', async () => {
|
|
||||||
// IPv4 normalization
|
|
||||||
const ipv4Variants = IpUtils.normalizeIP('127.0.0.1');
|
|
||||||
expect(ipv4Variants).toEqual(['127.0.0.1', '::ffff:127.0.0.1']);
|
|
||||||
|
|
||||||
// IPv6-mapped IPv4 normalization
|
|
||||||
const ipv6MappedVariants = IpUtils.normalizeIP('::ffff:127.0.0.1');
|
|
||||||
expect(ipv6MappedVariants).toEqual(['::ffff:127.0.0.1', '127.0.0.1']);
|
|
||||||
|
|
||||||
// IPv6 normalization
|
|
||||||
const ipv6Variants = IpUtils.normalizeIP('::1');
|
|
||||||
expect(ipv6Variants).toEqual(['::1']);
|
|
||||||
|
|
||||||
// Invalid/empty input handling
|
|
||||||
expect(IpUtils.normalizeIP('')).toEqual([]);
|
|
||||||
expect(IpUtils.normalizeIP(null as any)).toEqual([]);
|
|
||||||
expect(IpUtils.normalizeIP(undefined as any)).toEqual([]);
|
|
||||||
});
|
|
||||||
|
|
||||||
tap.test('ip-utils - isGlobIPMatch', async () => {
|
|
||||||
// Direct matches
|
|
||||||
expect(IpUtils.isGlobIPMatch('127.0.0.1', ['127.0.0.1'])).toEqual(true);
|
|
||||||
expect(IpUtils.isGlobIPMatch('::1', ['::1'])).toEqual(true);
|
|
||||||
|
|
||||||
// Wildcard matches
|
|
||||||
expect(IpUtils.isGlobIPMatch('127.0.0.1', ['127.0.0.*'])).toEqual(true);
|
|
||||||
expect(IpUtils.isGlobIPMatch('127.0.0.1', ['127.0.*.*'])).toEqual(true);
|
|
||||||
expect(IpUtils.isGlobIPMatch('127.0.0.1', ['127.*.*.*'])).toEqual(true);
|
|
||||||
|
|
||||||
// IPv4-mapped IPv6 handling
|
|
||||||
expect(IpUtils.isGlobIPMatch('::ffff:127.0.0.1', ['127.0.0.1'])).toEqual(true);
|
|
||||||
expect(IpUtils.isGlobIPMatch('127.0.0.1', ['::ffff:127.0.0.1'])).toEqual(true);
|
|
||||||
|
|
||||||
// Match multiple patterns
|
|
||||||
expect(IpUtils.isGlobIPMatch('127.0.0.1', ['10.0.0.1', '127.0.0.1', '192.168.1.1'])).toEqual(true);
|
|
||||||
|
|
||||||
// Non-matching patterns
|
|
||||||
expect(IpUtils.isGlobIPMatch('127.0.0.1', ['10.0.0.1'])).toEqual(false);
|
|
||||||
expect(IpUtils.isGlobIPMatch('127.0.0.1', ['128.0.0.1'])).toEqual(false);
|
|
||||||
expect(IpUtils.isGlobIPMatch('127.0.0.1', ['127.0.0.2'])).toEqual(false);
|
|
||||||
|
|
||||||
// Edge cases
|
|
||||||
expect(IpUtils.isGlobIPMatch('', ['127.0.0.1'])).toEqual(false);
|
|
||||||
expect(IpUtils.isGlobIPMatch('127.0.0.1', [])).toEqual(false);
|
|
||||||
expect(IpUtils.isGlobIPMatch('127.0.0.1', null as any)).toEqual(false);
|
|
||||||
expect(IpUtils.isGlobIPMatch(null as any, ['127.0.0.1'])).toEqual(false);
|
|
||||||
});
|
|
||||||
|
|
||||||
tap.test('ip-utils - isIPAuthorized', async () => {
|
|
||||||
// Basic tests to check the core functionality works
|
|
||||||
// No restrictions - all IPs allowed
|
|
||||||
expect(IpUtils.isIPAuthorized('127.0.0.1')).toEqual(true);
|
|
||||||
|
|
||||||
// Basic blocked IP test
|
|
||||||
const blockedIP = '8.8.8.8';
|
|
||||||
const blockedIPs = [blockedIP];
|
|
||||||
expect(IpUtils.isIPAuthorized(blockedIP, [], blockedIPs)).toEqual(false);
|
|
||||||
|
|
||||||
// Basic allowed IP test
|
|
||||||
const allowedIP = '10.0.0.1';
|
|
||||||
const allowedIPs = [allowedIP];
|
|
||||||
expect(IpUtils.isIPAuthorized(allowedIP, allowedIPs)).toEqual(true);
|
|
||||||
expect(IpUtils.isIPAuthorized('192.168.1.1', allowedIPs)).toEqual(false);
|
|
||||||
});
|
|
||||||
|
|
||||||
tap.test('ip-utils - isPrivateIP', async () => {
|
|
||||||
// Private IPv4 ranges
|
|
||||||
expect(IpUtils.isPrivateIP('10.0.0.1')).toEqual(true);
|
|
||||||
expect(IpUtils.isPrivateIP('172.16.0.1')).toEqual(true);
|
|
||||||
expect(IpUtils.isPrivateIP('172.31.255.255')).toEqual(true);
|
|
||||||
expect(IpUtils.isPrivateIP('192.168.0.1')).toEqual(true);
|
|
||||||
expect(IpUtils.isPrivateIP('127.0.0.1')).toEqual(true);
|
|
||||||
|
|
||||||
// Public IPv4 addresses
|
|
||||||
expect(IpUtils.isPrivateIP('8.8.8.8')).toEqual(false);
|
|
||||||
expect(IpUtils.isPrivateIP('203.0.113.1')).toEqual(false);
|
|
||||||
|
|
||||||
// IPv4-mapped IPv6 handling
|
|
||||||
expect(IpUtils.isPrivateIP('::ffff:10.0.0.1')).toEqual(true);
|
|
||||||
expect(IpUtils.isPrivateIP('::ffff:8.8.8.8')).toEqual(false);
|
|
||||||
|
|
||||||
// Private IPv6 addresses
|
|
||||||
expect(IpUtils.isPrivateIP('::1')).toEqual(true);
|
|
||||||
expect(IpUtils.isPrivateIP('fd00::')).toEqual(true);
|
|
||||||
expect(IpUtils.isPrivateIP('fe80::1')).toEqual(true);
|
|
||||||
|
|
||||||
// Public IPv6 addresses
|
|
||||||
expect(IpUtils.isPrivateIP('2001:db8::1')).toEqual(false);
|
|
||||||
|
|
||||||
// Edge cases
|
|
||||||
expect(IpUtils.isPrivateIP('')).toEqual(false);
|
|
||||||
expect(IpUtils.isPrivateIP(null as any)).toEqual(false);
|
|
||||||
expect(IpUtils.isPrivateIP(undefined as any)).toEqual(false);
|
|
||||||
});
|
|
||||||
|
|
||||||
tap.test('ip-utils - isPublicIP', async () => {
|
|
||||||
// Public IPv4 addresses
|
|
||||||
expect(IpUtils.isPublicIP('8.8.8.8')).toEqual(true);
|
|
||||||
expect(IpUtils.isPublicIP('203.0.113.1')).toEqual(true);
|
|
||||||
|
|
||||||
// Private IPv4 ranges
|
|
||||||
expect(IpUtils.isPublicIP('10.0.0.1')).toEqual(false);
|
|
||||||
expect(IpUtils.isPublicIP('172.16.0.1')).toEqual(false);
|
|
||||||
expect(IpUtils.isPublicIP('192.168.0.1')).toEqual(false);
|
|
||||||
expect(IpUtils.isPublicIP('127.0.0.1')).toEqual(false);
|
|
||||||
|
|
||||||
// Public IPv6 addresses
|
|
||||||
expect(IpUtils.isPublicIP('2001:db8::1')).toEqual(true);
|
|
||||||
|
|
||||||
// Private IPv6 addresses
|
|
||||||
expect(IpUtils.isPublicIP('::1')).toEqual(false);
|
|
||||||
expect(IpUtils.isPublicIP('fd00::')).toEqual(false);
|
|
||||||
expect(IpUtils.isPublicIP('fe80::1')).toEqual(false);
|
|
||||||
|
|
||||||
// Edge cases - the implementation treats these as non-private, which is technically correct but might not be what users expect
|
|
||||||
const emptyResult = IpUtils.isPublicIP('');
|
|
||||||
expect(emptyResult).toEqual(true);
|
|
||||||
|
|
||||||
const nullResult = IpUtils.isPublicIP(null as any);
|
|
||||||
expect(nullResult).toEqual(true);
|
|
||||||
|
|
||||||
const undefinedResult = IpUtils.isPublicIP(undefined as any);
|
|
||||||
expect(undefinedResult).toEqual(true);
|
|
||||||
});
|
|
||||||
|
|
||||||
tap.test('ip-utils - cidrToGlobPatterns', async () => {
|
|
||||||
// Class C network
|
|
||||||
const classC = IpUtils.cidrToGlobPatterns('192.168.1.0/24');
|
|
||||||
expect(classC).toEqual(['192.168.1.*']);
|
|
||||||
|
|
||||||
// Class B network
|
|
||||||
const classB = IpUtils.cidrToGlobPatterns('172.16.0.0/16');
|
|
||||||
expect(classB).toEqual(['172.16.*.*']);
|
|
||||||
|
|
||||||
// Class A network
|
|
||||||
const classA = IpUtils.cidrToGlobPatterns('10.0.0.0/8');
|
|
||||||
expect(classA).toEqual(['10.*.*.*']);
|
|
||||||
|
|
||||||
// Small subnet (/28 = 16 addresses)
|
|
||||||
const smallSubnet = IpUtils.cidrToGlobPatterns('192.168.1.0/28');
|
|
||||||
expect(smallSubnet.length).toEqual(16);
|
|
||||||
expect(smallSubnet).toContain('192.168.1.0');
|
|
||||||
expect(smallSubnet).toContain('192.168.1.15');
|
|
||||||
|
|
||||||
// Invalid inputs
|
|
||||||
expect(IpUtils.cidrToGlobPatterns('')).toEqual([]);
|
|
||||||
expect(IpUtils.cidrToGlobPatterns('192.168.1.0')).toEqual([]);
|
|
||||||
expect(IpUtils.cidrToGlobPatterns('192.168.1.0/')).toEqual([]);
|
|
||||||
expect(IpUtils.cidrToGlobPatterns('192.168.1.0/33')).toEqual([]);
|
|
||||||
expect(IpUtils.cidrToGlobPatterns('invalid/24')).toEqual([]);
|
|
||||||
});
|
|
||||||
|
|
||||||
export default tap.start();
|
|
||||||
@@ -1,252 +0,0 @@
|
|||||||
import { tap, expect } from '@git.zone/tstest/tapbundle';
|
|
||||||
import { LifecycleComponent } from '../../../ts/core/utils/lifecycle-component.js';
|
|
||||||
import { EventEmitter } from 'events';
|
|
||||||
|
|
||||||
// Test implementation of LifecycleComponent
|
|
||||||
class TestComponent extends LifecycleComponent {
|
|
||||||
public timerCallCount = 0;
|
|
||||||
public intervalCallCount = 0;
|
|
||||||
public cleanupCalled = false;
|
|
||||||
public testEmitter = new EventEmitter();
|
|
||||||
public listenerCallCount = 0;
|
|
||||||
|
|
||||||
constructor() {
|
|
||||||
super();
|
|
||||||
this.setupTimers();
|
|
||||||
this.setupListeners();
|
|
||||||
}
|
|
||||||
|
|
||||||
private setupTimers() {
|
|
||||||
// Set up a timeout
|
|
||||||
this.setTimeout(() => {
|
|
||||||
this.timerCallCount++;
|
|
||||||
}, 100);
|
|
||||||
|
|
||||||
// Set up an interval
|
|
||||||
this.setInterval(() => {
|
|
||||||
this.intervalCallCount++;
|
|
||||||
}, 50);
|
|
||||||
}
|
|
||||||
|
|
||||||
private setupListeners() {
|
|
||||||
this.addEventListener(this.testEmitter, 'test-event', () => {
|
|
||||||
this.listenerCallCount++;
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
protected async onCleanup(): Promise<void> {
|
|
||||||
this.cleanupCalled = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Expose protected methods for testing
|
|
||||||
public testSetTimeout(handler: Function, timeout: number): NodeJS.Timeout {
|
|
||||||
return this.setTimeout(handler, timeout);
|
|
||||||
}
|
|
||||||
|
|
||||||
public testSetInterval(handler: Function, interval: number): NodeJS.Timeout {
|
|
||||||
return this.setInterval(handler, interval);
|
|
||||||
}
|
|
||||||
|
|
||||||
public testClearTimeout(timer: NodeJS.Timeout): void {
|
|
||||||
return this.clearTimeout(timer);
|
|
||||||
}
|
|
||||||
|
|
||||||
public testClearInterval(timer: NodeJS.Timeout): void {
|
|
||||||
return this.clearInterval(timer);
|
|
||||||
}
|
|
||||||
|
|
||||||
public testAddEventListener(target: any, event: string, handler: Function, options?: { once?: boolean }): void {
|
|
||||||
return this.addEventListener(target, event, handler, options);
|
|
||||||
}
|
|
||||||
|
|
||||||
public testIsShuttingDown(): boolean {
|
|
||||||
return this.isShuttingDownState();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
tap.test('should manage timers properly', async () => {
|
|
||||||
const component = new TestComponent();
|
|
||||||
|
|
||||||
// Wait for timers to fire
|
|
||||||
await new Promise(resolve => setTimeout(resolve, 200));
|
|
||||||
|
|
||||||
expect(component.timerCallCount).toEqual(1);
|
|
||||||
expect(component.intervalCallCount).toBeGreaterThan(2);
|
|
||||||
|
|
||||||
await component.cleanup();
|
|
||||||
});
|
|
||||||
|
|
||||||
tap.test('should manage event listeners properly', async () => {
|
|
||||||
const component = new TestComponent();
|
|
||||||
|
|
||||||
// Emit events
|
|
||||||
component.testEmitter.emit('test-event');
|
|
||||||
component.testEmitter.emit('test-event');
|
|
||||||
|
|
||||||
expect(component.listenerCallCount).toEqual(2);
|
|
||||||
|
|
||||||
// Cleanup and verify listeners are removed
|
|
||||||
await component.cleanup();
|
|
||||||
|
|
||||||
component.testEmitter.emit('test-event');
|
|
||||||
expect(component.listenerCallCount).toEqual(2); // Should not increase
|
|
||||||
});
|
|
||||||
|
|
||||||
tap.test('should prevent timer execution after cleanup', async () => {
|
|
||||||
const component = new TestComponent();
|
|
||||||
|
|
||||||
let laterCallCount = 0;
|
|
||||||
component.testSetTimeout(() => {
|
|
||||||
laterCallCount++;
|
|
||||||
}, 100);
|
|
||||||
|
|
||||||
// Cleanup immediately
|
|
||||||
await component.cleanup();
|
|
||||||
|
|
||||||
// Wait for timer that would have fired
|
|
||||||
await new Promise(resolve => setTimeout(resolve, 150));
|
|
||||||
|
|
||||||
expect(laterCallCount).toEqual(0);
|
|
||||||
});
|
|
||||||
|
|
||||||
tap.test('should handle child components', async () => {
|
|
||||||
class ParentComponent extends LifecycleComponent {
|
|
||||||
public child: TestComponent;
|
|
||||||
|
|
||||||
constructor() {
|
|
||||||
super();
|
|
||||||
this.child = new TestComponent();
|
|
||||||
this.registerChildComponent(this.child);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const parent = new ParentComponent();
|
|
||||||
|
|
||||||
// Wait for child timers
|
|
||||||
await new Promise(resolve => setTimeout(resolve, 100));
|
|
||||||
|
|
||||||
expect(parent.child.timerCallCount).toEqual(1);
|
|
||||||
|
|
||||||
// Cleanup parent should cleanup child
|
|
||||||
await parent.cleanup();
|
|
||||||
|
|
||||||
expect(parent.child.cleanupCalled).toBeTrue();
|
|
||||||
expect(parent.child.testIsShuttingDown()).toBeTrue();
|
|
||||||
});
|
|
||||||
|
|
||||||
tap.test('should handle multiple cleanup calls gracefully', async () => {
|
|
||||||
const component = new TestComponent();
|
|
||||||
|
|
||||||
// Call cleanup multiple times
|
|
||||||
const promises = [
|
|
||||||
component.cleanup(),
|
|
||||||
component.cleanup(),
|
|
||||||
component.cleanup()
|
|
||||||
];
|
|
||||||
|
|
||||||
await Promise.all(promises);
|
|
||||||
|
|
||||||
// Should only clean up once
|
|
||||||
expect(component.cleanupCalled).toBeTrue();
|
|
||||||
});
|
|
||||||
|
|
||||||
tap.test('should clear specific timers', async () => {
|
|
||||||
const component = new TestComponent();
|
|
||||||
|
|
||||||
let callCount = 0;
|
|
||||||
const timer = component.testSetTimeout(() => {
|
|
||||||
callCount++;
|
|
||||||
}, 100);
|
|
||||||
|
|
||||||
// Clear the timer
|
|
||||||
component.testClearTimeout(timer);
|
|
||||||
|
|
||||||
// Wait and verify it didn't fire
|
|
||||||
await new Promise(resolve => setTimeout(resolve, 150));
|
|
||||||
|
|
||||||
expect(callCount).toEqual(0);
|
|
||||||
|
|
||||||
await component.cleanup();
|
|
||||||
});
|
|
||||||
|
|
||||||
tap.test('should clear specific intervals', async () => {
|
|
||||||
const component = new TestComponent();
|
|
||||||
|
|
||||||
let callCount = 0;
|
|
||||||
const interval = component.testSetInterval(() => {
|
|
||||||
callCount++;
|
|
||||||
}, 50);
|
|
||||||
|
|
||||||
// Let it run a bit
|
|
||||||
await new Promise(resolve => setTimeout(resolve, 120));
|
|
||||||
|
|
||||||
const countBeforeClear = callCount;
|
|
||||||
expect(countBeforeClear).toBeGreaterThan(1);
|
|
||||||
|
|
||||||
// Clear the interval
|
|
||||||
component.testClearInterval(interval);
|
|
||||||
|
|
||||||
// Wait and verify it stopped
|
|
||||||
await new Promise(resolve => setTimeout(resolve, 100));
|
|
||||||
|
|
||||||
expect(callCount).toEqual(countBeforeClear);
|
|
||||||
|
|
||||||
await component.cleanup();
|
|
||||||
});
|
|
||||||
|
|
||||||
tap.test('should handle once event listeners', async () => {
|
|
||||||
const component = new TestComponent();
|
|
||||||
const emitter = new EventEmitter();
|
|
||||||
|
|
||||||
let callCount = 0;
|
|
||||||
const handler = () => {
|
|
||||||
callCount++;
|
|
||||||
};
|
|
||||||
|
|
||||||
component.testAddEventListener(emitter, 'once-event', handler, { once: true });
|
|
||||||
|
|
||||||
// Check listener count before emit
|
|
||||||
const beforeCount = emitter.listenerCount('once-event');
|
|
||||||
expect(beforeCount).toEqual(1);
|
|
||||||
|
|
||||||
// Emit once - the listener should fire and auto-remove
|
|
||||||
emitter.emit('once-event');
|
|
||||||
expect(callCount).toEqual(1);
|
|
||||||
|
|
||||||
// Check listener was auto-removed
|
|
||||||
const afterCount = emitter.listenerCount('once-event');
|
|
||||||
expect(afterCount).toEqual(0);
|
|
||||||
|
|
||||||
// Emit again - should not increase count
|
|
||||||
emitter.emit('once-event');
|
|
||||||
expect(callCount).toEqual(1);
|
|
||||||
|
|
||||||
await component.cleanup();
|
|
||||||
});
|
|
||||||
|
|
||||||
tap.test('should not create timers when shutting down', async () => {
|
|
||||||
const component = new TestComponent();
|
|
||||||
|
|
||||||
// Start cleanup
|
|
||||||
const cleanupPromise = component.cleanup();
|
|
||||||
|
|
||||||
// Try to create timers during shutdown
|
|
||||||
let timerFired = false;
|
|
||||||
let intervalFired = false;
|
|
||||||
|
|
||||||
component.testSetTimeout(() => {
|
|
||||||
timerFired = true;
|
|
||||||
}, 10);
|
|
||||||
|
|
||||||
component.testSetInterval(() => {
|
|
||||||
intervalFired = true;
|
|
||||||
}, 10);
|
|
||||||
|
|
||||||
await cleanupPromise;
|
|
||||||
await new Promise(resolve => setTimeout(resolve, 50));
|
|
||||||
|
|
||||||
expect(timerFired).toBeFalse();
|
|
||||||
expect(intervalFired).toBeFalse();
|
|
||||||
});
|
|
||||||
|
|
||||||
export default tap.start();
|
|
||||||
@@ -1,158 +0,0 @@
|
|||||||
import { expect, tap } from '@git.zone/tstest/tapbundle';
|
|
||||||
import { SharedSecurityManager } from '../../../ts/core/utils/shared-security-manager.js';
|
|
||||||
import type { IRouteConfig, IRouteContext } from '../../../ts/proxies/smart-proxy/models/route-types.js';
|
|
||||||
|
|
||||||
// Test security manager
|
|
||||||
tap.test('Shared Security Manager', async () => {
|
|
||||||
let securityManager: SharedSecurityManager;
|
|
||||||
|
|
||||||
// Set up a new security manager for each test
|
|
||||||
securityManager = new SharedSecurityManager({
|
|
||||||
maxConnectionsPerIP: 5,
|
|
||||||
connectionRateLimitPerMinute: 10
|
|
||||||
});
|
|
||||||
|
|
||||||
tap.test('should validate IPs correctly', async () => {
|
|
||||||
// Should allow IPs under connection limit
|
|
||||||
expect(securityManager.validateIP('192.168.1.1').allowed).toBeTrue();
|
|
||||||
|
|
||||||
// Track multiple connections
|
|
||||||
for (let i = 0; i < 4; i++) {
|
|
||||||
securityManager.trackConnectionByIP('192.168.1.1', `conn_${i}`);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Should still allow IPs under connection limit
|
|
||||||
expect(securityManager.validateIP('192.168.1.1').allowed).toBeTrue();
|
|
||||||
|
|
||||||
// Add one more to reach the limit
|
|
||||||
securityManager.trackConnectionByIP('192.168.1.1', 'conn_4');
|
|
||||||
|
|
||||||
// Should now block IPs over connection limit
|
|
||||||
expect(securityManager.validateIP('192.168.1.1').allowed).toBeFalse();
|
|
||||||
|
|
||||||
// Remove a connection
|
|
||||||
securityManager.removeConnectionByIP('192.168.1.1', 'conn_0');
|
|
||||||
|
|
||||||
// Should allow again after connection is removed
|
|
||||||
expect(securityManager.validateIP('192.168.1.1').allowed).toBeTrue();
|
|
||||||
});
|
|
||||||
|
|
||||||
tap.test('should authorize IPs based on allow/block lists', async () => {
|
|
||||||
// Test with allow list only
|
|
||||||
expect(securityManager.isIPAuthorized('192.168.1.1', ['192.168.1.*'])).toBeTrue();
|
|
||||||
expect(securityManager.isIPAuthorized('192.168.2.1', ['192.168.1.*'])).toBeFalse();
|
|
||||||
|
|
||||||
// Test with block list
|
|
||||||
expect(securityManager.isIPAuthorized('192.168.1.5', ['*'], ['192.168.1.5'])).toBeFalse();
|
|
||||||
expect(securityManager.isIPAuthorized('192.168.1.1', ['*'], ['192.168.1.5'])).toBeTrue();
|
|
||||||
|
|
||||||
// Test with both allow and block lists
|
|
||||||
expect(securityManager.isIPAuthorized('192.168.1.1', ['192.168.1.*'], ['192.168.1.5'])).toBeTrue();
|
|
||||||
expect(securityManager.isIPAuthorized('192.168.1.5', ['192.168.1.*'], ['192.168.1.5'])).toBeFalse();
|
|
||||||
});
|
|
||||||
|
|
||||||
tap.test('should validate route access', async () => {
|
|
||||||
const route: IRouteConfig = {
|
|
||||||
match: {
|
|
||||||
ports: [8080]
|
|
||||||
},
|
|
||||||
action: {
|
|
||||||
type: 'forward',
|
|
||||||
targets: [{ host: 'target.com', port: 443 }]
|
|
||||||
},
|
|
||||||
security: {
|
|
||||||
ipAllowList: ['10.0.0.*', '192.168.1.*'],
|
|
||||||
ipBlockList: ['192.168.1.100'],
|
|
||||||
maxConnections: 3
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
const allowedContext: IRouteContext = {
|
|
||||||
clientIp: '192.168.1.1',
|
|
||||||
port: 8080,
|
|
||||||
serverIp: '127.0.0.1',
|
|
||||||
isTls: false,
|
|
||||||
timestamp: Date.now(),
|
|
||||||
connectionId: 'test_conn_1'
|
|
||||||
};
|
|
||||||
|
|
||||||
const blockedByIPContext: IRouteContext = {
|
|
||||||
...allowedContext,
|
|
||||||
clientIp: '192.168.1.100'
|
|
||||||
};
|
|
||||||
|
|
||||||
const blockedByRangeContext: IRouteContext = {
|
|
||||||
...allowedContext,
|
|
||||||
clientIp: '172.16.0.1'
|
|
||||||
};
|
|
||||||
|
|
||||||
const blockedByMaxConnectionsContext: IRouteContext = {
|
|
||||||
...allowedContext,
|
|
||||||
connectionId: 'test_conn_4'
|
|
||||||
};
|
|
||||||
|
|
||||||
expect(securityManager.isAllowed(route, allowedContext)).toBeTrue();
|
|
||||||
expect(securityManager.isAllowed(route, blockedByIPContext)).toBeFalse();
|
|
||||||
expect(securityManager.isAllowed(route, blockedByRangeContext)).toBeFalse();
|
|
||||||
|
|
||||||
// Test max connections for route - assuming implementation has been updated
|
|
||||||
if ((securityManager as any).trackConnectionByRoute) {
|
|
||||||
(securityManager as any).trackConnectionByRoute(route, 'conn_1');
|
|
||||||
(securityManager as any).trackConnectionByRoute(route, 'conn_2');
|
|
||||||
(securityManager as any).trackConnectionByRoute(route, 'conn_3');
|
|
||||||
|
|
||||||
// Should now block due to max connections
|
|
||||||
expect(securityManager.isAllowed(route, blockedByMaxConnectionsContext)).toBeFalse();
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
tap.test('should clean up expired entries', async () => {
|
|
||||||
const route: IRouteConfig = {
|
|
||||||
match: {
|
|
||||||
ports: [8080]
|
|
||||||
},
|
|
||||||
action: {
|
|
||||||
type: 'forward',
|
|
||||||
targets: [{ host: 'target.com', port: 443 }]
|
|
||||||
},
|
|
||||||
security: {
|
|
||||||
rateLimit: {
|
|
||||||
enabled: true,
|
|
||||||
maxRequests: 5,
|
|
||||||
window: 60 // 60 seconds
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
const context: IRouteContext = {
|
|
||||||
clientIp: '192.168.1.1',
|
|
||||||
port: 8080,
|
|
||||||
serverIp: '127.0.0.1',
|
|
||||||
isTls: false,
|
|
||||||
timestamp: Date.now(),
|
|
||||||
connectionId: 'test_conn_1'
|
|
||||||
};
|
|
||||||
|
|
||||||
// Test rate limiting if method exists
|
|
||||||
if ((securityManager as any).checkRateLimit) {
|
|
||||||
// Add 5 attempts (max allowed)
|
|
||||||
for (let i = 0; i < 5; i++) {
|
|
||||||
expect((securityManager as any).checkRateLimit(route, context)).toBeTrue();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Should now be blocked
|
|
||||||
expect((securityManager as any).checkRateLimit(route, context)).toBeFalse();
|
|
||||||
|
|
||||||
// Force cleanup (normally runs periodically)
|
|
||||||
if ((securityManager as any).cleanup) {
|
|
||||||
(securityManager as any).cleanup();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Should still be blocked since entries are not expired yet
|
|
||||||
expect((securityManager as any).checkRateLimit(route, context)).toBeFalse();
|
|
||||||
}
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
// Export test runner
|
|
||||||
export default tap.start();
|
|
||||||
@@ -1,302 +0,0 @@
|
|||||||
import { expect, tap } from '@git.zone/tstest/tapbundle';
|
|
||||||
import { ValidationUtils } from '../../../ts/core/utils/validation-utils.js';
|
|
||||||
import type { IDomainOptions, IAcmeOptions } from '../../../ts/core/models/common-types.js';
|
|
||||||
|
|
||||||
tap.test('validation-utils - isValidPort', async () => {
|
|
||||||
// Valid port values
|
|
||||||
expect(ValidationUtils.isValidPort(1)).toEqual(true);
|
|
||||||
expect(ValidationUtils.isValidPort(80)).toEqual(true);
|
|
||||||
expect(ValidationUtils.isValidPort(443)).toEqual(true);
|
|
||||||
expect(ValidationUtils.isValidPort(8080)).toEqual(true);
|
|
||||||
expect(ValidationUtils.isValidPort(65535)).toEqual(true);
|
|
||||||
|
|
||||||
// Invalid port values
|
|
||||||
expect(ValidationUtils.isValidPort(0)).toEqual(false);
|
|
||||||
expect(ValidationUtils.isValidPort(-1)).toEqual(false);
|
|
||||||
expect(ValidationUtils.isValidPort(65536)).toEqual(false);
|
|
||||||
expect(ValidationUtils.isValidPort(80.5)).toEqual(false);
|
|
||||||
expect(ValidationUtils.isValidPort(NaN)).toEqual(false);
|
|
||||||
expect(ValidationUtils.isValidPort(null as any)).toEqual(false);
|
|
||||||
expect(ValidationUtils.isValidPort(undefined as any)).toEqual(false);
|
|
||||||
});
|
|
||||||
|
|
||||||
tap.test('validation-utils - isValidDomainName', async () => {
|
|
||||||
// Valid domain names
|
|
||||||
expect(ValidationUtils.isValidDomainName('example.com')).toEqual(true);
|
|
||||||
expect(ValidationUtils.isValidDomainName('sub.example.com')).toEqual(true);
|
|
||||||
expect(ValidationUtils.isValidDomainName('*.example.com')).toEqual(true);
|
|
||||||
expect(ValidationUtils.isValidDomainName('a-hyphenated-domain.example.com')).toEqual(true);
|
|
||||||
expect(ValidationUtils.isValidDomainName('example123.com')).toEqual(true);
|
|
||||||
|
|
||||||
// Invalid domain names
|
|
||||||
expect(ValidationUtils.isValidDomainName('')).toEqual(false);
|
|
||||||
expect(ValidationUtils.isValidDomainName(null as any)).toEqual(false);
|
|
||||||
expect(ValidationUtils.isValidDomainName(undefined as any)).toEqual(false);
|
|
||||||
expect(ValidationUtils.isValidDomainName('-invalid.com')).toEqual(false);
|
|
||||||
expect(ValidationUtils.isValidDomainName('invalid-.com')).toEqual(false);
|
|
||||||
expect(ValidationUtils.isValidDomainName('inv@lid.com')).toEqual(false);
|
|
||||||
expect(ValidationUtils.isValidDomainName('example')).toEqual(false);
|
|
||||||
expect(ValidationUtils.isValidDomainName('example.')).toEqual(false);
|
|
||||||
});
|
|
||||||
|
|
||||||
tap.test('validation-utils - isValidEmail', async () => {
|
|
||||||
// Valid email addresses
|
|
||||||
expect(ValidationUtils.isValidEmail('user@example.com')).toEqual(true);
|
|
||||||
expect(ValidationUtils.isValidEmail('admin@sub.example.com')).toEqual(true);
|
|
||||||
expect(ValidationUtils.isValidEmail('first.last@example.com')).toEqual(true);
|
|
||||||
expect(ValidationUtils.isValidEmail('user+tag@example.com')).toEqual(true);
|
|
||||||
|
|
||||||
// Invalid email addresses
|
|
||||||
expect(ValidationUtils.isValidEmail('')).toEqual(false);
|
|
||||||
expect(ValidationUtils.isValidEmail(null as any)).toEqual(false);
|
|
||||||
expect(ValidationUtils.isValidEmail(undefined as any)).toEqual(false);
|
|
||||||
expect(ValidationUtils.isValidEmail('user')).toEqual(false);
|
|
||||||
expect(ValidationUtils.isValidEmail('user@')).toEqual(false);
|
|
||||||
expect(ValidationUtils.isValidEmail('@example.com')).toEqual(false);
|
|
||||||
expect(ValidationUtils.isValidEmail('user example.com')).toEqual(false);
|
|
||||||
});
|
|
||||||
|
|
||||||
tap.test('validation-utils - isValidCertificate', async () => {
|
|
||||||
// Valid certificate format
|
|
||||||
const validCert = `-----BEGIN CERTIFICATE-----
|
|
||||||
MIIDazCCAlOgAwIBAgIUJlq+zz9CO2E91rlD4vhx0CX1Z/kwDQYJKoZIhvcNAQEL
|
|
||||||
BQAwRTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM
|
|
||||||
GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDAeFw0yMzAxMDEwMDAwMDBaFw0yNDAx
|
|
||||||
MDEwMDAwMDBaMEUxCzAJBgNVBAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEw
|
|
||||||
HwYDVQQKDBhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGQwggEiMA0GCSqGSIb3DQEB
|
|
||||||
AQUAA4IBDwAwggEKAoIBAQC0aQeHIV9vQpZ4UVwW/xhx9zl01UbppLXdoqe3NP9x
|
|
||||||
KfXTCB1YbtJ4GgKIlQqHGLGsLI5ZOE7KxmJeGEwK7ueP4f3WkUlM5C5yTbZ5hSUo
|
|
||||||
R+OFnszFRJJiBXJlw57YAW9+zqKQHYxwve64O64dlgw6pekDYJhXtrUUZ78Lz0GX
|
|
||||||
veJvCrci1M4Xk6/7/p1Ii9PNmbPKqHafdmkFLf6TXiWPuRDhPuHW7cXyE8xD5ahr
|
|
||||||
NsDuwJyRUk+GS4/oJg0TqLSiD0IPxDH50V5MSfUIB82i+lc1t+OAGwLhjUDuQmJi
|
|
||||||
Pv1+9Zvv+HA5PXBCsGXnSADrOOUO6t9q5R9PXbSvAgMBAAGjUzBRMB0GA1UdDgQW
|
|
||||||
BBQEtdtBhH/z1XyIf+y+5O9ErDGCVjAfBgNVHSMEGDAWgBQEtdtBhH/z1XyIf+y+
|
|
||||||
5O9ErDGCVjAPBgNVHRMBAf8EBTADAQH/MA0GCSqGSIb3DQEBCwUAA4IBAQBmJyQ0
|
|
||||||
r0pBJkYJJVDJ6i3WMoEEFTD8MEUkWxASHRnuMzm7XlZ8WS1HvbEWF0+WfJPCYHnk
|
|
||||||
tGbvUFGaZ4qUxZ4Ip2mvKXoeYTJCZRxxhHeSVWnZZu0KS3X7xVAFwQYQNhdLOqP8
|
|
||||||
XOHyLhHV/1/kcFd3GvKKjXxE79jUUZ/RXHZ/IY50KvxGzWc/5ZOFYrPEW1/rNlRo
|
|
||||||
7ixXo1hNnBQsG1YoFAxTBGegdTFJeTYHYjZZ5XlRvY2aBq6QveRbJGJLcPm1UQMd
|
|
||||||
HQYxacbWSVAQf3ltYwSH+y3a97C5OsJJiQXpRRJlQKL3txklzcpg3E5swhr63bM2
|
|
||||||
jUoNXr5G5Q5h3GD5
|
|
||||||
-----END CERTIFICATE-----`;
|
|
||||||
|
|
||||||
expect(ValidationUtils.isValidCertificate(validCert)).toEqual(true);
|
|
||||||
|
|
||||||
// Invalid certificate format
|
|
||||||
expect(ValidationUtils.isValidCertificate('')).toEqual(false);
|
|
||||||
expect(ValidationUtils.isValidCertificate(null as any)).toEqual(false);
|
|
||||||
expect(ValidationUtils.isValidCertificate(undefined as any)).toEqual(false);
|
|
||||||
expect(ValidationUtils.isValidCertificate('invalid certificate')).toEqual(false);
|
|
||||||
expect(ValidationUtils.isValidCertificate('-----BEGIN CERTIFICATE-----')).toEqual(false);
|
|
||||||
});
|
|
||||||
|
|
||||||
tap.test('validation-utils - isValidPrivateKey', async () => {
|
|
||||||
// Valid private key format
|
|
||||||
const validKey = `-----BEGIN PRIVATE KEY-----
|
|
||||||
MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQC0aQeHIV9vQpZ4
|
|
||||||
UVwW/xhx9zl01UbppLXdoqe3NP9xKfXTCB1YbtJ4GgKIlQqHGLGsLI5ZOE7KxmJe
|
|
||||||
GEwK7ueP4f3WkUlM5C5yTbZ5hSUoR+OFnszFRJJiBXJlw57YAW9+zqKQHYxwve64
|
|
||||||
O64dlgw6pekDYJhXtrUUZ78Lz0GXveJvCrci1M4Xk6/7/p1Ii9PNmbPKqHafdmkF
|
|
||||||
Lf6TXiWPuRDhPuHW7cXyE8xD5ahrNsDuwJyRUk+GS4/oJg0TqLSiD0IPxDH50V5M
|
|
||||||
SfUIB82i+lc1t+OAGwLhjUDuQmJiPv1+9Zvv+HA5PXBCsGXnSADrOOUO6t9q5R9P
|
|
||||||
XbSvAgMBAAECggEADw8Xx9iEv3FvS8hYIRn2ZWM8ObRgbHkFN92NJ/5RvUwgyV03
|
|
||||||
gG8GwVN+7IsVLnIQRyIYEGGJ0ZLZFIq7//Jy0jYUgEGLmXxknuZQn1cQEqqYVyBr
|
|
||||||
G9JrfKkXaDEoP/bZBMvZ0KEO2C9Vq6mY8M0h0GxDT2y6UQnQYjH3+H6Rvhbhh+Ld
|
|
||||||
n8lCJqWoW1t9GOUZ4xLsZ5jEDibcMJJzLBWYRxgHWyECK31/VtEQDKFiUcymrJ3I
|
|
||||||
/zoDEDGbp1gdJHvlCxfSLJ2za7ErtRKRXYFRhZ9QkNSXl1pVFMqRQkedXIcA1/Cs
|
|
||||||
VpUxiIE2JA3hSrv2csjmXoGJKDLVCvZ3CFxKL3u/AQKBgQDf6MxHXN3IDuJNrJP7
|
|
||||||
0gyRbO5d6vcvP/8qiYjtEt2xB2MNt5jDz9Bxl6aKEdNW2+UE0rvXXT6KAMZv9LiF
|
|
||||||
hxr5qiJmmSB8OeGfr0W4FCixGN4BkRNwfT1gUqZgQOrfMOLHNXOksc1CJwHJfROV
|
|
||||||
h6AH+gjtF2BCXnVEHcqtRklk4QKBgQDOOYnLJn1CwgFAyRUYK8LQYKnrLp2cGn7N
|
|
||||||
YH0SLf+VnCu7RCeNr3dm9FoHBCynjkx+qv9kGvCaJuZqEJ7+7IimNUZfDjwXTOJ+
|
|
||||||
pzs8kEPN5EQOcbkmYCTQyOA0YeBuEXcv5xIZRZUYQvKg1xXOe/JhAQ4siVIMhgQL
|
|
||||||
2XR3QwzRDwKBgB7rjZs2VYnuVExGr74lUUAGoZ71WCgt9Du9aYGJfNUriDtTEWAd
|
|
||||||
VT5sKgVqpRwkY/zXujdxGr+K8DZu4vSdHBLcDLQsEBvRZIILTzjwXBRPGMnVe95v
|
|
||||||
Q90+vytbmHshlkbMaVRNQxCjdbf7LbQbLecgRt+5BKxHVwL4u3BZNIqhAoGAas4f
|
|
||||||
PoPOdFfKAMKZL7FLGMhEXLyFsg1JcGRfmByxTNgOJKXpYv5Hl7JLYOvfaiUOUYKI
|
|
||||||
5Dnh5yLdFOaOjnB3iP0KEiSVEwZK0/Vna5JkzFTqImK9QD3SQCtQLXHJLD52EPFR
|
|
||||||
9gRa8N5k68+mIzGDEzPBoC1AajbXFGPxNOwaQQ0CgYEAq0dPYK0TTv3Yez27LzVy
|
|
||||||
RbHkwpE+df4+KhpHbCzUKzfQYo4WTahlR6IzhpOyVQKIptkjuTDyQzkmt0tXEGw3
|
|
||||||
/M3yHa1FcY9IzPrHXHJoOeU1r9ay0GOQUi4FxKkYYWxUCtjOi5xlUxI0ABD8vGGR
|
|
||||||
QbKMrQXRgLd/84nDnY2cYzA=
|
|
||||||
-----END PRIVATE KEY-----`;
|
|
||||||
|
|
||||||
expect(ValidationUtils.isValidPrivateKey(validKey)).toEqual(true);
|
|
||||||
|
|
||||||
// Invalid private key format
|
|
||||||
expect(ValidationUtils.isValidPrivateKey('')).toEqual(false);
|
|
||||||
expect(ValidationUtils.isValidPrivateKey(null as any)).toEqual(false);
|
|
||||||
expect(ValidationUtils.isValidPrivateKey(undefined as any)).toEqual(false);
|
|
||||||
expect(ValidationUtils.isValidPrivateKey('invalid key')).toEqual(false);
|
|
||||||
expect(ValidationUtils.isValidPrivateKey('-----BEGIN PRIVATE KEY-----')).toEqual(false);
|
|
||||||
});
|
|
||||||
|
|
||||||
tap.test('validation-utils - validateDomainOptions', async () => {
|
|
||||||
// Valid domain options
|
|
||||||
const validDomainOptions: IDomainOptions = {
|
|
||||||
domainName: 'example.com',
|
|
||||||
sslRedirect: true,
|
|
||||||
acmeMaintenance: true
|
|
||||||
};
|
|
||||||
|
|
||||||
expect(ValidationUtils.validateDomainOptions(validDomainOptions).isValid).toEqual(true);
|
|
||||||
|
|
||||||
// Valid domain options with forward
|
|
||||||
const validDomainOptionsWithForward: IDomainOptions = {
|
|
||||||
domainName: 'example.com',
|
|
||||||
sslRedirect: true,
|
|
||||||
acmeMaintenance: true,
|
|
||||||
forward: {
|
|
||||||
ip: '127.0.0.1',
|
|
||||||
port: 8080
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
expect(ValidationUtils.validateDomainOptions(validDomainOptionsWithForward).isValid).toEqual(true);
|
|
||||||
|
|
||||||
// Invalid domain options - no domain name
|
|
||||||
const invalidDomainOptions1: IDomainOptions = {
|
|
||||||
domainName: '',
|
|
||||||
sslRedirect: true,
|
|
||||||
acmeMaintenance: true
|
|
||||||
};
|
|
||||||
|
|
||||||
expect(ValidationUtils.validateDomainOptions(invalidDomainOptions1).isValid).toEqual(false);
|
|
||||||
|
|
||||||
// Invalid domain options - invalid domain name
|
|
||||||
const invalidDomainOptions2: IDomainOptions = {
|
|
||||||
domainName: 'inv@lid.com',
|
|
||||||
sslRedirect: true,
|
|
||||||
acmeMaintenance: true
|
|
||||||
};
|
|
||||||
|
|
||||||
expect(ValidationUtils.validateDomainOptions(invalidDomainOptions2).isValid).toEqual(false);
|
|
||||||
|
|
||||||
// Invalid domain options - forward missing ip
|
|
||||||
const invalidDomainOptions3: IDomainOptions = {
|
|
||||||
domainName: 'example.com',
|
|
||||||
sslRedirect: true,
|
|
||||||
acmeMaintenance: true,
|
|
||||||
forward: {
|
|
||||||
ip: '',
|
|
||||||
port: 8080
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
expect(ValidationUtils.validateDomainOptions(invalidDomainOptions3).isValid).toEqual(false);
|
|
||||||
|
|
||||||
// Invalid domain options - forward missing port
|
|
||||||
const invalidDomainOptions4: IDomainOptions = {
|
|
||||||
domainName: 'example.com',
|
|
||||||
sslRedirect: true,
|
|
||||||
acmeMaintenance: true,
|
|
||||||
forward: {
|
|
||||||
ip: '127.0.0.1',
|
|
||||||
port: null as any
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
expect(ValidationUtils.validateDomainOptions(invalidDomainOptions4).isValid).toEqual(false);
|
|
||||||
|
|
||||||
// Invalid domain options - invalid forward port
|
|
||||||
const invalidDomainOptions5: IDomainOptions = {
|
|
||||||
domainName: 'example.com',
|
|
||||||
sslRedirect: true,
|
|
||||||
acmeMaintenance: true,
|
|
||||||
forward: {
|
|
||||||
ip: '127.0.0.1',
|
|
||||||
port: 99999
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
expect(ValidationUtils.validateDomainOptions(invalidDomainOptions5).isValid).toEqual(false);
|
|
||||||
});
|
|
||||||
|
|
||||||
tap.test('validation-utils - validateAcmeOptions', async () => {
|
|
||||||
// Valid ACME options
|
|
||||||
const validAcmeOptions: IAcmeOptions = {
|
|
||||||
enabled: true,
|
|
||||||
accountEmail: 'admin@example.com',
|
|
||||||
port: 80,
|
|
||||||
httpsRedirectPort: 443,
|
|
||||||
useProduction: false,
|
|
||||||
renewThresholdDays: 30,
|
|
||||||
renewCheckIntervalHours: 24,
|
|
||||||
certificateStore: './certs'
|
|
||||||
};
|
|
||||||
|
|
||||||
expect(ValidationUtils.validateAcmeOptions(validAcmeOptions).isValid).toEqual(true);
|
|
||||||
|
|
||||||
// ACME disabled - should be valid regardless of other options
|
|
||||||
const disabledAcmeOptions: IAcmeOptions = {
|
|
||||||
enabled: false
|
|
||||||
};
|
|
||||||
|
|
||||||
// Don't need to verify other fields when ACME is disabled
|
|
||||||
const disabledResult = ValidationUtils.validateAcmeOptions(disabledAcmeOptions);
|
|
||||||
expect(disabledResult.isValid).toEqual(true);
|
|
||||||
|
|
||||||
// Invalid ACME options - missing email
|
|
||||||
const invalidAcmeOptions1: IAcmeOptions = {
|
|
||||||
enabled: true,
|
|
||||||
accountEmail: '',
|
|
||||||
port: 80
|
|
||||||
};
|
|
||||||
|
|
||||||
expect(ValidationUtils.validateAcmeOptions(invalidAcmeOptions1).isValid).toEqual(false);
|
|
||||||
|
|
||||||
// Invalid ACME options - invalid email
|
|
||||||
const invalidAcmeOptions2: IAcmeOptions = {
|
|
||||||
enabled: true,
|
|
||||||
accountEmail: 'invalid-email',
|
|
||||||
port: 80
|
|
||||||
};
|
|
||||||
|
|
||||||
expect(ValidationUtils.validateAcmeOptions(invalidAcmeOptions2).isValid).toEqual(false);
|
|
||||||
|
|
||||||
// Invalid ACME options - invalid port
|
|
||||||
const invalidAcmeOptions3: IAcmeOptions = {
|
|
||||||
enabled: true,
|
|
||||||
accountEmail: 'admin@example.com',
|
|
||||||
port: 99999
|
|
||||||
};
|
|
||||||
|
|
||||||
expect(ValidationUtils.validateAcmeOptions(invalidAcmeOptions3).isValid).toEqual(false);
|
|
||||||
|
|
||||||
// Invalid ACME options - invalid HTTPS redirect port
|
|
||||||
const invalidAcmeOptions4: IAcmeOptions = {
|
|
||||||
enabled: true,
|
|
||||||
accountEmail: 'admin@example.com',
|
|
||||||
port: 80,
|
|
||||||
httpsRedirectPort: -1
|
|
||||||
};
|
|
||||||
|
|
||||||
expect(ValidationUtils.validateAcmeOptions(invalidAcmeOptions4).isValid).toEqual(false);
|
|
||||||
|
|
||||||
// Invalid ACME options - invalid renew threshold days
|
|
||||||
const invalidAcmeOptions5: IAcmeOptions = {
|
|
||||||
enabled: true,
|
|
||||||
accountEmail: 'admin@example.com',
|
|
||||||
port: 80,
|
|
||||||
renewThresholdDays: 0
|
|
||||||
};
|
|
||||||
|
|
||||||
// The implementation allows renewThresholdDays of 0, even though the docstring suggests otherwise
|
|
||||||
const validationResult5 = ValidationUtils.validateAcmeOptions(invalidAcmeOptions5);
|
|
||||||
expect(validationResult5.isValid).toEqual(true);
|
|
||||||
|
|
||||||
// Invalid ACME options - invalid renew check interval hours
|
|
||||||
const invalidAcmeOptions6: IAcmeOptions = {
|
|
||||||
enabled: true,
|
|
||||||
accountEmail: 'admin@example.com',
|
|
||||||
port: 80,
|
|
||||||
renewCheckIntervalHours: 0
|
|
||||||
};
|
|
||||||
|
|
||||||
// The implementation should validate this, but let's check the actual result
|
|
||||||
const checkIntervalResult = ValidationUtils.validateAcmeOptions(invalidAcmeOptions6);
|
|
||||||
// Adjust test to match actual implementation behavior
|
|
||||||
expect(checkIntervalResult.isValid !== false ? true : false).toEqual(true);
|
|
||||||
});
|
|
||||||
|
|
||||||
export default tap.start();
|
|
||||||
@@ -1,146 +0,0 @@
|
|||||||
import { expect, tap } from '@git.zone/tstest/tapbundle';
|
|
||||||
import * as smartproxy from '../ts/index.js';
|
|
||||||
|
|
||||||
tap.test('Protocol Detection - TLS Detection', async () => {
|
|
||||||
// Test TLS handshake detection
|
|
||||||
const tlsHandshake = Buffer.from([
|
|
||||||
0x16, // Handshake record type
|
|
||||||
0x03, 0x01, // TLS 1.0
|
|
||||||
0x00, 0x05, // Length: 5 bytes
|
|
||||||
0x01, // ClientHello
|
|
||||||
0x00, 0x00, 0x01, 0x00 // Handshake length and data
|
|
||||||
]);
|
|
||||||
|
|
||||||
const detector = new smartproxy.detection.TlsDetector();
|
|
||||||
expect(detector.canHandle(tlsHandshake)).toEqual(true);
|
|
||||||
|
|
||||||
const result = detector.detect(tlsHandshake);
|
|
||||||
expect(result).toBeDefined();
|
|
||||||
expect(result?.protocol).toEqual('tls');
|
|
||||||
expect(result?.connectionInfo.tlsVersion).toEqual('TLSv1.0');
|
|
||||||
});
|
|
||||||
|
|
||||||
tap.test('Protocol Detection - HTTP Detection', async () => {
|
|
||||||
// Test HTTP request detection
|
|
||||||
const httpRequest = Buffer.from(
|
|
||||||
'GET /test HTTP/1.1\r\n' +
|
|
||||||
'Host: example.com\r\n' +
|
|
||||||
'User-Agent: TestClient/1.0\r\n' +
|
|
||||||
'\r\n'
|
|
||||||
);
|
|
||||||
|
|
||||||
const detector = new smartproxy.detection.HttpDetector();
|
|
||||||
expect(detector.canHandle(httpRequest)).toEqual(true);
|
|
||||||
|
|
||||||
const result = detector.detect(httpRequest);
|
|
||||||
expect(result).toBeDefined();
|
|
||||||
expect(result?.protocol).toEqual('http');
|
|
||||||
expect(result?.connectionInfo.method).toEqual('GET');
|
|
||||||
expect(result?.connectionInfo.path).toEqual('/test');
|
|
||||||
expect(result?.connectionInfo.domain).toEqual('example.com');
|
|
||||||
});
|
|
||||||
|
|
||||||
tap.test('Protocol Detection - Main Detector TLS', async () => {
|
|
||||||
const tlsHandshake = Buffer.from([
|
|
||||||
0x16, // Handshake record type
|
|
||||||
0x03, 0x03, // TLS 1.2
|
|
||||||
0x00, 0x05, // Length: 5 bytes
|
|
||||||
0x01, // ClientHello
|
|
||||||
0x00, 0x00, 0x01, 0x00 // Handshake length and data
|
|
||||||
]);
|
|
||||||
|
|
||||||
const result = await smartproxy.detection.ProtocolDetector.detect(tlsHandshake);
|
|
||||||
expect(result.protocol).toEqual('tls');
|
|
||||||
expect(result.connectionInfo.tlsVersion).toEqual('TLSv1.2');
|
|
||||||
});
|
|
||||||
|
|
||||||
tap.test('Protocol Detection - Main Detector HTTP', async () => {
|
|
||||||
const httpRequest = Buffer.from(
|
|
||||||
'POST /api/test HTTP/1.1\r\n' +
|
|
||||||
'Host: api.example.com\r\n' +
|
|
||||||
'Content-Type: application/json\r\n' +
|
|
||||||
'Content-Length: 2\r\n' +
|
|
||||||
'\r\n' +
|
|
||||||
'{}'
|
|
||||||
);
|
|
||||||
|
|
||||||
const result = await smartproxy.detection.ProtocolDetector.detect(httpRequest);
|
|
||||||
expect(result.protocol).toEqual('http');
|
|
||||||
expect(result.connectionInfo.method).toEqual('POST');
|
|
||||||
expect(result.connectionInfo.path).toEqual('/api/test');
|
|
||||||
expect(result.connectionInfo.domain).toEqual('api.example.com');
|
|
||||||
});
|
|
||||||
|
|
||||||
tap.test('Protocol Detection - Unknown Protocol', async () => {
|
|
||||||
const unknownData = Buffer.from('UNKNOWN PROTOCOL DATA\r\n');
|
|
||||||
|
|
||||||
const result = await smartproxy.detection.ProtocolDetector.detect(unknownData);
|
|
||||||
expect(result.protocol).toEqual('unknown');
|
|
||||||
expect(result.isComplete).toEqual(true);
|
|
||||||
});
|
|
||||||
|
|
||||||
tap.test('Protocol Detection - Fragmented HTTP', async () => {
|
|
||||||
// Create connection context
|
|
||||||
const context = smartproxy.detection.ProtocolDetector.createConnectionContext({
|
|
||||||
sourceIp: '127.0.0.1',
|
|
||||||
sourcePort: 12345,
|
|
||||||
destIp: '127.0.0.1',
|
|
||||||
destPort: 80,
|
|
||||||
socketId: 'test-connection-1'
|
|
||||||
});
|
|
||||||
|
|
||||||
// First fragment
|
|
||||||
const fragment1 = Buffer.from('GET /test HT');
|
|
||||||
let result = await smartproxy.detection.ProtocolDetector.detectWithContext(
|
|
||||||
fragment1,
|
|
||||||
context
|
|
||||||
);
|
|
||||||
expect(result.protocol).toEqual('http');
|
|
||||||
expect(result.isComplete).toEqual(false);
|
|
||||||
|
|
||||||
// Second fragment
|
|
||||||
const fragment2 = Buffer.from('TP/1.1\r\nHost: example.com\r\n\r\n');
|
|
||||||
result = await smartproxy.detection.ProtocolDetector.detectWithContext(
|
|
||||||
fragment2,
|
|
||||||
context
|
|
||||||
);
|
|
||||||
expect(result.protocol).toEqual('http');
|
|
||||||
expect(result.isComplete).toEqual(true);
|
|
||||||
expect(result.connectionInfo.method).toEqual('GET');
|
|
||||||
expect(result.connectionInfo.path).toEqual('/test');
|
|
||||||
expect(result.connectionInfo.domain).toEqual('example.com');
|
|
||||||
|
|
||||||
// Clean up fragments
|
|
||||||
smartproxy.detection.ProtocolDetector.cleanupConnection(context);
|
|
||||||
});
|
|
||||||
|
|
||||||
tap.test('Protocol Detection - HTTP Methods', async () => {
|
|
||||||
const methods = ['GET', 'POST', 'PUT', 'DELETE', 'PATCH', 'HEAD', 'OPTIONS'];
|
|
||||||
|
|
||||||
for (const method of methods) {
|
|
||||||
const request = Buffer.from(
|
|
||||||
`${method} /test HTTP/1.1\r\n` +
|
|
||||||
'Host: example.com\r\n' +
|
|
||||||
'\r\n'
|
|
||||||
);
|
|
||||||
|
|
||||||
const detector = new smartproxy.detection.HttpDetector();
|
|
||||||
const result = detector.detect(request);
|
|
||||||
expect(result?.connectionInfo.method).toEqual(method);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
tap.test('Protocol Detection - Invalid Data', async () => {
|
|
||||||
// Binary data that's not a valid protocol
|
|
||||||
const binaryData = Buffer.from([0xFF, 0xFE, 0xFD, 0xFC, 0xFB]);
|
|
||||||
|
|
||||||
const result = await smartproxy.detection.ProtocolDetector.detect(binaryData);
|
|
||||||
expect(result.protocol).toEqual('unknown');
|
|
||||||
});
|
|
||||||
|
|
||||||
tap.test('cleanup detection', async () => {
|
|
||||||
// Clean up the protocol detector instance
|
|
||||||
smartproxy.detection.ProtocolDetector.destroy();
|
|
||||||
});
|
|
||||||
|
|
||||||
export default tap.start();
|
|
||||||
@@ -1,128 +0,0 @@
|
|||||||
import { expect, tap } from '@git.zone/tstest/tapbundle';
|
|
||||||
import * as smartproxy from '../ts/index.js';
|
|
||||||
import { RouteValidator } from '../ts/proxies/smart-proxy/utils/route-validator.js';
|
|
||||||
import { IpUtils } from '../ts/core/utils/ip-utils.js';
|
|
||||||
|
|
||||||
tap.test('IP Validation - Shorthand patterns', async () => {
|
|
||||||
|
|
||||||
// Test shorthand patterns are now accepted
|
|
||||||
const testPatterns = [
|
|
||||||
{ pattern: '192.168.*', shouldPass: true },
|
|
||||||
{ pattern: '192.168.*.*', shouldPass: true },
|
|
||||||
{ pattern: '10.*', shouldPass: true },
|
|
||||||
{ pattern: '10.*.*.*', shouldPass: true },
|
|
||||||
{ pattern: '172.16.*', shouldPass: true },
|
|
||||||
{ pattern: '10.0.0.0/8', shouldPass: true },
|
|
||||||
{ pattern: '192.168.0.0/16', shouldPass: true },
|
|
||||||
{ pattern: '192.168.1.100', shouldPass: true },
|
|
||||||
{ pattern: '*', shouldPass: true },
|
|
||||||
{ pattern: '192.168.1.1-192.168.1.100', shouldPass: true },
|
|
||||||
];
|
|
||||||
|
|
||||||
for (const { pattern, shouldPass } of testPatterns) {
|
|
||||||
const route = {
|
|
||||||
name: 'test',
|
|
||||||
match: { ports: 80 },
|
|
||||||
action: { type: 'forward' as const, targets: [{ host: 'localhost', port: 8080 }] },
|
|
||||||
security: { ipAllowList: [pattern] }
|
|
||||||
};
|
|
||||||
|
|
||||||
const result = RouteValidator.validateRoute(route);
|
|
||||||
|
|
||||||
if (shouldPass) {
|
|
||||||
expect(result.valid).toEqual(true);
|
|
||||||
console.log(`✅ Pattern '${pattern}' correctly accepted`);
|
|
||||||
} else {
|
|
||||||
expect(result.valid).toEqual(false);
|
|
||||||
console.log(`✅ Pattern '${pattern}' correctly rejected`);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
tap.test('IP Matching - Runtime shorthand pattern matching', async () => {
|
|
||||||
|
|
||||||
// Test runtime matching with shorthand patterns
|
|
||||||
const testCases = [
|
|
||||||
{ ip: '192.168.1.100', patterns: ['192.168.*'], expected: true },
|
|
||||||
{ ip: '192.168.1.100', patterns: ['192.168.1.*'], expected: true },
|
|
||||||
{ ip: '192.168.1.100', patterns: ['192.168.2.*'], expected: false },
|
|
||||||
{ ip: '10.0.0.1', patterns: ['10.*'], expected: true },
|
|
||||||
{ ip: '10.1.2.3', patterns: ['10.*'], expected: true },
|
|
||||||
{ ip: '172.16.0.1', patterns: ['10.*'], expected: false },
|
|
||||||
{ ip: '192.168.1.1', patterns: ['192.168.*.*'], expected: true },
|
|
||||||
];
|
|
||||||
|
|
||||||
for (const { ip, patterns, expected } of testCases) {
|
|
||||||
const result = IpUtils.isGlobIPMatch(ip, patterns);
|
|
||||||
expect(result).toEqual(expected);
|
|
||||||
console.log(`✅ IP ${ip} with pattern ${patterns[0]} = ${result} (expected ${expected})`);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
tap.test('IP Matching - CIDR notation', async () => {
|
|
||||||
|
|
||||||
// Test CIDR notation matching
|
|
||||||
const cidrTests = [
|
|
||||||
{ ip: '10.0.0.1', cidr: '10.0.0.0/8', expected: true },
|
|
||||||
{ ip: '10.255.255.255', cidr: '10.0.0.0/8', expected: true },
|
|
||||||
{ ip: '11.0.0.1', cidr: '10.0.0.0/8', expected: false },
|
|
||||||
{ ip: '192.168.1.1', cidr: '192.168.0.0/16', expected: true },
|
|
||||||
{ ip: '192.168.255.255', cidr: '192.168.0.0/16', expected: true },
|
|
||||||
{ ip: '192.169.0.1', cidr: '192.168.0.0/16', expected: false },
|
|
||||||
{ ip: '192.168.1.100', cidr: '192.168.1.0/24', expected: true },
|
|
||||||
{ ip: '192.168.2.100', cidr: '192.168.1.0/24', expected: false },
|
|
||||||
];
|
|
||||||
|
|
||||||
for (const { ip, cidr, expected } of cidrTests) {
|
|
||||||
const result = IpUtils.isGlobIPMatch(ip, [cidr]);
|
|
||||||
expect(result).toEqual(expected);
|
|
||||||
console.log(`✅ IP ${ip} in CIDR ${cidr} = ${result} (expected ${expected})`);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
tap.test('IP Matching - Range notation', async () => {
|
|
||||||
|
|
||||||
// Test range notation matching
|
|
||||||
const rangeTests = [
|
|
||||||
{ ip: '192.168.1.1', range: '192.168.1.1-192.168.1.100', expected: true },
|
|
||||||
{ ip: '192.168.1.50', range: '192.168.1.1-192.168.1.100', expected: true },
|
|
||||||
{ ip: '192.168.1.100', range: '192.168.1.1-192.168.1.100', expected: true },
|
|
||||||
{ ip: '192.168.1.101', range: '192.168.1.1-192.168.1.100', expected: false },
|
|
||||||
{ ip: '192.168.2.50', range: '192.168.1.1-192.168.1.100', expected: false },
|
|
||||||
];
|
|
||||||
|
|
||||||
for (const { ip, range, expected } of rangeTests) {
|
|
||||||
const result = IpUtils.isGlobIPMatch(ip, [range]);
|
|
||||||
expect(result).toEqual(expected);
|
|
||||||
console.log(`✅ IP ${ip} in range ${range} = ${result} (expected ${expected})`);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
tap.test('IP Matching - Mixed patterns', async () => {
|
|
||||||
|
|
||||||
// Test with mixed pattern types
|
|
||||||
const allowList = [
|
|
||||||
'10.0.0.0/8', // CIDR
|
|
||||||
'192.168.*', // Shorthand glob
|
|
||||||
'172.16.1.*', // Specific subnet glob
|
|
||||||
'8.8.8.8', // Single IP
|
|
||||||
'1.1.1.1-1.1.1.10' // Range
|
|
||||||
];
|
|
||||||
|
|
||||||
const tests = [
|
|
||||||
{ ip: '10.1.2.3', expected: true }, // Matches CIDR
|
|
||||||
{ ip: '192.168.100.1', expected: true }, // Matches shorthand glob
|
|
||||||
{ ip: '172.16.1.5', expected: true }, // Matches specific glob
|
|
||||||
{ ip: '8.8.8.8', expected: true }, // Matches single IP
|
|
||||||
{ ip: '1.1.1.5', expected: true }, // Matches range
|
|
||||||
{ ip: '9.9.9.9', expected: false }, // Doesn't match any
|
|
||||||
];
|
|
||||||
|
|
||||||
for (const { ip, expected } of tests) {
|
|
||||||
const result = IpUtils.isGlobIPMatch(ip, allowList);
|
|
||||||
expect(result).toEqual(expected);
|
|
||||||
console.log(`✅ IP ${ip} in mixed patterns = ${result} (expected ${expected})`);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
export default tap.start();
|
|
||||||
@@ -1,112 +0,0 @@
|
|||||||
import { expect, tap } from '@git.zone/tstest/tapbundle';
|
|
||||||
import { LogDeduplicator } from '../ts/core/utils/log-deduplicator.js';
|
|
||||||
|
|
||||||
let deduplicator: LogDeduplicator;
|
|
||||||
|
|
||||||
tap.test('Setup log deduplicator', async () => {
|
|
||||||
deduplicator = new LogDeduplicator(1000); // 1 second flush interval for testing
|
|
||||||
});
|
|
||||||
|
|
||||||
tap.test('Connection rejection deduplication', async (tools) => {
|
|
||||||
// Simulate multiple connection rejections
|
|
||||||
for (let i = 0; i < 10; i++) {
|
|
||||||
deduplicator.log(
|
|
||||||
'connection-rejected',
|
|
||||||
'warn',
|
|
||||||
'Connection rejected',
|
|
||||||
{ reason: 'global-limit', component: 'test' },
|
|
||||||
'global-limit'
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
for (let i = 0; i < 5; i++) {
|
|
||||||
deduplicator.log(
|
|
||||||
'connection-rejected',
|
|
||||||
'warn',
|
|
||||||
'Connection rejected',
|
|
||||||
{ reason: 'route-limit', component: 'test' },
|
|
||||||
'route-limit'
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Force flush
|
|
||||||
deduplicator.flush('connection-rejected');
|
|
||||||
|
|
||||||
// The logs should have been aggregated
|
|
||||||
// (Can't easily test the actual log output, but we can verify the mechanism works)
|
|
||||||
expect(deduplicator).toBeInstanceOf(LogDeduplicator);
|
|
||||||
});
|
|
||||||
|
|
||||||
tap.test('IP rejection deduplication', async (tools) => {
|
|
||||||
// Simulate rejections from multiple IPs
|
|
||||||
const ips = ['192.168.1.100', '192.168.1.101', '192.168.1.100', '10.0.0.1'];
|
|
||||||
const reasons = ['per-ip-limit', 'rate-limit', 'per-ip-limit', 'global-limit'];
|
|
||||||
|
|
||||||
for (let i = 0; i < ips.length; i++) {
|
|
||||||
deduplicator.log(
|
|
||||||
'ip-rejected',
|
|
||||||
'warn',
|
|
||||||
`Connection rejected from ${ips[i]}`,
|
|
||||||
{ remoteIP: ips[i], reason: reasons[i] },
|
|
||||||
ips[i]
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add more rejections from the same IP
|
|
||||||
for (let i = 0; i < 20; i++) {
|
|
||||||
deduplicator.log(
|
|
||||||
'ip-rejected',
|
|
||||||
'warn',
|
|
||||||
'Connection rejected from 192.168.1.100',
|
|
||||||
{ remoteIP: '192.168.1.100', reason: 'rate-limit' },
|
|
||||||
'192.168.1.100'
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Force flush
|
|
||||||
deduplicator.flush('ip-rejected');
|
|
||||||
|
|
||||||
// Verify the deduplicator exists and works
|
|
||||||
expect(deduplicator).toBeInstanceOf(LogDeduplicator);
|
|
||||||
});
|
|
||||||
|
|
||||||
tap.test('Connection cleanup deduplication', async (tools) => {
|
|
||||||
// Simulate various cleanup events
|
|
||||||
const reasons = ['normal', 'timeout', 'error', 'normal', 'zombie'];
|
|
||||||
|
|
||||||
for (const reason of reasons) {
|
|
||||||
for (let i = 0; i < 5; i++) {
|
|
||||||
deduplicator.log(
|
|
||||||
'connection-cleanup',
|
|
||||||
'info',
|
|
||||||
`Connection cleanup: ${reason}`,
|
|
||||||
{ connectionId: `conn-${i}`, reason },
|
|
||||||
reason
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wait for automatic flush
|
|
||||||
await tools.delayFor(1500);
|
|
||||||
|
|
||||||
// Verify deduplicator is working
|
|
||||||
expect(deduplicator).toBeInstanceOf(LogDeduplicator);
|
|
||||||
});
|
|
||||||
|
|
||||||
tap.test('Automatic periodic flush', async (tools) => {
|
|
||||||
// Add some events
|
|
||||||
deduplicator.log('test-event', 'info', 'Test message', {}, 'test');
|
|
||||||
|
|
||||||
// Wait for automatic flush (should happen within 2x flush interval = 2 seconds)
|
|
||||||
await tools.delayFor(2500);
|
|
||||||
|
|
||||||
// Events should have been flushed automatically
|
|
||||||
expect(deduplicator).toBeInstanceOf(LogDeduplicator);
|
|
||||||
});
|
|
||||||
|
|
||||||
tap.test('Cleanup deduplicator', async () => {
|
|
||||||
deduplicator.cleanup();
|
|
||||||
expect(deduplicator).toBeInstanceOf(LogDeduplicator);
|
|
||||||
});
|
|
||||||
|
|
||||||
export default tap.start();
|
|
||||||
@@ -1,403 +0,0 @@
|
|||||||
import { expect, tap } from '@git.zone/tstest/tapbundle';
|
|
||||||
import * as http from 'http';
|
|
||||||
import { HttpRouter, type RouterResult } from '../ts/routing/router/http-router.js';
|
|
||||||
import type { IRouteConfig } from '../ts/proxies/smart-proxy/models/route-types.js';
|
|
||||||
|
|
||||||
// Test proxies and configurations
|
|
||||||
let router: HttpRouter;
|
|
||||||
|
|
||||||
// Sample hostname for testing
|
|
||||||
const TEST_DOMAIN = 'example.com';
|
|
||||||
const TEST_SUBDOMAIN = 'api.example.com';
|
|
||||||
const TEST_WILDCARD = '*.example.com';
|
|
||||||
|
|
||||||
// Helper: Creates a mock HTTP request for testing
|
|
||||||
function createMockRequest(host: string, url: string = '/'): http.IncomingMessage {
|
|
||||||
const req = {
|
|
||||||
headers: { host },
|
|
||||||
url,
|
|
||||||
socket: {
|
|
||||||
remoteAddress: '127.0.0.1'
|
|
||||||
}
|
|
||||||
} as any;
|
|
||||||
return req;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Helper: Creates a test route configuration
|
|
||||||
function createRouteConfig(
|
|
||||||
hostname: string,
|
|
||||||
destinationIp: string = '10.0.0.1',
|
|
||||||
destinationPort: number = 8080
|
|
||||||
): IRouteConfig {
|
|
||||||
return {
|
|
||||||
name: `route-${hostname}`,
|
|
||||||
match: {
|
|
||||||
domains: [hostname],
|
|
||||||
ports: 443
|
|
||||||
},
|
|
||||||
action: {
|
|
||||||
type: 'forward',
|
|
||||||
targets: [{
|
|
||||||
host: destinationIp,
|
|
||||||
port: destinationPort
|
|
||||||
}]
|
|
||||||
}
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
// SETUP: Create an HttpRouter instance
|
|
||||||
tap.test('setup http router test environment', async () => {
|
|
||||||
router = new HttpRouter();
|
|
||||||
|
|
||||||
// Initialize with empty config
|
|
||||||
router.setRoutes([]);
|
|
||||||
});
|
|
||||||
|
|
||||||
// Test basic routing by hostname
|
|
||||||
tap.test('should route requests by hostname', async () => {
|
|
||||||
const config = createRouteConfig(TEST_DOMAIN);
|
|
||||||
router.setRoutes([config]);
|
|
||||||
|
|
||||||
const req = createMockRequest(TEST_DOMAIN);
|
|
||||||
const result = router.routeReq(req);
|
|
||||||
|
|
||||||
expect(result).toBeTruthy();
|
|
||||||
expect(result).toEqual(config);
|
|
||||||
});
|
|
||||||
|
|
||||||
// Test handling of hostname with port number
|
|
||||||
tap.test('should handle hostname with port number', async () => {
|
|
||||||
const config = createRouteConfig(TEST_DOMAIN);
|
|
||||||
router.setRoutes([config]);
|
|
||||||
|
|
||||||
const req = createMockRequest(`${TEST_DOMAIN}:443`);
|
|
||||||
const result = router.routeReq(req);
|
|
||||||
|
|
||||||
expect(result).toBeTruthy();
|
|
||||||
expect(result).toEqual(config);
|
|
||||||
});
|
|
||||||
|
|
||||||
// Test case-insensitive hostname matching
|
|
||||||
tap.test('should perform case-insensitive hostname matching', async () => {
|
|
||||||
const config = createRouteConfig(TEST_DOMAIN.toLowerCase());
|
|
||||||
router.setRoutes([config]);
|
|
||||||
|
|
||||||
const req = createMockRequest(TEST_DOMAIN.toUpperCase());
|
|
||||||
const result = router.routeReq(req);
|
|
||||||
|
|
||||||
expect(result).toBeTruthy();
|
|
||||||
expect(result).toEqual(config);
|
|
||||||
});
|
|
||||||
|
|
||||||
// Test handling of unmatched hostnames
|
|
||||||
tap.test('should return undefined for unmatched hostnames', async () => {
|
|
||||||
const config = createRouteConfig(TEST_DOMAIN);
|
|
||||||
router.setRoutes([config]);
|
|
||||||
|
|
||||||
const req = createMockRequest('unknown.domain.com');
|
|
||||||
const result = router.routeReq(req);
|
|
||||||
|
|
||||||
expect(result).toBeUndefined();
|
|
||||||
});
|
|
||||||
|
|
||||||
// Test adding path patterns
|
|
||||||
tap.test('should match requests using path patterns', async () => {
|
|
||||||
const config = createRouteConfig(TEST_DOMAIN);
|
|
||||||
config.match.path = '/api/users';
|
|
||||||
router.setRoutes([config]);
|
|
||||||
|
|
||||||
// Test that path matches
|
|
||||||
const req1 = createMockRequest(TEST_DOMAIN, '/api/users');
|
|
||||||
const result1 = router.routeReqWithDetails(req1);
|
|
||||||
|
|
||||||
expect(result1).toBeTruthy();
|
|
||||||
expect(result1.route).toEqual(config);
|
|
||||||
expect(result1.pathMatch).toEqual('/api/users');
|
|
||||||
|
|
||||||
// Test that non-matching path doesn't match
|
|
||||||
const req2 = createMockRequest(TEST_DOMAIN, '/web/users');
|
|
||||||
const result2 = router.routeReqWithDetails(req2);
|
|
||||||
|
|
||||||
expect(result2).toBeUndefined();
|
|
||||||
});
|
|
||||||
|
|
||||||
// Test handling wildcard patterns
|
|
||||||
tap.test('should support wildcard path patterns', async () => {
|
|
||||||
const config = createRouteConfig(TEST_DOMAIN);
|
|
||||||
config.match.path = '/api/*';
|
|
||||||
router.setRoutes([config]);
|
|
||||||
|
|
||||||
// Test with path that matches the wildcard pattern
|
|
||||||
const req = createMockRequest(TEST_DOMAIN, '/api/users/123');
|
|
||||||
const result = router.routeReqWithDetails(req);
|
|
||||||
|
|
||||||
expect(result).toBeTruthy();
|
|
||||||
expect(result.route).toEqual(config);
|
|
||||||
expect(result.pathMatch).toEqual('/api');
|
|
||||||
|
|
||||||
// Print the actual value to diagnose issues
|
|
||||||
console.log('Path remainder value:', result.pathRemainder);
|
|
||||||
expect(result.pathRemainder).toBeTruthy();
|
|
||||||
expect(result.pathRemainder).toEqual('/users/123');
|
|
||||||
});
|
|
||||||
|
|
||||||
// Test extracting path parameters
|
|
||||||
tap.test('should extract path parameters from URL', async () => {
|
|
||||||
const config = createRouteConfig(TEST_DOMAIN);
|
|
||||||
config.match.path = '/users/:id/profile';
|
|
||||||
router.setRoutes([config]);
|
|
||||||
|
|
||||||
const req = createMockRequest(TEST_DOMAIN, '/users/123/profile');
|
|
||||||
const result = router.routeReqWithDetails(req);
|
|
||||||
|
|
||||||
expect(result).toBeTruthy();
|
|
||||||
expect(result.route).toEqual(config);
|
|
||||||
expect(result.pathParams).toBeTruthy();
|
|
||||||
expect(result.pathParams.id).toEqual('123');
|
|
||||||
});
|
|
||||||
|
|
||||||
// Test multiple configs for same hostname with different paths
|
|
||||||
tap.test('should support multiple configs for same hostname with different paths', async () => {
|
|
||||||
const apiConfig = createRouteConfig(TEST_DOMAIN, '10.0.0.1', 8001);
|
|
||||||
apiConfig.match.path = '/api/*';
|
|
||||||
apiConfig.name = 'api-route';
|
|
||||||
|
|
||||||
const webConfig = createRouteConfig(TEST_DOMAIN, '10.0.0.2', 8002);
|
|
||||||
webConfig.match.path = '/web/*';
|
|
||||||
webConfig.name = 'web-route';
|
|
||||||
|
|
||||||
// Add both configs
|
|
||||||
router.setRoutes([apiConfig, webConfig]);
|
|
||||||
|
|
||||||
// Test API path routes to API config
|
|
||||||
const apiReq = createMockRequest(TEST_DOMAIN, '/api/users');
|
|
||||||
const apiResult = router.routeReq(apiReq);
|
|
||||||
|
|
||||||
expect(apiResult).toEqual(apiConfig);
|
|
||||||
|
|
||||||
// Test web path routes to web config
|
|
||||||
const webReq = createMockRequest(TEST_DOMAIN, '/web/dashboard');
|
|
||||||
const webResult = router.routeReq(webReq);
|
|
||||||
|
|
||||||
expect(webResult).toEqual(webConfig);
|
|
||||||
|
|
||||||
// Test unknown path returns undefined
|
|
||||||
const unknownReq = createMockRequest(TEST_DOMAIN, '/unknown');
|
|
||||||
const unknownResult = router.routeReq(unknownReq);
|
|
||||||
|
|
||||||
expect(unknownResult).toBeUndefined();
|
|
||||||
});
|
|
||||||
|
|
||||||
// Test wildcard subdomains
|
|
||||||
tap.test('should match wildcard subdomains', async () => {
|
|
||||||
const wildcardConfig = createRouteConfig(TEST_WILDCARD);
|
|
||||||
router.setRoutes([wildcardConfig]);
|
|
||||||
|
|
||||||
// Test that subdomain.example.com matches *.example.com
|
|
||||||
const req = createMockRequest('subdomain.example.com');
|
|
||||||
const result = router.routeReq(req);
|
|
||||||
|
|
||||||
expect(result).toBeTruthy();
|
|
||||||
expect(result).toEqual(wildcardConfig);
|
|
||||||
});
|
|
||||||
|
|
||||||
// Test TLD wildcards (example.*)
|
|
||||||
tap.test('should match TLD wildcards', async () => {
|
|
||||||
const tldWildcardConfig = createRouteConfig('example.*');
|
|
||||||
router.setRoutes([tldWildcardConfig]);
|
|
||||||
|
|
||||||
// Test that example.com matches example.*
|
|
||||||
const req1 = createMockRequest('example.com');
|
|
||||||
const result1 = router.routeReq(req1);
|
|
||||||
expect(result1).toBeTruthy();
|
|
||||||
expect(result1).toEqual(tldWildcardConfig);
|
|
||||||
|
|
||||||
// Test that example.org matches example.*
|
|
||||||
const req2 = createMockRequest('example.org');
|
|
||||||
const result2 = router.routeReq(req2);
|
|
||||||
expect(result2).toBeTruthy();
|
|
||||||
expect(result2).toEqual(tldWildcardConfig);
|
|
||||||
|
|
||||||
// Test that subdomain.example.com doesn't match example.*
|
|
||||||
const req3 = createMockRequest('subdomain.example.com');
|
|
||||||
const result3 = router.routeReq(req3);
|
|
||||||
expect(result3).toBeUndefined();
|
|
||||||
});
|
|
||||||
|
|
||||||
// Test complex pattern matching (*.lossless*)
|
|
||||||
tap.test('should match complex wildcard patterns', async () => {
|
|
||||||
const complexWildcardConfig = createRouteConfig('*.lossless*');
|
|
||||||
router.setRoutes([complexWildcardConfig]);
|
|
||||||
|
|
||||||
// Test that sub.lossless.com matches *.lossless*
|
|
||||||
const req1 = createMockRequest('sub.lossless.com');
|
|
||||||
const result1 = router.routeReq(req1);
|
|
||||||
expect(result1).toBeTruthy();
|
|
||||||
expect(result1).toEqual(complexWildcardConfig);
|
|
||||||
|
|
||||||
// Test that api.lossless.org matches *.lossless*
|
|
||||||
const req2 = createMockRequest('api.lossless.org');
|
|
||||||
const result2 = router.routeReq(req2);
|
|
||||||
expect(result2).toBeTruthy();
|
|
||||||
expect(result2).toEqual(complexWildcardConfig);
|
|
||||||
|
|
||||||
// Test that losslessapi.com matches *.lossless*
|
|
||||||
const req3 = createMockRequest('losslessapi.com');
|
|
||||||
const result3 = router.routeReq(req3);
|
|
||||||
expect(result3).toBeUndefined(); // Should not match as it doesn't have a subdomain
|
|
||||||
});
|
|
||||||
|
|
||||||
// Test default configuration fallback
|
|
||||||
tap.test('should fall back to default configuration', async () => {
|
|
||||||
const defaultConfig = createRouteConfig('*');
|
|
||||||
const specificConfig = createRouteConfig(TEST_DOMAIN);
|
|
||||||
|
|
||||||
router.setRoutes([specificConfig, defaultConfig]);
|
|
||||||
|
|
||||||
// Test specific domain routes to specific config
|
|
||||||
const specificReq = createMockRequest(TEST_DOMAIN);
|
|
||||||
const specificResult = router.routeReq(specificReq);
|
|
||||||
|
|
||||||
expect(specificResult).toEqual(specificConfig);
|
|
||||||
|
|
||||||
// Test unknown domain falls back to default config
|
|
||||||
const unknownReq = createMockRequest('unknown.com');
|
|
||||||
const unknownResult = router.routeReq(unknownReq);
|
|
||||||
|
|
||||||
expect(unknownResult).toEqual(defaultConfig);
|
|
||||||
});
|
|
||||||
|
|
||||||
// Test priority between exact and wildcard matches
|
|
||||||
tap.test('should prioritize exact hostname over wildcard', async () => {
|
|
||||||
const wildcardConfig = createRouteConfig(TEST_WILDCARD);
|
|
||||||
const exactConfig = createRouteConfig(TEST_SUBDOMAIN);
|
|
||||||
|
|
||||||
router.setRoutes([exactConfig, wildcardConfig]);
|
|
||||||
|
|
||||||
// Test that exact match takes priority
|
|
||||||
const req = createMockRequest(TEST_SUBDOMAIN);
|
|
||||||
const result = router.routeReq(req);
|
|
||||||
|
|
||||||
expect(result).toEqual(exactConfig);
|
|
||||||
});
|
|
||||||
|
|
||||||
// Test adding and removing configurations
|
|
||||||
tap.test('should manage configurations correctly', async () => {
|
|
||||||
router.setRoutes([]);
|
|
||||||
|
|
||||||
// Add a config
|
|
||||||
const config = createRouteConfig(TEST_DOMAIN);
|
|
||||||
router.setRoutes([config]);
|
|
||||||
|
|
||||||
// Verify routing works
|
|
||||||
const req = createMockRequest(TEST_DOMAIN);
|
|
||||||
let result = router.routeReq(req);
|
|
||||||
|
|
||||||
expect(result).toEqual(config);
|
|
||||||
|
|
||||||
// Remove the config and verify it no longer routes
|
|
||||||
router.setRoutes([]);
|
|
||||||
|
|
||||||
result = router.routeReq(req);
|
|
||||||
expect(result).toBeUndefined();
|
|
||||||
});
|
|
||||||
|
|
||||||
// Test path pattern specificity
|
|
||||||
tap.test('should prioritize more specific path patterns', async () => {
|
|
||||||
const genericConfig = createRouteConfig(TEST_DOMAIN, '10.0.0.1', 8001);
|
|
||||||
genericConfig.match.path = '/api/*';
|
|
||||||
genericConfig.name = 'generic-api';
|
|
||||||
|
|
||||||
const specificConfig = createRouteConfig(TEST_DOMAIN, '10.0.0.2', 8002);
|
|
||||||
specificConfig.match.path = '/api/users';
|
|
||||||
specificConfig.name = 'specific-api';
|
|
||||||
specificConfig.priority = 10; // Higher priority
|
|
||||||
|
|
||||||
router.setRoutes([genericConfig, specificConfig]);
|
|
||||||
|
|
||||||
// The more specific '/api/users' should match before the '/api/*' wildcard
|
|
||||||
const req = createMockRequest(TEST_DOMAIN, '/api/users');
|
|
||||||
const result = router.routeReq(req);
|
|
||||||
|
|
||||||
expect(result).toEqual(specificConfig);
|
|
||||||
});
|
|
||||||
|
|
||||||
// Test multiple hostnames
|
|
||||||
tap.test('should handle multiple configured hostnames', async () => {
|
|
||||||
const routes = [
|
|
||||||
createRouteConfig(TEST_DOMAIN),
|
|
||||||
createRouteConfig(TEST_SUBDOMAIN)
|
|
||||||
];
|
|
||||||
router.setRoutes(routes);
|
|
||||||
|
|
||||||
// Test first domain routes correctly
|
|
||||||
const req1 = createMockRequest(TEST_DOMAIN);
|
|
||||||
const result1 = router.routeReq(req1);
|
|
||||||
expect(result1).toEqual(routes[0]);
|
|
||||||
|
|
||||||
// Test second domain routes correctly
|
|
||||||
const req2 = createMockRequest(TEST_SUBDOMAIN);
|
|
||||||
const result2 = router.routeReq(req2);
|
|
||||||
expect(result2).toEqual(routes[1]);
|
|
||||||
});
|
|
||||||
|
|
||||||
// Test handling missing host header
|
|
||||||
tap.test('should handle missing host header', async () => {
|
|
||||||
const defaultConfig = createRouteConfig('*');
|
|
||||||
router.setRoutes([defaultConfig]);
|
|
||||||
|
|
||||||
const req = createMockRequest('');
|
|
||||||
req.headers.host = undefined;
|
|
||||||
|
|
||||||
const result = router.routeReq(req);
|
|
||||||
|
|
||||||
expect(result).toEqual(defaultConfig);
|
|
||||||
});
|
|
||||||
|
|
||||||
// Test complex path parameters
|
|
||||||
tap.test('should handle complex path parameters', async () => {
|
|
||||||
const config = createRouteConfig(TEST_DOMAIN);
|
|
||||||
config.match.path = '/api/:version/users/:userId/posts/:postId';
|
|
||||||
router.setRoutes([config]);
|
|
||||||
|
|
||||||
const req = createMockRequest(TEST_DOMAIN, '/api/v1/users/123/posts/456');
|
|
||||||
const result = router.routeReqWithDetails(req);
|
|
||||||
|
|
||||||
expect(result).toBeTruthy();
|
|
||||||
expect(result.route).toEqual(config);
|
|
||||||
expect(result.pathParams).toBeTruthy();
|
|
||||||
expect(result.pathParams.version).toEqual('v1');
|
|
||||||
expect(result.pathParams.userId).toEqual('123');
|
|
||||||
expect(result.pathParams.postId).toEqual('456');
|
|
||||||
});
|
|
||||||
|
|
||||||
// Performance test
|
|
||||||
tap.test('should handle many configurations efficiently', async () => {
|
|
||||||
const configs = [];
|
|
||||||
|
|
||||||
// Create many configs with different hostnames
|
|
||||||
for (let i = 0; i < 100; i++) {
|
|
||||||
configs.push(createRouteConfig(`host-${i}.example.com`));
|
|
||||||
}
|
|
||||||
|
|
||||||
router.setRoutes(configs);
|
|
||||||
|
|
||||||
// Test middle of the list to avoid best/worst case
|
|
||||||
const req = createMockRequest('host-50.example.com');
|
|
||||||
const result = router.routeReq(req);
|
|
||||||
|
|
||||||
expect(result).toEqual(configs[50]);
|
|
||||||
});
|
|
||||||
|
|
||||||
// Test cleanup
|
|
||||||
tap.test('cleanup proxy router test environment', async () => {
|
|
||||||
// Clear all configurations
|
|
||||||
router.setRoutes([]);
|
|
||||||
|
|
||||||
// Verify empty state by testing that no routes match
|
|
||||||
const req = createMockRequest(TEST_DOMAIN);
|
|
||||||
const result = router.routeReq(req);
|
|
||||||
expect(result).toBeUndefined();
|
|
||||||
});
|
|
||||||
|
|
||||||
export default tap.start();
|
|
||||||
@@ -1,157 +0,0 @@
|
|||||||
import { expect, tap } from '@git.zone/tstest/tapbundle';
|
|
||||||
import { SharedSecurityManager } from '../ts/core/utils/shared-security-manager.js';
|
|
||||||
import type { IRouteConfig, IRouteContext } from '../ts/proxies/smart-proxy/models/route-types.js';
|
|
||||||
|
|
||||||
let securityManager: SharedSecurityManager;
|
|
||||||
|
|
||||||
tap.test('Setup SharedSecurityManager', async () => {
|
|
||||||
securityManager = new SharedSecurityManager({
|
|
||||||
maxConnectionsPerIP: 5,
|
|
||||||
connectionRateLimitPerMinute: 10,
|
|
||||||
cleanupIntervalMs: 1000 // 1 second for faster testing
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
tap.test('IP connection tracking', async () => {
|
|
||||||
const testIP = '192.168.1.100';
|
|
||||||
|
|
||||||
// Track multiple connections
|
|
||||||
securityManager.trackConnectionByIP(testIP, 'conn1');
|
|
||||||
securityManager.trackConnectionByIP(testIP, 'conn2');
|
|
||||||
securityManager.trackConnectionByIP(testIP, 'conn3');
|
|
||||||
|
|
||||||
// Verify connection count
|
|
||||||
expect(securityManager.getConnectionCountByIP(testIP)).toEqual(3);
|
|
||||||
|
|
||||||
// Remove a connection
|
|
||||||
securityManager.removeConnectionByIP(testIP, 'conn2');
|
|
||||||
expect(securityManager.getConnectionCountByIP(testIP)).toEqual(2);
|
|
||||||
|
|
||||||
// Remove remaining connections
|
|
||||||
securityManager.removeConnectionByIP(testIP, 'conn1');
|
|
||||||
securityManager.removeConnectionByIP(testIP, 'conn3');
|
|
||||||
expect(securityManager.getConnectionCountByIP(testIP)).toEqual(0);
|
|
||||||
});
|
|
||||||
|
|
||||||
tap.test('Per-IP connection limits validation', async () => {
|
|
||||||
const testIP = '192.168.1.101';
|
|
||||||
|
|
||||||
// Track connections up to limit
|
|
||||||
for (let i = 1; i <= 5; i++) {
|
|
||||||
// Validate BEFORE tracking the connection (checking if we can add a new connection)
|
|
||||||
const result = securityManager.validateIP(testIP);
|
|
||||||
expect(result.allowed).toBeTrue();
|
|
||||||
// Now track the connection
|
|
||||||
securityManager.trackConnectionByIP(testIP, `conn${i}`);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify we're at the limit
|
|
||||||
expect(securityManager.getConnectionCountByIP(testIP)).toEqual(5);
|
|
||||||
|
|
||||||
// Next connection should be rejected (we're already at 5)
|
|
||||||
const result = securityManager.validateIP(testIP);
|
|
||||||
expect(result.allowed).toBeFalse();
|
|
||||||
expect(result.reason).toInclude('Maximum connections per IP');
|
|
||||||
|
|
||||||
// Clean up
|
|
||||||
for (let i = 1; i <= 5; i++) {
|
|
||||||
securityManager.removeConnectionByIP(testIP, `conn${i}`);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
tap.test('Connection rate limiting', async () => {
|
|
||||||
const testIP = '192.168.1.102';
|
|
||||||
|
|
||||||
// Make connections at the rate limit
|
|
||||||
// Note: validateIP() already tracks timestamps internally for rate limiting
|
|
||||||
for (let i = 0; i < 10; i++) {
|
|
||||||
const result = securityManager.validateIP(testIP);
|
|
||||||
expect(result.allowed).toBeTrue();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Next connection should exceed rate limit
|
|
||||||
const result = securityManager.validateIP(testIP);
|
|
||||||
expect(result.allowed).toBeFalse();
|
|
||||||
expect(result.reason).toInclude('Connection rate limit');
|
|
||||||
});
|
|
||||||
|
|
||||||
tap.test('Route-level connection limits', async () => {
|
|
||||||
const route: IRouteConfig = {
|
|
||||||
name: 'test-route',
|
|
||||||
match: { ports: 443 },
|
|
||||||
action: { type: 'forward', targets: [{ host: 'localhost', port: 8080 }] },
|
|
||||||
security: {
|
|
||||||
maxConnections: 3
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
const context: IRouteContext = {
|
|
||||||
port: 443,
|
|
||||||
clientIp: '192.168.1.103',
|
|
||||||
serverIp: '0.0.0.0',
|
|
||||||
timestamp: Date.now(),
|
|
||||||
connectionId: 'test-conn',
|
|
||||||
isTls: true
|
|
||||||
};
|
|
||||||
|
|
||||||
// Test with connection counts below limit
|
|
||||||
expect(securityManager.isAllowed(route, context, 0)).toBeTrue();
|
|
||||||
expect(securityManager.isAllowed(route, context, 2)).toBeTrue();
|
|
||||||
|
|
||||||
// Test at limit
|
|
||||||
expect(securityManager.isAllowed(route, context, 3)).toBeFalse();
|
|
||||||
|
|
||||||
// Test above limit
|
|
||||||
expect(securityManager.isAllowed(route, context, 5)).toBeFalse();
|
|
||||||
});
|
|
||||||
|
|
||||||
tap.test('IPv4/IPv6 normalization', async () => {
|
|
||||||
const ipv4 = '127.0.0.1';
|
|
||||||
const ipv4Mapped = '::ffff:127.0.0.1';
|
|
||||||
|
|
||||||
// Track connection with IPv4
|
|
||||||
securityManager.trackConnectionByIP(ipv4, 'conn1');
|
|
||||||
|
|
||||||
// Both representations should show the same connection
|
|
||||||
expect(securityManager.getConnectionCountByIP(ipv4)).toEqual(1);
|
|
||||||
expect(securityManager.getConnectionCountByIP(ipv4Mapped)).toEqual(1);
|
|
||||||
|
|
||||||
// Track another connection with IPv6 representation
|
|
||||||
securityManager.trackConnectionByIP(ipv4Mapped, 'conn2');
|
|
||||||
|
|
||||||
// Both should show 2 connections
|
|
||||||
expect(securityManager.getConnectionCountByIP(ipv4)).toEqual(2);
|
|
||||||
expect(securityManager.getConnectionCountByIP(ipv4Mapped)).toEqual(2);
|
|
||||||
|
|
||||||
// Clean up
|
|
||||||
securityManager.removeConnectionByIP(ipv4, 'conn1');
|
|
||||||
securityManager.removeConnectionByIP(ipv4Mapped, 'conn2');
|
|
||||||
});
|
|
||||||
|
|
||||||
tap.test('Automatic cleanup of expired data', async (tools) => {
|
|
||||||
const testIP = '192.168.1.104';
|
|
||||||
|
|
||||||
// Track a connection and then remove it
|
|
||||||
securityManager.trackConnectionByIP(testIP, 'temp-conn');
|
|
||||||
securityManager.removeConnectionByIP(testIP, 'temp-conn');
|
|
||||||
|
|
||||||
// Add some rate limit entries (they expire after 1 minute)
|
|
||||||
for (let i = 0; i < 5; i++) {
|
|
||||||
securityManager.validateIP(testIP);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wait for cleanup interval (set to 1 second in our test)
|
|
||||||
await tools.delayFor(1500);
|
|
||||||
|
|
||||||
// The IP should be cleaned up since it has no connections
|
|
||||||
// Note: We can't directly check the internal map, but we can verify
|
|
||||||
// that a new connection is allowed (fresh rate limit)
|
|
||||||
const result = securityManager.validateIP(testIP);
|
|
||||||
expect(result.allowed).toBeTrue();
|
|
||||||
});
|
|
||||||
|
|
||||||
tap.test('Cleanup SharedSecurityManager', async () => {
|
|
||||||
securityManager.clearIPTracking();
|
|
||||||
});
|
|
||||||
|
|
||||||
export default tap.start();
|
|
||||||
@@ -1,315 +0,0 @@
|
|||||||
import { expect, tap } from '@git.zone/tstest/tapbundle';
|
|
||||||
import * as plugins from '../ts/plugins.js';
|
|
||||||
import { WrappedSocket } from '../ts/core/models/wrapped-socket.js';
|
|
||||||
import * as net from 'net';
|
|
||||||
|
|
||||||
tap.test('WrappedSocket - should wrap a regular socket', async () => {
|
|
||||||
// Create a simple test server
|
|
||||||
const server = net.createServer();
|
|
||||||
await new Promise<void>((resolve) => {
|
|
||||||
server.listen(0, 'localhost', () => resolve());
|
|
||||||
});
|
|
||||||
|
|
||||||
const serverPort = (server.address() as net.AddressInfo).port;
|
|
||||||
|
|
||||||
// Create a client connection
|
|
||||||
const clientSocket = net.connect(serverPort, 'localhost');
|
|
||||||
|
|
||||||
// Wrap the socket
|
|
||||||
const wrappedSocket = new WrappedSocket(clientSocket);
|
|
||||||
|
|
||||||
// Test initial state - should use underlying socket values
|
|
||||||
expect(wrappedSocket.remoteAddress).toEqual(clientSocket.remoteAddress);
|
|
||||||
expect(wrappedSocket.remotePort).toEqual(clientSocket.remotePort);
|
|
||||||
expect(wrappedSocket.localAddress).toEqual(clientSocket.localAddress);
|
|
||||||
expect(wrappedSocket.localPort).toEqual(clientSocket.localPort);
|
|
||||||
expect(wrappedSocket.isFromTrustedProxy).toBeFalse();
|
|
||||||
|
|
||||||
// Clean up
|
|
||||||
clientSocket.destroy();
|
|
||||||
server.close();
|
|
||||||
});
|
|
||||||
|
|
||||||
tap.test('WrappedSocket - should provide real client info when set', async () => {
|
|
||||||
// Create a simple test server
|
|
||||||
const server = net.createServer();
|
|
||||||
await new Promise<void>((resolve) => {
|
|
||||||
server.listen(0, 'localhost', () => resolve());
|
|
||||||
});
|
|
||||||
|
|
||||||
const serverPort = (server.address() as net.AddressInfo).port;
|
|
||||||
|
|
||||||
// Create a client connection
|
|
||||||
const clientSocket = net.connect(serverPort, 'localhost');
|
|
||||||
|
|
||||||
// Wrap the socket with initial proxy info
|
|
||||||
const wrappedSocket = new WrappedSocket(clientSocket, '192.168.1.100', 54321);
|
|
||||||
|
|
||||||
// Test that real client info is returned
|
|
||||||
expect(wrappedSocket.remoteAddress).toEqual('192.168.1.100');
|
|
||||||
expect(wrappedSocket.remotePort).toEqual(54321);
|
|
||||||
expect(wrappedSocket.isFromTrustedProxy).toBeTrue();
|
|
||||||
|
|
||||||
// Local info should still come from underlying socket
|
|
||||||
expect(wrappedSocket.localAddress).toEqual(clientSocket.localAddress);
|
|
||||||
expect(wrappedSocket.localPort).toEqual(clientSocket.localPort);
|
|
||||||
|
|
||||||
// Clean up
|
|
||||||
clientSocket.destroy();
|
|
||||||
server.close();
|
|
||||||
});
|
|
||||||
|
|
||||||
tap.test('WrappedSocket - should update proxy info via setProxyInfo', async () => {
|
|
||||||
// Create a simple test server
|
|
||||||
const server = net.createServer();
|
|
||||||
await new Promise<void>((resolve) => {
|
|
||||||
server.listen(0, 'localhost', () => resolve());
|
|
||||||
});
|
|
||||||
|
|
||||||
const serverPort = (server.address() as net.AddressInfo).port;
|
|
||||||
|
|
||||||
// Create a client connection
|
|
||||||
const clientSocket = net.connect(serverPort, 'localhost');
|
|
||||||
|
|
||||||
// Wrap the socket without initial proxy info
|
|
||||||
const wrappedSocket = new WrappedSocket(clientSocket);
|
|
||||||
|
|
||||||
// Initially should use underlying socket
|
|
||||||
expect(wrappedSocket.isFromTrustedProxy).toBeFalse();
|
|
||||||
expect(wrappedSocket.remoteAddress).toEqual(clientSocket.remoteAddress);
|
|
||||||
|
|
||||||
// Update proxy info
|
|
||||||
wrappedSocket.setProxyInfo('10.0.0.5', 12345);
|
|
||||||
|
|
||||||
// Now should return proxy info
|
|
||||||
expect(wrappedSocket.remoteAddress).toEqual('10.0.0.5');
|
|
||||||
expect(wrappedSocket.remotePort).toEqual(12345);
|
|
||||||
expect(wrappedSocket.isFromTrustedProxy).toBeTrue();
|
|
||||||
|
|
||||||
// Clean up
|
|
||||||
clientSocket.destroy();
|
|
||||||
server.close();
|
|
||||||
});
|
|
||||||
|
|
||||||
tap.test('WrappedSocket - should correctly determine IP family', async () => {
|
|
||||||
// Create a simple test server
|
|
||||||
const server = net.createServer();
|
|
||||||
await new Promise<void>((resolve) => {
|
|
||||||
server.listen(0, 'localhost', () => resolve());
|
|
||||||
});
|
|
||||||
|
|
||||||
const serverPort = (server.address() as net.AddressInfo).port;
|
|
||||||
|
|
||||||
// Create a client connection
|
|
||||||
const clientSocket = net.connect(serverPort, 'localhost');
|
|
||||||
|
|
||||||
// Test IPv4
|
|
||||||
const wrappedSocketIPv4 = new WrappedSocket(clientSocket, '192.168.1.1', 80);
|
|
||||||
expect(wrappedSocketIPv4.remoteFamily).toEqual('IPv4');
|
|
||||||
|
|
||||||
// Test IPv6
|
|
||||||
const wrappedSocketIPv6 = new WrappedSocket(clientSocket, '2001:0db8:85a3:0000:0000:8a2e:0370:7334', 443);
|
|
||||||
expect(wrappedSocketIPv6.remoteFamily).toEqual('IPv6');
|
|
||||||
|
|
||||||
// Test fallback to underlying socket
|
|
||||||
const wrappedSocketNoProxy = new WrappedSocket(clientSocket);
|
|
||||||
expect(wrappedSocketNoProxy.remoteFamily).toEqual(clientSocket.remoteFamily);
|
|
||||||
|
|
||||||
// Clean up
|
|
||||||
clientSocket.destroy();
|
|
||||||
server.close();
|
|
||||||
});
|
|
||||||
|
|
||||||
tap.test('WrappedSocket - should forward events correctly', async () => {
|
|
||||||
// Create a simple echo server
|
|
||||||
let serverConnection: net.Socket;
|
|
||||||
const server = net.createServer((socket) => {
|
|
||||||
serverConnection = socket;
|
|
||||||
socket.on('data', (data) => {
|
|
||||||
socket.write(data); // Echo back
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
await new Promise<void>((resolve) => {
|
|
||||||
server.listen(0, 'localhost', () => resolve());
|
|
||||||
});
|
|
||||||
|
|
||||||
const serverPort = (server.address() as net.AddressInfo).port;
|
|
||||||
|
|
||||||
// Create a client connection
|
|
||||||
const clientSocket = net.connect(serverPort, 'localhost');
|
|
||||||
|
|
||||||
// Wrap the socket
|
|
||||||
const wrappedSocket = new WrappedSocket(clientSocket);
|
|
||||||
|
|
||||||
// Set up event tracking
|
|
||||||
let connectReceived = false;
|
|
||||||
let dataReceived = false;
|
|
||||||
let endReceived = false;
|
|
||||||
let closeReceived = false;
|
|
||||||
|
|
||||||
wrappedSocket.on('connect', () => {
|
|
||||||
connectReceived = true;
|
|
||||||
});
|
|
||||||
|
|
||||||
wrappedSocket.on('data', (chunk) => {
|
|
||||||
dataReceived = true;
|
|
||||||
expect(chunk.toString()).toEqual('test data');
|
|
||||||
});
|
|
||||||
|
|
||||||
wrappedSocket.on('end', () => {
|
|
||||||
endReceived = true;
|
|
||||||
});
|
|
||||||
|
|
||||||
wrappedSocket.on('close', () => {
|
|
||||||
closeReceived = true;
|
|
||||||
});
|
|
||||||
|
|
||||||
// Wait for connection
|
|
||||||
await new Promise<void>((resolve) => {
|
|
||||||
if (clientSocket.readyState === 'open') {
|
|
||||||
resolve();
|
|
||||||
} else {
|
|
||||||
clientSocket.once('connect', () => resolve());
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
// Send data
|
|
||||||
wrappedSocket.write('test data');
|
|
||||||
|
|
||||||
// Wait for echo
|
|
||||||
await new Promise(resolve => setTimeout(resolve, 100));
|
|
||||||
|
|
||||||
// Close the connection
|
|
||||||
serverConnection.end();
|
|
||||||
|
|
||||||
// Wait for events
|
|
||||||
await new Promise(resolve => setTimeout(resolve, 100));
|
|
||||||
|
|
||||||
// Verify all events were received
|
|
||||||
expect(dataReceived).toBeTrue();
|
|
||||||
expect(endReceived).toBeTrue();
|
|
||||||
expect(closeReceived).toBeTrue();
|
|
||||||
|
|
||||||
// Clean up
|
|
||||||
server.close();
|
|
||||||
});
|
|
||||||
|
|
||||||
tap.test('WrappedSocket - should pass through socket methods', async () => {
|
|
||||||
// Create a simple test server
|
|
||||||
const server = net.createServer();
|
|
||||||
await new Promise<void>((resolve) => {
|
|
||||||
server.listen(0, 'localhost', () => resolve());
|
|
||||||
});
|
|
||||||
|
|
||||||
const serverPort = (server.address() as net.AddressInfo).port;
|
|
||||||
|
|
||||||
// Create a client connection
|
|
||||||
const clientSocket = net.connect(serverPort, 'localhost');
|
|
||||||
await new Promise<void>((resolve) => {
|
|
||||||
clientSocket.once('connect', () => resolve());
|
|
||||||
});
|
|
||||||
|
|
||||||
// Wrap the socket
|
|
||||||
const wrappedSocket = new WrappedSocket(clientSocket);
|
|
||||||
|
|
||||||
// Test various pass-through methods
|
|
||||||
expect(wrappedSocket.readable).toEqual(clientSocket.readable);
|
|
||||||
expect(wrappedSocket.writable).toEqual(clientSocket.writable);
|
|
||||||
expect(wrappedSocket.destroyed).toEqual(clientSocket.destroyed);
|
|
||||||
expect(wrappedSocket.bytesRead).toEqual(clientSocket.bytesRead);
|
|
||||||
expect(wrappedSocket.bytesWritten).toEqual(clientSocket.bytesWritten);
|
|
||||||
|
|
||||||
// Test method calls
|
|
||||||
wrappedSocket.pause();
|
|
||||||
expect(clientSocket.isPaused()).toBeTrue();
|
|
||||||
|
|
||||||
wrappedSocket.resume();
|
|
||||||
expect(clientSocket.isPaused()).toBeFalse();
|
|
||||||
|
|
||||||
// Test setTimeout
|
|
||||||
let timeoutCalled = false;
|
|
||||||
wrappedSocket.setTimeout(100, () => {
|
|
||||||
timeoutCalled = true;
|
|
||||||
});
|
|
||||||
await new Promise(resolve => setTimeout(resolve, 150));
|
|
||||||
expect(timeoutCalled).toBeTrue();
|
|
||||||
|
|
||||||
// Clean up
|
|
||||||
wrappedSocket.destroy();
|
|
||||||
server.close();
|
|
||||||
});
|
|
||||||
|
|
||||||
tap.test('WrappedSocket - should handle write and pipe operations', async () => {
|
|
||||||
// Create a simple echo server
|
|
||||||
const server = net.createServer((socket) => {
|
|
||||||
socket.pipe(socket); // Echo everything back
|
|
||||||
});
|
|
||||||
|
|
||||||
await new Promise<void>((resolve) => {
|
|
||||||
server.listen(0, 'localhost', () => resolve());
|
|
||||||
});
|
|
||||||
|
|
||||||
const serverPort = (server.address() as net.AddressInfo).port;
|
|
||||||
|
|
||||||
// Create a client connection
|
|
||||||
const clientSocket = net.connect(serverPort, 'localhost');
|
|
||||||
await new Promise<void>((resolve) => {
|
|
||||||
clientSocket.once('connect', () => resolve());
|
|
||||||
});
|
|
||||||
|
|
||||||
// Wrap the socket
|
|
||||||
const wrappedSocket = new WrappedSocket(clientSocket);
|
|
||||||
|
|
||||||
// Test write with callback
|
|
||||||
const writeResult = wrappedSocket.write('test', 'utf8', () => {
|
|
||||||
// Write completed
|
|
||||||
});
|
|
||||||
expect(typeof writeResult).toEqual('boolean');
|
|
||||||
|
|
||||||
// Test pipe
|
|
||||||
const { PassThrough } = await import('stream');
|
|
||||||
const passThrough = new PassThrough();
|
|
||||||
const piped = wrappedSocket.pipe(passThrough);
|
|
||||||
expect(piped).toEqual(passThrough);
|
|
||||||
|
|
||||||
// Clean up
|
|
||||||
wrappedSocket.destroy();
|
|
||||||
server.close();
|
|
||||||
});
|
|
||||||
|
|
||||||
tap.test('WrappedSocket - should handle encoding and address methods', async () => {
|
|
||||||
// Create a simple test server
|
|
||||||
const server = net.createServer();
|
|
||||||
await new Promise<void>((resolve) => {
|
|
||||||
server.listen(0, 'localhost', () => resolve());
|
|
||||||
});
|
|
||||||
|
|
||||||
const serverPort = (server.address() as net.AddressInfo).port;
|
|
||||||
|
|
||||||
// Create a client connection
|
|
||||||
const clientSocket = net.connect(serverPort, 'localhost');
|
|
||||||
await new Promise<void>((resolve) => {
|
|
||||||
clientSocket.once('connect', () => resolve());
|
|
||||||
});
|
|
||||||
|
|
||||||
// Wrap the socket
|
|
||||||
const wrappedSocket = new WrappedSocket(clientSocket);
|
|
||||||
|
|
||||||
// Test setEncoding
|
|
||||||
wrappedSocket.setEncoding('utf8');
|
|
||||||
|
|
||||||
// Test address method
|
|
||||||
const addr = wrappedSocket.address();
|
|
||||||
expect(addr).toEqual(clientSocket.address());
|
|
||||||
|
|
||||||
// Test cork/uncork (if available)
|
|
||||||
wrappedSocket.cork();
|
|
||||||
wrappedSocket.uncork();
|
|
||||||
|
|
||||||
// Clean up
|
|
||||||
wrappedSocket.destroy();
|
|
||||||
server.close();
|
|
||||||
});
|
|
||||||
|
|
||||||
export default tap.start();
|
|
||||||
@@ -3,6 +3,6 @@
|
|||||||
*/
|
*/
|
||||||
export const commitinfo = {
|
export const commitinfo = {
|
||||||
name: '@push.rocks/smartproxy',
|
name: '@push.rocks/smartproxy',
|
||||||
version: '25.14.1',
|
version: '26.0.0',
|
||||||
description: 'A powerful proxy package with unified route-based configuration for high traffic management. Features include SSL/TLS support, flexible routing patterns, WebSocket handling, advanced security options, and automatic ACME certificate management.'
|
description: 'A powerful proxy package with unified route-based configuration for high traffic management. Features include SSL/TLS support, flexible routing patterns, WebSocket handling, advanced security options, and automatic ACME certificate management.'
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,3 +0,0 @@
|
|||||||
/**
|
|
||||||
* Common event definitions
|
|
||||||
*/
|
|
||||||
@@ -5,4 +5,3 @@
|
|||||||
// Export submodules
|
// Export submodules
|
||||||
export * from './models/index.js';
|
export * from './models/index.js';
|
||||||
export * from './utils/index.js';
|
export * from './utils/index.js';
|
||||||
export * from './events/index.js';
|
|
||||||
|
|||||||
@@ -3,7 +3,6 @@
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
export * from './common-types.js';
|
export * from './common-types.js';
|
||||||
export * from './socket-augmentation.js';
|
|
||||||
export * from './route-context.js';
|
export * from './route-context.js';
|
||||||
export * from './wrapped-socket.js';
|
export * from './wrapped-socket.js';
|
||||||
export * from './socket-types.js';
|
export * from './socket-types.js';
|
||||||
|
|||||||
@@ -1,38 +0,0 @@
|
|||||||
import * as plugins from '../../plugins.js';
|
|
||||||
|
|
||||||
// Augment the Node.js Socket type to include TLS-related properties
|
|
||||||
// This helps TypeScript understand properties that are dynamically added by Node.js
|
|
||||||
declare module 'net' {
|
|
||||||
interface Socket {
|
|
||||||
// TLS-related properties
|
|
||||||
encrypted?: boolean; // Indicates if the socket is encrypted (TLS/SSL)
|
|
||||||
authorizationError?: Error; // Authentication error if TLS handshake failed
|
|
||||||
|
|
||||||
// TLS-related methods
|
|
||||||
getTLSVersion?(): string; // Returns the TLS version (e.g., 'TLSv1.2', 'TLSv1.3')
|
|
||||||
getPeerCertificate?(detailed?: boolean): any; // Returns the peer's certificate
|
|
||||||
getSession?(): Buffer; // Returns the TLS session data
|
|
||||||
|
|
||||||
// Connection tracking properties (used by HttpProxy)
|
|
||||||
_connectionId?: string; // Unique identifier for the connection
|
|
||||||
_remoteIP?: string; // Remote IP address
|
|
||||||
_realRemoteIP?: string; // Real remote IP (when proxied)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Export a utility function to check if a socket is a TLS socket
|
|
||||||
export function isTLSSocket(socket: plugins.net.Socket): boolean {
|
|
||||||
return 'encrypted' in socket && !!socket.encrypted;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Export a utility function to safely get the TLS version
|
|
||||||
export function getTLSVersion(socket: plugins.net.Socket): string | null {
|
|
||||||
if (socket.getTLSVersion) {
|
|
||||||
try {
|
|
||||||
return socket.getTLSVersion();
|
|
||||||
} catch (e) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
@@ -1,275 +0,0 @@
|
|||||||
/**
|
|
||||||
* Async utility functions for SmartProxy
|
|
||||||
* Provides non-blocking alternatives to synchronous operations
|
|
||||||
*/
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Delays execution for the specified number of milliseconds
|
|
||||||
* Non-blocking alternative to busy wait loops
|
|
||||||
* @param ms - Number of milliseconds to delay
|
|
||||||
* @returns Promise that resolves after the delay
|
|
||||||
*/
|
|
||||||
export async function delay(ms: number): Promise<void> {
|
|
||||||
return new Promise(resolve => setTimeout(resolve, ms));
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Retry an async operation with exponential backoff
|
|
||||||
* @param fn - The async function to retry
|
|
||||||
* @param options - Retry options
|
|
||||||
* @returns The result of the function or throws the last error
|
|
||||||
*/
|
|
||||||
export async function retryWithBackoff<T>(
|
|
||||||
fn: () => Promise<T>,
|
|
||||||
options: {
|
|
||||||
maxAttempts?: number;
|
|
||||||
initialDelay?: number;
|
|
||||||
maxDelay?: number;
|
|
||||||
factor?: number;
|
|
||||||
onRetry?: (attempt: number, error: Error) => void;
|
|
||||||
} = {}
|
|
||||||
): Promise<T> {
|
|
||||||
const {
|
|
||||||
maxAttempts = 3,
|
|
||||||
initialDelay = 100,
|
|
||||||
maxDelay = 10000,
|
|
||||||
factor = 2,
|
|
||||||
onRetry
|
|
||||||
} = options;
|
|
||||||
|
|
||||||
let lastError: Error | null = null;
|
|
||||||
let currentDelay = initialDelay;
|
|
||||||
|
|
||||||
for (let attempt = 1; attempt <= maxAttempts; attempt++) {
|
|
||||||
try {
|
|
||||||
return await fn();
|
|
||||||
} catch (error: any) {
|
|
||||||
lastError = error;
|
|
||||||
|
|
||||||
if (attempt === maxAttempts) {
|
|
||||||
throw error;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (onRetry) {
|
|
||||||
onRetry(attempt, error);
|
|
||||||
}
|
|
||||||
|
|
||||||
await delay(currentDelay);
|
|
||||||
currentDelay = Math.min(currentDelay * factor, maxDelay);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
throw lastError || new Error('Retry failed');
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Execute an async operation with a timeout
|
|
||||||
* @param fn - The async function to execute
|
|
||||||
* @param timeoutMs - Timeout in milliseconds
|
|
||||||
* @param timeoutError - Optional custom timeout error
|
|
||||||
* @returns The result of the function or throws timeout error
|
|
||||||
*/
|
|
||||||
export async function withTimeout<T>(
|
|
||||||
fn: () => Promise<T>,
|
|
||||||
timeoutMs: number,
|
|
||||||
timeoutError?: Error
|
|
||||||
): Promise<T> {
|
|
||||||
const timeoutPromise = new Promise<never>((_, reject) => {
|
|
||||||
setTimeout(() => {
|
|
||||||
reject(timeoutError || new Error(`Operation timed out after ${timeoutMs}ms`));
|
|
||||||
}, timeoutMs);
|
|
||||||
});
|
|
||||||
|
|
||||||
return Promise.race([fn(), timeoutPromise]);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Run multiple async operations in parallel with a concurrency limit
|
|
||||||
* @param items - Array of items to process
|
|
||||||
* @param fn - Async function to run for each item
|
|
||||||
* @param concurrency - Maximum number of concurrent operations
|
|
||||||
* @returns Array of results in the same order as input
|
|
||||||
*/
|
|
||||||
export async function parallelLimit<T, R>(
|
|
||||||
items: T[],
|
|
||||||
fn: (item: T, index: number) => Promise<R>,
|
|
||||||
concurrency: number
|
|
||||||
): Promise<R[]> {
|
|
||||||
const results: R[] = new Array(items.length);
|
|
||||||
const executing: Set<Promise<void>> = new Set();
|
|
||||||
|
|
||||||
for (let i = 0; i < items.length; i++) {
|
|
||||||
const promise = fn(items[i], i).then(result => {
|
|
||||||
results[i] = result;
|
|
||||||
executing.delete(promise);
|
|
||||||
});
|
|
||||||
|
|
||||||
executing.add(promise);
|
|
||||||
|
|
||||||
if (executing.size >= concurrency) {
|
|
||||||
await Promise.race(executing);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
await Promise.all(executing);
|
|
||||||
return results;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Debounce an async function
|
|
||||||
* @param fn - The async function to debounce
|
|
||||||
* @param delayMs - Delay in milliseconds
|
|
||||||
* @returns Debounced function with cancel method
|
|
||||||
*/
|
|
||||||
export function debounceAsync<T extends (...args: any[]) => Promise<any>>(
|
|
||||||
fn: T,
|
|
||||||
delayMs: number
|
|
||||||
): T & { cancel: () => void } {
|
|
||||||
let timeoutId: NodeJS.Timeout | null = null;
|
|
||||||
let lastPromise: Promise<any> | null = null;
|
|
||||||
|
|
||||||
const debounced = ((...args: Parameters<T>) => {
|
|
||||||
if (timeoutId) {
|
|
||||||
clearTimeout(timeoutId);
|
|
||||||
}
|
|
||||||
|
|
||||||
lastPromise = new Promise((resolve, reject) => {
|
|
||||||
timeoutId = setTimeout(async () => {
|
|
||||||
timeoutId = null;
|
|
||||||
try {
|
|
||||||
const result = await fn(...args);
|
|
||||||
resolve(result);
|
|
||||||
} catch (error) {
|
|
||||||
reject(error);
|
|
||||||
}
|
|
||||||
}, delayMs);
|
|
||||||
});
|
|
||||||
|
|
||||||
return lastPromise;
|
|
||||||
}) as any;
|
|
||||||
|
|
||||||
debounced.cancel = () => {
|
|
||||||
if (timeoutId) {
|
|
||||||
clearTimeout(timeoutId);
|
|
||||||
timeoutId = null;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
return debounced as T & { cancel: () => void };
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Create a mutex for ensuring exclusive access to a resource
|
|
||||||
*/
|
|
||||||
export class AsyncMutex {
|
|
||||||
private queue: Array<() => void> = [];
|
|
||||||
private locked = false;
|
|
||||||
|
|
||||||
async acquire(): Promise<() => void> {
|
|
||||||
if (!this.locked) {
|
|
||||||
this.locked = true;
|
|
||||||
return () => this.release();
|
|
||||||
}
|
|
||||||
|
|
||||||
return new Promise<() => void>(resolve => {
|
|
||||||
this.queue.push(() => {
|
|
||||||
resolve(() => this.release());
|
|
||||||
});
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
private release(): void {
|
|
||||||
const next = this.queue.shift();
|
|
||||||
if (next) {
|
|
||||||
next();
|
|
||||||
} else {
|
|
||||||
this.locked = false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async runExclusive<T>(fn: () => Promise<T>): Promise<T> {
|
|
||||||
const release = await this.acquire();
|
|
||||||
try {
|
|
||||||
return await fn();
|
|
||||||
} finally {
|
|
||||||
release();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Circuit breaker for protecting against cascading failures
|
|
||||||
*/
|
|
||||||
export class CircuitBreaker {
|
|
||||||
private failureCount = 0;
|
|
||||||
private lastFailureTime = 0;
|
|
||||||
private state: 'closed' | 'open' | 'half-open' = 'closed';
|
|
||||||
|
|
||||||
constructor(
|
|
||||||
private options: {
|
|
||||||
failureThreshold: number;
|
|
||||||
resetTimeout: number;
|
|
||||||
onStateChange?: (state: 'closed' | 'open' | 'half-open') => void;
|
|
||||||
}
|
|
||||||
) {}
|
|
||||||
|
|
||||||
async execute<T>(fn: () => Promise<T>): Promise<T> {
|
|
||||||
if (this.state === 'open') {
|
|
||||||
if (Date.now() - this.lastFailureTime > this.options.resetTimeout) {
|
|
||||||
this.setState('half-open');
|
|
||||||
} else {
|
|
||||||
throw new Error('Circuit breaker is open');
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
try {
|
|
||||||
const result = await fn();
|
|
||||||
this.onSuccess();
|
|
||||||
return result;
|
|
||||||
} catch (error) {
|
|
||||||
this.onFailure();
|
|
||||||
throw error;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private onSuccess(): void {
|
|
||||||
this.failureCount = 0;
|
|
||||||
if (this.state !== 'closed') {
|
|
||||||
this.setState('closed');
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private onFailure(): void {
|
|
||||||
this.failureCount++;
|
|
||||||
this.lastFailureTime = Date.now();
|
|
||||||
|
|
||||||
if (this.failureCount >= this.options.failureThreshold) {
|
|
||||||
this.setState('open');
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private setState(state: 'closed' | 'open' | 'half-open'): void {
|
|
||||||
if (this.state !== state) {
|
|
||||||
this.state = state;
|
|
||||||
if (this.options.onStateChange) {
|
|
||||||
this.options.onStateChange(state);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
isOpen(): boolean {
|
|
||||||
return this.state === 'open';
|
|
||||||
}
|
|
||||||
|
|
||||||
getState(): 'closed' | 'open' | 'half-open' {
|
|
||||||
return this.state;
|
|
||||||
}
|
|
||||||
|
|
||||||
recordSuccess(): void {
|
|
||||||
this.onSuccess();
|
|
||||||
}
|
|
||||||
|
|
||||||
recordFailure(): void {
|
|
||||||
this.onFailure();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,225 +0,0 @@
|
|||||||
/**
|
|
||||||
* A binary heap implementation for efficient priority queue operations
|
|
||||||
* Supports O(log n) insert and extract operations
|
|
||||||
*/
|
|
||||||
export class BinaryHeap<T> {
|
|
||||||
private heap: T[] = [];
|
|
||||||
private keyMap?: Map<string, number>; // For efficient key-based lookups
|
|
||||||
|
|
||||||
constructor(
|
|
||||||
private compareFn: (a: T, b: T) => number,
|
|
||||||
private extractKey?: (item: T) => string
|
|
||||||
) {
|
|
||||||
if (extractKey) {
|
|
||||||
this.keyMap = new Map();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get the current size of the heap
|
|
||||||
*/
|
|
||||||
public get size(): number {
|
|
||||||
return this.heap.length;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Check if the heap is empty
|
|
||||||
*/
|
|
||||||
public isEmpty(): boolean {
|
|
||||||
return this.heap.length === 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Peek at the top element without removing it
|
|
||||||
*/
|
|
||||||
public peek(): T | undefined {
|
|
||||||
return this.heap[0];
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Insert a new item into the heap
|
|
||||||
* O(log n) time complexity
|
|
||||||
*/
|
|
||||||
public insert(item: T): void {
|
|
||||||
const index = this.heap.length;
|
|
||||||
this.heap.push(item);
|
|
||||||
|
|
||||||
if (this.keyMap && this.extractKey) {
|
|
||||||
const key = this.extractKey(item);
|
|
||||||
this.keyMap.set(key, index);
|
|
||||||
}
|
|
||||||
|
|
||||||
this.bubbleUp(index);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Extract the top element from the heap
|
|
||||||
* O(log n) time complexity
|
|
||||||
*/
|
|
||||||
public extract(): T | undefined {
|
|
||||||
if (this.heap.length === 0) return undefined;
|
|
||||||
if (this.heap.length === 1) {
|
|
||||||
const item = this.heap.pop()!;
|
|
||||||
if (this.keyMap && this.extractKey) {
|
|
||||||
this.keyMap.delete(this.extractKey(item));
|
|
||||||
}
|
|
||||||
return item;
|
|
||||||
}
|
|
||||||
|
|
||||||
const result = this.heap[0];
|
|
||||||
const lastItem = this.heap.pop()!;
|
|
||||||
this.heap[0] = lastItem;
|
|
||||||
|
|
||||||
if (this.keyMap && this.extractKey) {
|
|
||||||
this.keyMap.delete(this.extractKey(result));
|
|
||||||
this.keyMap.set(this.extractKey(lastItem), 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
this.bubbleDown(0);
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Extract an element that matches the predicate
|
|
||||||
* O(n) time complexity for search, O(log n) for extraction
|
|
||||||
*/
|
|
||||||
public extractIf(predicate: (item: T) => boolean): T | undefined {
|
|
||||||
const index = this.heap.findIndex(predicate);
|
|
||||||
if (index === -1) return undefined;
|
|
||||||
|
|
||||||
return this.extractAt(index);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Extract an element by its key (if extractKey was provided)
|
|
||||||
* O(log n) time complexity
|
|
||||||
*/
|
|
||||||
public extractByKey(key: string): T | undefined {
|
|
||||||
if (!this.keyMap || !this.extractKey) {
|
|
||||||
throw new Error('extractKey function must be provided to use key-based extraction');
|
|
||||||
}
|
|
||||||
|
|
||||||
const index = this.keyMap.get(key);
|
|
||||||
if (index === undefined) return undefined;
|
|
||||||
|
|
||||||
return this.extractAt(index);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Check if a key exists in the heap
|
|
||||||
* O(1) time complexity
|
|
||||||
*/
|
|
||||||
public hasKey(key: string): boolean {
|
|
||||||
if (!this.keyMap) return false;
|
|
||||||
return this.keyMap.has(key);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get all elements as an array (does not modify heap)
|
|
||||||
* O(n) time complexity
|
|
||||||
*/
|
|
||||||
public toArray(): T[] {
|
|
||||||
return [...this.heap];
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Clear the heap
|
|
||||||
*/
|
|
||||||
public clear(): void {
|
|
||||||
this.heap = [];
|
|
||||||
if (this.keyMap) {
|
|
||||||
this.keyMap.clear();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Extract element at specific index
|
|
||||||
*/
|
|
||||||
private extractAt(index: number): T {
|
|
||||||
const item = this.heap[index];
|
|
||||||
|
|
||||||
if (this.keyMap && this.extractKey) {
|
|
||||||
this.keyMap.delete(this.extractKey(item));
|
|
||||||
}
|
|
||||||
|
|
||||||
if (index === this.heap.length - 1) {
|
|
||||||
this.heap.pop();
|
|
||||||
return item;
|
|
||||||
}
|
|
||||||
|
|
||||||
const lastItem = this.heap.pop()!;
|
|
||||||
this.heap[index] = lastItem;
|
|
||||||
|
|
||||||
if (this.keyMap && this.extractKey) {
|
|
||||||
this.keyMap.set(this.extractKey(lastItem), index);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Try bubbling up first
|
|
||||||
const parentIndex = Math.floor((index - 1) / 2);
|
|
||||||
if (parentIndex >= 0 && this.compareFn(this.heap[index], this.heap[parentIndex]) < 0) {
|
|
||||||
this.bubbleUp(index);
|
|
||||||
} else {
|
|
||||||
this.bubbleDown(index);
|
|
||||||
}
|
|
||||||
|
|
||||||
return item;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Bubble up element at given index to maintain heap property
|
|
||||||
*/
|
|
||||||
private bubbleUp(index: number): void {
|
|
||||||
while (index > 0) {
|
|
||||||
const parentIndex = Math.floor((index - 1) / 2);
|
|
||||||
|
|
||||||
if (this.compareFn(this.heap[index], this.heap[parentIndex]) >= 0) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
this.swap(index, parentIndex);
|
|
||||||
index = parentIndex;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Bubble down element at given index to maintain heap property
|
|
||||||
*/
|
|
||||||
private bubbleDown(index: number): void {
|
|
||||||
const length = this.heap.length;
|
|
||||||
|
|
||||||
while (true) {
|
|
||||||
const leftChild = 2 * index + 1;
|
|
||||||
const rightChild = 2 * index + 2;
|
|
||||||
let smallest = index;
|
|
||||||
|
|
||||||
if (leftChild < length &&
|
|
||||||
this.compareFn(this.heap[leftChild], this.heap[smallest]) < 0) {
|
|
||||||
smallest = leftChild;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (rightChild < length &&
|
|
||||||
this.compareFn(this.heap[rightChild], this.heap[smallest]) < 0) {
|
|
||||||
smallest = rightChild;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (smallest === index) break;
|
|
||||||
|
|
||||||
this.swap(index, smallest);
|
|
||||||
index = smallest;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Swap two elements in the heap
|
|
||||||
*/
|
|
||||||
private swap(i: number, j: number): void {
|
|
||||||
const temp = this.heap[i];
|
|
||||||
this.heap[i] = this.heap[j];
|
|
||||||
this.heap[j] = temp;
|
|
||||||
|
|
||||||
if (this.keyMap && this.extractKey) {
|
|
||||||
this.keyMap.set(this.extractKey(this.heap[i]), i);
|
|
||||||
this.keyMap.set(this.extractKey(this.heap[j]), j);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,425 +0,0 @@
|
|||||||
import { LifecycleComponent } from './lifecycle-component.js';
|
|
||||||
import { BinaryHeap } from './binary-heap.js';
|
|
||||||
import { AsyncMutex } from './async-utils.js';
|
|
||||||
import { EventEmitter } from 'node:events';
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Interface for pooled connection
|
|
||||||
*/
|
|
||||||
export interface IPooledConnection<T> {
|
|
||||||
id: string;
|
|
||||||
connection: T;
|
|
||||||
createdAt: number;
|
|
||||||
lastUsedAt: number;
|
|
||||||
useCount: number;
|
|
||||||
inUse: boolean;
|
|
||||||
metadata?: any;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Configuration options for the connection pool
|
|
||||||
*/
|
|
||||||
export interface IConnectionPoolOptions<T> {
|
|
||||||
minSize?: number;
|
|
||||||
maxSize?: number;
|
|
||||||
acquireTimeout?: number;
|
|
||||||
idleTimeout?: number;
|
|
||||||
maxUseCount?: number;
|
|
||||||
validateOnAcquire?: boolean;
|
|
||||||
validateOnReturn?: boolean;
|
|
||||||
queueTimeout?: number;
|
|
||||||
connectionFactory: () => Promise<T>;
|
|
||||||
connectionValidator?: (connection: T) => Promise<boolean>;
|
|
||||||
connectionDestroyer?: (connection: T) => Promise<void>;
|
|
||||||
onConnectionError?: (error: Error, connection?: T) => void;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Interface for queued acquire request
|
|
||||||
*/
|
|
||||||
interface IAcquireRequest<T> {
|
|
||||||
id: string;
|
|
||||||
priority: number;
|
|
||||||
timestamp: number;
|
|
||||||
resolve: (connection: IPooledConnection<T>) => void;
|
|
||||||
reject: (error: Error) => void;
|
|
||||||
timeoutHandle?: NodeJS.Timeout;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Enhanced connection pool with priority queue, backpressure, and lifecycle management
|
|
||||||
*/
|
|
||||||
export class EnhancedConnectionPool<T> extends LifecycleComponent {
|
|
||||||
private readonly options: Required<Omit<IConnectionPoolOptions<T>, 'connectionValidator' | 'connectionDestroyer' | 'onConnectionError'>> & Pick<IConnectionPoolOptions<T>, 'connectionValidator' | 'connectionDestroyer' | 'onConnectionError'>;
|
|
||||||
private readonly availableConnections: IPooledConnection<T>[] = [];
|
|
||||||
private readonly activeConnections: Map<string, IPooledConnection<T>> = new Map();
|
|
||||||
private readonly waitQueue: BinaryHeap<IAcquireRequest<T>>;
|
|
||||||
private readonly mutex = new AsyncMutex();
|
|
||||||
private readonly eventEmitter = new EventEmitter();
|
|
||||||
|
|
||||||
private connectionIdCounter = 0;
|
|
||||||
private requestIdCounter = 0;
|
|
||||||
private isClosing = false;
|
|
||||||
|
|
||||||
// Metrics
|
|
||||||
private metrics = {
|
|
||||||
connectionsCreated: 0,
|
|
||||||
connectionsDestroyed: 0,
|
|
||||||
connectionsAcquired: 0,
|
|
||||||
connectionsReleased: 0,
|
|
||||||
acquireTimeouts: 0,
|
|
||||||
validationFailures: 0,
|
|
||||||
queueHighWaterMark: 0,
|
|
||||||
};
|
|
||||||
|
|
||||||
constructor(options: IConnectionPoolOptions<T>) {
|
|
||||||
super();
|
|
||||||
|
|
||||||
this.options = {
|
|
||||||
minSize: 0,
|
|
||||||
maxSize: 10,
|
|
||||||
acquireTimeout: 30000,
|
|
||||||
idleTimeout: 300000, // 5 minutes
|
|
||||||
maxUseCount: Infinity,
|
|
||||||
validateOnAcquire: true,
|
|
||||||
validateOnReturn: false,
|
|
||||||
queueTimeout: 60000,
|
|
||||||
...options,
|
|
||||||
};
|
|
||||||
|
|
||||||
// Initialize priority queue (higher priority = extracted first)
|
|
||||||
this.waitQueue = new BinaryHeap<IAcquireRequest<T>>(
|
|
||||||
(a, b) => b.priority - a.priority || a.timestamp - b.timestamp,
|
|
||||||
(item) => item.id
|
|
||||||
);
|
|
||||||
|
|
||||||
// Start maintenance cycle
|
|
||||||
this.startMaintenance();
|
|
||||||
|
|
||||||
// Initialize minimum connections
|
|
||||||
this.initializeMinConnections();
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Initialize minimum number of connections
|
|
||||||
*/
|
|
||||||
private async initializeMinConnections(): Promise<void> {
|
|
||||||
const promises: Promise<void>[] = [];
|
|
||||||
|
|
||||||
for (let i = 0; i < this.options.minSize; i++) {
|
|
||||||
promises.push(
|
|
||||||
this.createConnection()
|
|
||||||
.then(conn => {
|
|
||||||
this.availableConnections.push(conn);
|
|
||||||
})
|
|
||||||
.catch(err => {
|
|
||||||
if (this.options.onConnectionError) {
|
|
||||||
this.options.onConnectionError(err);
|
|
||||||
}
|
|
||||||
})
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
await Promise.all(promises);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Start maintenance timer for idle connection cleanup
|
|
||||||
*/
|
|
||||||
private startMaintenance(): void {
|
|
||||||
this.setInterval(() => {
|
|
||||||
this.performMaintenance();
|
|
||||||
}, 30000); // Every 30 seconds
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Perform maintenance tasks
|
|
||||||
*/
|
|
||||||
private async performMaintenance(): Promise<void> {
|
|
||||||
await this.mutex.runExclusive(async () => {
|
|
||||||
const now = Date.now();
|
|
||||||
const toRemove: IPooledConnection<T>[] = [];
|
|
||||||
|
|
||||||
// Check for idle connections beyond minimum size
|
|
||||||
for (let i = this.availableConnections.length - 1; i >= 0; i--) {
|
|
||||||
const conn = this.availableConnections[i];
|
|
||||||
|
|
||||||
// Keep minimum connections
|
|
||||||
if (this.availableConnections.length <= this.options.minSize) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Remove idle connections
|
|
||||||
if (now - conn.lastUsedAt > this.options.idleTimeout) {
|
|
||||||
toRemove.push(conn);
|
|
||||||
this.availableConnections.splice(i, 1);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Destroy idle connections
|
|
||||||
for (const conn of toRemove) {
|
|
||||||
await this.destroyConnection(conn);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Acquire a connection from the pool
|
|
||||||
*/
|
|
||||||
public async acquire(priority: number = 0, timeout?: number): Promise<IPooledConnection<T>> {
|
|
||||||
if (this.isClosing) {
|
|
||||||
throw new Error('Connection pool is closing');
|
|
||||||
}
|
|
||||||
|
|
||||||
return this.mutex.runExclusive(async () => {
|
|
||||||
// Try to get an available connection
|
|
||||||
const connection = await this.tryAcquireConnection();
|
|
||||||
if (connection) {
|
|
||||||
return connection;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if we can create a new connection
|
|
||||||
const totalConnections = this.availableConnections.length + this.activeConnections.size;
|
|
||||||
if (totalConnections < this.options.maxSize) {
|
|
||||||
try {
|
|
||||||
const newConnection = await this.createConnection();
|
|
||||||
return this.checkoutConnection(newConnection);
|
|
||||||
} catch (err) {
|
|
||||||
// Fall through to queue if creation fails
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add to wait queue
|
|
||||||
return this.queueAcquireRequest(priority, timeout);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Try to acquire an available connection
|
|
||||||
*/
|
|
||||||
private async tryAcquireConnection(): Promise<IPooledConnection<T> | null> {
|
|
||||||
while (this.availableConnections.length > 0) {
|
|
||||||
const connection = this.availableConnections.shift()!;
|
|
||||||
|
|
||||||
// Check if connection exceeded max use count
|
|
||||||
if (connection.useCount >= this.options.maxUseCount) {
|
|
||||||
await this.destroyConnection(connection);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate connection if required
|
|
||||||
if (this.options.validateOnAcquire && this.options.connectionValidator) {
|
|
||||||
try {
|
|
||||||
const isValid = await this.options.connectionValidator(connection.connection);
|
|
||||||
if (!isValid) {
|
|
||||||
this.metrics.validationFailures++;
|
|
||||||
await this.destroyConnection(connection);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
} catch (err) {
|
|
||||||
this.metrics.validationFailures++;
|
|
||||||
await this.destroyConnection(connection);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return this.checkoutConnection(connection);
|
|
||||||
}
|
|
||||||
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Checkout a connection for use
|
|
||||||
*/
|
|
||||||
private checkoutConnection(connection: IPooledConnection<T>): IPooledConnection<T> {
|
|
||||||
connection.inUse = true;
|
|
||||||
connection.lastUsedAt = Date.now();
|
|
||||||
connection.useCount++;
|
|
||||||
|
|
||||||
this.activeConnections.set(connection.id, connection);
|
|
||||||
this.metrics.connectionsAcquired++;
|
|
||||||
|
|
||||||
this.eventEmitter.emit('acquire', connection);
|
|
||||||
return connection;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Queue an acquire request
|
|
||||||
*/
|
|
||||||
private queueAcquireRequest(priority: number, timeout?: number): Promise<IPooledConnection<T>> {
|
|
||||||
return new Promise<IPooledConnection<T>>((resolve, reject) => {
|
|
||||||
const request: IAcquireRequest<T> = {
|
|
||||||
id: `req-${this.requestIdCounter++}`,
|
|
||||||
priority,
|
|
||||||
timestamp: Date.now(),
|
|
||||||
resolve,
|
|
||||||
reject,
|
|
||||||
};
|
|
||||||
|
|
||||||
// Set timeout
|
|
||||||
const timeoutMs = timeout || this.options.queueTimeout;
|
|
||||||
request.timeoutHandle = this.setTimeout(() => {
|
|
||||||
if (this.waitQueue.extractByKey(request.id)) {
|
|
||||||
this.metrics.acquireTimeouts++;
|
|
||||||
reject(new Error(`Connection acquire timeout after ${timeoutMs}ms`));
|
|
||||||
}
|
|
||||||
}, timeoutMs);
|
|
||||||
|
|
||||||
this.waitQueue.insert(request);
|
|
||||||
this.metrics.queueHighWaterMark = Math.max(
|
|
||||||
this.metrics.queueHighWaterMark,
|
|
||||||
this.waitQueue.size
|
|
||||||
);
|
|
||||||
|
|
||||||
this.eventEmitter.emit('enqueue', { queueSize: this.waitQueue.size });
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Release a connection back to the pool
|
|
||||||
*/
|
|
||||||
public async release(connection: IPooledConnection<T>): Promise<void> {
|
|
||||||
return this.mutex.runExclusive(async () => {
|
|
||||||
if (!connection.inUse || !this.activeConnections.has(connection.id)) {
|
|
||||||
throw new Error('Connection is not active');
|
|
||||||
}
|
|
||||||
|
|
||||||
this.activeConnections.delete(connection.id);
|
|
||||||
connection.inUse = false;
|
|
||||||
connection.lastUsedAt = Date.now();
|
|
||||||
this.metrics.connectionsReleased++;
|
|
||||||
|
|
||||||
// Check if connection should be destroyed
|
|
||||||
if (connection.useCount >= this.options.maxUseCount) {
|
|
||||||
await this.destroyConnection(connection);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate on return if required
|
|
||||||
if (this.options.validateOnReturn && this.options.connectionValidator) {
|
|
||||||
try {
|
|
||||||
const isValid = await this.options.connectionValidator(connection.connection);
|
|
||||||
if (!isValid) {
|
|
||||||
await this.destroyConnection(connection);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
} catch (err) {
|
|
||||||
await this.destroyConnection(connection);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if there are waiting requests
|
|
||||||
const request = this.waitQueue.extract();
|
|
||||||
if (request) {
|
|
||||||
this.clearTimeout(request.timeoutHandle!);
|
|
||||||
request.resolve(this.checkoutConnection(connection));
|
|
||||||
this.eventEmitter.emit('dequeue', { queueSize: this.waitQueue.size });
|
|
||||||
} else {
|
|
||||||
// Return to available pool
|
|
||||||
this.availableConnections.push(connection);
|
|
||||||
this.eventEmitter.emit('release', connection);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Create a new connection
|
|
||||||
*/
|
|
||||||
private async createConnection(): Promise<IPooledConnection<T>> {
|
|
||||||
const rawConnection = await this.options.connectionFactory();
|
|
||||||
|
|
||||||
const connection: IPooledConnection<T> = {
|
|
||||||
id: `conn-${this.connectionIdCounter++}`,
|
|
||||||
connection: rawConnection,
|
|
||||||
createdAt: Date.now(),
|
|
||||||
lastUsedAt: Date.now(),
|
|
||||||
useCount: 0,
|
|
||||||
inUse: false,
|
|
||||||
};
|
|
||||||
|
|
||||||
this.metrics.connectionsCreated++;
|
|
||||||
this.eventEmitter.emit('create', connection);
|
|
||||||
|
|
||||||
return connection;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Destroy a connection
|
|
||||||
*/
|
|
||||||
private async destroyConnection(connection: IPooledConnection<T>): Promise<void> {
|
|
||||||
try {
|
|
||||||
if (this.options.connectionDestroyer) {
|
|
||||||
await this.options.connectionDestroyer(connection.connection);
|
|
||||||
}
|
|
||||||
|
|
||||||
this.metrics.connectionsDestroyed++;
|
|
||||||
this.eventEmitter.emit('destroy', connection);
|
|
||||||
} catch (err) {
|
|
||||||
if (this.options.onConnectionError) {
|
|
||||||
this.options.onConnectionError(err as Error, connection.connection);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get current pool statistics
|
|
||||||
*/
|
|
||||||
public getStats() {
|
|
||||||
return {
|
|
||||||
available: this.availableConnections.length,
|
|
||||||
active: this.activeConnections.size,
|
|
||||||
waiting: this.waitQueue.size,
|
|
||||||
total: this.availableConnections.length + this.activeConnections.size,
|
|
||||||
...this.metrics,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Subscribe to pool events
|
|
||||||
*/
|
|
||||||
public on(event: string, listener: Function): void {
|
|
||||||
this.addEventListener(this.eventEmitter, event, listener);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Close the pool and cleanup resources
|
|
||||||
*/
|
|
||||||
protected async onCleanup(): Promise<void> {
|
|
||||||
this.isClosing = true;
|
|
||||||
|
|
||||||
// Clear the wait queue
|
|
||||||
while (!this.waitQueue.isEmpty()) {
|
|
||||||
const request = this.waitQueue.extract();
|
|
||||||
if (request) {
|
|
||||||
this.clearTimeout(request.timeoutHandle!);
|
|
||||||
request.reject(new Error('Connection pool is closing'));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wait for active connections to be released (with timeout)
|
|
||||||
const timeout = 30000;
|
|
||||||
const startTime = Date.now();
|
|
||||||
|
|
||||||
while (this.activeConnections.size > 0 && Date.now() - startTime < timeout) {
|
|
||||||
await new Promise(resolve => {
|
|
||||||
const timer = setTimeout(resolve, 100);
|
|
||||||
if (typeof timer.unref === 'function') {
|
|
||||||
timer.unref();
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
// Destroy all connections
|
|
||||||
const allConnections = [
|
|
||||||
...this.availableConnections,
|
|
||||||
...this.activeConnections.values(),
|
|
||||||
];
|
|
||||||
|
|
||||||
await Promise.all(allConnections.map(conn => this.destroyConnection(conn)));
|
|
||||||
|
|
||||||
this.availableConnections.length = 0;
|
|
||||||
this.activeConnections.clear();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,270 +0,0 @@
|
|||||||
/**
|
|
||||||
* Async filesystem utilities for SmartProxy
|
|
||||||
* Provides non-blocking alternatives to synchronous filesystem operations
|
|
||||||
*/
|
|
||||||
|
|
||||||
import * as plugins from '../../plugins.js';
|
|
||||||
|
|
||||||
export class AsyncFileSystem {
|
|
||||||
/**
|
|
||||||
* Check if a file or directory exists
|
|
||||||
* @param path - Path to check
|
|
||||||
* @returns Promise resolving to true if exists, false otherwise
|
|
||||||
*/
|
|
||||||
static async exists(path: string): Promise<boolean> {
|
|
||||||
try {
|
|
||||||
await plugins.fs.promises.access(path);
|
|
||||||
return true;
|
|
||||||
} catch {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Ensure a directory exists, creating it if necessary
|
|
||||||
* @param dirPath - Directory path to ensure
|
|
||||||
* @returns Promise that resolves when directory is ensured
|
|
||||||
*/
|
|
||||||
static async ensureDir(dirPath: string): Promise<void> {
|
|
||||||
await plugins.fs.promises.mkdir(dirPath, { recursive: true });
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Read a file as string
|
|
||||||
* @param filePath - Path to the file
|
|
||||||
* @param encoding - File encoding (default: utf8)
|
|
||||||
* @returns Promise resolving to file contents
|
|
||||||
*/
|
|
||||||
static async readFile(filePath: string, encoding: BufferEncoding = 'utf8'): Promise<string> {
|
|
||||||
return plugins.fs.promises.readFile(filePath, encoding);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Read a file as buffer
|
|
||||||
* @param filePath - Path to the file
|
|
||||||
* @returns Promise resolving to file buffer
|
|
||||||
*/
|
|
||||||
static async readFileBuffer(filePath: string): Promise<Buffer> {
|
|
||||||
return plugins.fs.promises.readFile(filePath);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Write string data to a file
|
|
||||||
* @param filePath - Path to the file
|
|
||||||
* @param data - String data to write
|
|
||||||
* @param encoding - File encoding (default: utf8)
|
|
||||||
* @returns Promise that resolves when file is written
|
|
||||||
*/
|
|
||||||
static async writeFile(filePath: string, data: string, encoding: BufferEncoding = 'utf8'): Promise<void> {
|
|
||||||
// Ensure directory exists
|
|
||||||
const dir = plugins.path.dirname(filePath);
|
|
||||||
await this.ensureDir(dir);
|
|
||||||
await plugins.fs.promises.writeFile(filePath, data, encoding);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Write buffer data to a file
|
|
||||||
* @param filePath - Path to the file
|
|
||||||
* @param data - Buffer data to write
|
|
||||||
* @returns Promise that resolves when file is written
|
|
||||||
*/
|
|
||||||
static async writeFileBuffer(filePath: string, data: Buffer): Promise<void> {
|
|
||||||
const dir = plugins.path.dirname(filePath);
|
|
||||||
await this.ensureDir(dir);
|
|
||||||
await plugins.fs.promises.writeFile(filePath, data);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Remove a file
|
|
||||||
* @param filePath - Path to the file
|
|
||||||
* @returns Promise that resolves when file is removed
|
|
||||||
*/
|
|
||||||
static async remove(filePath: string): Promise<void> {
|
|
||||||
try {
|
|
||||||
await plugins.fs.promises.unlink(filePath);
|
|
||||||
} catch (error: any) {
|
|
||||||
if (error.code !== 'ENOENT') {
|
|
||||||
throw error;
|
|
||||||
}
|
|
||||||
// File doesn't exist, which is fine
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Remove a directory and all its contents
|
|
||||||
* @param dirPath - Path to the directory
|
|
||||||
* @returns Promise that resolves when directory is removed
|
|
||||||
*/
|
|
||||||
static async removeDir(dirPath: string): Promise<void> {
|
|
||||||
try {
|
|
||||||
await plugins.fs.promises.rm(dirPath, { recursive: true, force: true });
|
|
||||||
} catch (error: any) {
|
|
||||||
if (error.code !== 'ENOENT') {
|
|
||||||
throw error;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Read JSON from a file
|
|
||||||
* @param filePath - Path to the JSON file
|
|
||||||
* @returns Promise resolving to parsed JSON
|
|
||||||
*/
|
|
||||||
static async readJSON<T = any>(filePath: string): Promise<T> {
|
|
||||||
const content = await this.readFile(filePath);
|
|
||||||
return JSON.parse(content);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Write JSON to a file
|
|
||||||
* @param filePath - Path to the file
|
|
||||||
* @param data - Data to write as JSON
|
|
||||||
* @param pretty - Whether to pretty-print JSON (default: true)
|
|
||||||
* @returns Promise that resolves when file is written
|
|
||||||
*/
|
|
||||||
static async writeJSON(filePath: string, data: any, pretty = true): Promise<void> {
|
|
||||||
const jsonString = pretty ? JSON.stringify(data, null, 2) : JSON.stringify(data);
|
|
||||||
await this.writeFile(filePath, jsonString);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Copy a file from source to destination
|
|
||||||
* @param source - Source file path
|
|
||||||
* @param destination - Destination file path
|
|
||||||
* @returns Promise that resolves when file is copied
|
|
||||||
*/
|
|
||||||
static async copyFile(source: string, destination: string): Promise<void> {
|
|
||||||
const destDir = plugins.path.dirname(destination);
|
|
||||||
await this.ensureDir(destDir);
|
|
||||||
await plugins.fs.promises.copyFile(source, destination);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Move/rename a file
|
|
||||||
* @param source - Source file path
|
|
||||||
* @param destination - Destination file path
|
|
||||||
* @returns Promise that resolves when file is moved
|
|
||||||
*/
|
|
||||||
static async moveFile(source: string, destination: string): Promise<void> {
|
|
||||||
const destDir = plugins.path.dirname(destination);
|
|
||||||
await this.ensureDir(destDir);
|
|
||||||
await plugins.fs.promises.rename(source, destination);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get file stats
|
|
||||||
* @param filePath - Path to the file
|
|
||||||
* @returns Promise resolving to file stats or null if doesn't exist
|
|
||||||
*/
|
|
||||||
static async getStats(filePath: string): Promise<plugins.fs.Stats | null> {
|
|
||||||
try {
|
|
||||||
return await plugins.fs.promises.stat(filePath);
|
|
||||||
} catch (error: any) {
|
|
||||||
if (error.code === 'ENOENT') {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
throw error;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* List files in a directory
|
|
||||||
* @param dirPath - Directory path
|
|
||||||
* @returns Promise resolving to array of filenames
|
|
||||||
*/
|
|
||||||
static async listFiles(dirPath: string): Promise<string[]> {
|
|
||||||
try {
|
|
||||||
return await plugins.fs.promises.readdir(dirPath);
|
|
||||||
} catch (error: any) {
|
|
||||||
if (error.code === 'ENOENT') {
|
|
||||||
return [];
|
|
||||||
}
|
|
||||||
throw error;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* List files in a directory with full paths
|
|
||||||
* @param dirPath - Directory path
|
|
||||||
* @returns Promise resolving to array of full file paths
|
|
||||||
*/
|
|
||||||
static async listFilesFullPath(dirPath: string): Promise<string[]> {
|
|
||||||
const files = await this.listFiles(dirPath);
|
|
||||||
return files.map(file => plugins.path.join(dirPath, file));
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Recursively list all files in a directory
|
|
||||||
* @param dirPath - Directory path
|
|
||||||
* @param fileList - Accumulator for file list (used internally)
|
|
||||||
* @returns Promise resolving to array of all file paths
|
|
||||||
*/
|
|
||||||
static async listFilesRecursive(dirPath: string, fileList: string[] = []): Promise<string[]> {
|
|
||||||
const files = await this.listFiles(dirPath);
|
|
||||||
|
|
||||||
for (const file of files) {
|
|
||||||
const filePath = plugins.path.join(dirPath, file);
|
|
||||||
const stats = await this.getStats(filePath);
|
|
||||||
|
|
||||||
if (stats?.isDirectory()) {
|
|
||||||
await this.listFilesRecursive(filePath, fileList);
|
|
||||||
} else if (stats?.isFile()) {
|
|
||||||
fileList.push(filePath);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return fileList;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Create a read stream for a file
|
|
||||||
* @param filePath - Path to the file
|
|
||||||
* @param options - Stream options
|
|
||||||
* @returns Read stream
|
|
||||||
*/
|
|
||||||
static createReadStream(filePath: string, options?: Parameters<typeof plugins.fs.createReadStream>[1]): plugins.fs.ReadStream {
|
|
||||||
return plugins.fs.createReadStream(filePath, options);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Create a write stream for a file
|
|
||||||
* @param filePath - Path to the file
|
|
||||||
* @param options - Stream options
|
|
||||||
* @returns Write stream
|
|
||||||
*/
|
|
||||||
static createWriteStream(filePath: string, options?: Parameters<typeof plugins.fs.createWriteStream>[1]): plugins.fs.WriteStream {
|
|
||||||
return plugins.fs.createWriteStream(filePath, options);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Ensure a file exists, creating an empty file if necessary
|
|
||||||
* @param filePath - Path to the file
|
|
||||||
* @returns Promise that resolves when file is ensured
|
|
||||||
*/
|
|
||||||
static async ensureFile(filePath: string): Promise<void> {
|
|
||||||
const exists = await this.exists(filePath);
|
|
||||||
if (!exists) {
|
|
||||||
await this.writeFile(filePath, '');
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Check if a path is a directory
|
|
||||||
* @param path - Path to check
|
|
||||||
* @returns Promise resolving to true if directory, false otherwise
|
|
||||||
*/
|
|
||||||
static async isDirectory(path: string): Promise<boolean> {
|
|
||||||
const stats = await this.getStats(path);
|
|
||||||
return stats?.isDirectory() ?? false;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Check if a path is a file
|
|
||||||
* @param path - Path to check
|
|
||||||
* @returns Promise resolving to true if file, false otherwise
|
|
||||||
*/
|
|
||||||
static async isFile(path: string): Promise<boolean> {
|
|
||||||
const stats = await this.getStats(path);
|
|
||||||
return stats?.isFile() ?? false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -2,16 +2,4 @@
|
|||||||
* Core utility functions
|
* Core utility functions
|
||||||
*/
|
*/
|
||||||
|
|
||||||
export * from './validation-utils.js';
|
|
||||||
export * from './ip-utils.js';
|
|
||||||
export * from './template-utils.js';
|
|
||||||
export * from './security-utils.js';
|
|
||||||
export * from './shared-security-manager.js';
|
|
||||||
export * from './websocket-utils.js';
|
|
||||||
export * from './logger.js';
|
export * from './logger.js';
|
||||||
export * from './async-utils.js';
|
|
||||||
export * from './fs-utils.js';
|
|
||||||
export * from './lifecycle-component.js';
|
|
||||||
export * from './binary-heap.js';
|
|
||||||
export * from './enhanced-connection-pool.js';
|
|
||||||
export * from './socket-utils.js';
|
|
||||||
|
|||||||
@@ -1,303 +0,0 @@
|
|||||||
import * as plugins from '../../plugins.js';
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Utility class for IP address operations
|
|
||||||
*/
|
|
||||||
export class IpUtils {
|
|
||||||
/**
|
|
||||||
* Check if the IP matches any of the glob patterns
|
|
||||||
*
|
|
||||||
* This method checks IP addresses against glob patterns and handles IPv4/IPv6 normalization.
|
|
||||||
* It's used to implement IP filtering based on security configurations.
|
|
||||||
*
|
|
||||||
* @param ip - The IP address to check
|
|
||||||
* @param patterns - Array of glob patterns
|
|
||||||
* @returns true if IP matches any pattern, false otherwise
|
|
||||||
*/
|
|
||||||
public static isGlobIPMatch(ip: string, patterns: string[]): boolean {
|
|
||||||
if (!ip || !patterns || patterns.length === 0) return false;
|
|
||||||
|
|
||||||
// Normalize the IP being checked
|
|
||||||
const normalizedIPVariants = this.normalizeIP(ip);
|
|
||||||
if (normalizedIPVariants.length === 0) return false;
|
|
||||||
|
|
||||||
// Check each pattern
|
|
||||||
for (const pattern of patterns) {
|
|
||||||
// Handle CIDR notation
|
|
||||||
if (pattern.includes('/')) {
|
|
||||||
if (this.matchCIDR(ip, pattern)) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle range notation
|
|
||||||
if (pattern.includes('-') && !pattern.includes('*')) {
|
|
||||||
if (this.matchIPRange(ip, pattern)) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Expand shorthand patterns for glob matching
|
|
||||||
let expandedPattern = pattern;
|
|
||||||
if (pattern.includes('*') && !pattern.includes(':')) {
|
|
||||||
const parts = pattern.split('.');
|
|
||||||
while (parts.length < 4) {
|
|
||||||
parts.push('*');
|
|
||||||
}
|
|
||||||
expandedPattern = parts.join('.');
|
|
||||||
}
|
|
||||||
|
|
||||||
// Normalize and check with minimatch
|
|
||||||
const normalizedPatterns = this.normalizeIP(expandedPattern);
|
|
||||||
|
|
||||||
for (const ipVariant of normalizedIPVariants) {
|
|
||||||
for (const normalizedPattern of normalizedPatterns) {
|
|
||||||
if (plugins.minimatch(ipVariant, normalizedPattern)) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Normalize IP addresses for consistent comparison
|
|
||||||
*
|
|
||||||
* @param ip The IP address to normalize
|
|
||||||
* @returns Array of normalized IP forms
|
|
||||||
*/
|
|
||||||
public static normalizeIP(ip: string): string[] {
|
|
||||||
if (!ip) return [];
|
|
||||||
|
|
||||||
// Handle IPv4-mapped IPv6 addresses (::ffff:127.0.0.1)
|
|
||||||
if (ip.startsWith('::ffff:')) {
|
|
||||||
const ipv4 = ip.slice(7);
|
|
||||||
return [ip, ipv4];
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle IPv4 addresses by also checking IPv4-mapped form
|
|
||||||
if (/^\d{1,3}(\.\d{1,3}){3}$/.test(ip)) {
|
|
||||||
return [ip, `::ffff:${ip}`];
|
|
||||||
}
|
|
||||||
|
|
||||||
return [ip];
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Check if an IP is authorized using security rules
|
|
||||||
*
|
|
||||||
* @param ip - The IP address to check
|
|
||||||
* @param allowedIPs - Array of allowed IP patterns
|
|
||||||
* @param blockedIPs - Array of blocked IP patterns
|
|
||||||
* @returns true if IP is authorized, false if blocked
|
|
||||||
*/
|
|
||||||
public static isIPAuthorized(ip: string, allowedIPs: string[] = [], blockedIPs: string[] = []): boolean {
|
|
||||||
// Skip IP validation if no rules are defined
|
|
||||||
if (!ip || (allowedIPs.length === 0 && blockedIPs.length === 0)) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
// First check if IP is blocked - blocked IPs take precedence
|
|
||||||
if (blockedIPs.length > 0 && this.isGlobIPMatch(ip, blockedIPs)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Then check if IP is allowed (if no allowed IPs are specified, all non-blocked IPs are allowed)
|
|
||||||
return allowedIPs.length === 0 || this.isGlobIPMatch(ip, allowedIPs);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Check if an IP address is a private network address
|
|
||||||
*
|
|
||||||
* @param ip The IP address to check
|
|
||||||
* @returns true if the IP is a private network address, false otherwise
|
|
||||||
*/
|
|
||||||
public static isPrivateIP(ip: string): boolean {
|
|
||||||
if (!ip) return false;
|
|
||||||
|
|
||||||
// Handle IPv4-mapped IPv6 addresses
|
|
||||||
if (ip.startsWith('::ffff:')) {
|
|
||||||
ip = ip.slice(7);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check IPv4 private ranges
|
|
||||||
if (/^\d{1,3}(\.\d{1,3}){3}$/.test(ip)) {
|
|
||||||
const parts = ip.split('.').map(Number);
|
|
||||||
|
|
||||||
// Check common private ranges
|
|
||||||
// 10.0.0.0/8
|
|
||||||
if (parts[0] === 10) return true;
|
|
||||||
|
|
||||||
// 172.16.0.0/12
|
|
||||||
if (parts[0] === 172 && parts[1] >= 16 && parts[1] <= 31) return true;
|
|
||||||
|
|
||||||
// 192.168.0.0/16
|
|
||||||
if (parts[0] === 192 && parts[1] === 168) return true;
|
|
||||||
|
|
||||||
// 127.0.0.0/8 (localhost)
|
|
||||||
if (parts[0] === 127) return true;
|
|
||||||
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// IPv6 local addresses
|
|
||||||
return ip === '::1' || ip.startsWith('fc00:') || ip.startsWith('fd00:') || ip.startsWith('fe80:');
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Check if an IP address is a public network address
|
|
||||||
*
|
|
||||||
* @param ip The IP address to check
|
|
||||||
* @returns true if the IP is a public network address, false otherwise
|
|
||||||
*/
|
|
||||||
public static isPublicIP(ip: string): boolean {
|
|
||||||
return !this.isPrivateIP(ip);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Check if an IP matches a CIDR notation
|
|
||||||
*
|
|
||||||
* @param ip The IP address to check
|
|
||||||
* @param cidr The CIDR notation (e.g., "192.168.1.0/24")
|
|
||||||
* @returns true if IP is within the CIDR range
|
|
||||||
*/
|
|
||||||
private static matchCIDR(ip: string, cidr: string): boolean {
|
|
||||||
if (!cidr.includes('/')) return false;
|
|
||||||
|
|
||||||
const [networkAddr, prefixStr] = cidr.split('/');
|
|
||||||
const prefix = parseInt(prefixStr, 10);
|
|
||||||
|
|
||||||
// Handle IPv4-mapped IPv6 in the IP being checked
|
|
||||||
let checkIP = ip;
|
|
||||||
if (checkIP.startsWith('::ffff:')) {
|
|
||||||
checkIP = checkIP.slice(7);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle IPv6 CIDR
|
|
||||||
if (networkAddr.includes(':')) {
|
|
||||||
// TODO: Implement IPv6 CIDR matching
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// IPv4 CIDR matching
|
|
||||||
if (!/^\d{1,3}(\.\d{1,3}){3}$/.test(checkIP)) return false;
|
|
||||||
if (!/^\d{1,3}(\.\d{1,3}){3}$/.test(networkAddr)) return false;
|
|
||||||
if (isNaN(prefix) || prefix < 0 || prefix > 32) return false;
|
|
||||||
|
|
||||||
const ipParts = checkIP.split('.').map(Number);
|
|
||||||
const netParts = networkAddr.split('.').map(Number);
|
|
||||||
|
|
||||||
// Validate IP parts
|
|
||||||
for (const part of [...ipParts, ...netParts]) {
|
|
||||||
if (part < 0 || part > 255) return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Convert to 32-bit integers
|
|
||||||
const ipNum = (ipParts[0] << 24) | (ipParts[1] << 16) | (ipParts[2] << 8) | ipParts[3];
|
|
||||||
const netNum = (netParts[0] << 24) | (netParts[1] << 16) | (netParts[2] << 8) | netParts[3];
|
|
||||||
|
|
||||||
// Create mask
|
|
||||||
const mask = (-1 << (32 - prefix)) >>> 0;
|
|
||||||
|
|
||||||
// Check if IP is in network range
|
|
||||||
return (ipNum & mask) === (netNum & mask);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Check if an IP matches a range notation
|
|
||||||
*
|
|
||||||
* @param ip The IP address to check
|
|
||||||
* @param range The range notation (e.g., "192.168.1.1-192.168.1.100")
|
|
||||||
* @returns true if IP is within the range
|
|
||||||
*/
|
|
||||||
private static matchIPRange(ip: string, range: string): boolean {
|
|
||||||
if (!range.includes('-')) return false;
|
|
||||||
|
|
||||||
const [startIP, endIP] = range.split('-').map(s => s.trim());
|
|
||||||
|
|
||||||
// Handle IPv4-mapped IPv6 in the IP being checked
|
|
||||||
let checkIP = ip;
|
|
||||||
if (checkIP.startsWith('::ffff:')) {
|
|
||||||
checkIP = checkIP.slice(7);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Only handle IPv4 for now
|
|
||||||
if (!/^\d{1,3}(\.\d{1,3}){3}$/.test(checkIP)) return false;
|
|
||||||
if (!/^\d{1,3}(\.\d{1,3}){3}$/.test(startIP)) return false;
|
|
||||||
if (!/^\d{1,3}(\.\d{1,3}){3}$/.test(endIP)) return false;
|
|
||||||
|
|
||||||
const ipParts = checkIP.split('.').map(Number);
|
|
||||||
const startParts = startIP.split('.').map(Number);
|
|
||||||
const endParts = endIP.split('.').map(Number);
|
|
||||||
|
|
||||||
// Validate parts
|
|
||||||
for (const part of [...ipParts, ...startParts, ...endParts]) {
|
|
||||||
if (part < 0 || part > 255) return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Convert to 32-bit integers for comparison
|
|
||||||
const ipNum = (ipParts[0] << 24) | (ipParts[1] << 16) | (ipParts[2] << 8) | ipParts[3];
|
|
||||||
const startNum = (startParts[0] << 24) | (startParts[1] << 16) | (startParts[2] << 8) | startParts[3];
|
|
||||||
const endNum = (endParts[0] << 24) | (endParts[1] << 16) | (endParts[2] << 8) | endParts[3];
|
|
||||||
|
|
||||||
// Convert to unsigned for proper comparison
|
|
||||||
const ipUnsigned = ipNum >>> 0;
|
|
||||||
const startUnsigned = startNum >>> 0;
|
|
||||||
const endUnsigned = endNum >>> 0;
|
|
||||||
|
|
||||||
return ipUnsigned >= startUnsigned && ipUnsigned <= endUnsigned;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Convert a subnet CIDR to an IP range for filtering
|
|
||||||
*
|
|
||||||
* @param cidr The CIDR notation (e.g., "192.168.1.0/24")
|
|
||||||
* @returns Array of glob patterns that match the CIDR range
|
|
||||||
*/
|
|
||||||
public static cidrToGlobPatterns(cidr: string): string[] {
|
|
||||||
if (!cidr || !cidr.includes('/')) return [];
|
|
||||||
|
|
||||||
const [ipPart, prefixPart] = cidr.split('/');
|
|
||||||
const prefix = parseInt(prefixPart, 10);
|
|
||||||
|
|
||||||
if (isNaN(prefix) || prefix < 0 || prefix > 32) return [];
|
|
||||||
|
|
||||||
// For IPv4 only for now
|
|
||||||
if (!/^\d{1,3}(\.\d{1,3}){3}$/.test(ipPart)) return [];
|
|
||||||
|
|
||||||
const ipParts = ipPart.split('.').map(Number);
|
|
||||||
const fullMask = Math.pow(2, 32 - prefix) - 1;
|
|
||||||
|
|
||||||
// Convert IP to a numeric value
|
|
||||||
const ipNum = (ipParts[0] << 24) | (ipParts[1] << 16) | (ipParts[2] << 8) | ipParts[3];
|
|
||||||
|
|
||||||
// Calculate network address (IP & ~fullMask)
|
|
||||||
const networkNum = ipNum & ~fullMask;
|
|
||||||
|
|
||||||
// For large ranges, return wildcard patterns
|
|
||||||
if (prefix <= 8) {
|
|
||||||
return [`${(networkNum >>> 24) & 255}.*.*.*`];
|
|
||||||
} else if (prefix <= 16) {
|
|
||||||
return [`${(networkNum >>> 24) & 255}.${(networkNum >>> 16) & 255}.*.*`];
|
|
||||||
} else if (prefix <= 24) {
|
|
||||||
return [`${(networkNum >>> 24) & 255}.${(networkNum >>> 16) & 255}.${(networkNum >>> 8) & 255}.*`];
|
|
||||||
}
|
|
||||||
|
|
||||||
// For small ranges, create individual IP patterns
|
|
||||||
const patterns = [];
|
|
||||||
const maxAddresses = Math.min(256, Math.pow(2, 32 - prefix));
|
|
||||||
|
|
||||||
for (let i = 0; i < maxAddresses; i++) {
|
|
||||||
const currentIpNum = networkNum + i;
|
|
||||||
patterns.push(
|
|
||||||
`${(currentIpNum >>> 24) & 255}.${(currentIpNum >>> 16) & 255}.${(currentIpNum >>> 8) & 255}.${currentIpNum & 255}`
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
return patterns;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,251 +0,0 @@
|
|||||||
/**
|
|
||||||
* Base class for components that need proper resource lifecycle management
|
|
||||||
* Provides automatic cleanup of timers and event listeners to prevent memory leaks
|
|
||||||
*/
|
|
||||||
export abstract class LifecycleComponent {
|
|
||||||
private timers: Set<NodeJS.Timeout> = new Set();
|
|
||||||
private intervals: Set<NodeJS.Timeout> = new Set();
|
|
||||||
private listeners: Array<{
|
|
||||||
target: any;
|
|
||||||
event: string;
|
|
||||||
handler: Function;
|
|
||||||
actualHandler?: Function; // The actual handler registered (may be wrapped)
|
|
||||||
once?: boolean;
|
|
||||||
}> = [];
|
|
||||||
private childComponents: Set<LifecycleComponent> = new Set();
|
|
||||||
protected isShuttingDown = false;
|
|
||||||
private cleanupPromise?: Promise<void>;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Create a managed setTimeout that will be automatically cleaned up
|
|
||||||
*/
|
|
||||||
protected setTimeout(handler: Function, timeout: number): NodeJS.Timeout {
|
|
||||||
if (this.isShuttingDown) {
|
|
||||||
// Return a dummy timer if shutting down
|
|
||||||
const dummyTimer = setTimeout(() => {}, 0);
|
|
||||||
if (typeof dummyTimer.unref === 'function') {
|
|
||||||
dummyTimer.unref();
|
|
||||||
}
|
|
||||||
return dummyTimer;
|
|
||||||
}
|
|
||||||
|
|
||||||
const wrappedHandler = () => {
|
|
||||||
this.timers.delete(timer);
|
|
||||||
if (!this.isShuttingDown) {
|
|
||||||
handler();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
const timer = setTimeout(wrappedHandler, timeout);
|
|
||||||
this.timers.add(timer);
|
|
||||||
|
|
||||||
// Allow process to exit even with timer
|
|
||||||
if (typeof timer.unref === 'function') {
|
|
||||||
timer.unref();
|
|
||||||
}
|
|
||||||
|
|
||||||
return timer;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Create a managed setInterval that will be automatically cleaned up
|
|
||||||
*/
|
|
||||||
protected setInterval(handler: Function, interval: number): NodeJS.Timeout {
|
|
||||||
if (this.isShuttingDown) {
|
|
||||||
// Return a dummy timer if shutting down
|
|
||||||
const dummyTimer = setInterval(() => {}, interval);
|
|
||||||
if (typeof dummyTimer.unref === 'function') {
|
|
||||||
dummyTimer.unref();
|
|
||||||
}
|
|
||||||
clearInterval(dummyTimer); // Clear immediately since we don't need it
|
|
||||||
return dummyTimer;
|
|
||||||
}
|
|
||||||
|
|
||||||
const wrappedHandler = () => {
|
|
||||||
if (!this.isShuttingDown) {
|
|
||||||
handler();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
const timer = setInterval(wrappedHandler, interval);
|
|
||||||
this.intervals.add(timer);
|
|
||||||
|
|
||||||
// Allow process to exit even with timer
|
|
||||||
if (typeof timer.unref === 'function') {
|
|
||||||
timer.unref();
|
|
||||||
}
|
|
||||||
|
|
||||||
return timer;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Clear a managed timeout
|
|
||||||
*/
|
|
||||||
protected clearTimeout(timer: NodeJS.Timeout): void {
|
|
||||||
clearTimeout(timer);
|
|
||||||
this.timers.delete(timer);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Clear a managed interval
|
|
||||||
*/
|
|
||||||
protected clearInterval(timer: NodeJS.Timeout): void {
|
|
||||||
clearInterval(timer);
|
|
||||||
this.intervals.delete(timer);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Add a managed event listener that will be automatically removed on cleanup
|
|
||||||
*/
|
|
||||||
protected addEventListener(
|
|
||||||
target: any,
|
|
||||||
event: string,
|
|
||||||
handler: Function,
|
|
||||||
options?: { once?: boolean }
|
|
||||||
): void {
|
|
||||||
if (this.isShuttingDown) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// For 'once' listeners, we need to wrap the handler to remove it from our tracking
|
|
||||||
let actualHandler = handler;
|
|
||||||
if (options?.once) {
|
|
||||||
actualHandler = (...args: any[]) => {
|
|
||||||
// Call the original handler
|
|
||||||
handler(...args);
|
|
||||||
|
|
||||||
// Remove from our internal tracking
|
|
||||||
const index = this.listeners.findIndex(
|
|
||||||
l => l.target === target && l.event === event && l.handler === handler
|
|
||||||
);
|
|
||||||
if (index !== -1) {
|
|
||||||
this.listeners.splice(index, 1);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
// Support both EventEmitter and DOM-style event targets
|
|
||||||
if (typeof target.on === 'function') {
|
|
||||||
if (options?.once) {
|
|
||||||
target.once(event, actualHandler);
|
|
||||||
} else {
|
|
||||||
target.on(event, actualHandler);
|
|
||||||
}
|
|
||||||
} else if (typeof target.addEventListener === 'function') {
|
|
||||||
target.addEventListener(event, actualHandler, options);
|
|
||||||
} else {
|
|
||||||
throw new Error('Target must support on() or addEventListener()');
|
|
||||||
}
|
|
||||||
|
|
||||||
// Store both the original handler and the actual handler registered
|
|
||||||
this.listeners.push({
|
|
||||||
target,
|
|
||||||
event,
|
|
||||||
handler,
|
|
||||||
actualHandler, // The handler that was actually registered (may be wrapped)
|
|
||||||
once: options?.once
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Remove a specific event listener
|
|
||||||
*/
|
|
||||||
protected removeEventListener(target: any, event: string, handler: Function): void {
|
|
||||||
// Remove from target
|
|
||||||
if (typeof target.removeListener === 'function') {
|
|
||||||
target.removeListener(event, handler);
|
|
||||||
} else if (typeof target.removeEventListener === 'function') {
|
|
||||||
target.removeEventListener(event, handler);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Remove from our tracking
|
|
||||||
const index = this.listeners.findIndex(
|
|
||||||
l => l.target === target && l.event === event && l.handler === handler
|
|
||||||
);
|
|
||||||
if (index !== -1) {
|
|
||||||
this.listeners.splice(index, 1);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Register a child component that should be cleaned up when this component is cleaned up
|
|
||||||
*/
|
|
||||||
protected registerChildComponent(component: LifecycleComponent): void {
|
|
||||||
this.childComponents.add(component);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Unregister a child component
|
|
||||||
*/
|
|
||||||
protected unregisterChildComponent(component: LifecycleComponent): void {
|
|
||||||
this.childComponents.delete(component);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Override this method to implement component-specific cleanup logic
|
|
||||||
*/
|
|
||||||
protected async onCleanup(): Promise<void> {
|
|
||||||
// Override in subclasses
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Clean up all managed resources
|
|
||||||
*/
|
|
||||||
public async cleanup(): Promise<void> {
|
|
||||||
// Return existing cleanup promise if already cleaning up
|
|
||||||
if (this.cleanupPromise) {
|
|
||||||
return this.cleanupPromise;
|
|
||||||
}
|
|
||||||
|
|
||||||
this.cleanupPromise = this.performCleanup();
|
|
||||||
return this.cleanupPromise;
|
|
||||||
}
|
|
||||||
|
|
||||||
private async performCleanup(): Promise<void> {
|
|
||||||
this.isShuttingDown = true;
|
|
||||||
|
|
||||||
// First, clean up child components
|
|
||||||
const childCleanupPromises: Promise<void>[] = [];
|
|
||||||
for (const child of this.childComponents) {
|
|
||||||
childCleanupPromises.push(child.cleanup());
|
|
||||||
}
|
|
||||||
await Promise.all(childCleanupPromises);
|
|
||||||
this.childComponents.clear();
|
|
||||||
|
|
||||||
// Clear all timers
|
|
||||||
for (const timer of this.timers) {
|
|
||||||
clearTimeout(timer);
|
|
||||||
}
|
|
||||||
this.timers.clear();
|
|
||||||
|
|
||||||
// Clear all intervals
|
|
||||||
for (const timer of this.intervals) {
|
|
||||||
clearInterval(timer);
|
|
||||||
}
|
|
||||||
this.intervals.clear();
|
|
||||||
|
|
||||||
// Remove all event listeners
|
|
||||||
for (const { target, event, handler, actualHandler } of this.listeners) {
|
|
||||||
// Use actualHandler if available (for wrapped handlers), otherwise use the original handler
|
|
||||||
const handlerToRemove = actualHandler || handler;
|
|
||||||
|
|
||||||
// All listeners need to be removed, including 'once' listeners that might not have fired
|
|
||||||
if (typeof target.removeListener === 'function') {
|
|
||||||
target.removeListener(event, handlerToRemove);
|
|
||||||
} else if (typeof target.removeEventListener === 'function') {
|
|
||||||
target.removeEventListener(event, handlerToRemove);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
this.listeners = [];
|
|
||||||
|
|
||||||
// Call subclass cleanup
|
|
||||||
await this.onCleanup();
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Check if the component is shutting down
|
|
||||||
*/
|
|
||||||
protected isShuttingDownState(): boolean {
|
|
||||||
return this.isShuttingDown;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,370 +0,0 @@
|
|||||||
import { logger } from './logger.js';
|
|
||||||
|
|
||||||
interface ILogEvent {
|
|
||||||
level: 'info' | 'warn' | 'error' | 'debug';
|
|
||||||
message: string;
|
|
||||||
data?: any;
|
|
||||||
count: number;
|
|
||||||
firstSeen: number;
|
|
||||||
lastSeen: number;
|
|
||||||
}
|
|
||||||
|
|
||||||
interface IAggregatedEvent {
|
|
||||||
key: string;
|
|
||||||
events: Map<string, ILogEvent>;
|
|
||||||
flushTimer?: NodeJS.Timeout;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Log deduplication utility to reduce log spam for repetitive events
|
|
||||||
*/
|
|
||||||
export class LogDeduplicator {
|
|
||||||
private globalFlushTimer?: NodeJS.Timeout;
|
|
||||||
private aggregatedEvents: Map<string, IAggregatedEvent> = new Map();
|
|
||||||
private flushInterval: number = 5000; // 5 seconds
|
|
||||||
private maxBatchSize: number = 100;
|
|
||||||
private rapidEventThreshold: number = 50; // Flush early if this many events in 1 second
|
|
||||||
private lastRapidCheck: number = Date.now();
|
|
||||||
|
|
||||||
constructor(flushInterval?: number) {
|
|
||||||
if (flushInterval) {
|
|
||||||
this.flushInterval = flushInterval;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set up global periodic flush to ensure logs are emitted regularly
|
|
||||||
this.globalFlushTimer = setInterval(() => {
|
|
||||||
this.flushAll();
|
|
||||||
}, this.flushInterval * 2); // Flush everything every 2x the normal interval
|
|
||||||
|
|
||||||
if (this.globalFlushTimer.unref) {
|
|
||||||
this.globalFlushTimer.unref();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Log a deduplicated event
|
|
||||||
* @param key - Aggregation key (e.g., 'connection-rejected', 'cleanup-batch')
|
|
||||||
* @param level - Log level
|
|
||||||
* @param message - Log message template
|
|
||||||
* @param data - Additional data
|
|
||||||
* @param dedupeKey - Deduplication key within the aggregation (e.g., IP address, reason)
|
|
||||||
*/
|
|
||||||
public log(
|
|
||||||
key: string,
|
|
||||||
level: 'info' | 'warn' | 'error' | 'debug',
|
|
||||||
message: string,
|
|
||||||
data?: any,
|
|
||||||
dedupeKey?: string
|
|
||||||
): void {
|
|
||||||
const eventKey = dedupeKey || message;
|
|
||||||
const now = Date.now();
|
|
||||||
|
|
||||||
if (!this.aggregatedEvents.has(key)) {
|
|
||||||
this.aggregatedEvents.set(key, {
|
|
||||||
key,
|
|
||||||
events: new Map(),
|
|
||||||
flushTimer: undefined
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
const aggregated = this.aggregatedEvents.get(key)!;
|
|
||||||
|
|
||||||
if (aggregated.events.has(eventKey)) {
|
|
||||||
const event = aggregated.events.get(eventKey)!;
|
|
||||||
event.count++;
|
|
||||||
event.lastSeen = now;
|
|
||||||
if (data) {
|
|
||||||
event.data = { ...event.data, ...data };
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
aggregated.events.set(eventKey, {
|
|
||||||
level,
|
|
||||||
message,
|
|
||||||
data,
|
|
||||||
count: 1,
|
|
||||||
firstSeen: now,
|
|
||||||
lastSeen: now
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check for rapid events (many events in short time)
|
|
||||||
const totalEvents = Array.from(aggregated.events.values()).reduce((sum, e) => sum + e.count, 0);
|
|
||||||
|
|
||||||
// If we're getting flooded with events, flush more frequently
|
|
||||||
if (now - this.lastRapidCheck < 1000 && totalEvents >= this.rapidEventThreshold) {
|
|
||||||
this.flush(key);
|
|
||||||
this.lastRapidCheck = now;
|
|
||||||
} else if (aggregated.events.size >= this.maxBatchSize) {
|
|
||||||
// Check if we should flush due to size
|
|
||||||
this.flush(key);
|
|
||||||
} else if (!aggregated.flushTimer) {
|
|
||||||
// Schedule flush
|
|
||||||
aggregated.flushTimer = setTimeout(() => {
|
|
||||||
this.flush(key);
|
|
||||||
}, this.flushInterval);
|
|
||||||
|
|
||||||
if (aggregated.flushTimer.unref) {
|
|
||||||
aggregated.flushTimer.unref();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update rapid check time
|
|
||||||
if (now - this.lastRapidCheck >= 1000) {
|
|
||||||
this.lastRapidCheck = now;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Flush aggregated events for a specific key
|
|
||||||
*/
|
|
||||||
public flush(key: string): void {
|
|
||||||
const aggregated = this.aggregatedEvents.get(key);
|
|
||||||
if (!aggregated || aggregated.events.size === 0) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (aggregated.flushTimer) {
|
|
||||||
clearTimeout(aggregated.flushTimer);
|
|
||||||
aggregated.flushTimer = undefined;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Emit aggregated log based on the key
|
|
||||||
switch (key) {
|
|
||||||
case 'connection-rejected':
|
|
||||||
this.flushConnectionRejections(aggregated);
|
|
||||||
break;
|
|
||||||
case 'connection-cleanup':
|
|
||||||
this.flushConnectionCleanups(aggregated);
|
|
||||||
break;
|
|
||||||
case 'connection-terminated':
|
|
||||||
this.flushConnectionTerminations(aggregated);
|
|
||||||
break;
|
|
||||||
case 'ip-rejected':
|
|
||||||
this.flushIPRejections(aggregated);
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
this.flushGeneric(aggregated);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Clear events
|
|
||||||
aggregated.events.clear();
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Flush all pending events
|
|
||||||
*/
|
|
||||||
public flushAll(): void {
|
|
||||||
for (const key of this.aggregatedEvents.keys()) {
|
|
||||||
this.flush(key);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private flushConnectionRejections(aggregated: IAggregatedEvent): void {
|
|
||||||
const totalCount = Array.from(aggregated.events.values()).reduce((sum, e) => sum + e.count, 0);
|
|
||||||
const byReason = new Map<string, number>();
|
|
||||||
|
|
||||||
for (const [, event] of aggregated.events) {
|
|
||||||
const reason = event.data?.reason || 'unknown';
|
|
||||||
byReason.set(reason, (byReason.get(reason) || 0) + event.count);
|
|
||||||
}
|
|
||||||
|
|
||||||
const reasonSummary = Array.from(byReason.entries())
|
|
||||||
.sort((a, b) => b[1] - a[1])
|
|
||||||
.map(([reason, count]) => `${reason}: ${count}`)
|
|
||||||
.join(', ');
|
|
||||||
|
|
||||||
const duration = Date.now() - Math.min(...Array.from(aggregated.events.values()).map(e => e.firstSeen));
|
|
||||||
logger.log('warn', `[SUMMARY] Rejected ${totalCount} connections in ${Math.round(duration/1000)}s`, {
|
|
||||||
reasons: reasonSummary,
|
|
||||||
uniqueIPs: aggregated.events.size,
|
|
||||||
component: 'connection-dedup'
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
private flushConnectionCleanups(aggregated: IAggregatedEvent): void {
|
|
||||||
const totalCount = Array.from(aggregated.events.values()).reduce((sum, e) => sum + e.count, 0);
|
|
||||||
const byReason = new Map<string, number>();
|
|
||||||
|
|
||||||
for (const [, event] of aggregated.events) {
|
|
||||||
const reason = event.data?.reason || 'normal';
|
|
||||||
byReason.set(reason, (byReason.get(reason) || 0) + event.count);
|
|
||||||
}
|
|
||||||
|
|
||||||
const reasonSummary = Array.from(byReason.entries())
|
|
||||||
.sort((a, b) => b[1] - a[1])
|
|
||||||
.slice(0, 5) // Top 5 reasons
|
|
||||||
.map(([reason, count]) => `${reason}: ${count}`)
|
|
||||||
.join(', ');
|
|
||||||
|
|
||||||
logger.log('info', `Cleaned up ${totalCount} connections`, {
|
|
||||||
reasons: reasonSummary,
|
|
||||||
duration: Date.now() - Math.min(...Array.from(aggregated.events.values()).map(e => e.firstSeen)),
|
|
||||||
component: 'connection-dedup'
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
private flushConnectionTerminations(aggregated: IAggregatedEvent): void {
|
|
||||||
const totalCount = Array.from(aggregated.events.values()).reduce((sum, e) => sum + e.count, 0);
|
|
||||||
const byReason = new Map<string, number>();
|
|
||||||
const byIP = new Map<string, number>();
|
|
||||||
let lastActiveCount = 0;
|
|
||||||
|
|
||||||
for (const [, event] of aggregated.events) {
|
|
||||||
const reason = event.data?.reason || 'unknown';
|
|
||||||
const ip = event.data?.remoteIP || 'unknown';
|
|
||||||
|
|
||||||
byReason.set(reason, (byReason.get(reason) || 0) + event.count);
|
|
||||||
|
|
||||||
// Track by IP
|
|
||||||
if (ip !== 'unknown') {
|
|
||||||
byIP.set(ip, (byIP.get(ip) || 0) + event.count);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Track the last active connection count
|
|
||||||
if (event.data?.activeConnections !== undefined) {
|
|
||||||
lastActiveCount = event.data.activeConnections;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const reasonSummary = Array.from(byReason.entries())
|
|
||||||
.sort((a, b) => b[1] - a[1])
|
|
||||||
.slice(0, 5) // Top 5 reasons
|
|
||||||
.map(([reason, count]) => `${reason}: ${count}`)
|
|
||||||
.join(', ');
|
|
||||||
|
|
||||||
// Show top IPs if there are many different ones
|
|
||||||
let ipInfo = '';
|
|
||||||
if (byIP.size > 3) {
|
|
||||||
const topIPs = Array.from(byIP.entries())
|
|
||||||
.sort((a, b) => b[1] - a[1])
|
|
||||||
.slice(0, 3)
|
|
||||||
.map(([ip, count]) => `${ip} (${count})`)
|
|
||||||
.join(', ');
|
|
||||||
ipInfo = `, from ${byIP.size} IPs (top: ${topIPs})`;
|
|
||||||
} else if (byIP.size > 0) {
|
|
||||||
ipInfo = `, IPs: ${Array.from(byIP.keys()).join(', ')}`;
|
|
||||||
}
|
|
||||||
|
|
||||||
const duration = Date.now() - Math.min(...Array.from(aggregated.events.values()).map(e => e.firstSeen));
|
|
||||||
|
|
||||||
// Special handling for localhost connections (HttpProxy)
|
|
||||||
const localhostCount = byIP.get('::ffff:127.0.0.1') || 0;
|
|
||||||
if (localhostCount > 0 && byIP.size === 1) {
|
|
||||||
// All connections are from localhost (HttpProxy)
|
|
||||||
logger.log('info', `[SUMMARY] ${totalCount} HttpProxy connections terminated in ${Math.round(duration/1000)}s`, {
|
|
||||||
reasons: reasonSummary,
|
|
||||||
activeConnections: lastActiveCount,
|
|
||||||
component: 'connection-dedup'
|
|
||||||
});
|
|
||||||
} else {
|
|
||||||
logger.log('info', `[SUMMARY] ${totalCount} connections terminated in ${Math.round(duration/1000)}s`, {
|
|
||||||
reasons: reasonSummary,
|
|
||||||
activeConnections: lastActiveCount,
|
|
||||||
uniqueReasons: byReason.size,
|
|
||||||
...(ipInfo ? { ips: ipInfo } : {}),
|
|
||||||
component: 'connection-dedup'
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private flushIPRejections(aggregated: IAggregatedEvent): void {
|
|
||||||
const byIP = new Map<string, { count: number; reasons: Set<string> }>();
|
|
||||||
const allReasons = new Map<string, number>();
|
|
||||||
|
|
||||||
for (const [ip, event] of aggregated.events) {
|
|
||||||
if (!byIP.has(ip)) {
|
|
||||||
byIP.set(ip, { count: 0, reasons: new Set() });
|
|
||||||
}
|
|
||||||
const ipData = byIP.get(ip)!;
|
|
||||||
ipData.count += event.count;
|
|
||||||
if (event.data?.reason) {
|
|
||||||
ipData.reasons.add(event.data.reason);
|
|
||||||
// Track overall reason counts
|
|
||||||
allReasons.set(event.data.reason, (allReasons.get(event.data.reason) || 0) + event.count);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create reason summary
|
|
||||||
const reasonSummary = Array.from(allReasons.entries())
|
|
||||||
.sort((a, b) => b[1] - a[1])
|
|
||||||
.map(([reason, count]) => `${reason}: ${count}`)
|
|
||||||
.join(', ');
|
|
||||||
|
|
||||||
// Log top offenders
|
|
||||||
const topOffenders = Array.from(byIP.entries())
|
|
||||||
.sort((a, b) => b[1].count - a[1].count)
|
|
||||||
.slice(0, 10)
|
|
||||||
.map(([ip, data]) => `${ip} (${data.count}x, ${Array.from(data.reasons).join('/')})`)
|
|
||||||
.join(', ');
|
|
||||||
|
|
||||||
const totalRejections = Array.from(byIP.values()).reduce((sum, data) => sum + data.count, 0);
|
|
||||||
|
|
||||||
const duration = Date.now() - Math.min(...Array.from(aggregated.events.values()).map(e => e.firstSeen));
|
|
||||||
logger.log('warn', `[SUMMARY] Rejected ${totalRejections} connections from ${byIP.size} IPs in ${Math.round(duration/1000)}s (${reasonSummary})`, {
|
|
||||||
topOffenders,
|
|
||||||
component: 'ip-dedup'
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
private flushGeneric(aggregated: IAggregatedEvent): void {
|
|
||||||
const totalCount = Array.from(aggregated.events.values()).reduce((sum, e) => sum + e.count, 0);
|
|
||||||
const level = aggregated.events.values().next().value?.level || 'info';
|
|
||||||
|
|
||||||
// Special handling for IP cleanup events
|
|
||||||
if (aggregated.key === 'ip-cleanup') {
|
|
||||||
const totalCleaned = Array.from(aggregated.events.values()).reduce((sum, e) => {
|
|
||||||
return sum + (e.data?.cleanedIPs || 0) + (e.data?.cleanedRateLimits || 0);
|
|
||||||
}, 0);
|
|
||||||
|
|
||||||
if (totalCleaned > 0) {
|
|
||||||
logger.log(level as any, `IP tracking cleanup: removed ${totalCleaned} entries across ${totalCount} cleanup cycles`, {
|
|
||||||
duration: Date.now() - Math.min(...Array.from(aggregated.events.values()).map(e => e.firstSeen)),
|
|
||||||
component: 'log-dedup'
|
|
||||||
});
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
logger.log(level as any, `${aggregated.key}: ${totalCount} events`, {
|
|
||||||
uniqueEvents: aggregated.events.size,
|
|
||||||
duration: Date.now() - Math.min(...Array.from(aggregated.events.values()).map(e => e.firstSeen)),
|
|
||||||
component: 'log-dedup'
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Cleanup and stop deduplication
|
|
||||||
*/
|
|
||||||
public cleanup(): void {
|
|
||||||
this.flushAll();
|
|
||||||
|
|
||||||
if (this.globalFlushTimer) {
|
|
||||||
clearInterval(this.globalFlushTimer);
|
|
||||||
this.globalFlushTimer = undefined;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (const aggregated of this.aggregatedEvents.values()) {
|
|
||||||
if (aggregated.flushTimer) {
|
|
||||||
clearTimeout(aggregated.flushTimer);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
this.aggregatedEvents.clear();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Global instance for connection-related log deduplication
|
|
||||||
export const connectionLogDeduplicator = new LogDeduplicator(5000); // 5 second batches
|
|
||||||
|
|
||||||
// Ensure logs are flushed on process exit.
|
|
||||||
// Only use beforeExit — do NOT call process.exit() from SIGINT/SIGTERM handlers
|
|
||||||
// as that kills the host process's graceful shutdown (e.g., dcrouter connection draining).
|
|
||||||
process.on('beforeExit', () => {
|
|
||||||
connectionLogDeduplicator.flushAll();
|
|
||||||
});
|
|
||||||
|
|
||||||
process.on('SIGINT', () => {
|
|
||||||
connectionLogDeduplicator.cleanup();
|
|
||||||
});
|
|
||||||
|
|
||||||
process.on('SIGTERM', () => {
|
|
||||||
connectionLogDeduplicator.cleanup();
|
|
||||||
});
|
|
||||||
@@ -1,305 +0,0 @@
|
|||||||
import * as plugins from '../../plugins.js';
|
|
||||||
import { IpMatcher } from '../routing/matchers/ip.js';
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Security utilities for IP validation, rate limiting,
|
|
||||||
* authentication, and other security features
|
|
||||||
*/
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Result of IP validation
|
|
||||||
*/
|
|
||||||
export interface IIpValidationResult {
|
|
||||||
allowed: boolean;
|
|
||||||
reason?: string;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* IP connection tracking information
|
|
||||||
*/
|
|
||||||
export interface IIpConnectionInfo {
|
|
||||||
connections: Set<string>; // ConnectionIDs
|
|
||||||
timestamps: number[]; // Connection timestamps
|
|
||||||
ipVariants: string[]; // Normalized IP variants (e.g., ::ffff:127.0.0.1 and 127.0.0.1)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Rate limit tracking
|
|
||||||
*/
|
|
||||||
export interface IRateLimitInfo {
|
|
||||||
count: number;
|
|
||||||
expiry: number;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Logger interface for security utilities
|
|
||||||
*/
|
|
||||||
export interface ISecurityLogger {
|
|
||||||
info: (message: string, ...args: any[]) => void;
|
|
||||||
warn: (message: string, ...args: any[]) => void;
|
|
||||||
error: (message: string, ...args: any[]) => void;
|
|
||||||
debug?: (message: string, ...args: any[]) => void;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Normalize IP addresses for comparison
|
|
||||||
* Handles IPv4-mapped IPv6 addresses (::ffff:127.0.0.1)
|
|
||||||
*
|
|
||||||
* @param ip IP address to normalize
|
|
||||||
* @returns Array of equivalent IP representations
|
|
||||||
*/
|
|
||||||
export function normalizeIP(ip: string): string[] {
|
|
||||||
if (!ip) return [];
|
|
||||||
|
|
||||||
// Handle IPv4-mapped IPv6 addresses (::ffff:127.0.0.1)
|
|
||||||
if (ip.startsWith('::ffff:')) {
|
|
||||||
const ipv4 = ip.slice(7);
|
|
||||||
return [ip, ipv4];
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle IPv4 addresses by also checking IPv4-mapped form
|
|
||||||
if (/^\d{1,3}(\.\d{1,3}){3}$/.test(ip)) {
|
|
||||||
return [ip, `::ffff:${ip}`];
|
|
||||||
}
|
|
||||||
|
|
||||||
return [ip];
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Check if an IP is authorized based on allow and block lists
|
|
||||||
*
|
|
||||||
* @param ip - The IP address to check
|
|
||||||
* @param allowedIPs - Array of allowed IP patterns
|
|
||||||
* @param blockedIPs - Array of blocked IP patterns
|
|
||||||
* @returns Whether the IP is authorized
|
|
||||||
*/
|
|
||||||
export function isIPAuthorized(
|
|
||||||
ip: string,
|
|
||||||
allowedIPs: string[] = ['*'],
|
|
||||||
blockedIPs: string[] = []
|
|
||||||
): boolean {
|
|
||||||
// Skip IP validation if no rules
|
|
||||||
if (!ip || (allowedIPs.length === 0 && blockedIPs.length === 0)) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
// First check if IP is blocked - blocked IPs take precedence
|
|
||||||
if (blockedIPs.length > 0) {
|
|
||||||
for (const pattern of blockedIPs) {
|
|
||||||
if (IpMatcher.match(pattern, ip)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// If allowed IPs list has wildcard, all non-blocked IPs are allowed
|
|
||||||
if (allowedIPs.includes('*')) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Then check if IP is allowed in the explicit allow list
|
|
||||||
if (allowedIPs.length > 0) {
|
|
||||||
for (const pattern of allowedIPs) {
|
|
||||||
if (IpMatcher.match(pattern, ip)) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// If allowedIPs is specified but no match, deny access
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Default allow if no explicit allow list
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Check if an IP exceeds maximum connections
|
|
||||||
*
|
|
||||||
* @param ip - The IP address to check
|
|
||||||
* @param ipConnectionsMap - Map of IPs to connection info
|
|
||||||
* @param maxConnectionsPerIP - Maximum allowed connections per IP
|
|
||||||
* @returns Result with allowed status and reason if blocked
|
|
||||||
*/
|
|
||||||
export function checkMaxConnections(
|
|
||||||
ip: string,
|
|
||||||
ipConnectionsMap: Map<string, IIpConnectionInfo>,
|
|
||||||
maxConnectionsPerIP: number
|
|
||||||
): IIpValidationResult {
|
|
||||||
if (!ipConnectionsMap.has(ip)) {
|
|
||||||
return { allowed: true };
|
|
||||||
}
|
|
||||||
|
|
||||||
const connectionCount = ipConnectionsMap.get(ip)!.connections.size;
|
|
||||||
|
|
||||||
if (connectionCount >= maxConnectionsPerIP) {
|
|
||||||
return {
|
|
||||||
allowed: false,
|
|
||||||
reason: `Maximum connections per IP (${maxConnectionsPerIP}) exceeded`
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
return { allowed: true };
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Check if an IP exceeds connection rate limit
|
|
||||||
*
|
|
||||||
* @param ip - The IP address to check
|
|
||||||
* @param ipConnectionsMap - Map of IPs to connection info
|
|
||||||
* @param rateLimit - Maximum connections per minute
|
|
||||||
* @returns Result with allowed status and reason if blocked
|
|
||||||
*/
|
|
||||||
export function checkConnectionRate(
|
|
||||||
ip: string,
|
|
||||||
ipConnectionsMap: Map<string, IIpConnectionInfo>,
|
|
||||||
rateLimit: number
|
|
||||||
): IIpValidationResult {
|
|
||||||
const now = Date.now();
|
|
||||||
const minute = 60 * 1000;
|
|
||||||
|
|
||||||
// Get or create connection info
|
|
||||||
if (!ipConnectionsMap.has(ip)) {
|
|
||||||
const info: IIpConnectionInfo = {
|
|
||||||
connections: new Set(),
|
|
||||||
timestamps: [now],
|
|
||||||
ipVariants: normalizeIP(ip)
|
|
||||||
};
|
|
||||||
ipConnectionsMap.set(ip, info);
|
|
||||||
return { allowed: true };
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get timestamps and filter out entries older than 1 minute
|
|
||||||
const info = ipConnectionsMap.get(ip)!;
|
|
||||||
const timestamps = info.timestamps.filter(time => now - time < minute);
|
|
||||||
timestamps.push(now);
|
|
||||||
info.timestamps = timestamps;
|
|
||||||
|
|
||||||
// Check if rate exceeds limit
|
|
||||||
if (timestamps.length > rateLimit) {
|
|
||||||
return {
|
|
||||||
allowed: false,
|
|
||||||
reason: `Connection rate limit (${rateLimit}/min) exceeded`
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
return { allowed: true };
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Track a connection for an IP
|
|
||||||
*
|
|
||||||
* @param ip - The IP address
|
|
||||||
* @param connectionId - The connection ID to track
|
|
||||||
* @param ipConnectionsMap - Map of IPs to connection info
|
|
||||||
*/
|
|
||||||
export function trackConnection(
|
|
||||||
ip: string,
|
|
||||||
connectionId: string,
|
|
||||||
ipConnectionsMap: Map<string, IIpConnectionInfo>
|
|
||||||
): void {
|
|
||||||
if (!ipConnectionsMap.has(ip)) {
|
|
||||||
ipConnectionsMap.set(ip, {
|
|
||||||
connections: new Set([connectionId]),
|
|
||||||
timestamps: [Date.now()],
|
|
||||||
ipVariants: normalizeIP(ip)
|
|
||||||
});
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const info = ipConnectionsMap.get(ip)!;
|
|
||||||
info.connections.add(connectionId);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Remove connection tracking for an IP
|
|
||||||
*
|
|
||||||
* @param ip - The IP address
|
|
||||||
* @param connectionId - The connection ID to remove
|
|
||||||
* @param ipConnectionsMap - Map of IPs to connection info
|
|
||||||
*/
|
|
||||||
export function removeConnection(
|
|
||||||
ip: string,
|
|
||||||
connectionId: string,
|
|
||||||
ipConnectionsMap: Map<string, IIpConnectionInfo>
|
|
||||||
): void {
|
|
||||||
if (!ipConnectionsMap.has(ip)) return;
|
|
||||||
|
|
||||||
const info = ipConnectionsMap.get(ip)!;
|
|
||||||
info.connections.delete(connectionId);
|
|
||||||
|
|
||||||
if (info.connections.size === 0) {
|
|
||||||
ipConnectionsMap.delete(ip);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Clean up expired rate limits
|
|
||||||
*
|
|
||||||
* @param rateLimits - Map of rate limits to clean up
|
|
||||||
* @param logger - Logger for debug messages
|
|
||||||
*/
|
|
||||||
export function cleanupExpiredRateLimits(
|
|
||||||
rateLimits: Map<string, Map<string, IRateLimitInfo>>,
|
|
||||||
logger?: ISecurityLogger
|
|
||||||
): void {
|
|
||||||
const now = Date.now();
|
|
||||||
let totalRemoved = 0;
|
|
||||||
|
|
||||||
for (const [routeId, routeLimits] of rateLimits.entries()) {
|
|
||||||
let removed = 0;
|
|
||||||
for (const [key, limit] of routeLimits.entries()) {
|
|
||||||
if (limit.expiry < now) {
|
|
||||||
routeLimits.delete(key);
|
|
||||||
removed++;
|
|
||||||
totalRemoved++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (removed > 0 && logger?.debug) {
|
|
||||||
logger.debug(`Cleaned up ${removed} expired rate limits for route ${routeId}`);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (totalRemoved > 0 && logger?.info) {
|
|
||||||
logger.info(`Cleaned up ${totalRemoved} expired rate limits total`);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Generate basic auth header value from username and password
|
|
||||||
*
|
|
||||||
* @param username - The username
|
|
||||||
* @param password - The password
|
|
||||||
* @returns Base64 encoded basic auth string
|
|
||||||
*/
|
|
||||||
export function generateBasicAuthHeader(username: string, password: string): string {
|
|
||||||
return `Basic ${Buffer.from(`${username}:${password}`).toString('base64')}`;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Parse basic auth header
|
|
||||||
*
|
|
||||||
* @param authHeader - The Authorization header value
|
|
||||||
* @returns Username and password, or null if invalid
|
|
||||||
*/
|
|
||||||
export function parseBasicAuthHeader(
|
|
||||||
authHeader: string
|
|
||||||
): { username: string; password: string } | null {
|
|
||||||
if (!authHeader || !authHeader.startsWith('Basic ')) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
try {
|
|
||||||
const base64 = authHeader.slice(6); // Remove 'Basic '
|
|
||||||
const decoded = Buffer.from(base64, 'base64').toString();
|
|
||||||
const [username, password] = decoded.split(':');
|
|
||||||
|
|
||||||
if (!username || !password) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
return { username, password };
|
|
||||||
} catch (err) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,470 +0,0 @@
|
|||||||
import * as plugins from '../../plugins.js';
|
|
||||||
import type { IRouteConfig, IRouteContext } from '../../proxies/smart-proxy/models/route-types.js';
|
|
||||||
import type {
|
|
||||||
IIpValidationResult,
|
|
||||||
IIpConnectionInfo,
|
|
||||||
ISecurityLogger,
|
|
||||||
IRateLimitInfo
|
|
||||||
} from './security-utils.js';
|
|
||||||
import {
|
|
||||||
isIPAuthorized,
|
|
||||||
checkMaxConnections,
|
|
||||||
checkConnectionRate,
|
|
||||||
trackConnection,
|
|
||||||
removeConnection,
|
|
||||||
cleanupExpiredRateLimits,
|
|
||||||
parseBasicAuthHeader,
|
|
||||||
normalizeIP
|
|
||||||
} from './security-utils.js';
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Shared SecurityManager for use across proxy components
|
|
||||||
* Handles IP tracking, rate limiting, and authentication
|
|
||||||
*/
|
|
||||||
export class SharedSecurityManager {
|
|
||||||
// IP connection tracking
|
|
||||||
private connectionsByIP: Map<string, IIpConnectionInfo> = new Map();
|
|
||||||
|
|
||||||
// Route-specific rate limiting
|
|
||||||
private rateLimits: Map<string, Map<string, IRateLimitInfo>> = new Map();
|
|
||||||
|
|
||||||
// Cache IP filtering results to avoid constant regex matching
|
|
||||||
private ipFilterCache: Map<string, Map<string, boolean>> = new Map();
|
|
||||||
|
|
||||||
// Default limits
|
|
||||||
private maxConnectionsPerIP: number;
|
|
||||||
private connectionRateLimitPerMinute: number;
|
|
||||||
|
|
||||||
// Cache cleanup interval
|
|
||||||
private cleanupInterval: NodeJS.Timeout | null = null;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Create a new SharedSecurityManager
|
|
||||||
*
|
|
||||||
* @param options - Configuration options
|
|
||||||
* @param logger - Logger instance
|
|
||||||
*/
|
|
||||||
constructor(options: {
|
|
||||||
maxConnectionsPerIP?: number;
|
|
||||||
connectionRateLimitPerMinute?: number;
|
|
||||||
cleanupIntervalMs?: number;
|
|
||||||
routes?: IRouteConfig[];
|
|
||||||
}, private logger?: ISecurityLogger) {
|
|
||||||
this.maxConnectionsPerIP = options.maxConnectionsPerIP || 100;
|
|
||||||
this.connectionRateLimitPerMinute = options.connectionRateLimitPerMinute || 300;
|
|
||||||
|
|
||||||
// Set up logger with defaults if not provided
|
|
||||||
this.logger = logger || {
|
|
||||||
info: console.log,
|
|
||||||
warn: console.warn,
|
|
||||||
error: console.error
|
|
||||||
};
|
|
||||||
|
|
||||||
// Set up cache cleanup interval
|
|
||||||
const cleanupInterval = options.cleanupIntervalMs || 60000; // Default: 1 minute
|
|
||||||
this.cleanupInterval = setInterval(() => {
|
|
||||||
this.cleanupCaches();
|
|
||||||
}, cleanupInterval);
|
|
||||||
|
|
||||||
// Don't keep the process alive just for cleanup
|
|
||||||
if (this.cleanupInterval.unref) {
|
|
||||||
this.cleanupInterval.unref();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get connections count by IP
|
|
||||||
*
|
|
||||||
* @param ip - The IP address to check
|
|
||||||
* @returns Number of connections from this IP
|
|
||||||
*/
|
|
||||||
public getConnectionCountByIP(ip: string): number {
|
|
||||||
// Check all normalized variants of the IP
|
|
||||||
const variants = normalizeIP(ip);
|
|
||||||
for (const variant of variants) {
|
|
||||||
const info = this.connectionsByIP.get(variant);
|
|
||||||
if (info) {
|
|
||||||
return info.connections.size;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Track connection by IP
|
|
||||||
*
|
|
||||||
* @param ip - The IP address to track
|
|
||||||
* @param connectionId - The connection ID to associate
|
|
||||||
*/
|
|
||||||
public trackConnectionByIP(ip: string, connectionId: string): void {
|
|
||||||
// Check if any variant already exists
|
|
||||||
const variants = normalizeIP(ip);
|
|
||||||
let existingKey: string | null = null;
|
|
||||||
|
|
||||||
for (const variant of variants) {
|
|
||||||
if (this.connectionsByIP.has(variant)) {
|
|
||||||
existingKey = variant;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Use existing key or the original IP
|
|
||||||
trackConnection(existingKey || ip, connectionId, this.connectionsByIP);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Remove connection tracking for an IP
|
|
||||||
*
|
|
||||||
* @param ip - The IP address to update
|
|
||||||
* @param connectionId - The connection ID to remove
|
|
||||||
*/
|
|
||||||
public removeConnectionByIP(ip: string, connectionId: string): void {
|
|
||||||
// Check all variants to find where the connection is tracked
|
|
||||||
const variants = normalizeIP(ip);
|
|
||||||
|
|
||||||
for (const variant of variants) {
|
|
||||||
if (this.connectionsByIP.has(variant)) {
|
|
||||||
removeConnection(variant, connectionId, this.connectionsByIP);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Check if IP is authorized based on route security settings
|
|
||||||
*
|
|
||||||
* @param ip - The IP address to check
|
|
||||||
* @param allowedIPs - List of allowed IP patterns
|
|
||||||
* @param blockedIPs - List of blocked IP patterns
|
|
||||||
* @returns Whether the IP is authorized
|
|
||||||
*/
|
|
||||||
public isIPAuthorized(
|
|
||||||
ip: string,
|
|
||||||
allowedIPs: string[] = ['*'],
|
|
||||||
blockedIPs: string[] = []
|
|
||||||
): boolean {
|
|
||||||
return isIPAuthorized(ip, allowedIPs, blockedIPs);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Validate IP against rate limits and connection limits
|
|
||||||
*
|
|
||||||
* @param ip - The IP address to validate
|
|
||||||
* @returns Result with allowed status and reason if blocked
|
|
||||||
*/
|
|
||||||
public validateIP(ip: string): IIpValidationResult {
|
|
||||||
// Check connection count limit
|
|
||||||
const connectionResult = checkMaxConnections(
|
|
||||||
ip,
|
|
||||||
this.connectionsByIP,
|
|
||||||
this.maxConnectionsPerIP
|
|
||||||
);
|
|
||||||
if (!connectionResult.allowed) {
|
|
||||||
return connectionResult;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check connection rate limit
|
|
||||||
const rateResult = checkConnectionRate(
|
|
||||||
ip,
|
|
||||||
this.connectionsByIP,
|
|
||||||
this.connectionRateLimitPerMinute
|
|
||||||
);
|
|
||||||
if (!rateResult.allowed) {
|
|
||||||
return rateResult;
|
|
||||||
}
|
|
||||||
|
|
||||||
return { allowed: true };
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Atomically validate an IP and track the connection if allowed.
|
|
||||||
* This prevents race conditions where concurrent connections could bypass per-IP limits.
|
|
||||||
*
|
|
||||||
* @param ip - The IP address to validate
|
|
||||||
* @param connectionId - The connection ID to track if validation passes
|
|
||||||
* @returns Object with validation result and reason
|
|
||||||
*/
|
|
||||||
public validateAndTrackIP(ip: string, connectionId: string): IIpValidationResult {
|
|
||||||
// Check connection count limit BEFORE tracking
|
|
||||||
const connectionResult = checkMaxConnections(
|
|
||||||
ip,
|
|
||||||
this.connectionsByIP,
|
|
||||||
this.maxConnectionsPerIP
|
|
||||||
);
|
|
||||||
if (!connectionResult.allowed) {
|
|
||||||
return connectionResult;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check connection rate limit
|
|
||||||
const rateResult = checkConnectionRate(
|
|
||||||
ip,
|
|
||||||
this.connectionsByIP,
|
|
||||||
this.connectionRateLimitPerMinute
|
|
||||||
);
|
|
||||||
if (!rateResult.allowed) {
|
|
||||||
return rateResult;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validation passed - immediately track to prevent race conditions
|
|
||||||
this.trackConnectionByIP(ip, connectionId);
|
|
||||||
|
|
||||||
return { allowed: true };
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Check if a client is allowed to access a specific route
|
|
||||||
*
|
|
||||||
* @param route - The route to check
|
|
||||||
* @param context - The request context
|
|
||||||
* @param routeConnectionCount - Current connection count for this route (optional)
|
|
||||||
* @returns Whether access is allowed
|
|
||||||
*/
|
|
||||||
public isAllowed(route: IRouteConfig, context: IRouteContext, routeConnectionCount?: number): boolean {
|
|
||||||
if (!route.security) {
|
|
||||||
return true; // No security restrictions
|
|
||||||
}
|
|
||||||
|
|
||||||
// --- IP filtering ---
|
|
||||||
if (!this.isClientIpAllowed(route, context.clientIp)) {
|
|
||||||
this.logger?.debug?.(`IP ${context.clientIp} is blocked for route ${route.name || 'unnamed'}`);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// --- Route-level connection limit ---
|
|
||||||
if (route.security.maxConnections !== undefined && routeConnectionCount !== undefined) {
|
|
||||||
if (routeConnectionCount >= route.security.maxConnections) {
|
|
||||||
this.logger?.debug?.(`Route connection limit (${route.security.maxConnections}) exceeded for route ${route.name || 'unnamed'}`);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// --- Rate limiting ---
|
|
||||||
if (route.security.rateLimit?.enabled && !this.isWithinRateLimit(route, context)) {
|
|
||||||
this.logger?.debug?.(`Rate limit exceeded for route ${route.name || 'unnamed'}`);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Check if a client IP is allowed for a route
|
|
||||||
*
|
|
||||||
* @param route - The route to check
|
|
||||||
* @param clientIp - The client IP
|
|
||||||
* @returns Whether the IP is allowed
|
|
||||||
*/
|
|
||||||
private isClientIpAllowed(route: IRouteConfig, clientIp: string): boolean {
|
|
||||||
if (!route.security) {
|
|
||||||
return true; // No security restrictions
|
|
||||||
}
|
|
||||||
|
|
||||||
const routeId = route.id || route.name || 'unnamed';
|
|
||||||
|
|
||||||
// Check cache first
|
|
||||||
if (!this.ipFilterCache.has(routeId)) {
|
|
||||||
this.ipFilterCache.set(routeId, new Map());
|
|
||||||
}
|
|
||||||
|
|
||||||
const routeCache = this.ipFilterCache.get(routeId)!;
|
|
||||||
if (routeCache.has(clientIp)) {
|
|
||||||
return routeCache.get(clientIp)!;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check IP against route security settings
|
|
||||||
const ipAllowList = route.security.ipAllowList;
|
|
||||||
const ipBlockList = route.security.ipBlockList;
|
|
||||||
|
|
||||||
const allowed = this.isIPAuthorized(clientIp, ipAllowList, ipBlockList);
|
|
||||||
|
|
||||||
// Cache the result
|
|
||||||
routeCache.set(clientIp, allowed);
|
|
||||||
|
|
||||||
return allowed;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Check if request is within rate limit
|
|
||||||
*
|
|
||||||
* @param route - The route to check
|
|
||||||
* @param context - The request context
|
|
||||||
* @returns Whether the request is within rate limit
|
|
||||||
*/
|
|
||||||
private isWithinRateLimit(route: IRouteConfig, context: IRouteContext): boolean {
|
|
||||||
if (!route.security?.rateLimit?.enabled) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
const rateLimit = route.security.rateLimit;
|
|
||||||
const routeId = route.id || route.name || 'unnamed';
|
|
||||||
|
|
||||||
// Determine rate limit key (by IP, path, or header)
|
|
||||||
let key = context.clientIp; // Default to IP
|
|
||||||
|
|
||||||
if (rateLimit.keyBy === 'path' && context.path) {
|
|
||||||
key = `${context.clientIp}:${context.path}`;
|
|
||||||
} else if (rateLimit.keyBy === 'header' && rateLimit.headerName && context.headers) {
|
|
||||||
const headerValue = context.headers[rateLimit.headerName.toLowerCase()];
|
|
||||||
if (headerValue) {
|
|
||||||
key = `${context.clientIp}:${headerValue}`;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get or create rate limit tracking for this route
|
|
||||||
if (!this.rateLimits.has(routeId)) {
|
|
||||||
this.rateLimits.set(routeId, new Map());
|
|
||||||
}
|
|
||||||
|
|
||||||
const routeLimits = this.rateLimits.get(routeId)!;
|
|
||||||
const now = Date.now();
|
|
||||||
|
|
||||||
// Get or create rate limit tracking for this key
|
|
||||||
let limit = routeLimits.get(key);
|
|
||||||
if (!limit || limit.expiry < now) {
|
|
||||||
// Create new rate limit or reset expired one
|
|
||||||
limit = {
|
|
||||||
count: 1,
|
|
||||||
expiry: now + (rateLimit.window * 1000)
|
|
||||||
};
|
|
||||||
routeLimits.set(key, limit);
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Increment the counter
|
|
||||||
limit.count++;
|
|
||||||
|
|
||||||
// Check if rate limit is exceeded
|
|
||||||
return limit.count <= rateLimit.maxRequests;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Validate HTTP Basic Authentication
|
|
||||||
*
|
|
||||||
* @param route - The route to check
|
|
||||||
* @param authHeader - The Authorization header
|
|
||||||
* @returns Whether authentication is valid
|
|
||||||
*/
|
|
||||||
public validateBasicAuth(route: IRouteConfig, authHeader?: string): boolean {
|
|
||||||
// Skip if basic auth not enabled for route
|
|
||||||
if (!route.security?.basicAuth?.enabled) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
// No auth header means auth failed
|
|
||||||
if (!authHeader) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse auth header
|
|
||||||
const credentials = parseBasicAuthHeader(authHeader);
|
|
||||||
if (!credentials) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check credentials against configured users
|
|
||||||
const { username, password } = credentials;
|
|
||||||
const users = route.security.basicAuth.users;
|
|
||||||
|
|
||||||
return users.some(user =>
|
|
||||||
user.username === username && user.password === password
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Verify a JWT token against route configuration
|
|
||||||
*
|
|
||||||
* @param route - The route to verify the token for
|
|
||||||
* @param token - The JWT token to verify
|
|
||||||
* @returns True if the token is valid, false otherwise
|
|
||||||
*/
|
|
||||||
public verifyJwtToken(route: IRouteConfig, token: string): boolean {
|
|
||||||
if (!route.security?.jwtAuth?.enabled) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
try {
|
|
||||||
const jwtAuth = route.security.jwtAuth;
|
|
||||||
|
|
||||||
// Verify structure (header.payload.signature)
|
|
||||||
const parts = token.split('.');
|
|
||||||
if (parts.length !== 3) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Decode payload
|
|
||||||
const payload = JSON.parse(Buffer.from(parts[1], 'base64').toString());
|
|
||||||
|
|
||||||
// Check expiration
|
|
||||||
if (payload.exp && payload.exp < Math.floor(Date.now() / 1000)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check issuer
|
|
||||||
if (jwtAuth.issuer && payload.iss !== jwtAuth.issuer) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check audience
|
|
||||||
if (jwtAuth.audience && payload.aud !== jwtAuth.audience) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Note: In a real implementation, you'd also verify the signature
|
|
||||||
// using the secret and algorithm specified in jwtAuth.
|
|
||||||
// This requires a proper JWT library for cryptographic verification.
|
|
||||||
|
|
||||||
return true;
|
|
||||||
} catch (err) {
|
|
||||||
this.logger?.error?.(`Error verifying JWT: ${err}`);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Clean up caches to prevent memory leaks
|
|
||||||
*/
|
|
||||||
private cleanupCaches(): void {
|
|
||||||
// Clean up rate limits
|
|
||||||
cleanupExpiredRateLimits(this.rateLimits, this.logger);
|
|
||||||
|
|
||||||
// Clean up IP connection tracking
|
|
||||||
let cleanedIPs = 0;
|
|
||||||
for (const [ip, info] of this.connectionsByIP.entries()) {
|
|
||||||
// Remove IPs with no active connections and no recent timestamps
|
|
||||||
if (info.connections.size === 0 && info.timestamps.length === 0) {
|
|
||||||
this.connectionsByIP.delete(ip);
|
|
||||||
cleanedIPs++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (cleanedIPs > 0 && this.logger?.debug) {
|
|
||||||
this.logger.debug(`Cleaned up ${cleanedIPs} IPs with no active connections`);
|
|
||||||
}
|
|
||||||
|
|
||||||
// IP filter cache doesn't need cleanup (tied to routes)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Clear all IP tracking data (for shutdown)
|
|
||||||
*/
|
|
||||||
public clearIPTracking(): void {
|
|
||||||
this.connectionsByIP.clear();
|
|
||||||
this.rateLimits.clear();
|
|
||||||
this.ipFilterCache.clear();
|
|
||||||
|
|
||||||
if (this.cleanupInterval) {
|
|
||||||
clearInterval(this.cleanupInterval);
|
|
||||||
this.cleanupInterval = null;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Update routes for security checking
|
|
||||||
*
|
|
||||||
* @param routes - New routes to use
|
|
||||||
*/
|
|
||||||
public setRoutes(routes: IRouteConfig[]): void {
|
|
||||||
// Only clear the IP filter cache - route-specific
|
|
||||||
this.ipFilterCache.clear();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,322 +0,0 @@
|
|||||||
import * as plugins from '../../plugins.js';
|
|
||||||
|
|
||||||
export interface CleanupOptions {
|
|
||||||
immediate?: boolean; // Force immediate destruction
|
|
||||||
allowDrain?: boolean; // Allow write buffer to drain
|
|
||||||
gracePeriod?: number; // Ms to wait before force close
|
|
||||||
}
|
|
||||||
|
|
||||||
export interface SafeSocketOptions {
|
|
||||||
port: number;
|
|
||||||
host: string;
|
|
||||||
onError?: (error: Error) => void;
|
|
||||||
onConnect?: () => void;
|
|
||||||
timeout?: number;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Safely cleanup a socket by removing all listeners and destroying it
|
|
||||||
* @param socket The socket to cleanup
|
|
||||||
* @param socketName Optional name for logging
|
|
||||||
* @param options Cleanup options
|
|
||||||
*/
|
|
||||||
export function cleanupSocket(
|
|
||||||
socket: plugins.net.Socket | plugins.tls.TLSSocket | null,
|
|
||||||
socketName?: string,
|
|
||||||
options: CleanupOptions = {}
|
|
||||||
): Promise<void> {
|
|
||||||
if (!socket || socket.destroyed) return Promise.resolve();
|
|
||||||
|
|
||||||
return new Promise<void>((resolve) => {
|
|
||||||
const cleanup = () => {
|
|
||||||
try {
|
|
||||||
// Remove all event listeners
|
|
||||||
socket.removeAllListeners();
|
|
||||||
|
|
||||||
// Destroy if not already destroyed
|
|
||||||
if (!socket.destroyed) {
|
|
||||||
socket.destroy();
|
|
||||||
}
|
|
||||||
} catch (err) {
|
|
||||||
console.error(`Error cleaning up socket${socketName ? ` (${socketName})` : ''}: ${err}`);
|
|
||||||
}
|
|
||||||
resolve();
|
|
||||||
};
|
|
||||||
|
|
||||||
if (options.immediate) {
|
|
||||||
// Immediate cleanup (old behavior)
|
|
||||||
socket.unpipe();
|
|
||||||
cleanup();
|
|
||||||
} else if (options.allowDrain && socket.writable) {
|
|
||||||
// Allow pending writes to complete
|
|
||||||
socket.end(() => cleanup());
|
|
||||||
|
|
||||||
// Force cleanup after grace period
|
|
||||||
if (options.gracePeriod) {
|
|
||||||
setTimeout(() => {
|
|
||||||
if (!socket.destroyed) {
|
|
||||||
cleanup();
|
|
||||||
}
|
|
||||||
}, options.gracePeriod);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Default: immediate cleanup
|
|
||||||
socket.unpipe();
|
|
||||||
cleanup();
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Create independent cleanup handlers for paired sockets that support half-open connections
|
|
||||||
* @param clientSocket The client socket
|
|
||||||
* @param serverSocket The server socket
|
|
||||||
* @param onBothClosed Callback when both sockets are closed
|
|
||||||
* @returns Independent cleanup functions for each socket
|
|
||||||
*/
|
|
||||||
export function createIndependentSocketHandlers(
|
|
||||||
clientSocket: plugins.net.Socket | plugins.tls.TLSSocket,
|
|
||||||
serverSocket: plugins.net.Socket | plugins.tls.TLSSocket,
|
|
||||||
onBothClosed: (reason: string) => void,
|
|
||||||
options: { enableHalfOpen?: boolean } = {}
|
|
||||||
): { cleanupClient: (reason: string) => Promise<void>, cleanupServer: (reason: string) => Promise<void> } {
|
|
||||||
let clientClosed = false;
|
|
||||||
let serverClosed = false;
|
|
||||||
let clientReason = '';
|
|
||||||
let serverReason = '';
|
|
||||||
|
|
||||||
const checkBothClosed = () => {
|
|
||||||
if (clientClosed && serverClosed) {
|
|
||||||
onBothClosed(`client: ${clientReason}, server: ${serverReason}`);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
const cleanupClient = async (reason: string) => {
|
|
||||||
if (clientClosed) return;
|
|
||||||
clientClosed = true;
|
|
||||||
clientReason = reason;
|
|
||||||
|
|
||||||
// Default behavior: close both sockets when one closes (required for proxy chains)
|
|
||||||
if (!serverClosed && !options.enableHalfOpen) {
|
|
||||||
serverSocket.destroy();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Half-open support (opt-in only)
|
|
||||||
if (!serverClosed && serverSocket.writable && options.enableHalfOpen) {
|
|
||||||
// Half-close: stop reading from client, let server finish
|
|
||||||
clientSocket.pause();
|
|
||||||
clientSocket.unpipe(serverSocket);
|
|
||||||
await cleanupSocket(clientSocket, 'client', { allowDrain: true, gracePeriod: 5000 });
|
|
||||||
} else {
|
|
||||||
await cleanupSocket(clientSocket, 'client', { immediate: true });
|
|
||||||
}
|
|
||||||
|
|
||||||
checkBothClosed();
|
|
||||||
};
|
|
||||||
|
|
||||||
const cleanupServer = async (reason: string) => {
|
|
||||||
if (serverClosed) return;
|
|
||||||
serverClosed = true;
|
|
||||||
serverReason = reason;
|
|
||||||
|
|
||||||
// Default behavior: close both sockets when one closes (required for proxy chains)
|
|
||||||
if (!clientClosed && !options.enableHalfOpen) {
|
|
||||||
clientSocket.destroy();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Half-open support (opt-in only)
|
|
||||||
if (!clientClosed && clientSocket.writable && options.enableHalfOpen) {
|
|
||||||
// Half-close: stop reading from server, let client finish
|
|
||||||
serverSocket.pause();
|
|
||||||
serverSocket.unpipe(clientSocket);
|
|
||||||
await cleanupSocket(serverSocket, 'server', { allowDrain: true, gracePeriod: 5000 });
|
|
||||||
} else {
|
|
||||||
await cleanupSocket(serverSocket, 'server', { immediate: true });
|
|
||||||
}
|
|
||||||
|
|
||||||
checkBothClosed();
|
|
||||||
};
|
|
||||||
|
|
||||||
return { cleanupClient, cleanupServer };
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Setup socket error and close handlers with proper cleanup
|
|
||||||
* @param socket The socket to setup handlers for
|
|
||||||
* @param handleClose The cleanup function to call
|
|
||||||
* @param handleTimeout Optional custom timeout handler
|
|
||||||
* @param errorPrefix Optional prefix for error messages
|
|
||||||
*/
|
|
||||||
export function setupSocketHandlers(
|
|
||||||
socket: plugins.net.Socket | plugins.tls.TLSSocket,
|
|
||||||
handleClose: (reason: string) => void,
|
|
||||||
handleTimeout?: (socket: plugins.net.Socket | plugins.tls.TLSSocket) => void,
|
|
||||||
errorPrefix?: string
|
|
||||||
): void {
|
|
||||||
socket.on('error', (error) => {
|
|
||||||
const prefix = errorPrefix || 'Socket';
|
|
||||||
handleClose(`${prefix}_error: ${error.message}`);
|
|
||||||
});
|
|
||||||
|
|
||||||
socket.on('close', () => {
|
|
||||||
const prefix = errorPrefix || 'socket';
|
|
||||||
handleClose(`${prefix}_closed`);
|
|
||||||
});
|
|
||||||
|
|
||||||
socket.on('timeout', () => {
|
|
||||||
if (handleTimeout) {
|
|
||||||
handleTimeout(socket); // Custom timeout handling
|
|
||||||
} else {
|
|
||||||
// Default: just log, don't close
|
|
||||||
console.warn(`Socket timeout: ${errorPrefix || 'socket'}`);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Setup bidirectional data forwarding between two sockets with proper cleanup
|
|
||||||
* @param clientSocket The client/incoming socket
|
|
||||||
* @param serverSocket The server/outgoing socket
|
|
||||||
* @param handlers Object containing optional handlers for data and cleanup
|
|
||||||
* @returns Cleanup functions for both sockets
|
|
||||||
*/
|
|
||||||
export function setupBidirectionalForwarding(
|
|
||||||
clientSocket: plugins.net.Socket | plugins.tls.TLSSocket,
|
|
||||||
serverSocket: plugins.net.Socket | plugins.tls.TLSSocket,
|
|
||||||
handlers: {
|
|
||||||
onClientData?: (chunk: Buffer) => void;
|
|
||||||
onServerData?: (chunk: Buffer) => void;
|
|
||||||
onCleanup: (reason: string) => void;
|
|
||||||
enableHalfOpen?: boolean;
|
|
||||||
}
|
|
||||||
): { cleanupClient: (reason: string) => Promise<void>, cleanupServer: (reason: string) => Promise<void> } {
|
|
||||||
// Set up cleanup handlers
|
|
||||||
const { cleanupClient, cleanupServer } = createIndependentSocketHandlers(
|
|
||||||
clientSocket,
|
|
||||||
serverSocket,
|
|
||||||
handlers.onCleanup,
|
|
||||||
{ enableHalfOpen: handlers.enableHalfOpen }
|
|
||||||
);
|
|
||||||
|
|
||||||
// Set up error and close handlers
|
|
||||||
setupSocketHandlers(clientSocket, cleanupClient, undefined, 'client');
|
|
||||||
setupSocketHandlers(serverSocket, cleanupServer, undefined, 'server');
|
|
||||||
|
|
||||||
// Set up data forwarding with backpressure handling
|
|
||||||
clientSocket.on('data', (chunk: Buffer) => {
|
|
||||||
if (handlers.onClientData) {
|
|
||||||
handlers.onClientData(chunk);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (serverSocket.writable) {
|
|
||||||
const flushed = serverSocket.write(chunk);
|
|
||||||
|
|
||||||
// Handle backpressure
|
|
||||||
if (!flushed) {
|
|
||||||
clientSocket.pause();
|
|
||||||
serverSocket.once('drain', () => {
|
|
||||||
if (!clientSocket.destroyed) {
|
|
||||||
clientSocket.resume();
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
serverSocket.on('data', (chunk: Buffer) => {
|
|
||||||
if (handlers.onServerData) {
|
|
||||||
handlers.onServerData(chunk);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (clientSocket.writable) {
|
|
||||||
const flushed = clientSocket.write(chunk);
|
|
||||||
|
|
||||||
// Handle backpressure
|
|
||||||
if (!flushed) {
|
|
||||||
serverSocket.pause();
|
|
||||||
clientSocket.once('drain', () => {
|
|
||||||
if (!serverSocket.destroyed) {
|
|
||||||
serverSocket.resume();
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
return { cleanupClient, cleanupServer };
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Create a socket with immediate error handling to prevent crashes
|
|
||||||
* @param options Socket creation options
|
|
||||||
* @returns The created socket
|
|
||||||
*/
|
|
||||||
export function createSocketWithErrorHandler(options: SafeSocketOptions): plugins.net.Socket {
|
|
||||||
const { port, host, onError, onConnect, timeout } = options;
|
|
||||||
|
|
||||||
// Create socket with immediate error handler attachment
|
|
||||||
const socket = new plugins.net.Socket();
|
|
||||||
|
|
||||||
// Track if connected
|
|
||||||
let connected = false;
|
|
||||||
let connectionTimeout: NodeJS.Timeout | null = null;
|
|
||||||
|
|
||||||
// Attach error handler BEFORE connecting to catch immediate errors
|
|
||||||
socket.on('error', (error) => {
|
|
||||||
console.error(`Socket connection error to ${host}:${port}: ${error.message}`);
|
|
||||||
// Clear the connection timeout if it exists
|
|
||||||
if (connectionTimeout) {
|
|
||||||
clearTimeout(connectionTimeout);
|
|
||||||
connectionTimeout = null;
|
|
||||||
}
|
|
||||||
if (onError) {
|
|
||||||
onError(error);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
// Attach connect handler
|
|
||||||
const handleConnect = () => {
|
|
||||||
connected = true;
|
|
||||||
// Clear the connection timeout
|
|
||||||
if (connectionTimeout) {
|
|
||||||
clearTimeout(connectionTimeout);
|
|
||||||
connectionTimeout = null;
|
|
||||||
}
|
|
||||||
// Set inactivity timeout if provided (after connection is established)
|
|
||||||
if (timeout) {
|
|
||||||
socket.setTimeout(timeout);
|
|
||||||
}
|
|
||||||
if (onConnect) {
|
|
||||||
onConnect();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
socket.on('connect', handleConnect);
|
|
||||||
|
|
||||||
// Implement connection establishment timeout
|
|
||||||
if (timeout) {
|
|
||||||
connectionTimeout = setTimeout(() => {
|
|
||||||
if (!connected && !socket.destroyed) {
|
|
||||||
// Connection timed out - destroy the socket
|
|
||||||
const error = new Error(`Connection timeout after ${timeout}ms to ${host}:${port}`);
|
|
||||||
(error as any).code = 'ETIMEDOUT';
|
|
||||||
|
|
||||||
console.error(`Socket connection timeout to ${host}:${port} after ${timeout}ms`);
|
|
||||||
|
|
||||||
// Destroy the socket
|
|
||||||
socket.destroy();
|
|
||||||
|
|
||||||
// Call error handler
|
|
||||||
if (onError) {
|
|
||||||
onError(error);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}, timeout);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Now attempt to connect - any immediate errors will be caught
|
|
||||||
socket.connect(port, host);
|
|
||||||
|
|
||||||
return socket;
|
|
||||||
}
|
|
||||||
@@ -1,124 +0,0 @@
|
|||||||
import type { IRouteContext } from '../models/route-context.js';
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Utility class for resolving template variables in strings
|
|
||||||
*/
|
|
||||||
export class TemplateUtils {
|
|
||||||
/**
|
|
||||||
* Resolve template variables in a string using the route context
|
|
||||||
* Supports variables like {domain}, {path}, {clientIp}, etc.
|
|
||||||
*
|
|
||||||
* @param template The template string with {variables}
|
|
||||||
* @param context The route context with values
|
|
||||||
* @returns The resolved string
|
|
||||||
*/
|
|
||||||
public static resolveTemplateVariables(template: string, context: IRouteContext): string {
|
|
||||||
if (!template) {
|
|
||||||
return template;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Replace variables with values from context
|
|
||||||
return template.replace(/\{([a-zA-Z0-9_\.]+)\}/g, (match, varName) => {
|
|
||||||
// Handle nested properties with dot notation (e.g., {headers.host})
|
|
||||||
if (varName.includes('.')) {
|
|
||||||
const parts = varName.split('.');
|
|
||||||
let current: any = context;
|
|
||||||
|
|
||||||
// Traverse nested object structure
|
|
||||||
for (const part of parts) {
|
|
||||||
if (current === undefined || current === null) {
|
|
||||||
return match; // Return original if path doesn't exist
|
|
||||||
}
|
|
||||||
current = current[part];
|
|
||||||
}
|
|
||||||
|
|
||||||
// Return the resolved value if it exists
|
|
||||||
if (current !== undefined && current !== null) {
|
|
||||||
return TemplateUtils.convertToString(current);
|
|
||||||
}
|
|
||||||
|
|
||||||
return match;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Direct property access
|
|
||||||
const value = context[varName as keyof IRouteContext];
|
|
||||||
if (value === undefined) {
|
|
||||||
return match; // Keep the original {variable} if not found
|
|
||||||
}
|
|
||||||
|
|
||||||
// Convert value to string
|
|
||||||
return TemplateUtils.convertToString(value);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Safely convert a value to a string
|
|
||||||
*
|
|
||||||
* @param value Any value to convert to string
|
|
||||||
* @returns String representation or original match for complex objects
|
|
||||||
*/
|
|
||||||
private static convertToString(value: any): string {
|
|
||||||
if (value === null || value === undefined) {
|
|
||||||
return '';
|
|
||||||
}
|
|
||||||
|
|
||||||
if (typeof value === 'string') {
|
|
||||||
return value;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (typeof value === 'number' || typeof value === 'boolean') {
|
|
||||||
return value.toString();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (Array.isArray(value)) {
|
|
||||||
return value.join(',');
|
|
||||||
}
|
|
||||||
|
|
||||||
if (typeof value === 'object') {
|
|
||||||
try {
|
|
||||||
return JSON.stringify(value);
|
|
||||||
} catch (e) {
|
|
||||||
return '[Object]';
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return String(value);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Resolve template variables in header values
|
|
||||||
*
|
|
||||||
* @param headers Header object with potential template variables
|
|
||||||
* @param context Route context for variable resolution
|
|
||||||
* @returns New header object with resolved values
|
|
||||||
*/
|
|
||||||
public static resolveHeaderTemplates(
|
|
||||||
headers: Record<string, string>,
|
|
||||||
context: IRouteContext
|
|
||||||
): Record<string, string> {
|
|
||||||
const result: Record<string, string> = {};
|
|
||||||
|
|
||||||
for (const [key, value] of Object.entries(headers)) {
|
|
||||||
// Skip special directive headers (starting with !)
|
|
||||||
if (value.startsWith('!')) {
|
|
||||||
result[key] = value;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Resolve template variables in the header value
|
|
||||||
result[key] = TemplateUtils.resolveTemplateVariables(value, context);
|
|
||||||
}
|
|
||||||
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Check if a string contains template variables
|
|
||||||
*
|
|
||||||
* @param str String to check for template variables
|
|
||||||
* @returns True if string contains template variables
|
|
||||||
*/
|
|
||||||
public static containsTemplateVariables(str: string): boolean {
|
|
||||||
return !!str && /\{([a-zA-Z0-9_\.]+)\}/g.test(str);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,177 +0,0 @@
|
|||||||
import * as plugins from '../../plugins.js';
|
|
||||||
import type { IDomainOptions, IAcmeOptions } from '../models/common-types.js';
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Collection of validation utilities for configuration and domain options
|
|
||||||
*/
|
|
||||||
export class ValidationUtils {
|
|
||||||
/**
|
|
||||||
* Validates domain configuration options
|
|
||||||
*
|
|
||||||
* @param domainOptions The domain options to validate
|
|
||||||
* @returns An object with validation result and error message if invalid
|
|
||||||
*/
|
|
||||||
public static validateDomainOptions(domainOptions: IDomainOptions): { isValid: boolean; error?: string } {
|
|
||||||
if (!domainOptions) {
|
|
||||||
return { isValid: false, error: 'Domain options cannot be null or undefined' };
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!domainOptions.domainName) {
|
|
||||||
return { isValid: false, error: 'Domain name is required' };
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check domain pattern
|
|
||||||
if (!this.isValidDomainName(domainOptions.domainName)) {
|
|
||||||
return { isValid: false, error: `Invalid domain name: ${domainOptions.domainName}` };
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate forward config if provided
|
|
||||||
if (domainOptions.forward) {
|
|
||||||
if (!domainOptions.forward.ip) {
|
|
||||||
return { isValid: false, error: 'Forward IP is required when forward is specified' };
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!domainOptions.forward.port) {
|
|
||||||
return { isValid: false, error: 'Forward port is required when forward is specified' };
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!this.isValidPort(domainOptions.forward.port)) {
|
|
||||||
return { isValid: false, error: `Invalid forward port: ${domainOptions.forward.port}` };
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate ACME forward config if provided
|
|
||||||
if (domainOptions.acmeForward) {
|
|
||||||
if (!domainOptions.acmeForward.ip) {
|
|
||||||
return { isValid: false, error: 'ACME forward IP is required when acmeForward is specified' };
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!domainOptions.acmeForward.port) {
|
|
||||||
return { isValid: false, error: 'ACME forward port is required when acmeForward is specified' };
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!this.isValidPort(domainOptions.acmeForward.port)) {
|
|
||||||
return { isValid: false, error: `Invalid ACME forward port: ${domainOptions.acmeForward.port}` };
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return { isValid: true };
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Validates ACME configuration options
|
|
||||||
*
|
|
||||||
* @param acmeOptions The ACME options to validate
|
|
||||||
* @returns An object with validation result and error message if invalid
|
|
||||||
*/
|
|
||||||
public static validateAcmeOptions(acmeOptions: IAcmeOptions): { isValid: boolean; error?: string } {
|
|
||||||
if (!acmeOptions) {
|
|
||||||
return { isValid: false, error: 'ACME options cannot be null or undefined' };
|
|
||||||
}
|
|
||||||
|
|
||||||
if (acmeOptions.enabled) {
|
|
||||||
if (!acmeOptions.accountEmail) {
|
|
||||||
return { isValid: false, error: 'Account email is required when ACME is enabled' };
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!this.isValidEmail(acmeOptions.accountEmail)) {
|
|
||||||
return { isValid: false, error: `Invalid email: ${acmeOptions.accountEmail}` };
|
|
||||||
}
|
|
||||||
|
|
||||||
if (acmeOptions.port && !this.isValidPort(acmeOptions.port)) {
|
|
||||||
return { isValid: false, error: `Invalid ACME port: ${acmeOptions.port}` };
|
|
||||||
}
|
|
||||||
|
|
||||||
if (acmeOptions.httpsRedirectPort && !this.isValidPort(acmeOptions.httpsRedirectPort)) {
|
|
||||||
return { isValid: false, error: `Invalid HTTPS redirect port: ${acmeOptions.httpsRedirectPort}` };
|
|
||||||
}
|
|
||||||
|
|
||||||
if (acmeOptions.renewThresholdDays && acmeOptions.renewThresholdDays < 1) {
|
|
||||||
return { isValid: false, error: 'Renew threshold days must be greater than 0' };
|
|
||||||
}
|
|
||||||
|
|
||||||
if (acmeOptions.renewCheckIntervalHours && acmeOptions.renewCheckIntervalHours < 1) {
|
|
||||||
return { isValid: false, error: 'Renew check interval hours must be greater than 0' };
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return { isValid: true };
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Validates a port number
|
|
||||||
*
|
|
||||||
* @param port The port to validate
|
|
||||||
* @returns true if the port is valid, false otherwise
|
|
||||||
*/
|
|
||||||
public static isValidPort(port: number): boolean {
|
|
||||||
return typeof port === 'number' && port > 0 && port <= 65535 && Number.isInteger(port);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Validates a domain name
|
|
||||||
*
|
|
||||||
* @param domain The domain name to validate
|
|
||||||
* @returns true if the domain name is valid, false otherwise
|
|
||||||
*/
|
|
||||||
public static isValidDomainName(domain: string): boolean {
|
|
||||||
if (!domain || typeof domain !== 'string') {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wildcard domain check (*.example.com)
|
|
||||||
if (domain.startsWith('*.')) {
|
|
||||||
domain = domain.substring(2);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Simple domain validation pattern
|
|
||||||
const domainPattern = /^([a-zA-Z0-9]([a-zA-Z0-9\-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,}$/;
|
|
||||||
return domainPattern.test(domain);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Validates an email address
|
|
||||||
*
|
|
||||||
* @param email The email to validate
|
|
||||||
* @returns true if the email is valid, false otherwise
|
|
||||||
*/
|
|
||||||
public static isValidEmail(email: string): boolean {
|
|
||||||
if (!email || typeof email !== 'string') {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Basic email validation pattern
|
|
||||||
const emailPattern = /^[^\s@]+@[^\s@]+\.[^\s@]+$/;
|
|
||||||
return emailPattern.test(email);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Validates a certificate format (PEM)
|
|
||||||
*
|
|
||||||
* @param cert The certificate content to validate
|
|
||||||
* @returns true if the certificate appears to be in PEM format, false otherwise
|
|
||||||
*/
|
|
||||||
public static isValidCertificate(cert: string): boolean {
|
|
||||||
if (!cert || typeof cert !== 'string') {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
return cert.includes('-----BEGIN CERTIFICATE-----') &&
|
|
||||||
cert.includes('-----END CERTIFICATE-----');
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Validates a private key format (PEM)
|
|
||||||
*
|
|
||||||
* @param key The private key content to validate
|
|
||||||
* @returns true if the key appears to be in PEM format, false otherwise
|
|
||||||
*/
|
|
||||||
public static isValidPrivateKey(key: string): boolean {
|
|
||||||
if (!key || typeof key !== 'string') {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
return key.includes('-----BEGIN PRIVATE KEY-----') &&
|
|
||||||
key.includes('-----END PRIVATE KEY-----');
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,33 +0,0 @@
|
|||||||
/**
|
|
||||||
* WebSocket utility functions
|
|
||||||
*
|
|
||||||
* This module provides smartproxy-specific WebSocket utilities
|
|
||||||
* and re-exports protocol utilities from the protocols module
|
|
||||||
*/
|
|
||||||
|
|
||||||
// Import and re-export from protocols
|
|
||||||
import { getMessageSize as protocolGetMessageSize, toBuffer as protocolToBuffer } from '../../protocols/websocket/index.js';
|
|
||||||
export type { RawData } from '../../protocols/websocket/index.js';
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get the length of a WebSocket message regardless of its type
|
|
||||||
* (handles all possible WebSocket message data types)
|
|
||||||
*
|
|
||||||
* @param data - The data message from WebSocket (could be any RawData type)
|
|
||||||
* @returns The length of the data in bytes
|
|
||||||
*/
|
|
||||||
export function getMessageSize(data: import('../../protocols/websocket/index.js').RawData): number {
|
|
||||||
// Delegate to protocol implementation
|
|
||||||
return protocolGetMessageSize(data);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Convert any raw WebSocket data to Buffer for consistent handling
|
|
||||||
*
|
|
||||||
* @param data - The data message from WebSocket (could be any RawData type)
|
|
||||||
* @returns A Buffer containing the data
|
|
||||||
*/
|
|
||||||
export function toBuffer(data: import('../../protocols/websocket/index.js').RawData): Buffer {
|
|
||||||
// Delegate to protocol implementation
|
|
||||||
return protocolToBuffer(data);
|
|
||||||
}
|
|
||||||
@@ -1,127 +0,0 @@
|
|||||||
/**
|
|
||||||
* HTTP Protocol Detector
|
|
||||||
*
|
|
||||||
* Simplified HTTP detection using the new architecture
|
|
||||||
*/
|
|
||||||
|
|
||||||
import type { IProtocolDetector } from '../models/interfaces.js';
|
|
||||||
import type { IDetectionResult, IDetectionOptions } from '../models/detection-types.js';
|
|
||||||
import type { IProtocolDetectionResult, IConnectionContext } from '../../protocols/common/types.js';
|
|
||||||
import type { THttpMethod } from '../../protocols/http/index.js';
|
|
||||||
import { QuickProtocolDetector } from './quick-detector.js';
|
|
||||||
import { RoutingExtractor } from './routing-extractor.js';
|
|
||||||
import { DetectionFragmentManager } from '../utils/fragment-manager.js';
|
|
||||||
import { HttpParser } from '../../protocols/http/parser.js';
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Simplified HTTP detector
|
|
||||||
*/
|
|
||||||
export class HttpDetector implements IProtocolDetector {
|
|
||||||
private quickDetector = new QuickProtocolDetector();
|
|
||||||
private fragmentManager: DetectionFragmentManager;
|
|
||||||
|
|
||||||
constructor(fragmentManager?: DetectionFragmentManager) {
|
|
||||||
this.fragmentManager = fragmentManager || new DetectionFragmentManager();
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Check if buffer can be handled by this detector
|
|
||||||
*/
|
|
||||||
canHandle(buffer: Buffer): boolean {
|
|
||||||
const result = this.quickDetector.quickDetect(buffer);
|
|
||||||
return result.protocol === 'http' && result.confidence > 50;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get minimum bytes needed for detection
|
|
||||||
*/
|
|
||||||
getMinimumBytes(): number {
|
|
||||||
return 4; // "GET " minimum
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Detect HTTP protocol from buffer
|
|
||||||
*/
|
|
||||||
detect(buffer: Buffer, options?: IDetectionOptions): IDetectionResult | null {
|
|
||||||
// Quick detection first
|
|
||||||
const quickResult = this.quickDetector.quickDetect(buffer);
|
|
||||||
|
|
||||||
if (quickResult.protocol !== 'http' || quickResult.confidence < 50) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if we have complete headers first
|
|
||||||
const headersEnd = buffer.indexOf('\r\n\r\n');
|
|
||||||
const isComplete = headersEnd !== -1;
|
|
||||||
|
|
||||||
// Extract routing information
|
|
||||||
const routing = RoutingExtractor.extract(buffer, 'http');
|
|
||||||
|
|
||||||
// Extract headers if requested and we have complete headers
|
|
||||||
let headers: Record<string, string> | undefined;
|
|
||||||
if (options?.extractFullHeaders && isComplete) {
|
|
||||||
const headerSection = buffer.slice(0, headersEnd).toString();
|
|
||||||
const lines = headerSection.split('\r\n');
|
|
||||||
if (lines.length > 1) {
|
|
||||||
// Skip the request line and parse headers
|
|
||||||
headers = HttpParser.parseHeaders(lines.slice(1));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// If we don't need full headers and we have complete headers, we can return early
|
|
||||||
if (quickResult.confidence >= 95 && !options?.extractFullHeaders && isComplete) {
|
|
||||||
return {
|
|
||||||
protocol: 'http',
|
|
||||||
connectionInfo: {
|
|
||||||
protocol: 'http',
|
|
||||||
method: quickResult.metadata?.method as THttpMethod,
|
|
||||||
domain: routing?.domain,
|
|
||||||
path: routing?.path
|
|
||||||
},
|
|
||||||
isComplete: true
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
return {
|
|
||||||
protocol: 'http',
|
|
||||||
connectionInfo: {
|
|
||||||
protocol: 'http',
|
|
||||||
domain: routing?.domain,
|
|
||||||
path: routing?.path,
|
|
||||||
method: quickResult.metadata?.method as THttpMethod,
|
|
||||||
headers: headers
|
|
||||||
},
|
|
||||||
isComplete,
|
|
||||||
bytesNeeded: isComplete ? undefined : buffer.length + 512 // Need more for headers
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Handle fragmented detection
|
|
||||||
*/
|
|
||||||
detectWithContext(
|
|
||||||
buffer: Buffer,
|
|
||||||
context: IConnectionContext,
|
|
||||||
options?: IDetectionOptions
|
|
||||||
): IDetectionResult | null {
|
|
||||||
const handler = this.fragmentManager.getHandler('http');
|
|
||||||
const connectionId = DetectionFragmentManager.createConnectionId(context);
|
|
||||||
|
|
||||||
// Add fragment
|
|
||||||
const result = handler.addFragment(connectionId, buffer);
|
|
||||||
|
|
||||||
if (result.error) {
|
|
||||||
handler.complete(connectionId);
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Try detection on accumulated buffer
|
|
||||||
const detectResult = this.detect(result.buffer!, options);
|
|
||||||
|
|
||||||
if (detectResult && detectResult.isComplete) {
|
|
||||||
handler.complete(connectionId);
|
|
||||||
}
|
|
||||||
|
|
||||||
return detectResult;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,148 +0,0 @@
|
|||||||
/**
|
|
||||||
* Quick Protocol Detector
|
|
||||||
*
|
|
||||||
* Lightweight protocol identification based on minimal bytes
|
|
||||||
* No parsing, just identification
|
|
||||||
*/
|
|
||||||
|
|
||||||
import type { IProtocolDetector, IProtocolDetectionResult } from '../../protocols/common/types.js';
|
|
||||||
import { TlsRecordType } from '../../protocols/tls/index.js';
|
|
||||||
import { HttpParser } from '../../protocols/http/index.js';
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Quick protocol detector for fast identification
|
|
||||||
*/
|
|
||||||
export class QuickProtocolDetector implements IProtocolDetector {
|
|
||||||
/**
|
|
||||||
* Check if this detector can handle the data
|
|
||||||
*/
|
|
||||||
canHandle(data: Buffer): boolean {
|
|
||||||
return data.length >= 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Perform quick detection based on first few bytes
|
|
||||||
*/
|
|
||||||
quickDetect(data: Buffer): IProtocolDetectionResult {
|
|
||||||
if (data.length === 0) {
|
|
||||||
return {
|
|
||||||
protocol: 'unknown',
|
|
||||||
confidence: 0,
|
|
||||||
requiresMoreData: true
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check for TLS
|
|
||||||
const tlsResult = this.checkTls(data);
|
|
||||||
if (tlsResult.confidence > 80) {
|
|
||||||
return tlsResult;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check for HTTP
|
|
||||||
const httpResult = this.checkHttp(data);
|
|
||||||
if (httpResult.confidence > 80) {
|
|
||||||
return httpResult;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Need more data or unknown
|
|
||||||
return {
|
|
||||||
protocol: 'unknown',
|
|
||||||
confidence: 0,
|
|
||||||
requiresMoreData: data.length < 20
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Check if data looks like TLS
|
|
||||||
*/
|
|
||||||
private checkTls(data: Buffer): IProtocolDetectionResult {
|
|
||||||
if (data.length < 3) {
|
|
||||||
return {
|
|
||||||
protocol: 'tls',
|
|
||||||
confidence: 0,
|
|
||||||
requiresMoreData: true
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
const firstByte = data[0];
|
|
||||||
const secondByte = data[1];
|
|
||||||
|
|
||||||
// Check for valid TLS record type
|
|
||||||
const validRecordTypes = [
|
|
||||||
TlsRecordType.CHANGE_CIPHER_SPEC,
|
|
||||||
TlsRecordType.ALERT,
|
|
||||||
TlsRecordType.HANDSHAKE,
|
|
||||||
TlsRecordType.APPLICATION_DATA,
|
|
||||||
TlsRecordType.HEARTBEAT
|
|
||||||
];
|
|
||||||
|
|
||||||
if (!validRecordTypes.includes(firstByte)) {
|
|
||||||
return {
|
|
||||||
protocol: 'tls',
|
|
||||||
confidence: 0
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check TLS version byte (0x03 for all TLS/SSL versions)
|
|
||||||
if (secondByte !== 0x03) {
|
|
||||||
return {
|
|
||||||
protocol: 'tls',
|
|
||||||
confidence: 0
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
// High confidence it's TLS
|
|
||||||
return {
|
|
||||||
protocol: 'tls',
|
|
||||||
confidence: 95,
|
|
||||||
metadata: {
|
|
||||||
recordType: firstByte
|
|
||||||
}
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Check if data looks like HTTP
|
|
||||||
*/
|
|
||||||
private checkHttp(data: Buffer): IProtocolDetectionResult {
|
|
||||||
if (data.length < 3) {
|
|
||||||
return {
|
|
||||||
protocol: 'http',
|
|
||||||
confidence: 0,
|
|
||||||
requiresMoreData: true
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
// Quick check for HTTP methods
|
|
||||||
const start = data.subarray(0, Math.min(10, data.length)).toString('ascii');
|
|
||||||
|
|
||||||
// Check common HTTP methods
|
|
||||||
const httpMethods = ['GET ', 'POST ', 'PUT ', 'DELETE ', 'HEAD ', 'OPTIONS', 'PATCH ', 'CONNECT', 'TRACE '];
|
|
||||||
for (const method of httpMethods) {
|
|
||||||
if (start.startsWith(method)) {
|
|
||||||
return {
|
|
||||||
protocol: 'http',
|
|
||||||
confidence: 95,
|
|
||||||
metadata: {
|
|
||||||
method: method.trim()
|
|
||||||
}
|
|
||||||
};
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if it might be HTTP but need more data
|
|
||||||
if (HttpParser.isPrintableAscii(data, Math.min(20, data.length))) {
|
|
||||||
// Could be HTTP, but not sure
|
|
||||||
return {
|
|
||||||
protocol: 'http',
|
|
||||||
confidence: 30,
|
|
||||||
requiresMoreData: data.length < 20
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
return {
|
|
||||||
protocol: 'http',
|
|
||||||
confidence: 0
|
|
||||||
};
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,147 +0,0 @@
|
|||||||
/**
|
|
||||||
* Routing Information Extractor
|
|
||||||
*
|
|
||||||
* Extracts minimal routing information from protocols
|
|
||||||
* without full parsing
|
|
||||||
*/
|
|
||||||
|
|
||||||
import type { IRoutingInfo, IConnectionContext, TProtocolType } from '../../protocols/common/types.js';
|
|
||||||
import { SniExtraction } from '../../protocols/tls/sni/sni-extraction.js';
|
|
||||||
import { HttpParser } from '../../protocols/http/index.js';
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Extracts routing information from protocol data
|
|
||||||
*/
|
|
||||||
export class RoutingExtractor {
|
|
||||||
/**
|
|
||||||
* Extract routing info based on protocol type
|
|
||||||
*/
|
|
||||||
static extract(
|
|
||||||
data: Buffer,
|
|
||||||
protocol: TProtocolType,
|
|
||||||
context?: IConnectionContext
|
|
||||||
): IRoutingInfo | null {
|
|
||||||
switch (protocol) {
|
|
||||||
case 'tls':
|
|
||||||
case 'https':
|
|
||||||
return this.extractTlsRouting(data, context);
|
|
||||||
|
|
||||||
case 'http':
|
|
||||||
return this.extractHttpRouting(data);
|
|
||||||
|
|
||||||
default:
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Extract routing from TLS ClientHello (SNI)
|
|
||||||
*/
|
|
||||||
private static extractTlsRouting(
|
|
||||||
data: Buffer,
|
|
||||||
context?: IConnectionContext
|
|
||||||
): IRoutingInfo | null {
|
|
||||||
try {
|
|
||||||
// Quick SNI extraction without full parsing
|
|
||||||
const sni = SniExtraction.extractSNI(data);
|
|
||||||
|
|
||||||
if (sni) {
|
|
||||||
return {
|
|
||||||
domain: sni,
|
|
||||||
protocol: 'tls',
|
|
||||||
port: 443 // Default HTTPS port
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
return null;
|
|
||||||
} catch (error) {
|
|
||||||
// Extraction failed, return null
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Extract routing from HTTP headers (Host header)
|
|
||||||
*/
|
|
||||||
private static extractHttpRouting(data: Buffer): IRoutingInfo | null {
|
|
||||||
try {
|
|
||||||
// Look for first line
|
|
||||||
const firstLineEnd = data.indexOf('\n');
|
|
||||||
if (firstLineEnd === -1) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse request line
|
|
||||||
const firstLine = data.subarray(0, firstLineEnd).toString('ascii').trim();
|
|
||||||
const requestLine = HttpParser.parseRequestLine(firstLine);
|
|
||||||
|
|
||||||
if (!requestLine) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Look for Host header
|
|
||||||
let pos = firstLineEnd + 1;
|
|
||||||
const maxSearch = Math.min(data.length, 4096); // Don't search too far
|
|
||||||
|
|
||||||
while (pos < maxSearch) {
|
|
||||||
const lineEnd = data.indexOf('\n', pos);
|
|
||||||
if (lineEnd === -1) break;
|
|
||||||
|
|
||||||
const line = data.subarray(pos, lineEnd).toString('ascii').trim();
|
|
||||||
|
|
||||||
// Empty line means end of headers
|
|
||||||
if (line.length === 0) break;
|
|
||||||
|
|
||||||
// Check for Host header
|
|
||||||
if (line.toLowerCase().startsWith('host:')) {
|
|
||||||
const hostValue = line.substring(5).trim();
|
|
||||||
const domain = HttpParser.extractDomainFromHost(hostValue);
|
|
||||||
|
|
||||||
return {
|
|
||||||
domain,
|
|
||||||
path: requestLine.path,
|
|
||||||
protocol: 'http',
|
|
||||||
port: 80 // Default HTTP port
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
pos = lineEnd + 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
// No Host header found, but we have the path
|
|
||||||
return {
|
|
||||||
path: requestLine.path,
|
|
||||||
protocol: 'http',
|
|
||||||
port: 80
|
|
||||||
};
|
|
||||||
} catch (error) {
|
|
||||||
// Extraction failed
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Try to extract domain from any protocol
|
|
||||||
*/
|
|
||||||
static extractDomain(data: Buffer, hint?: TProtocolType): string | null {
|
|
||||||
// If we have a hint, use it
|
|
||||||
if (hint) {
|
|
||||||
const routing = this.extract(data, hint);
|
|
||||||
return routing?.domain || null;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Try TLS first (more specific)
|
|
||||||
const tlsRouting = this.extractTlsRouting(data);
|
|
||||||
if (tlsRouting?.domain) {
|
|
||||||
return tlsRouting.domain;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Try HTTP
|
|
||||||
const httpRouting = this.extractHttpRouting(data);
|
|
||||||
if (httpRouting?.domain) {
|
|
||||||
return httpRouting.domain;
|
|
||||||
}
|
|
||||||
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,223 +0,0 @@
|
|||||||
/**
|
|
||||||
* TLS protocol detector
|
|
||||||
*/
|
|
||||||
|
|
||||||
// TLS detector doesn't need plugins imports
|
|
||||||
import type { IProtocolDetector } from '../models/interfaces.js';
|
|
||||||
import type { IDetectionResult, IDetectionOptions, IConnectionInfo } from '../models/detection-types.js';
|
|
||||||
import { readUInt16BE } from '../utils/buffer-utils.js';
|
|
||||||
import { tlsVersionToString } from '../utils/parser-utils.js';
|
|
||||||
|
|
||||||
// Import from protocols
|
|
||||||
import { TlsRecordType, TlsHandshakeType, TlsExtensionType } from '../../protocols/tls/index.js';
|
|
||||||
|
|
||||||
// Import TLS utilities for SNI extraction from protocols
|
|
||||||
import { SniExtraction } from '../../protocols/tls/sni/sni-extraction.js';
|
|
||||||
import { ClientHelloParser } from '../../protocols/tls/sni/client-hello-parser.js';
|
|
||||||
|
|
||||||
/**
|
|
||||||
* TLS detector implementation
|
|
||||||
*/
|
|
||||||
export class TlsDetector implements IProtocolDetector {
|
|
||||||
/**
|
|
||||||
* Minimum bytes needed to identify TLS (record header)
|
|
||||||
*/
|
|
||||||
private static readonly MIN_TLS_HEADER_SIZE = 5;
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Detect TLS protocol from buffer
|
|
||||||
*/
|
|
||||||
detect(buffer: Buffer, options?: IDetectionOptions): IDetectionResult | null {
|
|
||||||
// Check if buffer is too small
|
|
||||||
if (buffer.length < TlsDetector.MIN_TLS_HEADER_SIZE) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if this is a TLS record
|
|
||||||
if (!this.isTlsRecord(buffer)) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Extract basic TLS info
|
|
||||||
const recordType = buffer[0];
|
|
||||||
const tlsMajor = buffer[1];
|
|
||||||
const tlsMinor = buffer[2];
|
|
||||||
const recordLength = readUInt16BE(buffer, 3);
|
|
||||||
|
|
||||||
// Initialize connection info
|
|
||||||
const connectionInfo: IConnectionInfo = {
|
|
||||||
protocol: 'tls',
|
|
||||||
tlsVersion: tlsVersionToString(tlsMajor, tlsMinor) || undefined
|
|
||||||
};
|
|
||||||
|
|
||||||
// If it's a handshake, try to extract more info
|
|
||||||
if (recordType === TlsRecordType.HANDSHAKE && buffer.length >= 6) {
|
|
||||||
const handshakeType = buffer[5];
|
|
||||||
|
|
||||||
// For ClientHello, extract SNI and other info
|
|
||||||
if (handshakeType === TlsHandshakeType.CLIENT_HELLO) {
|
|
||||||
// Check if we have the complete handshake
|
|
||||||
const totalRecordLength = recordLength + 5; // Including TLS header
|
|
||||||
if (buffer.length >= totalRecordLength) {
|
|
||||||
// Extract SNI using existing logic
|
|
||||||
const sni = SniExtraction.extractSNI(buffer);
|
|
||||||
if (sni) {
|
|
||||||
connectionInfo.domain = sni;
|
|
||||||
connectionInfo.sni = sni;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse ClientHello for additional info
|
|
||||||
const parseResult = ClientHelloParser.parseClientHello(buffer);
|
|
||||||
if (parseResult.isValid) {
|
|
||||||
// Extract ALPN if present
|
|
||||||
const alpnExtension = parseResult.extensions.find(
|
|
||||||
ext => ext.type === TlsExtensionType.APPLICATION_LAYER_PROTOCOL_NEGOTIATION
|
|
||||||
);
|
|
||||||
|
|
||||||
if (alpnExtension) {
|
|
||||||
connectionInfo.alpn = this.parseAlpnExtension(alpnExtension.data);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Store cipher suites if needed
|
|
||||||
if (parseResult.cipherSuites && options?.extractFullHeaders) {
|
|
||||||
connectionInfo.cipherSuites = this.parseCipherSuites(parseResult.cipherSuites);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Return complete result
|
|
||||||
return {
|
|
||||||
protocol: 'tls',
|
|
||||||
connectionInfo,
|
|
||||||
remainingBuffer: buffer.length > totalRecordLength
|
|
||||||
? buffer.subarray(totalRecordLength)
|
|
||||||
: undefined,
|
|
||||||
isComplete: true
|
|
||||||
};
|
|
||||||
} else {
|
|
||||||
// Incomplete handshake
|
|
||||||
return {
|
|
||||||
protocol: 'tls',
|
|
||||||
connectionInfo,
|
|
||||||
isComplete: false,
|
|
||||||
bytesNeeded: totalRecordLength
|
|
||||||
};
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// For other TLS record types, just return basic info
|
|
||||||
return {
|
|
||||||
protocol: 'tls',
|
|
||||||
connectionInfo,
|
|
||||||
isComplete: true,
|
|
||||||
remainingBuffer: buffer.length > recordLength + 5
|
|
||||||
? buffer.subarray(recordLength + 5)
|
|
||||||
: undefined
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Check if buffer can be handled by this detector
|
|
||||||
*/
|
|
||||||
canHandle(buffer: Buffer): boolean {
|
|
||||||
return buffer.length >= TlsDetector.MIN_TLS_HEADER_SIZE &&
|
|
||||||
this.isTlsRecord(buffer);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get minimum bytes needed for detection
|
|
||||||
*/
|
|
||||||
getMinimumBytes(): number {
|
|
||||||
return TlsDetector.MIN_TLS_HEADER_SIZE;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Check if buffer contains a valid TLS record
|
|
||||||
*/
|
|
||||||
private isTlsRecord(buffer: Buffer): boolean {
|
|
||||||
const recordType = buffer[0];
|
|
||||||
|
|
||||||
// Check for valid record type
|
|
||||||
const validTypes = [
|
|
||||||
TlsRecordType.CHANGE_CIPHER_SPEC,
|
|
||||||
TlsRecordType.ALERT,
|
|
||||||
TlsRecordType.HANDSHAKE,
|
|
||||||
TlsRecordType.APPLICATION_DATA,
|
|
||||||
TlsRecordType.HEARTBEAT
|
|
||||||
];
|
|
||||||
|
|
||||||
if (!validTypes.includes(recordType)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check TLS version bytes (should be 0x03 0x0X)
|
|
||||||
if (buffer[1] !== 0x03) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check record length is reasonable
|
|
||||||
const recordLength = readUInt16BE(buffer, 3);
|
|
||||||
if (recordLength > 16384) { // Max TLS record size
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Parse ALPN extension data
|
|
||||||
*/
|
|
||||||
private parseAlpnExtension(data: Buffer): string[] {
|
|
||||||
const protocols: string[] = [];
|
|
||||||
|
|
||||||
if (data.length < 2) {
|
|
||||||
return protocols;
|
|
||||||
}
|
|
||||||
|
|
||||||
const listLength = readUInt16BE(data, 0);
|
|
||||||
let offset = 2;
|
|
||||||
|
|
||||||
while (offset < Math.min(2 + listLength, data.length)) {
|
|
||||||
const protoLength = data[offset];
|
|
||||||
offset++;
|
|
||||||
|
|
||||||
if (offset + protoLength <= data.length) {
|
|
||||||
const protocol = data.subarray(offset, offset + protoLength).toString('ascii');
|
|
||||||
protocols.push(protocol);
|
|
||||||
offset += protoLength;
|
|
||||||
} else {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return protocols;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Parse cipher suites
|
|
||||||
*/
|
|
||||||
private parseCipherSuites(cipherData: Buffer): number[] {
|
|
||||||
const suites: number[] = [];
|
|
||||||
|
|
||||||
for (let i = 0; i < cipherData.length - 1; i += 2) {
|
|
||||||
const suite = readUInt16BE(cipherData, i);
|
|
||||||
suites.push(suite);
|
|
||||||
}
|
|
||||||
|
|
||||||
return suites;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Detect with context for fragmented data
|
|
||||||
*/
|
|
||||||
detectWithContext(
|
|
||||||
buffer: Buffer,
|
|
||||||
_context: { sourceIp?: string; sourcePort?: number; destIp?: string; destPort?: number },
|
|
||||||
options?: IDetectionOptions
|
|
||||||
): IDetectionResult | null {
|
|
||||||
// This method is deprecated - TLS detection should use the fragment manager
|
|
||||||
// from the parent detector system, not maintain its own fragments
|
|
||||||
return this.detect(buffer, options);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,25 +0,0 @@
|
|||||||
/**
|
|
||||||
* Centralized Protocol Detection Module
|
|
||||||
*
|
|
||||||
* This module provides unified protocol detection capabilities for
|
|
||||||
* both TLS and HTTP protocols, extracting connection information
|
|
||||||
* without consuming the data stream.
|
|
||||||
*/
|
|
||||||
|
|
||||||
// Main detector
|
|
||||||
export * from './protocol-detector.js';
|
|
||||||
|
|
||||||
// Models
|
|
||||||
export * from './models/detection-types.js';
|
|
||||||
export * from './models/interfaces.js';
|
|
||||||
|
|
||||||
// Individual detectors
|
|
||||||
export * from './detectors/tls-detector.js';
|
|
||||||
export * from './detectors/http-detector.js';
|
|
||||||
export * from './detectors/quick-detector.js';
|
|
||||||
export * from './detectors/routing-extractor.js';
|
|
||||||
|
|
||||||
// Utilities
|
|
||||||
export * from './utils/buffer-utils.js';
|
|
||||||
export * from './utils/parser-utils.js';
|
|
||||||
export * from './utils/fragment-manager.js';
|
|
||||||
@@ -1,102 +0,0 @@
|
|||||||
/**
|
|
||||||
* Type definitions for protocol detection
|
|
||||||
*/
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Supported protocol types that can be detected
|
|
||||||
*/
|
|
||||||
export type TProtocolType = 'tls' | 'http' | 'unknown';
|
|
||||||
|
|
||||||
/**
|
|
||||||
* HTTP method types
|
|
||||||
*/
|
|
||||||
export type THttpMethod = 'GET' | 'POST' | 'PUT' | 'DELETE' | 'PATCH' | 'HEAD' | 'OPTIONS' | 'CONNECT' | 'TRACE';
|
|
||||||
|
|
||||||
/**
|
|
||||||
* TLS version identifiers
|
|
||||||
*/
|
|
||||||
export type TTlsVersion = 'SSLv3' | 'TLSv1.0' | 'TLSv1.1' | 'TLSv1.2' | 'TLSv1.3';
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Connection information extracted from protocol detection
|
|
||||||
*/
|
|
||||||
export interface IConnectionInfo {
|
|
||||||
/**
|
|
||||||
* The detected protocol type
|
|
||||||
*/
|
|
||||||
protocol: TProtocolType;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Domain/hostname extracted from the connection
|
|
||||||
* - For TLS: from SNI extension
|
|
||||||
* - For HTTP: from Host header
|
|
||||||
*/
|
|
||||||
domain?: string;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* HTTP-specific fields
|
|
||||||
*/
|
|
||||||
method?: THttpMethod;
|
|
||||||
path?: string;
|
|
||||||
httpVersion?: string;
|
|
||||||
headers?: Record<string, string>;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* TLS-specific fields
|
|
||||||
*/
|
|
||||||
tlsVersion?: TTlsVersion;
|
|
||||||
sni?: string;
|
|
||||||
alpn?: string[];
|
|
||||||
cipherSuites?: number[];
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Result of protocol detection
|
|
||||||
*/
|
|
||||||
export interface IDetectionResult {
|
|
||||||
/**
|
|
||||||
* The detected protocol type
|
|
||||||
*/
|
|
||||||
protocol: TProtocolType;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Extracted connection information
|
|
||||||
*/
|
|
||||||
connectionInfo: IConnectionInfo;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Any remaining buffer data after detection headers
|
|
||||||
* This can be used to continue processing the stream
|
|
||||||
*/
|
|
||||||
remainingBuffer?: Buffer;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Whether the detection is complete or needs more data
|
|
||||||
*/
|
|
||||||
isComplete: boolean;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Minimum bytes needed for complete detection (if incomplete)
|
|
||||||
*/
|
|
||||||
bytesNeeded?: number;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Options for protocol detection
|
|
||||||
*/
|
|
||||||
export interface IDetectionOptions {
|
|
||||||
/**
|
|
||||||
* Maximum bytes to buffer for detection (default: 8192)
|
|
||||||
*/
|
|
||||||
maxBufferSize?: number;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Timeout for detection in milliseconds (default: 5000)
|
|
||||||
*/
|
|
||||||
timeout?: number;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Whether to extract full headers or just essential info
|
|
||||||
*/
|
|
||||||
extractFullHeaders?: boolean;
|
|
||||||
}
|
|
||||||
@@ -1,115 +0,0 @@
|
|||||||
/**
|
|
||||||
* Interface definitions for protocol detection components
|
|
||||||
*/
|
|
||||||
|
|
||||||
import type { IDetectionResult, IDetectionOptions } from './detection-types.js';
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Interface for protocol detectors
|
|
||||||
*/
|
|
||||||
export interface IProtocolDetector {
|
|
||||||
/**
|
|
||||||
* Detect protocol from buffer data
|
|
||||||
* @param buffer The buffer to analyze
|
|
||||||
* @param options Detection options
|
|
||||||
* @returns Detection result or null if protocol cannot be determined
|
|
||||||
*/
|
|
||||||
detect(buffer: Buffer, options?: IDetectionOptions): IDetectionResult | null;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Check if buffer potentially contains this protocol
|
|
||||||
* @param buffer The buffer to check
|
|
||||||
* @returns True if buffer might contain this protocol
|
|
||||||
*/
|
|
||||||
canHandle(buffer: Buffer): boolean;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get the minimum bytes needed for detection
|
|
||||||
*/
|
|
||||||
getMinimumBytes(): number;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Interface for connection tracking during fragmented detection
|
|
||||||
*/
|
|
||||||
export interface IConnectionTracker {
|
|
||||||
/**
|
|
||||||
* Connection identifier
|
|
||||||
*/
|
|
||||||
id: string;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Accumulated buffer data
|
|
||||||
*/
|
|
||||||
buffer: Buffer;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Timestamp of first data
|
|
||||||
*/
|
|
||||||
startTime: number;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Current detection state
|
|
||||||
*/
|
|
||||||
state: 'detecting' | 'complete' | 'failed';
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Partial detection result (if any)
|
|
||||||
*/
|
|
||||||
partialResult?: Partial<IDetectionResult>;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Interface for buffer accumulator (handles fragmented data)
|
|
||||||
*/
|
|
||||||
export interface IBufferAccumulator {
|
|
||||||
/**
|
|
||||||
* Add data to accumulator
|
|
||||||
*/
|
|
||||||
append(data: Buffer): void;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get accumulated buffer
|
|
||||||
*/
|
|
||||||
getBuffer(): Buffer;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get buffer length
|
|
||||||
*/
|
|
||||||
length(): number;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Clear accumulated data
|
|
||||||
*/
|
|
||||||
clear(): void;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Check if accumulator has enough data
|
|
||||||
*/
|
|
||||||
hasMinimumBytes(minBytes: number): boolean;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Detection events
|
|
||||||
*/
|
|
||||||
export interface IDetectionEvents {
|
|
||||||
/**
|
|
||||||
* Emitted when protocol is successfully detected
|
|
||||||
*/
|
|
||||||
detected: (result: IDetectionResult) => void;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Emitted when detection fails
|
|
||||||
*/
|
|
||||||
failed: (error: Error) => void;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Emitted when detection times out
|
|
||||||
*/
|
|
||||||
timeout: () => void;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Emitted when more data is needed
|
|
||||||
*/
|
|
||||||
needMoreData: (bytesNeeded: number) => void;
|
|
||||||
}
|
|
||||||
@@ -1,319 +0,0 @@
|
|||||||
/**
|
|
||||||
* Protocol Detector
|
|
||||||
*
|
|
||||||
* Simplified protocol detection using the new architecture
|
|
||||||
*/
|
|
||||||
|
|
||||||
import type { IDetectionResult, IDetectionOptions } from './models/detection-types.js';
|
|
||||||
import type { IConnectionContext } from '../protocols/common/types.js';
|
|
||||||
import { TlsDetector } from './detectors/tls-detector.js';
|
|
||||||
import { HttpDetector } from './detectors/http-detector.js';
|
|
||||||
import { DetectionFragmentManager } from './utils/fragment-manager.js';
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Main protocol detector class
|
|
||||||
*/
|
|
||||||
export class ProtocolDetector {
|
|
||||||
private static instance: ProtocolDetector;
|
|
||||||
private fragmentManager: DetectionFragmentManager;
|
|
||||||
private tlsDetector: TlsDetector;
|
|
||||||
private httpDetector: HttpDetector;
|
|
||||||
private connectionProtocols: Map<string, { protocol: 'tls' | 'http'; createdAt: number }> = new Map();
|
|
||||||
|
|
||||||
constructor() {
|
|
||||||
this.fragmentManager = new DetectionFragmentManager();
|
|
||||||
this.tlsDetector = new TlsDetector();
|
|
||||||
this.httpDetector = new HttpDetector(this.fragmentManager);
|
|
||||||
}
|
|
||||||
|
|
||||||
private static getInstance(): ProtocolDetector {
|
|
||||||
if (!this.instance) {
|
|
||||||
this.instance = new ProtocolDetector();
|
|
||||||
}
|
|
||||||
return this.instance;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Detect protocol from buffer data
|
|
||||||
*/
|
|
||||||
static async detect(buffer: Buffer, options?: IDetectionOptions): Promise<IDetectionResult> {
|
|
||||||
return this.getInstance().detectInstance(buffer, options);
|
|
||||||
}
|
|
||||||
|
|
||||||
private async detectInstance(buffer: Buffer, options?: IDetectionOptions): Promise<IDetectionResult> {
|
|
||||||
// Quick sanity check
|
|
||||||
if (!buffer || buffer.length === 0) {
|
|
||||||
return {
|
|
||||||
protocol: 'unknown',
|
|
||||||
connectionInfo: { protocol: 'unknown' },
|
|
||||||
isComplete: true
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
// Try TLS detection first (more specific)
|
|
||||||
if (this.tlsDetector.canHandle(buffer)) {
|
|
||||||
const tlsResult = this.tlsDetector.detect(buffer, options);
|
|
||||||
if (tlsResult) {
|
|
||||||
return tlsResult;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Try HTTP detection
|
|
||||||
if (this.httpDetector.canHandle(buffer)) {
|
|
||||||
const httpResult = this.httpDetector.detect(buffer, options);
|
|
||||||
if (httpResult) {
|
|
||||||
return httpResult;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Neither TLS nor HTTP
|
|
||||||
return {
|
|
||||||
protocol: 'unknown',
|
|
||||||
connectionInfo: { protocol: 'unknown' },
|
|
||||||
isComplete: true
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Detect protocol with connection tracking for fragmented data
|
|
||||||
* @deprecated Use detectWithContext instead
|
|
||||||
*/
|
|
||||||
static async detectWithConnectionTracking(
|
|
||||||
buffer: Buffer,
|
|
||||||
connectionId: string,
|
|
||||||
options?: IDetectionOptions
|
|
||||||
): Promise<IDetectionResult> {
|
|
||||||
// Convert connection ID to context
|
|
||||||
const context: IConnectionContext = {
|
|
||||||
id: connectionId,
|
|
||||||
sourceIp: 'unknown',
|
|
||||||
sourcePort: 0,
|
|
||||||
destIp: 'unknown',
|
|
||||||
destPort: 0,
|
|
||||||
timestamp: Date.now()
|
|
||||||
};
|
|
||||||
|
|
||||||
return this.getInstance().detectWithContextInstance(buffer, context, options);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Detect protocol with connection context for fragmented data
|
|
||||||
*/
|
|
||||||
static async detectWithContext(
|
|
||||||
buffer: Buffer,
|
|
||||||
context: IConnectionContext,
|
|
||||||
options?: IDetectionOptions
|
|
||||||
): Promise<IDetectionResult> {
|
|
||||||
return this.getInstance().detectWithContextInstance(buffer, context, options);
|
|
||||||
}
|
|
||||||
|
|
||||||
private async detectWithContextInstance(
|
|
||||||
buffer: Buffer,
|
|
||||||
context: IConnectionContext,
|
|
||||||
options?: IDetectionOptions
|
|
||||||
): Promise<IDetectionResult> {
|
|
||||||
// Quick sanity check
|
|
||||||
if (!buffer || buffer.length === 0) {
|
|
||||||
return {
|
|
||||||
protocol: 'unknown',
|
|
||||||
connectionInfo: { protocol: 'unknown' },
|
|
||||||
isComplete: true
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
const connectionId = DetectionFragmentManager.createConnectionId(context);
|
|
||||||
|
|
||||||
// Check if we already know the protocol for this connection
|
|
||||||
const knownEntry = this.connectionProtocols.get(connectionId);
|
|
||||||
const knownProtocol = knownEntry?.protocol;
|
|
||||||
|
|
||||||
if (knownProtocol === 'http') {
|
|
||||||
const result = this.httpDetector.detectWithContext(buffer, context, options);
|
|
||||||
if (result) {
|
|
||||||
if (result.isComplete) {
|
|
||||||
this.connectionProtocols.delete(connectionId);
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
} else if (knownProtocol === 'tls') {
|
|
||||||
// Handle TLS with fragment accumulation
|
|
||||||
const handler = this.fragmentManager.getHandler('tls');
|
|
||||||
const fragmentResult = handler.addFragment(connectionId, buffer);
|
|
||||||
|
|
||||||
if (fragmentResult.error) {
|
|
||||||
handler.complete(connectionId);
|
|
||||||
this.connectionProtocols.delete(connectionId);
|
|
||||||
return {
|
|
||||||
protocol: 'unknown',
|
|
||||||
connectionInfo: { protocol: 'unknown' },
|
|
||||||
isComplete: true
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
const result = this.tlsDetector.detect(fragmentResult.buffer!, options);
|
|
||||||
if (result) {
|
|
||||||
if (result.isComplete) {
|
|
||||||
handler.complete(connectionId);
|
|
||||||
this.connectionProtocols.delete(connectionId);
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// If we don't know the protocol yet, try to detect it
|
|
||||||
if (!knownProtocol) {
|
|
||||||
// First peek to determine protocol type
|
|
||||||
if (this.tlsDetector.canHandle(buffer)) {
|
|
||||||
this.connectionProtocols.set(connectionId, { protocol: 'tls', createdAt: Date.now() });
|
|
||||||
// Handle TLS with fragment accumulation
|
|
||||||
const handler = this.fragmentManager.getHandler('tls');
|
|
||||||
const fragmentResult = handler.addFragment(connectionId, buffer);
|
|
||||||
|
|
||||||
if (fragmentResult.error) {
|
|
||||||
handler.complete(connectionId);
|
|
||||||
this.connectionProtocols.delete(connectionId);
|
|
||||||
return {
|
|
||||||
protocol: 'unknown',
|
|
||||||
connectionInfo: { protocol: 'unknown' },
|
|
||||||
isComplete: true
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
const result = this.tlsDetector.detect(fragmentResult.buffer!, options);
|
|
||||||
if (result) {
|
|
||||||
if (result.isComplete) {
|
|
||||||
handler.complete(connectionId);
|
|
||||||
this.connectionProtocols.delete(connectionId);
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (this.httpDetector.canHandle(buffer)) {
|
|
||||||
this.connectionProtocols.set(connectionId, { protocol: 'http', createdAt: Date.now() });
|
|
||||||
const result = this.httpDetector.detectWithContext(buffer, context, options);
|
|
||||||
if (result) {
|
|
||||||
if (result.isComplete) {
|
|
||||||
this.connectionProtocols.delete(connectionId);
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Can't determine protocol
|
|
||||||
return {
|
|
||||||
protocol: 'unknown',
|
|
||||||
connectionInfo: { protocol: 'unknown' },
|
|
||||||
isComplete: false,
|
|
||||||
bytesNeeded: Math.max(
|
|
||||||
this.tlsDetector.getMinimumBytes(),
|
|
||||||
this.httpDetector.getMinimumBytes()
|
|
||||||
)
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Clean up resources
|
|
||||||
*/
|
|
||||||
static cleanup(): void {
|
|
||||||
this.getInstance().cleanupInstance();
|
|
||||||
}
|
|
||||||
|
|
||||||
private cleanupInstance(): void {
|
|
||||||
this.fragmentManager.cleanup();
|
|
||||||
// Remove stale connectionProtocols entries (abandoned handshakes, port scanners)
|
|
||||||
const maxAge = 30_000; // 30 seconds
|
|
||||||
const now = Date.now();
|
|
||||||
for (const [id, entry] of this.connectionProtocols) {
|
|
||||||
if (now - entry.createdAt > maxAge) {
|
|
||||||
this.connectionProtocols.delete(id);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Destroy detector instance
|
|
||||||
*/
|
|
||||||
static destroy(): void {
|
|
||||||
this.getInstance().destroyInstance();
|
|
||||||
this.instance = null as any;
|
|
||||||
}
|
|
||||||
|
|
||||||
private destroyInstance(): void {
|
|
||||||
this.fragmentManager.destroy();
|
|
||||||
this.connectionProtocols.clear();
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Clean up old connection tracking entries
|
|
||||||
*
|
|
||||||
* @param _maxAge Maximum age in milliseconds (default: 30 seconds)
|
|
||||||
*/
|
|
||||||
static cleanupConnections(_maxAge: number = 30000): void {
|
|
||||||
this.getInstance().cleanupInstance();
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Clean up fragments for a specific connection
|
|
||||||
*/
|
|
||||||
static cleanupConnection(context: IConnectionContext): void {
|
|
||||||
const instance = this.getInstance();
|
|
||||||
const connectionId = DetectionFragmentManager.createConnectionId(context);
|
|
||||||
|
|
||||||
// Clean up both TLS and HTTP fragments for this connection
|
|
||||||
instance.fragmentManager.getHandler('tls').complete(connectionId);
|
|
||||||
instance.fragmentManager.getHandler('http').complete(connectionId);
|
|
||||||
|
|
||||||
// Remove from connection protocols tracking
|
|
||||||
instance.connectionProtocols.delete(connectionId);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Extract domain from connection info
|
|
||||||
*/
|
|
||||||
static extractDomain(connectionInfo: any): string | undefined {
|
|
||||||
return connectionInfo.domain || connectionInfo.sni || connectionInfo.host;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Create a connection ID from connection parameters
|
|
||||||
* @deprecated Use createConnectionContext instead
|
|
||||||
*/
|
|
||||||
static createConnectionId(params: {
|
|
||||||
sourceIp?: string;
|
|
||||||
sourcePort?: number;
|
|
||||||
destIp?: string;
|
|
||||||
destPort?: number;
|
|
||||||
socketId?: string;
|
|
||||||
}): string {
|
|
||||||
// If socketId is provided, use it
|
|
||||||
if (params.socketId) {
|
|
||||||
return params.socketId;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Otherwise create from connection tuple
|
|
||||||
const { sourceIp = 'unknown', sourcePort = 0, destIp = 'unknown', destPort = 0 } = params;
|
|
||||||
return `${sourceIp}:${sourcePort}-${destIp}:${destPort}`;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Create a connection context from parameters
|
|
||||||
*/
|
|
||||||
static createConnectionContext(params: {
|
|
||||||
sourceIp?: string;
|
|
||||||
sourcePort?: number;
|
|
||||||
destIp?: string;
|
|
||||||
destPort?: number;
|
|
||||||
socketId?: string;
|
|
||||||
}): IConnectionContext {
|
|
||||||
return {
|
|
||||||
id: params.socketId,
|
|
||||||
sourceIp: params.sourceIp || 'unknown',
|
|
||||||
sourcePort: params.sourcePort || 0,
|
|
||||||
destIp: params.destIp || 'unknown',
|
|
||||||
destPort: params.destPort || 0,
|
|
||||||
timestamp: Date.now()
|
|
||||||
};
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,141 +0,0 @@
|
|||||||
/**
|
|
||||||
* Buffer manipulation utilities for protocol detection
|
|
||||||
*/
|
|
||||||
|
|
||||||
// Import from protocols
|
|
||||||
import { HttpParser } from '../../protocols/http/index.js';
|
|
||||||
|
|
||||||
/**
|
|
||||||
* BufferAccumulator class for handling fragmented data
|
|
||||||
*/
|
|
||||||
export class BufferAccumulator {
|
|
||||||
private chunks: Buffer[] = [];
|
|
||||||
private totalLength = 0;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Append data to the accumulator
|
|
||||||
*/
|
|
||||||
append(data: Buffer): void {
|
|
||||||
this.chunks.push(data);
|
|
||||||
this.totalLength += data.length;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get the accumulated buffer
|
|
||||||
*/
|
|
||||||
getBuffer(): Buffer {
|
|
||||||
if (this.chunks.length === 0) {
|
|
||||||
return Buffer.alloc(0);
|
|
||||||
}
|
|
||||||
if (this.chunks.length === 1) {
|
|
||||||
return this.chunks[0];
|
|
||||||
}
|
|
||||||
return Buffer.concat(this.chunks, this.totalLength);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get current buffer length
|
|
||||||
*/
|
|
||||||
length(): number {
|
|
||||||
return this.totalLength;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Clear all accumulated data
|
|
||||||
*/
|
|
||||||
clear(): void {
|
|
||||||
this.chunks = [];
|
|
||||||
this.totalLength = 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Check if accumulator has minimum bytes
|
|
||||||
*/
|
|
||||||
hasMinimumBytes(minBytes: number): boolean {
|
|
||||||
return this.totalLength >= minBytes;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Read a big-endian 16-bit integer from buffer
|
|
||||||
*/
|
|
||||||
export function readUInt16BE(buffer: Buffer, offset: number): number {
|
|
||||||
if (offset + 2 > buffer.length) {
|
|
||||||
throw new Error('Buffer too short for UInt16BE read');
|
|
||||||
}
|
|
||||||
return (buffer[offset] << 8) | buffer[offset + 1];
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Read a big-endian 24-bit integer from buffer
|
|
||||||
*/
|
|
||||||
export function readUInt24BE(buffer: Buffer, offset: number): number {
|
|
||||||
if (offset + 3 > buffer.length) {
|
|
||||||
throw new Error('Buffer too short for UInt24BE read');
|
|
||||||
}
|
|
||||||
return (buffer[offset] << 16) | (buffer[offset + 1] << 8) | buffer[offset + 2];
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Find a byte sequence in a buffer
|
|
||||||
*/
|
|
||||||
export function findSequence(buffer: Buffer, sequence: Buffer, startOffset = 0): number {
|
|
||||||
if (sequence.length === 0) {
|
|
||||||
return startOffset;
|
|
||||||
}
|
|
||||||
|
|
||||||
const searchLength = buffer.length - sequence.length + 1;
|
|
||||||
for (let i = startOffset; i < searchLength; i++) {
|
|
||||||
let found = true;
|
|
||||||
for (let j = 0; j < sequence.length; j++) {
|
|
||||||
if (buffer[i + j] !== sequence[j]) {
|
|
||||||
found = false;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (found) {
|
|
||||||
return i;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Extract a line from buffer (up to CRLF or LF)
|
|
||||||
*/
|
|
||||||
export function extractLine(buffer: Buffer, startOffset = 0): { line: string; nextOffset: number } | null {
|
|
||||||
// Delegate to protocol parser
|
|
||||||
return HttpParser.extractLine(buffer, startOffset);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Check if buffer starts with a string (case-insensitive)
|
|
||||||
*/
|
|
||||||
export function startsWithString(buffer: Buffer, str: string, offset = 0): boolean {
|
|
||||||
if (offset + str.length > buffer.length) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
const bufferStr = buffer.slice(offset, offset + str.length).toString('utf8');
|
|
||||||
return bufferStr.toLowerCase() === str.toLowerCase();
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Safe buffer slice that doesn't throw on out-of-bounds
|
|
||||||
*/
|
|
||||||
export function safeSlice(buffer: Buffer, start: number, end?: number): Buffer {
|
|
||||||
const safeStart = Math.max(0, Math.min(start, buffer.length));
|
|
||||||
const safeEnd = end === undefined
|
|
||||||
? buffer.length
|
|
||||||
: Math.max(safeStart, Math.min(end, buffer.length));
|
|
||||||
|
|
||||||
return buffer.slice(safeStart, safeEnd);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Check if buffer contains printable ASCII
|
|
||||||
*/
|
|
||||||
export function isPrintableAscii(buffer: Buffer, length?: number): boolean {
|
|
||||||
// Delegate to protocol parser
|
|
||||||
return HttpParser.isPrintableAscii(buffer, length);
|
|
||||||
}
|
|
||||||
@@ -1,64 +0,0 @@
|
|||||||
/**
|
|
||||||
* Fragment Manager for Detection Module
|
|
||||||
*
|
|
||||||
* Manages fragmented protocol data using the shared fragment handler
|
|
||||||
*/
|
|
||||||
|
|
||||||
import { FragmentHandler, type IFragmentOptions } from '../../protocols/common/fragment-handler.js';
|
|
||||||
import type { IConnectionContext } from '../../protocols/common/types.js';
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Detection-specific fragment manager
|
|
||||||
*/
|
|
||||||
export class DetectionFragmentManager {
|
|
||||||
private tlsFragments: FragmentHandler;
|
|
||||||
private httpFragments: FragmentHandler;
|
|
||||||
|
|
||||||
constructor() {
|
|
||||||
// Configure fragment handlers with appropriate limits
|
|
||||||
const tlsOptions: IFragmentOptions = {
|
|
||||||
maxBufferSize: 16384, // TLS record max size
|
|
||||||
timeout: 5000,
|
|
||||||
cleanupInterval: 30000
|
|
||||||
};
|
|
||||||
|
|
||||||
const httpOptions: IFragmentOptions = {
|
|
||||||
maxBufferSize: 8192, // HTTP header reasonable limit
|
|
||||||
timeout: 5000,
|
|
||||||
cleanupInterval: 30000
|
|
||||||
};
|
|
||||||
|
|
||||||
this.tlsFragments = new FragmentHandler(tlsOptions);
|
|
||||||
this.httpFragments = new FragmentHandler(httpOptions);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get fragment handler for protocol type
|
|
||||||
*/
|
|
||||||
getHandler(protocol: 'tls' | 'http'): FragmentHandler {
|
|
||||||
return protocol === 'tls' ? this.tlsFragments : this.httpFragments;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Create connection ID from context
|
|
||||||
*/
|
|
||||||
static createConnectionId(context: IConnectionContext): string {
|
|
||||||
return context.id || `${context.sourceIp}:${context.sourcePort}-${context.destIp}:${context.destPort}`;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Clean up all handlers
|
|
||||||
*/
|
|
||||||
cleanup(): void {
|
|
||||||
this.tlsFragments.cleanup();
|
|
||||||
this.httpFragments.cleanup();
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Destroy all handlers
|
|
||||||
*/
|
|
||||||
destroy(): void {
|
|
||||||
this.tlsFragments.destroy();
|
|
||||||
this.httpFragments.destroy();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,77 +0,0 @@
|
|||||||
/**
|
|
||||||
* Parser utilities for protocol detection
|
|
||||||
* Now delegates to protocol modules for actual parsing
|
|
||||||
*/
|
|
||||||
|
|
||||||
import type { THttpMethod, TTlsVersion } from '../models/detection-types.js';
|
|
||||||
import { HttpParser, HTTP_METHODS, HTTP_VERSIONS } from '../../protocols/http/index.js';
|
|
||||||
import { tlsVersionToString as protocolTlsVersionToString } from '../../protocols/tls/index.js';
|
|
||||||
|
|
||||||
// Re-export constants for backward compatibility
|
|
||||||
export { HTTP_METHODS, HTTP_VERSIONS };
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Parse HTTP request line
|
|
||||||
*/
|
|
||||||
export function parseHttpRequestLine(line: string): {
|
|
||||||
method: THttpMethod;
|
|
||||||
path: string;
|
|
||||||
version: string;
|
|
||||||
} | null {
|
|
||||||
// Delegate to protocol parser
|
|
||||||
const result = HttpParser.parseRequestLine(line);
|
|
||||||
return result ? {
|
|
||||||
method: result.method as THttpMethod,
|
|
||||||
path: result.path,
|
|
||||||
version: result.version
|
|
||||||
} : null;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Parse HTTP header line
|
|
||||||
*/
|
|
||||||
export function parseHttpHeader(line: string): { name: string; value: string } | null {
|
|
||||||
// Delegate to protocol parser
|
|
||||||
return HttpParser.parseHeaderLine(line);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Parse HTTP headers from lines
|
|
||||||
*/
|
|
||||||
export function parseHttpHeaders(lines: string[]): Record<string, string> {
|
|
||||||
// Delegate to protocol parser
|
|
||||||
return HttpParser.parseHeaders(lines);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Convert TLS version bytes to version string
|
|
||||||
*/
|
|
||||||
export function tlsVersionToString(major: number, minor: number): TTlsVersion | null {
|
|
||||||
// Delegate to protocol parser
|
|
||||||
return protocolTlsVersionToString(major, minor) as TTlsVersion;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Extract domain from Host header value
|
|
||||||
*/
|
|
||||||
export function extractDomainFromHost(hostHeader: string): string {
|
|
||||||
// Delegate to protocol parser
|
|
||||||
return HttpParser.extractDomainFromHost(hostHeader);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Validate domain name
|
|
||||||
*/
|
|
||||||
export function isValidDomain(domain: string): boolean {
|
|
||||||
// Delegate to protocol parser
|
|
||||||
return HttpParser.isValidDomain(domain);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Check if string is a valid HTTP method
|
|
||||||
*/
|
|
||||||
export function isHttpMethod(str: string): str is THttpMethod {
|
|
||||||
// Delegate to protocol parser
|
|
||||||
return HttpParser.isHttpMethod(str) && (str as THttpMethod) !== undefined;
|
|
||||||
}
|
|
||||||
|
|
||||||
@@ -11,18 +11,12 @@ export type { ISmartProxyOptions, IConnectionRecord, IRouteConfig, IRouteMatch,
|
|||||||
export type { TSmartProxyCertProvisionObject, ICertProvisionEventComms, ICertificateIssuedEvent, ICertificateFailedEvent } from './proxies/smart-proxy/models/interfaces.js';
|
export type { TSmartProxyCertProvisionObject, ICertProvisionEventComms, ICertificateIssuedEvent, ICertificateFailedEvent } from './proxies/smart-proxy/models/interfaces.js';
|
||||||
export * from './proxies/smart-proxy/utils/index.js';
|
export * from './proxies/smart-proxy/utils/index.js';
|
||||||
|
|
||||||
// Original: export * from './smartproxy/classes.pp.snihandler.js'
|
|
||||||
// Now we export from the new module
|
|
||||||
export { SniHandler } from './tls/sni/sni-handler.js';
|
|
||||||
|
|
||||||
// Core types and utilities
|
// Core types and utilities
|
||||||
export * from './core/models/common-types.js';
|
export * from './core/models/common-types.js';
|
||||||
|
|
||||||
// Export IAcmeOptions from one place only
|
// Export IAcmeOptions from one place only
|
||||||
export type { IAcmeOptions } from './proxies/smart-proxy/models/interfaces.js';
|
export type { IAcmeOptions } from './proxies/smart-proxy/models/interfaces.js';
|
||||||
|
|
||||||
// Modular exports for new architecture
|
// Modular exports
|
||||||
export * as tls from './tls/index.js';
|
|
||||||
export * as routing from './routing/index.js';
|
export * as routing from './routing/index.js';
|
||||||
export * as detection from './detection/index.js';
|
|
||||||
export * as protocols from './protocols/index.js';
|
export * as protocols from './protocols/index.js';
|
||||||
|
|||||||
@@ -1,167 +0,0 @@
|
|||||||
/**
|
|
||||||
* Shared Fragment Handler for Protocol Detection
|
|
||||||
*
|
|
||||||
* Provides unified fragment buffering and reassembly for protocols
|
|
||||||
* that may span multiple TCP packets.
|
|
||||||
*/
|
|
||||||
|
|
||||||
import { Buffer } from 'node:buffer';
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Fragment tracking information
|
|
||||||
*/
|
|
||||||
export interface IFragmentInfo {
|
|
||||||
buffer: Buffer;
|
|
||||||
timestamp: number;
|
|
||||||
connectionId: string;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Options for fragment handling
|
|
||||||
*/
|
|
||||||
export interface IFragmentOptions {
|
|
||||||
maxBufferSize?: number;
|
|
||||||
timeout?: number;
|
|
||||||
cleanupInterval?: number;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Result of fragment processing
|
|
||||||
*/
|
|
||||||
export interface IFragmentResult {
|
|
||||||
isComplete: boolean;
|
|
||||||
buffer?: Buffer;
|
|
||||||
needsMoreData: boolean;
|
|
||||||
error?: string;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Shared fragment handler for protocol detection
|
|
||||||
*/
|
|
||||||
export class FragmentHandler {
|
|
||||||
private fragments = new Map<string, IFragmentInfo>();
|
|
||||||
private cleanupTimer?: NodeJS.Timeout;
|
|
||||||
|
|
||||||
constructor(private options: IFragmentOptions = {}) {
|
|
||||||
// Start cleanup timer if not already running
|
|
||||||
if (options.cleanupInterval && !this.cleanupTimer) {
|
|
||||||
this.cleanupTimer = setInterval(
|
|
||||||
() => this.cleanup(),
|
|
||||||
options.cleanupInterval
|
|
||||||
);
|
|
||||||
// Don't let this timer prevent process exit
|
|
||||||
if (this.cleanupTimer.unref) {
|
|
||||||
this.cleanupTimer.unref();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Add a fragment for a connection
|
|
||||||
*/
|
|
||||||
addFragment(connectionId: string, fragment: Buffer): IFragmentResult {
|
|
||||||
const existing = this.fragments.get(connectionId);
|
|
||||||
|
|
||||||
if (existing) {
|
|
||||||
// Append to existing buffer
|
|
||||||
const newBuffer = Buffer.concat([existing.buffer, fragment]);
|
|
||||||
|
|
||||||
// Check size limit
|
|
||||||
const maxSize = this.options.maxBufferSize || 65536;
|
|
||||||
if (newBuffer.length > maxSize) {
|
|
||||||
this.fragments.delete(connectionId);
|
|
||||||
return {
|
|
||||||
isComplete: false,
|
|
||||||
needsMoreData: false,
|
|
||||||
error: 'Buffer size exceeded maximum allowed'
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update fragment info
|
|
||||||
this.fragments.set(connectionId, {
|
|
||||||
buffer: newBuffer,
|
|
||||||
timestamp: Date.now(),
|
|
||||||
connectionId
|
|
||||||
});
|
|
||||||
|
|
||||||
return {
|
|
||||||
isComplete: false,
|
|
||||||
buffer: newBuffer,
|
|
||||||
needsMoreData: true
|
|
||||||
};
|
|
||||||
} else {
|
|
||||||
// New fragment
|
|
||||||
this.fragments.set(connectionId, {
|
|
||||||
buffer: fragment,
|
|
||||||
timestamp: Date.now(),
|
|
||||||
connectionId
|
|
||||||
});
|
|
||||||
|
|
||||||
return {
|
|
||||||
isComplete: false,
|
|
||||||
buffer: fragment,
|
|
||||||
needsMoreData: true
|
|
||||||
};
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get the current buffer for a connection
|
|
||||||
*/
|
|
||||||
getBuffer(connectionId: string): Buffer | undefined {
|
|
||||||
return this.fragments.get(connectionId)?.buffer;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Mark a connection as complete and clean up
|
|
||||||
*/
|
|
||||||
complete(connectionId: string): void {
|
|
||||||
this.fragments.delete(connectionId);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Check if we're tracking a connection
|
|
||||||
*/
|
|
||||||
hasConnection(connectionId: string): boolean {
|
|
||||||
return this.fragments.has(connectionId);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Clean up expired fragments
|
|
||||||
*/
|
|
||||||
cleanup(): void {
|
|
||||||
const now = Date.now();
|
|
||||||
const timeout = this.options.timeout || 5000;
|
|
||||||
|
|
||||||
for (const [connectionId, info] of this.fragments.entries()) {
|
|
||||||
if (now - info.timestamp > timeout) {
|
|
||||||
this.fragments.delete(connectionId);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Clear all fragments
|
|
||||||
*/
|
|
||||||
clear(): void {
|
|
||||||
this.fragments.clear();
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Destroy the handler and clean up resources
|
|
||||||
*/
|
|
||||||
destroy(): void {
|
|
||||||
if (this.cleanupTimer) {
|
|
||||||
clearInterval(this.cleanupTimer);
|
|
||||||
this.cleanupTimer = undefined;
|
|
||||||
}
|
|
||||||
this.clear();
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get the number of tracked connections
|
|
||||||
*/
|
|
||||||
get size(): number {
|
|
||||||
return this.fragments.size;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,8 +0,0 @@
|
|||||||
/**
|
|
||||||
* Common Protocol Infrastructure
|
|
||||||
*
|
|
||||||
* Shared utilities and types for protocol handling
|
|
||||||
*/
|
|
||||||
|
|
||||||
export * from './fragment-handler.js';
|
|
||||||
export * from './types.js';
|
|
||||||
@@ -1,76 +0,0 @@
|
|||||||
/**
|
|
||||||
* Common Protocol Types
|
|
||||||
*
|
|
||||||
* Shared types used across different protocol implementations
|
|
||||||
*/
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Supported protocol types
|
|
||||||
*/
|
|
||||||
export type TProtocolType = 'tls' | 'http' | 'https' | 'websocket' | 'unknown';
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Protocol detection result
|
|
||||||
*/
|
|
||||||
export interface IProtocolDetectionResult {
|
|
||||||
protocol: TProtocolType;
|
|
||||||
confidence: number; // 0-100
|
|
||||||
requiresMoreData?: boolean;
|
|
||||||
metadata?: {
|
|
||||||
version?: string;
|
|
||||||
method?: string;
|
|
||||||
[key: string]: any;
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Routing information extracted from protocols
|
|
||||||
*/
|
|
||||||
export interface IRoutingInfo {
|
|
||||||
domain?: string;
|
|
||||||
port?: number;
|
|
||||||
path?: string;
|
|
||||||
protocol: TProtocolType;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Connection context for protocol operations
|
|
||||||
*/
|
|
||||||
export interface IConnectionContext {
|
|
||||||
id: string;
|
|
||||||
sourceIp?: string;
|
|
||||||
sourcePort?: number;
|
|
||||||
destIp?: string;
|
|
||||||
destPort?: number;
|
|
||||||
timestamp?: number;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Protocol detection options
|
|
||||||
*/
|
|
||||||
export interface IProtocolDetectionOptions {
|
|
||||||
quickMode?: boolean; // Only do minimal detection
|
|
||||||
extractRouting?: boolean; // Extract routing information
|
|
||||||
maxWaitTime?: number; // Max time to wait for complete data
|
|
||||||
maxBufferSize?: number; // Max buffer size for fragmented data
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Base interface for protocol detectors
|
|
||||||
*/
|
|
||||||
export interface IProtocolDetector {
|
|
||||||
/**
|
|
||||||
* Check if this detector can handle the data
|
|
||||||
*/
|
|
||||||
canHandle(data: Buffer): boolean;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Perform quick detection (first few bytes only)
|
|
||||||
*/
|
|
||||||
quickDetect(data: Buffer): IProtocolDetectionResult;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Extract routing information if possible
|
|
||||||
*/
|
|
||||||
extractRouting?(data: Buffer, context?: IConnectionContext): IRoutingInfo | null;
|
|
||||||
}
|
|
||||||
@@ -5,4 +5,3 @@
|
|||||||
|
|
||||||
export * from './constants.js';
|
export * from './constants.js';
|
||||||
export * from './types.js';
|
export * from './types.js';
|
||||||
export * from './parser.js';
|
|
||||||
@@ -1,219 +0,0 @@
|
|||||||
/**
|
|
||||||
* HTTP Protocol Parser
|
|
||||||
* Generic HTTP parsing utilities
|
|
||||||
*/
|
|
||||||
|
|
||||||
import { HTTP_METHODS, type THttpMethod, type THttpVersion } from './constants.js';
|
|
||||||
import type { IHttpRequestLine, IHttpHeader } from './types.js';
|
|
||||||
|
|
||||||
/**
|
|
||||||
* HTTP parser utilities
|
|
||||||
*/
|
|
||||||
export class HttpParser {
|
|
||||||
/**
|
|
||||||
* Check if string is a valid HTTP method
|
|
||||||
*/
|
|
||||||
static isHttpMethod(str: string): str is THttpMethod {
|
|
||||||
return HTTP_METHODS.includes(str as THttpMethod);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Parse HTTP request line
|
|
||||||
*/
|
|
||||||
static parseRequestLine(line: string): IHttpRequestLine | null {
|
|
||||||
const parts = line.trim().split(' ');
|
|
||||||
|
|
||||||
if (parts.length !== 3) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
const [method, path, version] = parts;
|
|
||||||
|
|
||||||
// Validate method
|
|
||||||
if (!this.isHttpMethod(method)) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate version
|
|
||||||
if (!version.startsWith('HTTP/')) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
return {
|
|
||||||
method: method as THttpMethod,
|
|
||||||
path,
|
|
||||||
version: version as THttpVersion
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Parse HTTP header line
|
|
||||||
*/
|
|
||||||
static parseHeaderLine(line: string): IHttpHeader | null {
|
|
||||||
const colonIndex = line.indexOf(':');
|
|
||||||
|
|
||||||
if (colonIndex === -1) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
const name = line.slice(0, colonIndex).trim();
|
|
||||||
const value = line.slice(colonIndex + 1).trim();
|
|
||||||
|
|
||||||
if (!name) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
return { name, value };
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Parse HTTP headers from lines
|
|
||||||
*/
|
|
||||||
static parseHeaders(lines: string[]): Record<string, string> {
|
|
||||||
const headers: Record<string, string> = {};
|
|
||||||
|
|
||||||
for (const line of lines) {
|
|
||||||
const header = this.parseHeaderLine(line);
|
|
||||||
if (header) {
|
|
||||||
// Convert header names to lowercase for consistency
|
|
||||||
headers[header.name.toLowerCase()] = header.value;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return headers;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Extract domain from Host header value
|
|
||||||
*/
|
|
||||||
static extractDomainFromHost(hostHeader: string): string {
|
|
||||||
// Remove port if present
|
|
||||||
const colonIndex = hostHeader.lastIndexOf(':');
|
|
||||||
if (colonIndex !== -1) {
|
|
||||||
// Check if it's not part of IPv6 address
|
|
||||||
const beforeColon = hostHeader.slice(0, colonIndex);
|
|
||||||
if (!beforeColon.includes(']')) {
|
|
||||||
return beforeColon;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return hostHeader;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Validate domain name
|
|
||||||
*/
|
|
||||||
static isValidDomain(domain: string): boolean {
|
|
||||||
// Basic domain validation
|
|
||||||
if (!domain || domain.length > 253) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check for valid characters and structure
|
|
||||||
const domainRegex = /^(?!-)[A-Za-z0-9-]{1,63}(?<!-)(\.[A-Za-z0-9-]{1,63})*$/;
|
|
||||||
return domainRegex.test(domain);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Extract line from buffer
|
|
||||||
*/
|
|
||||||
static extractLine(buffer: Buffer, offset: number = 0): { line: string; nextOffset: number } | null {
|
|
||||||
// Look for CRLF
|
|
||||||
const crlfIndex = buffer.indexOf('\r\n', offset);
|
|
||||||
if (crlfIndex === -1) {
|
|
||||||
// Look for just LF
|
|
||||||
const lfIndex = buffer.indexOf('\n', offset);
|
|
||||||
if (lfIndex === -1) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
return {
|
|
||||||
line: buffer.slice(offset, lfIndex).toString('utf8'),
|
|
||||||
nextOffset: lfIndex + 1
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
return {
|
|
||||||
line: buffer.slice(offset, crlfIndex).toString('utf8'),
|
|
||||||
nextOffset: crlfIndex + 2
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Check if buffer contains printable ASCII
|
|
||||||
*/
|
|
||||||
static isPrintableAscii(buffer: Buffer, length?: number): boolean {
|
|
||||||
const checkLength = Math.min(length || buffer.length, buffer.length);
|
|
||||||
|
|
||||||
for (let i = 0; i < checkLength; i++) {
|
|
||||||
const byte = buffer[i];
|
|
||||||
// Allow printable ASCII (32-126) plus tab (9), LF (10), and CR (13)
|
|
||||||
if (byte < 32 || byte > 126) {
|
|
||||||
if (byte !== 9 && byte !== 10 && byte !== 13) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Quick check if buffer starts with HTTP method
|
|
||||||
*/
|
|
||||||
static quickCheck(buffer: Buffer): boolean {
|
|
||||||
if (buffer.length < 3) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check common HTTP methods
|
|
||||||
const start = buffer.slice(0, 7).toString('ascii');
|
|
||||||
return start.startsWith('GET ') ||
|
|
||||||
start.startsWith('POST ') ||
|
|
||||||
start.startsWith('PUT ') ||
|
|
||||||
start.startsWith('DELETE ') ||
|
|
||||||
start.startsWith('HEAD ') ||
|
|
||||||
start.startsWith('OPTIONS') ||
|
|
||||||
start.startsWith('PATCH ') ||
|
|
||||||
start.startsWith('CONNECT') ||
|
|
||||||
start.startsWith('TRACE ');
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Parse query string
|
|
||||||
*/
|
|
||||||
static parseQueryString(queryString: string): Record<string, string> {
|
|
||||||
const params: Record<string, string> = {};
|
|
||||||
|
|
||||||
if (!queryString) {
|
|
||||||
return params;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Remove leading '?' if present
|
|
||||||
if (queryString.startsWith('?')) {
|
|
||||||
queryString = queryString.slice(1);
|
|
||||||
}
|
|
||||||
|
|
||||||
const pairs = queryString.split('&');
|
|
||||||
for (const pair of pairs) {
|
|
||||||
const [key, value] = pair.split('=');
|
|
||||||
if (key) {
|
|
||||||
params[decodeURIComponent(key)] = value ? decodeURIComponent(value) : '';
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return params;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Build query string from params
|
|
||||||
*/
|
|
||||||
static buildQueryString(params: Record<string, string>): string {
|
|
||||||
const pairs: string[] = [];
|
|
||||||
|
|
||||||
for (const [key, value] of Object.entries(params)) {
|
|
||||||
pairs.push(`${encodeURIComponent(key)}=${encodeURIComponent(value)}`);
|
|
||||||
}
|
|
||||||
|
|
||||||
return pairs.length > 0 ? '?' + pairs.join('&') : '';
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,12 +1,5 @@
|
|||||||
/**
|
/**
|
||||||
* Protocol-specific modules for smartproxy
|
* Protocol-specific modules for smartproxy
|
||||||
*
|
|
||||||
* This directory contains generic protocol knowledge separated from
|
|
||||||
* smartproxy-specific implementation details.
|
|
||||||
*/
|
*/
|
||||||
|
|
||||||
export * as common from './common/index.js';
|
|
||||||
export * as tls from './tls/index.js';
|
|
||||||
export * as http from './http/index.js';
|
export * as http from './http/index.js';
|
||||||
export * as proxy from './proxy/index.js';
|
|
||||||
export * as websocket from './websocket/index.js';
|
|
||||||
@@ -1,6 +0,0 @@
|
|||||||
/**
|
|
||||||
* PROXY Protocol Module
|
|
||||||
* Type definitions for HAProxy PROXY protocol v1/v2
|
|
||||||
*/
|
|
||||||
|
|
||||||
export * from './types.js';
|
|
||||||
@@ -1,53 +0,0 @@
|
|||||||
/**
|
|
||||||
* PROXY Protocol Type Definitions
|
|
||||||
* Based on HAProxy PROXY protocol specification
|
|
||||||
*/
|
|
||||||
|
|
||||||
/**
|
|
||||||
* PROXY protocol version
|
|
||||||
*/
|
|
||||||
export type TProxyProtocolVersion = 'v1' | 'v2';
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Connection protocol type
|
|
||||||
*/
|
|
||||||
export type TProxyProtocol = 'TCP4' | 'TCP6' | 'UDP4' | 'UDP6' | 'UNKNOWN';
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Interface representing parsed PROXY protocol information
|
|
||||||
*/
|
|
||||||
export interface IProxyInfo {
|
|
||||||
protocol: TProxyProtocol;
|
|
||||||
sourceIP: string;
|
|
||||||
sourcePort: number;
|
|
||||||
destinationIP: string;
|
|
||||||
destinationPort: number;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Interface for parse result including remaining data
|
|
||||||
*/
|
|
||||||
export interface IProxyParseResult {
|
|
||||||
proxyInfo: IProxyInfo | null;
|
|
||||||
remainingData: Buffer;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* PROXY protocol v2 header format
|
|
||||||
*/
|
|
||||||
export interface IProxyV2Header {
|
|
||||||
signature: Buffer;
|
|
||||||
versionCommand: number;
|
|
||||||
family: number;
|
|
||||||
length: number;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Connection information for PROXY protocol
|
|
||||||
*/
|
|
||||||
export interface IProxyConnectionInfo {
|
|
||||||
sourceIp?: string;
|
|
||||||
sourcePort?: number;
|
|
||||||
destIp?: string;
|
|
||||||
destPort?: number;
|
|
||||||
}
|
|
||||||
@@ -1,3 +0,0 @@
|
|||||||
/**
|
|
||||||
* TLS alerts
|
|
||||||
*/
|
|
||||||
@@ -1,259 +0,0 @@
|
|||||||
import * as plugins from '../../../plugins.js';
|
|
||||||
import { TlsAlertLevel, TlsAlertDescription, TlsVersion } from '../utils/tls-utils.js';
|
|
||||||
|
|
||||||
/**
|
|
||||||
* TlsAlert class for creating and sending TLS alert messages
|
|
||||||
*/
|
|
||||||
export class TlsAlert {
|
|
||||||
// Use enum values from TlsAlertLevel
|
|
||||||
static readonly LEVEL_WARNING = TlsAlertLevel.WARNING;
|
|
||||||
static readonly LEVEL_FATAL = TlsAlertLevel.FATAL;
|
|
||||||
|
|
||||||
// Use enum values from TlsAlertDescription
|
|
||||||
static readonly CLOSE_NOTIFY = TlsAlertDescription.CLOSE_NOTIFY;
|
|
||||||
static readonly UNEXPECTED_MESSAGE = TlsAlertDescription.UNEXPECTED_MESSAGE;
|
|
||||||
static readonly BAD_RECORD_MAC = TlsAlertDescription.BAD_RECORD_MAC;
|
|
||||||
static readonly DECRYPTION_FAILED = TlsAlertDescription.DECRYPTION_FAILED;
|
|
||||||
static readonly RECORD_OVERFLOW = TlsAlertDescription.RECORD_OVERFLOW;
|
|
||||||
static readonly DECOMPRESSION_FAILURE = TlsAlertDescription.DECOMPRESSION_FAILURE;
|
|
||||||
static readonly HANDSHAKE_FAILURE = TlsAlertDescription.HANDSHAKE_FAILURE;
|
|
||||||
static readonly NO_CERTIFICATE = TlsAlertDescription.NO_CERTIFICATE;
|
|
||||||
static readonly BAD_CERTIFICATE = TlsAlertDescription.BAD_CERTIFICATE;
|
|
||||||
static readonly UNSUPPORTED_CERTIFICATE = TlsAlertDescription.UNSUPPORTED_CERTIFICATE;
|
|
||||||
static readonly CERTIFICATE_REVOKED = TlsAlertDescription.CERTIFICATE_REVOKED;
|
|
||||||
static readonly CERTIFICATE_EXPIRED = TlsAlertDescription.CERTIFICATE_EXPIRED;
|
|
||||||
static readonly CERTIFICATE_UNKNOWN = TlsAlertDescription.CERTIFICATE_UNKNOWN;
|
|
||||||
static readonly ILLEGAL_PARAMETER = TlsAlertDescription.ILLEGAL_PARAMETER;
|
|
||||||
static readonly UNKNOWN_CA = TlsAlertDescription.UNKNOWN_CA;
|
|
||||||
static readonly ACCESS_DENIED = TlsAlertDescription.ACCESS_DENIED;
|
|
||||||
static readonly DECODE_ERROR = TlsAlertDescription.DECODE_ERROR;
|
|
||||||
static readonly DECRYPT_ERROR = TlsAlertDescription.DECRYPT_ERROR;
|
|
||||||
static readonly EXPORT_RESTRICTION = TlsAlertDescription.EXPORT_RESTRICTION;
|
|
||||||
static readonly PROTOCOL_VERSION = TlsAlertDescription.PROTOCOL_VERSION;
|
|
||||||
static readonly INSUFFICIENT_SECURITY = TlsAlertDescription.INSUFFICIENT_SECURITY;
|
|
||||||
static readonly INTERNAL_ERROR = TlsAlertDescription.INTERNAL_ERROR;
|
|
||||||
static readonly INAPPROPRIATE_FALLBACK = TlsAlertDescription.INAPPROPRIATE_FALLBACK;
|
|
||||||
static readonly USER_CANCELED = TlsAlertDescription.USER_CANCELED;
|
|
||||||
static readonly NO_RENEGOTIATION = TlsAlertDescription.NO_RENEGOTIATION;
|
|
||||||
static readonly MISSING_EXTENSION = TlsAlertDescription.MISSING_EXTENSION;
|
|
||||||
static readonly UNSUPPORTED_EXTENSION = TlsAlertDescription.UNSUPPORTED_EXTENSION;
|
|
||||||
static readonly CERTIFICATE_REQUIRED = TlsAlertDescription.CERTIFICATE_REQUIRED;
|
|
||||||
static readonly UNRECOGNIZED_NAME = TlsAlertDescription.UNRECOGNIZED_NAME;
|
|
||||||
static readonly BAD_CERTIFICATE_STATUS_RESPONSE = TlsAlertDescription.BAD_CERTIFICATE_STATUS_RESPONSE;
|
|
||||||
static readonly BAD_CERTIFICATE_HASH_VALUE = TlsAlertDescription.BAD_CERTIFICATE_HASH_VALUE;
|
|
||||||
static readonly UNKNOWN_PSK_IDENTITY = TlsAlertDescription.UNKNOWN_PSK_IDENTITY;
|
|
||||||
static readonly CERTIFICATE_REQUIRED_1_3 = TlsAlertDescription.CERTIFICATE_REQUIRED_1_3;
|
|
||||||
static readonly NO_APPLICATION_PROTOCOL = TlsAlertDescription.NO_APPLICATION_PROTOCOL;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Create a TLS alert buffer with the specified level and description code
|
|
||||||
*
|
|
||||||
* @param level Alert level (warning or fatal)
|
|
||||||
* @param description Alert description code
|
|
||||||
* @param tlsVersion TLS version bytes (default is TLS 1.2: 0x0303)
|
|
||||||
* @returns Buffer containing the TLS alert message
|
|
||||||
*/
|
|
||||||
static create(
|
|
||||||
level: number,
|
|
||||||
description: number,
|
|
||||||
tlsVersion: [number, number] = [TlsVersion.TLS1_2[0], TlsVersion.TLS1_2[1]]
|
|
||||||
): Buffer {
|
|
||||||
return Buffer.from([
|
|
||||||
0x15, // Alert record type
|
|
||||||
tlsVersion[0],
|
|
||||||
tlsVersion[1], // TLS version (default to TLS 1.2: 0x0303)
|
|
||||||
0x00,
|
|
||||||
0x02, // Length
|
|
||||||
level, // Alert level
|
|
||||||
description, // Alert description
|
|
||||||
]);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Create a warning-level TLS alert
|
|
||||||
*
|
|
||||||
* @param description Alert description code
|
|
||||||
* @returns Buffer containing the warning-level TLS alert message
|
|
||||||
*/
|
|
||||||
static createWarning(description: number): Buffer {
|
|
||||||
return this.create(this.LEVEL_WARNING, description);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Create a fatal-level TLS alert
|
|
||||||
*
|
|
||||||
* @param description Alert description code
|
|
||||||
* @returns Buffer containing the fatal-level TLS alert message
|
|
||||||
*/
|
|
||||||
static createFatal(description: number): Buffer {
|
|
||||||
return this.create(this.LEVEL_FATAL, description);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Send a TLS alert to a socket and optionally close the connection
|
|
||||||
*
|
|
||||||
* @param socket The socket to send the alert to
|
|
||||||
* @param level Alert level (warning or fatal)
|
|
||||||
* @param description Alert description code
|
|
||||||
* @param closeAfterSend Whether to close the connection after sending the alert
|
|
||||||
* @param closeDelay Milliseconds to wait before closing the connection (default: 200ms)
|
|
||||||
* @returns Promise that resolves when the alert has been sent
|
|
||||||
*/
|
|
||||||
static async send(
|
|
||||||
socket: plugins.net.Socket,
|
|
||||||
level: number,
|
|
||||||
description: number,
|
|
||||||
closeAfterSend: boolean = false,
|
|
||||||
closeDelay: number = 200
|
|
||||||
): Promise<void> {
|
|
||||||
const alert = this.create(level, description);
|
|
||||||
|
|
||||||
return new Promise<void>((resolve, reject) => {
|
|
||||||
try {
|
|
||||||
// Ensure the alert is written as a single packet
|
|
||||||
socket.cork();
|
|
||||||
const writeSuccessful = socket.write(alert, (err) => {
|
|
||||||
if (err) {
|
|
||||||
reject(err);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (closeAfterSend) {
|
|
||||||
setTimeout(() => {
|
|
||||||
socket.end();
|
|
||||||
resolve();
|
|
||||||
}, closeDelay);
|
|
||||||
} else {
|
|
||||||
resolve();
|
|
||||||
}
|
|
||||||
});
|
|
||||||
socket.uncork();
|
|
||||||
|
|
||||||
// If write wasn't successful immediately, wait for drain
|
|
||||||
if (!writeSuccessful && !closeAfterSend) {
|
|
||||||
socket.once('drain', () => {
|
|
||||||
resolve();
|
|
||||||
});
|
|
||||||
}
|
|
||||||
} catch (err) {
|
|
||||||
reject(err);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Pre-defined TLS alert messages
|
|
||||||
*/
|
|
||||||
static readonly alerts = {
|
|
||||||
// Warning level alerts
|
|
||||||
closeNotify: TlsAlert.createWarning(TlsAlert.CLOSE_NOTIFY),
|
|
||||||
unsupportedExtension: TlsAlert.createWarning(TlsAlert.UNSUPPORTED_EXTENSION),
|
|
||||||
certificateRequired: TlsAlert.createWarning(TlsAlert.CERTIFICATE_REQUIRED),
|
|
||||||
unrecognizedName: TlsAlert.createWarning(TlsAlert.UNRECOGNIZED_NAME),
|
|
||||||
noRenegotiation: TlsAlert.createWarning(TlsAlert.NO_RENEGOTIATION),
|
|
||||||
userCanceled: TlsAlert.createWarning(TlsAlert.USER_CANCELED),
|
|
||||||
|
|
||||||
// Warning level alerts for session resumption
|
|
||||||
certificateExpiredWarning: TlsAlert.createWarning(TlsAlert.CERTIFICATE_EXPIRED),
|
|
||||||
handshakeFailureWarning: TlsAlert.createWarning(TlsAlert.HANDSHAKE_FAILURE),
|
|
||||||
insufficientSecurityWarning: TlsAlert.createWarning(TlsAlert.INSUFFICIENT_SECURITY),
|
|
||||||
|
|
||||||
// Fatal level alerts
|
|
||||||
unexpectedMessage: TlsAlert.createFatal(TlsAlert.UNEXPECTED_MESSAGE),
|
|
||||||
badRecordMac: TlsAlert.createFatal(TlsAlert.BAD_RECORD_MAC),
|
|
||||||
recordOverflow: TlsAlert.createFatal(TlsAlert.RECORD_OVERFLOW),
|
|
||||||
handshakeFailure: TlsAlert.createFatal(TlsAlert.HANDSHAKE_FAILURE),
|
|
||||||
badCertificate: TlsAlert.createFatal(TlsAlert.BAD_CERTIFICATE),
|
|
||||||
certificateExpired: TlsAlert.createFatal(TlsAlert.CERTIFICATE_EXPIRED),
|
|
||||||
certificateUnknown: TlsAlert.createFatal(TlsAlert.CERTIFICATE_UNKNOWN),
|
|
||||||
illegalParameter: TlsAlert.createFatal(TlsAlert.ILLEGAL_PARAMETER),
|
|
||||||
unknownCA: TlsAlert.createFatal(TlsAlert.UNKNOWN_CA),
|
|
||||||
accessDenied: TlsAlert.createFatal(TlsAlert.ACCESS_DENIED),
|
|
||||||
decodeError: TlsAlert.createFatal(TlsAlert.DECODE_ERROR),
|
|
||||||
decryptError: TlsAlert.createFatal(TlsAlert.DECRYPT_ERROR),
|
|
||||||
protocolVersion: TlsAlert.createFatal(TlsAlert.PROTOCOL_VERSION),
|
|
||||||
insufficientSecurity: TlsAlert.createFatal(TlsAlert.INSUFFICIENT_SECURITY),
|
|
||||||
internalError: TlsAlert.createFatal(TlsAlert.INTERNAL_ERROR),
|
|
||||||
unrecognizedNameFatal: TlsAlert.createFatal(TlsAlert.UNRECOGNIZED_NAME),
|
|
||||||
};
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Utility method to send a warning-level unrecognized_name alert
|
|
||||||
* Specifically designed for SNI issues to encourage the client to retry with SNI
|
|
||||||
*
|
|
||||||
* @param socket The socket to send the alert to
|
|
||||||
* @returns Promise that resolves when the alert has been sent
|
|
||||||
*/
|
|
||||||
static async sendSniRequired(socket: plugins.net.Socket): Promise<void> {
|
|
||||||
return this.send(socket, this.LEVEL_WARNING, this.UNRECOGNIZED_NAME);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Utility method to send a close_notify alert and close the connection
|
|
||||||
*
|
|
||||||
* @param socket The socket to send the alert to
|
|
||||||
* @param closeDelay Milliseconds to wait before closing the connection (default: 200ms)
|
|
||||||
* @returns Promise that resolves when the alert has been sent and the connection closed
|
|
||||||
*/
|
|
||||||
static async sendCloseNotify(socket: plugins.net.Socket, closeDelay: number = 200): Promise<void> {
|
|
||||||
return this.send(socket, this.LEVEL_WARNING, this.CLOSE_NOTIFY, true, closeDelay);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Utility method to send a certificate_expired alert to force new TLS session
|
|
||||||
*
|
|
||||||
* @param socket The socket to send the alert to
|
|
||||||
* @param fatal Whether to send as a fatal alert (default: false)
|
|
||||||
* @param closeAfterSend Whether to close the connection after sending the alert (default: true)
|
|
||||||
* @param closeDelay Milliseconds to wait before closing the connection (default: 200ms)
|
|
||||||
* @returns Promise that resolves when the alert has been sent
|
|
||||||
*/
|
|
||||||
static async sendCertificateExpired(
|
|
||||||
socket: plugins.net.Socket,
|
|
||||||
fatal: boolean = false,
|
|
||||||
closeAfterSend: boolean = true,
|
|
||||||
closeDelay: number = 200
|
|
||||||
): Promise<void> {
|
|
||||||
const level = fatal ? this.LEVEL_FATAL : this.LEVEL_WARNING;
|
|
||||||
return this.send(socket, level, this.CERTIFICATE_EXPIRED, closeAfterSend, closeDelay);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Send a sequence of alerts to force SNI from clients
|
|
||||||
* This combines multiple alerts to ensure maximum browser compatibility
|
|
||||||
*
|
|
||||||
* @param socket The socket to send the alerts to
|
|
||||||
* @returns Promise that resolves when all alerts have been sent
|
|
||||||
*/
|
|
||||||
static async sendForceSniSequence(socket: plugins.net.Socket): Promise<void> {
|
|
||||||
try {
|
|
||||||
// Send unrecognized_name (warning)
|
|
||||||
socket.cork();
|
|
||||||
socket.write(this.alerts.unrecognizedName);
|
|
||||||
socket.uncork();
|
|
||||||
|
|
||||||
// Give the socket time to send the alert
|
|
||||||
return new Promise((resolve) => {
|
|
||||||
setTimeout(resolve, 50);
|
|
||||||
});
|
|
||||||
} catch (err) {
|
|
||||||
return Promise.reject(err);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Send a fatal level alert that immediately terminates the connection
|
|
||||||
*
|
|
||||||
* @param socket The socket to send the alert to
|
|
||||||
* @param description Alert description code
|
|
||||||
* @param closeDelay Milliseconds to wait before closing the connection (default: 100ms)
|
|
||||||
* @returns Promise that resolves when the alert has been sent and the connection closed
|
|
||||||
*/
|
|
||||||
static async sendFatalAndClose(
|
|
||||||
socket: plugins.net.Socket,
|
|
||||||
description: number,
|
|
||||||
closeDelay: number = 100
|
|
||||||
): Promise<void> {
|
|
||||||
return this.send(socket, this.LEVEL_FATAL, description, true, closeDelay);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,37 +0,0 @@
|
|||||||
/**
|
|
||||||
* TLS Protocol Module
|
|
||||||
* Contains generic TLS protocol knowledge including parsers, constants, and utilities
|
|
||||||
*/
|
|
||||||
|
|
||||||
// Export all sub-modules
|
|
||||||
export * from './alerts/index.js';
|
|
||||||
export * from './sni/index.js';
|
|
||||||
export * from './utils/index.js';
|
|
||||||
|
|
||||||
// Re-export main utilities and types for convenience
|
|
||||||
export {
|
|
||||||
TlsUtils,
|
|
||||||
TlsRecordType,
|
|
||||||
TlsHandshakeType,
|
|
||||||
TlsExtensionType,
|
|
||||||
TlsAlertLevel,
|
|
||||||
TlsAlertDescription,
|
|
||||||
TlsVersion
|
|
||||||
} from './utils/tls-utils.js';
|
|
||||||
export { TlsAlert } from './alerts/tls-alert.js';
|
|
||||||
export { ClientHelloParser } from './sni/client-hello-parser.js';
|
|
||||||
export { SniExtraction } from './sni/sni-extraction.js';
|
|
||||||
|
|
||||||
// Export tlsVersionToString helper
|
|
||||||
export function tlsVersionToString(major: number, minor: number): string | null {
|
|
||||||
if (major === 0x03) {
|
|
||||||
switch (minor) {
|
|
||||||
case 0x00: return 'SSLv3';
|
|
||||||
case 0x01: return 'TLSv1.0';
|
|
||||||
case 0x02: return 'TLSv1.1';
|
|
||||||
case 0x03: return 'TLSv1.2';
|
|
||||||
case 0x04: return 'TLSv1.3';
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
@@ -1,629 +0,0 @@
|
|||||||
import { Buffer } from 'node:buffer';
|
|
||||||
import {
|
|
||||||
TlsRecordType,
|
|
||||||
TlsHandshakeType,
|
|
||||||
TlsExtensionType,
|
|
||||||
TlsUtils
|
|
||||||
} from '../utils/tls-utils.js';
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Interface for logging functions used by the parser
|
|
||||||
*/
|
|
||||||
export type LoggerFunction = (message: string) => void;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Result of a session resumption check
|
|
||||||
*/
|
|
||||||
export interface SessionResumptionResult {
|
|
||||||
isResumption: boolean;
|
|
||||||
hasSNI: boolean;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Information about parsed TLS extensions
|
|
||||||
*/
|
|
||||||
export interface ExtensionInfo {
|
|
||||||
type: number;
|
|
||||||
length: number;
|
|
||||||
data: Buffer;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Result of a ClientHello parse operation
|
|
||||||
*/
|
|
||||||
export interface ClientHelloParseResult {
|
|
||||||
isValid: boolean;
|
|
||||||
version?: [number, number];
|
|
||||||
random?: Buffer;
|
|
||||||
sessionId?: Buffer;
|
|
||||||
hasSessionId: boolean;
|
|
||||||
cipherSuites?: Buffer;
|
|
||||||
compressionMethods?: Buffer;
|
|
||||||
extensions: ExtensionInfo[];
|
|
||||||
serverNameList?: string[];
|
|
||||||
hasSessionTicket: boolean;
|
|
||||||
hasPsk: boolean;
|
|
||||||
hasEarlyData: boolean;
|
|
||||||
error?: string;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Fragment tracking information
|
|
||||||
*/
|
|
||||||
export interface FragmentTrackingInfo {
|
|
||||||
buffer: Buffer;
|
|
||||||
timestamp: number;
|
|
||||||
connectionId: string;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Class for parsing TLS ClientHello messages
|
|
||||||
*/
|
|
||||||
export class ClientHelloParser {
|
|
||||||
// Buffer for handling fragmented ClientHello messages
|
|
||||||
private static fragmentedBuffers: Map<string, FragmentTrackingInfo> = new Map();
|
|
||||||
private static fragmentTimeout: number = 1000; // ms to wait for fragments before cleanup
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Clean up expired fragments
|
|
||||||
*/
|
|
||||||
private static cleanupExpiredFragments(): void {
|
|
||||||
const now = Date.now();
|
|
||||||
for (const [connectionId, info] of this.fragmentedBuffers.entries()) {
|
|
||||||
if (now - info.timestamp > this.fragmentTimeout) {
|
|
||||||
this.fragmentedBuffers.delete(connectionId);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Handles potential fragmented ClientHello messages by buffering and reassembling
|
|
||||||
* TLS record fragments that might span multiple TCP packets.
|
|
||||||
*
|
|
||||||
* @param buffer The current buffer fragment
|
|
||||||
* @param connectionId Unique identifier for the connection
|
|
||||||
* @param logger Optional logging function
|
|
||||||
* @returns A complete buffer if reassembly is successful, or undefined if more fragments are needed
|
|
||||||
*/
|
|
||||||
public static handleFragmentedClientHello(
|
|
||||||
buffer: Buffer,
|
|
||||||
connectionId: string,
|
|
||||||
logger?: LoggerFunction
|
|
||||||
): Buffer | undefined {
|
|
||||||
const log = logger || (() => {});
|
|
||||||
|
|
||||||
// Periodically clean up expired fragments
|
|
||||||
this.cleanupExpiredFragments();
|
|
||||||
|
|
||||||
// Check if we've seen this connection before
|
|
||||||
if (!this.fragmentedBuffers.has(connectionId)) {
|
|
||||||
// New connection, start with this buffer
|
|
||||||
this.fragmentedBuffers.set(connectionId, {
|
|
||||||
buffer,
|
|
||||||
timestamp: Date.now(),
|
|
||||||
connectionId
|
|
||||||
});
|
|
||||||
|
|
||||||
// Evaluate if this buffer already contains a complete ClientHello
|
|
||||||
try {
|
|
||||||
if (buffer.length >= 5) {
|
|
||||||
// Get the record length from TLS header
|
|
||||||
const recordLength = (buffer[3] << 8) + buffer[4] + 5; // +5 for the TLS record header itself
|
|
||||||
log(`Initial buffer size: ${buffer.length}, expected record length: ${recordLength}`);
|
|
||||||
|
|
||||||
// Check if this buffer already contains a complete TLS record
|
|
||||||
if (buffer.length >= recordLength) {
|
|
||||||
log(`Initial buffer contains complete ClientHello, length: ${buffer.length}`);
|
|
||||||
return buffer;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
log(
|
|
||||||
`Initial buffer too small (${buffer.length} bytes), needs at least 5 bytes for TLS header`
|
|
||||||
);
|
|
||||||
}
|
|
||||||
} catch (e) {
|
|
||||||
log(`Error checking initial buffer completeness: ${e}`);
|
|
||||||
}
|
|
||||||
|
|
||||||
log(`Started buffering connection ${connectionId}, initial size: ${buffer.length}`);
|
|
||||||
return undefined; // Need more fragments
|
|
||||||
} else {
|
|
||||||
// Existing connection, append this buffer
|
|
||||||
const existingInfo = this.fragmentedBuffers.get(connectionId)!;
|
|
||||||
const newBuffer = Buffer.concat([existingInfo.buffer, buffer]);
|
|
||||||
|
|
||||||
// Update the buffer and timestamp
|
|
||||||
this.fragmentedBuffers.set(connectionId, {
|
|
||||||
...existingInfo,
|
|
||||||
buffer: newBuffer,
|
|
||||||
timestamp: Date.now()
|
|
||||||
});
|
|
||||||
|
|
||||||
log(`Appended to buffer for ${connectionId}, new size: ${newBuffer.length}`);
|
|
||||||
|
|
||||||
// Check if we now have a complete ClientHello
|
|
||||||
try {
|
|
||||||
if (newBuffer.length >= 5) {
|
|
||||||
// Get the record length from TLS header
|
|
||||||
const recordLength = (newBuffer[3] << 8) + newBuffer[4] + 5; // +5 for the TLS record header itself
|
|
||||||
log(
|
|
||||||
`Reassembled buffer size: ${newBuffer.length}, expected record length: ${recordLength}`
|
|
||||||
);
|
|
||||||
|
|
||||||
// Check if we have a complete TLS record now
|
|
||||||
if (newBuffer.length >= recordLength) {
|
|
||||||
log(
|
|
||||||
`Assembled complete ClientHello, length: ${newBuffer.length}, needed: ${recordLength}`
|
|
||||||
);
|
|
||||||
|
|
||||||
// Extract the complete TLS record (might be followed by more data)
|
|
||||||
const completeRecord = newBuffer.slice(0, recordLength);
|
|
||||||
|
|
||||||
// Check if this record is indeed a ClientHello (type 1) at position 5
|
|
||||||
if (
|
|
||||||
completeRecord.length > 5 &&
|
|
||||||
completeRecord[5] === TlsHandshakeType.CLIENT_HELLO
|
|
||||||
) {
|
|
||||||
log(`Verified record is a ClientHello handshake message`);
|
|
||||||
|
|
||||||
// Complete message received, remove from tracking
|
|
||||||
this.fragmentedBuffers.delete(connectionId);
|
|
||||||
return completeRecord;
|
|
||||||
} else {
|
|
||||||
log(`Record is complete but not a ClientHello handshake, continuing to buffer`);
|
|
||||||
// This might be another TLS record type preceding the ClientHello
|
|
||||||
|
|
||||||
// Try checking for a ClientHello starting at the end of this record
|
|
||||||
if (newBuffer.length > recordLength + 5) {
|
|
||||||
const nextRecordType = newBuffer[recordLength];
|
|
||||||
log(
|
|
||||||
`Next record type: ${nextRecordType} (looking for ${TlsRecordType.HANDSHAKE})`
|
|
||||||
);
|
|
||||||
|
|
||||||
if (nextRecordType === TlsRecordType.HANDSHAKE) {
|
|
||||||
const handshakeType = newBuffer[recordLength + 5];
|
|
||||||
log(
|
|
||||||
`Next handshake type: ${handshakeType} (looking for ${TlsHandshakeType.CLIENT_HELLO})`
|
|
||||||
);
|
|
||||||
|
|
||||||
if (handshakeType === TlsHandshakeType.CLIENT_HELLO) {
|
|
||||||
// Found a ClientHello in the next record, return the entire buffer
|
|
||||||
log(`Found ClientHello in subsequent record, returning full buffer`);
|
|
||||||
this.fragmentedBuffers.delete(connectionId);
|
|
||||||
return newBuffer;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} catch (e) {
|
|
||||||
log(`Error checking reassembled buffer completeness: ${e}`);
|
|
||||||
}
|
|
||||||
|
|
||||||
return undefined; // Still need more fragments
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Parses a TLS ClientHello message and extracts all components
|
|
||||||
*
|
|
||||||
* @param buffer The buffer containing the ClientHello message
|
|
||||||
* @param logger Optional logging function
|
|
||||||
* @returns Parsed ClientHello or undefined if parsing failed
|
|
||||||
*/
|
|
||||||
public static parseClientHello(
|
|
||||||
buffer: Buffer,
|
|
||||||
logger?: LoggerFunction
|
|
||||||
): ClientHelloParseResult {
|
|
||||||
const log = logger || (() => {});
|
|
||||||
const result: ClientHelloParseResult = {
|
|
||||||
isValid: false,
|
|
||||||
hasSessionId: false,
|
|
||||||
extensions: [],
|
|
||||||
hasSessionTicket: false,
|
|
||||||
hasPsk: false,
|
|
||||||
hasEarlyData: false
|
|
||||||
};
|
|
||||||
|
|
||||||
try {
|
|
||||||
// Check basic validity
|
|
||||||
if (buffer.length < 5) {
|
|
||||||
result.error = 'Buffer too small for TLS record header';
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check record type (must be HANDSHAKE)
|
|
||||||
if (buffer[0] !== TlsRecordType.HANDSHAKE) {
|
|
||||||
result.error = `Not a TLS handshake record: ${buffer[0]}`;
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get TLS version from record header
|
|
||||||
const majorVersion = buffer[1];
|
|
||||||
const minorVersion = buffer[2];
|
|
||||||
result.version = [majorVersion, minorVersion];
|
|
||||||
log(`TLS record version: ${majorVersion}.${minorVersion}`);
|
|
||||||
|
|
||||||
// Parse record length (bytes 3-4, big-endian)
|
|
||||||
const recordLength = (buffer[3] << 8) + buffer[4];
|
|
||||||
log(`Record length: ${recordLength}`);
|
|
||||||
|
|
||||||
// Validate record length against buffer size
|
|
||||||
if (buffer.length < recordLength + 5) {
|
|
||||||
result.error = 'Buffer smaller than expected record length';
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start of handshake message in the buffer
|
|
||||||
let pos = 5;
|
|
||||||
|
|
||||||
// Check handshake type (must be CLIENT_HELLO)
|
|
||||||
if (buffer[pos] !== TlsHandshakeType.CLIENT_HELLO) {
|
|
||||||
result.error = `Not a ClientHello message: ${buffer[pos]}`;
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Skip handshake type (1 byte)
|
|
||||||
pos += 1;
|
|
||||||
|
|
||||||
// Parse handshake length (3 bytes, big-endian)
|
|
||||||
const handshakeLength = (buffer[pos] << 16) + (buffer[pos + 1] << 8) + buffer[pos + 2];
|
|
||||||
log(`Handshake length: ${handshakeLength}`);
|
|
||||||
|
|
||||||
// Skip handshake length (3 bytes)
|
|
||||||
pos += 3;
|
|
||||||
|
|
||||||
// Check client version (2 bytes)
|
|
||||||
const clientMajorVersion = buffer[pos];
|
|
||||||
const clientMinorVersion = buffer[pos + 1];
|
|
||||||
log(`Client version: ${clientMajorVersion}.${clientMinorVersion}`);
|
|
||||||
|
|
||||||
// Skip client version (2 bytes)
|
|
||||||
pos += 2;
|
|
||||||
|
|
||||||
// Extract client random (32 bytes)
|
|
||||||
if (pos + 32 > buffer.length) {
|
|
||||||
result.error = 'Buffer too small for client random';
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
result.random = buffer.slice(pos, pos + 32);
|
|
||||||
log(`Client random: ${result.random.toString('hex')}`);
|
|
||||||
|
|
||||||
// Skip client random (32 bytes)
|
|
||||||
pos += 32;
|
|
||||||
|
|
||||||
// Parse session ID
|
|
||||||
if (pos + 1 > buffer.length) {
|
|
||||||
result.error = 'Buffer too small for session ID length';
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
const sessionIdLength = buffer[pos];
|
|
||||||
log(`Session ID length: ${sessionIdLength}`);
|
|
||||||
pos += 1;
|
|
||||||
|
|
||||||
result.hasSessionId = sessionIdLength > 0;
|
|
||||||
|
|
||||||
if (sessionIdLength > 0) {
|
|
||||||
if (pos + sessionIdLength > buffer.length) {
|
|
||||||
result.error = 'Buffer too small for session ID';
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
result.sessionId = buffer.slice(pos, pos + sessionIdLength);
|
|
||||||
log(`Session ID: ${result.sessionId.toString('hex')}`);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Skip session ID
|
|
||||||
pos += sessionIdLength;
|
|
||||||
|
|
||||||
// Check if we have enough bytes left for cipher suites
|
|
||||||
if (pos + 2 > buffer.length) {
|
|
||||||
result.error = 'Buffer too small for cipher suites length';
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse cipher suites length (2 bytes, big-endian)
|
|
||||||
const cipherSuitesLength = (buffer[pos] << 8) + buffer[pos + 1];
|
|
||||||
log(`Cipher suites length: ${cipherSuitesLength}`);
|
|
||||||
pos += 2;
|
|
||||||
|
|
||||||
// Extract cipher suites
|
|
||||||
if (pos + cipherSuitesLength > buffer.length) {
|
|
||||||
result.error = 'Buffer too small for cipher suites';
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
result.cipherSuites = buffer.slice(pos, pos + cipherSuitesLength);
|
|
||||||
|
|
||||||
// Skip cipher suites
|
|
||||||
pos += cipherSuitesLength;
|
|
||||||
|
|
||||||
// Check if we have enough bytes left for compression methods
|
|
||||||
if (pos + 1 > buffer.length) {
|
|
||||||
result.error = 'Buffer too small for compression methods length';
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse compression methods length (1 byte)
|
|
||||||
const compressionMethodsLength = buffer[pos];
|
|
||||||
log(`Compression methods length: ${compressionMethodsLength}`);
|
|
||||||
pos += 1;
|
|
||||||
|
|
||||||
// Extract compression methods
|
|
||||||
if (pos + compressionMethodsLength > buffer.length) {
|
|
||||||
result.error = 'Buffer too small for compression methods';
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
result.compressionMethods = buffer.slice(pos, pos + compressionMethodsLength);
|
|
||||||
|
|
||||||
// Skip compression methods
|
|
||||||
pos += compressionMethodsLength;
|
|
||||||
|
|
||||||
// Check if we have enough bytes for extensions length
|
|
||||||
if (pos + 2 > buffer.length) {
|
|
||||||
// No extensions present - this is valid for older TLS versions
|
|
||||||
result.isValid = true;
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse extensions length (2 bytes, big-endian)
|
|
||||||
const extensionsLength = (buffer[pos] << 8) + buffer[pos + 1];
|
|
||||||
log(`Extensions length: ${extensionsLength}`);
|
|
||||||
pos += 2;
|
|
||||||
|
|
||||||
// Extensions end position
|
|
||||||
const extensionsEnd = pos + extensionsLength;
|
|
||||||
|
|
||||||
// Check if extensions length is valid
|
|
||||||
if (extensionsEnd > buffer.length) {
|
|
||||||
result.error = 'Extensions length exceeds buffer size';
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Iterate through extensions
|
|
||||||
const serverNames: string[] = [];
|
|
||||||
|
|
||||||
while (pos + 4 <= extensionsEnd) {
|
|
||||||
// Parse extension type (2 bytes, big-endian)
|
|
||||||
const extensionType = (buffer[pos] << 8) + buffer[pos + 1];
|
|
||||||
log(`Extension type: 0x${extensionType.toString(16).padStart(4, '0')}`);
|
|
||||||
pos += 2;
|
|
||||||
|
|
||||||
// Parse extension length (2 bytes, big-endian)
|
|
||||||
const extensionLength = (buffer[pos] << 8) + buffer[pos + 1];
|
|
||||||
log(`Extension length: ${extensionLength}`);
|
|
||||||
pos += 2;
|
|
||||||
|
|
||||||
// Extract extension data
|
|
||||||
if (pos + extensionLength > extensionsEnd) {
|
|
||||||
result.error = `Extension ${extensionType} data exceeds bounds`;
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
const extensionData = buffer.slice(pos, pos + extensionLength);
|
|
||||||
|
|
||||||
// Record all extensions
|
|
||||||
result.extensions.push({
|
|
||||||
type: extensionType,
|
|
||||||
length: extensionLength,
|
|
||||||
data: extensionData
|
|
||||||
});
|
|
||||||
|
|
||||||
// Track specific extension types
|
|
||||||
if (extensionType === TlsExtensionType.SERVER_NAME) {
|
|
||||||
// Server Name Indication (SNI)
|
|
||||||
this.parseServerNameExtension(extensionData, serverNames, logger);
|
|
||||||
} else if (extensionType === TlsExtensionType.SESSION_TICKET) {
|
|
||||||
// Session ticket
|
|
||||||
result.hasSessionTicket = true;
|
|
||||||
} else if (extensionType === TlsExtensionType.PRE_SHARED_KEY) {
|
|
||||||
// TLS 1.3 PSK
|
|
||||||
result.hasPsk = true;
|
|
||||||
} else if (extensionType === TlsExtensionType.EARLY_DATA) {
|
|
||||||
// TLS 1.3 Early Data (0-RTT)
|
|
||||||
result.hasEarlyData = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Move to next extension
|
|
||||||
pos += extensionLength;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Store any server names found
|
|
||||||
if (serverNames.length > 0) {
|
|
||||||
result.serverNameList = serverNames;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Mark as valid if we get here
|
|
||||||
result.isValid = true;
|
|
||||||
return result;
|
|
||||||
} catch (error) {
|
|
||||||
const errorMessage = error instanceof Error ? error.message : String(error);
|
|
||||||
log(`Error parsing ClientHello: ${errorMessage}`);
|
|
||||||
result.error = errorMessage;
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Parses the server name extension data and extracts hostnames
|
|
||||||
*
|
|
||||||
* @param data Extension data buffer
|
|
||||||
* @param serverNames Array to populate with found server names
|
|
||||||
* @param logger Optional logging function
|
|
||||||
* @returns true if parsing succeeded
|
|
||||||
*/
|
|
||||||
private static parseServerNameExtension(
|
|
||||||
data: Buffer,
|
|
||||||
serverNames: string[],
|
|
||||||
logger?: LoggerFunction
|
|
||||||
): boolean {
|
|
||||||
const log = logger || (() => {});
|
|
||||||
|
|
||||||
try {
|
|
||||||
// Need at least 2 bytes for server name list length
|
|
||||||
if (data.length < 2) {
|
|
||||||
log('SNI extension too small for server name list length');
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse server name list length (2 bytes)
|
|
||||||
const listLength = (data[0] << 8) + data[1];
|
|
||||||
|
|
||||||
// Skip to first name entry
|
|
||||||
let pos = 2;
|
|
||||||
|
|
||||||
// End of list
|
|
||||||
const listEnd = pos + listLength;
|
|
||||||
|
|
||||||
// Validate length
|
|
||||||
if (listEnd > data.length) {
|
|
||||||
log('SNI server name list exceeds extension data');
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Process all name entries
|
|
||||||
while (pos + 3 <= listEnd) {
|
|
||||||
// Name type (1 byte)
|
|
||||||
const nameType = data[pos];
|
|
||||||
pos += 1;
|
|
||||||
|
|
||||||
// For hostname, type must be 0
|
|
||||||
if (nameType !== 0) {
|
|
||||||
// Skip this entry
|
|
||||||
if (pos + 2 <= listEnd) {
|
|
||||||
const nameLength = (data[pos] << 8) + data[pos + 1];
|
|
||||||
pos += 2 + nameLength;
|
|
||||||
continue;
|
|
||||||
} else {
|
|
||||||
log('Malformed SNI entry');
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse hostname length (2 bytes)
|
|
||||||
if (pos + 2 > listEnd) {
|
|
||||||
log('SNI extension truncated');
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
const nameLength = (data[pos] << 8) + data[pos + 1];
|
|
||||||
pos += 2;
|
|
||||||
|
|
||||||
// Extract hostname
|
|
||||||
if (pos + nameLength > listEnd) {
|
|
||||||
log('SNI hostname truncated');
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Extract the hostname as UTF-8
|
|
||||||
try {
|
|
||||||
const hostname = data.slice(pos, pos + nameLength).toString('utf8');
|
|
||||||
log(`Found SNI hostname: ${hostname}`);
|
|
||||||
serverNames.push(hostname);
|
|
||||||
} catch (err) {
|
|
||||||
log(`Error extracting hostname: ${err}`);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Move to next entry
|
|
||||||
pos += nameLength;
|
|
||||||
}
|
|
||||||
|
|
||||||
return serverNames.length > 0;
|
|
||||||
} catch (error) {
|
|
||||||
log(`Error parsing SNI extension: ${error}`);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Determines if a ClientHello contains session resumption indicators
|
|
||||||
*
|
|
||||||
* @param buffer The ClientHello buffer
|
|
||||||
* @param logger Optional logging function
|
|
||||||
* @returns Session resumption result
|
|
||||||
*/
|
|
||||||
public static hasSessionResumption(
|
|
||||||
buffer: Buffer,
|
|
||||||
logger?: LoggerFunction
|
|
||||||
): SessionResumptionResult {
|
|
||||||
const log = logger || (() => {});
|
|
||||||
|
|
||||||
if (!TlsUtils.isClientHello(buffer)) {
|
|
||||||
return { isResumption: false, hasSNI: false };
|
|
||||||
}
|
|
||||||
|
|
||||||
const parseResult = this.parseClientHello(buffer, logger);
|
|
||||||
if (!parseResult.isValid) {
|
|
||||||
log(`ClientHello parse failed: ${parseResult.error}`);
|
|
||||||
return { isResumption: false, hasSNI: false };
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check resumption indicators
|
|
||||||
const hasSessionId = parseResult.hasSessionId;
|
|
||||||
const hasSessionTicket = parseResult.hasSessionTicket;
|
|
||||||
const hasPsk = parseResult.hasPsk;
|
|
||||||
const hasEarlyData = parseResult.hasEarlyData;
|
|
||||||
|
|
||||||
// Check for SNI
|
|
||||||
const hasSNI = !!parseResult.serverNameList && parseResult.serverNameList.length > 0;
|
|
||||||
|
|
||||||
// Consider it a resumption if any resumption mechanism is present
|
|
||||||
const isResumption = hasSessionTicket || hasPsk || hasEarlyData ||
|
|
||||||
(hasSessionId && !hasPsk); // Legacy resumption
|
|
||||||
|
|
||||||
// Log details
|
|
||||||
if (isResumption) {
|
|
||||||
log(
|
|
||||||
'Session resumption detected: ' +
|
|
||||||
(hasSessionTicket ? 'session ticket, ' : '') +
|
|
||||||
(hasPsk ? 'PSK, ' : '') +
|
|
||||||
(hasEarlyData ? 'early data, ' : '') +
|
|
||||||
(hasSessionId ? 'session ID' : '') +
|
|
||||||
(hasSNI ? ', with SNI' : ', without SNI')
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
return { isResumption, hasSNI };
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Checks if a ClientHello appears to be from a tab reactivation
|
|
||||||
*
|
|
||||||
* @param buffer The ClientHello buffer
|
|
||||||
* @param logger Optional logging function
|
|
||||||
* @returns true if it appears to be a tab reactivation
|
|
||||||
*/
|
|
||||||
public static isTabReactivationHandshake(
|
|
||||||
buffer: Buffer,
|
|
||||||
logger?: LoggerFunction
|
|
||||||
): boolean {
|
|
||||||
const log = logger || (() => {});
|
|
||||||
|
|
||||||
if (!TlsUtils.isClientHello(buffer)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse the ClientHello
|
|
||||||
const parseResult = this.parseClientHello(buffer, logger);
|
|
||||||
if (!parseResult.isValid) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Tab reactivation pattern: session identifier + (ticket or PSK) but no SNI
|
|
||||||
const hasSessionId = parseResult.hasSessionId;
|
|
||||||
const hasSessionTicket = parseResult.hasSessionTicket;
|
|
||||||
const hasPsk = parseResult.hasPsk;
|
|
||||||
const hasSNI = !!parseResult.serverNameList && parseResult.serverNameList.length > 0;
|
|
||||||
|
|
||||||
if ((hasSessionId && (hasSessionTicket || hasPsk)) && !hasSNI) {
|
|
||||||
log('Detected tab reactivation pattern: session resumption without SNI');
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,6 +0,0 @@
|
|||||||
/**
|
|
||||||
* TLS SNI (Server Name Indication) protocol utilities
|
|
||||||
*/
|
|
||||||
|
|
||||||
export * from './client-hello-parser.js';
|
|
||||||
export * from './sni-extraction.js';
|
|
||||||
@@ -1,353 +0,0 @@
|
|||||||
import { Buffer } from 'node:buffer';
|
|
||||||
import { TlsExtensionType, TlsUtils } from '../utils/tls-utils.js';
|
|
||||||
import {
|
|
||||||
ClientHelloParser,
|
|
||||||
type LoggerFunction
|
|
||||||
} from './client-hello-parser.js';
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Connection tracking information
|
|
||||||
*/
|
|
||||||
export interface ConnectionInfo {
|
|
||||||
sourceIp: string;
|
|
||||||
sourcePort: number;
|
|
||||||
destIp: string;
|
|
||||||
destPort: number;
|
|
||||||
timestamp?: number;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Utilities for extracting SNI information from TLS handshakes
|
|
||||||
*/
|
|
||||||
export class SniExtraction {
|
|
||||||
/**
|
|
||||||
* Extracts the SNI (Server Name Indication) from a TLS ClientHello message.
|
|
||||||
*
|
|
||||||
* @param buffer The buffer containing the TLS ClientHello message
|
|
||||||
* @param logger Optional logging function
|
|
||||||
* @returns The extracted server name or undefined if not found
|
|
||||||
*/
|
|
||||||
public static extractSNI(buffer: Buffer, logger?: LoggerFunction): string | undefined {
|
|
||||||
const log = logger || (() => {});
|
|
||||||
|
|
||||||
try {
|
|
||||||
// Parse the ClientHello
|
|
||||||
const parseResult = ClientHelloParser.parseClientHello(buffer, logger);
|
|
||||||
if (!parseResult.isValid) {
|
|
||||||
log(`Failed to parse ClientHello: ${parseResult.error}`);
|
|
||||||
return undefined;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if ServerName extension was found
|
|
||||||
if (parseResult.serverNameList && parseResult.serverNameList.length > 0) {
|
|
||||||
// Use the first hostname (most common case)
|
|
||||||
const serverName = parseResult.serverNameList[0];
|
|
||||||
log(`Found SNI: ${serverName}`);
|
|
||||||
return serverName;
|
|
||||||
}
|
|
||||||
|
|
||||||
log('No SNI extension found in ClientHello');
|
|
||||||
return undefined;
|
|
||||||
} catch (error) {
|
|
||||||
log(`Error extracting SNI: ${error instanceof Error ? error.message : String(error)}`);
|
|
||||||
return undefined;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Attempts to extract SNI from the PSK extension in a TLS 1.3 ClientHello.
|
|
||||||
*
|
|
||||||
* In TLS 1.3, when a client attempts to resume a session, it may include
|
|
||||||
* the server name in the PSK identity hint rather than in the SNI extension.
|
|
||||||
*
|
|
||||||
* @param buffer The buffer containing the TLS ClientHello message
|
|
||||||
* @param logger Optional logging function
|
|
||||||
* @returns The extracted server name or undefined if not found
|
|
||||||
*/
|
|
||||||
public static extractSNIFromPSKExtension(
|
|
||||||
buffer: Buffer,
|
|
||||||
logger?: LoggerFunction
|
|
||||||
): string | undefined {
|
|
||||||
const log = logger || (() => {});
|
|
||||||
|
|
||||||
try {
|
|
||||||
// Ensure this is a ClientHello
|
|
||||||
if (!TlsUtils.isClientHello(buffer)) {
|
|
||||||
log('Not a ClientHello message');
|
|
||||||
return undefined;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse the ClientHello to find PSK extension
|
|
||||||
const parseResult = ClientHelloParser.parseClientHello(buffer, logger);
|
|
||||||
if (!parseResult.isValid || !parseResult.extensions) {
|
|
||||||
return undefined;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Find the PSK extension
|
|
||||||
const pskExtension = parseResult.extensions.find(ext =>
|
|
||||||
ext.type === TlsExtensionType.PRE_SHARED_KEY);
|
|
||||||
|
|
||||||
if (!pskExtension) {
|
|
||||||
log('No PSK extension found');
|
|
||||||
return undefined;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse the PSK extension data
|
|
||||||
const data = pskExtension.data;
|
|
||||||
|
|
||||||
// PSK extension structure:
|
|
||||||
// 2 bytes: identities list length
|
|
||||||
if (data.length < 2) return undefined;
|
|
||||||
|
|
||||||
const identitiesLength = (data[0] << 8) + data[1];
|
|
||||||
let pos = 2;
|
|
||||||
|
|
||||||
// End of identities list
|
|
||||||
const identitiesEnd = pos + identitiesLength;
|
|
||||||
if (identitiesEnd > data.length) return undefined;
|
|
||||||
|
|
||||||
// Process each PSK identity
|
|
||||||
while (pos + 2 <= identitiesEnd) {
|
|
||||||
// Identity length (2 bytes)
|
|
||||||
if (pos + 2 > identitiesEnd) break;
|
|
||||||
|
|
||||||
const identityLength = (data[pos] << 8) + data[pos + 1];
|
|
||||||
pos += 2;
|
|
||||||
|
|
||||||
if (pos + identityLength > identitiesEnd) break;
|
|
||||||
|
|
||||||
// Try to extract hostname from identity
|
|
||||||
// Chrome often embeds the hostname in the PSK identity
|
|
||||||
// This is a heuristic as there's no standard format
|
|
||||||
if (identityLength > 0) {
|
|
||||||
const identity = data.slice(pos, pos + identityLength);
|
|
||||||
|
|
||||||
// Skip identity bytes
|
|
||||||
pos += identityLength;
|
|
||||||
|
|
||||||
// Skip obfuscated ticket age (4 bytes)
|
|
||||||
if (pos + 4 <= identitiesEnd) {
|
|
||||||
pos += 4;
|
|
||||||
} else {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Try to parse the identity as UTF-8
|
|
||||||
try {
|
|
||||||
const identityStr = identity.toString('utf8');
|
|
||||||
log(`PSK identity: ${identityStr}`);
|
|
||||||
|
|
||||||
// Check if the identity contains hostname hints
|
|
||||||
// Chrome often embeds the hostname in a known format
|
|
||||||
// Try to extract using common patterns
|
|
||||||
|
|
||||||
// Pattern 1: Look for domain name pattern
|
|
||||||
const domainPattern =
|
|
||||||
/([a-z0-9]([a-z0-9-]{0,61}[a-z0-9])?\.)+[a-z0-9]([a-z0-9-]{0,61}[a-z0-9])?/i;
|
|
||||||
const domainMatch = identityStr.match(domainPattern);
|
|
||||||
if (domainMatch && domainMatch[0]) {
|
|
||||||
log(`Found domain in PSK identity: ${domainMatch[0]}`);
|
|
||||||
return domainMatch[0];
|
|
||||||
}
|
|
||||||
|
|
||||||
// Pattern 2: Chrome sometimes uses a specific format with delimiters
|
|
||||||
// This is a heuristic approach since the format isn't standardized
|
|
||||||
const parts = identityStr.split('|');
|
|
||||||
if (parts.length > 1) {
|
|
||||||
for (const part of parts) {
|
|
||||||
if (part.includes('.') && !part.includes('/')) {
|
|
||||||
const possibleDomain = part.trim();
|
|
||||||
if (/^[a-z0-9.-]+$/i.test(possibleDomain)) {
|
|
||||||
log(`Found possible domain in PSK delimiter format: ${possibleDomain}`);
|
|
||||||
return possibleDomain;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} catch (e) {
|
|
||||||
log('Failed to parse PSK identity as UTF-8');
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
log('No hostname found in PSK extension');
|
|
||||||
return undefined;
|
|
||||||
} catch (error) {
|
|
||||||
log(`Error parsing PSK: ${error instanceof Error ? error.message : String(error)}`);
|
|
||||||
return undefined;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Main entry point for SNI extraction with support for fragmented messages
|
|
||||||
* and session resumption edge cases.
|
|
||||||
*
|
|
||||||
* @param buffer The buffer containing TLS data
|
|
||||||
* @param connectionInfo Connection tracking information
|
|
||||||
* @param logger Optional logging function
|
|
||||||
* @param cachedSni Optional previously cached SNI value
|
|
||||||
* @returns The extracted server name or undefined
|
|
||||||
*/
|
|
||||||
public static extractSNIWithResumptionSupport(
|
|
||||||
buffer: Buffer,
|
|
||||||
connectionInfo?: ConnectionInfo,
|
|
||||||
logger?: LoggerFunction,
|
|
||||||
cachedSni?: string
|
|
||||||
): string | undefined {
|
|
||||||
const log = logger || (() => {});
|
|
||||||
|
|
||||||
// Log buffer details for debugging
|
|
||||||
if (logger) {
|
|
||||||
log(`Buffer size: ${buffer.length} bytes`);
|
|
||||||
log(`Buffer starts with: ${buffer.slice(0, Math.min(10, buffer.length)).toString('hex')}`);
|
|
||||||
|
|
||||||
if (buffer.length >= 5) {
|
|
||||||
const recordType = buffer[0];
|
|
||||||
const majorVersion = buffer[1];
|
|
||||||
const minorVersion = buffer[2];
|
|
||||||
const recordLength = (buffer[3] << 8) + buffer[4];
|
|
||||||
|
|
||||||
log(
|
|
||||||
`TLS Record: type=${recordType}, version=${majorVersion}.${minorVersion}, length=${recordLength}`
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if we need to handle fragmented packets
|
|
||||||
let processBuffer = buffer;
|
|
||||||
if (connectionInfo) {
|
|
||||||
const connectionId = TlsUtils.createConnectionId(connectionInfo);
|
|
||||||
const reassembledBuffer = ClientHelloParser.handleFragmentedClientHello(
|
|
||||||
buffer,
|
|
||||||
connectionId,
|
|
||||||
logger
|
|
||||||
);
|
|
||||||
|
|
||||||
if (!reassembledBuffer) {
|
|
||||||
log(`Waiting for more fragments on connection ${connectionId}`);
|
|
||||||
return undefined; // Need more fragments to complete ClientHello
|
|
||||||
}
|
|
||||||
|
|
||||||
processBuffer = reassembledBuffer;
|
|
||||||
log(`Using reassembled buffer of length ${processBuffer.length}`);
|
|
||||||
}
|
|
||||||
|
|
||||||
// First try the standard SNI extraction
|
|
||||||
const standardSni = this.extractSNI(processBuffer, logger);
|
|
||||||
if (standardSni) {
|
|
||||||
log(`Found standard SNI: ${standardSni}`);
|
|
||||||
return standardSni;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check for session resumption when standard SNI extraction fails
|
|
||||||
if (TlsUtils.isClientHello(processBuffer)) {
|
|
||||||
const resumptionInfo = ClientHelloParser.hasSessionResumption(processBuffer, logger);
|
|
||||||
|
|
||||||
if (resumptionInfo.isResumption) {
|
|
||||||
log(`Detected session resumption in ClientHello without standard SNI`);
|
|
||||||
|
|
||||||
// Try to extract SNI from PSK extension
|
|
||||||
const pskSni = this.extractSNIFromPSKExtension(processBuffer, logger);
|
|
||||||
if (pskSni) {
|
|
||||||
log(`Extracted SNI from PSK extension: ${pskSni}`);
|
|
||||||
return pskSni;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// If cached SNI was provided, use it for application data packets
|
|
||||||
if (cachedSni && TlsUtils.isTlsApplicationData(buffer)) {
|
|
||||||
log(`Using provided cached SNI for application data: ${cachedSni}`);
|
|
||||||
return cachedSni;
|
|
||||||
}
|
|
||||||
|
|
||||||
return undefined;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Unified method for processing a TLS packet and extracting SNI.
|
|
||||||
* Main entry point for SNI extraction that handles all edge cases.
|
|
||||||
*
|
|
||||||
* @param buffer The buffer containing TLS data
|
|
||||||
* @param connectionInfo Connection tracking information
|
|
||||||
* @param logger Optional logging function
|
|
||||||
* @param cachedSni Optional previously cached SNI value
|
|
||||||
* @returns The extracted server name or undefined
|
|
||||||
*/
|
|
||||||
public static processTlsPacket(
|
|
||||||
buffer: Buffer,
|
|
||||||
connectionInfo: ConnectionInfo,
|
|
||||||
logger?: LoggerFunction,
|
|
||||||
cachedSni?: string
|
|
||||||
): string | undefined {
|
|
||||||
const log = logger || (() => {});
|
|
||||||
|
|
||||||
// Add timestamp if not provided
|
|
||||||
if (!connectionInfo.timestamp) {
|
|
||||||
connectionInfo.timestamp = Date.now();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if this is a TLS handshake or application data
|
|
||||||
if (!TlsUtils.isTlsHandshake(buffer) && !TlsUtils.isTlsApplicationData(buffer)) {
|
|
||||||
log('Not a TLS handshake or application data packet');
|
|
||||||
return undefined;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create connection ID for tracking
|
|
||||||
const connectionId = TlsUtils.createConnectionId(connectionInfo);
|
|
||||||
log(`Processing TLS packet for connection ${connectionId}, buffer length: ${buffer.length}`);
|
|
||||||
|
|
||||||
// Handle application data with cached SNI (for connection racing)
|
|
||||||
if (TlsUtils.isTlsApplicationData(buffer)) {
|
|
||||||
// If explicit cachedSni was provided, use it
|
|
||||||
if (cachedSni) {
|
|
||||||
log(`Using provided cached SNI for application data: ${cachedSni}`);
|
|
||||||
return cachedSni;
|
|
||||||
}
|
|
||||||
|
|
||||||
log('Application data packet without cached SNI, cannot determine hostname');
|
|
||||||
return undefined;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Enhanced session resumption detection
|
|
||||||
if (TlsUtils.isClientHello(buffer)) {
|
|
||||||
const resumptionInfo = ClientHelloParser.hasSessionResumption(buffer, logger);
|
|
||||||
|
|
||||||
if (resumptionInfo.isResumption) {
|
|
||||||
log(`Session resumption detected in TLS packet`);
|
|
||||||
|
|
||||||
// Always try standard SNI extraction first
|
|
||||||
const standardSni = this.extractSNI(buffer, logger);
|
|
||||||
if (standardSni) {
|
|
||||||
log(`Found standard SNI in session resumption: ${standardSni}`);
|
|
||||||
return standardSni;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Enhanced session resumption SNI extraction
|
|
||||||
// Try extracting from PSK identity
|
|
||||||
const pskSni = this.extractSNIFromPSKExtension(buffer, logger);
|
|
||||||
if (pskSni) {
|
|
||||||
log(`Extracted SNI from PSK extension: ${pskSni}`);
|
|
||||||
return pskSni;
|
|
||||||
}
|
|
||||||
|
|
||||||
log(`Session resumption without extractable SNI`);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// For handshake messages, try the full extraction process
|
|
||||||
const sni = this.extractSNIWithResumptionSupport(buffer, connectionInfo, logger);
|
|
||||||
|
|
||||||
if (sni) {
|
|
||||||
log(`Successfully extracted SNI: ${sni}`);
|
|
||||||
return sni;
|
|
||||||
}
|
|
||||||
|
|
||||||
// If we couldn't extract an SNI, check if this is a valid ClientHello
|
|
||||||
if (TlsUtils.isClientHello(buffer)) {
|
|
||||||
log('Valid ClientHello detected, but no SNI extracted - might need more data');
|
|
||||||
}
|
|
||||||
|
|
||||||
return undefined;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,3 +0,0 @@
|
|||||||
/**
|
|
||||||
* TLS utilities
|
|
||||||
*/
|
|
||||||
@@ -1,201 +0,0 @@
|
|||||||
import * as plugins from '../../../plugins.js';
|
|
||||||
|
|
||||||
/**
|
|
||||||
* TLS record types as defined in various RFCs
|
|
||||||
*/
|
|
||||||
export enum TlsRecordType {
|
|
||||||
CHANGE_CIPHER_SPEC = 20,
|
|
||||||
ALERT = 21,
|
|
||||||
HANDSHAKE = 22,
|
|
||||||
APPLICATION_DATA = 23,
|
|
||||||
HEARTBEAT = 24, // RFC 6520
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* TLS handshake message types
|
|
||||||
*/
|
|
||||||
export enum TlsHandshakeType {
|
|
||||||
HELLO_REQUEST = 0,
|
|
||||||
CLIENT_HELLO = 1,
|
|
||||||
SERVER_HELLO = 2,
|
|
||||||
NEW_SESSION_TICKET = 4,
|
|
||||||
ENCRYPTED_EXTENSIONS = 8, // TLS 1.3
|
|
||||||
CERTIFICATE = 11,
|
|
||||||
SERVER_KEY_EXCHANGE = 12,
|
|
||||||
CERTIFICATE_REQUEST = 13,
|
|
||||||
SERVER_HELLO_DONE = 14,
|
|
||||||
CERTIFICATE_VERIFY = 15,
|
|
||||||
CLIENT_KEY_EXCHANGE = 16,
|
|
||||||
FINISHED = 20,
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* TLS extension types
|
|
||||||
*/
|
|
||||||
export enum TlsExtensionType {
|
|
||||||
SERVER_NAME = 0, // SNI
|
|
||||||
MAX_FRAGMENT_LENGTH = 1,
|
|
||||||
CLIENT_CERTIFICATE_URL = 2,
|
|
||||||
TRUSTED_CA_KEYS = 3,
|
|
||||||
TRUNCATED_HMAC = 4,
|
|
||||||
STATUS_REQUEST = 5, // OCSP
|
|
||||||
SUPPORTED_GROUPS = 10, // Previously named "elliptic_curves"
|
|
||||||
EC_POINT_FORMATS = 11,
|
|
||||||
SIGNATURE_ALGORITHMS = 13,
|
|
||||||
APPLICATION_LAYER_PROTOCOL_NEGOTIATION = 16, // ALPN
|
|
||||||
SIGNED_CERTIFICATE_TIMESTAMP = 18, // Certificate Transparency
|
|
||||||
PADDING = 21,
|
|
||||||
SESSION_TICKET = 35,
|
|
||||||
PRE_SHARED_KEY = 41, // TLS 1.3
|
|
||||||
EARLY_DATA = 42, // TLS 1.3 0-RTT
|
|
||||||
SUPPORTED_VERSIONS = 43, // TLS 1.3
|
|
||||||
COOKIE = 44, // TLS 1.3
|
|
||||||
PSK_KEY_EXCHANGE_MODES = 45, // TLS 1.3
|
|
||||||
CERTIFICATE_AUTHORITIES = 47, // TLS 1.3
|
|
||||||
POST_HANDSHAKE_AUTH = 49, // TLS 1.3
|
|
||||||
SIGNATURE_ALGORITHMS_CERT = 50, // TLS 1.3
|
|
||||||
KEY_SHARE = 51, // TLS 1.3
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* TLS alert levels
|
|
||||||
*/
|
|
||||||
export enum TlsAlertLevel {
|
|
||||||
WARNING = 1,
|
|
||||||
FATAL = 2,
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* TLS alert description codes
|
|
||||||
*/
|
|
||||||
export enum TlsAlertDescription {
|
|
||||||
CLOSE_NOTIFY = 0,
|
|
||||||
UNEXPECTED_MESSAGE = 10,
|
|
||||||
BAD_RECORD_MAC = 20,
|
|
||||||
DECRYPTION_FAILED = 21, // TLS 1.0 only
|
|
||||||
RECORD_OVERFLOW = 22,
|
|
||||||
DECOMPRESSION_FAILURE = 30, // TLS 1.2 and below
|
|
||||||
HANDSHAKE_FAILURE = 40,
|
|
||||||
NO_CERTIFICATE = 41, // SSLv3 only
|
|
||||||
BAD_CERTIFICATE = 42,
|
|
||||||
UNSUPPORTED_CERTIFICATE = 43,
|
|
||||||
CERTIFICATE_REVOKED = 44,
|
|
||||||
CERTIFICATE_EXPIRED = 45,
|
|
||||||
CERTIFICATE_UNKNOWN = 46,
|
|
||||||
ILLEGAL_PARAMETER = 47,
|
|
||||||
UNKNOWN_CA = 48,
|
|
||||||
ACCESS_DENIED = 49,
|
|
||||||
DECODE_ERROR = 50,
|
|
||||||
DECRYPT_ERROR = 51,
|
|
||||||
EXPORT_RESTRICTION = 60, // TLS 1.0 only
|
|
||||||
PROTOCOL_VERSION = 70,
|
|
||||||
INSUFFICIENT_SECURITY = 71,
|
|
||||||
INTERNAL_ERROR = 80,
|
|
||||||
INAPPROPRIATE_FALLBACK = 86,
|
|
||||||
USER_CANCELED = 90,
|
|
||||||
NO_RENEGOTIATION = 100, // TLS 1.2 and below
|
|
||||||
MISSING_EXTENSION = 109, // TLS 1.3
|
|
||||||
UNSUPPORTED_EXTENSION = 110, // TLS 1.3
|
|
||||||
CERTIFICATE_REQUIRED = 111, // TLS 1.3
|
|
||||||
UNRECOGNIZED_NAME = 112,
|
|
||||||
BAD_CERTIFICATE_STATUS_RESPONSE = 113,
|
|
||||||
BAD_CERTIFICATE_HASH_VALUE = 114, // TLS 1.2 and below
|
|
||||||
UNKNOWN_PSK_IDENTITY = 115,
|
|
||||||
CERTIFICATE_REQUIRED_1_3 = 116, // TLS 1.3
|
|
||||||
NO_APPLICATION_PROTOCOL = 120,
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* TLS version codes (major.minor)
|
|
||||||
*/
|
|
||||||
export const TlsVersion = {
|
|
||||||
SSL3: [0x03, 0x00],
|
|
||||||
TLS1_0: [0x03, 0x01],
|
|
||||||
TLS1_1: [0x03, 0x02],
|
|
||||||
TLS1_2: [0x03, 0x03],
|
|
||||||
TLS1_3: [0x03, 0x04],
|
|
||||||
};
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Utility functions for TLS protocol operations
|
|
||||||
*/
|
|
||||||
export class TlsUtils {
|
|
||||||
/**
|
|
||||||
* Checks if a buffer contains a TLS handshake record
|
|
||||||
* @param buffer The buffer to check
|
|
||||||
* @returns true if the buffer starts with a TLS handshake record
|
|
||||||
*/
|
|
||||||
public static isTlsHandshake(buffer: Buffer): boolean {
|
|
||||||
return buffer.length > 0 && buffer[0] === TlsRecordType.HANDSHAKE;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Checks if a buffer contains TLS application data
|
|
||||||
* @param buffer The buffer to check
|
|
||||||
* @returns true if the buffer starts with a TLS application data record
|
|
||||||
*/
|
|
||||||
public static isTlsApplicationData(buffer: Buffer): boolean {
|
|
||||||
return buffer.length > 0 && buffer[0] === TlsRecordType.APPLICATION_DATA;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Checks if a buffer contains a TLS alert record
|
|
||||||
* @param buffer The buffer to check
|
|
||||||
* @returns true if the buffer starts with a TLS alert record
|
|
||||||
*/
|
|
||||||
public static isTlsAlert(buffer: Buffer): boolean {
|
|
||||||
return buffer.length > 0 && buffer[0] === TlsRecordType.ALERT;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Checks if a buffer contains a TLS ClientHello message
|
|
||||||
* @param buffer The buffer to check
|
|
||||||
* @returns true if the buffer appears to be a ClientHello message
|
|
||||||
*/
|
|
||||||
public static isClientHello(buffer: Buffer): boolean {
|
|
||||||
// Minimum ClientHello size (TLS record header + handshake header)
|
|
||||||
if (buffer.length < 9) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check record type (must be TLS_HANDSHAKE_RECORD_TYPE)
|
|
||||||
if (buffer[0] !== TlsRecordType.HANDSHAKE) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Skip version and length in TLS record header (5 bytes total)
|
|
||||||
// Check handshake type at byte 5 (must be CLIENT_HELLO)
|
|
||||||
return buffer[5] === TlsHandshakeType.CLIENT_HELLO;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Gets the record length from a TLS record header
|
|
||||||
* @param buffer Buffer containing a TLS record
|
|
||||||
* @returns The record length if the buffer is valid, -1 otherwise
|
|
||||||
*/
|
|
||||||
public static getTlsRecordLength(buffer: Buffer): number {
|
|
||||||
if (buffer.length < 5) {
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Bytes 3-4 contain the record length (big-endian)
|
|
||||||
return (buffer[3] << 8) + buffer[4];
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Creates a connection ID based on source/destination information
|
|
||||||
* Used to track fragmented ClientHello messages across multiple packets
|
|
||||||
*
|
|
||||||
* @param connectionInfo Object containing connection identifiers
|
|
||||||
* @returns A string ID for the connection
|
|
||||||
*/
|
|
||||||
public static createConnectionId(connectionInfo: {
|
|
||||||
sourceIp?: string;
|
|
||||||
sourcePort?: number;
|
|
||||||
destIp?: string;
|
|
||||||
destPort?: number;
|
|
||||||
}): string {
|
|
||||||
const { sourceIp, sourcePort, destIp, destPort } = connectionInfo;
|
|
||||||
return `${sourceIp}:${sourcePort}-${destIp}:${destPort}`;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,60 +0,0 @@
|
|||||||
/**
|
|
||||||
* WebSocket Protocol Constants
|
|
||||||
* Based on RFC 6455
|
|
||||||
*/
|
|
||||||
|
|
||||||
/**
|
|
||||||
* WebSocket opcode types
|
|
||||||
*/
|
|
||||||
export enum WebSocketOpcode {
|
|
||||||
CONTINUATION = 0x0,
|
|
||||||
TEXT = 0x1,
|
|
||||||
BINARY = 0x2,
|
|
||||||
CLOSE = 0x8,
|
|
||||||
PING = 0x9,
|
|
||||||
PONG = 0xa,
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* WebSocket close codes
|
|
||||||
*/
|
|
||||||
export enum WebSocketCloseCode {
|
|
||||||
NORMAL_CLOSURE = 1000,
|
|
||||||
GOING_AWAY = 1001,
|
|
||||||
PROTOCOL_ERROR = 1002,
|
|
||||||
UNSUPPORTED_DATA = 1003,
|
|
||||||
NO_STATUS_RECEIVED = 1005,
|
|
||||||
ABNORMAL_CLOSURE = 1006,
|
|
||||||
INVALID_FRAME_PAYLOAD_DATA = 1007,
|
|
||||||
POLICY_VIOLATION = 1008,
|
|
||||||
MESSAGE_TOO_BIG = 1009,
|
|
||||||
MISSING_EXTENSION = 1010,
|
|
||||||
INTERNAL_ERROR = 1011,
|
|
||||||
SERVICE_RESTART = 1012,
|
|
||||||
TRY_AGAIN_LATER = 1013,
|
|
||||||
BAD_GATEWAY = 1014,
|
|
||||||
TLS_HANDSHAKE = 1015,
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* WebSocket protocol version
|
|
||||||
*/
|
|
||||||
export const WEBSOCKET_VERSION = 13;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* WebSocket magic string for handshake
|
|
||||||
*/
|
|
||||||
export const WEBSOCKET_MAGIC_STRING = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11';
|
|
||||||
|
|
||||||
/**
|
|
||||||
* WebSocket headers
|
|
||||||
*/
|
|
||||||
export const WEBSOCKET_HEADERS = {
|
|
||||||
UPGRADE: 'upgrade',
|
|
||||||
CONNECTION: 'connection',
|
|
||||||
SEC_WEBSOCKET_KEY: 'sec-websocket-key',
|
|
||||||
SEC_WEBSOCKET_VERSION: 'sec-websocket-version',
|
|
||||||
SEC_WEBSOCKET_ACCEPT: 'sec-websocket-accept',
|
|
||||||
SEC_WEBSOCKET_PROTOCOL: 'sec-websocket-protocol',
|
|
||||||
SEC_WEBSOCKET_EXTENSIONS: 'sec-websocket-extensions',
|
|
||||||
} as const;
|
|
||||||
@@ -1,8 +0,0 @@
|
|||||||
/**
|
|
||||||
* WebSocket Protocol Module
|
|
||||||
* WebSocket protocol utilities and constants
|
|
||||||
*/
|
|
||||||
|
|
||||||
export * from './constants.js';
|
|
||||||
export * from './types.js';
|
|
||||||
export * from './utils.js';
|
|
||||||
@@ -1,53 +0,0 @@
|
|||||||
/**
|
|
||||||
* WebSocket Protocol Type Definitions
|
|
||||||
*/
|
|
||||||
|
|
||||||
import type { WebSocketOpcode, WebSocketCloseCode } from './constants.js';
|
|
||||||
|
|
||||||
/**
|
|
||||||
* WebSocket frame header
|
|
||||||
*/
|
|
||||||
export interface IWebSocketFrameHeader {
|
|
||||||
fin: boolean;
|
|
||||||
rsv1: boolean;
|
|
||||||
rsv2: boolean;
|
|
||||||
rsv3: boolean;
|
|
||||||
opcode: WebSocketOpcode;
|
|
||||||
masked: boolean;
|
|
||||||
payloadLength: number;
|
|
||||||
maskingKey?: Buffer;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* WebSocket frame
|
|
||||||
*/
|
|
||||||
export interface IWebSocketFrame {
|
|
||||||
header: IWebSocketFrameHeader;
|
|
||||||
payload: Buffer;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* WebSocket close frame payload
|
|
||||||
*/
|
|
||||||
export interface IWebSocketClosePayload {
|
|
||||||
code: WebSocketCloseCode;
|
|
||||||
reason?: string;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* WebSocket handshake request headers
|
|
||||||
*/
|
|
||||||
export interface IWebSocketHandshakeHeaders {
|
|
||||||
upgrade: string;
|
|
||||||
connection: string;
|
|
||||||
'sec-websocket-key': string;
|
|
||||||
'sec-websocket-version': string;
|
|
||||||
'sec-websocket-protocol'?: string;
|
|
||||||
'sec-websocket-extensions'?: string;
|
|
||||||
[key: string]: string | undefined;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Type for WebSocket raw data (matching ws library)
|
|
||||||
*/
|
|
||||||
export type RawData = Buffer | ArrayBuffer | Buffer[] | any;
|
|
||||||
@@ -1,98 +0,0 @@
|
|||||||
/**
|
|
||||||
* WebSocket Protocol Utilities
|
|
||||||
*/
|
|
||||||
|
|
||||||
import * as crypto from 'node:crypto';
|
|
||||||
import { WEBSOCKET_MAGIC_STRING } from './constants.js';
|
|
||||||
import type { RawData } from './types.js';
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get the length of a WebSocket message regardless of its type
|
|
||||||
* (handles all possible WebSocket message data types)
|
|
||||||
*/
|
|
||||||
export function getMessageSize(data: RawData): number {
|
|
||||||
if (typeof data === 'string') {
|
|
||||||
// For string data, get the byte length
|
|
||||||
return Buffer.from(data, 'utf8').length;
|
|
||||||
} else if (data instanceof Buffer) {
|
|
||||||
// For Node.js Buffer
|
|
||||||
return data.length;
|
|
||||||
} else if (data instanceof ArrayBuffer) {
|
|
||||||
// For ArrayBuffer
|
|
||||||
return data.byteLength;
|
|
||||||
} else if (Array.isArray(data)) {
|
|
||||||
// For array of buffers, sum their lengths
|
|
||||||
return data.reduce((sum, chunk) => {
|
|
||||||
if (chunk instanceof Buffer) {
|
|
||||||
return sum + chunk.length;
|
|
||||||
} else if (chunk instanceof ArrayBuffer) {
|
|
||||||
return sum + chunk.byteLength;
|
|
||||||
}
|
|
||||||
return sum;
|
|
||||||
}, 0);
|
|
||||||
} else {
|
|
||||||
// For other types, try to determine the size or return 0
|
|
||||||
try {
|
|
||||||
return Buffer.from(data).length;
|
|
||||||
} catch (e) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Convert any raw WebSocket data to Buffer for consistent handling
|
|
||||||
*/
|
|
||||||
export function toBuffer(data: RawData): Buffer {
|
|
||||||
if (typeof data === 'string') {
|
|
||||||
return Buffer.from(data, 'utf8');
|
|
||||||
} else if (data instanceof Buffer) {
|
|
||||||
return data;
|
|
||||||
} else if (data instanceof ArrayBuffer) {
|
|
||||||
return Buffer.from(data);
|
|
||||||
} else if (Array.isArray(data)) {
|
|
||||||
// For array of buffers, concatenate them
|
|
||||||
return Buffer.concat(data.map(chunk => {
|
|
||||||
if (chunk instanceof Buffer) {
|
|
||||||
return chunk;
|
|
||||||
} else if (chunk instanceof ArrayBuffer) {
|
|
||||||
return Buffer.from(chunk);
|
|
||||||
}
|
|
||||||
return Buffer.from(chunk);
|
|
||||||
}));
|
|
||||||
} else {
|
|
||||||
// For other types, try to convert to Buffer or return empty Buffer
|
|
||||||
try {
|
|
||||||
return Buffer.from(data);
|
|
||||||
} catch (e) {
|
|
||||||
return Buffer.alloc(0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Generate WebSocket accept key from client key
|
|
||||||
*/
|
|
||||||
export function generateAcceptKey(clientKey: string): string {
|
|
||||||
const hash = crypto.createHash('sha1');
|
|
||||||
hash.update(clientKey + WEBSOCKET_MAGIC_STRING);
|
|
||||||
return hash.digest('base64');
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Validate WebSocket upgrade request
|
|
||||||
*/
|
|
||||||
export function isWebSocketUpgrade(headers: Record<string, string>): boolean {
|
|
||||||
const upgrade = headers['upgrade'];
|
|
||||||
const connection = headers['connection'];
|
|
||||||
|
|
||||||
return upgrade?.toLowerCase() === 'websocket' &&
|
|
||||||
connection?.toLowerCase().includes('upgrade');
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Generate random WebSocket key for client handshake
|
|
||||||
*/
|
|
||||||
export function generateWebSocketKey(): string {
|
|
||||||
return crypto.randomBytes(16).toString('base64');
|
|
||||||
}
|
|
||||||
@@ -274,6 +274,12 @@ export class SocketHandlerServer {
|
|||||||
backend.pipe(socket);
|
backend.pipe(socket);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// Track backend socket for cleanup on stop()
|
||||||
|
this.activeSockets.add(backend);
|
||||||
|
backend.on('close', () => {
|
||||||
|
this.activeSockets.delete(backend);
|
||||||
|
});
|
||||||
|
|
||||||
// Connect timeout: if backend doesn't connect within 30s, destroy both
|
// Connect timeout: if backend doesn't connect within 30s, destroy both
|
||||||
backend.setTimeout(30_000);
|
backend.setTimeout(30_000);
|
||||||
|
|
||||||
|
|||||||
@@ -7,9 +7,54 @@
|
|||||||
|
|
||||||
import * as plugins from '../../../../plugins.js';
|
import * as plugins from '../../../../plugins.js';
|
||||||
import type { IRouteConfig, TPortRange, IRouteContext } from '../../models/route-types.js';
|
import type { IRouteConfig, TPortRange, IRouteContext } from '../../models/route-types.js';
|
||||||
import { ProtocolDetector } from '../../../../detection/index.js';
|
|
||||||
import { createSocketTracker } from '../../../../core/utils/socket-tracker.js';
|
import { createSocketTracker } from '../../../../core/utils/socket-tracker.js';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Minimal HTTP request parser for socket handlers.
|
||||||
|
* Parses method, path, and optionally headers from a raw buffer.
|
||||||
|
*/
|
||||||
|
function parseHttpRequest(data: Buffer, extractHeaders: boolean = false): {
|
||||||
|
method: string;
|
||||||
|
path: string;
|
||||||
|
headers: Record<string, string>;
|
||||||
|
isComplete: boolean;
|
||||||
|
body?: string;
|
||||||
|
} | null {
|
||||||
|
const str = data.toString('utf8');
|
||||||
|
const headerEnd = str.indexOf('\r\n\r\n');
|
||||||
|
const isComplete = headerEnd !== -1;
|
||||||
|
const headerSection = isComplete ? str.slice(0, headerEnd) : str;
|
||||||
|
const lines = headerSection.split('\r\n');
|
||||||
|
const requestLine = lines[0];
|
||||||
|
if (!requestLine) return null;
|
||||||
|
|
||||||
|
const parts = requestLine.split(' ');
|
||||||
|
if (parts.length < 2) return null;
|
||||||
|
|
||||||
|
const method = parts[0];
|
||||||
|
const path = parts[1];
|
||||||
|
|
||||||
|
// Quick check: valid HTTP method
|
||||||
|
const validMethods = ['GET', 'POST', 'PUT', 'DELETE', 'PATCH', 'HEAD', 'OPTIONS', 'CONNECT', 'TRACE'];
|
||||||
|
if (!validMethods.includes(method)) return null;
|
||||||
|
|
||||||
|
const headers: Record<string, string> = {};
|
||||||
|
if (extractHeaders) {
|
||||||
|
for (let i = 1; i < lines.length; i++) {
|
||||||
|
const colonIdx = lines[i].indexOf(':');
|
||||||
|
if (colonIdx > 0) {
|
||||||
|
const name = lines[i].slice(0, colonIdx).trim().toLowerCase();
|
||||||
|
const value = lines[i].slice(colonIdx + 1).trim();
|
||||||
|
headers[name] = value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const body = isComplete ? str.slice(headerEnd + 4) : undefined;
|
||||||
|
|
||||||
|
return { method, path, headers, isComplete, body };
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Pre-built socket handlers for common use cases
|
* Pre-built socket handlers for common use cases
|
||||||
*/
|
*/
|
||||||
@@ -104,30 +149,19 @@ export const SocketHandlers = {
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* HTTP redirect handler
|
* HTTP redirect handler
|
||||||
* Uses the centralized detection module for HTTP parsing
|
|
||||||
*/
|
*/
|
||||||
httpRedirect: (locationTemplate: string, statusCode: number = 301) => (socket: plugins.net.Socket, context: IRouteContext) => {
|
httpRedirect: (locationTemplate: string, statusCode: number = 301) => (socket: plugins.net.Socket, context: IRouteContext) => {
|
||||||
const tracker = createSocketTracker(socket);
|
const tracker = createSocketTracker(socket);
|
||||||
const connectionId = ProtocolDetector.createConnectionId({
|
|
||||||
socketId: context.connectionId || `${Date.now()}-${Math.random()}`
|
|
||||||
});
|
|
||||||
|
|
||||||
const handleData = async (data: Buffer) => {
|
const handleData = (data: Buffer) => {
|
||||||
// Use detection module for parsing
|
const parsed = parseHttpRequest(data);
|
||||||
const detectionResult = await ProtocolDetector.detectWithConnectionTracking(
|
|
||||||
data,
|
|
||||||
connectionId,
|
|
||||||
{ extractFullHeaders: false } // We only need method and path
|
|
||||||
);
|
|
||||||
|
|
||||||
if (detectionResult.protocol === 'http' && detectionResult.connectionInfo.path) {
|
|
||||||
const method = detectionResult.connectionInfo.method || 'GET';
|
|
||||||
const path = detectionResult.connectionInfo.path || '/';
|
|
||||||
|
|
||||||
|
if (parsed) {
|
||||||
|
const path = parsed.path || '/';
|
||||||
const domain = context.domain || 'localhost';
|
const domain = context.domain || 'localhost';
|
||||||
const port = context.port;
|
const port = context.port;
|
||||||
|
|
||||||
let finalLocation = locationTemplate
|
const finalLocation = locationTemplate
|
||||||
.replace('{domain}', domain)
|
.replace('{domain}', domain)
|
||||||
.replace('{port}', String(port))
|
.replace('{port}', String(port))
|
||||||
.replace('{path}', path)
|
.replace('{path}', path)
|
||||||
@@ -146,18 +180,13 @@ export const SocketHandlers = {
|
|||||||
|
|
||||||
socket.write(response);
|
socket.write(response);
|
||||||
} else {
|
} else {
|
||||||
// Not a valid HTTP request, close connection
|
|
||||||
socket.write('HTTP/1.1 400 Bad Request\r\nConnection: close\r\n\r\n');
|
socket.write('HTTP/1.1 400 Bad Request\r\nConnection: close\r\n\r\n');
|
||||||
}
|
}
|
||||||
|
|
||||||
socket.end();
|
socket.end();
|
||||||
// Clean up detection state
|
|
||||||
ProtocolDetector.cleanupConnections();
|
|
||||||
// Clean up all tracked resources
|
|
||||||
tracker.cleanup();
|
tracker.cleanup();
|
||||||
};
|
};
|
||||||
|
|
||||||
// Use tracker to manage the listener
|
|
||||||
socket.once('data', handleData);
|
socket.once('data', handleData);
|
||||||
|
|
||||||
tracker.addListener('error', (err) => {
|
tracker.addListener('error', (err) => {
|
||||||
@@ -171,45 +200,31 @@ export const SocketHandlers = {
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* HTTP server handler for ACME challenges and other HTTP needs
|
* HTTP server handler for ACME challenges and other HTTP needs
|
||||||
* Uses the centralized detection module for HTTP parsing
|
|
||||||
*/
|
*/
|
||||||
httpServer: (handler: (req: { method: string; url: string; headers: Record<string, string>; body?: string }, res: { status: (code: number) => void; header: (name: string, value: string) => void; send: (data: string) => void; end: () => void }) => void) => (socket: plugins.net.Socket, context: IRouteContext) => {
|
httpServer: (handler: (req: { method: string; url: string; headers: Record<string, string>; body?: string }, res: { status: (code: number) => void; header: (name: string, value: string) => void; send: (data: string) => void; end: () => void }) => void) => (socket: plugins.net.Socket, context: IRouteContext) => {
|
||||||
const tracker = createSocketTracker(socket);
|
const tracker = createSocketTracker(socket);
|
||||||
let requestParsed = false;
|
let requestParsed = false;
|
||||||
let responseTimer: NodeJS.Timeout | null = null;
|
let responseTimer: NodeJS.Timeout | null = null;
|
||||||
const connectionId = ProtocolDetector.createConnectionId({
|
|
||||||
socketId: context.connectionId || `${Date.now()}-${Math.random()}`
|
|
||||||
});
|
|
||||||
|
|
||||||
const processData = async (data: Buffer) => {
|
const processData = (data: Buffer) => {
|
||||||
if (requestParsed) return; // Only handle the first request
|
if (requestParsed) return;
|
||||||
|
|
||||||
// Use HttpDetector for parsing
|
const parsed = parseHttpRequest(data, true);
|
||||||
const detectionResult = await ProtocolDetector.detectWithConnectionTracking(
|
|
||||||
data,
|
|
||||||
connectionId,
|
|
||||||
{ extractFullHeaders: true }
|
|
||||||
);
|
|
||||||
|
|
||||||
if (detectionResult.protocol !== 'http' || !detectionResult.isComplete) {
|
if (!parsed || !parsed.isComplete) {
|
||||||
// Not a complete HTTP request yet
|
return; // Not a complete HTTP request yet
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
requestParsed = true;
|
requestParsed = true;
|
||||||
// Remove data listener after parsing request
|
|
||||||
socket.removeListener('data', processData);
|
socket.removeListener('data', processData);
|
||||||
const connInfo = detectionResult.connectionInfo;
|
|
||||||
|
|
||||||
// Create request object from detection result
|
|
||||||
const req = {
|
const req = {
|
||||||
method: connInfo.method || 'GET',
|
method: parsed.method,
|
||||||
url: connInfo.path || '/',
|
url: parsed.path,
|
||||||
headers: connInfo.headers || {},
|
headers: parsed.headers,
|
||||||
body: detectionResult.remainingBuffer?.toString() || ''
|
body: parsed.body || ''
|
||||||
};
|
};
|
||||||
|
|
||||||
// Create response object
|
|
||||||
let statusCode = 200;
|
let statusCode = 200;
|
||||||
const responseHeaders: Record<string, string> = {};
|
const responseHeaders: Record<string, string> = {};
|
||||||
let ended = false;
|
let ended = false;
|
||||||
@@ -225,7 +240,6 @@ export const SocketHandlers = {
|
|||||||
if (ended) return;
|
if (ended) return;
|
||||||
ended = true;
|
ended = true;
|
||||||
|
|
||||||
// Clear response timer since we're sending now
|
|
||||||
if (responseTimer) {
|
if (responseTimer) {
|
||||||
clearTimeout(responseTimer);
|
clearTimeout(responseTimer);
|
||||||
responseTimer = null;
|
responseTimer = null;
|
||||||
@@ -261,26 +275,22 @@ export const SocketHandlers = {
|
|||||||
|
|
||||||
try {
|
try {
|
||||||
handler(req, res);
|
handler(req, res);
|
||||||
// Ensure response is sent even if handler doesn't call send()
|
|
||||||
responseTimer = setTimeout(() => {
|
responseTimer = setTimeout(() => {
|
||||||
if (!ended) {
|
if (!ended) {
|
||||||
res.send('');
|
res.send('');
|
||||||
}
|
}
|
||||||
responseTimer = null;
|
responseTimer = null;
|
||||||
}, 1000);
|
}, 1000);
|
||||||
// Track and unref the timer
|
|
||||||
tracker.addTimer(responseTimer);
|
tracker.addTimer(responseTimer);
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
if (!ended) {
|
if (!ended) {
|
||||||
res.status(500);
|
res.status(500);
|
||||||
res.send('Internal Server Error');
|
res.send('Internal Server Error');
|
||||||
}
|
}
|
||||||
// Use safeDestroy for error cases
|
|
||||||
tracker.safeDestroy(error instanceof Error ? error : new Error('Handler error'));
|
tracker.safeDestroy(error instanceof Error ? error : new Error('Handler error'));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Use tracker to manage listeners
|
|
||||||
tracker.addListener('data', processData);
|
tracker.addListener('data', processData);
|
||||||
|
|
||||||
tracker.addListener('error', (err) => {
|
tracker.addListener('error', (err) => {
|
||||||
@@ -290,14 +300,10 @@ export const SocketHandlers = {
|
|||||||
});
|
});
|
||||||
|
|
||||||
tracker.addListener('close', () => {
|
tracker.addListener('close', () => {
|
||||||
// Clear any pending response timer
|
|
||||||
if (responseTimer) {
|
if (responseTimer) {
|
||||||
clearTimeout(responseTimer);
|
clearTimeout(responseTimer);
|
||||||
responseTimer = null;
|
responseTimer = null;
|
||||||
}
|
}
|
||||||
// Clean up detection state
|
|
||||||
ProtocolDetector.cleanupConnections();
|
|
||||||
// Clean up all tracked resources
|
|
||||||
tracker.cleanup();
|
tracker.cleanup();
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@@ -305,11 +311,6 @@ export const SocketHandlers = {
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a socket handler route configuration
|
* Create a socket handler route configuration
|
||||||
* @param domains Domain(s) to match
|
|
||||||
* @param ports Port(s) to listen on
|
|
||||||
* @param handler Socket handler function
|
|
||||||
* @param options Additional route options
|
|
||||||
* @returns Route configuration object
|
|
||||||
*/
|
*/
|
||||||
export function createSocketHandlerRoute(
|
export function createSocketHandlerRoute(
|
||||||
domains: string | string[],
|
domains: string | string[],
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import type { IRouteConfig, IRouteMatch, IRouteAction, TPortRange } from '../mod
|
|||||||
export class RouteValidator {
|
export class RouteValidator {
|
||||||
private static readonly VALID_TLS_MODES = ['terminate', 'passthrough', 'terminate-and-reencrypt'];
|
private static readonly VALID_TLS_MODES = ['terminate', 'passthrough', 'terminate-and-reencrypt'];
|
||||||
private static readonly VALID_ACTION_TYPES = ['forward', 'socket-handler'];
|
private static readonly VALID_ACTION_TYPES = ['forward', 'socket-handler'];
|
||||||
private static readonly VALID_PROTOCOLS = ['tcp', 'http', 'https', 'ws', 'wss'];
|
private static readonly VALID_PROTOCOLS = ['tcp', 'http', 'https', 'ws', 'wss', 'udp', 'quic', 'http3'];
|
||||||
private static readonly MAX_PORTS = 100;
|
private static readonly MAX_PORTS = 100;
|
||||||
private static readonly MAX_DOMAINS = 1000;
|
private static readonly MAX_DOMAINS = 1000;
|
||||||
private static readonly MAX_HEADER_SIZE = 8192;
|
private static readonly MAX_HEADER_SIZE = 8192;
|
||||||
@@ -173,6 +173,22 @@ export class RouteValidator {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// QUIC routes require TLS with termination (QUIC mandates TLS 1.3)
|
||||||
|
if (route.action.udp?.quic && route.action.type === 'forward') {
|
||||||
|
if (!route.action.tls) {
|
||||||
|
errors.push('QUIC routes require TLS configuration (action.tls) — QUIC mandates TLS 1.3');
|
||||||
|
} else if (route.action.tls.mode === 'passthrough') {
|
||||||
|
errors.push('QUIC routes cannot use TLS mode "passthrough" — use "terminate" or "terminate-and-reencrypt"');
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Protocol quic/http3 requires transport udp or all
|
||||||
|
if (route.match?.protocol && ['quic', 'http3'].includes(route.match.protocol)) {
|
||||||
|
if (route.match.transport && route.match.transport !== 'udp' && route.match.transport !== 'all') {
|
||||||
|
errors.push(`Protocol "${route.match.protocol}" requires transport "udp" or "all"`);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate security settings
|
// Validate security settings
|
||||||
@@ -619,6 +635,15 @@ export function validateRouteAction(action: IRouteAction): { valid: boolean; err
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// QUIC routes require TLS with termination
|
||||||
|
if (action.udp?.quic && action.type === 'forward') {
|
||||||
|
if (!action.tls) {
|
||||||
|
errors.push('QUIC routes require TLS configuration — QUIC mandates TLS 1.3');
|
||||||
|
} else if (action.tls.mode === 'passthrough') {
|
||||||
|
errors.push('QUIC routes cannot use TLS mode "passthrough"');
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (action.type === 'socket-handler') {
|
if (action.type === 'socket-handler') {
|
||||||
if (!action.socketHandler && !action.datagramHandler) {
|
if (!action.socketHandler && !action.datagramHandler) {
|
||||||
errors.push('Socket handler or datagram handler function is required for socket-handler action');
|
errors.push('Socket handler or datagram handler function is required for socket-handler action');
|
||||||
|
|||||||
@@ -4,6 +4,3 @@
|
|||||||
|
|
||||||
// Export types and models
|
// Export types and models
|
||||||
export * from './models/http-types.js';
|
export * from './models/http-types.js';
|
||||||
|
|
||||||
// Export router functionality
|
|
||||||
export * from './router/index.js';
|
|
||||||
|
|||||||
@@ -1,266 +0,0 @@
|
|||||||
import * as plugins from '../../plugins.js';
|
|
||||||
import type { IRouteConfig } from '../../proxies/smart-proxy/models/route-types.js';
|
|
||||||
import { DomainMatcher, PathMatcher } from '../../core/routing/matchers/index.js';
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Interface for router result with additional metadata
|
|
||||||
*/
|
|
||||||
export interface RouterResult {
|
|
||||||
route: IRouteConfig;
|
|
||||||
pathMatch?: string;
|
|
||||||
pathParams?: Record<string, string>;
|
|
||||||
pathRemainder?: string;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Logger interface for HttpRouter
|
|
||||||
*/
|
|
||||||
export interface ILogger {
|
|
||||||
debug?: (message: string, data?: any) => void;
|
|
||||||
info: (message: string, data?: any) => void;
|
|
||||||
warn: (message: string, data?: any) => void;
|
|
||||||
error: (message: string, data?: any) => void;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Unified HTTP Router for reverse proxy requests
|
|
||||||
*
|
|
||||||
* Domain matching patterns:
|
|
||||||
* - Exact matches: "example.com"
|
|
||||||
* - Wildcard subdomains: "*.example.com" (matches any subdomain of example.com)
|
|
||||||
* - TLD wildcards: "example.*" (matches example.com, example.org, etc.)
|
|
||||||
* - Complex wildcards: "*.lossless*" (matches any subdomain of any lossless domain)
|
|
||||||
* - Default fallback: "*" (matches any unmatched domain)
|
|
||||||
*
|
|
||||||
* Path pattern matching:
|
|
||||||
* - Exact path: "/api/users"
|
|
||||||
* - Wildcard paths: "/api/*"
|
|
||||||
* - Path parameters: "/users/:id/profile"
|
|
||||||
*/
|
|
||||||
export class HttpRouter {
|
|
||||||
// Store routes sorted by priority
|
|
||||||
private routes: IRouteConfig[] = [];
|
|
||||||
// Default route to use when no match is found (optional)
|
|
||||||
private defaultRoute?: IRouteConfig;
|
|
||||||
// Logger interface
|
|
||||||
private logger: ILogger;
|
|
||||||
|
|
||||||
constructor(
|
|
||||||
routes?: IRouteConfig[],
|
|
||||||
logger?: ILogger
|
|
||||||
) {
|
|
||||||
this.logger = logger || {
|
|
||||||
error: console.error.bind(console),
|
|
||||||
warn: console.warn.bind(console),
|
|
||||||
info: console.info.bind(console),
|
|
||||||
debug: console.debug?.bind(console)
|
|
||||||
};
|
|
||||||
|
|
||||||
if (routes) {
|
|
||||||
this.setRoutes(routes);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Sets a new set of routes
|
|
||||||
* @param routes Array of route configurations
|
|
||||||
*/
|
|
||||||
public setRoutes(routes: IRouteConfig[]): void {
|
|
||||||
this.routes = [...routes];
|
|
||||||
|
|
||||||
// Sort routes by priority (higher priority first)
|
|
||||||
this.routes.sort((a, b) => {
|
|
||||||
const priorityA = a.priority ?? 0;
|
|
||||||
const priorityB = b.priority ?? 0;
|
|
||||||
return priorityB - priorityA;
|
|
||||||
});
|
|
||||||
|
|
||||||
// Find default route if any (route with "*" as domain)
|
|
||||||
this.defaultRoute = this.routes.find(route => {
|
|
||||||
const domains = Array.isArray(route.match.domains)
|
|
||||||
? route.match.domains
|
|
||||||
: route.match.domains ? [route.match.domains] : [];
|
|
||||||
return domains.includes('*');
|
|
||||||
});
|
|
||||||
|
|
||||||
const uniqueDomains = this.getHostnames();
|
|
||||||
this.logger.info(`HttpRouter initialized with ${this.routes.length} routes (${uniqueDomains.length} unique hosts)`);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Routes a request based on hostname and path
|
|
||||||
* @param req The incoming HTTP request
|
|
||||||
* @returns The matching route or undefined if no match found
|
|
||||||
*/
|
|
||||||
public routeReq(req: plugins.http.IncomingMessage): IRouteConfig | undefined {
|
|
||||||
const result = this.routeReqWithDetails(req);
|
|
||||||
return result ? result.route : undefined;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Routes a request with detailed matching information
|
|
||||||
* @param req The incoming HTTP request
|
|
||||||
* @returns Detailed routing result including matched route and path information
|
|
||||||
*/
|
|
||||||
public routeReqWithDetails(req: plugins.http.IncomingMessage): RouterResult | undefined {
|
|
||||||
// Extract and validate host header
|
|
||||||
const originalHost = req.headers.host;
|
|
||||||
if (!originalHost) {
|
|
||||||
this.logger.error('No host header found in request');
|
|
||||||
return this.defaultRoute ? { route: this.defaultRoute } : undefined;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse URL for path matching
|
|
||||||
const parsedUrl = plugins.url.parse(req.url || '/');
|
|
||||||
const urlPath = parsedUrl.pathname || '/';
|
|
||||||
|
|
||||||
// Extract hostname without port
|
|
||||||
const hostWithoutPort = originalHost.split(':')[0].toLowerCase();
|
|
||||||
|
|
||||||
// Find matching route
|
|
||||||
const matchingRoute = this.findMatchingRoute(hostWithoutPort, urlPath);
|
|
||||||
|
|
||||||
if (matchingRoute) {
|
|
||||||
return matchingRoute;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fall back to default route if available
|
|
||||||
if (this.defaultRoute) {
|
|
||||||
this.logger.warn(`No specific route found for host: ${hostWithoutPort}, using default`);
|
|
||||||
return { route: this.defaultRoute };
|
|
||||||
}
|
|
||||||
|
|
||||||
this.logger.error(`No route found for host: ${hostWithoutPort}`);
|
|
||||||
return undefined;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Find the best matching route for a given hostname and path
|
|
||||||
*/
|
|
||||||
private findMatchingRoute(hostname: string, path: string): RouterResult | undefined {
|
|
||||||
// Try each route in priority order
|
|
||||||
for (const route of this.routes) {
|
|
||||||
// Skip disabled routes
|
|
||||||
if (route.enabled === false) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check domain match
|
|
||||||
if (route.match.domains) {
|
|
||||||
const domains = Array.isArray(route.match.domains)
|
|
||||||
? route.match.domains
|
|
||||||
: [route.match.domains];
|
|
||||||
|
|
||||||
// Check if any domain pattern matches
|
|
||||||
const domainMatches = domains.some(domain =>
|
|
||||||
DomainMatcher.match(domain, hostname)
|
|
||||||
);
|
|
||||||
|
|
||||||
if (!domainMatches) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check path match if specified
|
|
||||||
if (route.match.path) {
|
|
||||||
const pathResult = PathMatcher.match(route.match.path, path);
|
|
||||||
if (pathResult.matches) {
|
|
||||||
return {
|
|
||||||
route,
|
|
||||||
pathMatch: pathResult.pathMatch || path,
|
|
||||||
pathParams: pathResult.params,
|
|
||||||
pathRemainder: pathResult.pathRemainder
|
|
||||||
};
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// No path specified, so domain match is sufficient
|
|
||||||
return { route };
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return undefined;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Gets all currently active route configurations
|
|
||||||
* @returns Array of all active routes
|
|
||||||
*/
|
|
||||||
public getRoutes(): IRouteConfig[] {
|
|
||||||
return [...this.routes];
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Gets all hostnames that this router is configured to handle
|
|
||||||
* @returns Array of unique hostnames
|
|
||||||
*/
|
|
||||||
public getHostnames(): string[] {
|
|
||||||
const hostnames = new Set<string>();
|
|
||||||
for (const route of this.routes) {
|
|
||||||
if (!route.match.domains) continue;
|
|
||||||
|
|
||||||
const domains = Array.isArray(route.match.domains)
|
|
||||||
? route.match.domains
|
|
||||||
: [route.match.domains];
|
|
||||||
|
|
||||||
for (const domain of domains) {
|
|
||||||
if (domain !== '*') {
|
|
||||||
hostnames.add(domain.toLowerCase());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return Array.from(hostnames);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Adds a single new route configuration
|
|
||||||
* @param route The route configuration to add
|
|
||||||
*/
|
|
||||||
public addRoute(route: IRouteConfig): void {
|
|
||||||
this.routes.push(route);
|
|
||||||
|
|
||||||
// Re-sort routes by priority
|
|
||||||
this.routes.sort((a, b) => {
|
|
||||||
const priorityA = a.priority ?? 0;
|
|
||||||
const priorityB = b.priority ?? 0;
|
|
||||||
return priorityB - priorityA;
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Removes routes by domain pattern
|
|
||||||
* @param domain The domain pattern to remove routes for
|
|
||||||
* @returns Boolean indicating whether any routes were removed
|
|
||||||
*/
|
|
||||||
public removeRoutesByDomain(domain: string): boolean {
|
|
||||||
const initialCount = this.routes.length;
|
|
||||||
|
|
||||||
// Filter out routes that match the domain
|
|
||||||
this.routes = this.routes.filter(route => {
|
|
||||||
if (!route.match.domains) return true;
|
|
||||||
|
|
||||||
const domains = Array.isArray(route.match.domains)
|
|
||||||
? route.match.domains
|
|
||||||
: [route.match.domains];
|
|
||||||
|
|
||||||
return !domains.includes(domain);
|
|
||||||
});
|
|
||||||
|
|
||||||
return this.routes.length !== initialCount;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Remove a specific route by reference
|
|
||||||
* @param route The route to remove
|
|
||||||
* @returns Boolean indicating if the route was found and removed
|
|
||||||
*/
|
|
||||||
public removeRoute(route: IRouteConfig): boolean {
|
|
||||||
const index = this.routes.indexOf(route);
|
|
||||||
if (index !== -1) {
|
|
||||||
this.routes.splice(index, 1);
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
@@ -1,7 +0,0 @@
|
|||||||
/**
|
|
||||||
* HTTP routing
|
|
||||||
*/
|
|
||||||
|
|
||||||
// Export the unified HttpRouter
|
|
||||||
export { HttpRouter } from './http-router.js';
|
|
||||||
export type { RouterResult, ILogger } from './http-router.js';
|
|
||||||
@@ -1,29 +0,0 @@
|
|||||||
/**
|
|
||||||
* TLS module for smartproxy
|
|
||||||
* Re-exports protocol components and provides smartproxy-specific functionality
|
|
||||||
*/
|
|
||||||
|
|
||||||
// Re-export all protocol components from protocols/tls
|
|
||||||
export * from '../protocols/tls/index.js';
|
|
||||||
|
|
||||||
// Export smartproxy-specific SNI handler
|
|
||||||
export * from './sni/sni-handler.js';
|
|
||||||
|
|
||||||
// Create a namespace for SNI utilities
|
|
||||||
import { SniHandler } from './sni/sni-handler.js';
|
|
||||||
import { SniExtraction } from '../protocols/tls/sni/sni-extraction.js';
|
|
||||||
import { ClientHelloParser } from '../protocols/tls/sni/client-hello-parser.js';
|
|
||||||
|
|
||||||
// Export utility objects for convenience
|
|
||||||
export const SNI = {
|
|
||||||
// Main handler class (for backward compatibility)
|
|
||||||
Handler: SniHandler,
|
|
||||||
|
|
||||||
// Utility classes
|
|
||||||
Extraction: SniExtraction,
|
|
||||||
Parser: ClientHelloParser,
|
|
||||||
|
|
||||||
// Convenience functions
|
|
||||||
extractSNI: SniHandler.extractSNI,
|
|
||||||
processTlsPacket: SniHandler.processTlsPacket,
|
|
||||||
};
|
|
||||||
@@ -1,3 +0,0 @@
|
|||||||
/**
|
|
||||||
* SNI handling
|
|
||||||
*/
|
|
||||||
@@ -1,264 +0,0 @@
|
|||||||
import { Buffer } from 'node:buffer';
|
|
||||||
import {
|
|
||||||
TlsRecordType,
|
|
||||||
TlsHandshakeType,
|
|
||||||
TlsExtensionType,
|
|
||||||
TlsUtils
|
|
||||||
} from '../../protocols/tls/utils/tls-utils.js';
|
|
||||||
import {
|
|
||||||
ClientHelloParser,
|
|
||||||
type LoggerFunction
|
|
||||||
} from '../../protocols/tls/sni/client-hello-parser.js';
|
|
||||||
import {
|
|
||||||
SniExtraction,
|
|
||||||
type ConnectionInfo
|
|
||||||
} from '../../protocols/tls/sni/sni-extraction.js';
|
|
||||||
|
|
||||||
/**
|
|
||||||
* SNI (Server Name Indication) handler for TLS connections.
|
|
||||||
* Provides robust extraction of SNI values from TLS ClientHello messages
|
|
||||||
* with support for fragmented packets, TLS 1.3 resumption, Chrome-specific
|
|
||||||
* connection behaviors, and tab hibernation/reactivation scenarios.
|
|
||||||
*
|
|
||||||
* This class retains the original API but leverages the new modular implementation
|
|
||||||
* for better maintainability and testability.
|
|
||||||
*/
|
|
||||||
export class SniHandler {
|
|
||||||
// Re-export constants for backward compatibility
|
|
||||||
private static readonly TLS_HANDSHAKE_RECORD_TYPE = TlsRecordType.HANDSHAKE;
|
|
||||||
private static readonly TLS_APPLICATION_DATA_TYPE = TlsRecordType.APPLICATION_DATA;
|
|
||||||
private static readonly TLS_CLIENT_HELLO_HANDSHAKE_TYPE = TlsHandshakeType.CLIENT_HELLO;
|
|
||||||
private static readonly TLS_SNI_EXTENSION_TYPE = TlsExtensionType.SERVER_NAME;
|
|
||||||
private static readonly TLS_SESSION_TICKET_EXTENSION_TYPE = TlsExtensionType.SESSION_TICKET;
|
|
||||||
private static readonly TLS_SNI_HOST_NAME_TYPE = 0; // NameType.HOST_NAME in RFC 6066
|
|
||||||
private static readonly TLS_PSK_EXTENSION_TYPE = TlsExtensionType.PRE_SHARED_KEY;
|
|
||||||
private static readonly TLS_PSK_KE_MODES_EXTENSION_TYPE = TlsExtensionType.PSK_KEY_EXCHANGE_MODES;
|
|
||||||
private static readonly TLS_EARLY_DATA_EXTENSION_TYPE = TlsExtensionType.EARLY_DATA;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Checks if a buffer contains a TLS handshake message (record type 22)
|
|
||||||
* @param buffer - The buffer to check
|
|
||||||
* @returns true if the buffer starts with a TLS handshake record type
|
|
||||||
*/
|
|
||||||
public static isTlsHandshake(buffer: Buffer): boolean {
|
|
||||||
return TlsUtils.isTlsHandshake(buffer);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Checks if a buffer contains TLS application data (record type 23)
|
|
||||||
* @param buffer - The buffer to check
|
|
||||||
* @returns true if the buffer starts with a TLS application data record type
|
|
||||||
*/
|
|
||||||
public static isTlsApplicationData(buffer: Buffer): boolean {
|
|
||||||
return TlsUtils.isTlsApplicationData(buffer);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Creates a connection ID based on source/destination information
|
|
||||||
* Used to track fragmented ClientHello messages across multiple packets
|
|
||||||
*
|
|
||||||
* @param connectionInfo - Object containing connection identifiers (IP/port)
|
|
||||||
* @returns A string ID for the connection
|
|
||||||
*/
|
|
||||||
public static createConnectionId(connectionInfo: {
|
|
||||||
sourceIp?: string;
|
|
||||||
sourcePort?: number;
|
|
||||||
destIp?: string;
|
|
||||||
destPort?: number;
|
|
||||||
}): string {
|
|
||||||
return TlsUtils.createConnectionId(connectionInfo);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Handles potential fragmented ClientHello messages by buffering and reassembling
|
|
||||||
* TLS record fragments that might span multiple TCP packets.
|
|
||||||
*
|
|
||||||
* @param buffer - The current buffer fragment
|
|
||||||
* @param connectionId - Unique identifier for the connection
|
|
||||||
* @param enableLogging - Whether to enable logging
|
|
||||||
* @returns A complete buffer if reassembly is successful, or undefined if more fragments are needed
|
|
||||||
*/
|
|
||||||
public static handleFragmentedClientHello(
|
|
||||||
buffer: Buffer,
|
|
||||||
connectionId: string,
|
|
||||||
enableLogging: boolean = false
|
|
||||||
): Buffer | undefined {
|
|
||||||
const logger = enableLogging ?
|
|
||||||
(message: string) => console.log(`[SNI Fragment] ${message}`) :
|
|
||||||
undefined;
|
|
||||||
|
|
||||||
return ClientHelloParser.handleFragmentedClientHello(buffer, connectionId, logger);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Checks if a buffer contains a TLS ClientHello message
|
|
||||||
* @param buffer - The buffer to check
|
|
||||||
* @returns true if the buffer appears to be a ClientHello message
|
|
||||||
*/
|
|
||||||
public static isClientHello(buffer: Buffer): boolean {
|
|
||||||
return TlsUtils.isClientHello(buffer);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Checks if a ClientHello message contains session resumption indicators
|
|
||||||
* such as session tickets or PSK (Pre-Shared Key) extensions.
|
|
||||||
*
|
|
||||||
* @param buffer - The buffer containing a ClientHello message
|
|
||||||
* @param enableLogging - Whether to enable logging
|
|
||||||
* @returns Object containing details about session resumption and SNI presence
|
|
||||||
*/
|
|
||||||
public static hasSessionResumption(
|
|
||||||
buffer: Buffer,
|
|
||||||
enableLogging: boolean = false
|
|
||||||
): { isResumption: boolean; hasSNI: boolean } {
|
|
||||||
const logger = enableLogging ?
|
|
||||||
(message: string) => console.log(`[Session Resumption] ${message}`) :
|
|
||||||
undefined;
|
|
||||||
|
|
||||||
return ClientHelloParser.hasSessionResumption(buffer, logger);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Detects characteristics of a tab reactivation TLS handshake
|
|
||||||
* These often have specific patterns in Chrome and other browsers
|
|
||||||
*
|
|
||||||
* @param buffer - The buffer containing a ClientHello message
|
|
||||||
* @param enableLogging - Whether to enable logging
|
|
||||||
* @returns true if this appears to be a tab reactivation handshake
|
|
||||||
*/
|
|
||||||
public static isTabReactivationHandshake(
|
|
||||||
buffer: Buffer,
|
|
||||||
enableLogging: boolean = false
|
|
||||||
): boolean {
|
|
||||||
const logger = enableLogging ?
|
|
||||||
(message: string) => console.log(`[Tab Reactivation] ${message}`) :
|
|
||||||
undefined;
|
|
||||||
|
|
||||||
return ClientHelloParser.isTabReactivationHandshake(buffer, logger);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Extracts the SNI (Server Name Indication) from a TLS ClientHello message.
|
|
||||||
* Implements robust parsing with support for session resumption edge cases.
|
|
||||||
*
|
|
||||||
* @param buffer - The buffer containing the TLS ClientHello message
|
|
||||||
* @param enableLogging - Whether to enable detailed debug logging
|
|
||||||
* @returns The extracted server name or undefined if not found
|
|
||||||
*/
|
|
||||||
public static extractSNI(buffer: Buffer, enableLogging: boolean = false): string | undefined {
|
|
||||||
const logger = enableLogging ?
|
|
||||||
(message: string) => console.log(`[SNI Extraction] ${message}`) :
|
|
||||||
undefined;
|
|
||||||
|
|
||||||
return SniExtraction.extractSNI(buffer, logger);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Attempts to extract SNI from the PSK extension in a TLS 1.3 ClientHello.
|
|
||||||
*
|
|
||||||
* In TLS 1.3, when a client attempts to resume a session, it may include
|
|
||||||
* the server name in the PSK identity hint rather than in the SNI extension.
|
|
||||||
*
|
|
||||||
* @param buffer - The buffer containing the TLS ClientHello message
|
|
||||||
* @param enableLogging - Whether to enable detailed debug logging
|
|
||||||
* @returns The extracted server name or undefined if not found
|
|
||||||
*/
|
|
||||||
public static extractSNIFromPSKExtension(
|
|
||||||
buffer: Buffer,
|
|
||||||
enableLogging: boolean = false
|
|
||||||
): string | undefined {
|
|
||||||
const logger = enableLogging ?
|
|
||||||
(message: string) => console.log(`[PSK-SNI Extraction] ${message}`) :
|
|
||||||
undefined;
|
|
||||||
|
|
||||||
return SniExtraction.extractSNIFromPSKExtension(buffer, logger);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Checks if the buffer contains TLS 1.3 early data (0-RTT)
|
|
||||||
* @param buffer - The buffer to check
|
|
||||||
* @param enableLogging - Whether to enable logging
|
|
||||||
* @returns true if early data is detected
|
|
||||||
*/
|
|
||||||
public static hasEarlyData(buffer: Buffer, enableLogging: boolean = false): boolean {
|
|
||||||
// This functionality has been moved to ClientHelloParser
|
|
||||||
// We can implement it in terms of the parse result if needed
|
|
||||||
const logger = enableLogging ?
|
|
||||||
(message: string) => console.log(`[Early Data] ${message}`) :
|
|
||||||
undefined;
|
|
||||||
|
|
||||||
const parseResult = ClientHelloParser.parseClientHello(buffer, logger);
|
|
||||||
return parseResult.isValid && parseResult.hasEarlyData;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Attempts to extract SNI from an initial ClientHello packet and handles
|
|
||||||
* session resumption edge cases more robustly than the standard extraction.
|
|
||||||
*
|
|
||||||
* This method handles:
|
|
||||||
* 1. Standard SNI extraction
|
|
||||||
* 2. TLS 1.3 PSK-based resumption (Chrome, Firefox, etc.)
|
|
||||||
* 3. Session ticket-based resumption
|
|
||||||
* 4. Fragmented ClientHello messages
|
|
||||||
* 5. TLS 1.3 Early Data (0-RTT)
|
|
||||||
* 6. Chrome's connection racing behaviors
|
|
||||||
*
|
|
||||||
* @param buffer - The buffer containing the TLS ClientHello message
|
|
||||||
* @param connectionInfo - Optional connection information for fragment handling
|
|
||||||
* @param enableLogging - Whether to enable detailed debug logging
|
|
||||||
* @returns The extracted server name or undefined if not found or more data needed
|
|
||||||
*/
|
|
||||||
public static extractSNIWithResumptionSupport(
|
|
||||||
buffer: Buffer,
|
|
||||||
connectionInfo?: {
|
|
||||||
sourceIp?: string;
|
|
||||||
sourcePort?: number;
|
|
||||||
destIp?: string;
|
|
||||||
destPort?: number;
|
|
||||||
},
|
|
||||||
enableLogging: boolean = false
|
|
||||||
): string | undefined {
|
|
||||||
const logger = enableLogging ?
|
|
||||||
(message: string) => console.log(`[SNI Extraction] ${message}`) :
|
|
||||||
undefined;
|
|
||||||
|
|
||||||
return SniExtraction.extractSNIWithResumptionSupport(
|
|
||||||
buffer,
|
|
||||||
connectionInfo as ConnectionInfo,
|
|
||||||
logger
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Main entry point for SNI extraction that handles all edge cases.
|
|
||||||
* This should be called for each TLS packet received from a client.
|
|
||||||
*
|
|
||||||
* The method uses connection tracking to handle fragmented ClientHello
|
|
||||||
* messages and various TLS 1.3 behaviors, including Chrome's connection
|
|
||||||
* racing patterns and tab reactivation behaviors.
|
|
||||||
*
|
|
||||||
* @param buffer - The buffer containing TLS data
|
|
||||||
* @param connectionInfo - Connection metadata (IPs and ports)
|
|
||||||
* @param enableLogging - Whether to enable detailed debug logging
|
|
||||||
* @param cachedSni - Optional cached SNI from previous connections (for racing detection)
|
|
||||||
* @returns The extracted server name or undefined if not found or more data needed
|
|
||||||
*/
|
|
||||||
public static processTlsPacket(
|
|
||||||
buffer: Buffer,
|
|
||||||
connectionInfo: {
|
|
||||||
sourceIp: string;
|
|
||||||
sourcePort: number;
|
|
||||||
destIp: string;
|
|
||||||
destPort: number;
|
|
||||||
timestamp?: number;
|
|
||||||
},
|
|
||||||
enableLogging: boolean = false,
|
|
||||||
cachedSni?: string
|
|
||||||
): string | undefined {
|
|
||||||
const logger = enableLogging ?
|
|
||||||
(message: string) => console.log(`[TLS Packet] ${message}`) :
|
|
||||||
undefined;
|
|
||||||
|
|
||||||
return SniExtraction.processTlsPacket(buffer, connectionInfo, logger, cachedSni);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Reference in New Issue
Block a user