feat(rustproxy): introduce a Rust-powered proxy engine and workspace with core crates for proxy functionality, ACME/TLS support, passthrough and HTTP proxies, metrics, nftables integration, routing/security, management IPC, tests, and README updates
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -17,4 +17,5 @@ dist/
|
||||
dist_*/
|
||||
|
||||
#------# custom
|
||||
.claude/*
|
||||
.claude/*
|
||||
rust/target
|
||||
15
changelog.md
15
changelog.md
@@ -1,5 +1,20 @@
|
||||
# Changelog
|
||||
|
||||
## 2026-02-09 - 22.5.0 - feat(rustproxy)
|
||||
introduce a Rust-powered proxy engine and workspace with core crates for proxy functionality, ACME/TLS support, passthrough and HTTP proxies, metrics, nftables integration, routing/security, management IPC, tests, and README updates
|
||||
|
||||
- Add Rust workspace and multiple crates: rustproxy, rustproxy-config, rustproxy-routing, rustproxy-tls, rustproxy-passthrough, rustproxy-http, rustproxy-nftables, rustproxy-metrics, rustproxy-security
|
||||
- Implement ACME integration (instant-acme) and an HTTP-01 challenge server with certificate lifecycle management
|
||||
- Add TLS management: cert store, cert manager, SNI resolver, TLS acceptor/connector and certificate hot-swap support
|
||||
- Implement TCP/TLS passthrough engine with ClientHello SNI parsing, PROXY v1 support, connection tracking and bidirectional forwarder
|
||||
- Add Hyper-based HTTP proxy components: request/response filtering, CORS, auth, header templating and upstream selection with load balancing
|
||||
- Introduce metrics (throughput tracker, metrics collector) and log deduplication utilities
|
||||
- Implement nftables manager and rule builder (safe no-op behavior when not running as root)
|
||||
- Add route types, validation, helpers, route manager and matchers (domain/path/header/ip)
|
||||
- Provide management IPC (JSON over stdin/stdout) for TypeScript wrapper control (start/stop/add/remove ports, load certificates, etc.)
|
||||
- Include extensive unit and integration tests, test helpers, and an example Rust config.json
|
||||
- Update README to document the Rust-powered engine, new features and rustBinaryPath lookup
|
||||
|
||||
## 2026-01-31 - 22.4.2 - fix(tests)
|
||||
shorten long-lived connection test timeouts and update certificate metadata timestamps
|
||||
|
||||
|
||||
719
readme.md
719
readme.md
@@ -1,6 +1,6 @@
|
||||
# @push.rocks/smartproxy 🚀
|
||||
|
||||
**The Swiss Army Knife of Node.js Proxies** - A unified, high-performance proxy toolkit that handles everything from simple HTTP forwarding to complex enterprise routing scenarios.
|
||||
**A high-performance, Rust-powered proxy toolkit for Node.js** — unified route-based configuration for SSL/TLS termination, HTTP/HTTPS reverse proxying, WebSocket support, load balancing, custom protocol handlers, and kernel-level NFTables forwarding.
|
||||
|
||||
## 📦 Installation
|
||||
|
||||
@@ -16,22 +16,26 @@ For reporting bugs, issues, or security vulnerabilities, please visit [community
|
||||
|
||||
## 🎯 What is SmartProxy?
|
||||
|
||||
SmartProxy is a modern, production-ready proxy solution that brings order to the chaos of traffic management. Whether you're building microservices, deploying edge infrastructure, or need a battle-tested reverse proxy, SmartProxy has you covered.
|
||||
SmartProxy is a production-ready proxy solution that takes the complexity out of traffic management. Under the hood, all networking — TCP, TLS, HTTP reverse proxy, connection tracking, security enforcement, and NFTables — is handled by a **Rust engine** for maximum performance, while you configure everything through a clean TypeScript API with full type safety.
|
||||
|
||||
Whether you're building microservices, deploying edge infrastructure, or need a battle-tested reverse proxy with automatic Let's Encrypt certificates, SmartProxy has you covered.
|
||||
|
||||
### ⚡ Key Features
|
||||
|
||||
| Feature | Description |
|
||||
|---------|-------------|
|
||||
| 🦀 **Rust-Powered Engine** | All networking handled by a high-performance Rust binary via IPC |
|
||||
| 🔀 **Unified Route-Based Config** | Clean match/action patterns for intuitive traffic routing |
|
||||
| 🔒 **Automatic SSL/TLS** | Zero-config HTTPS with Let's Encrypt ACME integration |
|
||||
| 🎯 **Flexible Matching** | Route by port, domain, path, client IP, TLS version, or custom logic |
|
||||
| 🎯 **Flexible Matching** | Route by port, domain, path, client IP, TLS version, headers, or custom logic |
|
||||
| 🚄 **High-Performance** | Choose between user-space or kernel-level (NFTables) forwarding |
|
||||
| ⚖️ **Load Balancing** | Distribute traffic with health checks and multiple algorithms |
|
||||
| 🛡️ **Enterprise Security** | IP filtering, rate limiting, authentication, connection limits |
|
||||
| ⚖️ **Load Balancing** | Round-robin, least-connections, IP-hash with health checks |
|
||||
| 🛡️ **Enterprise Security** | IP filtering, rate limiting, basic auth, JWT auth, connection limits |
|
||||
| 🔌 **WebSocket Support** | First-class WebSocket proxying with ping/pong keep-alive |
|
||||
| 🎮 **Custom Protocols** | Socket handlers for implementing any protocol |
|
||||
| 🎮 **Custom Protocols** | Socket handlers for implementing any protocol in TypeScript |
|
||||
| 📊 **Live Metrics** | Real-time throughput, connection counts, and performance data |
|
||||
| 🔧 **Dynamic Management** | Add/remove ports and routes at runtime without restarts |
|
||||
| 🔄 **PROXY Protocol** | Full PROXY protocol v1/v2 support for preserving client information |
|
||||
|
||||
## 🚀 Quick Start
|
||||
|
||||
@@ -43,16 +47,16 @@ import { SmartProxy, createCompleteHttpsServer } from '@push.rocks/smartproxy';
|
||||
// Create a proxy with automatic HTTPS
|
||||
const proxy = new SmartProxy({
|
||||
acme: {
|
||||
email: 'ssl@yourdomain.com', // Your email for Let's Encrypt
|
||||
useProduction: true // Use production servers
|
||||
email: 'ssl@yourdomain.com',
|
||||
useProduction: true
|
||||
},
|
||||
routes: [
|
||||
// Complete HTTPS setup in one line! ✨
|
||||
// Complete HTTPS setup in one call! ✨
|
||||
...createCompleteHttpsServer('app.example.com', {
|
||||
host: 'localhost',
|
||||
port: 3000
|
||||
}, {
|
||||
certificate: 'auto' // Magic! 🎩
|
||||
certificate: 'auto' // Automatic Let's Encrypt cert 🎩
|
||||
})
|
||||
]
|
||||
});
|
||||
@@ -84,10 +88,11 @@ SmartProxy uses a powerful **match/action** pattern that makes routing predictab
|
||||
```
|
||||
|
||||
Every route consists of:
|
||||
- **Match** - What traffic to capture (ports, domains, paths, IPs)
|
||||
- **Action** - What to do with it (forward, redirect, block, socket-handler)
|
||||
- **Security** (optional) - Access controls, rate limits, authentication
|
||||
- **Name/Priority** (optional) - For identification and ordering
|
||||
- **Match** — What traffic to capture (ports, domains, paths, IPs, headers)
|
||||
- **Action** — What to do with it (`forward` or `socket-handler`)
|
||||
- **Security** (optional) — IP allow/block lists, rate limits, authentication
|
||||
- **Headers** (optional) — Request/response header manipulation with template variables
|
||||
- **Name/Priority** (optional) — For identification and ordering
|
||||
|
||||
### 🔄 TLS Modes
|
||||
|
||||
@@ -95,8 +100,8 @@ SmartProxy supports three TLS handling modes:
|
||||
|
||||
| Mode | Description | Use Case |
|
||||
|------|-------------|----------|
|
||||
| `passthrough` | Forward encrypted traffic as-is | Backend handles TLS |
|
||||
| `terminate` | Decrypt at proxy, forward plain | Standard reverse proxy |
|
||||
| `passthrough` | Forward encrypted traffic as-is (SNI-based routing) | Backend handles TLS |
|
||||
| `terminate` | Decrypt at proxy, forward plain HTTP to backend | Standard reverse proxy |
|
||||
| `terminate-and-reencrypt` | Decrypt, then re-encrypt to backend | Zero-trust environments |
|
||||
|
||||
## 💡 Common Use Cases
|
||||
@@ -116,53 +121,61 @@ const proxy = new SmartProxy({
|
||||
### ⚖️ Load Balancer with Health Checks
|
||||
|
||||
```typescript
|
||||
import { createLoadBalancerRoute } from '@push.rocks/smartproxy';
|
||||
import { SmartProxy, createLoadBalancerRoute } from '@push.rocks/smartproxy';
|
||||
|
||||
const route = createLoadBalancerRoute(
|
||||
'app.example.com',
|
||||
[
|
||||
{ host: 'server1.internal', port: 8080 },
|
||||
{ host: 'server2.internal', port: 8080 },
|
||||
{ host: 'server3.internal', port: 8080 }
|
||||
],
|
||||
{
|
||||
tls: { mode: 'terminate', certificate: 'auto' },
|
||||
loadBalancing: {
|
||||
algorithm: 'round-robin',
|
||||
healthCheck: {
|
||||
path: '/health',
|
||||
interval: 30000,
|
||||
timeout: 5000
|
||||
const proxy = new SmartProxy({
|
||||
routes: [
|
||||
createLoadBalancerRoute(
|
||||
'app.example.com',
|
||||
[
|
||||
{ host: 'server1.internal', port: 8080 },
|
||||
{ host: 'server2.internal', port: 8080 },
|
||||
{ host: 'server3.internal', port: 8080 }
|
||||
],
|
||||
{
|
||||
tls: { mode: 'terminate', certificate: 'auto' },
|
||||
loadBalancing: {
|
||||
algorithm: 'round-robin',
|
||||
healthCheck: {
|
||||
path: '/health',
|
||||
interval: 30000,
|
||||
timeout: 5000
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
);
|
||||
)
|
||||
]
|
||||
});
|
||||
```
|
||||
|
||||
### 🔌 WebSocket Proxy
|
||||
|
||||
```typescript
|
||||
import { createWebSocketRoute } from '@push.rocks/smartproxy';
|
||||
import { SmartProxy, createWebSocketRoute } from '@push.rocks/smartproxy';
|
||||
|
||||
const route = createWebSocketRoute(
|
||||
'ws.example.com',
|
||||
{ host: 'websocket-server', port: 8080 },
|
||||
{
|
||||
path: '/socket',
|
||||
useTls: true,
|
||||
certificate: 'auto',
|
||||
pingInterval: 30000, // Keep connections alive
|
||||
pingTimeout: 10000
|
||||
}
|
||||
);
|
||||
const proxy = new SmartProxy({
|
||||
routes: [
|
||||
createWebSocketRoute(
|
||||
'ws.example.com',
|
||||
{ host: 'websocket-server', port: 8080 },
|
||||
{
|
||||
path: '/socket',
|
||||
useTls: true,
|
||||
certificate: 'auto',
|
||||
pingInterval: 30000,
|
||||
pingTimeout: 10000
|
||||
}
|
||||
)
|
||||
]
|
||||
});
|
||||
```
|
||||
|
||||
### 🚦 API Gateway with Rate Limiting
|
||||
|
||||
```typescript
|
||||
import { createApiGatewayRoute, addRateLimiting } from '@push.rocks/smartproxy';
|
||||
import { SmartProxy, createApiGatewayRoute, addRateLimiting } from '@push.rocks/smartproxy';
|
||||
|
||||
let route = createApiGatewayRoute(
|
||||
let apiRoute = createApiGatewayRoute(
|
||||
'api.example.com',
|
||||
'/api',
|
||||
{ host: 'api-backend', port: 8080 },
|
||||
@@ -173,20 +186,22 @@ let route = createApiGatewayRoute(
|
||||
}
|
||||
);
|
||||
|
||||
// Add rate limiting - 100 requests per minute per IP
|
||||
route = addRateLimiting(route, {
|
||||
// Add rate limiting — 100 requests per minute per IP
|
||||
apiRoute = addRateLimiting(apiRoute, {
|
||||
maxRequests: 100,
|
||||
window: 60,
|
||||
keyBy: 'ip'
|
||||
});
|
||||
|
||||
const proxy = new SmartProxy({ routes: [apiRoute] });
|
||||
```
|
||||
|
||||
### 🎮 Custom Protocol Handler
|
||||
|
||||
SmartProxy lets you implement any protocol with full socket control:
|
||||
SmartProxy lets you implement any protocol with full socket control. Routes with JavaScript socket handlers are automatically relayed from the Rust engine back to your TypeScript code:
|
||||
|
||||
```typescript
|
||||
import { createSocketHandlerRoute, SocketHandlers } from '@push.rocks/smartproxy';
|
||||
import { SmartProxy, createSocketHandlerRoute, SocketHandlers } from '@push.rocks/smartproxy';
|
||||
|
||||
// Use pre-built handlers
|
||||
const echoRoute = createSocketHandlerRoute(
|
||||
@@ -214,18 +229,21 @@ const customRoute = createSocketHandlerRoute(
|
||||
});
|
||||
}
|
||||
);
|
||||
|
||||
const proxy = new SmartProxy({ routes: [echoRoute, customRoute] });
|
||||
```
|
||||
|
||||
**Pre-built Socket Handlers:**
|
||||
|
||||
| Handler | Description |
|
||||
|---------|-------------|
|
||||
| `SocketHandlers.echo` | Echo server - returns everything sent |
|
||||
| `SocketHandlers.echo` | Echo server — returns everything sent |
|
||||
| `SocketHandlers.proxy(host, port)` | TCP proxy to another server |
|
||||
| `SocketHandlers.lineProtocol(handler)` | Line-based text protocol |
|
||||
| `SocketHandlers.httpResponse(code, body)` | Simple HTTP response |
|
||||
| `SocketHandlers.httpRedirect(url, code)` | HTTP redirect with templates |
|
||||
| `SocketHandlers.httpRedirect(url, code)` | HTTP redirect with template variables (`{domain}`, `{path}`, `{port}`, `{clientIp}`) |
|
||||
| `SocketHandlers.httpServer(handler)` | Full HTTP request/response handling |
|
||||
| `SocketHandlers.httpBlock(status, message)` | HTTP block response |
|
||||
| `SocketHandlers.block(message)` | Block with optional message |
|
||||
|
||||
### ⚡ High-Performance NFTables Forwarding
|
||||
@@ -233,48 +251,73 @@ const customRoute = createSocketHandlerRoute(
|
||||
For ultra-low latency on Linux, use kernel-level forwarding (requires root):
|
||||
|
||||
```typescript
|
||||
import { createNfTablesTerminateRoute } from '@push.rocks/smartproxy';
|
||||
import { SmartProxy, createNfTablesTerminateRoute } from '@push.rocks/smartproxy';
|
||||
|
||||
const route = createNfTablesTerminateRoute(
|
||||
'fast.example.com',
|
||||
{ host: 'backend', port: 8080 },
|
||||
{
|
||||
ports: 443,
|
||||
certificate: 'auto',
|
||||
preserveSourceIP: true, // Backend sees real client IP
|
||||
maxRate: '1gbps' // QoS rate limiting
|
||||
}
|
||||
);
|
||||
const proxy = new SmartProxy({
|
||||
routes: [
|
||||
createNfTablesTerminateRoute(
|
||||
'fast.example.com',
|
||||
{ host: 'backend', port: 8080 },
|
||||
{
|
||||
ports: 443,
|
||||
certificate: 'auto',
|
||||
preserveSourceIP: true, // Backend sees real client IP
|
||||
maxRate: '1gbps' // QoS rate limiting
|
||||
}
|
||||
)
|
||||
]
|
||||
});
|
||||
```
|
||||
|
||||
### 🔒 SNI Passthrough (TLS Passthrough)
|
||||
|
||||
Forward encrypted traffic to backends without terminating TLS — the proxy routes based on the SNI hostname alone:
|
||||
|
||||
```typescript
|
||||
import { SmartProxy, createHttpsPassthroughRoute } from '@push.rocks/smartproxy';
|
||||
|
||||
const proxy = new SmartProxy({
|
||||
routes: [
|
||||
createHttpsPassthroughRoute('secure.example.com', {
|
||||
host: 'backend-that-handles-tls',
|
||||
port: 8443
|
||||
})
|
||||
]
|
||||
});
|
||||
```
|
||||
|
||||
## 🔧 Advanced Features
|
||||
|
||||
### 🎯 Dynamic Routing
|
||||
|
||||
Route traffic based on runtime conditions:
|
||||
Route traffic based on runtime conditions using function-based host/port resolution:
|
||||
|
||||
```typescript
|
||||
{
|
||||
name: 'business-hours-only',
|
||||
match: {
|
||||
ports: 443,
|
||||
domains: 'internal.example.com'
|
||||
},
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{
|
||||
host: (context) => {
|
||||
// Dynamic host selection based on path
|
||||
return context.path?.startsWith('/premium')
|
||||
? 'premium-backend'
|
||||
: 'standard-backend';
|
||||
},
|
||||
port: 8080
|
||||
}]
|
||||
}
|
||||
}
|
||||
const proxy = new SmartProxy({
|
||||
routes: [{
|
||||
name: 'dynamic-backend',
|
||||
match: {
|
||||
ports: 443,
|
||||
domains: 'app.example.com'
|
||||
},
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{
|
||||
host: (context) => {
|
||||
return context.path?.startsWith('/premium')
|
||||
? 'premium-backend'
|
||||
: 'standard-backend';
|
||||
},
|
||||
port: 8080
|
||||
}],
|
||||
tls: { mode: 'terminate', certificate: 'auto' }
|
||||
}
|
||||
}]
|
||||
});
|
||||
```
|
||||
|
||||
> **Note:** Routes with dynamic functions (host/port callbacks) are automatically relayed through the TypeScript socket handler server, since JavaScript functions can't be serialized to Rust.
|
||||
|
||||
### 🔒 Security Controls
|
||||
|
||||
Comprehensive per-route security options:
|
||||
@@ -285,7 +328,8 @@ Comprehensive per-route security options:
|
||||
match: { ports: 443, domains: 'api.example.com' },
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{ host: 'api-backend', port: 8080 }]
|
||||
targets: [{ host: 'api-backend', port: 8080 }],
|
||||
tls: { mode: 'terminate', certificate: 'auto' }
|
||||
},
|
||||
security: {
|
||||
// IP-based access control
|
||||
@@ -294,17 +338,31 @@ Comprehensive per-route security options:
|
||||
|
||||
// Connection limits
|
||||
maxConnections: 1000,
|
||||
maxConnectionsPerIp: 10,
|
||||
|
||||
// Rate limiting
|
||||
rateLimit: {
|
||||
enabled: true,
|
||||
maxRequests: 100,
|
||||
windowMs: 60000
|
||||
}
|
||||
window: 60
|
||||
},
|
||||
|
||||
// Authentication
|
||||
basicAuth: { users: [{ username: 'admin', password: 'secret' }] },
|
||||
jwtAuth: { secret: 'your-jwt-secret', algorithm: 'HS256' }
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Security modifier helpers** let you add security to any existing route:
|
||||
|
||||
```typescript
|
||||
import { addRateLimiting, addBasicAuth, addJwtAuth } from '@push.rocks/smartproxy';
|
||||
|
||||
let route = createHttpsTerminateRoute('api.example.com', { host: 'backend', port: 8080 });
|
||||
route = addRateLimiting(route, { maxRequests: 100, window: 60, keyBy: 'ip' });
|
||||
route = addBasicAuth(route, { users: [{ username: 'admin', password: 'secret' }] });
|
||||
```
|
||||
|
||||
### 📊 Runtime Management
|
||||
|
||||
Control your proxy without restarts:
|
||||
@@ -313,21 +371,26 @@ Control your proxy without restarts:
|
||||
// Dynamic port management
|
||||
await proxy.addListeningPort(8443);
|
||||
await proxy.removeListeningPort(8080);
|
||||
const ports = await proxy.getListeningPorts();
|
||||
|
||||
// Update routes on the fly
|
||||
// Update routes on the fly (atomic, mutex-locked)
|
||||
await proxy.updateRoutes([...newRoutes]);
|
||||
|
||||
// Monitor status
|
||||
const status = proxy.getStatus();
|
||||
console.log(`Active connections: ${status.activeConnections}`);
|
||||
|
||||
// Get detailed metrics
|
||||
// Get real-time metrics
|
||||
const metrics = proxy.getMetrics();
|
||||
console.log(`Throughput: ${metrics.throughput.bytesPerSecond} bytes/sec`);
|
||||
console.log(`Active connections: ${metrics.connections.active()}`);
|
||||
console.log(`Requests/sec: ${metrics.throughput.requestsPerSecond()}`);
|
||||
|
||||
// Get detailed statistics from the Rust engine
|
||||
const stats = await proxy.getStatistics();
|
||||
|
||||
// Certificate management
|
||||
const certInfo = proxy.getCertificateInfo('example.com');
|
||||
console.log(`Certificate expires: ${certInfo.expiresAt}`);
|
||||
await proxy.provisionCertificate('my-route-name');
|
||||
await proxy.renewCertificate('my-route-name');
|
||||
const certStatus = await proxy.getCertificateStatus('my-route-name');
|
||||
|
||||
// NFTables status
|
||||
const nftStatus = await proxy.getNfTablesStatus();
|
||||
```
|
||||
|
||||
### 🔄 Header Manipulation
|
||||
@@ -338,51 +401,107 @@ Transform requests and responses with template variables:
|
||||
{
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{ host: 'backend', port: 8080 }],
|
||||
headers: {
|
||||
request: {
|
||||
'X-Real-IP': '{clientIp}',
|
||||
'X-Request-ID': '{uuid}',
|
||||
'X-Forwarded-Proto': 'https'
|
||||
},
|
||||
response: {
|
||||
'X-Powered-By': 'SmartProxy',
|
||||
'Strict-Transport-Security': 'max-age=31536000',
|
||||
'X-Frame-Options': 'DENY'
|
||||
}
|
||||
targets: [{ host: 'backend', port: 8080 }]
|
||||
},
|
||||
headers: {
|
||||
request: {
|
||||
'X-Real-IP': '{clientIp}',
|
||||
'X-Request-ID': '{uuid}',
|
||||
'X-Forwarded-Proto': 'https'
|
||||
},
|
||||
response: {
|
||||
'Strict-Transport-Security': 'max-age=31536000',
|
||||
'X-Frame-Options': 'DENY'
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 🔀 PROXY Protocol Support
|
||||
|
||||
Preserve original client information through proxy chains:
|
||||
|
||||
```typescript
|
||||
const proxy = new SmartProxy({
|
||||
// Accept PROXY protocol from trusted load balancers
|
||||
acceptProxyProtocol: true,
|
||||
proxyIPs: ['10.0.0.1', '10.0.0.2'],
|
||||
|
||||
// Forward PROXY protocol to backends
|
||||
sendProxyProtocol: true,
|
||||
|
||||
routes: [...]
|
||||
});
|
||||
```
|
||||
|
||||
### 🏗️ Custom Certificate Provisioning
|
||||
|
||||
Supply your own certificates or integrate with external certificate providers:
|
||||
|
||||
```typescript
|
||||
const proxy = new SmartProxy({
|
||||
certProvisionFunction: async (domain: string) => {
|
||||
// Return 'http01' to let the built-in ACME handle it
|
||||
if (domain.endsWith('.example.com')) return 'http01';
|
||||
|
||||
// Or return a static certificate object
|
||||
return {
|
||||
publicKey: myPemCert,
|
||||
privateKey: myPemKey,
|
||||
};
|
||||
},
|
||||
certProvisionFallbackToAcme: true, // Fall back to ACME if callback fails
|
||||
routes: [...]
|
||||
});
|
||||
```
|
||||
|
||||
## 🏛️ Architecture
|
||||
|
||||
SmartProxy is built with a modular, extensible architecture:
|
||||
SmartProxy uses a hybrid **Rust + TypeScript** architecture:
|
||||
|
||||
```
|
||||
SmartProxy
|
||||
├── 📋 RouteManager # Route matching and prioritization
|
||||
├── 🔌 PortManager # Dynamic port lifecycle management
|
||||
├── 🔒 SmartCertManager # ACME/Let's Encrypt automation
|
||||
├── 🚦 ConnectionManager # Connection pooling and tracking
|
||||
├── 📊 MetricsCollector # Real-time performance monitoring
|
||||
├── 🛡️ SecurityManager # Access control and rate limiting
|
||||
├── 🔧 ProtocolDetector # Smart HTTP/TLS/WebSocket detection
|
||||
├── ⚡ NFTablesManager # Kernel-level forwarding (Linux)
|
||||
└── 🌐 HttpProxyBridge # HTTP/HTTPS request handling
|
||||
┌─────────────────────────────────────────────────────┐
|
||||
│ Your Application │
|
||||
│ (TypeScript — routes, config, socket handlers) │
|
||||
└──────────────────┬──────────────────────────────────┘
|
||||
│ IPC (JSON over stdin/stdout)
|
||||
┌──────────────────▼──────────────────────────────────┐
|
||||
│ Rust Proxy Engine │
|
||||
│ ┌─────────┐ ┌─────────┐ ┌─────────┐ ┌──────────┐ │
|
||||
│ │ TCP/TLS │ │ HTTP │ │ Route │ │ ACME │ │
|
||||
│ │ Listener│ │ Reverse │ │ Matcher │ │ Cert Mgr │ │
|
||||
│ │ │ │ Proxy │ │ │ │ │ │
|
||||
│ └─────────┘ └─────────┘ └─────────┘ └──────────┘ │
|
||||
│ ┌─────────┐ ┌─────────┐ ┌─────────┐ ┌──────────┐ │
|
||||
│ │ Security│ │ Metrics │ │ Connec- │ │ NFTables │ │
|
||||
│ │ Enforce │ │ Collect │ │ tion │ │ Mgr │ │
|
||||
│ │ │ │ │ │ Tracker │ │ │ │
|
||||
│ └─────────┘ └─────────┘ └─────────┘ └──────────┘ │
|
||||
└──────────────────┬──────────────────────────────────┘
|
||||
│ Unix Socket Relay
|
||||
┌──────────────────▼──────────────────────────────────┐
|
||||
│ TypeScript Socket Handler Server │
|
||||
│ (for JS-defined socket handlers & dynamic routes) │
|
||||
└─────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
- **Rust Engine** handles all networking, TLS, HTTP proxying, connection management, security, and metrics
|
||||
- **TypeScript** provides the npm API, configuration types, route helpers, validation, and socket handler callbacks
|
||||
- **IPC** — JSON commands/events over stdin/stdout for seamless cross-language communication
|
||||
- **Socket Relay** — a Unix domain socket server for routes requiring TypeScript-side handling (socket handlers, dynamic host/port functions)
|
||||
|
||||
## 🎯 Route Configuration Reference
|
||||
|
||||
### Match Criteria
|
||||
|
||||
```typescript
|
||||
interface IRouteMatch {
|
||||
ports: number | number[] | string; // 80, [80, 443], '8000-8999'
|
||||
domains?: string | string[]; // 'example.com', '*.example.com'
|
||||
path?: string; // '/api/*', '/users/:id'
|
||||
clientIp?: string | string[]; // '10.0.0.0/8', ['192.168.*']
|
||||
tlsVersion?: string | string[]; // ['TLSv1.2', 'TLSv1.3']
|
||||
ports: number | number[] | Array<{ from: number; to: number }>; // Port(s) to listen on
|
||||
domains?: string | string[]; // 'example.com', '*.example.com'
|
||||
path?: string; // '/api/*', '/users/:id'
|
||||
clientIp?: string[]; // ['10.0.0.0/8', '192.168.*']
|
||||
tlsVersion?: string[]; // ['TLSv1.2', 'TLSv1.3']
|
||||
headers?: Record<string, string | RegExp>; // Match by HTTP headers
|
||||
}
|
||||
```
|
||||
|
||||
@@ -390,69 +509,251 @@ interface IRouteMatch {
|
||||
|
||||
| Type | Description |
|
||||
|------|-------------|
|
||||
| `forward` | Proxy to one or more backend targets |
|
||||
| `redirect` | HTTP redirect with status code |
|
||||
| `block` | Block the connection |
|
||||
| `socket-handler` | Custom socket handling function |
|
||||
| `forward` | Proxy to one or more backend targets (with optional TLS, WebSocket, load balancing) |
|
||||
| `socket-handler` | Custom socket handling function in TypeScript |
|
||||
|
||||
### Target Options
|
||||
|
||||
```typescript
|
||||
interface IRouteTarget {
|
||||
host: string | string[] | ((context: IRouteContext) => string);
|
||||
port: number | 'preserve' | ((context: IRouteContext) => number);
|
||||
tls?: { ... }; // Per-target TLS override
|
||||
priority?: number; // Target priority
|
||||
match?: ITargetMatch; // Sub-match within a route (by port, path, headers, method)
|
||||
}
|
||||
```
|
||||
|
||||
### TLS Options
|
||||
|
||||
```typescript
|
||||
interface IRouteTls {
|
||||
mode: 'passthrough' | 'terminate' | 'terminate-and-reencrypt';
|
||||
certificate: 'auto' | { key: string; cert: string };
|
||||
// For terminate-and-reencrypt:
|
||||
reencrypt?: {
|
||||
host: string;
|
||||
port: number;
|
||||
ca?: string; // Custom CA for backend
|
||||
certificate: 'auto' | {
|
||||
key: string;
|
||||
cert: string;
|
||||
ca?: string;
|
||||
keyFile?: string;
|
||||
certFile?: string;
|
||||
};
|
||||
acme?: {
|
||||
email: string;
|
||||
useProduction?: boolean;
|
||||
challengePort?: number;
|
||||
renewBeforeDays?: number;
|
||||
};
|
||||
versions?: string[];
|
||||
ciphers?: string[];
|
||||
honorCipherOrder?: boolean;
|
||||
sessionTimeout?: number;
|
||||
}
|
||||
```
|
||||
|
||||
### WebSocket Options
|
||||
|
||||
```typescript
|
||||
interface IRouteWebSocket {
|
||||
enabled: boolean;
|
||||
pingInterval?: number; // ms between pings
|
||||
pingTimeout?: number; // ms to wait for pong
|
||||
maxPayloadSize?: number; // Maximum frame payload
|
||||
subprotocols?: string[]; // Allowed subprotocols
|
||||
allowedOrigins?: string[]; // CORS origins
|
||||
}
|
||||
```
|
||||
|
||||
### Load Balancing Options
|
||||
|
||||
```typescript
|
||||
interface IRouteLoadBalancing {
|
||||
algorithm: 'round-robin' | 'least-connections' | 'ip-hash';
|
||||
healthCheck?: {
|
||||
path: string;
|
||||
interval: number; // ms
|
||||
timeout: number; // ms
|
||||
unhealthyThreshold?: number;
|
||||
healthyThreshold?: number;
|
||||
};
|
||||
}
|
||||
```
|
||||
|
||||
## 🛠️ Helper Functions Reference
|
||||
|
||||
All helpers are fully typed and documented:
|
||||
All helpers are fully typed and return `IRouteConfig` or `IRouteConfig[]`:
|
||||
|
||||
```typescript
|
||||
import {
|
||||
// HTTP/HTTPS
|
||||
createHttpRoute,
|
||||
createHttpsTerminateRoute,
|
||||
createHttpsPassthroughRoute,
|
||||
createHttpToHttpsRedirect,
|
||||
createCompleteHttpsServer,
|
||||
createHttpRoute, // Plain HTTP route
|
||||
createHttpsTerminateRoute, // HTTPS with TLS termination
|
||||
createHttpsPassthroughRoute, // SNI passthrough (no termination)
|
||||
createHttpToHttpsRedirect, // HTTP → HTTPS redirect
|
||||
createCompleteHttpsServer, // HTTPS + redirect combo (returns IRouteConfig[])
|
||||
|
||||
// Load Balancing
|
||||
createLoadBalancerRoute,
|
||||
createSmartLoadBalancer,
|
||||
createLoadBalancerRoute, // Multi-backend with health checks
|
||||
createSmartLoadBalancer, // Dynamic domain-based backend selection
|
||||
|
||||
// API & WebSocket
|
||||
createApiRoute,
|
||||
createApiGatewayRoute,
|
||||
createWebSocketRoute,
|
||||
createApiRoute, // API route with path matching
|
||||
createApiGatewayRoute, // API gateway with CORS
|
||||
createWebSocketRoute, // WebSocket-enabled route
|
||||
|
||||
// Custom Protocols
|
||||
createSocketHandlerRoute,
|
||||
SocketHandlers,
|
||||
createSocketHandlerRoute, // Custom socket handler
|
||||
SocketHandlers, // Pre-built handlers (echo, proxy, block, etc.)
|
||||
|
||||
// NFTables (Linux)
|
||||
createNfTablesRoute,
|
||||
createNfTablesTerminateRoute,
|
||||
createCompleteNfTablesHttpsServer,
|
||||
// NFTables (Linux, requires root)
|
||||
createNfTablesRoute, // Kernel-level packet forwarding
|
||||
createNfTablesTerminateRoute, // NFTables + TLS termination
|
||||
createCompleteNfTablesHttpsServer, // NFTables HTTPS + redirect combo
|
||||
|
||||
// Dynamic Routing
|
||||
createPortMappingRoute,
|
||||
createOffsetPortMappingRoute,
|
||||
createDynamicRoute,
|
||||
createPortMappingRoute, // Port mapping with context
|
||||
createOffsetPortMappingRoute, // Simple port offset
|
||||
createDynamicRoute, // Dynamic host/port via functions
|
||||
|
||||
// Security Modifiers
|
||||
addRateLimiting,
|
||||
addBasicAuth,
|
||||
addJwtAuth
|
||||
addRateLimiting, // Add rate limiting to any route
|
||||
addBasicAuth, // Add basic auth to any route
|
||||
addJwtAuth, // Add JWT auth to any route
|
||||
|
||||
// Route Utilities
|
||||
mergeRouteConfigs, // Deep-merge two route configs
|
||||
findMatchingRoutes, // Find routes matching criteria
|
||||
findBestMatchingRoute, // Find best matching route
|
||||
cloneRoute, // Deep-clone a route
|
||||
generateRouteId, // Generate deterministic route ID
|
||||
RouteValidator, // Validate route configurations
|
||||
} from '@push.rocks/smartproxy';
|
||||
```
|
||||
|
||||
## 📖 API Documentation
|
||||
|
||||
### SmartProxy Class
|
||||
|
||||
```typescript
|
||||
class SmartProxy extends EventEmitter {
|
||||
constructor(options: ISmartProxyOptions);
|
||||
|
||||
// Lifecycle
|
||||
start(): Promise<void>;
|
||||
stop(): Promise<void>;
|
||||
|
||||
// Route Management (atomic, mutex-locked)
|
||||
updateRoutes(routes: IRouteConfig[]): Promise<void>;
|
||||
|
||||
// Port Management
|
||||
addListeningPort(port: number): Promise<void>;
|
||||
removeListeningPort(port: number): Promise<void>;
|
||||
getListeningPorts(): Promise<number[]>;
|
||||
|
||||
// Monitoring & Metrics
|
||||
getMetrics(): IMetrics; // Sync — returns cached metrics adapter
|
||||
getStatistics(): Promise<any>; // Async — queries Rust engine
|
||||
|
||||
// Certificate Management
|
||||
provisionCertificate(routeName: string): Promise<void>;
|
||||
renewCertificate(routeName: string): Promise<void>;
|
||||
getCertificateStatus(routeName: string): Promise<any>;
|
||||
getEligibleDomainsForCertificates(): string[];
|
||||
|
||||
// NFTables
|
||||
getNfTablesStatus(): Promise<Record<string, any>>;
|
||||
|
||||
// Events
|
||||
on(event: 'error', handler: (err: Error) => void): this;
|
||||
}
|
||||
```
|
||||
|
||||
### Configuration Options
|
||||
|
||||
```typescript
|
||||
interface ISmartProxyOptions {
|
||||
routes: IRouteConfig[]; // Required: array of route configs
|
||||
|
||||
// ACME/Let's Encrypt
|
||||
acme?: {
|
||||
email: string; // Contact email for Let's Encrypt
|
||||
useProduction?: boolean; // Use production servers (default: false)
|
||||
port?: number; // HTTP-01 challenge port (default: 80)
|
||||
renewThresholdDays?: number; // Days before expiry to renew (default: 30)
|
||||
autoRenew?: boolean; // Enable auto-renewal (default: true)
|
||||
certificateStore?: string; // Directory to store certs (default: './certs')
|
||||
renewCheckIntervalHours?: number; // Renewal check interval (default: 24)
|
||||
};
|
||||
|
||||
// Custom certificate provisioning
|
||||
certProvisionFunction?: (domain: string) => Promise<ICert | 'http01'>;
|
||||
certProvisionFallbackToAcme?: boolean; // Fall back to ACME on failure (default: true)
|
||||
|
||||
// Global defaults
|
||||
defaults?: {
|
||||
target?: { host: string; port: number };
|
||||
security?: { ipAllowList?: string[]; ipBlockList?: string[]; maxConnections?: number };
|
||||
};
|
||||
|
||||
// PROXY protocol
|
||||
proxyIPs?: string[]; // Trusted proxy IPs
|
||||
acceptProxyProtocol?: boolean; // Accept PROXY protocol headers
|
||||
sendProxyProtocol?: boolean; // Send PROXY protocol to targets
|
||||
|
||||
// Timeouts
|
||||
connectionTimeout?: number; // Backend connection timeout (default: 30s)
|
||||
initialDataTimeout?: number; // Initial data/SNI timeout (default: 120s)
|
||||
socketTimeout?: number; // Socket inactivity timeout (default: 1h)
|
||||
maxConnectionLifetime?: number; // Max connection lifetime (default: 24h)
|
||||
inactivityTimeout?: number; // Inactivity timeout (default: 4h)
|
||||
gracefulShutdownTimeout?: number; // Shutdown grace period (default: 30s)
|
||||
|
||||
// Connection limits
|
||||
maxConnectionsPerIP?: number; // Per-IP connection limit (default: 100)
|
||||
connectionRateLimitPerMinute?: number; // Per-IP rate limit (default: 300/min)
|
||||
|
||||
// Keep-alive
|
||||
keepAliveTreatment?: 'standard' | 'extended' | 'immortal';
|
||||
keepAliveInactivityMultiplier?: number; // (default: 6)
|
||||
extendedKeepAliveLifetime?: number; // (default: 7 days)
|
||||
|
||||
// Metrics
|
||||
metrics?: {
|
||||
enabled?: boolean;
|
||||
sampleIntervalMs?: number;
|
||||
retentionSeconds?: number;
|
||||
};
|
||||
|
||||
// Behavior
|
||||
enableDetailedLogging?: boolean; // Verbose connection logging
|
||||
enableTlsDebugLogging?: boolean; // TLS handshake debug logging
|
||||
|
||||
// Rust binary
|
||||
rustBinaryPath?: string; // Custom path to the Rust binary
|
||||
}
|
||||
```
|
||||
|
||||
### NfTablesProxy Class
|
||||
|
||||
A standalone class for managing nftables NAT rules directly (Linux only, requires root):
|
||||
|
||||
```typescript
|
||||
import { NfTablesProxy } from '@push.rocks/smartproxy';
|
||||
|
||||
const nftProxy = new NfTablesProxy({
|
||||
fromPorts: [80, 443],
|
||||
toHost: 'backend-server',
|
||||
toPorts: [8080, 8443],
|
||||
protocol: 'tcp',
|
||||
preserveSourceIP: true,
|
||||
enableIPv6: true,
|
||||
maxRate: '1gbps',
|
||||
useIPSets: true
|
||||
});
|
||||
|
||||
await nftProxy.start(); // Apply nftables rules
|
||||
const status = nftProxy.getStatus();
|
||||
await nftProxy.stop(); // Remove rules
|
||||
```
|
||||
|
||||
## 🐛 Troubleshooting
|
||||
|
||||
### Certificate Issues
|
||||
@@ -460,93 +761,41 @@ import {
|
||||
- ✅ Port 80 must be accessible for ACME HTTP-01 challenges
|
||||
- ✅ Check DNS propagation with `dig` or `nslookup`
|
||||
- ✅ Verify the email in ACME configuration is valid
|
||||
- ✅ Use `getCertificateStatus('route-name')` to check cert state
|
||||
|
||||
### Connection Problems
|
||||
- ✅ Check route priorities (higher number = matched first)
|
||||
- ✅ Verify security rules aren't blocking legitimate traffic
|
||||
- ✅ Test with `curl -v` for detailed connection output
|
||||
- ✅ Enable debug logging for verbose output
|
||||
- ✅ Enable debug logging with `enableDetailedLogging: true`
|
||||
|
||||
### Rust Binary Not Found
|
||||
SmartProxy searches for the Rust binary in this order:
|
||||
1. `SMARTPROXY_RUST_BINARY` environment variable
|
||||
2. Platform-specific npm package (`@push.rocks/smartproxy-linux-x64`, etc.)
|
||||
3. Local dev build (`./rust/target/release/rustproxy`)
|
||||
4. System PATH (`rustproxy`)
|
||||
|
||||
Set `rustBinaryPath` in options to override.
|
||||
|
||||
### Performance Tuning
|
||||
- ✅ Use NFTables forwarding for high-traffic routes (Linux only)
|
||||
- ✅ Enable connection keep-alive where appropriate
|
||||
- ✅ Monitor metrics to identify bottlenecks
|
||||
- ✅ Adjust `maxConnections` based on your server resources
|
||||
|
||||
### Debug Mode
|
||||
|
||||
```typescript
|
||||
const proxy = new SmartProxy({
|
||||
enableDetailedLogging: true, // Verbose connection logging
|
||||
routes: [...]
|
||||
});
|
||||
```
|
||||
- ✅ Use `getMetrics()` and `getStatistics()` to identify bottlenecks
|
||||
- ✅ Adjust `maxConnectionsPerIP` and `connectionRateLimitPerMinute` based on your workload
|
||||
- ✅ Use `passthrough` TLS mode when backend can handle TLS directly
|
||||
|
||||
## 🏆 Best Practices
|
||||
|
||||
1. **📝 Use Helper Functions** - They provide sensible defaults and prevent common mistakes
|
||||
2. **🎯 Set Route Priorities** - More specific routes should have higher priority values
|
||||
3. **🔒 Enable Security** - Always use IP filtering and rate limiting for public services
|
||||
4. **📊 Monitor Metrics** - Use the built-in metrics to identify issues early
|
||||
5. **🔄 Certificate Monitoring** - Set up alerts for certificate expiration
|
||||
6. **🛑 Graceful Shutdown** - Always call `proxy.stop()` for clean connection termination
|
||||
7. **🔧 Test Routes** - Validate your route configurations before deploying to production
|
||||
|
||||
## 📖 API Documentation
|
||||
|
||||
### SmartProxy Class
|
||||
|
||||
```typescript
|
||||
class SmartProxy {
|
||||
constructor(options: ISmartProxyOptions);
|
||||
|
||||
// Lifecycle
|
||||
start(): Promise<void>;
|
||||
stop(): Promise<void>;
|
||||
|
||||
// Route Management
|
||||
updateRoutes(routes: IRouteConfig[]): Promise<void>;
|
||||
|
||||
// Port Management
|
||||
addListeningPort(port: number): Promise<void>;
|
||||
removeListeningPort(port: number): Promise<void>;
|
||||
getListeningPorts(): number[];
|
||||
|
||||
// Monitoring
|
||||
getStatus(): IProxyStatus;
|
||||
getMetrics(): IMetrics;
|
||||
|
||||
// Certificate Management
|
||||
getCertificateInfo(domain: string): ICertStatus | null;
|
||||
}
|
||||
```
|
||||
|
||||
### Configuration Options
|
||||
|
||||
```typescript
|
||||
interface ISmartProxyOptions {
|
||||
routes: IRouteConfig[]; // Required: array of route configs
|
||||
|
||||
// ACME/Let's Encrypt
|
||||
acme?: {
|
||||
email: string; // Contact email
|
||||
useProduction?: boolean; // Use production servers (default: false)
|
||||
port?: number; // Challenge port (default: 80)
|
||||
renewThresholdDays?: number; // Days before expiry to renew (default: 30)
|
||||
};
|
||||
|
||||
// Defaults
|
||||
defaults?: {
|
||||
target?: { host: string; port: number };
|
||||
security?: IRouteSecurity;
|
||||
tls?: IRouteTls;
|
||||
};
|
||||
|
||||
// Behavior
|
||||
enableDetailedLogging?: boolean;
|
||||
gracefulShutdownTimeout?: number; // ms to wait for connections to close
|
||||
}
|
||||
```
|
||||
1. **📝 Use Helper Functions** — They provide sensible defaults and prevent common mistakes
|
||||
2. **🎯 Set Route Priorities** — More specific routes should have higher priority values
|
||||
3. **🔒 Enable Security** — Always use IP filtering and rate limiting for public-facing services
|
||||
4. **📊 Monitor Metrics** — Use the built-in metrics to catch issues early
|
||||
5. **🔄 Certificate Monitoring** — Set up alerts before certificates expire
|
||||
6. **🛑 Graceful Shutdown** — Always call `proxy.stop()` for clean connection termination
|
||||
7. **✅ Validate Routes** — Use `RouteValidator.validateRoutes()` to catch config errors before deployment
|
||||
8. **🔀 Atomic Updates** — Use `updateRoutes()` for hot-reloading routes (mutex-locked, no downtime)
|
||||
9. **🎮 Use Socket Handlers** — For protocols beyond HTTP, implement custom socket handlers instead of fighting the proxy model
|
||||
|
||||
## License and Legal Information
|
||||
|
||||
|
||||
1760
rust/Cargo.lock
generated
Normal file
1760
rust/Cargo.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
98
rust/Cargo.toml
Normal file
98
rust/Cargo.toml
Normal file
@@ -0,0 +1,98 @@
|
||||
[workspace]
|
||||
resolver = "2"
|
||||
members = [
|
||||
"crates/rustproxy",
|
||||
"crates/rustproxy-config",
|
||||
"crates/rustproxy-routing",
|
||||
"crates/rustproxy-tls",
|
||||
"crates/rustproxy-passthrough",
|
||||
"crates/rustproxy-http",
|
||||
"crates/rustproxy-nftables",
|
||||
"crates/rustproxy-metrics",
|
||||
"crates/rustproxy-security",
|
||||
]
|
||||
|
||||
[workspace.package]
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
license = "MIT"
|
||||
authors = ["Lossless GmbH <hello@lossless.com>"]
|
||||
|
||||
[workspace.dependencies]
|
||||
# Async runtime
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
|
||||
# Serialization
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
|
||||
# HTTP proxy engine (hyper-based)
|
||||
hyper = { version = "1", features = ["http1", "http2", "server", "client"] }
|
||||
hyper-util = { version = "0.1", features = ["tokio", "http1", "http2", "client-legacy", "server-auto"] }
|
||||
http-body-util = "0.1"
|
||||
bytes = "1"
|
||||
|
||||
# ACME / Let's Encrypt
|
||||
instant-acme = { version = "0.7", features = ["hyper-rustls"] }
|
||||
|
||||
# TLS for passthrough SNI
|
||||
rustls = { version = "0.23", features = ["ring"] }
|
||||
tokio-rustls = "0.26"
|
||||
rustls-pemfile = "2"
|
||||
|
||||
# Self-signed cert generation for tests
|
||||
rcgen = "0.13"
|
||||
|
||||
# Temp directories for tests
|
||||
tempfile = "3"
|
||||
|
||||
# Lock-free atomics
|
||||
arc-swap = "1"
|
||||
|
||||
# Concurrent maps
|
||||
dashmap = "6"
|
||||
|
||||
# Domain wildcard matching
|
||||
glob-match = "0.2"
|
||||
|
||||
# IP/CIDR parsing
|
||||
ipnet = "2"
|
||||
|
||||
# JWT authentication
|
||||
jsonwebtoken = "9"
|
||||
|
||||
# Structured logging
|
||||
tracing = "0.1"
|
||||
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||
|
||||
# Error handling
|
||||
thiserror = "2"
|
||||
anyhow = "1"
|
||||
|
||||
# CLI
|
||||
clap = { version = "4", features = ["derive"] }
|
||||
|
||||
# Regex for URL rewriting
|
||||
regex = "1"
|
||||
|
||||
# Base64 for basic auth
|
||||
base64 = "0.22"
|
||||
|
||||
# Cancellation / utility
|
||||
tokio-util = "0.7"
|
||||
|
||||
# Async traits
|
||||
async-trait = "0.1"
|
||||
|
||||
# libc for uid checks
|
||||
libc = "0.2"
|
||||
|
||||
# Internal crates
|
||||
rustproxy-config = { path = "crates/rustproxy-config" }
|
||||
rustproxy-routing = { path = "crates/rustproxy-routing" }
|
||||
rustproxy-tls = { path = "crates/rustproxy-tls" }
|
||||
rustproxy-passthrough = { path = "crates/rustproxy-passthrough" }
|
||||
rustproxy-http = { path = "crates/rustproxy-http" }
|
||||
rustproxy-nftables = { path = "crates/rustproxy-nftables" }
|
||||
rustproxy-metrics = { path = "crates/rustproxy-metrics" }
|
||||
rustproxy-security = { path = "crates/rustproxy-security" }
|
||||
145
rust/config/example.json
Normal file
145
rust/config/example.json
Normal file
@@ -0,0 +1,145 @@
|
||||
{
|
||||
"routes": [
|
||||
{
|
||||
"id": "https-passthrough",
|
||||
"name": "HTTPS Passthrough to Backend",
|
||||
"match": {
|
||||
"ports": 443,
|
||||
"domains": "backend.example.com"
|
||||
},
|
||||
"action": {
|
||||
"type": "forward",
|
||||
"targets": [
|
||||
{
|
||||
"host": "10.0.0.1",
|
||||
"port": 443
|
||||
}
|
||||
],
|
||||
"tls": {
|
||||
"mode": "passthrough"
|
||||
}
|
||||
},
|
||||
"priority": 10,
|
||||
"enabled": true
|
||||
},
|
||||
{
|
||||
"id": "https-terminate",
|
||||
"name": "HTTPS Terminate for API",
|
||||
"match": {
|
||||
"ports": 443,
|
||||
"domains": "api.example.com"
|
||||
},
|
||||
"action": {
|
||||
"type": "forward",
|
||||
"targets": [
|
||||
{
|
||||
"host": "localhost",
|
||||
"port": 8080
|
||||
}
|
||||
],
|
||||
"tls": {
|
||||
"mode": "terminate",
|
||||
"certificate": "auto"
|
||||
}
|
||||
},
|
||||
"priority": 20,
|
||||
"enabled": true
|
||||
},
|
||||
{
|
||||
"id": "http-redirect",
|
||||
"name": "HTTP to HTTPS Redirect",
|
||||
"match": {
|
||||
"ports": 80,
|
||||
"domains": ["api.example.com", "www.example.com"]
|
||||
},
|
||||
"action": {
|
||||
"type": "forward",
|
||||
"targets": [
|
||||
{
|
||||
"host": "localhost",
|
||||
"port": 8080
|
||||
}
|
||||
]
|
||||
},
|
||||
"priority": 0
|
||||
},
|
||||
{
|
||||
"id": "load-balanced",
|
||||
"name": "Load Balanced Backend",
|
||||
"match": {
|
||||
"ports": 443,
|
||||
"domains": "*.example.com"
|
||||
},
|
||||
"action": {
|
||||
"type": "forward",
|
||||
"targets": [
|
||||
{
|
||||
"host": "backend1.internal",
|
||||
"port": 8080
|
||||
},
|
||||
{
|
||||
"host": "backend2.internal",
|
||||
"port": 8080
|
||||
},
|
||||
{
|
||||
"host": "backend3.internal",
|
||||
"port": 8080
|
||||
}
|
||||
],
|
||||
"tls": {
|
||||
"mode": "terminate",
|
||||
"certificate": "auto"
|
||||
},
|
||||
"loadBalancing": {
|
||||
"algorithm": "round-robin",
|
||||
"healthCheck": {
|
||||
"path": "/health",
|
||||
"interval": 30,
|
||||
"timeout": 5,
|
||||
"unhealthyThreshold": 3,
|
||||
"healthyThreshold": 2
|
||||
}
|
||||
}
|
||||
},
|
||||
"security": {
|
||||
"ipAllowList": ["10.0.0.0/8", "192.168.0.0/16"],
|
||||
"maxConnections": 1000,
|
||||
"rateLimit": {
|
||||
"enabled": true,
|
||||
"maxRequests": 100,
|
||||
"window": 60,
|
||||
"keyBy": "ip"
|
||||
}
|
||||
},
|
||||
"headers": {
|
||||
"request": {
|
||||
"X-Forwarded-For": "{clientIp}",
|
||||
"X-Real-IP": "{clientIp}"
|
||||
},
|
||||
"response": {
|
||||
"X-Powered-By": "RustProxy"
|
||||
},
|
||||
"cors": {
|
||||
"enabled": true,
|
||||
"allowOrigin": "*",
|
||||
"allowMethods": "GET,POST,PUT,DELETE,OPTIONS",
|
||||
"allowHeaders": "Content-Type,Authorization",
|
||||
"allowCredentials": false,
|
||||
"maxAge": 86400
|
||||
}
|
||||
},
|
||||
"priority": 5
|
||||
}
|
||||
],
|
||||
"acme": {
|
||||
"email": "admin@example.com",
|
||||
"useProduction": false,
|
||||
"port": 80
|
||||
},
|
||||
"connectionTimeout": 30000,
|
||||
"socketTimeout": 3600000,
|
||||
"maxConnectionsPerIp": 100,
|
||||
"connectionRateLimitPerMinute": 300,
|
||||
"keepAliveTreatment": "extended",
|
||||
"enableDetailedLogging": false
|
||||
}
|
||||
13
rust/crates/rustproxy-config/Cargo.toml
Normal file
13
rust/crates/rustproxy-config/Cargo.toml
Normal file
@@ -0,0 +1,13 @@
|
||||
[package]
|
||||
name = "rustproxy-config"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
authors.workspace = true
|
||||
description = "Configuration types for RustProxy, compatible with SmartProxy JSON schema"
|
||||
|
||||
[dependencies]
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
ipnet = { workspace = true }
|
||||
334
rust/crates/rustproxy-config/src/helpers.rs
Normal file
334
rust/crates/rustproxy-config/src/helpers.rs
Normal file
@@ -0,0 +1,334 @@
|
||||
use crate::route_types::*;
|
||||
use crate::tls_types::*;
|
||||
|
||||
/// Create a simple HTTP forwarding route.
|
||||
/// Equivalent to SmartProxy's `createHttpRoute()`.
|
||||
pub fn create_http_route(
|
||||
domains: impl Into<DomainSpec>,
|
||||
target_host: impl Into<String>,
|
||||
target_port: u16,
|
||||
) -> RouteConfig {
|
||||
RouteConfig {
|
||||
id: None,
|
||||
route_match: RouteMatch {
|
||||
ports: PortRange::Single(80),
|
||||
domains: Some(domains.into()),
|
||||
path: None,
|
||||
client_ip: None,
|
||||
tls_version: None,
|
||||
headers: None,
|
||||
},
|
||||
action: RouteAction {
|
||||
action_type: RouteActionType::Forward,
|
||||
targets: Some(vec![RouteTarget {
|
||||
target_match: None,
|
||||
host: HostSpec::Single(target_host.into()),
|
||||
port: PortSpec::Fixed(target_port),
|
||||
tls: None,
|
||||
websocket: None,
|
||||
load_balancing: None,
|
||||
send_proxy_protocol: None,
|
||||
headers: None,
|
||||
advanced: None,
|
||||
priority: None,
|
||||
}]),
|
||||
tls: None,
|
||||
websocket: None,
|
||||
load_balancing: None,
|
||||
advanced: None,
|
||||
options: None,
|
||||
forwarding_engine: None,
|
||||
nftables: None,
|
||||
send_proxy_protocol: None,
|
||||
},
|
||||
headers: None,
|
||||
security: None,
|
||||
name: None,
|
||||
description: None,
|
||||
priority: None,
|
||||
tags: None,
|
||||
enabled: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create an HTTPS termination route.
|
||||
/// Equivalent to SmartProxy's `createHttpsTerminateRoute()`.
|
||||
pub fn create_https_terminate_route(
|
||||
domains: impl Into<DomainSpec>,
|
||||
target_host: impl Into<String>,
|
||||
target_port: u16,
|
||||
) -> RouteConfig {
|
||||
let mut route = create_http_route(domains, target_host, target_port);
|
||||
route.route_match.ports = PortRange::Single(443);
|
||||
route.action.tls = Some(RouteTls {
|
||||
mode: TlsMode::Terminate,
|
||||
certificate: Some(CertificateSpec::Auto("auto".to_string())),
|
||||
acme: None,
|
||||
versions: None,
|
||||
ciphers: None,
|
||||
honor_cipher_order: None,
|
||||
session_timeout: None,
|
||||
});
|
||||
route
|
||||
}
|
||||
|
||||
/// Create a TLS passthrough route.
|
||||
/// Equivalent to SmartProxy's `createHttpsPassthroughRoute()`.
|
||||
pub fn create_https_passthrough_route(
|
||||
domains: impl Into<DomainSpec>,
|
||||
target_host: impl Into<String>,
|
||||
target_port: u16,
|
||||
) -> RouteConfig {
|
||||
let mut route = create_http_route(domains, target_host, target_port);
|
||||
route.route_match.ports = PortRange::Single(443);
|
||||
route.action.tls = Some(RouteTls {
|
||||
mode: TlsMode::Passthrough,
|
||||
certificate: None,
|
||||
acme: None,
|
||||
versions: None,
|
||||
ciphers: None,
|
||||
honor_cipher_order: None,
|
||||
session_timeout: None,
|
||||
});
|
||||
route
|
||||
}
|
||||
|
||||
/// Create an HTTP-to-HTTPS redirect route.
|
||||
/// Equivalent to SmartProxy's `createHttpToHttpsRedirect()`.
|
||||
pub fn create_http_to_https_redirect(
|
||||
domains: impl Into<DomainSpec>,
|
||||
) -> RouteConfig {
|
||||
let domains = domains.into();
|
||||
RouteConfig {
|
||||
id: None,
|
||||
route_match: RouteMatch {
|
||||
ports: PortRange::Single(80),
|
||||
domains: Some(domains),
|
||||
path: None,
|
||||
client_ip: None,
|
||||
tls_version: None,
|
||||
headers: None,
|
||||
},
|
||||
action: RouteAction {
|
||||
action_type: RouteActionType::Forward,
|
||||
targets: None,
|
||||
tls: None,
|
||||
websocket: None,
|
||||
load_balancing: None,
|
||||
advanced: Some(RouteAdvanced {
|
||||
timeout: None,
|
||||
headers: None,
|
||||
keep_alive: None,
|
||||
static_files: None,
|
||||
test_response: Some(RouteTestResponse {
|
||||
status: 301,
|
||||
headers: {
|
||||
let mut h = std::collections::HashMap::new();
|
||||
h.insert("Location".to_string(), "https://{domain}{path}".to_string());
|
||||
h
|
||||
},
|
||||
body: String::new(),
|
||||
}),
|
||||
url_rewrite: None,
|
||||
}),
|
||||
options: None,
|
||||
forwarding_engine: None,
|
||||
nftables: None,
|
||||
send_proxy_protocol: None,
|
||||
},
|
||||
headers: None,
|
||||
security: None,
|
||||
name: Some("HTTP to HTTPS Redirect".to_string()),
|
||||
description: None,
|
||||
priority: None,
|
||||
tags: None,
|
||||
enabled: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a complete HTTPS server with HTTP redirect.
|
||||
/// Equivalent to SmartProxy's `createCompleteHttpsServer()`.
|
||||
pub fn create_complete_https_server(
|
||||
domain: impl Into<String>,
|
||||
target_host: impl Into<String>,
|
||||
target_port: u16,
|
||||
) -> Vec<RouteConfig> {
|
||||
let domain = domain.into();
|
||||
let target_host = target_host.into();
|
||||
|
||||
vec![
|
||||
create_http_to_https_redirect(DomainSpec::Single(domain.clone())),
|
||||
create_https_terminate_route(
|
||||
DomainSpec::Single(domain),
|
||||
target_host,
|
||||
target_port,
|
||||
),
|
||||
]
|
||||
}
|
||||
|
||||
/// Create a load balancer route.
|
||||
/// Equivalent to SmartProxy's `createLoadBalancerRoute()`.
|
||||
pub fn create_load_balancer_route(
|
||||
domains: impl Into<DomainSpec>,
|
||||
targets: Vec<(String, u16)>,
|
||||
tls: Option<RouteTls>,
|
||||
) -> RouteConfig {
|
||||
let route_targets: Vec<RouteTarget> = targets
|
||||
.into_iter()
|
||||
.map(|(host, port)| RouteTarget {
|
||||
target_match: None,
|
||||
host: HostSpec::Single(host),
|
||||
port: PortSpec::Fixed(port),
|
||||
tls: None,
|
||||
websocket: None,
|
||||
load_balancing: None,
|
||||
send_proxy_protocol: None,
|
||||
headers: None,
|
||||
advanced: None,
|
||||
priority: None,
|
||||
})
|
||||
.collect();
|
||||
|
||||
let port = if tls.is_some() { 443 } else { 80 };
|
||||
|
||||
RouteConfig {
|
||||
id: None,
|
||||
route_match: RouteMatch {
|
||||
ports: PortRange::Single(port),
|
||||
domains: Some(domains.into()),
|
||||
path: None,
|
||||
client_ip: None,
|
||||
tls_version: None,
|
||||
headers: None,
|
||||
},
|
||||
action: RouteAction {
|
||||
action_type: RouteActionType::Forward,
|
||||
targets: Some(route_targets),
|
||||
tls,
|
||||
websocket: None,
|
||||
load_balancing: Some(RouteLoadBalancing {
|
||||
algorithm: LoadBalancingAlgorithm::RoundRobin,
|
||||
health_check: None,
|
||||
}),
|
||||
advanced: None,
|
||||
options: None,
|
||||
forwarding_engine: None,
|
||||
nftables: None,
|
||||
send_proxy_protocol: None,
|
||||
},
|
||||
headers: None,
|
||||
security: None,
|
||||
name: Some("Load Balancer".to_string()),
|
||||
description: None,
|
||||
priority: None,
|
||||
tags: None,
|
||||
enabled: None,
|
||||
}
|
||||
}
|
||||
|
||||
// Convenience conversions for DomainSpec
|
||||
impl From<&str> for DomainSpec {
|
||||
fn from(s: &str) -> Self {
|
||||
DomainSpec::Single(s.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<String> for DomainSpec {
|
||||
fn from(s: String) -> Self {
|
||||
DomainSpec::Single(s)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Vec<String>> for DomainSpec {
|
||||
fn from(v: Vec<String>) -> Self {
|
||||
DomainSpec::List(v)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Vec<&str>> for DomainSpec {
|
||||
fn from(v: Vec<&str>) -> Self {
|
||||
DomainSpec::List(v.into_iter().map(|s| s.to_string()).collect())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::tls_types::TlsMode;
|
||||
|
||||
#[test]
|
||||
fn test_create_http_route() {
|
||||
let route = create_http_route("example.com", "localhost", 8080);
|
||||
assert_eq!(route.route_match.ports.to_ports(), vec![80]);
|
||||
let domains = route.route_match.domains.as_ref().unwrap().to_vec();
|
||||
assert_eq!(domains, vec!["example.com"]);
|
||||
let target = &route.action.targets.as_ref().unwrap()[0];
|
||||
assert_eq!(target.host.first(), "localhost");
|
||||
assert_eq!(target.port.resolve(80), 8080);
|
||||
assert!(route.action.tls.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_https_terminate_route() {
|
||||
let route = create_https_terminate_route("api.example.com", "backend", 3000);
|
||||
assert_eq!(route.route_match.ports.to_ports(), vec![443]);
|
||||
let tls = route.action.tls.as_ref().unwrap();
|
||||
assert_eq!(tls.mode, TlsMode::Terminate);
|
||||
assert!(tls.certificate.as_ref().unwrap().is_auto());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_https_passthrough_route() {
|
||||
let route = create_https_passthrough_route("secure.example.com", "backend", 443);
|
||||
assert_eq!(route.route_match.ports.to_ports(), vec![443]);
|
||||
let tls = route.action.tls.as_ref().unwrap();
|
||||
assert_eq!(tls.mode, TlsMode::Passthrough);
|
||||
assert!(tls.certificate.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_http_to_https_redirect() {
|
||||
let route = create_http_to_https_redirect("example.com");
|
||||
assert_eq!(route.route_match.ports.to_ports(), vec![80]);
|
||||
assert!(route.action.targets.is_none());
|
||||
let test_response = route.action.advanced.as_ref().unwrap().test_response.as_ref().unwrap();
|
||||
assert_eq!(test_response.status, 301);
|
||||
assert!(test_response.headers.contains_key("Location"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_complete_https_server() {
|
||||
let routes = create_complete_https_server("example.com", "backend", 8080);
|
||||
assert_eq!(routes.len(), 2);
|
||||
// First route is HTTP redirect
|
||||
assert_eq!(routes[0].route_match.ports.to_ports(), vec![80]);
|
||||
// Second route is HTTPS terminate
|
||||
assert_eq!(routes[1].route_match.ports.to_ports(), vec![443]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_load_balancer_route() {
|
||||
let targets = vec![
|
||||
("backend1".to_string(), 8080),
|
||||
("backend2".to_string(), 8080),
|
||||
("backend3".to_string(), 8080),
|
||||
];
|
||||
let route = create_load_balancer_route("*.example.com", targets, None);
|
||||
assert_eq!(route.route_match.ports.to_ports(), vec![80]);
|
||||
assert_eq!(route.action.targets.as_ref().unwrap().len(), 3);
|
||||
let lb = route.action.load_balancing.as_ref().unwrap();
|
||||
assert_eq!(lb.algorithm, LoadBalancingAlgorithm::RoundRobin);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_domain_spec_from_str() {
|
||||
let spec: DomainSpec = "example.com".into();
|
||||
assert_eq!(spec.to_vec(), vec!["example.com"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_domain_spec_from_vec() {
|
||||
let spec: DomainSpec = vec!["a.com", "b.com"].into();
|
||||
assert_eq!(spec.to_vec(), vec!["a.com", "b.com"]);
|
||||
}
|
||||
}
|
||||
19
rust/crates/rustproxy-config/src/lib.rs
Normal file
19
rust/crates/rustproxy-config/src/lib.rs
Normal file
@@ -0,0 +1,19 @@
|
||||
//! # rustproxy-config
|
||||
//!
|
||||
//! Configuration types for RustProxy, fully compatible with SmartProxy's JSON schema.
|
||||
//! All types use `#[serde(rename_all = "camelCase")]` to match TypeScript field naming.
|
||||
|
||||
pub mod route_types;
|
||||
pub mod proxy_options;
|
||||
pub mod tls_types;
|
||||
pub mod security_types;
|
||||
pub mod validation;
|
||||
pub mod helpers;
|
||||
|
||||
// Re-export all primary types
|
||||
pub use route_types::*;
|
||||
pub use proxy_options::*;
|
||||
pub use tls_types::*;
|
||||
pub use security_types::*;
|
||||
pub use validation::*;
|
||||
pub use helpers::*;
|
||||
439
rust/crates/rustproxy-config/src/proxy_options.rs
Normal file
439
rust/crates/rustproxy-config/src/proxy_options.rs
Normal file
@@ -0,0 +1,439 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::route_types::RouteConfig;
|
||||
|
||||
/// Global ACME configuration options.
|
||||
/// Matches TypeScript: `IAcmeOptions`
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct AcmeOptions {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub enabled: Option<bool>,
|
||||
/// Required when any route uses certificate: 'auto'
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub email: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub environment: Option<AcmeEnvironment>,
|
||||
/// Alias for email
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub account_email: Option<String>,
|
||||
/// Port for HTTP-01 challenges (default: 80)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub port: Option<u16>,
|
||||
/// Use Let's Encrypt production (default: false)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub use_production: Option<bool>,
|
||||
/// Days before expiry to renew (default: 30)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub renew_threshold_days: Option<u32>,
|
||||
/// Enable automatic renewal (default: true)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub auto_renew: Option<bool>,
|
||||
/// Directory to store certificates (default: './certs')
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub certificate_store: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub skip_configured_certs: Option<bool>,
|
||||
/// How often to check for renewals (default: 24)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub renew_check_interval_hours: Option<u32>,
|
||||
}
|
||||
|
||||
/// ACME environment.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum AcmeEnvironment {
|
||||
Production,
|
||||
Staging,
|
||||
}
|
||||
|
||||
/// Default target configuration.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct DefaultTarget {
|
||||
pub host: String,
|
||||
pub port: u16,
|
||||
}
|
||||
|
||||
/// Default security configuration.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct DefaultSecurity {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub ip_allow_list: Option<Vec<String>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub ip_block_list: Option<Vec<String>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub max_connections: Option<u64>,
|
||||
}
|
||||
|
||||
/// Default configuration.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct DefaultConfig {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub target: Option<DefaultTarget>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub security: Option<DefaultSecurity>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub preserve_source_ip: Option<bool>,
|
||||
}
|
||||
|
||||
/// Keep-alive treatment.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum KeepAliveTreatment {
|
||||
Standard,
|
||||
Extended,
|
||||
Immortal,
|
||||
}
|
||||
|
||||
/// Metrics configuration.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct MetricsConfig {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub enabled: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub sample_interval_ms: Option<u64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub retention_seconds: Option<u64>,
|
||||
}
|
||||
|
||||
/// RustProxy configuration options.
|
||||
/// Matches TypeScript: `ISmartProxyOptions`
|
||||
///
|
||||
/// This is the top-level configuration that can be loaded from a JSON file
|
||||
/// or constructed programmatically.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct RustProxyOptions {
|
||||
/// The unified configuration array (required)
|
||||
pub routes: Vec<RouteConfig>,
|
||||
|
||||
/// Preserve client IP when forwarding
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub preserve_source_ip: Option<bool>,
|
||||
|
||||
/// List of trusted proxy IPs that can send PROXY protocol
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub proxy_ips: Option<Vec<String>>,
|
||||
|
||||
/// Global option to accept PROXY protocol
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub accept_proxy_protocol: Option<bool>,
|
||||
|
||||
/// Global option to send PROXY protocol to all targets
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub send_proxy_protocol: Option<bool>,
|
||||
|
||||
/// Global/default settings
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub defaults: Option<DefaultConfig>,
|
||||
|
||||
// ─── Timeout Settings ────────────────────────────────────────────
|
||||
|
||||
/// Timeout for establishing connection to backend (ms), default: 30000
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub connection_timeout: Option<u64>,
|
||||
|
||||
/// Timeout for initial data/SNI (ms), default: 60000
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub initial_data_timeout: Option<u64>,
|
||||
|
||||
/// Socket inactivity timeout (ms), default: 3600000
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub socket_timeout: Option<u64>,
|
||||
|
||||
/// How often to check for inactive connections (ms), default: 60000
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub inactivity_check_interval: Option<u64>,
|
||||
|
||||
/// Default max connection lifetime (ms), default: 86400000
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub max_connection_lifetime: Option<u64>,
|
||||
|
||||
/// Inactivity timeout (ms), default: 14400000
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub inactivity_timeout: Option<u64>,
|
||||
|
||||
/// Maximum time to wait for connections to close during shutdown (ms)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub graceful_shutdown_timeout: Option<u64>,
|
||||
|
||||
// ─── Socket Optimization ─────────────────────────────────────────
|
||||
|
||||
/// Disable Nagle's algorithm (default: true)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub no_delay: Option<bool>,
|
||||
|
||||
/// Enable TCP keepalive (default: true)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub keep_alive: Option<bool>,
|
||||
|
||||
/// Initial delay before sending keepalive probes (ms)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub keep_alive_initial_delay: Option<u64>,
|
||||
|
||||
/// Maximum bytes to buffer during connection setup
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub max_pending_data_size: Option<u64>,
|
||||
|
||||
// ─── Enhanced Features ───────────────────────────────────────────
|
||||
|
||||
/// Disable inactivity checking entirely
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub disable_inactivity_check: Option<bool>,
|
||||
|
||||
/// Enable TCP keep-alive probes
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub enable_keep_alive_probes: Option<bool>,
|
||||
|
||||
/// Enable detailed connection logging
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub enable_detailed_logging: Option<bool>,
|
||||
|
||||
/// Enable TLS handshake debug logging
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub enable_tls_debug_logging: Option<bool>,
|
||||
|
||||
/// Randomize timeouts to prevent thundering herd
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub enable_randomized_timeouts: Option<bool>,
|
||||
|
||||
// ─── Rate Limiting ───────────────────────────────────────────────
|
||||
|
||||
/// Maximum simultaneous connections from a single IP
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub max_connections_per_ip: Option<u64>,
|
||||
|
||||
/// Max new connections per minute from a single IP
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub connection_rate_limit_per_minute: Option<u64>,
|
||||
|
||||
// ─── Keep-Alive Settings ─────────────────────────────────────────
|
||||
|
||||
/// How to treat keep-alive connections
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub keep_alive_treatment: Option<KeepAliveTreatment>,
|
||||
|
||||
/// Multiplier for inactivity timeout for keep-alive connections
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub keep_alive_inactivity_multiplier: Option<f64>,
|
||||
|
||||
/// Extended lifetime for keep-alive connections (ms)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub extended_keep_alive_lifetime: Option<u64>,
|
||||
|
||||
// ─── HttpProxy Integration ───────────────────────────────────────
|
||||
|
||||
/// Array of ports to forward to HttpProxy
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub use_http_proxy: Option<Vec<u16>>,
|
||||
|
||||
/// Port where HttpProxy is listening (default: 8443)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub http_proxy_port: Option<u16>,
|
||||
|
||||
// ─── Metrics ─────────────────────────────────────────────────────
|
||||
|
||||
/// Metrics configuration
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub metrics: Option<MetricsConfig>,
|
||||
|
||||
// ─── ACME ────────────────────────────────────────────────────────
|
||||
|
||||
/// Global ACME configuration
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub acme: Option<AcmeOptions>,
|
||||
}
|
||||
|
||||
impl Default for RustProxyOptions {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
routes: Vec::new(),
|
||||
preserve_source_ip: None,
|
||||
proxy_ips: None,
|
||||
accept_proxy_protocol: None,
|
||||
send_proxy_protocol: None,
|
||||
defaults: None,
|
||||
connection_timeout: None,
|
||||
initial_data_timeout: None,
|
||||
socket_timeout: None,
|
||||
inactivity_check_interval: None,
|
||||
max_connection_lifetime: None,
|
||||
inactivity_timeout: None,
|
||||
graceful_shutdown_timeout: None,
|
||||
no_delay: None,
|
||||
keep_alive: None,
|
||||
keep_alive_initial_delay: None,
|
||||
max_pending_data_size: None,
|
||||
disable_inactivity_check: None,
|
||||
enable_keep_alive_probes: None,
|
||||
enable_detailed_logging: None,
|
||||
enable_tls_debug_logging: None,
|
||||
enable_randomized_timeouts: None,
|
||||
max_connections_per_ip: None,
|
||||
connection_rate_limit_per_minute: None,
|
||||
keep_alive_treatment: None,
|
||||
keep_alive_inactivity_multiplier: None,
|
||||
extended_keep_alive_lifetime: None,
|
||||
use_http_proxy: None,
|
||||
http_proxy_port: None,
|
||||
metrics: None,
|
||||
acme: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl RustProxyOptions {
|
||||
/// Load configuration from a JSON file.
|
||||
pub fn from_file(path: &str) -> Result<Self, Box<dyn std::error::Error>> {
|
||||
let content = std::fs::read_to_string(path)?;
|
||||
let options: Self = serde_json::from_str(&content)?;
|
||||
Ok(options)
|
||||
}
|
||||
|
||||
/// Get the effective connection timeout in milliseconds.
|
||||
pub fn effective_connection_timeout(&self) -> u64 {
|
||||
self.connection_timeout.unwrap_or(30_000)
|
||||
}
|
||||
|
||||
/// Get the effective initial data timeout in milliseconds.
|
||||
pub fn effective_initial_data_timeout(&self) -> u64 {
|
||||
self.initial_data_timeout.unwrap_or(60_000)
|
||||
}
|
||||
|
||||
/// Get the effective socket timeout in milliseconds.
|
||||
pub fn effective_socket_timeout(&self) -> u64 {
|
||||
self.socket_timeout.unwrap_or(3_600_000)
|
||||
}
|
||||
|
||||
/// Get the effective max connection lifetime in milliseconds.
|
||||
pub fn effective_max_connection_lifetime(&self) -> u64 {
|
||||
self.max_connection_lifetime.unwrap_or(86_400_000)
|
||||
}
|
||||
|
||||
/// Get all unique ports that routes listen on.
|
||||
pub fn all_listening_ports(&self) -> Vec<u16> {
|
||||
let mut ports: Vec<u16> = self.routes
|
||||
.iter()
|
||||
.flat_map(|r| r.listening_ports())
|
||||
.collect();
|
||||
ports.sort();
|
||||
ports.dedup();
|
||||
ports
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::helpers::*;
|
||||
|
||||
#[test]
|
||||
fn test_serde_roundtrip_minimal() {
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![create_http_route("example.com", "localhost", 8080)],
|
||||
..Default::default()
|
||||
};
|
||||
let json = serde_json::to_string(&options).unwrap();
|
||||
let parsed: RustProxyOptions = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(parsed.routes.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_serde_roundtrip_full() {
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![
|
||||
create_http_route("a.com", "backend1", 8080),
|
||||
create_https_passthrough_route("b.com", "backend2", 443),
|
||||
],
|
||||
connection_timeout: Some(5000),
|
||||
socket_timeout: Some(60000),
|
||||
max_connections_per_ip: Some(100),
|
||||
acme: Some(AcmeOptions {
|
||||
enabled: Some(true),
|
||||
email: Some("admin@example.com".to_string()),
|
||||
environment: Some(AcmeEnvironment::Staging),
|
||||
account_email: None,
|
||||
port: None,
|
||||
use_production: None,
|
||||
renew_threshold_days: None,
|
||||
auto_renew: None,
|
||||
certificate_store: None,
|
||||
skip_configured_certs: None,
|
||||
renew_check_interval_hours: None,
|
||||
}),
|
||||
..Default::default()
|
||||
};
|
||||
let json = serde_json::to_string_pretty(&options).unwrap();
|
||||
let parsed: RustProxyOptions = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(parsed.routes.len(), 2);
|
||||
assert_eq!(parsed.connection_timeout, Some(5000));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_default_timeouts() {
|
||||
let options = RustProxyOptions::default();
|
||||
assert_eq!(options.effective_connection_timeout(), 30_000);
|
||||
assert_eq!(options.effective_initial_data_timeout(), 60_000);
|
||||
assert_eq!(options.effective_socket_timeout(), 3_600_000);
|
||||
assert_eq!(options.effective_max_connection_lifetime(), 86_400_000);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_custom_timeouts() {
|
||||
let options = RustProxyOptions {
|
||||
connection_timeout: Some(5000),
|
||||
initial_data_timeout: Some(10000),
|
||||
socket_timeout: Some(30000),
|
||||
max_connection_lifetime: Some(60000),
|
||||
..Default::default()
|
||||
};
|
||||
assert_eq!(options.effective_connection_timeout(), 5000);
|
||||
assert_eq!(options.effective_initial_data_timeout(), 10000);
|
||||
assert_eq!(options.effective_socket_timeout(), 30000);
|
||||
assert_eq!(options.effective_max_connection_lifetime(), 60000);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_all_listening_ports() {
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![
|
||||
create_http_route("a.com", "backend", 8080), // port 80
|
||||
create_https_passthrough_route("b.com", "backend", 443), // port 443
|
||||
create_http_route("c.com", "backend", 9090), // port 80 (duplicate)
|
||||
],
|
||||
..Default::default()
|
||||
};
|
||||
let ports = options.all_listening_ports();
|
||||
assert_eq!(ports, vec![80, 443]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_camel_case_field_names() {
|
||||
let options = RustProxyOptions {
|
||||
connection_timeout: Some(5000),
|
||||
max_connections_per_ip: Some(100),
|
||||
keep_alive_treatment: Some(KeepAliveTreatment::Extended),
|
||||
..Default::default()
|
||||
};
|
||||
let json = serde_json::to_string(&options).unwrap();
|
||||
assert!(json.contains("connectionTimeout"));
|
||||
assert!(json.contains("maxConnectionsPerIp"));
|
||||
assert!(json.contains("keepAliveTreatment"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_deserialize_example_json() {
|
||||
let content = std::fs::read_to_string(
|
||||
concat!(env!("CARGO_MANIFEST_DIR"), "/../../config/example.json")
|
||||
).unwrap();
|
||||
let options: RustProxyOptions = serde_json::from_str(&content).unwrap();
|
||||
assert_eq!(options.routes.len(), 4);
|
||||
let ports = options.all_listening_ports();
|
||||
assert!(ports.contains(&80));
|
||||
assert!(ports.contains(&443));
|
||||
}
|
||||
}
|
||||
603
rust/crates/rustproxy-config/src/route_types.rs
Normal file
603
rust/crates/rustproxy-config/src/route_types.rs
Normal file
@@ -0,0 +1,603 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::tls_types::RouteTls;
|
||||
use crate::security_types::RouteSecurity;
|
||||
|
||||
// ─── Port Range ──────────────────────────────────────────────────────
|
||||
|
||||
/// Port range specification format.
|
||||
/// Matches TypeScript: `type TPortRange = number | number[] | Array<{ from: number; to: number }>`
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum PortRange {
|
||||
/// Single port number
|
||||
Single(u16),
|
||||
/// Array of port numbers
|
||||
List(Vec<u16>),
|
||||
/// Array of port ranges
|
||||
Ranges(Vec<PortRangeSpec>),
|
||||
}
|
||||
|
||||
impl PortRange {
|
||||
/// Expand the port range into a flat list of ports.
|
||||
pub fn to_ports(&self) -> Vec<u16> {
|
||||
match self {
|
||||
PortRange::Single(p) => vec![*p],
|
||||
PortRange::List(ports) => ports.clone(),
|
||||
PortRange::Ranges(ranges) => {
|
||||
ranges.iter().flat_map(|r| r.from..=r.to).collect()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A from-to port range.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PortRangeSpec {
|
||||
pub from: u16,
|
||||
pub to: u16,
|
||||
}
|
||||
|
||||
// ─── Route Action Type ───────────────────────────────────────────────
|
||||
|
||||
/// Supported action types for route configurations.
|
||||
/// Matches TypeScript: `type TRouteActionType = 'forward' | 'socket-handler'`
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
pub enum RouteActionType {
|
||||
Forward,
|
||||
SocketHandler,
|
||||
}
|
||||
|
||||
// ─── Forwarding Engine ───────────────────────────────────────────────
|
||||
|
||||
/// Forwarding engine specification.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum ForwardingEngine {
|
||||
Node,
|
||||
Nftables,
|
||||
}
|
||||
|
||||
// ─── Route Match ─────────────────────────────────────────────────────
|
||||
|
||||
/// Domain specification: single string or array.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum DomainSpec {
|
||||
Single(String),
|
||||
List(Vec<String>),
|
||||
}
|
||||
|
||||
impl DomainSpec {
|
||||
pub fn to_vec(&self) -> Vec<&str> {
|
||||
match self {
|
||||
DomainSpec::Single(s) => vec![s.as_str()],
|
||||
DomainSpec::List(v) => v.iter().map(|s| s.as_str()).collect(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Header match value: either exact string or regex pattern.
|
||||
/// In JSON, all values come as strings. Regex patterns are prefixed with `/` and suffixed with `/`.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum HeaderMatchValue {
|
||||
Exact(String),
|
||||
}
|
||||
|
||||
/// Route match criteria for incoming requests.
|
||||
/// Matches TypeScript: `IRouteMatch`
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct RouteMatch {
|
||||
/// Listen on these ports (required)
|
||||
pub ports: PortRange,
|
||||
|
||||
/// Optional domain patterns to match (default: all domains)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub domains: Option<DomainSpec>,
|
||||
|
||||
/// Match specific paths
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub path: Option<String>,
|
||||
|
||||
/// Match specific client IPs
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub client_ip: Option<Vec<String>>,
|
||||
|
||||
/// Match specific TLS versions
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tls_version: Option<Vec<String>>,
|
||||
|
||||
/// Match specific HTTP headers
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub headers: Option<HashMap<String, String>>,
|
||||
}
|
||||
|
||||
// ─── Target Match ────────────────────────────────────────────────────
|
||||
|
||||
/// Target-specific match criteria for sub-routing within a route.
|
||||
/// Matches TypeScript: `ITargetMatch`
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct TargetMatch {
|
||||
/// Match specific ports from the route
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub ports: Option<Vec<u16>>,
|
||||
/// Match specific paths (supports wildcards like /api/*)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub path: Option<String>,
|
||||
/// Match specific HTTP headers
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub headers: Option<HashMap<String, String>>,
|
||||
/// Match specific HTTP methods
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub method: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
// ─── WebSocket Config ────────────────────────────────────────────────
|
||||
|
||||
/// WebSocket configuration.
|
||||
/// Matches TypeScript: `IRouteWebSocket`
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct RouteWebSocket {
|
||||
pub enabled: bool,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub ping_interval: Option<u64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub ping_timeout: Option<u64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub max_payload_size: Option<u64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub custom_headers: Option<HashMap<String, String>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub subprotocols: Option<Vec<String>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub rewrite_path: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub allowed_origins: Option<Vec<String>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub authenticate_request: Option<bool>,
|
||||
}
|
||||
|
||||
// ─── Load Balancing ──────────────────────────────────────────────────
|
||||
|
||||
/// Load balancing algorithm.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
pub enum LoadBalancingAlgorithm {
|
||||
RoundRobin,
|
||||
LeastConnections,
|
||||
IpHash,
|
||||
}
|
||||
|
||||
/// Health check configuration.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct HealthCheck {
|
||||
pub path: String,
|
||||
pub interval: u64,
|
||||
pub timeout: u64,
|
||||
pub unhealthy_threshold: u32,
|
||||
pub healthy_threshold: u32,
|
||||
}
|
||||
|
||||
/// Load balancing configuration.
|
||||
/// Matches TypeScript: `IRouteLoadBalancing`
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct RouteLoadBalancing {
|
||||
pub algorithm: LoadBalancingAlgorithm,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub health_check: Option<HealthCheck>,
|
||||
}
|
||||
|
||||
// ─── CORS ────────────────────────────────────────────────────────────
|
||||
|
||||
/// Allowed origin specification.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum AllowOrigin {
|
||||
Single(String),
|
||||
List(Vec<String>),
|
||||
}
|
||||
|
||||
/// CORS configuration for a route.
|
||||
/// Matches TypeScript: `IRouteCors`
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct RouteCors {
|
||||
pub enabled: bool,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub allow_origin: Option<AllowOrigin>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub allow_methods: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub allow_headers: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub allow_credentials: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub expose_headers: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub max_age: Option<u64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub preflight: Option<bool>,
|
||||
}
|
||||
|
||||
// ─── Headers ─────────────────────────────────────────────────────────
|
||||
|
||||
/// Headers configuration.
|
||||
/// Matches TypeScript: `IRouteHeaders`
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct RouteHeaders {
|
||||
/// Headers to add/modify for requests to backend
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub request: Option<HashMap<String, String>>,
|
||||
/// Headers to add/modify for responses to client
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub response: Option<HashMap<String, String>>,
|
||||
/// CORS configuration
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub cors: Option<RouteCors>,
|
||||
}
|
||||
|
||||
// ─── Static Files ────────────────────────────────────────────────────
|
||||
|
||||
/// Static file server configuration.
|
||||
/// Matches TypeScript: `IRouteStaticFiles`
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct RouteStaticFiles {
|
||||
pub root: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub index: Option<Vec<String>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub headers: Option<HashMap<String, String>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub directory: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub index_files: Option<Vec<String>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub cache_control: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub expires: Option<u64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub follow_symlinks: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub disable_directory_listing: Option<bool>,
|
||||
}
|
||||
|
||||
// ─── Test Response ───────────────────────────────────────────────────
|
||||
|
||||
/// Test route response configuration.
|
||||
/// Matches TypeScript: `IRouteTestResponse`
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct RouteTestResponse {
|
||||
pub status: u16,
|
||||
pub headers: HashMap<String, String>,
|
||||
pub body: String,
|
||||
}
|
||||
|
||||
// ─── URL Rewriting ───────────────────────────────────────────────────
|
||||
|
||||
/// URL rewriting configuration.
|
||||
/// Matches TypeScript: `IRouteUrlRewrite`
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct RouteUrlRewrite {
|
||||
/// RegExp pattern to match in URL
|
||||
pub pattern: String,
|
||||
/// Replacement pattern
|
||||
pub target: String,
|
||||
/// RegExp flags
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub flags: Option<String>,
|
||||
/// Only apply to path, not query string
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub only_rewrite_path: Option<bool>,
|
||||
}
|
||||
|
||||
// ─── Advanced Options ────────────────────────────────────────────────
|
||||
|
||||
/// Advanced options for route actions.
|
||||
/// Matches TypeScript: `IRouteAdvanced`
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct RouteAdvanced {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub timeout: Option<u64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub headers: Option<HashMap<String, String>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub keep_alive: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub static_files: Option<RouteStaticFiles>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub test_response: Option<RouteTestResponse>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub url_rewrite: Option<RouteUrlRewrite>,
|
||||
}
|
||||
|
||||
// ─── NFTables Options ────────────────────────────────────────────────
|
||||
|
||||
/// NFTables protocol type.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum NfTablesProtocol {
|
||||
Tcp,
|
||||
Udp,
|
||||
All,
|
||||
}
|
||||
|
||||
/// NFTables-specific configuration options.
|
||||
/// Matches TypeScript: `INfTablesOptions`
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct NfTablesOptions {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub preserve_source_ip: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub protocol: Option<NfTablesProtocol>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub max_rate: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub priority: Option<i32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub table_name: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub use_ip_sets: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub use_advanced_nat: Option<bool>,
|
||||
}
|
||||
|
||||
// ─── Backend Protocol ────────────────────────────────────────────────
|
||||
|
||||
/// Backend protocol.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum BackendProtocol {
|
||||
Http1,
|
||||
Http2,
|
||||
}
|
||||
|
||||
/// Action options.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ActionOptions {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub backend_protocol: Option<BackendProtocol>,
|
||||
/// Catch-all for additional options
|
||||
#[serde(flatten)]
|
||||
pub extra: HashMap<String, serde_json::Value>,
|
||||
}
|
||||
|
||||
// ─── Route Target ────────────────────────────────────────────────────
|
||||
|
||||
/// Host specification: single string or array of strings.
|
||||
/// Note: Dynamic host functions are only available via programmatic API, not JSON.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum HostSpec {
|
||||
Single(String),
|
||||
List(Vec<String>),
|
||||
}
|
||||
|
||||
impl HostSpec {
|
||||
pub fn to_vec(&self) -> Vec<&str> {
|
||||
match self {
|
||||
HostSpec::Single(s) => vec![s.as_str()],
|
||||
HostSpec::List(v) => v.iter().map(|s| s.as_str()).collect(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn first(&self) -> &str {
|
||||
match self {
|
||||
HostSpec::Single(s) => s.as_str(),
|
||||
HostSpec::List(v) => v.first().map(|s| s.as_str()).unwrap_or(""),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Port specification: number or "preserve".
|
||||
/// Note: Dynamic port functions are only available via programmatic API, not JSON.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum PortSpec {
|
||||
/// Fixed port number
|
||||
Fixed(u16),
|
||||
/// Special string value like "preserve"
|
||||
Special(String),
|
||||
}
|
||||
|
||||
impl PortSpec {
|
||||
/// Resolve the port, using incoming_port when "preserve" is specified.
|
||||
pub fn resolve(&self, incoming_port: u16) -> u16 {
|
||||
match self {
|
||||
PortSpec::Fixed(p) => *p,
|
||||
PortSpec::Special(s) if s == "preserve" => incoming_port,
|
||||
PortSpec::Special(_) => incoming_port, // fallback
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Target configuration for forwarding with sub-matching and overrides.
|
||||
/// Matches TypeScript: `IRouteTarget`
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct RouteTarget {
|
||||
/// Optional sub-matching criteria within the route
|
||||
#[serde(rename = "match")]
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub target_match: Option<TargetMatch>,
|
||||
|
||||
/// Target host(s)
|
||||
pub host: HostSpec,
|
||||
|
||||
/// Target port
|
||||
pub port: PortSpec,
|
||||
|
||||
/// Override route-level TLS settings
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tls: Option<RouteTls>,
|
||||
|
||||
/// Override route-level WebSocket settings
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub websocket: Option<RouteWebSocket>,
|
||||
|
||||
/// Override route-level load balancing
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub load_balancing: Option<RouteLoadBalancing>,
|
||||
|
||||
/// Override route-level proxy protocol setting
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub send_proxy_protocol: Option<bool>,
|
||||
|
||||
/// Override route-level headers
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub headers: Option<RouteHeaders>,
|
||||
|
||||
/// Override route-level advanced settings
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub advanced: Option<RouteAdvanced>,
|
||||
|
||||
/// Priority for matching (higher values checked first, default: 0)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub priority: Option<i32>,
|
||||
}
|
||||
|
||||
// ─── Route Action ────────────────────────────────────────────────────
|
||||
|
||||
/// Action configuration for route handling.
|
||||
/// Matches TypeScript: `IRouteAction`
|
||||
///
|
||||
/// Note: `socketHandler` is not serializable in JSON. Use the programmatic API
|
||||
/// for socket handler routes.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct RouteAction {
|
||||
/// Basic routing type
|
||||
#[serde(rename = "type")]
|
||||
pub action_type: RouteActionType,
|
||||
|
||||
/// Targets for forwarding (array supports multiple targets with sub-matching)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub targets: Option<Vec<RouteTarget>>,
|
||||
|
||||
/// TLS handling (default for all targets)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tls: Option<RouteTls>,
|
||||
|
||||
/// WebSocket support (default for all targets)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub websocket: Option<RouteWebSocket>,
|
||||
|
||||
/// Load balancing options (default for all targets)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub load_balancing: Option<RouteLoadBalancing>,
|
||||
|
||||
/// Advanced options (default for all targets)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub advanced: Option<RouteAdvanced>,
|
||||
|
||||
/// Additional options
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub options: Option<ActionOptions>,
|
||||
|
||||
/// Forwarding engine specification
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub forwarding_engine: Option<ForwardingEngine>,
|
||||
|
||||
/// NFTables-specific options
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub nftables: Option<NfTablesOptions>,
|
||||
|
||||
/// PROXY protocol support (default for all targets)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub send_proxy_protocol: Option<bool>,
|
||||
}
|
||||
|
||||
// ─── Route Config ────────────────────────────────────────────────────
|
||||
|
||||
/// The core unified configuration interface.
|
||||
/// Matches TypeScript: `IRouteConfig`
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct RouteConfig {
|
||||
/// Unique identifier
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub id: Option<String>,
|
||||
|
||||
/// What to match
|
||||
#[serde(rename = "match")]
|
||||
pub route_match: RouteMatch,
|
||||
|
||||
/// What to do with matched traffic
|
||||
pub action: RouteAction,
|
||||
|
||||
/// Custom headers
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub headers: Option<RouteHeaders>,
|
||||
|
||||
/// Security features
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub security: Option<RouteSecurity>,
|
||||
|
||||
/// Human-readable name for this route
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub name: Option<String>,
|
||||
|
||||
/// Description of the route's purpose
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub description: Option<String>,
|
||||
|
||||
/// Controls matching order (higher = matched first)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub priority: Option<i32>,
|
||||
|
||||
/// Arbitrary tags for categorization
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tags: Option<Vec<String>>,
|
||||
|
||||
/// Whether the route is active (default: true)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub enabled: Option<bool>,
|
||||
}
|
||||
|
||||
impl RouteConfig {
|
||||
/// Check if this route is enabled (defaults to true).
|
||||
pub fn is_enabled(&self) -> bool {
|
||||
self.enabled.unwrap_or(true)
|
||||
}
|
||||
|
||||
/// Get the effective priority (defaults to 0).
|
||||
pub fn effective_priority(&self) -> i32 {
|
||||
self.priority.unwrap_or(0)
|
||||
}
|
||||
|
||||
/// Get all ports this route listens on.
|
||||
pub fn listening_ports(&self) -> Vec<u16> {
|
||||
self.route_match.ports.to_ports()
|
||||
}
|
||||
|
||||
/// Get the TLS mode for this route (from action-level or first target).
|
||||
pub fn tls_mode(&self) -> Option<&crate::tls_types::TlsMode> {
|
||||
// Check action-level TLS first
|
||||
if let Some(tls) = &self.action.tls {
|
||||
return Some(&tls.mode);
|
||||
}
|
||||
// Check first target's TLS
|
||||
if let Some(targets) = &self.action.targets {
|
||||
if let Some(first) = targets.first() {
|
||||
if let Some(tls) = &first.tls {
|
||||
return Some(&tls.mode);
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
}
|
||||
132
rust/crates/rustproxy-config/src/security_types.rs
Normal file
132
rust/crates/rustproxy-config/src/security_types.rs
Normal file
@@ -0,0 +1,132 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Rate limiting configuration.
|
||||
/// Matches TypeScript: `IRouteRateLimit`
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct RouteRateLimit {
|
||||
pub enabled: bool,
|
||||
pub max_requests: u64,
|
||||
/// Time window in seconds
|
||||
pub window: u64,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub key_by: Option<RateLimitKeyBy>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub header_name: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub error_message: Option<String>,
|
||||
}
|
||||
|
||||
/// Rate limit key selection.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum RateLimitKeyBy {
|
||||
Ip,
|
||||
Path,
|
||||
Header,
|
||||
}
|
||||
|
||||
/// Authentication type.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum AuthenticationType {
|
||||
Basic,
|
||||
Digest,
|
||||
Oauth,
|
||||
Jwt,
|
||||
}
|
||||
|
||||
/// Authentication credentials.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct AuthCredentials {
|
||||
pub username: String,
|
||||
pub password: String,
|
||||
}
|
||||
|
||||
/// Authentication options.
|
||||
/// Matches TypeScript: `IRouteAuthentication`
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct RouteAuthentication {
|
||||
#[serde(rename = "type")]
|
||||
pub auth_type: AuthenticationType,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub credentials: Option<Vec<AuthCredentials>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub realm: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub jwt_secret: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub jwt_issuer: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub oauth_provider: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub oauth_client_id: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub oauth_client_secret: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub oauth_redirect_uri: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub options: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
/// Basic auth configuration.
|
||||
/// Matches TypeScript: `IRouteSecurity.basicAuth`
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct BasicAuthConfig {
|
||||
pub enabled: bool,
|
||||
pub users: Vec<AuthCredentials>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub realm: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub exclude_paths: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
/// JWT auth configuration.
|
||||
/// Matches TypeScript: `IRouteSecurity.jwtAuth`
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct JwtAuthConfig {
|
||||
pub enabled: bool,
|
||||
pub secret: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub algorithm: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub issuer: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub audience: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub expires_in: Option<u64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub exclude_paths: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
/// Security options for routes.
|
||||
/// Matches TypeScript: `IRouteSecurity`
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct RouteSecurity {
|
||||
/// IP addresses that are allowed to connect
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub ip_allow_list: Option<Vec<String>>,
|
||||
/// IP addresses that are blocked from connecting
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub ip_block_list: Option<Vec<String>>,
|
||||
/// Maximum concurrent connections
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub max_connections: Option<u64>,
|
||||
/// Authentication configuration
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub authentication: Option<RouteAuthentication>,
|
||||
/// Rate limiting
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub rate_limit: Option<RouteRateLimit>,
|
||||
/// Basic auth
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub basic_auth: Option<BasicAuthConfig>,
|
||||
/// JWT auth
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub jwt_auth: Option<JwtAuthConfig>,
|
||||
}
|
||||
93
rust/crates/rustproxy-config/src/tls_types.rs
Normal file
93
rust/crates/rustproxy-config/src/tls_types.rs
Normal file
@@ -0,0 +1,93 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// TLS handling modes for route configurations.
|
||||
/// Matches TypeScript: `type TTlsMode = 'passthrough' | 'terminate' | 'terminate-and-reencrypt'`
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
pub enum TlsMode {
|
||||
Passthrough,
|
||||
Terminate,
|
||||
TerminateAndReencrypt,
|
||||
}
|
||||
|
||||
/// Static certificate configuration (PEM-encoded).
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct CertificateConfig {
|
||||
/// PEM-encoded private key
|
||||
pub key: String,
|
||||
/// PEM-encoded certificate
|
||||
pub cert: String,
|
||||
/// PEM-encoded CA chain
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub ca: Option<String>,
|
||||
/// Path to key file (overrides key)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub key_file: Option<String>,
|
||||
/// Path to cert file (overrides cert)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub cert_file: Option<String>,
|
||||
}
|
||||
|
||||
/// Certificate specification: either automatic (ACME) or static.
|
||||
/// Matches TypeScript: `certificate?: 'auto' | { key, cert, ca?, keyFile?, certFile? }`
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum CertificateSpec {
|
||||
/// Use ACME (Let's Encrypt) for automatic provisioning
|
||||
Auto(String), // "auto"
|
||||
/// Static certificate configuration
|
||||
Static(CertificateConfig),
|
||||
}
|
||||
|
||||
impl CertificateSpec {
|
||||
/// Check if this is an auto (ACME) certificate
|
||||
pub fn is_auto(&self) -> bool {
|
||||
matches!(self, CertificateSpec::Auto(s) if s == "auto")
|
||||
}
|
||||
}
|
||||
|
||||
/// ACME configuration for automatic certificate provisioning.
|
||||
/// Matches TypeScript: `IRouteAcme`
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct RouteAcme {
|
||||
/// Contact email for ACME account
|
||||
pub email: String,
|
||||
/// Use production ACME servers (default: false)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub use_production: Option<bool>,
|
||||
/// Port for HTTP-01 challenges (default: 80)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub challenge_port: Option<u16>,
|
||||
/// Days before expiry to renew (default: 30)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub renew_before_days: Option<u32>,
|
||||
}
|
||||
|
||||
/// TLS configuration for route actions.
|
||||
/// Matches TypeScript: `IRouteTls`
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct RouteTls {
|
||||
/// TLS mode (passthrough, terminate, terminate-and-reencrypt)
|
||||
pub mode: TlsMode,
|
||||
/// Certificate configuration (auto or static)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub certificate: Option<CertificateSpec>,
|
||||
/// ACME options when certificate is 'auto'
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub acme: Option<RouteAcme>,
|
||||
/// Allowed TLS versions
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub versions: Option<Vec<String>>,
|
||||
/// OpenSSL cipher string
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub ciphers: Option<String>,
|
||||
/// Use server's cipher preferences
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub honor_cipher_order: Option<bool>,
|
||||
/// TLS session timeout in seconds
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub session_timeout: Option<u64>,
|
||||
}
|
||||
158
rust/crates/rustproxy-config/src/validation.rs
Normal file
158
rust/crates/rustproxy-config/src/validation.rs
Normal file
@@ -0,0 +1,158 @@
|
||||
use thiserror::Error;
|
||||
|
||||
use crate::route_types::{RouteConfig, RouteActionType};
|
||||
|
||||
/// Validation errors for route configurations.
|
||||
#[derive(Debug, Error)]
|
||||
pub enum ValidationError {
|
||||
#[error("Route '{name}' has no targets but action type is 'forward'")]
|
||||
MissingTargets { name: String },
|
||||
|
||||
#[error("Route '{name}' has empty targets list")]
|
||||
EmptyTargets { name: String },
|
||||
|
||||
#[error("Route '{name}' has no ports specified")]
|
||||
NoPorts { name: String },
|
||||
|
||||
#[error("Route '{name}' port {port} is invalid (must be 1-65535)")]
|
||||
InvalidPort { name: String, port: u16 },
|
||||
|
||||
#[error("Route '{name}': socket-handler action type is not supported in JSON config")]
|
||||
SocketHandlerInJson { name: String },
|
||||
|
||||
#[error("Route '{name}': duplicate route ID '{id}'")]
|
||||
DuplicateId { name: String, id: String },
|
||||
|
||||
#[error("Route '{name}': {message}")]
|
||||
Custom { name: String, message: String },
|
||||
}
|
||||
|
||||
/// Validate a single route configuration.
|
||||
pub fn validate_route(route: &RouteConfig) -> Result<(), Vec<ValidationError>> {
|
||||
let mut errors = Vec::new();
|
||||
let name = route.name.clone().unwrap_or_else(|| {
|
||||
route.id.clone().unwrap_or_else(|| "unnamed".to_string())
|
||||
});
|
||||
|
||||
// Check ports
|
||||
let ports = route.listening_ports();
|
||||
if ports.is_empty() {
|
||||
errors.push(ValidationError::NoPorts { name: name.clone() });
|
||||
}
|
||||
for &port in &ports {
|
||||
if port == 0 {
|
||||
errors.push(ValidationError::InvalidPort {
|
||||
name: name.clone(),
|
||||
port,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Check forward action has targets
|
||||
if route.action.action_type == RouteActionType::Forward {
|
||||
match &route.action.targets {
|
||||
None => {
|
||||
errors.push(ValidationError::MissingTargets { name: name.clone() });
|
||||
}
|
||||
Some(targets) if targets.is_empty() => {
|
||||
errors.push(ValidationError::EmptyTargets { name: name.clone() });
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
if errors.is_empty() {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(errors)
|
||||
}
|
||||
}
|
||||
|
||||
/// Validate an entire list of routes.
|
||||
pub fn validate_routes(routes: &[RouteConfig]) -> Result<(), Vec<ValidationError>> {
|
||||
let mut all_errors = Vec::new();
|
||||
let mut seen_ids = std::collections::HashSet::new();
|
||||
|
||||
for route in routes {
|
||||
// Check for duplicate IDs
|
||||
if let Some(id) = &route.id {
|
||||
if !seen_ids.insert(id.clone()) {
|
||||
let name = route.name.clone().unwrap_or_else(|| id.clone());
|
||||
all_errors.push(ValidationError::DuplicateId {
|
||||
name,
|
||||
id: id.clone(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Validate individual route
|
||||
if let Err(errors) = validate_route(route) {
|
||||
all_errors.extend(errors);
|
||||
}
|
||||
}
|
||||
|
||||
if all_errors.is_empty() {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(all_errors)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::route_types::*;
|
||||
|
||||
fn make_valid_route() -> RouteConfig {
|
||||
crate::helpers::create_http_route("example.com", "localhost", 8080)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_valid_route_passes() {
|
||||
let route = make_valid_route();
|
||||
assert!(validate_route(&route).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_missing_targets() {
|
||||
let mut route = make_valid_route();
|
||||
route.action.targets = None;
|
||||
let errors = validate_route(&route).unwrap_err();
|
||||
assert!(errors.iter().any(|e| matches!(e, ValidationError::MissingTargets { .. })));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_targets() {
|
||||
let mut route = make_valid_route();
|
||||
route.action.targets = Some(vec![]);
|
||||
let errors = validate_route(&route).unwrap_err();
|
||||
assert!(errors.iter().any(|e| matches!(e, ValidationError::EmptyTargets { .. })));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_port_zero() {
|
||||
let mut route = make_valid_route();
|
||||
route.route_match.ports = PortRange::Single(0);
|
||||
let errors = validate_route(&route).unwrap_err();
|
||||
assert!(errors.iter().any(|e| matches!(e, ValidationError::InvalidPort { port: 0, .. })));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_duplicate_ids() {
|
||||
let mut r1 = make_valid_route();
|
||||
r1.id = Some("route-1".to_string());
|
||||
let mut r2 = make_valid_route();
|
||||
r2.id = Some("route-1".to_string());
|
||||
let errors = validate_routes(&[r1, r2]).unwrap_err();
|
||||
assert!(errors.iter().any(|e| matches!(e, ValidationError::DuplicateId { .. })));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_multiple_errors_collected() {
|
||||
let mut r1 = make_valid_route();
|
||||
r1.action.targets = None; // MissingTargets
|
||||
r1.route_match.ports = PortRange::Single(0); // InvalidPort
|
||||
let errors = validate_route(&r1).unwrap_err();
|
||||
assert!(errors.len() >= 2);
|
||||
}
|
||||
}
|
||||
24
rust/crates/rustproxy-http/Cargo.toml
Normal file
24
rust/crates/rustproxy-http/Cargo.toml
Normal file
@@ -0,0 +1,24 @@
|
||||
[package]
|
||||
name = "rustproxy-http"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
authors.workspace = true
|
||||
description = "Hyper-based HTTP proxy service for RustProxy"
|
||||
|
||||
[dependencies]
|
||||
rustproxy-config = { workspace = true }
|
||||
rustproxy-routing = { workspace = true }
|
||||
rustproxy-security = { workspace = true }
|
||||
rustproxy-metrics = { workspace = true }
|
||||
hyper = { workspace = true }
|
||||
hyper-util = { workspace = true }
|
||||
regex = { workspace = true }
|
||||
http-body-util = { workspace = true }
|
||||
bytes = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
anyhow = { workspace = true }
|
||||
arc-swap = { workspace = true }
|
||||
dashmap = { workspace = true }
|
||||
14
rust/crates/rustproxy-http/src/lib.rs
Normal file
14
rust/crates/rustproxy-http/src/lib.rs
Normal file
@@ -0,0 +1,14 @@
|
||||
//! # rustproxy-http
|
||||
//!
|
||||
//! Hyper-based HTTP proxy service for RustProxy.
|
||||
//! Handles HTTP request parsing, route-based forwarding, and response filtering.
|
||||
|
||||
pub mod proxy_service;
|
||||
pub mod request_filter;
|
||||
pub mod response_filter;
|
||||
pub mod template;
|
||||
pub mod upstream_selector;
|
||||
|
||||
pub use proxy_service::*;
|
||||
pub use template::*;
|
||||
pub use upstream_selector::*;
|
||||
827
rust/crates/rustproxy-http/src/proxy_service.rs
Normal file
827
rust/crates/rustproxy-http/src/proxy_service.rs
Normal file
@@ -0,0 +1,827 @@
|
||||
//! Hyper-based HTTP proxy service.
|
||||
//!
|
||||
//! Accepts decrypted TCP streams (from TLS termination or plain TCP),
|
||||
//! parses HTTP requests, matches routes, and forwards to upstream backends.
|
||||
//! Supports HTTP/1.1 keep-alive, HTTP/2 (auto-detect), and WebSocket upgrade.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use bytes::Bytes;
|
||||
use http_body_util::{BodyExt, Full, combinators::BoxBody};
|
||||
use hyper::body::Incoming;
|
||||
use hyper::{Request, Response, StatusCode};
|
||||
use hyper_util::rt::TokioIo;
|
||||
use regex::Regex;
|
||||
use tokio::net::TcpStream;
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
use rustproxy_routing::RouteManager;
|
||||
use rustproxy_metrics::MetricsCollector;
|
||||
|
||||
use crate::request_filter::RequestFilter;
|
||||
use crate::response_filter::ResponseFilter;
|
||||
use crate::upstream_selector::UpstreamSelector;
|
||||
|
||||
/// HTTP proxy service that processes HTTP traffic.
|
||||
pub struct HttpProxyService {
|
||||
route_manager: Arc<RouteManager>,
|
||||
metrics: Arc<MetricsCollector>,
|
||||
upstream_selector: UpstreamSelector,
|
||||
}
|
||||
|
||||
impl HttpProxyService {
|
||||
pub fn new(route_manager: Arc<RouteManager>, metrics: Arc<MetricsCollector>) -> Self {
|
||||
Self {
|
||||
route_manager,
|
||||
metrics,
|
||||
upstream_selector: UpstreamSelector::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle an incoming HTTP connection on a plain TCP stream.
|
||||
pub async fn handle_connection(
|
||||
self: Arc<Self>,
|
||||
stream: TcpStream,
|
||||
peer_addr: std::net::SocketAddr,
|
||||
port: u16,
|
||||
) {
|
||||
self.handle_io(stream, peer_addr, port).await;
|
||||
}
|
||||
|
||||
/// Handle an incoming HTTP connection on any IO type (plain TCP or TLS-terminated).
|
||||
///
|
||||
/// Uses HTTP/1.1 with upgrade support. For clients that negotiate HTTP/2,
|
||||
/// use `handle_io_auto` instead.
|
||||
pub async fn handle_io<I>(
|
||||
self: Arc<Self>,
|
||||
stream: I,
|
||||
peer_addr: std::net::SocketAddr,
|
||||
port: u16,
|
||||
)
|
||||
where
|
||||
I: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
|
||||
{
|
||||
let io = TokioIo::new(stream);
|
||||
|
||||
let service = hyper::service::service_fn(move |req: Request<Incoming>| {
|
||||
let svc = Arc::clone(&self);
|
||||
let peer = peer_addr;
|
||||
async move {
|
||||
svc.handle_request(req, peer, port).await
|
||||
}
|
||||
});
|
||||
|
||||
// Use http1::Builder with upgrades for WebSocket support
|
||||
let conn = hyper::server::conn::http1::Builder::new()
|
||||
.keep_alive(true)
|
||||
.serve_connection(io, service)
|
||||
.with_upgrades();
|
||||
|
||||
if let Err(e) = conn.await {
|
||||
debug!("HTTP connection error from {}: {}", peer_addr, e);
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle a single HTTP request.
|
||||
async fn handle_request(
|
||||
&self,
|
||||
req: Request<Incoming>,
|
||||
peer_addr: std::net::SocketAddr,
|
||||
port: u16,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
let host = req.headers()
|
||||
.get("host")
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.map(|h| {
|
||||
// Strip port from host header
|
||||
h.split(':').next().unwrap_or(h).to_string()
|
||||
});
|
||||
|
||||
let path = req.uri().path().to_string();
|
||||
let method = req.method().clone();
|
||||
|
||||
// Extract headers for matching
|
||||
let headers: HashMap<String, String> = req.headers()
|
||||
.iter()
|
||||
.map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
|
||||
.collect();
|
||||
|
||||
debug!("HTTP {} {} (host: {:?}) from {}", method, path, host, peer_addr);
|
||||
|
||||
// Check for CORS preflight
|
||||
if method == hyper::Method::OPTIONS {
|
||||
if let Some(response) = RequestFilter::handle_cors_preflight(&req) {
|
||||
return Ok(response);
|
||||
}
|
||||
}
|
||||
|
||||
// Match route
|
||||
let ctx = rustproxy_routing::MatchContext {
|
||||
port,
|
||||
domain: host.as_deref(),
|
||||
path: Some(&path),
|
||||
client_ip: Some(&peer_addr.ip().to_string()),
|
||||
tls_version: None,
|
||||
headers: Some(&headers),
|
||||
is_tls: false,
|
||||
};
|
||||
|
||||
let route_match = match self.route_manager.find_route(&ctx) {
|
||||
Some(rm) => rm,
|
||||
None => {
|
||||
debug!("No route matched for HTTP request to {:?}{}", host, path);
|
||||
return Ok(error_response(StatusCode::BAD_GATEWAY, "No route matched"));
|
||||
}
|
||||
};
|
||||
|
||||
let route_id = route_match.route.id.as_deref();
|
||||
self.metrics.connection_opened(route_id);
|
||||
|
||||
// Apply request filters (IP check, rate limiting, auth)
|
||||
if let Some(ref security) = route_match.route.security {
|
||||
if let Some(response) = RequestFilter::apply(security, &req, &peer_addr) {
|
||||
self.metrics.connection_closed(route_id);
|
||||
return Ok(response);
|
||||
}
|
||||
}
|
||||
|
||||
// Check for test response (returns immediately, no upstream needed)
|
||||
if let Some(ref advanced) = route_match.route.action.advanced {
|
||||
if let Some(ref test_response) = advanced.test_response {
|
||||
self.metrics.connection_closed(route_id);
|
||||
return Ok(Self::build_test_response(test_response));
|
||||
}
|
||||
}
|
||||
|
||||
// Check for static file serving
|
||||
if let Some(ref advanced) = route_match.route.action.advanced {
|
||||
if let Some(ref static_files) = advanced.static_files {
|
||||
self.metrics.connection_closed(route_id);
|
||||
return Ok(Self::serve_static_file(&path, static_files));
|
||||
}
|
||||
}
|
||||
|
||||
// Select upstream
|
||||
let target = match route_match.target {
|
||||
Some(t) => t,
|
||||
None => {
|
||||
self.metrics.connection_closed(route_id);
|
||||
return Ok(error_response(StatusCode::BAD_GATEWAY, "No target available"));
|
||||
}
|
||||
};
|
||||
|
||||
let upstream = self.upstream_selector.select(target, &peer_addr, port);
|
||||
let upstream_key = format!("{}:{}", upstream.host, upstream.port);
|
||||
self.upstream_selector.connection_started(&upstream_key);
|
||||
|
||||
// Check for WebSocket upgrade
|
||||
let is_websocket = req.headers()
|
||||
.get("upgrade")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(|v| v.eq_ignore_ascii_case("websocket"))
|
||||
.unwrap_or(false);
|
||||
|
||||
if is_websocket {
|
||||
let result = self.handle_websocket_upgrade(
|
||||
req, peer_addr, &upstream, route_match.route, route_id, &upstream_key,
|
||||
).await;
|
||||
// Note: for WebSocket, connection_ended is called inside
|
||||
// the spawned tunnel task when the connection closes.
|
||||
return result;
|
||||
}
|
||||
|
||||
// Determine backend protocol
|
||||
let use_h2 = route_match.route.action.options.as_ref()
|
||||
.and_then(|o| o.backend_protocol.as_ref())
|
||||
.map(|p| *p == rustproxy_config::BackendProtocol::Http2)
|
||||
.unwrap_or(false);
|
||||
|
||||
// Build the upstream path (path + query), applying URL rewriting if configured
|
||||
let upstream_path = {
|
||||
let raw_path = match req.uri().query() {
|
||||
Some(q) => format!("{}?{}", path, q),
|
||||
None => path.clone(),
|
||||
};
|
||||
Self::apply_url_rewrite(&raw_path, &route_match.route)
|
||||
};
|
||||
|
||||
// Build upstream request - stream body instead of buffering
|
||||
let (parts, body) = req.into_parts();
|
||||
|
||||
// Apply request headers from route config
|
||||
let mut upstream_headers = parts.headers.clone();
|
||||
if let Some(ref route_headers) = route_match.route.headers {
|
||||
if let Some(ref request_headers) = route_headers.request {
|
||||
for (key, value) in request_headers {
|
||||
if let Ok(name) = hyper::header::HeaderName::from_bytes(key.as_bytes()) {
|
||||
if let Ok(val) = hyper::header::HeaderValue::from_str(value) {
|
||||
upstream_headers.insert(name, val);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Connect to upstream
|
||||
let upstream_stream = match TcpStream::connect(format!("{}:{}", upstream.host, upstream.port)).await {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
error!("Failed to connect to upstream {}:{}: {}", upstream.host, upstream.port, e);
|
||||
self.upstream_selector.connection_ended(&upstream_key);
|
||||
self.metrics.connection_closed(route_id);
|
||||
return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend unavailable"));
|
||||
}
|
||||
};
|
||||
upstream_stream.set_nodelay(true).ok();
|
||||
|
||||
let io = TokioIo::new(upstream_stream);
|
||||
|
||||
let result = if use_h2 {
|
||||
// HTTP/2 backend
|
||||
self.forward_h2(io, parts, body, upstream_headers, &upstream_path, &upstream, route_match.route, route_id).await
|
||||
} else {
|
||||
// HTTP/1.1 backend (default)
|
||||
self.forward_h1(io, parts, body, upstream_headers, &upstream_path, &upstream, route_match.route, route_id).await
|
||||
};
|
||||
self.upstream_selector.connection_ended(&upstream_key);
|
||||
result
|
||||
}
|
||||
|
||||
/// Forward request to backend via HTTP/1.1 with body streaming.
|
||||
async fn forward_h1(
|
||||
&self,
|
||||
io: TokioIo<TcpStream>,
|
||||
parts: hyper::http::request::Parts,
|
||||
body: Incoming,
|
||||
upstream_headers: hyper::HeaderMap,
|
||||
upstream_path: &str,
|
||||
upstream: &crate::upstream_selector::UpstreamSelection,
|
||||
route: &rustproxy_config::RouteConfig,
|
||||
route_id: Option<&str>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
let (mut sender, conn) = match hyper::client::conn::http1::handshake(io).await {
|
||||
Ok(h) => h,
|
||||
Err(e) => {
|
||||
error!("Upstream handshake failed: {}", e);
|
||||
self.metrics.connection_closed(route_id);
|
||||
return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend handshake failed"));
|
||||
}
|
||||
};
|
||||
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = conn.await {
|
||||
debug!("Upstream connection error: {}", e);
|
||||
}
|
||||
});
|
||||
|
||||
let mut upstream_req = Request::builder()
|
||||
.method(parts.method)
|
||||
.uri(upstream_path)
|
||||
.version(parts.version);
|
||||
|
||||
if let Some(headers) = upstream_req.headers_mut() {
|
||||
*headers = upstream_headers;
|
||||
if let Ok(host_val) = hyper::header::HeaderValue::from_str(
|
||||
&format!("{}:{}", upstream.host, upstream.port)
|
||||
) {
|
||||
headers.insert(hyper::header::HOST, host_val);
|
||||
}
|
||||
}
|
||||
|
||||
// Stream the request body through to upstream
|
||||
let upstream_req = upstream_req.body(body).unwrap();
|
||||
|
||||
let upstream_response = match sender.send_request(upstream_req).await {
|
||||
Ok(resp) => resp,
|
||||
Err(e) => {
|
||||
error!("Upstream request failed: {}", e);
|
||||
self.metrics.connection_closed(route_id);
|
||||
return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend request failed"));
|
||||
}
|
||||
};
|
||||
|
||||
self.build_streaming_response(upstream_response, route, route_id).await
|
||||
}
|
||||
|
||||
/// Forward request to backend via HTTP/2 with body streaming.
|
||||
async fn forward_h2(
|
||||
&self,
|
||||
io: TokioIo<TcpStream>,
|
||||
parts: hyper::http::request::Parts,
|
||||
body: Incoming,
|
||||
upstream_headers: hyper::HeaderMap,
|
||||
upstream_path: &str,
|
||||
upstream: &crate::upstream_selector::UpstreamSelection,
|
||||
route: &rustproxy_config::RouteConfig,
|
||||
route_id: Option<&str>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
let exec = hyper_util::rt::TokioExecutor::new();
|
||||
let (mut sender, conn) = match hyper::client::conn::http2::handshake(exec, io).await {
|
||||
Ok(h) => h,
|
||||
Err(e) => {
|
||||
error!("HTTP/2 upstream handshake failed: {}", e);
|
||||
self.metrics.connection_closed(route_id);
|
||||
return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend H2 handshake failed"));
|
||||
}
|
||||
};
|
||||
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = conn.await {
|
||||
debug!("HTTP/2 upstream connection error: {}", e);
|
||||
}
|
||||
});
|
||||
|
||||
let mut upstream_req = Request::builder()
|
||||
.method(parts.method)
|
||||
.uri(upstream_path);
|
||||
|
||||
if let Some(headers) = upstream_req.headers_mut() {
|
||||
*headers = upstream_headers;
|
||||
if let Ok(host_val) = hyper::header::HeaderValue::from_str(
|
||||
&format!("{}:{}", upstream.host, upstream.port)
|
||||
) {
|
||||
headers.insert(hyper::header::HOST, host_val);
|
||||
}
|
||||
}
|
||||
|
||||
// Stream the request body through to upstream
|
||||
let upstream_req = upstream_req.body(body).unwrap();
|
||||
|
||||
let upstream_response = match sender.send_request(upstream_req).await {
|
||||
Ok(resp) => resp,
|
||||
Err(e) => {
|
||||
error!("HTTP/2 upstream request failed: {}", e);
|
||||
self.metrics.connection_closed(route_id);
|
||||
return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend H2 request failed"));
|
||||
}
|
||||
};
|
||||
|
||||
self.build_streaming_response(upstream_response, route, route_id).await
|
||||
}
|
||||
|
||||
/// Build the client-facing response from an upstream response, streaming the body.
|
||||
async fn build_streaming_response(
|
||||
&self,
|
||||
upstream_response: Response<Incoming>,
|
||||
route: &rustproxy_config::RouteConfig,
|
||||
route_id: Option<&str>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
let (resp_parts, resp_body) = upstream_response.into_parts();
|
||||
|
||||
let mut response = Response::builder()
|
||||
.status(resp_parts.status);
|
||||
|
||||
if let Some(headers) = response.headers_mut() {
|
||||
*headers = resp_parts.headers;
|
||||
ResponseFilter::apply_headers(route, headers, None);
|
||||
}
|
||||
|
||||
self.metrics.connection_closed(route_id);
|
||||
|
||||
// Stream the response body directly from upstream to client
|
||||
let body: BoxBody<Bytes, hyper::Error> = BoxBody::new(resp_body);
|
||||
|
||||
Ok(response.body(body).unwrap())
|
||||
}
|
||||
|
||||
/// Handle a WebSocket upgrade request.
|
||||
async fn handle_websocket_upgrade(
|
||||
&self,
|
||||
req: Request<Incoming>,
|
||||
peer_addr: std::net::SocketAddr,
|
||||
upstream: &crate::upstream_selector::UpstreamSelection,
|
||||
route: &rustproxy_config::RouteConfig,
|
||||
route_id: Option<&str>,
|
||||
upstream_key: &str,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
|
||||
// Get WebSocket config from route
|
||||
let ws_config = route.action.websocket.as_ref();
|
||||
|
||||
// Check allowed origins if configured
|
||||
if let Some(ws) = ws_config {
|
||||
if let Some(ref allowed_origins) = ws.allowed_origins {
|
||||
let origin = req.headers()
|
||||
.get("origin")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.unwrap_or("");
|
||||
if !allowed_origins.is_empty() && !allowed_origins.iter().any(|o| o == "*" || o == origin) {
|
||||
self.upstream_selector.connection_ended(upstream_key);
|
||||
self.metrics.connection_closed(route_id);
|
||||
return Ok(error_response(StatusCode::FORBIDDEN, "Origin not allowed"));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
info!("WebSocket upgrade from {} -> {}:{}", peer_addr, upstream.host, upstream.port);
|
||||
|
||||
let mut upstream_stream = match TcpStream::connect(
|
||||
format!("{}:{}", upstream.host, upstream.port)
|
||||
).await {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
error!("WebSocket: failed to connect upstream {}:{}: {}", upstream.host, upstream.port, e);
|
||||
self.upstream_selector.connection_ended(upstream_key);
|
||||
self.metrics.connection_closed(route_id);
|
||||
return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend unavailable"));
|
||||
}
|
||||
};
|
||||
upstream_stream.set_nodelay(true).ok();
|
||||
|
||||
let path = req.uri().path().to_string();
|
||||
let upstream_path = {
|
||||
let raw = match req.uri().query() {
|
||||
Some(q) => format!("{}?{}", path, q),
|
||||
None => path,
|
||||
};
|
||||
// Apply rewrite_path if configured
|
||||
if let Some(ws) = ws_config {
|
||||
if let Some(ref rewrite_path) = ws.rewrite_path {
|
||||
rewrite_path.clone()
|
||||
} else {
|
||||
raw
|
||||
}
|
||||
} else {
|
||||
raw
|
||||
}
|
||||
};
|
||||
|
||||
let (parts, _body) = req.into_parts();
|
||||
|
||||
let mut raw_request = format!(
|
||||
"{} {} HTTP/1.1\r\n",
|
||||
parts.method, upstream_path
|
||||
);
|
||||
|
||||
let upstream_host = format!("{}:{}", upstream.host, upstream.port);
|
||||
for (name, value) in parts.headers.iter() {
|
||||
if name == hyper::header::HOST {
|
||||
raw_request.push_str(&format!("host: {}\r\n", upstream_host));
|
||||
} else {
|
||||
raw_request.push_str(&format!("{}: {}\r\n", name, value.to_str().unwrap_or("")));
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(ref route_headers) = route.headers {
|
||||
if let Some(ref request_headers) = route_headers.request {
|
||||
for (key, value) in request_headers {
|
||||
raw_request.push_str(&format!("{}: {}\r\n", key, value));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Apply WebSocket custom headers
|
||||
if let Some(ws) = ws_config {
|
||||
if let Some(ref custom_headers) = ws.custom_headers {
|
||||
for (key, value) in custom_headers {
|
||||
raw_request.push_str(&format!("{}: {}\r\n", key, value));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
raw_request.push_str("\r\n");
|
||||
|
||||
if let Err(e) = upstream_stream.write_all(raw_request.as_bytes()).await {
|
||||
error!("WebSocket: failed to send upgrade request to upstream: {}", e);
|
||||
self.upstream_selector.connection_ended(upstream_key);
|
||||
self.metrics.connection_closed(route_id);
|
||||
return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend write failed"));
|
||||
}
|
||||
|
||||
let mut response_buf = Vec::with_capacity(4096);
|
||||
let mut temp = [0u8; 1];
|
||||
loop {
|
||||
match upstream_stream.read(&mut temp).await {
|
||||
Ok(0) => {
|
||||
error!("WebSocket: upstream closed before completing handshake");
|
||||
self.upstream_selector.connection_ended(upstream_key);
|
||||
self.metrics.connection_closed(route_id);
|
||||
return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend closed"));
|
||||
}
|
||||
Ok(_) => {
|
||||
response_buf.push(temp[0]);
|
||||
if response_buf.len() >= 4 {
|
||||
let len = response_buf.len();
|
||||
if response_buf[len-4..] == *b"\r\n\r\n" {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if response_buf.len() > 8192 {
|
||||
error!("WebSocket: upstream response headers too large");
|
||||
self.upstream_selector.connection_ended(upstream_key);
|
||||
self.metrics.connection_closed(route_id);
|
||||
return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend response too large"));
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
error!("WebSocket: failed to read upstream response: {}", e);
|
||||
self.upstream_selector.connection_ended(upstream_key);
|
||||
self.metrics.connection_closed(route_id);
|
||||
return Ok(error_response(StatusCode::BAD_GATEWAY, "Backend read failed"));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let response_str = String::from_utf8_lossy(&response_buf);
|
||||
|
||||
let status_line = response_str.lines().next().unwrap_or("");
|
||||
let status_code = status_line
|
||||
.split_whitespace()
|
||||
.nth(1)
|
||||
.and_then(|s| s.parse::<u16>().ok())
|
||||
.unwrap_or(0);
|
||||
|
||||
if status_code != 101 {
|
||||
debug!("WebSocket: upstream rejected upgrade with status {}", status_code);
|
||||
self.upstream_selector.connection_ended(upstream_key);
|
||||
self.metrics.connection_closed(route_id);
|
||||
return Ok(error_response(
|
||||
StatusCode::from_u16(status_code).unwrap_or(StatusCode::BAD_GATEWAY),
|
||||
"WebSocket upgrade rejected by backend",
|
||||
));
|
||||
}
|
||||
|
||||
let mut client_resp = Response::builder()
|
||||
.status(StatusCode::SWITCHING_PROTOCOLS);
|
||||
|
||||
if let Some(resp_headers) = client_resp.headers_mut() {
|
||||
for line in response_str.lines().skip(1) {
|
||||
let line = line.trim();
|
||||
if line.is_empty() {
|
||||
break;
|
||||
}
|
||||
if let Some((name, value)) = line.split_once(':') {
|
||||
let name = name.trim();
|
||||
let value = value.trim();
|
||||
if let Ok(header_name) = hyper::header::HeaderName::from_bytes(name.as_bytes()) {
|
||||
if let Ok(header_value) = hyper::header::HeaderValue::from_str(value) {
|
||||
resp_headers.insert(header_name, header_value);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let on_client_upgrade = hyper::upgrade::on(
|
||||
Request::from_parts(parts, http_body_util::Empty::<Bytes>::new())
|
||||
);
|
||||
|
||||
let metrics = Arc::clone(&self.metrics);
|
||||
let route_id_owned = route_id.map(|s| s.to_string());
|
||||
let upstream_selector = self.upstream_selector.clone();
|
||||
let upstream_key_owned = upstream_key.to_string();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let client_upgraded = match on_client_upgrade.await {
|
||||
Ok(upgraded) => upgraded,
|
||||
Err(e) => {
|
||||
debug!("WebSocket: client upgrade failed: {}", e);
|
||||
upstream_selector.connection_ended(&upstream_key_owned);
|
||||
if let Some(ref rid) = route_id_owned {
|
||||
metrics.connection_closed(Some(rid.as_str()));
|
||||
}
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let client_io = TokioIo::new(client_upgraded);
|
||||
|
||||
let (mut cr, mut cw) = tokio::io::split(client_io);
|
||||
let (mut ur, mut uw) = tokio::io::split(upstream_stream);
|
||||
|
||||
let c2u = tokio::spawn(async move {
|
||||
let mut buf = vec![0u8; 65536];
|
||||
let mut total = 0u64;
|
||||
loop {
|
||||
let n = match cr.read(&mut buf).await {
|
||||
Ok(0) | Err(_) => break,
|
||||
Ok(n) => n,
|
||||
};
|
||||
if uw.write_all(&buf[..n]).await.is_err() {
|
||||
break;
|
||||
}
|
||||
total += n as u64;
|
||||
}
|
||||
let _ = uw.shutdown().await;
|
||||
total
|
||||
});
|
||||
|
||||
let u2c = tokio::spawn(async move {
|
||||
let mut buf = vec![0u8; 65536];
|
||||
let mut total = 0u64;
|
||||
loop {
|
||||
let n = match ur.read(&mut buf).await {
|
||||
Ok(0) | Err(_) => break,
|
||||
Ok(n) => n,
|
||||
};
|
||||
if cw.write_all(&buf[..n]).await.is_err() {
|
||||
break;
|
||||
}
|
||||
total += n as u64;
|
||||
}
|
||||
let _ = cw.shutdown().await;
|
||||
total
|
||||
});
|
||||
|
||||
let bytes_in = c2u.await.unwrap_or(0);
|
||||
let bytes_out = u2c.await.unwrap_or(0);
|
||||
|
||||
debug!("WebSocket tunnel closed: {} bytes in, {} bytes out", bytes_in, bytes_out);
|
||||
|
||||
upstream_selector.connection_ended(&upstream_key_owned);
|
||||
if let Some(ref rid) = route_id_owned {
|
||||
metrics.record_bytes(bytes_in, bytes_out, Some(rid.as_str()));
|
||||
metrics.connection_closed(Some(rid.as_str()));
|
||||
}
|
||||
});
|
||||
|
||||
let body: BoxBody<Bytes, hyper::Error> = BoxBody::new(
|
||||
http_body_util::Empty::<Bytes>::new().map_err(|never| match never {})
|
||||
);
|
||||
Ok(client_resp.body(body).unwrap())
|
||||
}
|
||||
|
||||
/// Build a test response from config (no upstream connection needed).
|
||||
fn build_test_response(config: &rustproxy_config::RouteTestResponse) -> Response<BoxBody<Bytes, hyper::Error>> {
|
||||
let mut response = Response::builder()
|
||||
.status(StatusCode::from_u16(config.status).unwrap_or(StatusCode::OK));
|
||||
|
||||
if let Some(headers) = response.headers_mut() {
|
||||
for (key, value) in &config.headers {
|
||||
if let Ok(name) = hyper::header::HeaderName::from_bytes(key.as_bytes()) {
|
||||
if let Ok(val) = hyper::header::HeaderValue::from_str(value) {
|
||||
headers.insert(name, val);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let body = Full::new(Bytes::from(config.body.clone()))
|
||||
.map_err(|never| match never {});
|
||||
response.body(BoxBody::new(body)).unwrap()
|
||||
}
|
||||
|
||||
/// Apply URL rewriting rules from route config.
|
||||
fn apply_url_rewrite(path: &str, route: &rustproxy_config::RouteConfig) -> String {
|
||||
let rewrite = match route.action.advanced.as_ref()
|
||||
.and_then(|a| a.url_rewrite.as_ref())
|
||||
{
|
||||
Some(r) => r,
|
||||
None => return path.to_string(),
|
||||
};
|
||||
|
||||
// Determine what to rewrite
|
||||
let (subject, suffix) = if rewrite.only_rewrite_path.unwrap_or(false) {
|
||||
// Only rewrite the path portion (before ?)
|
||||
match path.split_once('?') {
|
||||
Some((p, q)) => (p.to_string(), format!("?{}", q)),
|
||||
None => (path.to_string(), String::new()),
|
||||
}
|
||||
} else {
|
||||
(path.to_string(), String::new())
|
||||
};
|
||||
|
||||
match Regex::new(&rewrite.pattern) {
|
||||
Ok(re) => {
|
||||
let result = re.replace_all(&subject, rewrite.target.as_str());
|
||||
format!("{}{}", result, suffix)
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Invalid URL rewrite pattern '{}': {}", rewrite.pattern, e);
|
||||
path.to_string()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Serve a static file from the configured directory.
|
||||
fn serve_static_file(
|
||||
path: &str,
|
||||
config: &rustproxy_config::RouteStaticFiles,
|
||||
) -> Response<BoxBody<Bytes, hyper::Error>> {
|
||||
use std::path::Path;
|
||||
|
||||
let root = Path::new(&config.root);
|
||||
|
||||
// Sanitize path to prevent directory traversal
|
||||
let clean_path = path.trim_start_matches('/');
|
||||
let clean_path = clean_path.replace("..", "");
|
||||
|
||||
let mut file_path = root.join(&clean_path);
|
||||
|
||||
// If path points to a directory, try index files
|
||||
if file_path.is_dir() || clean_path.is_empty() {
|
||||
let index_files = config.index_files.as_deref()
|
||||
.or(config.index.as_deref())
|
||||
.unwrap_or(&[]);
|
||||
let default_index = vec!["index.html".to_string()];
|
||||
let index_files = if index_files.is_empty() { &default_index } else { index_files };
|
||||
|
||||
let mut found = false;
|
||||
for index in index_files {
|
||||
let candidate = if clean_path.is_empty() {
|
||||
root.join(index)
|
||||
} else {
|
||||
file_path.join(index)
|
||||
};
|
||||
if candidate.is_file() {
|
||||
file_path = candidate;
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return error_response(StatusCode::NOT_FOUND, "Not found");
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure the resolved path is within the root (prevent traversal)
|
||||
let canonical_root = match root.canonicalize() {
|
||||
Ok(p) => p,
|
||||
Err(_) => return error_response(StatusCode::NOT_FOUND, "Not found"),
|
||||
};
|
||||
let canonical_file = match file_path.canonicalize() {
|
||||
Ok(p) => p,
|
||||
Err(_) => return error_response(StatusCode::NOT_FOUND, "Not found"),
|
||||
};
|
||||
if !canonical_file.starts_with(&canonical_root) {
|
||||
return error_response(StatusCode::FORBIDDEN, "Forbidden");
|
||||
}
|
||||
|
||||
// Check if symlinks are allowed
|
||||
if config.follow_symlinks == Some(false) && canonical_file != file_path {
|
||||
return error_response(StatusCode::FORBIDDEN, "Forbidden");
|
||||
}
|
||||
|
||||
// Read the file
|
||||
match std::fs::read(&file_path) {
|
||||
Ok(content) => {
|
||||
let content_type = guess_content_type(&file_path);
|
||||
let mut response = Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
.header("Content-Type", content_type);
|
||||
|
||||
// Apply cache-control if configured
|
||||
if let Some(ref cache_control) = config.cache_control {
|
||||
response = response.header("Cache-Control", cache_control.as_str());
|
||||
}
|
||||
|
||||
// Apply custom headers
|
||||
if let Some(ref headers) = config.headers {
|
||||
for (key, value) in headers {
|
||||
response = response.header(key.as_str(), value.as_str());
|
||||
}
|
||||
}
|
||||
|
||||
let body = Full::new(Bytes::from(content))
|
||||
.map_err(|never| match never {});
|
||||
response.body(BoxBody::new(body)).unwrap()
|
||||
}
|
||||
Err(_) => error_response(StatusCode::NOT_FOUND, "Not found"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Guess MIME content type from file extension.
|
||||
fn guess_content_type(path: &std::path::Path) -> &'static str {
|
||||
match path.extension().and_then(|e| e.to_str()) {
|
||||
Some("html") | Some("htm") => "text/html; charset=utf-8",
|
||||
Some("css") => "text/css; charset=utf-8",
|
||||
Some("js") | Some("mjs") => "application/javascript; charset=utf-8",
|
||||
Some("json") => "application/json; charset=utf-8",
|
||||
Some("xml") => "application/xml; charset=utf-8",
|
||||
Some("txt") => "text/plain; charset=utf-8",
|
||||
Some("png") => "image/png",
|
||||
Some("jpg") | Some("jpeg") => "image/jpeg",
|
||||
Some("gif") => "image/gif",
|
||||
Some("svg") => "image/svg+xml",
|
||||
Some("ico") => "image/x-icon",
|
||||
Some("woff") => "font/woff",
|
||||
Some("woff2") => "font/woff2",
|
||||
Some("ttf") => "font/ttf",
|
||||
Some("pdf") => "application/pdf",
|
||||
Some("wasm") => "application/wasm",
|
||||
_ => "application/octet-stream",
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for HttpProxyService {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
route_manager: Arc::new(RouteManager::new(vec![])),
|
||||
metrics: Arc::new(MetricsCollector::new()),
|
||||
upstream_selector: UpstreamSelector::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn error_response(status: StatusCode, message: &str) -> Response<BoxBody<Bytes, hyper::Error>> {
|
||||
let body = Full::new(Bytes::from(message.to_string()))
|
||||
.map_err(|never| match never {});
|
||||
Response::builder()
|
||||
.status(status)
|
||||
.header("Content-Type", "text/plain")
|
||||
.body(BoxBody::new(body))
|
||||
.unwrap()
|
||||
}
|
||||
263
rust/crates/rustproxy-http/src/request_filter.rs
Normal file
263
rust/crates/rustproxy-http/src/request_filter.rs
Normal file
@@ -0,0 +1,263 @@
|
||||
//! Request filtering: security checks, auth, CORS preflight.
|
||||
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
|
||||
use bytes::Bytes;
|
||||
use http_body_util::Full;
|
||||
use http_body_util::BodyExt;
|
||||
use hyper::body::Incoming;
|
||||
use hyper::{Request, Response, StatusCode};
|
||||
use http_body_util::combinators::BoxBody;
|
||||
|
||||
use rustproxy_config::RouteSecurity;
|
||||
use rustproxy_security::{IpFilter, BasicAuthValidator, JwtValidator, RateLimiter};
|
||||
|
||||
pub struct RequestFilter;
|
||||
|
||||
impl RequestFilter {
|
||||
/// Apply security filters. Returns Some(response) if the request should be blocked.
|
||||
pub fn apply(
|
||||
security: &RouteSecurity,
|
||||
req: &Request<Incoming>,
|
||||
peer_addr: &SocketAddr,
|
||||
) -> Option<Response<BoxBody<Bytes, hyper::Error>>> {
|
||||
Self::apply_with_rate_limiter(security, req, peer_addr, None)
|
||||
}
|
||||
|
||||
/// Apply security filters with an optional shared rate limiter.
|
||||
/// Returns Some(response) if the request should be blocked.
|
||||
pub fn apply_with_rate_limiter(
|
||||
security: &RouteSecurity,
|
||||
req: &Request<Incoming>,
|
||||
peer_addr: &SocketAddr,
|
||||
rate_limiter: Option<&Arc<RateLimiter>>,
|
||||
) -> Option<Response<BoxBody<Bytes, hyper::Error>>> {
|
||||
let client_ip = peer_addr.ip();
|
||||
let request_path = req.uri().path();
|
||||
|
||||
// IP filter
|
||||
if security.ip_allow_list.is_some() || security.ip_block_list.is_some() {
|
||||
let allow = security.ip_allow_list.as_deref().unwrap_or(&[]);
|
||||
let block = security.ip_block_list.as_deref().unwrap_or(&[]);
|
||||
let filter = IpFilter::new(allow, block);
|
||||
let normalized = IpFilter::normalize_ip(&client_ip);
|
||||
if !filter.is_allowed(&normalized) {
|
||||
return Some(error_response(StatusCode::FORBIDDEN, "Access denied"));
|
||||
}
|
||||
}
|
||||
|
||||
// Rate limiting
|
||||
if let Some(ref rate_limit_config) = security.rate_limit {
|
||||
if rate_limit_config.enabled {
|
||||
// Use shared rate limiter if provided, otherwise create ephemeral one
|
||||
let should_block = if let Some(limiter) = rate_limiter {
|
||||
let key = Self::rate_limit_key(rate_limit_config, req, peer_addr);
|
||||
!limiter.check(&key)
|
||||
} else {
|
||||
// Create a per-check limiter (less ideal but works for non-shared case)
|
||||
let limiter = RateLimiter::new(
|
||||
rate_limit_config.max_requests,
|
||||
rate_limit_config.window,
|
||||
);
|
||||
let key = Self::rate_limit_key(rate_limit_config, req, peer_addr);
|
||||
!limiter.check(&key)
|
||||
};
|
||||
|
||||
if should_block {
|
||||
let message = rate_limit_config.error_message
|
||||
.as_deref()
|
||||
.unwrap_or("Rate limit exceeded");
|
||||
return Some(error_response(StatusCode::TOO_MANY_REQUESTS, message));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check exclude paths before auth
|
||||
let should_skip_auth = Self::path_matches_exclude_list(request_path, security);
|
||||
|
||||
if !should_skip_auth {
|
||||
// Basic auth
|
||||
if let Some(ref basic_auth) = security.basic_auth {
|
||||
if basic_auth.enabled {
|
||||
// Check basic auth exclude paths
|
||||
let skip_basic = basic_auth.exclude_paths.as_ref()
|
||||
.map(|paths| Self::path_matches_any(request_path, paths))
|
||||
.unwrap_or(false);
|
||||
|
||||
if !skip_basic {
|
||||
let users: Vec<(String, String)> = basic_auth.users.iter()
|
||||
.map(|c| (c.username.clone(), c.password.clone()))
|
||||
.collect();
|
||||
let validator = BasicAuthValidator::new(users, basic_auth.realm.clone());
|
||||
|
||||
let auth_header = req.headers()
|
||||
.get("authorization")
|
||||
.and_then(|v| v.to_str().ok());
|
||||
|
||||
match auth_header {
|
||||
Some(header) => {
|
||||
if validator.validate(header).is_none() {
|
||||
return Some(Response::builder()
|
||||
.status(StatusCode::UNAUTHORIZED)
|
||||
.header("WWW-Authenticate", validator.www_authenticate())
|
||||
.body(boxed_body("Invalid credentials"))
|
||||
.unwrap());
|
||||
}
|
||||
}
|
||||
None => {
|
||||
return Some(Response::builder()
|
||||
.status(StatusCode::UNAUTHORIZED)
|
||||
.header("WWW-Authenticate", validator.www_authenticate())
|
||||
.body(boxed_body("Authentication required"))
|
||||
.unwrap());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// JWT auth
|
||||
if let Some(ref jwt_auth) = security.jwt_auth {
|
||||
if jwt_auth.enabled {
|
||||
// Check JWT auth exclude paths
|
||||
let skip_jwt = jwt_auth.exclude_paths.as_ref()
|
||||
.map(|paths| Self::path_matches_any(request_path, paths))
|
||||
.unwrap_or(false);
|
||||
|
||||
if !skip_jwt {
|
||||
let validator = JwtValidator::new(
|
||||
&jwt_auth.secret,
|
||||
jwt_auth.algorithm.as_deref(),
|
||||
jwt_auth.issuer.as_deref(),
|
||||
jwt_auth.audience.as_deref(),
|
||||
);
|
||||
|
||||
let auth_header = req.headers()
|
||||
.get("authorization")
|
||||
.and_then(|v| v.to_str().ok());
|
||||
|
||||
match auth_header.and_then(JwtValidator::extract_token) {
|
||||
Some(token) => {
|
||||
if validator.validate(token).is_err() {
|
||||
return Some(error_response(StatusCode::UNAUTHORIZED, "Invalid token"));
|
||||
}
|
||||
}
|
||||
None => {
|
||||
return Some(error_response(StatusCode::UNAUTHORIZED, "Bearer token required"));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Check if a request path matches any pattern in the exclude list.
|
||||
fn path_matches_exclude_list(_path: &str, _security: &RouteSecurity) -> bool {
|
||||
// No global exclude paths on RouteSecurity currently,
|
||||
// but we check per-auth exclude paths above.
|
||||
// This can be extended if a global exclude_paths is added.
|
||||
false
|
||||
}
|
||||
|
||||
/// Check if a path matches any pattern in the list.
|
||||
/// Supports simple glob patterns: `/health*` matches `/health`, `/healthz`, `/health/check`
|
||||
fn path_matches_any(path: &str, patterns: &[String]) -> bool {
|
||||
for pattern in patterns {
|
||||
if pattern.ends_with('*') {
|
||||
let prefix = &pattern[..pattern.len() - 1];
|
||||
if path.starts_with(prefix) {
|
||||
return true;
|
||||
}
|
||||
} else if path == pattern {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
/// Determine the rate limit key based on configuration.
|
||||
fn rate_limit_key(
|
||||
config: &rustproxy_config::RouteRateLimit,
|
||||
req: &Request<Incoming>,
|
||||
peer_addr: &SocketAddr,
|
||||
) -> String {
|
||||
use rustproxy_config::RateLimitKeyBy;
|
||||
match config.key_by.as_ref().unwrap_or(&RateLimitKeyBy::Ip) {
|
||||
RateLimitKeyBy::Ip => peer_addr.ip().to_string(),
|
||||
RateLimitKeyBy::Path => req.uri().path().to_string(),
|
||||
RateLimitKeyBy::Header => {
|
||||
if let Some(ref header_name) = config.header_name {
|
||||
req.headers()
|
||||
.get(header_name.as_str())
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.unwrap_or("unknown")
|
||||
.to_string()
|
||||
} else {
|
||||
peer_addr.ip().to_string()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Check IP-based security (for use in passthrough / TCP-level connections).
|
||||
/// Returns true if allowed, false if blocked.
|
||||
pub fn check_ip_security(security: &RouteSecurity, client_ip: &std::net::IpAddr) -> bool {
|
||||
if security.ip_allow_list.is_some() || security.ip_block_list.is_some() {
|
||||
let allow = security.ip_allow_list.as_deref().unwrap_or(&[]);
|
||||
let block = security.ip_block_list.as_deref().unwrap_or(&[]);
|
||||
let filter = IpFilter::new(allow, block);
|
||||
let normalized = IpFilter::normalize_ip(client_ip);
|
||||
filter.is_allowed(&normalized)
|
||||
} else {
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle CORS preflight (OPTIONS) requests.
|
||||
/// Returns Some(response) if this is a CORS preflight that should be handled.
|
||||
pub fn handle_cors_preflight(
|
||||
req: &Request<Incoming>,
|
||||
) -> Option<Response<BoxBody<Bytes, hyper::Error>>> {
|
||||
if req.method() != hyper::Method::OPTIONS {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Check for CORS preflight indicators
|
||||
let has_origin = req.headers().contains_key("origin");
|
||||
let has_request_method = req.headers().contains_key("access-control-request-method");
|
||||
|
||||
if !has_origin || !has_request_method {
|
||||
return None;
|
||||
}
|
||||
|
||||
let origin = req.headers()
|
||||
.get("origin")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.unwrap_or("*");
|
||||
|
||||
Some(Response::builder()
|
||||
.status(StatusCode::NO_CONTENT)
|
||||
.header("Access-Control-Allow-Origin", origin)
|
||||
.header("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, PATCH, OPTIONS")
|
||||
.header("Access-Control-Allow-Headers", "Content-Type, Authorization, X-Requested-With")
|
||||
.header("Access-Control-Max-Age", "86400")
|
||||
.body(boxed_body(""))
|
||||
.unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
fn error_response(status: StatusCode, message: &str) -> Response<BoxBody<Bytes, hyper::Error>> {
|
||||
Response::builder()
|
||||
.status(status)
|
||||
.header("Content-Type", "text/plain")
|
||||
.body(boxed_body(message))
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
fn boxed_body(data: &str) -> BoxBody<Bytes, hyper::Error> {
|
||||
BoxBody::new(Full::new(Bytes::from(data.to_string())).map_err(|never| match never {}))
|
||||
}
|
||||
92
rust/crates/rustproxy-http/src/response_filter.rs
Normal file
92
rust/crates/rustproxy-http/src/response_filter.rs
Normal file
@@ -0,0 +1,92 @@
|
||||
//! Response filtering: CORS headers, custom headers, security headers.
|
||||
|
||||
use hyper::header::{HeaderMap, HeaderName, HeaderValue};
|
||||
use rustproxy_config::RouteConfig;
|
||||
|
||||
use crate::template::{RequestContext, expand_template};
|
||||
|
||||
pub struct ResponseFilter;
|
||||
|
||||
impl ResponseFilter {
|
||||
/// Apply response headers from route config and CORS settings.
|
||||
/// If a `RequestContext` is provided, template variables in header values will be expanded.
|
||||
pub fn apply_headers(route: &RouteConfig, headers: &mut HeaderMap, req_ctx: Option<&RequestContext>) {
|
||||
// Apply custom response headers from route config
|
||||
if let Some(ref route_headers) = route.headers {
|
||||
if let Some(ref response_headers) = route_headers.response {
|
||||
for (key, value) in response_headers {
|
||||
if let Ok(name) = HeaderName::from_bytes(key.as_bytes()) {
|
||||
let expanded = match req_ctx {
|
||||
Some(ctx) => expand_template(value, ctx),
|
||||
None => value.clone(),
|
||||
};
|
||||
if let Ok(val) = HeaderValue::from_str(&expanded) {
|
||||
headers.insert(name, val);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Apply CORS headers if configured
|
||||
if let Some(ref cors) = route_headers.cors {
|
||||
if cors.enabled {
|
||||
Self::apply_cors_headers(cors, headers);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn apply_cors_headers(cors: &rustproxy_config::RouteCors, headers: &mut HeaderMap) {
|
||||
// Allow-Origin
|
||||
if let Some(ref origin) = cors.allow_origin {
|
||||
let origin_str = match origin {
|
||||
rustproxy_config::AllowOrigin::Single(s) => s.clone(),
|
||||
rustproxy_config::AllowOrigin::List(list) => list.join(", "),
|
||||
};
|
||||
if let Ok(val) = HeaderValue::from_str(&origin_str) {
|
||||
headers.insert("access-control-allow-origin", val);
|
||||
}
|
||||
} else {
|
||||
headers.insert(
|
||||
"access-control-allow-origin",
|
||||
HeaderValue::from_static("*"),
|
||||
);
|
||||
}
|
||||
|
||||
// Allow-Methods
|
||||
if let Some(ref methods) = cors.allow_methods {
|
||||
if let Ok(val) = HeaderValue::from_str(methods) {
|
||||
headers.insert("access-control-allow-methods", val);
|
||||
}
|
||||
}
|
||||
|
||||
// Allow-Headers
|
||||
if let Some(ref allow_headers) = cors.allow_headers {
|
||||
if let Ok(val) = HeaderValue::from_str(allow_headers) {
|
||||
headers.insert("access-control-allow-headers", val);
|
||||
}
|
||||
}
|
||||
|
||||
// Allow-Credentials
|
||||
if cors.allow_credentials == Some(true) {
|
||||
headers.insert(
|
||||
"access-control-allow-credentials",
|
||||
HeaderValue::from_static("true"),
|
||||
);
|
||||
}
|
||||
|
||||
// Expose-Headers
|
||||
if let Some(ref expose) = cors.expose_headers {
|
||||
if let Ok(val) = HeaderValue::from_str(expose) {
|
||||
headers.insert("access-control-expose-headers", val);
|
||||
}
|
||||
}
|
||||
|
||||
// Max-Age
|
||||
if let Some(max_age) = cors.max_age {
|
||||
if let Ok(val) = HeaderValue::from_str(&max_age.to_string()) {
|
||||
headers.insert("access-control-max-age", val);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
162
rust/crates/rustproxy-http/src/template.rs
Normal file
162
rust/crates/rustproxy-http/src/template.rs
Normal file
@@ -0,0 +1,162 @@
|
||||
//! Header template variable expansion.
|
||||
//!
|
||||
//! Supports expanding template variables like `{clientIp}`, `{domain}`, etc.
|
||||
//! in header values before they are applied to requests or responses.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
/// Context for template variable expansion.
|
||||
pub struct RequestContext {
|
||||
pub client_ip: String,
|
||||
pub domain: String,
|
||||
pub port: u16,
|
||||
pub path: String,
|
||||
pub route_name: String,
|
||||
pub connection_id: u64,
|
||||
}
|
||||
|
||||
/// Expand template variables in a header value.
|
||||
/// Supported variables: {clientIp}, {domain}, {port}, {path}, {routeName}, {connectionId}, {timestamp}
|
||||
pub fn expand_template(template: &str, ctx: &RequestContext) -> String {
|
||||
let timestamp = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs();
|
||||
|
||||
template
|
||||
.replace("{clientIp}", &ctx.client_ip)
|
||||
.replace("{domain}", &ctx.domain)
|
||||
.replace("{port}", &ctx.port.to_string())
|
||||
.replace("{path}", &ctx.path)
|
||||
.replace("{routeName}", &ctx.route_name)
|
||||
.replace("{connectionId}", &ctx.connection_id.to_string())
|
||||
.replace("{timestamp}", ×tamp.to_string())
|
||||
}
|
||||
|
||||
/// Expand templates in a map of header key-value pairs.
|
||||
pub fn expand_headers(
|
||||
headers: &HashMap<String, String>,
|
||||
ctx: &RequestContext,
|
||||
) -> HashMap<String, String> {
|
||||
headers.iter()
|
||||
.map(|(k, v)| (k.clone(), expand_template(v, ctx)))
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn test_context() -> RequestContext {
|
||||
RequestContext {
|
||||
client_ip: "192.168.1.100".to_string(),
|
||||
domain: "example.com".to_string(),
|
||||
port: 443,
|
||||
path: "/api/v1/users".to_string(),
|
||||
route_name: "api-route".to_string(),
|
||||
connection_id: 42,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_expand_client_ip() {
|
||||
let ctx = test_context();
|
||||
assert_eq!(expand_template("{clientIp}", &ctx), "192.168.1.100");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_expand_domain() {
|
||||
let ctx = test_context();
|
||||
assert_eq!(expand_template("{domain}", &ctx), "example.com");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_expand_port() {
|
||||
let ctx = test_context();
|
||||
assert_eq!(expand_template("{port}", &ctx), "443");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_expand_path() {
|
||||
let ctx = test_context();
|
||||
assert_eq!(expand_template("{path}", &ctx), "/api/v1/users");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_expand_route_name() {
|
||||
let ctx = test_context();
|
||||
assert_eq!(expand_template("{routeName}", &ctx), "api-route");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_expand_connection_id() {
|
||||
let ctx = test_context();
|
||||
assert_eq!(expand_template("{connectionId}", &ctx), "42");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_expand_timestamp() {
|
||||
let ctx = test_context();
|
||||
let result = expand_template("{timestamp}", &ctx);
|
||||
// Timestamp should be a valid number
|
||||
let ts: u64 = result.parse().expect("timestamp should be a number");
|
||||
// Should be a reasonable Unix timestamp (after 2020)
|
||||
assert!(ts > 1_577_836_800);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_expand_mixed_template() {
|
||||
let ctx = test_context();
|
||||
let result = expand_template("client={clientIp}, host={domain}:{port}", &ctx);
|
||||
assert_eq!(result, "client=192.168.1.100, host=example.com:443");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_expand_no_variables() {
|
||||
let ctx = test_context();
|
||||
assert_eq!(expand_template("plain-value", &ctx), "plain-value");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_expand_empty_string() {
|
||||
let ctx = test_context();
|
||||
assert_eq!(expand_template("", &ctx), "");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_expand_multiple_same_variable() {
|
||||
let ctx = test_context();
|
||||
let result = expand_template("{clientIp}-{clientIp}", &ctx);
|
||||
assert_eq!(result, "192.168.1.100-192.168.1.100");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_expand_headers_map() {
|
||||
let ctx = test_context();
|
||||
let mut headers = HashMap::new();
|
||||
headers.insert("X-Forwarded-For".to_string(), "{clientIp}".to_string());
|
||||
headers.insert("X-Route".to_string(), "{routeName}".to_string());
|
||||
headers.insert("X-Static".to_string(), "no-template".to_string());
|
||||
|
||||
let result = expand_headers(&headers, &ctx);
|
||||
assert_eq!(result.get("X-Forwarded-For").unwrap(), "192.168.1.100");
|
||||
assert_eq!(result.get("X-Route").unwrap(), "api-route");
|
||||
assert_eq!(result.get("X-Static").unwrap(), "no-template");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_expand_all_variables_in_one() {
|
||||
let ctx = test_context();
|
||||
let template = "{clientIp}|{domain}|{port}|{path}|{routeName}|{connectionId}";
|
||||
let result = expand_template(template, &ctx);
|
||||
assert_eq!(result, "192.168.1.100|example.com|443|/api/v1/users|api-route|42");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_expand_unknown_variable_left_as_is() {
|
||||
let ctx = test_context();
|
||||
let result = expand_template("{unknownVar}", &ctx);
|
||||
assert_eq!(result, "{unknownVar}");
|
||||
}
|
||||
}
|
||||
222
rust/crates/rustproxy-http/src/upstream_selector.rs
Normal file
222
rust/crates/rustproxy-http/src/upstream_selector.rs
Normal file
@@ -0,0 +1,222 @@
|
||||
//! Route-aware upstream selection with load balancing.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex;
|
||||
|
||||
use dashmap::DashMap;
|
||||
use rustproxy_config::{RouteTarget, LoadBalancingAlgorithm};
|
||||
|
||||
/// Upstream selection result.
|
||||
pub struct UpstreamSelection {
|
||||
pub host: String,
|
||||
pub port: u16,
|
||||
pub use_tls: bool,
|
||||
}
|
||||
|
||||
/// Selects upstream backends with load balancing support.
|
||||
pub struct UpstreamSelector {
|
||||
/// Round-robin counters per route (keyed by first target host:port)
|
||||
round_robin: Mutex<HashMap<String, AtomicUsize>>,
|
||||
/// Active connection counts per host (keyed by "host:port")
|
||||
active_connections: Arc<DashMap<String, AtomicU64>>,
|
||||
}
|
||||
|
||||
impl UpstreamSelector {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
round_robin: Mutex::new(HashMap::new()),
|
||||
active_connections: Arc::new(DashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Select an upstream target based on the route target config and load balancing.
|
||||
pub fn select(
|
||||
&self,
|
||||
target: &RouteTarget,
|
||||
client_addr: &SocketAddr,
|
||||
incoming_port: u16,
|
||||
) -> UpstreamSelection {
|
||||
let hosts = target.host.to_vec();
|
||||
let port = target.port.resolve(incoming_port);
|
||||
|
||||
if hosts.len() <= 1 {
|
||||
return UpstreamSelection {
|
||||
host: hosts.first().map(|s| s.to_string()).unwrap_or_default(),
|
||||
port,
|
||||
use_tls: target.tls.is_some(),
|
||||
};
|
||||
}
|
||||
|
||||
// Determine load balancing algorithm
|
||||
let algorithm = target.load_balancing.as_ref()
|
||||
.map(|lb| &lb.algorithm)
|
||||
.unwrap_or(&LoadBalancingAlgorithm::RoundRobin);
|
||||
|
||||
let idx = match algorithm {
|
||||
LoadBalancingAlgorithm::RoundRobin => {
|
||||
self.round_robin_select(&hosts, port)
|
||||
}
|
||||
LoadBalancingAlgorithm::IpHash => {
|
||||
let hash = Self::ip_hash(client_addr);
|
||||
hash % hosts.len()
|
||||
}
|
||||
LoadBalancingAlgorithm::LeastConnections => {
|
||||
self.least_connections_select(&hosts, port)
|
||||
}
|
||||
};
|
||||
|
||||
UpstreamSelection {
|
||||
host: hosts[idx].to_string(),
|
||||
port,
|
||||
use_tls: target.tls.is_some(),
|
||||
}
|
||||
}
|
||||
|
||||
fn round_robin_select(&self, hosts: &[&str], port: u16) -> usize {
|
||||
let key = format!("{}:{}", hosts[0], port);
|
||||
let mut counters = self.round_robin.lock().unwrap();
|
||||
let counter = counters
|
||||
.entry(key)
|
||||
.or_insert_with(|| AtomicUsize::new(0));
|
||||
let idx = counter.fetch_add(1, Ordering::Relaxed);
|
||||
idx % hosts.len()
|
||||
}
|
||||
|
||||
fn least_connections_select(&self, hosts: &[&str], port: u16) -> usize {
|
||||
let mut min_conns = u64::MAX;
|
||||
let mut min_idx = 0;
|
||||
|
||||
for (i, host) in hosts.iter().enumerate() {
|
||||
let key = format!("{}:{}", host, port);
|
||||
let conns = self.active_connections
|
||||
.get(&key)
|
||||
.map(|entry| entry.value().load(Ordering::Relaxed))
|
||||
.unwrap_or(0);
|
||||
if conns < min_conns {
|
||||
min_conns = conns;
|
||||
min_idx = i;
|
||||
}
|
||||
}
|
||||
|
||||
min_idx
|
||||
}
|
||||
|
||||
/// Record that a connection to the given host has started.
|
||||
pub fn connection_started(&self, host: &str) {
|
||||
self.active_connections
|
||||
.entry(host.to_string())
|
||||
.or_insert_with(|| AtomicU64::new(0))
|
||||
.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
/// Record that a connection to the given host has ended.
|
||||
pub fn connection_ended(&self, host: &str) {
|
||||
if let Some(counter) = self.active_connections.get(host) {
|
||||
let prev = counter.value().fetch_sub(1, Ordering::Relaxed);
|
||||
// Guard against underflow (shouldn't happen, but be safe)
|
||||
if prev == 0 {
|
||||
counter.value().store(0, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn ip_hash(addr: &SocketAddr) -> usize {
|
||||
let ip_str = addr.ip().to_string();
|
||||
let mut hash: usize = 5381;
|
||||
for byte in ip_str.bytes() {
|
||||
hash = hash.wrapping_mul(33).wrapping_add(byte as usize);
|
||||
}
|
||||
hash
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for UpstreamSelector {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl Clone for UpstreamSelector {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
round_robin: Mutex::new(HashMap::new()),
|
||||
active_connections: Arc::clone(&self.active_connections),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use rustproxy_config::*;
|
||||
|
||||
fn make_target(hosts: Vec<&str>, port: u16) -> RouteTarget {
|
||||
RouteTarget {
|
||||
target_match: None,
|
||||
host: if hosts.len() == 1 {
|
||||
HostSpec::Single(hosts[0].to_string())
|
||||
} else {
|
||||
HostSpec::List(hosts.iter().map(|s| s.to_string()).collect())
|
||||
},
|
||||
port: PortSpec::Fixed(port),
|
||||
tls: None,
|
||||
websocket: None,
|
||||
load_balancing: None,
|
||||
send_proxy_protocol: None,
|
||||
headers: None,
|
||||
advanced: None,
|
||||
priority: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_single_host() {
|
||||
let selector = UpstreamSelector::new();
|
||||
let target = make_target(vec!["backend"], 8080);
|
||||
let addr: SocketAddr = "127.0.0.1:1234".parse().unwrap();
|
||||
let result = selector.select(&target, &addr, 80);
|
||||
assert_eq!(result.host, "backend");
|
||||
assert_eq!(result.port, 8080);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_round_robin() {
|
||||
let selector = UpstreamSelector::new();
|
||||
let mut target = make_target(vec!["a", "b", "c"], 8080);
|
||||
target.load_balancing = Some(RouteLoadBalancing {
|
||||
algorithm: LoadBalancingAlgorithm::RoundRobin,
|
||||
health_check: None,
|
||||
});
|
||||
let addr: SocketAddr = "127.0.0.1:1234".parse().unwrap();
|
||||
|
||||
let r1 = selector.select(&target, &addr, 80);
|
||||
let r2 = selector.select(&target, &addr, 80);
|
||||
let r3 = selector.select(&target, &addr, 80);
|
||||
let r4 = selector.select(&target, &addr, 80);
|
||||
|
||||
// Should cycle through a, b, c, a
|
||||
assert_eq!(r1.host, "a");
|
||||
assert_eq!(r2.host, "b");
|
||||
assert_eq!(r3.host, "c");
|
||||
assert_eq!(r4.host, "a");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ip_hash_consistent() {
|
||||
let selector = UpstreamSelector::new();
|
||||
let mut target = make_target(vec!["a", "b", "c"], 8080);
|
||||
target.load_balancing = Some(RouteLoadBalancing {
|
||||
algorithm: LoadBalancingAlgorithm::IpHash,
|
||||
health_check: None,
|
||||
});
|
||||
let addr: SocketAddr = "10.0.0.5:1234".parse().unwrap();
|
||||
|
||||
let r1 = selector.select(&target, &addr, 80);
|
||||
let r2 = selector.select(&target, &addr, 80);
|
||||
// Same IP should always get same backend
|
||||
assert_eq!(r1.host, r2.host);
|
||||
}
|
||||
}
|
||||
15
rust/crates/rustproxy-metrics/Cargo.toml
Normal file
15
rust/crates/rustproxy-metrics/Cargo.toml
Normal file
@@ -0,0 +1,15 @@
|
||||
[package]
|
||||
name = "rustproxy-metrics"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
authors.workspace = true
|
||||
description = "Metrics and throughput tracking for RustProxy"
|
||||
|
||||
[dependencies]
|
||||
dashmap = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
tokio-util = { workspace = true }
|
||||
251
rust/crates/rustproxy-metrics/src/collector.rs
Normal file
251
rust/crates/rustproxy-metrics/src/collector.rs
Normal file
@@ -0,0 +1,251 @@
|
||||
use dashmap::DashMap;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
|
||||
/// Aggregated metrics snapshot.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct Metrics {
|
||||
pub active_connections: u64,
|
||||
pub total_connections: u64,
|
||||
pub bytes_in: u64,
|
||||
pub bytes_out: u64,
|
||||
pub throughput_in_bytes_per_sec: u64,
|
||||
pub throughput_out_bytes_per_sec: u64,
|
||||
pub routes: std::collections::HashMap<String, RouteMetrics>,
|
||||
}
|
||||
|
||||
/// Per-route metrics.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct RouteMetrics {
|
||||
pub active_connections: u64,
|
||||
pub total_connections: u64,
|
||||
pub bytes_in: u64,
|
||||
pub bytes_out: u64,
|
||||
pub throughput_in_bytes_per_sec: u64,
|
||||
pub throughput_out_bytes_per_sec: u64,
|
||||
}
|
||||
|
||||
/// Statistics snapshot.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct Statistics {
|
||||
pub active_connections: u64,
|
||||
pub total_connections: u64,
|
||||
pub routes_count: u64,
|
||||
pub listening_ports: Vec<u16>,
|
||||
pub uptime_seconds: u64,
|
||||
}
|
||||
|
||||
/// Metrics collector tracking connections and throughput.
|
||||
pub struct MetricsCollector {
|
||||
active_connections: AtomicU64,
|
||||
total_connections: AtomicU64,
|
||||
total_bytes_in: AtomicU64,
|
||||
total_bytes_out: AtomicU64,
|
||||
/// Per-route active connection counts
|
||||
route_connections: DashMap<String, AtomicU64>,
|
||||
/// Per-route total connection counts
|
||||
route_total_connections: DashMap<String, AtomicU64>,
|
||||
/// Per-route byte counters
|
||||
route_bytes_in: DashMap<String, AtomicU64>,
|
||||
route_bytes_out: DashMap<String, AtomicU64>,
|
||||
}
|
||||
|
||||
impl MetricsCollector {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
active_connections: AtomicU64::new(0),
|
||||
total_connections: AtomicU64::new(0),
|
||||
total_bytes_in: AtomicU64::new(0),
|
||||
total_bytes_out: AtomicU64::new(0),
|
||||
route_connections: DashMap::new(),
|
||||
route_total_connections: DashMap::new(),
|
||||
route_bytes_in: DashMap::new(),
|
||||
route_bytes_out: DashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Record a new connection.
|
||||
pub fn connection_opened(&self, route_id: Option<&str>) {
|
||||
self.active_connections.fetch_add(1, Ordering::Relaxed);
|
||||
self.total_connections.fetch_add(1, Ordering::Relaxed);
|
||||
|
||||
if let Some(route_id) = route_id {
|
||||
self.route_connections
|
||||
.entry(route_id.to_string())
|
||||
.or_insert_with(|| AtomicU64::new(0))
|
||||
.fetch_add(1, Ordering::Relaxed);
|
||||
self.route_total_connections
|
||||
.entry(route_id.to_string())
|
||||
.or_insert_with(|| AtomicU64::new(0))
|
||||
.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
/// Record a connection closing.
|
||||
pub fn connection_closed(&self, route_id: Option<&str>) {
|
||||
self.active_connections.fetch_sub(1, Ordering::Relaxed);
|
||||
|
||||
if let Some(route_id) = route_id {
|
||||
if let Some(counter) = self.route_connections.get(route_id) {
|
||||
let val = counter.load(Ordering::Relaxed);
|
||||
if val > 0 {
|
||||
counter.fetch_sub(1, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Record bytes transferred.
|
||||
pub fn record_bytes(&self, bytes_in: u64, bytes_out: u64, route_id: Option<&str>) {
|
||||
self.total_bytes_in.fetch_add(bytes_in, Ordering::Relaxed);
|
||||
self.total_bytes_out.fetch_add(bytes_out, Ordering::Relaxed);
|
||||
|
||||
if let Some(route_id) = route_id {
|
||||
self.route_bytes_in
|
||||
.entry(route_id.to_string())
|
||||
.or_insert_with(|| AtomicU64::new(0))
|
||||
.fetch_add(bytes_in, Ordering::Relaxed);
|
||||
self.route_bytes_out
|
||||
.entry(route_id.to_string())
|
||||
.or_insert_with(|| AtomicU64::new(0))
|
||||
.fetch_add(bytes_out, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
/// Get current active connection count.
|
||||
pub fn active_connections(&self) -> u64 {
|
||||
self.active_connections.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
/// Get total connection count.
|
||||
pub fn total_connections(&self) -> u64 {
|
||||
self.total_connections.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
/// Get total bytes received.
|
||||
pub fn total_bytes_in(&self) -> u64 {
|
||||
self.total_bytes_in.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
/// Get total bytes sent.
|
||||
pub fn total_bytes_out(&self) -> u64 {
|
||||
self.total_bytes_out.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
/// Get a full metrics snapshot including per-route data.
|
||||
pub fn snapshot(&self) -> Metrics {
|
||||
let mut routes = std::collections::HashMap::new();
|
||||
|
||||
// Collect per-route metrics
|
||||
for entry in self.route_total_connections.iter() {
|
||||
let route_id = entry.key().clone();
|
||||
let total = entry.value().load(Ordering::Relaxed);
|
||||
let active = self.route_connections
|
||||
.get(&route_id)
|
||||
.map(|c| c.load(Ordering::Relaxed))
|
||||
.unwrap_or(0);
|
||||
let bytes_in = self.route_bytes_in
|
||||
.get(&route_id)
|
||||
.map(|c| c.load(Ordering::Relaxed))
|
||||
.unwrap_or(0);
|
||||
let bytes_out = self.route_bytes_out
|
||||
.get(&route_id)
|
||||
.map(|c| c.load(Ordering::Relaxed))
|
||||
.unwrap_or(0);
|
||||
|
||||
routes.insert(route_id, RouteMetrics {
|
||||
active_connections: active,
|
||||
total_connections: total,
|
||||
bytes_in,
|
||||
bytes_out,
|
||||
throughput_in_bytes_per_sec: 0,
|
||||
throughput_out_bytes_per_sec: 0,
|
||||
});
|
||||
}
|
||||
|
||||
Metrics {
|
||||
active_connections: self.active_connections(),
|
||||
total_connections: self.total_connections(),
|
||||
bytes_in: self.total_bytes_in(),
|
||||
bytes_out: self.total_bytes_out(),
|
||||
throughput_in_bytes_per_sec: 0,
|
||||
throughput_out_bytes_per_sec: 0,
|
||||
routes,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for MetricsCollector {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_initial_state_zeros() {
|
||||
let collector = MetricsCollector::new();
|
||||
assert_eq!(collector.active_connections(), 0);
|
||||
assert_eq!(collector.total_connections(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_connection_opened_increments() {
|
||||
let collector = MetricsCollector::new();
|
||||
collector.connection_opened(None);
|
||||
assert_eq!(collector.active_connections(), 1);
|
||||
assert_eq!(collector.total_connections(), 1);
|
||||
collector.connection_opened(None);
|
||||
assert_eq!(collector.active_connections(), 2);
|
||||
assert_eq!(collector.total_connections(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_connection_closed_decrements() {
|
||||
let collector = MetricsCollector::new();
|
||||
collector.connection_opened(None);
|
||||
collector.connection_opened(None);
|
||||
assert_eq!(collector.active_connections(), 2);
|
||||
collector.connection_closed(None);
|
||||
assert_eq!(collector.active_connections(), 1);
|
||||
// total_connections should stay at 2
|
||||
assert_eq!(collector.total_connections(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_route_specific_tracking() {
|
||||
let collector = MetricsCollector::new();
|
||||
collector.connection_opened(Some("route-a"));
|
||||
collector.connection_opened(Some("route-a"));
|
||||
collector.connection_opened(Some("route-b"));
|
||||
|
||||
assert_eq!(collector.active_connections(), 3);
|
||||
assert_eq!(collector.total_connections(), 3);
|
||||
|
||||
collector.connection_closed(Some("route-a"));
|
||||
assert_eq!(collector.active_connections(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_record_bytes() {
|
||||
let collector = MetricsCollector::new();
|
||||
collector.record_bytes(100, 200, Some("route-a"));
|
||||
collector.record_bytes(50, 75, Some("route-a"));
|
||||
collector.record_bytes(25, 30, None);
|
||||
|
||||
let total_in = collector.total_bytes_in.load(Ordering::Relaxed);
|
||||
let total_out = collector.total_bytes_out.load(Ordering::Relaxed);
|
||||
assert_eq!(total_in, 175);
|
||||
assert_eq!(total_out, 305);
|
||||
|
||||
// Route-specific bytes
|
||||
let route_in = collector.route_bytes_in.get("route-a").unwrap();
|
||||
assert_eq!(route_in.load(Ordering::Relaxed), 150);
|
||||
}
|
||||
}
|
||||
11
rust/crates/rustproxy-metrics/src/lib.rs
Normal file
11
rust/crates/rustproxy-metrics/src/lib.rs
Normal file
@@ -0,0 +1,11 @@
|
||||
//! # rustproxy-metrics
|
||||
//!
|
||||
//! Metrics and throughput tracking for RustProxy.
|
||||
|
||||
pub mod throughput;
|
||||
pub mod collector;
|
||||
pub mod log_dedup;
|
||||
|
||||
pub use throughput::*;
|
||||
pub use collector::*;
|
||||
pub use log_dedup::*;
|
||||
219
rust/crates/rustproxy-metrics/src/log_dedup.rs
Normal file
219
rust/crates/rustproxy-metrics/src/log_dedup.rs
Normal file
@@ -0,0 +1,219 @@
|
||||
use dashmap::DashMap;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::time::{Duration, Instant};
|
||||
use tracing::info;
|
||||
|
||||
/// An aggregated event during the deduplication window.
|
||||
struct AggregatedEvent {
|
||||
category: String,
|
||||
first_message: String,
|
||||
count: AtomicU64,
|
||||
first_seen: Instant,
|
||||
#[allow(dead_code)]
|
||||
last_seen: Instant,
|
||||
}
|
||||
|
||||
/// Log deduplicator that batches similar events over a time window.
|
||||
///
|
||||
/// Events are grouped by a composite key of `category:key`. Within each
|
||||
/// deduplication window (`flush_interval`) identical events are counted
|
||||
/// instead of being emitted individually. When the window expires (or the
|
||||
/// batch reaches `max_batch_size`) a single summary line is written via
|
||||
/// `tracing::info!`.
|
||||
pub struct LogDeduplicator {
|
||||
events: DashMap<String, AggregatedEvent>,
|
||||
flush_interval: Duration,
|
||||
max_batch_size: u64,
|
||||
#[allow(dead_code)]
|
||||
rapid_threshold: u64, // events/sec that triggers immediate flush
|
||||
}
|
||||
|
||||
impl LogDeduplicator {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
events: DashMap::new(),
|
||||
flush_interval: Duration::from_secs(5),
|
||||
max_batch_size: 100,
|
||||
rapid_threshold: 50,
|
||||
}
|
||||
}
|
||||
|
||||
/// Log an event, deduplicating by `category` + `key`.
|
||||
///
|
||||
/// If the batch for this composite key reaches `max_batch_size` the
|
||||
/// accumulated events are flushed immediately.
|
||||
pub fn log(&self, category: &str, key: &str, message: &str) {
|
||||
let map_key = format!("{}:{}", category, key);
|
||||
let now = Instant::now();
|
||||
|
||||
let entry = self.events.entry(map_key).or_insert_with(|| AggregatedEvent {
|
||||
category: category.to_string(),
|
||||
first_message: message.to_string(),
|
||||
count: AtomicU64::new(0),
|
||||
first_seen: now,
|
||||
last_seen: now,
|
||||
});
|
||||
|
||||
let count = entry.count.fetch_add(1, Ordering::Relaxed) + 1;
|
||||
|
||||
// Check if we should flush (batch size exceeded)
|
||||
if count >= self.max_batch_size {
|
||||
drop(entry);
|
||||
self.flush();
|
||||
}
|
||||
}
|
||||
|
||||
/// Flush all accumulated events, emitting summary log lines.
|
||||
pub fn flush(&self) {
|
||||
// Collect and remove all events
|
||||
self.events.retain(|_key, event| {
|
||||
let count = event.count.load(Ordering::Relaxed);
|
||||
if count > 0 {
|
||||
let elapsed = event.first_seen.elapsed();
|
||||
if count == 1 {
|
||||
info!("[{}] {}", event.category, event.first_message);
|
||||
} else {
|
||||
info!(
|
||||
"[SUMMARY] {} {} events in {:.1}s: {}",
|
||||
count,
|
||||
event.category,
|
||||
elapsed.as_secs_f64(),
|
||||
event.first_message
|
||||
);
|
||||
}
|
||||
}
|
||||
false // remove all entries after flushing
|
||||
});
|
||||
}
|
||||
|
||||
/// Start a background flush task that periodically drains accumulated
|
||||
/// events. The task runs until the supplied `CancellationToken` is
|
||||
/// cancelled, at which point it performs one final flush before exiting.
|
||||
pub fn start_flush_task(self: &Arc<Self>, cancel: tokio_util::sync::CancellationToken) {
|
||||
let dedup = Arc::clone(self);
|
||||
let interval = self.flush_interval;
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = cancel.cancelled() => {
|
||||
dedup.flush();
|
||||
break;
|
||||
}
|
||||
_ = tokio::time::sleep(interval) => {
|
||||
dedup.flush();
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for LogDeduplicator {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_single_event_emitted_as_is() {
|
||||
let dedup = LogDeduplicator::new();
|
||||
dedup.log("conn", "open", "connection opened from 1.2.3.4");
|
||||
// One event should exist
|
||||
assert_eq!(dedup.events.len(), 1);
|
||||
let entry = dedup.events.get("conn:open").unwrap();
|
||||
assert_eq!(entry.count.load(Ordering::Relaxed), 1);
|
||||
assert_eq!(entry.first_message, "connection opened from 1.2.3.4");
|
||||
drop(entry);
|
||||
dedup.flush();
|
||||
// After flush, map should be empty
|
||||
assert_eq!(dedup.events.len(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_duplicate_events_aggregated() {
|
||||
let dedup = LogDeduplicator::new();
|
||||
for _ in 0..10 {
|
||||
dedup.log("conn", "timeout", "connection timed out");
|
||||
}
|
||||
assert_eq!(dedup.events.len(), 1);
|
||||
let entry = dedup.events.get("conn:timeout").unwrap();
|
||||
assert_eq!(entry.count.load(Ordering::Relaxed), 10);
|
||||
drop(entry);
|
||||
dedup.flush();
|
||||
assert_eq!(dedup.events.len(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_different_keys_separate() {
|
||||
let dedup = LogDeduplicator::new();
|
||||
dedup.log("conn", "open", "opened");
|
||||
dedup.log("conn", "close", "closed");
|
||||
dedup.log("tls", "handshake", "TLS handshake");
|
||||
assert_eq!(dedup.events.len(), 3);
|
||||
dedup.flush();
|
||||
assert_eq!(dedup.events.len(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_flush_clears_events() {
|
||||
let dedup = LogDeduplicator::new();
|
||||
dedup.log("a", "b", "msg1");
|
||||
dedup.log("a", "b", "msg2");
|
||||
dedup.flush();
|
||||
assert_eq!(dedup.events.len(), 0);
|
||||
// Logging after flush creates a new entry
|
||||
dedup.log("a", "b", "msg3");
|
||||
assert_eq!(dedup.events.len(), 1);
|
||||
let entry = dedup.events.get("a:b").unwrap();
|
||||
assert_eq!(entry.count.load(Ordering::Relaxed), 1);
|
||||
assert_eq!(entry.first_message, "msg3");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_max_batch_triggers_flush() {
|
||||
let dedup = LogDeduplicator::new();
|
||||
// max_batch_size defaults to 100
|
||||
for i in 0..100 {
|
||||
dedup.log("flood", "key", &format!("event {}", i));
|
||||
}
|
||||
// After hitting max_batch_size the events map should have been flushed
|
||||
assert_eq!(dedup.events.len(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_default_trait() {
|
||||
let dedup = LogDeduplicator::default();
|
||||
assert_eq!(dedup.flush_interval, Duration::from_secs(5));
|
||||
assert_eq!(dedup.max_batch_size, 100);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_background_flush_task() {
|
||||
let dedup = Arc::new(LogDeduplicator {
|
||||
events: DashMap::new(),
|
||||
flush_interval: Duration::from_millis(50),
|
||||
max_batch_size: 100,
|
||||
rapid_threshold: 50,
|
||||
});
|
||||
|
||||
let cancel = tokio_util::sync::CancellationToken::new();
|
||||
dedup.start_flush_task(cancel.clone());
|
||||
|
||||
// Log some events
|
||||
dedup.log("bg", "test", "background flush test");
|
||||
assert_eq!(dedup.events.len(), 1);
|
||||
|
||||
// Wait for the background task to flush
|
||||
tokio::time::sleep(Duration::from_millis(100)).await;
|
||||
assert_eq!(dedup.events.len(), 0);
|
||||
|
||||
// Cancel the task
|
||||
cancel.cancel();
|
||||
tokio::time::sleep(Duration::from_millis(20)).await;
|
||||
}
|
||||
}
|
||||
173
rust/crates/rustproxy-metrics/src/throughput.rs
Normal file
173
rust/crates/rustproxy-metrics/src/throughput.rs
Normal file
@@ -0,0 +1,173 @@
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::time::{Instant, SystemTime, UNIX_EPOCH};
|
||||
|
||||
/// A single throughput sample.
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct ThroughputSample {
|
||||
pub timestamp_ms: u64,
|
||||
pub bytes_in: u64,
|
||||
pub bytes_out: u64,
|
||||
}
|
||||
|
||||
/// Circular buffer for 1Hz throughput sampling.
|
||||
/// Matches smartproxy's ThroughputTracker.
|
||||
pub struct ThroughputTracker {
|
||||
/// Circular buffer of samples
|
||||
samples: Vec<ThroughputSample>,
|
||||
/// Current write index
|
||||
write_index: usize,
|
||||
/// Number of valid samples
|
||||
count: usize,
|
||||
/// Maximum number of samples to retain
|
||||
capacity: usize,
|
||||
/// Accumulated bytes since last sample
|
||||
pending_bytes_in: AtomicU64,
|
||||
pending_bytes_out: AtomicU64,
|
||||
/// When the tracker was created
|
||||
created_at: Instant,
|
||||
}
|
||||
|
||||
impl ThroughputTracker {
|
||||
/// Create a new tracker with the given capacity (seconds of retention).
|
||||
pub fn new(retention_seconds: usize) -> Self {
|
||||
Self {
|
||||
samples: Vec::with_capacity(retention_seconds),
|
||||
write_index: 0,
|
||||
count: 0,
|
||||
capacity: retention_seconds,
|
||||
pending_bytes_in: AtomicU64::new(0),
|
||||
pending_bytes_out: AtomicU64::new(0),
|
||||
created_at: Instant::now(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Record bytes (called from data flow callbacks).
|
||||
pub fn record_bytes(&self, bytes_in: u64, bytes_out: u64) {
|
||||
self.pending_bytes_in.fetch_add(bytes_in, Ordering::Relaxed);
|
||||
self.pending_bytes_out.fetch_add(bytes_out, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
/// Take a sample (called at 1Hz).
|
||||
pub fn sample(&mut self) {
|
||||
let bytes_in = self.pending_bytes_in.swap(0, Ordering::Relaxed);
|
||||
let bytes_out = self.pending_bytes_out.swap(0, Ordering::Relaxed);
|
||||
let timestamp_ms = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_millis() as u64;
|
||||
|
||||
let sample = ThroughputSample {
|
||||
timestamp_ms,
|
||||
bytes_in,
|
||||
bytes_out,
|
||||
};
|
||||
|
||||
if self.samples.len() < self.capacity {
|
||||
self.samples.push(sample);
|
||||
} else {
|
||||
self.samples[self.write_index] = sample;
|
||||
}
|
||||
self.write_index = (self.write_index + 1) % self.capacity;
|
||||
self.count = (self.count + 1).min(self.capacity);
|
||||
}
|
||||
|
||||
/// Get throughput over the last N seconds.
|
||||
pub fn throughput(&self, window_seconds: usize) -> (u64, u64) {
|
||||
let window = window_seconds.min(self.count);
|
||||
if window == 0 {
|
||||
return (0, 0);
|
||||
}
|
||||
|
||||
let mut total_in = 0u64;
|
||||
let mut total_out = 0u64;
|
||||
|
||||
for i in 0..window {
|
||||
let idx = if self.write_index >= i + 1 {
|
||||
self.write_index - i - 1
|
||||
} else {
|
||||
self.capacity - (i + 1 - self.write_index)
|
||||
};
|
||||
if idx < self.samples.len() {
|
||||
total_in += self.samples[idx].bytes_in;
|
||||
total_out += self.samples[idx].bytes_out;
|
||||
}
|
||||
}
|
||||
|
||||
(total_in / window as u64, total_out / window as u64)
|
||||
}
|
||||
|
||||
/// Get instant throughput (last 1 second).
|
||||
pub fn instant(&self) -> (u64, u64) {
|
||||
self.throughput(1)
|
||||
}
|
||||
|
||||
/// Get recent throughput (last 10 seconds).
|
||||
pub fn recent(&self) -> (u64, u64) {
|
||||
self.throughput(10)
|
||||
}
|
||||
|
||||
/// How long this tracker has been alive.
|
||||
pub fn uptime(&self) -> std::time::Duration {
|
||||
self.created_at.elapsed()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_empty_throughput() {
|
||||
let tracker = ThroughputTracker::new(60);
|
||||
let (bytes_in, bytes_out) = tracker.throughput(10);
|
||||
assert_eq!(bytes_in, 0);
|
||||
assert_eq!(bytes_out, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_single_sample() {
|
||||
let mut tracker = ThroughputTracker::new(60);
|
||||
tracker.record_bytes(1000, 2000);
|
||||
tracker.sample();
|
||||
let (bytes_in, bytes_out) = tracker.instant();
|
||||
assert_eq!(bytes_in, 1000);
|
||||
assert_eq!(bytes_out, 2000);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_circular_buffer_wrap() {
|
||||
let mut tracker = ThroughputTracker::new(3); // Small capacity
|
||||
for i in 0..5 {
|
||||
tracker.record_bytes(i * 100, i * 200);
|
||||
tracker.sample();
|
||||
}
|
||||
// Should still work after wrapping
|
||||
let (bytes_in, bytes_out) = tracker.throughput(3);
|
||||
assert!(bytes_in > 0);
|
||||
assert!(bytes_out > 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_window_averaging() {
|
||||
let mut tracker = ThroughputTracker::new(60);
|
||||
// Record 3 samples of different sizes
|
||||
tracker.record_bytes(100, 200);
|
||||
tracker.sample();
|
||||
tracker.record_bytes(200, 400);
|
||||
tracker.sample();
|
||||
tracker.record_bytes(300, 600);
|
||||
tracker.sample();
|
||||
|
||||
// Average over 3 samples: (100+200+300)/3 = 200, (200+400+600)/3 = 400
|
||||
let (avg_in, avg_out) = tracker.throughput(3);
|
||||
assert_eq!(avg_in, 200);
|
||||
assert_eq!(avg_out, 400);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_uptime_positive() {
|
||||
let tracker = ThroughputTracker::new(60);
|
||||
std::thread::sleep(std::time::Duration::from_millis(10));
|
||||
assert!(tracker.uptime().as_millis() >= 10);
|
||||
}
|
||||
}
|
||||
17
rust/crates/rustproxy-nftables/Cargo.toml
Normal file
17
rust/crates/rustproxy-nftables/Cargo.toml
Normal file
@@ -0,0 +1,17 @@
|
||||
[package]
|
||||
name = "rustproxy-nftables"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
authors.workspace = true
|
||||
description = "NFTables kernel-level forwarding for RustProxy"
|
||||
|
||||
[dependencies]
|
||||
rustproxy-config = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
anyhow = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
libc = { workspace = true }
|
||||
10
rust/crates/rustproxy-nftables/src/lib.rs
Normal file
10
rust/crates/rustproxy-nftables/src/lib.rs
Normal file
@@ -0,0 +1,10 @@
|
||||
//! # rustproxy-nftables
|
||||
//!
|
||||
//! NFTables kernel-level forwarding for RustProxy.
|
||||
//! Generates and manages nft CLI rules for DNAT/SNAT.
|
||||
|
||||
pub mod nft_manager;
|
||||
pub mod rule_builder;
|
||||
|
||||
pub use nft_manager::*;
|
||||
pub use rule_builder::*;
|
||||
238
rust/crates/rustproxy-nftables/src/nft_manager.rs
Normal file
238
rust/crates/rustproxy-nftables/src/nft_manager.rs
Normal file
@@ -0,0 +1,238 @@
|
||||
use thiserror::Error;
|
||||
use std::collections::HashMap;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum NftError {
|
||||
#[error("nft command failed: {0}")]
|
||||
CommandFailed(String),
|
||||
#[error("IO error: {0}")]
|
||||
Io(#[from] std::io::Error),
|
||||
#[error("Not running as root")]
|
||||
NotRoot,
|
||||
}
|
||||
|
||||
/// Manager for nftables rules.
|
||||
///
|
||||
/// Executes `nft` CLI commands to manage kernel-level packet forwarding.
|
||||
/// Requires root privileges; operations are skipped gracefully if not root.
|
||||
pub struct NftManager {
|
||||
table_name: String,
|
||||
/// Active rules indexed by route ID
|
||||
active_rules: HashMap<String, Vec<String>>,
|
||||
/// Whether the table has been initialized
|
||||
table_initialized: bool,
|
||||
}
|
||||
|
||||
impl NftManager {
|
||||
pub fn new(table_name: Option<String>) -> Self {
|
||||
Self {
|
||||
table_name: table_name.unwrap_or_else(|| "rustproxy".to_string()),
|
||||
active_rules: HashMap::new(),
|
||||
table_initialized: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if we are running as root.
|
||||
fn is_root() -> bool {
|
||||
unsafe { libc::geteuid() == 0 }
|
||||
}
|
||||
|
||||
/// Execute a single nft command via the CLI.
|
||||
async fn exec_nft(command: &str) -> Result<String, NftError> {
|
||||
// The command starts with "nft ", strip it to get the args
|
||||
let args = if command.starts_with("nft ") {
|
||||
&command[4..]
|
||||
} else {
|
||||
command
|
||||
};
|
||||
|
||||
let output = tokio::process::Command::new("nft")
|
||||
.args(args.split_whitespace())
|
||||
.output()
|
||||
.await
|
||||
.map_err(NftError::Io)?;
|
||||
|
||||
if output.status.success() {
|
||||
Ok(String::from_utf8_lossy(&output.stdout).to_string())
|
||||
} else {
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
Err(NftError::CommandFailed(format!(
|
||||
"Command '{}' failed: {}",
|
||||
command, stderr
|
||||
)))
|
||||
}
|
||||
}
|
||||
|
||||
/// Ensure the nftables table and chains are set up.
|
||||
async fn ensure_table(&mut self) -> Result<(), NftError> {
|
||||
if self.table_initialized {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let setup_commands = crate::rule_builder::build_table_setup(&self.table_name);
|
||||
for cmd in &setup_commands {
|
||||
Self::exec_nft(cmd).await?;
|
||||
}
|
||||
|
||||
self.table_initialized = true;
|
||||
info!("NFTables table '{}' initialized", self.table_name);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Apply rules for a route.
|
||||
///
|
||||
/// Executes the nft commands via the CLI. If not running as root,
|
||||
/// the rules are stored locally but not applied to the kernel.
|
||||
pub async fn apply_rules(&mut self, route_id: &str, rules: Vec<String>) -> Result<(), NftError> {
|
||||
if !Self::is_root() {
|
||||
warn!("Not running as root, nftables rules will not be applied to kernel");
|
||||
self.active_rules.insert(route_id.to_string(), rules);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
self.ensure_table().await?;
|
||||
|
||||
for cmd in &rules {
|
||||
Self::exec_nft(cmd).await?;
|
||||
debug!("Applied nft rule: {}", cmd);
|
||||
}
|
||||
|
||||
info!("Applied {} nftables rules for route '{}'", rules.len(), route_id);
|
||||
self.active_rules.insert(route_id.to_string(), rules);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Remove rules for a route.
|
||||
///
|
||||
/// Currently removes the route from tracking. To fully remove specific
|
||||
/// rules would require handle-based tracking; for now, cleanup() removes
|
||||
/// the entire table.
|
||||
pub async fn remove_rules(&mut self, route_id: &str) -> Result<(), NftError> {
|
||||
if let Some(rules) = self.active_rules.remove(route_id) {
|
||||
info!("Removed {} tracked nft rules for route '{}'", rules.len(), route_id);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Clean up all managed rules by deleting the entire nftables table.
|
||||
pub async fn cleanup(&mut self) -> Result<(), NftError> {
|
||||
if !Self::is_root() {
|
||||
warn!("Not running as root, skipping nftables cleanup");
|
||||
self.active_rules.clear();
|
||||
self.table_initialized = false;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if self.table_initialized {
|
||||
let cleanup_commands = crate::rule_builder::build_table_cleanup(&self.table_name);
|
||||
for cmd in &cleanup_commands {
|
||||
match Self::exec_nft(cmd).await {
|
||||
Ok(_) => debug!("Cleanup: {}", cmd),
|
||||
Err(e) => warn!("Cleanup command failed (may be ok): {}", e),
|
||||
}
|
||||
}
|
||||
info!("NFTables table '{}' cleaned up", self.table_name);
|
||||
}
|
||||
|
||||
self.active_rules.clear();
|
||||
self.table_initialized = false;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get the table name.
|
||||
pub fn table_name(&self) -> &str {
|
||||
&self.table_name
|
||||
}
|
||||
|
||||
/// Whether the table has been initialized in the kernel.
|
||||
pub fn is_initialized(&self) -> bool {
|
||||
self.table_initialized
|
||||
}
|
||||
|
||||
/// Get the number of active route rule sets.
|
||||
pub fn active_route_count(&self) -> usize {
|
||||
self.active_rules.len()
|
||||
}
|
||||
|
||||
/// Get the status of all active rules.
|
||||
pub fn status(&self) -> HashMap<String, serde_json::Value> {
|
||||
let mut status = HashMap::new();
|
||||
for (route_id, rules) in &self.active_rules {
|
||||
status.insert(
|
||||
route_id.clone(),
|
||||
serde_json::json!({
|
||||
"ruleCount": rules.len(),
|
||||
"rules": rules,
|
||||
}),
|
||||
);
|
||||
}
|
||||
status
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_new_default_table_name() {
|
||||
let mgr = NftManager::new(None);
|
||||
assert_eq!(mgr.table_name(), "rustproxy");
|
||||
assert!(!mgr.is_initialized());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_new_custom_table_name() {
|
||||
let mgr = NftManager::new(Some("custom".to_string()));
|
||||
assert_eq!(mgr.table_name(), "custom");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_apply_rules_non_root() {
|
||||
let mut mgr = NftManager::new(None);
|
||||
// When not root, rules are stored but not applied to kernel
|
||||
let rules = vec!["nft add rule ip rustproxy prerouting tcp dport 443 dnat to 10.0.0.1:8443".to_string()];
|
||||
mgr.apply_rules("route-1", rules).await.unwrap();
|
||||
assert_eq!(mgr.active_route_count(), 1);
|
||||
|
||||
let status = mgr.status();
|
||||
assert!(status.contains_key("route-1"));
|
||||
assert_eq!(status["route-1"]["ruleCount"], 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_remove_rules() {
|
||||
let mut mgr = NftManager::new(None);
|
||||
let rules = vec!["nft add rule test".to_string()];
|
||||
mgr.apply_rules("route-1", rules).await.unwrap();
|
||||
assert_eq!(mgr.active_route_count(), 1);
|
||||
|
||||
mgr.remove_rules("route-1").await.unwrap();
|
||||
assert_eq!(mgr.active_route_count(), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_cleanup_non_root() {
|
||||
let mut mgr = NftManager::new(None);
|
||||
let rules = vec!["nft add rule test".to_string()];
|
||||
mgr.apply_rules("route-1", rules).await.unwrap();
|
||||
mgr.apply_rules("route-2", vec!["nft add rule test2".to_string()]).await.unwrap();
|
||||
|
||||
mgr.cleanup().await.unwrap();
|
||||
assert_eq!(mgr.active_route_count(), 0);
|
||||
assert!(!mgr.is_initialized());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_status_multiple_routes() {
|
||||
let mut mgr = NftManager::new(None);
|
||||
mgr.apply_rules("web", vec!["rule1".to_string(), "rule2".to_string()]).await.unwrap();
|
||||
mgr.apply_rules("api", vec!["rule3".to_string()]).await.unwrap();
|
||||
|
||||
let status = mgr.status();
|
||||
assert_eq!(status.len(), 2);
|
||||
assert_eq!(status["web"]["ruleCount"], 2);
|
||||
assert_eq!(status["api"]["ruleCount"], 1);
|
||||
}
|
||||
}
|
||||
123
rust/crates/rustproxy-nftables/src/rule_builder.rs
Normal file
123
rust/crates/rustproxy-nftables/src/rule_builder.rs
Normal file
@@ -0,0 +1,123 @@
|
||||
use rustproxy_config::{NfTablesOptions, NfTablesProtocol};
|
||||
|
||||
/// Build nftables DNAT rule for port forwarding.
|
||||
pub fn build_dnat_rule(
|
||||
table_name: &str,
|
||||
chain_name: &str,
|
||||
source_port: u16,
|
||||
target_host: &str,
|
||||
target_port: u16,
|
||||
options: &NfTablesOptions,
|
||||
) -> Vec<String> {
|
||||
let protocol = match options.protocol.as_ref().unwrap_or(&NfTablesProtocol::Tcp) {
|
||||
NfTablesProtocol::Tcp => "tcp",
|
||||
NfTablesProtocol::Udp => "udp",
|
||||
NfTablesProtocol::All => "tcp", // TODO: handle "all"
|
||||
};
|
||||
|
||||
let mut rules = Vec::new();
|
||||
|
||||
// DNAT rule
|
||||
rules.push(format!(
|
||||
"nft add rule ip {} {} {} dport {} dnat to {}:{}",
|
||||
table_name, chain_name, protocol, source_port, target_host, target_port,
|
||||
));
|
||||
|
||||
// SNAT rule if preserving source IP is not enabled
|
||||
if !options.preserve_source_ip.unwrap_or(false) {
|
||||
rules.push(format!(
|
||||
"nft add rule ip {} postrouting {} dport {} masquerade",
|
||||
table_name, protocol, target_port,
|
||||
));
|
||||
}
|
||||
|
||||
// Rate limiting
|
||||
if let Some(max_rate) = &options.max_rate {
|
||||
rules.push(format!(
|
||||
"nft add rule ip {} {} {} dport {} limit rate {} accept",
|
||||
table_name, chain_name, protocol, source_port, max_rate,
|
||||
));
|
||||
}
|
||||
|
||||
rules
|
||||
}
|
||||
|
||||
/// Build the initial table and chain setup commands.
|
||||
pub fn build_table_setup(table_name: &str) -> Vec<String> {
|
||||
vec![
|
||||
format!("nft add table ip {}", table_name),
|
||||
format!("nft add chain ip {} prerouting {{ type nat hook prerouting priority 0 \\; }}", table_name),
|
||||
format!("nft add chain ip {} postrouting {{ type nat hook postrouting priority 100 \\; }}", table_name),
|
||||
]
|
||||
}
|
||||
|
||||
/// Build cleanup commands to remove the table.
|
||||
pub fn build_table_cleanup(table_name: &str) -> Vec<String> {
|
||||
vec![format!("nft delete table ip {}", table_name)]
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn make_options() -> NfTablesOptions {
|
||||
NfTablesOptions {
|
||||
preserve_source_ip: None,
|
||||
protocol: None,
|
||||
max_rate: None,
|
||||
priority: None,
|
||||
table_name: None,
|
||||
use_ip_sets: None,
|
||||
use_advanced_nat: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_basic_dnat_rule() {
|
||||
let options = make_options();
|
||||
let rules = build_dnat_rule("rustproxy", "prerouting", 443, "10.0.0.1", 8443, &options);
|
||||
assert!(rules.len() >= 1);
|
||||
assert!(rules[0].contains("dnat to 10.0.0.1:8443"));
|
||||
assert!(rules[0].contains("dport 443"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_preserve_source_ip() {
|
||||
let mut options = make_options();
|
||||
options.preserve_source_ip = Some(true);
|
||||
let rules = build_dnat_rule("rustproxy", "prerouting", 443, "10.0.0.1", 8443, &options);
|
||||
// When preserving source IP, no masquerade rule
|
||||
assert!(rules.iter().all(|r| !r.contains("masquerade")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_without_preserve_source_ip() {
|
||||
let options = make_options();
|
||||
let rules = build_dnat_rule("rustproxy", "prerouting", 443, "10.0.0.1", 8443, &options);
|
||||
assert!(rules.iter().any(|r| r.contains("masquerade")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rate_limited_rule() {
|
||||
let mut options = make_options();
|
||||
options.max_rate = Some("100/second".to_string());
|
||||
let rules = build_dnat_rule("rustproxy", "prerouting", 80, "10.0.0.1", 8080, &options);
|
||||
assert!(rules.iter().any(|r| r.contains("limit rate 100/second")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_table_setup_commands() {
|
||||
let commands = build_table_setup("rustproxy");
|
||||
assert_eq!(commands.len(), 3);
|
||||
assert!(commands[0].contains("add table ip rustproxy"));
|
||||
assert!(commands[1].contains("prerouting"));
|
||||
assert!(commands[2].contains("postrouting"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_table_cleanup() {
|
||||
let commands = build_table_cleanup("rustproxy");
|
||||
assert_eq!(commands.len(), 1);
|
||||
assert!(commands[0].contains("delete table ip rustproxy"));
|
||||
}
|
||||
}
|
||||
25
rust/crates/rustproxy-passthrough/Cargo.toml
Normal file
25
rust/crates/rustproxy-passthrough/Cargo.toml
Normal file
@@ -0,0 +1,25 @@
|
||||
[package]
|
||||
name = "rustproxy-passthrough"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
authors.workspace = true
|
||||
description = "Raw TCP/SNI passthrough engine for RustProxy"
|
||||
|
||||
[dependencies]
|
||||
rustproxy-config = { workspace = true }
|
||||
rustproxy-routing = { workspace = true }
|
||||
rustproxy-metrics = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
anyhow = { workspace = true }
|
||||
dashmap = { workspace = true }
|
||||
arc-swap = { workspace = true }
|
||||
rustproxy-http = { workspace = true }
|
||||
rustls = { workspace = true }
|
||||
tokio-rustls = { workspace = true }
|
||||
rustls-pemfile = { workspace = true }
|
||||
tokio-util = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
155
rust/crates/rustproxy-passthrough/src/connection_record.rs
Normal file
155
rust/crates/rustproxy-passthrough/src/connection_record.rs
Normal file
@@ -0,0 +1,155 @@
|
||||
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
/// Per-connection tracking record with atomics for lock-free updates.
|
||||
///
|
||||
/// Each field uses atomics so that the forwarding tasks can update
|
||||
/// bytes_received / bytes_sent / last_activity without holding any lock,
|
||||
/// while the zombie scanner reads them concurrently.
|
||||
pub struct ConnectionRecord {
|
||||
/// Unique connection ID assigned by the ConnectionTracker.
|
||||
pub id: u64,
|
||||
/// Wall-clock instant when this connection was created.
|
||||
pub created_at: Instant,
|
||||
/// Milliseconds since `created_at` when the last activity occurred.
|
||||
/// Updated atomically by the forwarding loops.
|
||||
pub last_activity: AtomicU64,
|
||||
/// Total bytes received from the client (inbound).
|
||||
pub bytes_received: AtomicU64,
|
||||
/// Total bytes sent to the client (outbound / from backend).
|
||||
pub bytes_sent: AtomicU64,
|
||||
/// True once the client side of the connection has closed.
|
||||
pub client_closed: AtomicBool,
|
||||
/// True once the backend side of the connection has closed.
|
||||
pub backend_closed: AtomicBool,
|
||||
/// Whether this connection uses TLS (affects zombie thresholds).
|
||||
pub is_tls: AtomicBool,
|
||||
/// Whether this connection has keep-alive semantics.
|
||||
pub has_keep_alive: AtomicBool,
|
||||
}
|
||||
|
||||
impl ConnectionRecord {
|
||||
/// Create a new connection record with the given ID.
|
||||
/// All counters start at zero, all flags start as false.
|
||||
pub fn new(id: u64) -> Self {
|
||||
Self {
|
||||
id,
|
||||
created_at: Instant::now(),
|
||||
last_activity: AtomicU64::new(0),
|
||||
bytes_received: AtomicU64::new(0),
|
||||
bytes_sent: AtomicU64::new(0),
|
||||
client_closed: AtomicBool::new(false),
|
||||
backend_closed: AtomicBool::new(false),
|
||||
is_tls: AtomicBool::new(false),
|
||||
has_keep_alive: AtomicBool::new(false),
|
||||
}
|
||||
}
|
||||
|
||||
/// Update `last_activity` to reflect the current elapsed time.
|
||||
pub fn touch(&self) {
|
||||
let elapsed_ms = self.created_at.elapsed().as_millis() as u64;
|
||||
self.last_activity.store(elapsed_ms, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
/// Record `n` bytes received from the client (inbound).
|
||||
pub fn record_bytes_in(&self, n: u64) {
|
||||
self.bytes_received.fetch_add(n, Ordering::Relaxed);
|
||||
self.touch();
|
||||
}
|
||||
|
||||
/// Record `n` bytes sent to the client (outbound / from backend).
|
||||
pub fn record_bytes_out(&self, n: u64) {
|
||||
self.bytes_sent.fetch_add(n, Ordering::Relaxed);
|
||||
self.touch();
|
||||
}
|
||||
|
||||
/// How long since the last activity on this connection.
|
||||
pub fn idle_duration(&self) -> Duration {
|
||||
let last_ms = self.last_activity.load(Ordering::Relaxed);
|
||||
let age_ms = self.created_at.elapsed().as_millis() as u64;
|
||||
Duration::from_millis(age_ms.saturating_sub(last_ms))
|
||||
}
|
||||
|
||||
/// Total age of this connection (time since creation).
|
||||
pub fn age(&self) -> Duration {
|
||||
self.created_at.elapsed()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::thread;
|
||||
|
||||
#[test]
|
||||
fn test_new_record() {
|
||||
let record = ConnectionRecord::new(42);
|
||||
assert_eq!(record.id, 42);
|
||||
assert_eq!(record.bytes_received.load(Ordering::Relaxed), 0);
|
||||
assert_eq!(record.bytes_sent.load(Ordering::Relaxed), 0);
|
||||
assert!(!record.client_closed.load(Ordering::Relaxed));
|
||||
assert!(!record.backend_closed.load(Ordering::Relaxed));
|
||||
assert!(!record.is_tls.load(Ordering::Relaxed));
|
||||
assert!(!record.has_keep_alive.load(Ordering::Relaxed));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_record_bytes() {
|
||||
let record = ConnectionRecord::new(1);
|
||||
record.record_bytes_in(100);
|
||||
record.record_bytes_in(200);
|
||||
assert_eq!(record.bytes_received.load(Ordering::Relaxed), 300);
|
||||
|
||||
record.record_bytes_out(50);
|
||||
record.record_bytes_out(75);
|
||||
assert_eq!(record.bytes_sent.load(Ordering::Relaxed), 125);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_touch_updates_activity() {
|
||||
let record = ConnectionRecord::new(1);
|
||||
assert_eq!(record.last_activity.load(Ordering::Relaxed), 0);
|
||||
|
||||
// Sleep briefly so elapsed time is nonzero
|
||||
thread::sleep(Duration::from_millis(10));
|
||||
record.touch();
|
||||
|
||||
let activity = record.last_activity.load(Ordering::Relaxed);
|
||||
assert!(activity >= 10, "last_activity should be at least 10ms, got {}", activity);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_idle_duration() {
|
||||
let record = ConnectionRecord::new(1);
|
||||
// Initially idle_duration ~ age since last_activity is 0
|
||||
thread::sleep(Duration::from_millis(20));
|
||||
let idle = record.idle_duration();
|
||||
assert!(idle >= Duration::from_millis(20));
|
||||
|
||||
// After touch, idle should be near zero
|
||||
record.touch();
|
||||
let idle = record.idle_duration();
|
||||
assert!(idle < Duration::from_millis(10));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_age() {
|
||||
let record = ConnectionRecord::new(1);
|
||||
thread::sleep(Duration::from_millis(20));
|
||||
let age = record.age();
|
||||
assert!(age >= Duration::from_millis(20));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_flags() {
|
||||
let record = ConnectionRecord::new(1);
|
||||
record.client_closed.store(true, Ordering::Relaxed);
|
||||
record.is_tls.store(true, Ordering::Relaxed);
|
||||
record.has_keep_alive.store(true, Ordering::Relaxed);
|
||||
|
||||
assert!(record.client_closed.load(Ordering::Relaxed));
|
||||
assert!(!record.backend_closed.load(Ordering::Relaxed));
|
||||
assert!(record.is_tls.load(Ordering::Relaxed));
|
||||
assert!(record.has_keep_alive.load(Ordering::Relaxed));
|
||||
}
|
||||
}
|
||||
402
rust/crates/rustproxy-passthrough/src/connection_tracker.rs
Normal file
402
rust/crates/rustproxy-passthrough/src/connection_tracker.rs
Normal file
@@ -0,0 +1,402 @@
|
||||
use dashmap::DashMap;
|
||||
use std::collections::VecDeque;
|
||||
use std::net::IpAddr;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{debug, warn};
|
||||
|
||||
use super::connection_record::ConnectionRecord;
|
||||
|
||||
/// Thresholds for zombie detection (non-TLS connections).
|
||||
const HALF_ZOMBIE_TIMEOUT_PLAIN: Duration = Duration::from_secs(30);
|
||||
/// Thresholds for zombie detection (TLS connections).
|
||||
const HALF_ZOMBIE_TIMEOUT_TLS: Duration = Duration::from_secs(300);
|
||||
/// Stuck connection timeout (non-TLS): received data but never sent any.
|
||||
const STUCK_TIMEOUT_PLAIN: Duration = Duration::from_secs(60);
|
||||
/// Stuck connection timeout (TLS): received data but never sent any.
|
||||
const STUCK_TIMEOUT_TLS: Duration = Duration::from_secs(300);
|
||||
|
||||
/// Tracks active connections per IP and enforces per-IP limits and rate limiting.
|
||||
/// Also maintains per-connection records for zombie detection.
|
||||
pub struct ConnectionTracker {
|
||||
/// Active connection counts per IP
|
||||
active: DashMap<IpAddr, AtomicU64>,
|
||||
/// Connection timestamps per IP for rate limiting
|
||||
timestamps: DashMap<IpAddr, VecDeque<Instant>>,
|
||||
/// Maximum concurrent connections per IP (None = unlimited)
|
||||
max_per_ip: Option<u64>,
|
||||
/// Maximum new connections per minute per IP (None = unlimited)
|
||||
rate_limit_per_minute: Option<u64>,
|
||||
/// Per-connection tracking records for zombie detection
|
||||
connections: DashMap<u64, Arc<ConnectionRecord>>,
|
||||
/// Monotonically increasing connection ID counter
|
||||
next_id: AtomicU64,
|
||||
}
|
||||
|
||||
impl ConnectionTracker {
|
||||
pub fn new(max_per_ip: Option<u64>, rate_limit_per_minute: Option<u64>) -> Self {
|
||||
Self {
|
||||
active: DashMap::new(),
|
||||
timestamps: DashMap::new(),
|
||||
max_per_ip,
|
||||
rate_limit_per_minute,
|
||||
connections: DashMap::new(),
|
||||
next_id: AtomicU64::new(1),
|
||||
}
|
||||
}
|
||||
|
||||
/// Try to accept a new connection from the given IP.
|
||||
/// Returns true if allowed, false if over limit.
|
||||
pub fn try_accept(&self, ip: &IpAddr) -> bool {
|
||||
// Check per-IP connection limit
|
||||
if let Some(max) = self.max_per_ip {
|
||||
let count = self.active
|
||||
.get(ip)
|
||||
.map(|c| c.value().load(Ordering::Relaxed))
|
||||
.unwrap_or(0);
|
||||
if count >= max {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Check rate limit
|
||||
if let Some(rate_limit) = self.rate_limit_per_minute {
|
||||
let now = Instant::now();
|
||||
let one_minute = std::time::Duration::from_secs(60);
|
||||
let mut entry = self.timestamps.entry(*ip).or_default();
|
||||
let timestamps = entry.value_mut();
|
||||
|
||||
// Remove timestamps older than 1 minute
|
||||
while timestamps.front().is_some_and(|t| now.duration_since(*t) >= one_minute) {
|
||||
timestamps.pop_front();
|
||||
}
|
||||
|
||||
if timestamps.len() as u64 >= rate_limit {
|
||||
return false;
|
||||
}
|
||||
timestamps.push_back(now);
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
/// Record that a connection was opened from the given IP.
|
||||
pub fn connection_opened(&self, ip: &IpAddr) {
|
||||
self.active
|
||||
.entry(*ip)
|
||||
.or_insert_with(|| AtomicU64::new(0))
|
||||
.value()
|
||||
.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
/// Record that a connection was closed from the given IP.
|
||||
pub fn connection_closed(&self, ip: &IpAddr) {
|
||||
if let Some(counter) = self.active.get(ip) {
|
||||
let prev = counter.value().fetch_sub(1, Ordering::Relaxed);
|
||||
// Clean up zero entries
|
||||
if prev <= 1 {
|
||||
drop(counter);
|
||||
self.active.remove(ip);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the current number of active connections for an IP.
|
||||
pub fn active_connections(&self, ip: &IpAddr) -> u64 {
|
||||
self.active
|
||||
.get(ip)
|
||||
.map(|c| c.value().load(Ordering::Relaxed))
|
||||
.unwrap_or(0)
|
||||
}
|
||||
|
||||
/// Get the total number of tracked IPs.
|
||||
pub fn tracked_ips(&self) -> usize {
|
||||
self.active.len()
|
||||
}
|
||||
|
||||
/// Register a new connection and return its tracking record.
|
||||
///
|
||||
/// The returned `Arc<ConnectionRecord>` should be passed to the forwarding
|
||||
/// loop so it can update bytes / activity atomics in real time.
|
||||
pub fn register_connection(&self, is_tls: bool) -> Arc<ConnectionRecord> {
|
||||
let id = self.next_id.fetch_add(1, Ordering::Relaxed);
|
||||
let record = Arc::new(ConnectionRecord::new(id));
|
||||
record.is_tls.store(is_tls, Ordering::Relaxed);
|
||||
self.connections.insert(id, Arc::clone(&record));
|
||||
record
|
||||
}
|
||||
|
||||
/// Remove a connection record when the connection is fully closed.
|
||||
pub fn unregister_connection(&self, id: u64) {
|
||||
self.connections.remove(&id);
|
||||
}
|
||||
|
||||
/// Scan all tracked connections and return IDs of zombie connections.
|
||||
///
|
||||
/// A connection is considered a zombie in any of these cases:
|
||||
/// - **Full zombie**: both `client_closed` and `backend_closed` are true.
|
||||
/// - **Half zombie**: one side closed for longer than the threshold
|
||||
/// (5 min for TLS, 30s for non-TLS).
|
||||
/// - **Stuck**: `bytes_received > 0` but `bytes_sent == 0` for longer
|
||||
/// than the stuck threshold (5 min for TLS, 60s for non-TLS).
|
||||
pub fn scan_zombies(&self) -> Vec<u64> {
|
||||
let mut zombies = Vec::new();
|
||||
|
||||
for entry in self.connections.iter() {
|
||||
let record = entry.value();
|
||||
let id = *entry.key();
|
||||
let is_tls = record.is_tls.load(Ordering::Relaxed);
|
||||
let client_closed = record.client_closed.load(Ordering::Relaxed);
|
||||
let backend_closed = record.backend_closed.load(Ordering::Relaxed);
|
||||
let idle = record.idle_duration();
|
||||
let bytes_in = record.bytes_received.load(Ordering::Relaxed);
|
||||
let bytes_out = record.bytes_sent.load(Ordering::Relaxed);
|
||||
|
||||
// Full zombie: both sides closed
|
||||
if client_closed && backend_closed {
|
||||
zombies.push(id);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Half zombie: one side closed for too long
|
||||
let half_timeout = if is_tls {
|
||||
HALF_ZOMBIE_TIMEOUT_TLS
|
||||
} else {
|
||||
HALF_ZOMBIE_TIMEOUT_PLAIN
|
||||
};
|
||||
|
||||
if (client_closed || backend_closed) && idle >= half_timeout {
|
||||
zombies.push(id);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Stuck: received data but never sent anything for too long
|
||||
let stuck_timeout = if is_tls {
|
||||
STUCK_TIMEOUT_TLS
|
||||
} else {
|
||||
STUCK_TIMEOUT_PLAIN
|
||||
};
|
||||
|
||||
if bytes_in > 0 && bytes_out == 0 && idle >= stuck_timeout {
|
||||
zombies.push(id);
|
||||
}
|
||||
}
|
||||
|
||||
zombies
|
||||
}
|
||||
|
||||
/// Start a background task that periodically scans for zombie connections.
|
||||
///
|
||||
/// The scanner runs every 10 seconds and logs any zombies it finds.
|
||||
/// It stops when the provided `CancellationToken` is cancelled.
|
||||
pub fn start_zombie_scanner(self: &Arc<Self>, cancel: CancellationToken) {
|
||||
let tracker = Arc::clone(self);
|
||||
tokio::spawn(async move {
|
||||
let interval = Duration::from_secs(10);
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = cancel.cancelled() => {
|
||||
debug!("Zombie scanner shutting down");
|
||||
break;
|
||||
}
|
||||
_ = tokio::time::sleep(interval) => {
|
||||
let zombies = tracker.scan_zombies();
|
||||
if !zombies.is_empty() {
|
||||
warn!(
|
||||
"Detected {} zombie connection(s): {:?}",
|
||||
zombies.len(),
|
||||
zombies
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/// Get the total number of tracked connections (with records).
|
||||
pub fn total_connections(&self) -> usize {
|
||||
self.connections.len()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_basic_tracking() {
|
||||
let tracker = ConnectionTracker::new(None, None);
|
||||
let ip: IpAddr = "127.0.0.1".parse().unwrap();
|
||||
|
||||
assert!(tracker.try_accept(&ip));
|
||||
tracker.connection_opened(&ip);
|
||||
assert_eq!(tracker.active_connections(&ip), 1);
|
||||
|
||||
tracker.connection_opened(&ip);
|
||||
assert_eq!(tracker.active_connections(&ip), 2);
|
||||
|
||||
tracker.connection_closed(&ip);
|
||||
assert_eq!(tracker.active_connections(&ip), 1);
|
||||
|
||||
tracker.connection_closed(&ip);
|
||||
assert_eq!(tracker.active_connections(&ip), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_per_ip_limit() {
|
||||
let tracker = ConnectionTracker::new(Some(2), None);
|
||||
let ip: IpAddr = "10.0.0.1".parse().unwrap();
|
||||
|
||||
assert!(tracker.try_accept(&ip));
|
||||
tracker.connection_opened(&ip);
|
||||
|
||||
assert!(tracker.try_accept(&ip));
|
||||
tracker.connection_opened(&ip);
|
||||
|
||||
// Third connection should be rejected
|
||||
assert!(!tracker.try_accept(&ip));
|
||||
|
||||
// Different IP should still be allowed
|
||||
let ip2: IpAddr = "10.0.0.2".parse().unwrap();
|
||||
assert!(tracker.try_accept(&ip2));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rate_limit() {
|
||||
let tracker = ConnectionTracker::new(None, Some(3));
|
||||
let ip: IpAddr = "10.0.0.1".parse().unwrap();
|
||||
|
||||
assert!(tracker.try_accept(&ip));
|
||||
assert!(tracker.try_accept(&ip));
|
||||
assert!(tracker.try_accept(&ip));
|
||||
// 4th attempt within the minute should be rejected
|
||||
assert!(!tracker.try_accept(&ip));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_no_limits() {
|
||||
let tracker = ConnectionTracker::new(None, None);
|
||||
let ip: IpAddr = "10.0.0.1".parse().unwrap();
|
||||
|
||||
for _ in 0..1000 {
|
||||
assert!(tracker.try_accept(&ip));
|
||||
tracker.connection_opened(&ip);
|
||||
}
|
||||
assert_eq!(tracker.active_connections(&ip), 1000);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tracked_ips() {
|
||||
let tracker = ConnectionTracker::new(None, None);
|
||||
assert_eq!(tracker.tracked_ips(), 0);
|
||||
|
||||
let ip1: IpAddr = "10.0.0.1".parse().unwrap();
|
||||
let ip2: IpAddr = "10.0.0.2".parse().unwrap();
|
||||
|
||||
tracker.connection_opened(&ip1);
|
||||
tracker.connection_opened(&ip2);
|
||||
assert_eq!(tracker.tracked_ips(), 2);
|
||||
|
||||
tracker.connection_closed(&ip1);
|
||||
assert_eq!(tracker.tracked_ips(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_register_unregister_connection() {
|
||||
let tracker = ConnectionTracker::new(None, None);
|
||||
assert_eq!(tracker.total_connections(), 0);
|
||||
|
||||
let record1 = tracker.register_connection(false);
|
||||
assert_eq!(tracker.total_connections(), 1);
|
||||
assert!(!record1.is_tls.load(Ordering::Relaxed));
|
||||
|
||||
let record2 = tracker.register_connection(true);
|
||||
assert_eq!(tracker.total_connections(), 2);
|
||||
assert!(record2.is_tls.load(Ordering::Relaxed));
|
||||
|
||||
// IDs should be unique
|
||||
assert_ne!(record1.id, record2.id);
|
||||
|
||||
tracker.unregister_connection(record1.id);
|
||||
assert_eq!(tracker.total_connections(), 1);
|
||||
|
||||
tracker.unregister_connection(record2.id);
|
||||
assert_eq!(tracker.total_connections(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_full_zombie_detection() {
|
||||
let tracker = ConnectionTracker::new(None, None);
|
||||
let record = tracker.register_connection(false);
|
||||
|
||||
// Not a zombie initially
|
||||
assert!(tracker.scan_zombies().is_empty());
|
||||
|
||||
// Set both sides closed -> full zombie
|
||||
record.client_closed.store(true, Ordering::Relaxed);
|
||||
record.backend_closed.store(true, Ordering::Relaxed);
|
||||
|
||||
let zombies = tracker.scan_zombies();
|
||||
assert_eq!(zombies.len(), 1);
|
||||
assert_eq!(zombies[0], record.id);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_half_zombie_not_triggered_immediately() {
|
||||
let tracker = ConnectionTracker::new(None, None);
|
||||
let record = tracker.register_connection(false);
|
||||
record.touch(); // mark activity now
|
||||
|
||||
// Only one side closed, but just now -> not a zombie yet
|
||||
record.client_closed.store(true, Ordering::Relaxed);
|
||||
assert!(tracker.scan_zombies().is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_stuck_connection_not_triggered_immediately() {
|
||||
let tracker = ConnectionTracker::new(None, None);
|
||||
let record = tracker.register_connection(false);
|
||||
record.touch(); // mark activity now
|
||||
|
||||
// Has received data but sent nothing -> but just started, not stuck yet
|
||||
record.bytes_received.store(1000, Ordering::Relaxed);
|
||||
assert!(tracker.scan_zombies().is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unregister_removes_from_zombie_scan() {
|
||||
let tracker = ConnectionTracker::new(None, None);
|
||||
let record = tracker.register_connection(false);
|
||||
let id = record.id;
|
||||
|
||||
// Make it a full zombie
|
||||
record.client_closed.store(true, Ordering::Relaxed);
|
||||
record.backend_closed.store(true, Ordering::Relaxed);
|
||||
assert_eq!(tracker.scan_zombies().len(), 1);
|
||||
|
||||
// Unregister should remove it
|
||||
tracker.unregister_connection(id);
|
||||
assert!(tracker.scan_zombies().is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_total_connections() {
|
||||
let tracker = ConnectionTracker::new(None, None);
|
||||
assert_eq!(tracker.total_connections(), 0);
|
||||
|
||||
let r1 = tracker.register_connection(false);
|
||||
let r2 = tracker.register_connection(true);
|
||||
let r3 = tracker.register_connection(false);
|
||||
assert_eq!(tracker.total_connections(), 3);
|
||||
|
||||
tracker.unregister_connection(r2.id);
|
||||
assert_eq!(tracker.total_connections(), 2);
|
||||
|
||||
tracker.unregister_connection(r1.id);
|
||||
tracker.unregister_connection(r3.id);
|
||||
assert_eq!(tracker.total_connections(), 0);
|
||||
}
|
||||
}
|
||||
325
rust/crates/rustproxy-passthrough/src/forwarder.rs
Normal file
325
rust/crates/rustproxy-passthrough/src/forwarder.rs
Normal file
@@ -0,0 +1,325 @@
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::net::TcpStream;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use tracing::debug;
|
||||
|
||||
use super::connection_record::ConnectionRecord;
|
||||
|
||||
/// Statistics for a forwarded connection.
|
||||
#[derive(Debug, Default)]
|
||||
pub struct ForwardStats {
|
||||
pub bytes_in: AtomicU64,
|
||||
pub bytes_out: AtomicU64,
|
||||
}
|
||||
|
||||
/// Perform bidirectional TCP forwarding between client and backend.
|
||||
///
|
||||
/// This is the core data path for passthrough connections.
|
||||
/// Returns (bytes_from_client, bytes_from_backend) when the connection closes.
|
||||
pub async fn forward_bidirectional(
|
||||
mut client: TcpStream,
|
||||
mut backend: TcpStream,
|
||||
initial_data: Option<&[u8]>,
|
||||
) -> std::io::Result<(u64, u64)> {
|
||||
// Send initial data (peeked bytes) to backend
|
||||
if let Some(data) = initial_data {
|
||||
backend.write_all(data).await?;
|
||||
}
|
||||
|
||||
let (mut client_read, mut client_write) = client.split();
|
||||
let (mut backend_read, mut backend_write) = backend.split();
|
||||
|
||||
let client_to_backend = async {
|
||||
let mut buf = vec![0u8; 65536];
|
||||
let mut total = initial_data.map_or(0u64, |d| d.len() as u64);
|
||||
loop {
|
||||
let n = client_read.read(&mut buf).await?;
|
||||
if n == 0 {
|
||||
break;
|
||||
}
|
||||
backend_write.write_all(&buf[..n]).await?;
|
||||
total += n as u64;
|
||||
}
|
||||
backend_write.shutdown().await?;
|
||||
Ok::<u64, std::io::Error>(total)
|
||||
};
|
||||
|
||||
let backend_to_client = async {
|
||||
let mut buf = vec![0u8; 65536];
|
||||
let mut total = 0u64;
|
||||
loop {
|
||||
let n = backend_read.read(&mut buf).await?;
|
||||
if n == 0 {
|
||||
break;
|
||||
}
|
||||
client_write.write_all(&buf[..n]).await?;
|
||||
total += n as u64;
|
||||
}
|
||||
client_write.shutdown().await?;
|
||||
Ok::<u64, std::io::Error>(total)
|
||||
};
|
||||
|
||||
let (c2b, b2c) = tokio::join!(client_to_backend, backend_to_client);
|
||||
|
||||
Ok((c2b.unwrap_or(0), b2c.unwrap_or(0)))
|
||||
}
|
||||
|
||||
/// Perform bidirectional TCP forwarding with inactivity and max lifetime timeouts.
|
||||
///
|
||||
/// Returns (bytes_from_client, bytes_from_backend) when the connection closes or times out.
|
||||
pub async fn forward_bidirectional_with_timeouts(
|
||||
client: TcpStream,
|
||||
mut backend: TcpStream,
|
||||
initial_data: Option<&[u8]>,
|
||||
inactivity_timeout: std::time::Duration,
|
||||
max_lifetime: std::time::Duration,
|
||||
cancel: CancellationToken,
|
||||
) -> std::io::Result<(u64, u64)> {
|
||||
// Send initial data (peeked bytes) to backend
|
||||
if let Some(data) = initial_data {
|
||||
backend.write_all(data).await?;
|
||||
}
|
||||
|
||||
let (mut client_read, mut client_write) = client.into_split();
|
||||
let (mut backend_read, mut backend_write) = backend.into_split();
|
||||
|
||||
let last_activity = Arc::new(AtomicU64::new(0));
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
let la1 = Arc::clone(&last_activity);
|
||||
let initial_len = initial_data.map_or(0u64, |d| d.len() as u64);
|
||||
let c2b = tokio::spawn(async move {
|
||||
let mut buf = vec![0u8; 65536];
|
||||
let mut total = initial_len;
|
||||
loop {
|
||||
let n = match client_read.read(&mut buf).await {
|
||||
Ok(0) | Err(_) => break,
|
||||
Ok(n) => n,
|
||||
};
|
||||
if backend_write.write_all(&buf[..n]).await.is_err() {
|
||||
break;
|
||||
}
|
||||
total += n as u64;
|
||||
la1.store(start.elapsed().as_millis() as u64, Ordering::Relaxed);
|
||||
}
|
||||
let _ = backend_write.shutdown().await;
|
||||
total
|
||||
});
|
||||
|
||||
let la2 = Arc::clone(&last_activity);
|
||||
let b2c = tokio::spawn(async move {
|
||||
let mut buf = vec![0u8; 65536];
|
||||
let mut total = 0u64;
|
||||
loop {
|
||||
let n = match backend_read.read(&mut buf).await {
|
||||
Ok(0) | Err(_) => break,
|
||||
Ok(n) => n,
|
||||
};
|
||||
if client_write.write_all(&buf[..n]).await.is_err() {
|
||||
break;
|
||||
}
|
||||
total += n as u64;
|
||||
la2.store(start.elapsed().as_millis() as u64, Ordering::Relaxed);
|
||||
}
|
||||
let _ = client_write.shutdown().await;
|
||||
total
|
||||
});
|
||||
|
||||
// Watchdog: inactivity, max lifetime, and cancellation
|
||||
let la_watch = Arc::clone(&last_activity);
|
||||
let c2b_handle = c2b.abort_handle();
|
||||
let b2c_handle = b2c.abort_handle();
|
||||
let watchdog = tokio::spawn(async move {
|
||||
let check_interval = std::time::Duration::from_secs(5);
|
||||
let mut last_seen = 0u64;
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = cancel.cancelled() => {
|
||||
debug!("Connection cancelled by shutdown");
|
||||
c2b_handle.abort();
|
||||
b2c_handle.abort();
|
||||
break;
|
||||
}
|
||||
_ = tokio::time::sleep(check_interval) => {
|
||||
// Check max lifetime
|
||||
if start.elapsed() >= max_lifetime {
|
||||
debug!("Connection exceeded max lifetime, closing");
|
||||
c2b_handle.abort();
|
||||
b2c_handle.abort();
|
||||
break;
|
||||
}
|
||||
|
||||
// Check inactivity
|
||||
let current = la_watch.load(Ordering::Relaxed);
|
||||
if current == last_seen {
|
||||
let elapsed_since_activity = start.elapsed().as_millis() as u64 - current;
|
||||
if elapsed_since_activity >= inactivity_timeout.as_millis() as u64 {
|
||||
debug!("Connection inactive for {}ms, closing", elapsed_since_activity);
|
||||
c2b_handle.abort();
|
||||
b2c_handle.abort();
|
||||
break;
|
||||
}
|
||||
}
|
||||
last_seen = current;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let bytes_in = c2b.await.unwrap_or(0);
|
||||
let bytes_out = b2c.await.unwrap_or(0);
|
||||
watchdog.abort();
|
||||
Ok((bytes_in, bytes_out))
|
||||
}
|
||||
|
||||
/// Forward bidirectional with a callback for byte counting.
|
||||
pub async fn forward_bidirectional_with_stats(
|
||||
client: TcpStream,
|
||||
backend: TcpStream,
|
||||
initial_data: Option<&[u8]>,
|
||||
stats: Arc<ForwardStats>,
|
||||
) -> std::io::Result<()> {
|
||||
let (bytes_in, bytes_out) = forward_bidirectional(client, backend, initial_data).await?;
|
||||
stats.bytes_in.fetch_add(bytes_in, Ordering::Relaxed);
|
||||
stats.bytes_out.fetch_add(bytes_out, Ordering::Relaxed);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Perform bidirectional TCP forwarding with inactivity / lifetime timeouts,
|
||||
/// updating a `ConnectionRecord` with byte counts and activity timestamps
|
||||
/// in real time for zombie detection.
|
||||
///
|
||||
/// When `record` is `None`, this behaves identically to
|
||||
/// `forward_bidirectional_with_timeouts`.
|
||||
///
|
||||
/// The record's `client_closed` / `backend_closed` flags are set when the
|
||||
/// respective copy loop terminates, giving the zombie scanner visibility
|
||||
/// into half-open connections.
|
||||
pub async fn forward_bidirectional_with_record(
|
||||
client: TcpStream,
|
||||
mut backend: TcpStream,
|
||||
initial_data: Option<&[u8]>,
|
||||
inactivity_timeout: std::time::Duration,
|
||||
max_lifetime: std::time::Duration,
|
||||
cancel: CancellationToken,
|
||||
record: Option<Arc<ConnectionRecord>>,
|
||||
) -> std::io::Result<(u64, u64)> {
|
||||
// Send initial data (peeked bytes) to backend
|
||||
if let Some(data) = initial_data {
|
||||
backend.write_all(data).await?;
|
||||
if let Some(ref r) = record {
|
||||
r.record_bytes_in(data.len() as u64);
|
||||
}
|
||||
}
|
||||
|
||||
let (mut client_read, mut client_write) = client.into_split();
|
||||
let (mut backend_read, mut backend_write) = backend.into_split();
|
||||
|
||||
let last_activity = Arc::new(AtomicU64::new(0));
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
let la1 = Arc::clone(&last_activity);
|
||||
let initial_len = initial_data.map_or(0u64, |d| d.len() as u64);
|
||||
let rec1 = record.clone();
|
||||
let c2b = tokio::spawn(async move {
|
||||
let mut buf = vec![0u8; 65536];
|
||||
let mut total = initial_len;
|
||||
loop {
|
||||
let n = match client_read.read(&mut buf).await {
|
||||
Ok(0) | Err(_) => break,
|
||||
Ok(n) => n,
|
||||
};
|
||||
if backend_write.write_all(&buf[..n]).await.is_err() {
|
||||
break;
|
||||
}
|
||||
total += n as u64;
|
||||
let now_ms = start.elapsed().as_millis() as u64;
|
||||
la1.store(now_ms, Ordering::Relaxed);
|
||||
if let Some(ref r) = rec1 {
|
||||
r.record_bytes_in(n as u64);
|
||||
}
|
||||
}
|
||||
let _ = backend_write.shutdown().await;
|
||||
// Mark client side as closed
|
||||
if let Some(ref r) = rec1 {
|
||||
r.client_closed.store(true, Ordering::Relaxed);
|
||||
}
|
||||
total
|
||||
});
|
||||
|
||||
let la2 = Arc::clone(&last_activity);
|
||||
let rec2 = record.clone();
|
||||
let b2c = tokio::spawn(async move {
|
||||
let mut buf = vec![0u8; 65536];
|
||||
let mut total = 0u64;
|
||||
loop {
|
||||
let n = match backend_read.read(&mut buf).await {
|
||||
Ok(0) | Err(_) => break,
|
||||
Ok(n) => n,
|
||||
};
|
||||
if client_write.write_all(&buf[..n]).await.is_err() {
|
||||
break;
|
||||
}
|
||||
total += n as u64;
|
||||
let now_ms = start.elapsed().as_millis() as u64;
|
||||
la2.store(now_ms, Ordering::Relaxed);
|
||||
if let Some(ref r) = rec2 {
|
||||
r.record_bytes_out(n as u64);
|
||||
}
|
||||
}
|
||||
let _ = client_write.shutdown().await;
|
||||
// Mark backend side as closed
|
||||
if let Some(ref r) = rec2 {
|
||||
r.backend_closed.store(true, Ordering::Relaxed);
|
||||
}
|
||||
total
|
||||
});
|
||||
|
||||
// Watchdog: inactivity, max lifetime, and cancellation
|
||||
let la_watch = Arc::clone(&last_activity);
|
||||
let c2b_handle = c2b.abort_handle();
|
||||
let b2c_handle = b2c.abort_handle();
|
||||
let watchdog = tokio::spawn(async move {
|
||||
let check_interval = std::time::Duration::from_secs(5);
|
||||
let mut last_seen = 0u64;
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = cancel.cancelled() => {
|
||||
debug!("Connection cancelled by shutdown");
|
||||
c2b_handle.abort();
|
||||
b2c_handle.abort();
|
||||
break;
|
||||
}
|
||||
_ = tokio::time::sleep(check_interval) => {
|
||||
// Check max lifetime
|
||||
if start.elapsed() >= max_lifetime {
|
||||
debug!("Connection exceeded max lifetime, closing");
|
||||
c2b_handle.abort();
|
||||
b2c_handle.abort();
|
||||
break;
|
||||
}
|
||||
|
||||
// Check inactivity
|
||||
let current = la_watch.load(Ordering::Relaxed);
|
||||
if current == last_seen {
|
||||
let elapsed_since_activity = start.elapsed().as_millis() as u64 - current;
|
||||
if elapsed_since_activity >= inactivity_timeout.as_millis() as u64 {
|
||||
debug!("Connection inactive for {}ms, closing", elapsed_since_activity);
|
||||
c2b_handle.abort();
|
||||
b2c_handle.abort();
|
||||
break;
|
||||
}
|
||||
}
|
||||
last_seen = current;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let bytes_in = c2b.await.unwrap_or(0);
|
||||
let bytes_out = b2c.await.unwrap_or(0);
|
||||
watchdog.abort();
|
||||
Ok((bytes_in, bytes_out))
|
||||
}
|
||||
22
rust/crates/rustproxy-passthrough/src/lib.rs
Normal file
22
rust/crates/rustproxy-passthrough/src/lib.rs
Normal file
@@ -0,0 +1,22 @@
|
||||
//! # rustproxy-passthrough
|
||||
//!
|
||||
//! Raw TCP/SNI passthrough engine for RustProxy.
|
||||
//! Handles TCP listening, TLS ClientHello SNI extraction, and bidirectional forwarding.
|
||||
|
||||
pub mod tcp_listener;
|
||||
pub mod sni_parser;
|
||||
pub mod forwarder;
|
||||
pub mod proxy_protocol;
|
||||
pub mod tls_handler;
|
||||
pub mod connection_record;
|
||||
pub mod connection_tracker;
|
||||
pub mod socket_relay;
|
||||
|
||||
pub use tcp_listener::*;
|
||||
pub use sni_parser::*;
|
||||
pub use forwarder::*;
|
||||
pub use proxy_protocol::*;
|
||||
pub use tls_handler::*;
|
||||
pub use connection_record::*;
|
||||
pub use connection_tracker::*;
|
||||
pub use socket_relay::*;
|
||||
129
rust/crates/rustproxy-passthrough/src/proxy_protocol.rs
Normal file
129
rust/crates/rustproxy-passthrough/src/proxy_protocol.rs
Normal file
@@ -0,0 +1,129 @@
|
||||
use std::net::SocketAddr;
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum ProxyProtocolError {
|
||||
#[error("Invalid PROXY protocol header")]
|
||||
InvalidHeader,
|
||||
#[error("Unsupported PROXY protocol version")]
|
||||
UnsupportedVersion,
|
||||
#[error("Parse error: {0}")]
|
||||
Parse(String),
|
||||
}
|
||||
|
||||
/// Parsed PROXY protocol v1 header.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ProxyProtocolHeader {
|
||||
pub source_addr: SocketAddr,
|
||||
pub dest_addr: SocketAddr,
|
||||
pub protocol: ProxyProtocol,
|
||||
}
|
||||
|
||||
/// Protocol in PROXY header.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum ProxyProtocol {
|
||||
Tcp4,
|
||||
Tcp6,
|
||||
Unknown,
|
||||
}
|
||||
|
||||
/// Parse a PROXY protocol v1 header from data.
|
||||
///
|
||||
/// Format: `PROXY TCP4 <src_ip> <dst_ip> <src_port> <dst_port>\r\n`
|
||||
pub fn parse_v1(data: &[u8]) -> Result<(ProxyProtocolHeader, usize), ProxyProtocolError> {
|
||||
// Find the end of the header line
|
||||
let line_end = data
|
||||
.windows(2)
|
||||
.position(|w| w == b"\r\n")
|
||||
.ok_or(ProxyProtocolError::InvalidHeader)?;
|
||||
|
||||
let line = std::str::from_utf8(&data[..line_end])
|
||||
.map_err(|_| ProxyProtocolError::InvalidHeader)?;
|
||||
|
||||
if !line.starts_with("PROXY ") {
|
||||
return Err(ProxyProtocolError::InvalidHeader);
|
||||
}
|
||||
|
||||
let parts: Vec<&str> = line.split(' ').collect();
|
||||
if parts.len() != 6 {
|
||||
return Err(ProxyProtocolError::InvalidHeader);
|
||||
}
|
||||
|
||||
let protocol = match parts[1] {
|
||||
"TCP4" => ProxyProtocol::Tcp4,
|
||||
"TCP6" => ProxyProtocol::Tcp6,
|
||||
"UNKNOWN" => ProxyProtocol::Unknown,
|
||||
_ => return Err(ProxyProtocolError::UnsupportedVersion),
|
||||
};
|
||||
|
||||
let src_ip: std::net::IpAddr = parts[2]
|
||||
.parse()
|
||||
.map_err(|_| ProxyProtocolError::Parse("Invalid source IP".to_string()))?;
|
||||
let dst_ip: std::net::IpAddr = parts[3]
|
||||
.parse()
|
||||
.map_err(|_| ProxyProtocolError::Parse("Invalid destination IP".to_string()))?;
|
||||
let src_port: u16 = parts[4]
|
||||
.parse()
|
||||
.map_err(|_| ProxyProtocolError::Parse("Invalid source port".to_string()))?;
|
||||
let dst_port: u16 = parts[5]
|
||||
.parse()
|
||||
.map_err(|_| ProxyProtocolError::Parse("Invalid destination port".to_string()))?;
|
||||
|
||||
let header = ProxyProtocolHeader {
|
||||
source_addr: SocketAddr::new(src_ip, src_port),
|
||||
dest_addr: SocketAddr::new(dst_ip, dst_port),
|
||||
protocol,
|
||||
};
|
||||
|
||||
// Consumed bytes = line + \r\n
|
||||
Ok((header, line_end + 2))
|
||||
}
|
||||
|
||||
/// Generate a PROXY protocol v1 header string.
|
||||
pub fn generate_v1(source: &SocketAddr, dest: &SocketAddr) -> String {
|
||||
let proto = if source.is_ipv4() { "TCP4" } else { "TCP6" };
|
||||
format!(
|
||||
"PROXY {} {} {} {} {}\r\n",
|
||||
proto,
|
||||
source.ip(),
|
||||
dest.ip(),
|
||||
source.port(),
|
||||
dest.port()
|
||||
)
|
||||
}
|
||||
|
||||
/// Check if data starts with a PROXY protocol v1 header.
|
||||
pub fn is_proxy_protocol_v1(data: &[u8]) -> bool {
|
||||
data.starts_with(b"PROXY ")
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_parse_v1_tcp4() {
|
||||
let header = b"PROXY TCP4 192.168.1.100 10.0.0.1 12345 443\r\n";
|
||||
let (parsed, consumed) = parse_v1(header).unwrap();
|
||||
assert_eq!(consumed, header.len());
|
||||
assert_eq!(parsed.protocol, ProxyProtocol::Tcp4);
|
||||
assert_eq!(parsed.source_addr.ip().to_string(), "192.168.1.100");
|
||||
assert_eq!(parsed.source_addr.port(), 12345);
|
||||
assert_eq!(parsed.dest_addr.ip().to_string(), "10.0.0.1");
|
||||
assert_eq!(parsed.dest_addr.port(), 443);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_generate_v1() {
|
||||
let source: SocketAddr = "192.168.1.100:12345".parse().unwrap();
|
||||
let dest: SocketAddr = "10.0.0.1:443".parse().unwrap();
|
||||
let header = generate_v1(&source, &dest);
|
||||
assert_eq!(header, "PROXY TCP4 192.168.1.100 10.0.0.1 12345 443\r\n");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_proxy_protocol() {
|
||||
assert!(is_proxy_protocol_v1(b"PROXY TCP4 ..."));
|
||||
assert!(!is_proxy_protocol_v1(b"GET / HTTP/1.1"));
|
||||
}
|
||||
}
|
||||
287
rust/crates/rustproxy-passthrough/src/sni_parser.rs
Normal file
287
rust/crates/rustproxy-passthrough/src/sni_parser.rs
Normal file
@@ -0,0 +1,287 @@
|
||||
//! ClientHello SNI extraction via manual byte parsing.
|
||||
//! No TLS stack needed - we just parse enough of the ClientHello to extract the SNI.
|
||||
|
||||
/// Result of SNI extraction.
|
||||
#[derive(Debug)]
|
||||
pub enum SniResult {
|
||||
/// Successfully extracted SNI hostname.
|
||||
Found(String),
|
||||
/// TLS ClientHello detected but no SNI extension present.
|
||||
NoSni,
|
||||
/// Not a TLS ClientHello (plain HTTP or other protocol).
|
||||
NotTls,
|
||||
/// Need more data to determine.
|
||||
NeedMoreData,
|
||||
}
|
||||
|
||||
/// Extract the SNI hostname from a TLS ClientHello message.
|
||||
///
|
||||
/// This parses just enough of the TLS record to find the SNI extension,
|
||||
/// without performing any actual TLS operations.
|
||||
pub fn extract_sni(data: &[u8]) -> SniResult {
|
||||
// Minimum TLS record header is 5 bytes
|
||||
if data.len() < 5 {
|
||||
return SniResult::NeedMoreData;
|
||||
}
|
||||
|
||||
// Check for TLS record: content_type=22 (Handshake)
|
||||
if data[0] != 0x16 {
|
||||
return SniResult::NotTls;
|
||||
}
|
||||
|
||||
// TLS version (major.minor) - accept any
|
||||
// data[1..2] = version
|
||||
|
||||
// Record length
|
||||
let record_len = ((data[3] as usize) << 8) | (data[4] as usize);
|
||||
let _total_len = 5 + record_len;
|
||||
|
||||
// We need at least the handshake header (5 TLS + 4 handshake = 9)
|
||||
if data.len() < 9 {
|
||||
return SniResult::NeedMoreData;
|
||||
}
|
||||
|
||||
// Handshake type = 1 (ClientHello)
|
||||
if data[5] != 0x01 {
|
||||
return SniResult::NotTls;
|
||||
}
|
||||
|
||||
// Handshake length (3 bytes) - informational, we parse incrementally
|
||||
let _handshake_len = ((data[6] as usize) << 16)
|
||||
| ((data[7] as usize) << 8)
|
||||
| (data[8] as usize);
|
||||
|
||||
let hello = &data[9..];
|
||||
|
||||
// ClientHello structure:
|
||||
// 2 bytes: client version
|
||||
// 32 bytes: random
|
||||
// 1 byte: session_id length + session_id
|
||||
let mut pos = 2 + 32; // skip version + random
|
||||
|
||||
if pos >= hello.len() {
|
||||
return SniResult::NeedMoreData;
|
||||
}
|
||||
|
||||
// Session ID
|
||||
let session_id_len = hello[pos] as usize;
|
||||
pos += 1 + session_id_len;
|
||||
|
||||
if pos + 2 > hello.len() {
|
||||
return SniResult::NeedMoreData;
|
||||
}
|
||||
|
||||
// Cipher suites
|
||||
let cipher_suites_len = ((hello[pos] as usize) << 8) | (hello[pos + 1] as usize);
|
||||
pos += 2 + cipher_suites_len;
|
||||
|
||||
if pos + 1 > hello.len() {
|
||||
return SniResult::NeedMoreData;
|
||||
}
|
||||
|
||||
// Compression methods
|
||||
let compression_len = hello[pos] as usize;
|
||||
pos += 1 + compression_len;
|
||||
|
||||
if pos + 2 > hello.len() {
|
||||
// No extensions
|
||||
return SniResult::NoSni;
|
||||
}
|
||||
|
||||
// Extensions length
|
||||
let extensions_len = ((hello[pos] as usize) << 8) | (hello[pos + 1] as usize);
|
||||
pos += 2;
|
||||
|
||||
let extensions_end = pos + extensions_len;
|
||||
if extensions_end > hello.len() {
|
||||
// Partial extensions, try to parse what we have
|
||||
}
|
||||
|
||||
// Parse extensions looking for SNI (type 0x0000)
|
||||
while pos + 4 <= hello.len() && pos < extensions_end {
|
||||
let ext_type = ((hello[pos] as u16) << 8) | (hello[pos + 1] as u16);
|
||||
let ext_len = ((hello[pos + 2] as usize) << 8) | (hello[pos + 3] as usize);
|
||||
pos += 4;
|
||||
|
||||
if ext_type == 0x0000 {
|
||||
// SNI extension
|
||||
return parse_sni_extension(&hello[pos..(pos + ext_len).min(hello.len())], ext_len);
|
||||
}
|
||||
|
||||
pos += ext_len;
|
||||
}
|
||||
|
||||
SniResult::NoSni
|
||||
}
|
||||
|
||||
/// Parse the SNI extension data.
|
||||
fn parse_sni_extension(data: &[u8], _ext_len: usize) -> SniResult {
|
||||
if data.len() < 5 {
|
||||
return SniResult::NeedMoreData;
|
||||
}
|
||||
|
||||
// Server name list length
|
||||
let _list_len = ((data[0] as usize) << 8) | (data[1] as usize);
|
||||
|
||||
// Server name type (0 = hostname)
|
||||
if data[2] != 0x00 {
|
||||
return SniResult::NoSni;
|
||||
}
|
||||
|
||||
// Hostname length
|
||||
let name_len = ((data[3] as usize) << 8) | (data[4] as usize);
|
||||
|
||||
if data.len() < 5 + name_len {
|
||||
return SniResult::NeedMoreData;
|
||||
}
|
||||
|
||||
match std::str::from_utf8(&data[5..5 + name_len]) {
|
||||
Ok(hostname) => SniResult::Found(hostname.to_lowercase()),
|
||||
Err(_) => SniResult::NoSni,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if the initial bytes look like a TLS ClientHello.
|
||||
pub fn is_tls(data: &[u8]) -> bool {
|
||||
data.len() >= 3 && data[0] == 0x16 && data[1] == 0x03
|
||||
}
|
||||
|
||||
/// Check if the initial bytes look like HTTP.
|
||||
pub fn is_http(data: &[u8]) -> bool {
|
||||
if data.len() < 4 {
|
||||
return false;
|
||||
}
|
||||
// Check for common HTTP methods
|
||||
let starts = [
|
||||
b"GET " as &[u8],
|
||||
b"POST",
|
||||
b"PUT ",
|
||||
b"HEAD",
|
||||
b"DELE",
|
||||
b"PATC",
|
||||
b"OPTI",
|
||||
b"CONN",
|
||||
];
|
||||
starts.iter().any(|s| data.starts_with(s))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_not_tls() {
|
||||
let http_data = b"GET / HTTP/1.1\r\n";
|
||||
assert!(matches!(extract_sni(http_data), SniResult::NotTls));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_too_short() {
|
||||
assert!(matches!(extract_sni(&[0x16, 0x03]), SniResult::NeedMoreData));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_tls() {
|
||||
assert!(is_tls(&[0x16, 0x03, 0x01]));
|
||||
assert!(!is_tls(&[0x47, 0x45, 0x54])); // "GET"
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_http() {
|
||||
assert!(is_http(b"GET /"));
|
||||
assert!(is_http(b"POST /api"));
|
||||
assert!(!is_http(&[0x16, 0x03, 0x01]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_real_client_hello() {
|
||||
// A minimal TLS 1.2 ClientHello with SNI "example.com"
|
||||
let client_hello: Vec<u8> = build_test_client_hello("example.com");
|
||||
match extract_sni(&client_hello) {
|
||||
SniResult::Found(sni) => assert_eq!(sni, "example.com"),
|
||||
other => panic!("Expected Found, got {:?}", other),
|
||||
}
|
||||
}
|
||||
|
||||
/// Build a minimal TLS ClientHello for testing.
|
||||
fn build_test_client_hello(hostname: &str) -> Vec<u8> {
|
||||
let hostname_bytes = hostname.as_bytes();
|
||||
|
||||
// SNI extension
|
||||
let sni_ext_data = {
|
||||
let mut d = Vec::new();
|
||||
// Server name list length
|
||||
let name_entry_len = 3 + hostname_bytes.len(); // type(1) + len(2) + name
|
||||
d.push(((name_entry_len >> 8) & 0xFF) as u8);
|
||||
d.push((name_entry_len & 0xFF) as u8);
|
||||
// Host name type = 0
|
||||
d.push(0x00);
|
||||
// Host name length
|
||||
d.push(((hostname_bytes.len() >> 8) & 0xFF) as u8);
|
||||
d.push((hostname_bytes.len() & 0xFF) as u8);
|
||||
// Host name
|
||||
d.extend_from_slice(hostname_bytes);
|
||||
d
|
||||
};
|
||||
|
||||
// Extension: type=0x0000 (SNI), length, data
|
||||
let sni_extension = {
|
||||
let mut e = Vec::new();
|
||||
e.push(0x00); e.push(0x00); // SNI type
|
||||
e.push(((sni_ext_data.len() >> 8) & 0xFF) as u8);
|
||||
e.push((sni_ext_data.len() & 0xFF) as u8);
|
||||
e.extend_from_slice(&sni_ext_data);
|
||||
e
|
||||
};
|
||||
|
||||
// Extensions block
|
||||
let extensions = {
|
||||
let mut ext = Vec::new();
|
||||
ext.push(((sni_extension.len() >> 8) & 0xFF) as u8);
|
||||
ext.push((sni_extension.len() & 0xFF) as u8);
|
||||
ext.extend_from_slice(&sni_extension);
|
||||
ext
|
||||
};
|
||||
|
||||
// ClientHello body
|
||||
let hello_body = {
|
||||
let mut h = Vec::new();
|
||||
// Client version TLS 1.2
|
||||
h.push(0x03); h.push(0x03);
|
||||
// Random (32 bytes)
|
||||
h.extend_from_slice(&[0u8; 32]);
|
||||
// Session ID length = 0
|
||||
h.push(0x00);
|
||||
// Cipher suites: length=2, one suite
|
||||
h.push(0x00); h.push(0x02);
|
||||
h.push(0x00); h.push(0x2F); // TLS_RSA_WITH_AES_128_CBC_SHA
|
||||
// Compression methods: length=1, null
|
||||
h.push(0x01); h.push(0x00);
|
||||
// Extensions
|
||||
h.extend_from_slice(&extensions);
|
||||
h
|
||||
};
|
||||
|
||||
// Handshake: type=1 (ClientHello), length
|
||||
let handshake = {
|
||||
let mut hs = Vec::new();
|
||||
hs.push(0x01); // ClientHello
|
||||
// 3-byte length
|
||||
hs.push(((hello_body.len() >> 16) & 0xFF) as u8);
|
||||
hs.push(((hello_body.len() >> 8) & 0xFF) as u8);
|
||||
hs.push((hello_body.len() & 0xFF) as u8);
|
||||
hs.extend_from_slice(&hello_body);
|
||||
hs
|
||||
};
|
||||
|
||||
// TLS record: type=0x16, version TLS 1.0, length
|
||||
let mut record = Vec::new();
|
||||
record.push(0x16); // Handshake
|
||||
record.push(0x03); record.push(0x01); // TLS 1.0
|
||||
record.push(((handshake.len() >> 8) & 0xFF) as u8);
|
||||
record.push((handshake.len() & 0xFF) as u8);
|
||||
record.extend_from_slice(&handshake);
|
||||
|
||||
record
|
||||
}
|
||||
}
|
||||
126
rust/crates/rustproxy-passthrough/src/socket_relay.rs
Normal file
126
rust/crates/rustproxy-passthrough/src/socket_relay.rs
Normal file
@@ -0,0 +1,126 @@
|
||||
//! Socket handler relay for connecting client connections to a TypeScript handler
|
||||
//! via a Unix domain socket.
|
||||
//!
|
||||
//! Protocol: Send a JSON metadata line terminated by `\n`, then bidirectional relay.
|
||||
|
||||
use tokio::net::UnixStream;
|
||||
use tokio::io::{AsyncWriteExt, AsyncReadExt};
|
||||
use tokio::net::TcpStream;
|
||||
use serde::Serialize;
|
||||
use tracing::debug;
|
||||
|
||||
#[derive(Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
struct RelayMetadata {
|
||||
connection_id: u64,
|
||||
remote_ip: String,
|
||||
remote_port: u16,
|
||||
local_port: u16,
|
||||
sni: Option<String>,
|
||||
route_name: String,
|
||||
initial_data_base64: Option<String>,
|
||||
}
|
||||
|
||||
/// Relay a client connection to a TypeScript handler via Unix domain socket.
|
||||
///
|
||||
/// Protocol: Send a JSON metadata line terminated by `\n`, then bidirectional relay.
|
||||
pub async fn relay_to_handler(
|
||||
client: TcpStream,
|
||||
relay_socket_path: &str,
|
||||
connection_id: u64,
|
||||
remote_ip: String,
|
||||
remote_port: u16,
|
||||
local_port: u16,
|
||||
sni: Option<String>,
|
||||
route_name: String,
|
||||
initial_data: Option<&[u8]>,
|
||||
) -> std::io::Result<()> {
|
||||
debug!(
|
||||
"Relaying connection {} to handler socket {}",
|
||||
connection_id, relay_socket_path
|
||||
);
|
||||
|
||||
// Connect to TypeScript handler Unix socket
|
||||
let mut handler = UnixStream::connect(relay_socket_path).await?;
|
||||
|
||||
// Build and send metadata header
|
||||
let initial_data_base64 = initial_data.map(base64_encode);
|
||||
|
||||
let metadata = RelayMetadata {
|
||||
connection_id,
|
||||
remote_ip,
|
||||
remote_port,
|
||||
local_port,
|
||||
sni,
|
||||
route_name,
|
||||
initial_data_base64,
|
||||
};
|
||||
|
||||
let metadata_json = serde_json::to_string(&metadata)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
|
||||
|
||||
handler.write_all(metadata_json.as_bytes()).await?;
|
||||
handler.write_all(b"\n").await?;
|
||||
|
||||
// Bidirectional relay between client and handler
|
||||
let (mut client_read, mut client_write) = client.into_split();
|
||||
let (mut handler_read, mut handler_write) = handler.into_split();
|
||||
|
||||
let c2h = tokio::spawn(async move {
|
||||
let mut buf = vec![0u8; 65536];
|
||||
loop {
|
||||
let n = match client_read.read(&mut buf).await {
|
||||
Ok(0) | Err(_) => break,
|
||||
Ok(n) => n,
|
||||
};
|
||||
if handler_write.write_all(&buf[..n]).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
let _ = handler_write.shutdown().await;
|
||||
});
|
||||
|
||||
let h2c = tokio::spawn(async move {
|
||||
let mut buf = vec![0u8; 65536];
|
||||
loop {
|
||||
let n = match handler_read.read(&mut buf).await {
|
||||
Ok(0) | Err(_) => break,
|
||||
Ok(n) => n,
|
||||
};
|
||||
if client_write.write_all(&buf[..n]).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
let _ = client_write.shutdown().await;
|
||||
});
|
||||
|
||||
let _ = tokio::join!(c2h, h2c);
|
||||
|
||||
debug!("Relay connection {} completed", connection_id);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Simple base64 encoding without external dependency.
|
||||
fn base64_encode(data: &[u8]) -> String {
|
||||
const CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
|
||||
let mut result = String::new();
|
||||
for chunk in data.chunks(3) {
|
||||
let b0 = chunk[0] as u32;
|
||||
let b1 = if chunk.len() > 1 { chunk[1] as u32 } else { 0 };
|
||||
let b2 = if chunk.len() > 2 { chunk[2] as u32 } else { 0 };
|
||||
let n = (b0 << 16) | (b1 << 8) | b2;
|
||||
result.push(CHARS[((n >> 18) & 0x3F) as usize] as char);
|
||||
result.push(CHARS[((n >> 12) & 0x3F) as usize] as char);
|
||||
if chunk.len() > 1 {
|
||||
result.push(CHARS[((n >> 6) & 0x3F) as usize] as char);
|
||||
} else {
|
||||
result.push('=');
|
||||
}
|
||||
if chunk.len() > 2 {
|
||||
result.push(CHARS[(n & 0x3F) as usize] as char);
|
||||
} else {
|
||||
result.push('=');
|
||||
}
|
||||
}
|
||||
result
|
||||
}
|
||||
874
rust/crates/rustproxy-passthrough/src/tcp_listener.rs
Normal file
874
rust/crates/rustproxy-passthrough/src/tcp_listener.rs
Normal file
@@ -0,0 +1,874 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{info, error, debug, warn};
|
||||
use thiserror::Error;
|
||||
|
||||
use rustproxy_routing::RouteManager;
|
||||
use rustproxy_metrics::MetricsCollector;
|
||||
use rustproxy_http::HttpProxyService;
|
||||
use crate::sni_parser;
|
||||
use crate::forwarder;
|
||||
use crate::tls_handler;
|
||||
use crate::connection_tracker::ConnectionTracker;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum ListenerError {
|
||||
#[error("Failed to bind port {port}: {source}")]
|
||||
BindFailed { port: u16, source: std::io::Error },
|
||||
#[error("Port {0} already bound")]
|
||||
AlreadyBound(u16),
|
||||
#[error("IO error: {0}")]
|
||||
Io(#[from] std::io::Error),
|
||||
}
|
||||
|
||||
/// TLS configuration for a specific domain.
|
||||
#[derive(Clone)]
|
||||
pub struct TlsCertConfig {
|
||||
pub cert_pem: String,
|
||||
pub key_pem: String,
|
||||
}
|
||||
|
||||
/// Timeout and connection management configuration.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ConnectionConfig {
|
||||
/// Timeout for establishing connection to backend (ms)
|
||||
pub connection_timeout_ms: u64,
|
||||
/// Timeout for initial data/SNI peek (ms)
|
||||
pub initial_data_timeout_ms: u64,
|
||||
/// Socket inactivity timeout (ms)
|
||||
pub socket_timeout_ms: u64,
|
||||
/// Maximum connection lifetime (ms)
|
||||
pub max_connection_lifetime_ms: u64,
|
||||
/// Graceful shutdown timeout (ms)
|
||||
pub graceful_shutdown_timeout_ms: u64,
|
||||
/// Maximum connections per IP (None = unlimited)
|
||||
pub max_connections_per_ip: Option<u64>,
|
||||
/// Connection rate limit per minute per IP (None = unlimited)
|
||||
pub connection_rate_limit_per_minute: Option<u64>,
|
||||
/// Keep-alive treatment
|
||||
pub keep_alive_treatment: Option<rustproxy_config::KeepAliveTreatment>,
|
||||
/// Inactivity multiplier for keep-alive connections
|
||||
pub keep_alive_inactivity_multiplier: Option<f64>,
|
||||
/// Extended keep-alive lifetime (ms) for Extended treatment mode
|
||||
pub extended_keep_alive_lifetime_ms: Option<u64>,
|
||||
/// Whether to accept PROXY protocol
|
||||
pub accept_proxy_protocol: bool,
|
||||
/// Whether to send PROXY protocol
|
||||
pub send_proxy_protocol: bool,
|
||||
}
|
||||
|
||||
impl Default for ConnectionConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
connection_timeout_ms: 30_000,
|
||||
initial_data_timeout_ms: 60_000,
|
||||
socket_timeout_ms: 3_600_000,
|
||||
max_connection_lifetime_ms: 86_400_000,
|
||||
graceful_shutdown_timeout_ms: 30_000,
|
||||
max_connections_per_ip: None,
|
||||
connection_rate_limit_per_minute: None,
|
||||
keep_alive_treatment: None,
|
||||
keep_alive_inactivity_multiplier: None,
|
||||
extended_keep_alive_lifetime_ms: None,
|
||||
accept_proxy_protocol: false,
|
||||
send_proxy_protocol: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Manages TCP listeners for all configured ports.
|
||||
pub struct TcpListenerManager {
|
||||
/// Active listeners indexed by port
|
||||
listeners: HashMap<u16, tokio::task::JoinHandle<()>>,
|
||||
/// Shared route manager
|
||||
route_manager: Arc<RouteManager>,
|
||||
/// Shared metrics collector
|
||||
metrics: Arc<MetricsCollector>,
|
||||
/// TLS acceptors indexed by domain
|
||||
tls_configs: Arc<HashMap<String, TlsCertConfig>>,
|
||||
/// HTTP proxy service for HTTP-level forwarding
|
||||
http_proxy: Arc<HttpProxyService>,
|
||||
/// Connection configuration
|
||||
conn_config: Arc<ConnectionConfig>,
|
||||
/// Connection tracker for per-IP limits
|
||||
conn_tracker: Arc<ConnectionTracker>,
|
||||
/// Cancellation token for graceful shutdown
|
||||
cancel_token: CancellationToken,
|
||||
}
|
||||
|
||||
impl TcpListenerManager {
|
||||
pub fn new(route_manager: Arc<RouteManager>) -> Self {
|
||||
let metrics = Arc::new(MetricsCollector::new());
|
||||
let http_proxy = Arc::new(HttpProxyService::new(
|
||||
Arc::clone(&route_manager),
|
||||
Arc::clone(&metrics),
|
||||
));
|
||||
let conn_config = ConnectionConfig::default();
|
||||
let conn_tracker = Arc::new(ConnectionTracker::new(
|
||||
conn_config.max_connections_per_ip,
|
||||
conn_config.connection_rate_limit_per_minute,
|
||||
));
|
||||
Self {
|
||||
listeners: HashMap::new(),
|
||||
route_manager,
|
||||
metrics,
|
||||
tls_configs: Arc::new(HashMap::new()),
|
||||
http_proxy,
|
||||
conn_config: Arc::new(conn_config),
|
||||
conn_tracker,
|
||||
cancel_token: CancellationToken::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create with a metrics collector.
|
||||
pub fn with_metrics(route_manager: Arc<RouteManager>, metrics: Arc<MetricsCollector>) -> Self {
|
||||
let http_proxy = Arc::new(HttpProxyService::new(
|
||||
Arc::clone(&route_manager),
|
||||
Arc::clone(&metrics),
|
||||
));
|
||||
let conn_config = ConnectionConfig::default();
|
||||
let conn_tracker = Arc::new(ConnectionTracker::new(
|
||||
conn_config.max_connections_per_ip,
|
||||
conn_config.connection_rate_limit_per_minute,
|
||||
));
|
||||
Self {
|
||||
listeners: HashMap::new(),
|
||||
route_manager,
|
||||
metrics,
|
||||
tls_configs: Arc::new(HashMap::new()),
|
||||
http_proxy,
|
||||
conn_config: Arc::new(conn_config),
|
||||
conn_tracker,
|
||||
cancel_token: CancellationToken::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Set connection configuration.
|
||||
pub fn set_connection_config(&mut self, config: ConnectionConfig) {
|
||||
self.conn_tracker = Arc::new(ConnectionTracker::new(
|
||||
config.max_connections_per_ip,
|
||||
config.connection_rate_limit_per_minute,
|
||||
));
|
||||
self.conn_config = Arc::new(config);
|
||||
}
|
||||
|
||||
/// Set TLS certificate configurations.
|
||||
pub fn set_tls_configs(&mut self, configs: HashMap<String, TlsCertConfig>) {
|
||||
self.tls_configs = Arc::new(configs);
|
||||
}
|
||||
|
||||
/// Start listening on a port.
|
||||
pub async fn add_port(&mut self, port: u16) -> Result<(), ListenerError> {
|
||||
if self.listeners.contains_key(&port) {
|
||||
return Err(ListenerError::AlreadyBound(port));
|
||||
}
|
||||
|
||||
let addr = format!("0.0.0.0:{}", port);
|
||||
let listener = TcpListener::bind(&addr).await.map_err(|e| {
|
||||
ListenerError::BindFailed { port, source: e }
|
||||
})?;
|
||||
|
||||
info!("Listening on port {}", port);
|
||||
|
||||
let route_manager = Arc::clone(&self.route_manager);
|
||||
let metrics = Arc::clone(&self.metrics);
|
||||
let tls_configs = Arc::clone(&self.tls_configs);
|
||||
let http_proxy = Arc::clone(&self.http_proxy);
|
||||
let conn_config = Arc::clone(&self.conn_config);
|
||||
let conn_tracker = Arc::clone(&self.conn_tracker);
|
||||
let cancel = self.cancel_token.clone();
|
||||
|
||||
let handle = tokio::spawn(async move {
|
||||
Self::accept_loop(
|
||||
listener, port, route_manager, metrics, tls_configs,
|
||||
http_proxy, conn_config, conn_tracker, cancel,
|
||||
).await;
|
||||
});
|
||||
|
||||
self.listeners.insert(port, handle);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Stop listening on a port.
|
||||
pub fn remove_port(&mut self, port: u16) -> bool {
|
||||
if let Some(handle) = self.listeners.remove(&port) {
|
||||
handle.abort();
|
||||
info!("Stopped listening on port {}", port);
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// Get all currently listening ports.
|
||||
pub fn listening_ports(&self) -> Vec<u16> {
|
||||
let mut ports: Vec<u16> = self.listeners.keys().copied().collect();
|
||||
ports.sort();
|
||||
ports
|
||||
}
|
||||
|
||||
/// Stop all listeners gracefully.
|
||||
///
|
||||
/// Signals cancellation and waits up to `graceful_shutdown_timeout_ms` for
|
||||
/// connections to drain, then aborts remaining tasks.
|
||||
pub async fn graceful_stop(&mut self) {
|
||||
let timeout_ms = self.conn_config.graceful_shutdown_timeout_ms;
|
||||
info!("Initiating graceful shutdown (timeout: {}ms)", timeout_ms);
|
||||
|
||||
// Signal all accept loops to stop accepting new connections
|
||||
self.cancel_token.cancel();
|
||||
|
||||
// Wait for existing connections to drain
|
||||
let timeout = std::time::Duration::from_millis(timeout_ms);
|
||||
let deadline = tokio::time::Instant::now() + timeout;
|
||||
|
||||
for (port, handle) in self.listeners.drain() {
|
||||
let remaining = deadline.saturating_duration_since(tokio::time::Instant::now());
|
||||
if remaining.is_zero() {
|
||||
handle.abort();
|
||||
warn!("Force-stopped listener on port {} (timeout exceeded)", port);
|
||||
} else {
|
||||
match tokio::time::timeout(remaining, handle).await {
|
||||
Ok(_) => info!("Listener on port {} stopped gracefully", port),
|
||||
Err(_) => {
|
||||
warn!("Listener on port {} did not stop in time, aborting", port);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Reset cancellation token for potential restart
|
||||
self.cancel_token = CancellationToken::new();
|
||||
info!("Graceful shutdown complete");
|
||||
}
|
||||
|
||||
/// Stop all listeners immediately (backward compatibility).
|
||||
pub fn stop_all(&mut self) {
|
||||
self.cancel_token.cancel();
|
||||
for (port, handle) in self.listeners.drain() {
|
||||
handle.abort();
|
||||
info!("Stopped listening on port {}", port);
|
||||
}
|
||||
self.cancel_token = CancellationToken::new();
|
||||
}
|
||||
|
||||
/// Update the route manager (for hot-reload).
|
||||
pub fn update_route_manager(&mut self, route_manager: Arc<RouteManager>) {
|
||||
self.route_manager = route_manager;
|
||||
}
|
||||
|
||||
/// Get a reference to the metrics collector.
|
||||
pub fn metrics(&self) -> &Arc<MetricsCollector> {
|
||||
&self.metrics
|
||||
}
|
||||
|
||||
/// Accept loop for a single port.
|
||||
async fn accept_loop(
|
||||
listener: TcpListener,
|
||||
port: u16,
|
||||
route_manager: Arc<RouteManager>,
|
||||
metrics: Arc<MetricsCollector>,
|
||||
tls_configs: Arc<HashMap<String, TlsCertConfig>>,
|
||||
http_proxy: Arc<HttpProxyService>,
|
||||
conn_config: Arc<ConnectionConfig>,
|
||||
conn_tracker: Arc<ConnectionTracker>,
|
||||
cancel: CancellationToken,
|
||||
) {
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = cancel.cancelled() => {
|
||||
info!("Accept loop on port {} shutting down", port);
|
||||
break;
|
||||
}
|
||||
result = listener.accept() => {
|
||||
match result {
|
||||
Ok((stream, peer_addr)) => {
|
||||
let ip = peer_addr.ip();
|
||||
|
||||
// Check per-IP limits and rate limiting
|
||||
if !conn_tracker.try_accept(&ip) {
|
||||
debug!("Rejected connection from {} (per-IP limit or rate limit)", peer_addr);
|
||||
drop(stream);
|
||||
continue;
|
||||
}
|
||||
|
||||
conn_tracker.connection_opened(&ip);
|
||||
|
||||
let rm = Arc::clone(&route_manager);
|
||||
let m = Arc::clone(&metrics);
|
||||
let tc = Arc::clone(&tls_configs);
|
||||
let hp = Arc::clone(&http_proxy);
|
||||
let cc = Arc::clone(&conn_config);
|
||||
let ct = Arc::clone(&conn_tracker);
|
||||
let cn = cancel.clone();
|
||||
debug!("Accepted connection from {} on port {}", peer_addr, port);
|
||||
|
||||
tokio::spawn(async move {
|
||||
let result = Self::handle_connection(
|
||||
stream, port, peer_addr, rm, m, tc, hp, cc, cn,
|
||||
).await;
|
||||
if let Err(e) = result {
|
||||
debug!("Connection error from {}: {}", peer_addr, e);
|
||||
}
|
||||
ct.connection_closed(&ip);
|
||||
});
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Accept error on port {}: {}", port, e);
|
||||
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle a single incoming connection.
|
||||
async fn handle_connection(
|
||||
mut stream: tokio::net::TcpStream,
|
||||
port: u16,
|
||||
peer_addr: std::net::SocketAddr,
|
||||
route_manager: Arc<RouteManager>,
|
||||
metrics: Arc<MetricsCollector>,
|
||||
tls_configs: Arc<HashMap<String, TlsCertConfig>>,
|
||||
http_proxy: Arc<HttpProxyService>,
|
||||
conn_config: Arc<ConnectionConfig>,
|
||||
cancel: CancellationToken,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
use tokio::io::AsyncReadExt;
|
||||
|
||||
stream.set_nodelay(true)?;
|
||||
|
||||
// Handle PROXY protocol if configured
|
||||
let mut effective_peer_addr = peer_addr;
|
||||
if conn_config.accept_proxy_protocol {
|
||||
let mut proxy_peek = vec![0u8; 256];
|
||||
let pn = match tokio::time::timeout(
|
||||
std::time::Duration::from_millis(conn_config.initial_data_timeout_ms),
|
||||
stream.peek(&mut proxy_peek),
|
||||
).await {
|
||||
Ok(Ok(n)) => n,
|
||||
Ok(Err(e)) => return Err(e.into()),
|
||||
Err(_) => return Err("Initial data timeout (proxy protocol peek)".into()),
|
||||
};
|
||||
|
||||
if pn > 0 && crate::proxy_protocol::is_proxy_protocol_v1(&proxy_peek[..pn]) {
|
||||
match crate::proxy_protocol::parse_v1(&proxy_peek[..pn]) {
|
||||
Ok((header, consumed)) => {
|
||||
debug!("PROXY protocol: real client {} -> {}", header.source_addr, header.dest_addr);
|
||||
effective_peer_addr = header.source_addr;
|
||||
// Consume the proxy protocol header bytes
|
||||
let mut discard = vec![0u8; consumed];
|
||||
stream.read_exact(&mut discard).await?;
|
||||
}
|
||||
Err(e) => {
|
||||
debug!("Failed to parse PROXY protocol header: {}", e);
|
||||
// Not a PROXY protocol header, continue normally
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
let peer_addr = effective_peer_addr;
|
||||
|
||||
// Peek at initial bytes with timeout
|
||||
let mut peek_buf = vec![0u8; 4096];
|
||||
let n = match tokio::time::timeout(
|
||||
std::time::Duration::from_millis(conn_config.initial_data_timeout_ms),
|
||||
stream.peek(&mut peek_buf),
|
||||
).await {
|
||||
Ok(Ok(n)) => n,
|
||||
Ok(Err(e)) => return Err(e.into()),
|
||||
Err(_) => return Err("Initial data timeout".into()),
|
||||
};
|
||||
let initial_data = &peek_buf[..n];
|
||||
|
||||
// Determine connection type and extract SNI if TLS
|
||||
let is_tls = sni_parser::is_tls(initial_data);
|
||||
let is_http = sni_parser::is_http(initial_data);
|
||||
let domain = if is_tls {
|
||||
match sni_parser::extract_sni(initial_data) {
|
||||
sni_parser::SniResult::Found(sni) => Some(sni),
|
||||
sni_parser::SniResult::NoSni => None,
|
||||
sni_parser::SniResult::NeedMoreData => {
|
||||
let mut bigger_buf = vec![0u8; 16384];
|
||||
let n = match tokio::time::timeout(
|
||||
std::time::Duration::from_millis(conn_config.initial_data_timeout_ms),
|
||||
stream.peek(&mut bigger_buf),
|
||||
).await {
|
||||
Ok(Ok(n)) => n,
|
||||
Ok(Err(e)) => return Err(e.into()),
|
||||
Err(_) => return Err("SNI data timeout".into()),
|
||||
};
|
||||
match sni_parser::extract_sni(&bigger_buf[..n]) {
|
||||
sni_parser::SniResult::Found(sni) => Some(sni),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
sni_parser::SniResult::NotTls => None,
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Match route
|
||||
let ctx = rustproxy_routing::MatchContext {
|
||||
port,
|
||||
domain: domain.as_deref(),
|
||||
path: None,
|
||||
client_ip: Some(&peer_addr.ip().to_string()),
|
||||
tls_version: None,
|
||||
headers: None,
|
||||
is_tls,
|
||||
};
|
||||
|
||||
let route_match = route_manager.find_route(&ctx);
|
||||
|
||||
let route_match = match route_match {
|
||||
Some(rm) => rm,
|
||||
None => {
|
||||
debug!("No route matched for port {} domain {:?}", port, domain);
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
|
||||
let route_id = route_match.route.id.as_deref();
|
||||
|
||||
// Check route-level IP security for passthrough connections
|
||||
if let Some(ref security) = route_match.route.security {
|
||||
if !rustproxy_http::request_filter::RequestFilter::check_ip_security(
|
||||
security,
|
||||
&peer_addr.ip(),
|
||||
) {
|
||||
debug!("Connection from {} blocked by route security", peer_addr);
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
// Track connection in metrics
|
||||
metrics.connection_opened(route_id);
|
||||
|
||||
let target = match route_match.target {
|
||||
Some(t) => t,
|
||||
None => {
|
||||
debug!("Route matched but no target available");
|
||||
metrics.connection_closed(route_id);
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
|
||||
let target_host = target.host.first().to_string();
|
||||
let target_port = target.port.resolve(port);
|
||||
let tls_mode = route_match.route.tls_mode();
|
||||
|
||||
// Connection timeout for backend connections
|
||||
let connect_timeout = std::time::Duration::from_millis(conn_config.connection_timeout_ms);
|
||||
let base_inactivity_ms = conn_config.socket_timeout_ms;
|
||||
let (inactivity_timeout, max_lifetime) = match conn_config.keep_alive_treatment.as_ref() {
|
||||
Some(rustproxy_config::KeepAliveTreatment::Extended) => {
|
||||
let multiplier = conn_config.keep_alive_inactivity_multiplier.unwrap_or(6.0);
|
||||
let extended_lifetime = conn_config.extended_keep_alive_lifetime_ms
|
||||
.unwrap_or(7 * 24 * 3600 * 1000); // 7 days default
|
||||
(
|
||||
std::time::Duration::from_millis((base_inactivity_ms as f64 * multiplier) as u64),
|
||||
std::time::Duration::from_millis(extended_lifetime),
|
||||
)
|
||||
}
|
||||
Some(rustproxy_config::KeepAliveTreatment::Immortal) => {
|
||||
(
|
||||
std::time::Duration::from_millis(base_inactivity_ms),
|
||||
std::time::Duration::from_secs(u64::MAX / 2),
|
||||
)
|
||||
}
|
||||
_ => {
|
||||
// Standard
|
||||
(
|
||||
std::time::Duration::from_millis(base_inactivity_ms),
|
||||
std::time::Duration::from_millis(conn_config.max_connection_lifetime_ms),
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
// Determine if we should send PROXY protocol to backend
|
||||
let should_send_proxy = conn_config.send_proxy_protocol
|
||||
|| route_match.route.action.send_proxy_protocol.unwrap_or(false)
|
||||
|| target.send_proxy_protocol.unwrap_or(false);
|
||||
|
||||
// Generate PROXY protocol header if needed
|
||||
let proxy_header = if should_send_proxy {
|
||||
let dest = std::net::SocketAddr::new(
|
||||
target_host.parse().unwrap_or(std::net::IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED)),
|
||||
target_port,
|
||||
);
|
||||
Some(crate::proxy_protocol::generate_v1(&peer_addr, &dest))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let result = match tls_mode {
|
||||
Some(rustproxy_config::TlsMode::Passthrough) => {
|
||||
// Raw TCP passthrough - connect to backend and forward
|
||||
let mut backend = match tokio::time::timeout(
|
||||
connect_timeout,
|
||||
tokio::net::TcpStream::connect(format!("{}:{}", target_host, target_port)),
|
||||
).await {
|
||||
Ok(Ok(s)) => s,
|
||||
Ok(Err(e)) => return Err(e.into()),
|
||||
Err(_) => return Err("Backend connection timeout".into()),
|
||||
};
|
||||
backend.set_nodelay(true)?;
|
||||
|
||||
// Send PROXY protocol header if configured
|
||||
if let Some(ref header) = proxy_header {
|
||||
use tokio::io::AsyncWriteExt;
|
||||
backend.write_all(header.as_bytes()).await?;
|
||||
}
|
||||
|
||||
debug!(
|
||||
"Passthrough: {} -> {}:{} (SNI: {:?})",
|
||||
peer_addr, target_host, target_port, domain
|
||||
);
|
||||
|
||||
let mut actual_buf = vec![0u8; n];
|
||||
stream.read_exact(&mut actual_buf).await?;
|
||||
|
||||
let (bytes_in, bytes_out) = forwarder::forward_bidirectional_with_timeouts(
|
||||
stream, backend, Some(&actual_buf),
|
||||
inactivity_timeout, max_lifetime, cancel,
|
||||
).await?;
|
||||
metrics.record_bytes(bytes_in, bytes_out, route_id);
|
||||
Ok(())
|
||||
}
|
||||
Some(rustproxy_config::TlsMode::Terminate) => {
|
||||
let tls_config = Self::find_tls_config(&domain, &tls_configs)?;
|
||||
|
||||
// TLS accept with timeout, applying route-level TLS settings
|
||||
let route_tls = route_match.route.action.tls.as_ref();
|
||||
let acceptor = tls_handler::build_tls_acceptor_with_config(
|
||||
&tls_config.cert_pem, &tls_config.key_pem, route_tls,
|
||||
)?;
|
||||
let tls_stream = match tokio::time::timeout(
|
||||
std::time::Duration::from_millis(conn_config.initial_data_timeout_ms),
|
||||
tls_handler::accept_tls(stream, &acceptor),
|
||||
).await {
|
||||
Ok(Ok(s)) => s,
|
||||
Ok(Err(e)) => return Err(e),
|
||||
Err(_) => return Err("TLS handshake timeout".into()),
|
||||
};
|
||||
|
||||
// Peek at decrypted data to determine if HTTP
|
||||
let mut buf_stream = tokio::io::BufReader::new(tls_stream);
|
||||
let peeked = {
|
||||
use tokio::io::AsyncBufReadExt;
|
||||
match buf_stream.fill_buf().await {
|
||||
Ok(data) => sni_parser::is_http(data),
|
||||
Err(_) => false,
|
||||
}
|
||||
};
|
||||
|
||||
if peeked {
|
||||
debug!(
|
||||
"TLS Terminate + HTTP: {} -> {}:{} (domain: {:?})",
|
||||
peer_addr, target_host, target_port, domain
|
||||
);
|
||||
http_proxy.handle_io(buf_stream, peer_addr, port).await;
|
||||
} else {
|
||||
debug!(
|
||||
"TLS Terminate + TCP: {} -> {}:{} (domain: {:?})",
|
||||
peer_addr, target_host, target_port, domain
|
||||
);
|
||||
// Raw TCP forwarding of decrypted stream
|
||||
let backend = match tokio::time::timeout(
|
||||
connect_timeout,
|
||||
tokio::net::TcpStream::connect(format!("{}:{}", target_host, target_port)),
|
||||
).await {
|
||||
Ok(Ok(s)) => s,
|
||||
Ok(Err(e)) => return Err(e.into()),
|
||||
Err(_) => return Err("Backend connection timeout".into()),
|
||||
};
|
||||
backend.set_nodelay(true)?;
|
||||
|
||||
let (tls_read, tls_write) = tokio::io::split(buf_stream);
|
||||
let (backend_read, backend_write) = tokio::io::split(backend);
|
||||
|
||||
let (bytes_in, bytes_out) = Self::forward_bidirectional_split_with_timeouts(
|
||||
tls_read, tls_write, backend_read, backend_write,
|
||||
inactivity_timeout, max_lifetime,
|
||||
).await;
|
||||
|
||||
metrics.record_bytes(bytes_in, bytes_out, route_id);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
Some(rustproxy_config::TlsMode::TerminateAndReencrypt) => {
|
||||
let route_tls = route_match.route.action.tls.as_ref();
|
||||
Self::handle_tls_terminate_reencrypt(
|
||||
stream, n, &domain, &target_host, target_port,
|
||||
peer_addr, &tls_configs, &metrics, route_id, &conn_config, route_tls,
|
||||
).await
|
||||
}
|
||||
None => {
|
||||
if is_http {
|
||||
// Plain HTTP - use HTTP proxy for request-level routing
|
||||
debug!("HTTP proxy: {} on port {}", peer_addr, port);
|
||||
http_proxy.handle_connection(stream, peer_addr, port).await;
|
||||
Ok(())
|
||||
} else {
|
||||
// Plain TCP forwarding (non-HTTP)
|
||||
let mut backend = match tokio::time::timeout(
|
||||
connect_timeout,
|
||||
tokio::net::TcpStream::connect(format!("{}:{}", target_host, target_port)),
|
||||
).await {
|
||||
Ok(Ok(s)) => s,
|
||||
Ok(Err(e)) => return Err(e.into()),
|
||||
Err(_) => return Err("Backend connection timeout".into()),
|
||||
};
|
||||
backend.set_nodelay(true)?;
|
||||
|
||||
// Send PROXY protocol header if configured
|
||||
if let Some(ref header) = proxy_header {
|
||||
use tokio::io::AsyncWriteExt;
|
||||
backend.write_all(header.as_bytes()).await?;
|
||||
}
|
||||
|
||||
debug!(
|
||||
"Forward: {} -> {}:{}",
|
||||
peer_addr, target_host, target_port
|
||||
);
|
||||
|
||||
let mut actual_buf = vec![0u8; n];
|
||||
stream.read_exact(&mut actual_buf).await?;
|
||||
|
||||
let (bytes_in, bytes_out) = forwarder::forward_bidirectional_with_timeouts(
|
||||
stream, backend, Some(&actual_buf),
|
||||
inactivity_timeout, max_lifetime, cancel,
|
||||
).await?;
|
||||
metrics.record_bytes(bytes_in, bytes_out, route_id);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
metrics.connection_closed(route_id);
|
||||
result
|
||||
}
|
||||
|
||||
/// Handle TLS terminate-and-reencrypt: accept TLS from client, connect TLS to backend.
|
||||
async fn handle_tls_terminate_reencrypt(
|
||||
stream: tokio::net::TcpStream,
|
||||
_peek_len: usize,
|
||||
domain: &Option<String>,
|
||||
target_host: &str,
|
||||
target_port: u16,
|
||||
peer_addr: std::net::SocketAddr,
|
||||
tls_configs: &HashMap<String, TlsCertConfig>,
|
||||
metrics: &MetricsCollector,
|
||||
route_id: Option<&str>,
|
||||
conn_config: &ConnectionConfig,
|
||||
route_tls: Option<&rustproxy_config::RouteTls>,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
let tls_config = Self::find_tls_config(domain, tls_configs)?;
|
||||
let acceptor = tls_handler::build_tls_acceptor_with_config(
|
||||
&tls_config.cert_pem, &tls_config.key_pem, route_tls,
|
||||
)?;
|
||||
|
||||
// Accept TLS from client with timeout
|
||||
let client_tls = match tokio::time::timeout(
|
||||
std::time::Duration::from_millis(conn_config.initial_data_timeout_ms),
|
||||
tls_handler::accept_tls(stream, &acceptor),
|
||||
).await {
|
||||
Ok(Ok(s)) => s,
|
||||
Ok(Err(e)) => return Err(e),
|
||||
Err(_) => return Err("TLS handshake timeout".into()),
|
||||
};
|
||||
|
||||
debug!(
|
||||
"TLS Terminate+Reencrypt: {} -> {}:{} (domain: {:?})",
|
||||
peer_addr, target_host, target_port, domain
|
||||
);
|
||||
|
||||
// Connect to backend over TLS with timeout
|
||||
let backend_tls = match tokio::time::timeout(
|
||||
std::time::Duration::from_millis(conn_config.connection_timeout_ms),
|
||||
tls_handler::connect_tls(target_host, target_port),
|
||||
).await {
|
||||
Ok(Ok(s)) => s,
|
||||
Ok(Err(e)) => return Err(e),
|
||||
Err(_) => return Err("Backend TLS connection timeout".into()),
|
||||
};
|
||||
|
||||
// Forward between two TLS streams
|
||||
let (client_read, client_write) = tokio::io::split(client_tls);
|
||||
let (backend_read, backend_write) = tokio::io::split(backend_tls);
|
||||
|
||||
let base_inactivity_ms = conn_config.socket_timeout_ms;
|
||||
let (inactivity_timeout, max_lifetime) = match conn_config.keep_alive_treatment.as_ref() {
|
||||
Some(rustproxy_config::KeepAliveTreatment::Extended) => {
|
||||
let multiplier = conn_config.keep_alive_inactivity_multiplier.unwrap_or(6.0);
|
||||
let extended_lifetime = conn_config.extended_keep_alive_lifetime_ms
|
||||
.unwrap_or(7 * 24 * 3600 * 1000); // 7 days default
|
||||
(
|
||||
std::time::Duration::from_millis((base_inactivity_ms as f64 * multiplier) as u64),
|
||||
std::time::Duration::from_millis(extended_lifetime),
|
||||
)
|
||||
}
|
||||
Some(rustproxy_config::KeepAliveTreatment::Immortal) => {
|
||||
(
|
||||
std::time::Duration::from_millis(base_inactivity_ms),
|
||||
std::time::Duration::from_secs(u64::MAX / 2),
|
||||
)
|
||||
}
|
||||
_ => {
|
||||
// Standard
|
||||
(
|
||||
std::time::Duration::from_millis(base_inactivity_ms),
|
||||
std::time::Duration::from_millis(conn_config.max_connection_lifetime_ms),
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
let (bytes_in, bytes_out) = Self::forward_bidirectional_split_with_timeouts(
|
||||
client_read, client_write, backend_read, backend_write,
|
||||
inactivity_timeout, max_lifetime,
|
||||
).await;
|
||||
|
||||
metrics.record_bytes(bytes_in, bytes_out, route_id);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Find the TLS config for a given domain.
|
||||
fn find_tls_config<'a>(
|
||||
domain: &Option<String>,
|
||||
tls_configs: &'a HashMap<String, TlsCertConfig>,
|
||||
) -> Result<&'a TlsCertConfig, Box<dyn std::error::Error + Send + Sync>> {
|
||||
if let Some(domain) = domain {
|
||||
// Try exact match
|
||||
if let Some(config) = tls_configs.get(domain) {
|
||||
return Ok(config);
|
||||
}
|
||||
// Try wildcard
|
||||
if let Some(dot_pos) = domain.find('.') {
|
||||
let wildcard = format!("*.{}", &domain[dot_pos + 1..]);
|
||||
if let Some(config) = tls_configs.get(&wildcard) {
|
||||
return Ok(config);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Try default/fallback cert
|
||||
if let Some(config) = tls_configs.get("*") {
|
||||
return Ok(config);
|
||||
}
|
||||
// Try first available cert
|
||||
if let Some((_key, config)) = tls_configs.iter().next() {
|
||||
return Ok(config);
|
||||
}
|
||||
Err("No TLS certificate available for this domain".into())
|
||||
}
|
||||
|
||||
/// Forward bidirectional between two split streams with inactivity and lifetime timeouts.
|
||||
async fn forward_bidirectional_split_with_timeouts<R1, W1, R2, W2>(
|
||||
mut client_read: R1,
|
||||
mut client_write: W1,
|
||||
mut backend_read: R2,
|
||||
mut backend_write: W2,
|
||||
inactivity_timeout: std::time::Duration,
|
||||
max_lifetime: std::time::Duration,
|
||||
) -> (u64, u64)
|
||||
where
|
||||
R1: tokio::io::AsyncRead + Unpin + Send + 'static,
|
||||
W1: tokio::io::AsyncWrite + Unpin + Send + 'static,
|
||||
R2: tokio::io::AsyncRead + Unpin + Send + 'static,
|
||||
W2: tokio::io::AsyncWrite + Unpin + Send + 'static,
|
||||
{
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
|
||||
let last_activity = Arc::new(AtomicU64::new(0));
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
let la1 = Arc::clone(&last_activity);
|
||||
let c2b = tokio::spawn(async move {
|
||||
let mut buf = vec![0u8; 65536];
|
||||
let mut total = 0u64;
|
||||
loop {
|
||||
let n = match client_read.read(&mut buf).await {
|
||||
Ok(0) | Err(_) => break,
|
||||
Ok(n) => n,
|
||||
};
|
||||
if backend_write.write_all(&buf[..n]).await.is_err() {
|
||||
break;
|
||||
}
|
||||
total += n as u64;
|
||||
la1.store(
|
||||
start.elapsed().as_millis() as u64,
|
||||
Ordering::Relaxed,
|
||||
);
|
||||
}
|
||||
let _ = backend_write.shutdown().await;
|
||||
total
|
||||
});
|
||||
|
||||
let la2 = Arc::clone(&last_activity);
|
||||
let b2c = tokio::spawn(async move {
|
||||
let mut buf = vec![0u8; 65536];
|
||||
let mut total = 0u64;
|
||||
loop {
|
||||
let n = match backend_read.read(&mut buf).await {
|
||||
Ok(0) | Err(_) => break,
|
||||
Ok(n) => n,
|
||||
};
|
||||
if client_write.write_all(&buf[..n]).await.is_err() {
|
||||
break;
|
||||
}
|
||||
total += n as u64;
|
||||
la2.store(
|
||||
start.elapsed().as_millis() as u64,
|
||||
Ordering::Relaxed,
|
||||
);
|
||||
}
|
||||
let _ = client_write.shutdown().await;
|
||||
total
|
||||
});
|
||||
|
||||
// Watchdog task: check for inactivity and max lifetime
|
||||
let la_watch = Arc::clone(&last_activity);
|
||||
let c2b_handle = c2b.abort_handle();
|
||||
let b2c_handle = b2c.abort_handle();
|
||||
let watchdog = tokio::spawn(async move {
|
||||
let check_interval = std::time::Duration::from_secs(5);
|
||||
let mut last_seen = 0u64;
|
||||
loop {
|
||||
tokio::time::sleep(check_interval).await;
|
||||
|
||||
// Check max lifetime
|
||||
if start.elapsed() >= max_lifetime {
|
||||
debug!("Connection exceeded max lifetime, closing");
|
||||
c2b_handle.abort();
|
||||
b2c_handle.abort();
|
||||
break;
|
||||
}
|
||||
|
||||
// Check inactivity
|
||||
let current = la_watch.load(Ordering::Relaxed);
|
||||
if current == last_seen {
|
||||
// No activity since last check
|
||||
let elapsed_since_activity = start.elapsed().as_millis() as u64 - current;
|
||||
if elapsed_since_activity >= inactivity_timeout.as_millis() as u64 {
|
||||
debug!("Connection inactive for {}ms, closing", elapsed_since_activity);
|
||||
c2b_handle.abort();
|
||||
b2c_handle.abort();
|
||||
break;
|
||||
}
|
||||
}
|
||||
last_seen = current;
|
||||
}
|
||||
});
|
||||
|
||||
let bytes_in = c2b.await.unwrap_or(0);
|
||||
let bytes_out = b2c.await.unwrap_or(0);
|
||||
watchdog.abort();
|
||||
(bytes_in, bytes_out)
|
||||
}
|
||||
}
|
||||
190
rust/crates/rustproxy-passthrough/src/tls_handler.rs
Normal file
190
rust/crates/rustproxy-passthrough/src/tls_handler.rs
Normal file
@@ -0,0 +1,190 @@
|
||||
use std::io::BufReader;
|
||||
use std::sync::Arc;
|
||||
|
||||
use rustls::pki_types::{CertificateDer, PrivateKeyDer};
|
||||
use rustls::ServerConfig;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio_rustls::{TlsAcceptor, TlsConnector, server::TlsStream as ServerTlsStream};
|
||||
use tracing::debug;
|
||||
|
||||
/// Ensure the default crypto provider is installed.
|
||||
fn ensure_crypto_provider() {
|
||||
let _ = rustls::crypto::ring::default_provider().install_default();
|
||||
}
|
||||
|
||||
/// Build a TLS acceptor from PEM-encoded cert and key data.
|
||||
pub fn build_tls_acceptor(cert_pem: &str, key_pem: &str) -> Result<TlsAcceptor, Box<dyn std::error::Error + Send + Sync>> {
|
||||
build_tls_acceptor_with_config(cert_pem, key_pem, None)
|
||||
}
|
||||
|
||||
/// Build a TLS acceptor with optional RouteTls configuration for version/cipher tuning.
|
||||
pub fn build_tls_acceptor_with_config(
|
||||
cert_pem: &str,
|
||||
key_pem: &str,
|
||||
tls_config: Option<&rustproxy_config::RouteTls>,
|
||||
) -> Result<TlsAcceptor, Box<dyn std::error::Error + Send + Sync>> {
|
||||
ensure_crypto_provider();
|
||||
let certs = load_certs(cert_pem)?;
|
||||
let key = load_private_key(key_pem)?;
|
||||
|
||||
let mut config = if let Some(route_tls) = tls_config {
|
||||
// Apply TLS version restrictions
|
||||
let versions = resolve_tls_versions(route_tls.versions.as_deref());
|
||||
let builder = ServerConfig::builder_with_protocol_versions(&versions);
|
||||
builder
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(certs, key)?
|
||||
} else {
|
||||
ServerConfig::builder()
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(certs, key)?
|
||||
};
|
||||
|
||||
// Apply session timeout if configured
|
||||
if let Some(route_tls) = tls_config {
|
||||
if let Some(timeout_secs) = route_tls.session_timeout {
|
||||
config.session_storage = rustls::server::ServerSessionMemoryCache::new(
|
||||
256, // max sessions
|
||||
);
|
||||
debug!("TLS session timeout configured: {}s", timeout_secs);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(TlsAcceptor::from(Arc::new(config)))
|
||||
}
|
||||
|
||||
/// Resolve TLS version strings to rustls SupportedProtocolVersion.
|
||||
fn resolve_tls_versions(versions: Option<&[String]>) -> Vec<&'static rustls::SupportedProtocolVersion> {
|
||||
let versions = match versions {
|
||||
Some(v) if !v.is_empty() => v,
|
||||
_ => return vec![&rustls::version::TLS12, &rustls::version::TLS13],
|
||||
};
|
||||
|
||||
let mut result = Vec::new();
|
||||
for v in versions {
|
||||
match v.as_str() {
|
||||
"TLSv1.2" | "TLS1.2" | "1.2" | "TLSv12" => {
|
||||
if !result.contains(&&rustls::version::TLS12) {
|
||||
result.push(&rustls::version::TLS12);
|
||||
}
|
||||
}
|
||||
"TLSv1.3" | "TLS1.3" | "1.3" | "TLSv13" => {
|
||||
if !result.contains(&&rustls::version::TLS13) {
|
||||
result.push(&rustls::version::TLS13);
|
||||
}
|
||||
}
|
||||
other => {
|
||||
debug!("Unknown TLS version '{}', ignoring", other);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if result.is_empty() {
|
||||
// Fallback to both if no valid versions specified
|
||||
vec![&rustls::version::TLS12, &rustls::version::TLS13]
|
||||
} else {
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
/// Accept a TLS connection from a client stream.
|
||||
pub async fn accept_tls(
|
||||
stream: TcpStream,
|
||||
acceptor: &TlsAcceptor,
|
||||
) -> Result<ServerTlsStream<TcpStream>, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let tls_stream = acceptor.accept(stream).await?;
|
||||
debug!("TLS handshake completed");
|
||||
Ok(tls_stream)
|
||||
}
|
||||
|
||||
/// Connect to a backend with TLS (for terminate-and-reencrypt mode).
|
||||
pub async fn connect_tls(
|
||||
host: &str,
|
||||
port: u16,
|
||||
) -> Result<tokio_rustls::client::TlsStream<TcpStream>, Box<dyn std::error::Error + Send + Sync>> {
|
||||
ensure_crypto_provider();
|
||||
let config = rustls::ClientConfig::builder()
|
||||
.dangerous()
|
||||
.with_custom_certificate_verifier(Arc::new(InsecureVerifier))
|
||||
.with_no_client_auth();
|
||||
|
||||
let connector = TlsConnector::from(Arc::new(config));
|
||||
|
||||
let stream = TcpStream::connect(format!("{}:{}", host, port)).await?;
|
||||
stream.set_nodelay(true)?;
|
||||
|
||||
let server_name = rustls::pki_types::ServerName::try_from(host.to_string())?;
|
||||
let tls_stream = connector.connect(server_name, stream).await?;
|
||||
debug!("Backend TLS connection established to {}:{}", host, port);
|
||||
Ok(tls_stream)
|
||||
}
|
||||
|
||||
/// Load certificates from PEM string.
|
||||
fn load_certs(pem: &str) -> Result<Vec<CertificateDer<'static>>, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let mut reader = BufReader::new(pem.as_bytes());
|
||||
let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut reader)
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
if certs.is_empty() {
|
||||
return Err("No certificates found in PEM data".into());
|
||||
}
|
||||
Ok(certs)
|
||||
}
|
||||
|
||||
/// Load private key from PEM string.
|
||||
fn load_private_key(pem: &str) -> Result<PrivateKeyDer<'static>, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let mut reader = BufReader::new(pem.as_bytes());
|
||||
// Try PKCS8 first, then RSA, then EC
|
||||
let key = rustls_pemfile::private_key(&mut reader)?
|
||||
.ok_or("No private key found in PEM data")?;
|
||||
Ok(key)
|
||||
}
|
||||
|
||||
/// Insecure certificate verifier for backend connections (terminate-and-reencrypt).
|
||||
/// In internal networks, backends may use self-signed certs.
|
||||
#[derive(Debug)]
|
||||
struct InsecureVerifier;
|
||||
|
||||
impl rustls::client::danger::ServerCertVerifier for InsecureVerifier {
|
||||
fn verify_server_cert(
|
||||
&self,
|
||||
_end_entity: &CertificateDer<'_>,
|
||||
_intermediates: &[CertificateDer<'_>],
|
||||
_server_name: &rustls::pki_types::ServerName<'_>,
|
||||
_ocsp_response: &[u8],
|
||||
_now: rustls::pki_types::UnixTime,
|
||||
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
|
||||
Ok(rustls::client::danger::ServerCertVerified::assertion())
|
||||
}
|
||||
|
||||
fn verify_tls12_signature(
|
||||
&self,
|
||||
_message: &[u8],
|
||||
_cert: &CertificateDer<'_>,
|
||||
_dss: &rustls::DigitallySignedStruct,
|
||||
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
|
||||
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
|
||||
}
|
||||
|
||||
fn verify_tls13_signature(
|
||||
&self,
|
||||
_message: &[u8],
|
||||
_cert: &CertificateDer<'_>,
|
||||
_dss: &rustls::DigitallySignedStruct,
|
||||
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
|
||||
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
|
||||
}
|
||||
|
||||
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
|
||||
vec![
|
||||
rustls::SignatureScheme::RSA_PKCS1_SHA256,
|
||||
rustls::SignatureScheme::RSA_PKCS1_SHA384,
|
||||
rustls::SignatureScheme::RSA_PKCS1_SHA512,
|
||||
rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
|
||||
rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
|
||||
rustls::SignatureScheme::ED25519,
|
||||
rustls::SignatureScheme::RSA_PSS_SHA256,
|
||||
rustls::SignatureScheme::RSA_PSS_SHA384,
|
||||
rustls::SignatureScheme::RSA_PSS_SHA512,
|
||||
]
|
||||
}
|
||||
}
|
||||
16
rust/crates/rustproxy-routing/Cargo.toml
Normal file
16
rust/crates/rustproxy-routing/Cargo.toml
Normal file
@@ -0,0 +1,16 @@
|
||||
[package]
|
||||
name = "rustproxy-routing"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
authors.workspace = true
|
||||
description = "Route matching engine for RustProxy"
|
||||
|
||||
[dependencies]
|
||||
rustproxy-config = { workspace = true }
|
||||
glob-match = { workspace = true }
|
||||
ipnet = { workspace = true }
|
||||
regex = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
arc-swap = { workspace = true }
|
||||
9
rust/crates/rustproxy-routing/src/lib.rs
Normal file
9
rust/crates/rustproxy-routing/src/lib.rs
Normal file
@@ -0,0 +1,9 @@
|
||||
//! # rustproxy-routing
|
||||
//!
|
||||
//! Route matching engine for RustProxy.
|
||||
//! Provides domain/path/IP/header matchers and a port-indexed RouteManager.
|
||||
|
||||
pub mod route_manager;
|
||||
pub mod matchers;
|
||||
|
||||
pub use route_manager::*;
|
||||
86
rust/crates/rustproxy-routing/src/matchers/domain.rs
Normal file
86
rust/crates/rustproxy-routing/src/matchers/domain.rs
Normal file
@@ -0,0 +1,86 @@
|
||||
/// Match a domain against a pattern supporting wildcards.
|
||||
///
|
||||
/// Supported patterns:
|
||||
/// - `*` matches any domain
|
||||
/// - `*.example.com` matches any subdomain of example.com
|
||||
/// - `example.com` exact match
|
||||
/// - `**.example.com` matches any depth of subdomain
|
||||
pub fn domain_matches(pattern: &str, domain: &str) -> bool {
|
||||
let pattern = pattern.trim().to_lowercase();
|
||||
let domain = domain.trim().to_lowercase();
|
||||
|
||||
if pattern == "*" {
|
||||
return true;
|
||||
}
|
||||
|
||||
if pattern == domain {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Wildcard patterns
|
||||
if pattern.starts_with("*.") {
|
||||
let suffix = &pattern[2..]; // e.g., "example.com"
|
||||
// Match exact parent or any single-level subdomain
|
||||
if domain == suffix {
|
||||
return true;
|
||||
}
|
||||
if domain.ends_with(&format!(".{}", suffix)) {
|
||||
// Check it's a single level subdomain for `*.`
|
||||
let prefix = &domain[..domain.len() - suffix.len() - 1];
|
||||
return !prefix.contains('.');
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
if pattern.starts_with("**.") {
|
||||
let suffix = &pattern[3..];
|
||||
// Match exact parent or any depth of subdomain
|
||||
return domain == suffix || domain.ends_with(&format!(".{}", suffix));
|
||||
}
|
||||
|
||||
// Use glob-match for more complex patterns
|
||||
glob_match::glob_match(&pattern, &domain)
|
||||
}
|
||||
|
||||
/// Check if a domain matches any of the given patterns.
|
||||
pub fn domain_matches_any(patterns: &[&str], domain: &str) -> bool {
|
||||
patterns.iter().any(|p| domain_matches(p, domain))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_exact_match() {
|
||||
assert!(domain_matches("example.com", "example.com"));
|
||||
assert!(!domain_matches("example.com", "other.com"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_wildcard_all() {
|
||||
assert!(domain_matches("*", "anything.com"));
|
||||
assert!(domain_matches("*", "sub.domain.example.com"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_wildcard_subdomain() {
|
||||
assert!(domain_matches("*.example.com", "www.example.com"));
|
||||
assert!(domain_matches("*.example.com", "api.example.com"));
|
||||
assert!(domain_matches("*.example.com", "example.com"));
|
||||
assert!(!domain_matches("*.example.com", "deep.sub.example.com"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_double_wildcard() {
|
||||
assert!(domain_matches("**.example.com", "www.example.com"));
|
||||
assert!(domain_matches("**.example.com", "deep.sub.example.com"));
|
||||
assert!(domain_matches("**.example.com", "example.com"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_case_insensitive() {
|
||||
assert!(domain_matches("Example.COM", "example.com"));
|
||||
assert!(domain_matches("*.EXAMPLE.com", "WWW.example.COM"));
|
||||
}
|
||||
}
|
||||
98
rust/crates/rustproxy-routing/src/matchers/header.rs
Normal file
98
rust/crates/rustproxy-routing/src/matchers/header.rs
Normal file
@@ -0,0 +1,98 @@
|
||||
use std::collections::HashMap;
|
||||
use regex::Regex;
|
||||
|
||||
/// Match HTTP headers against a set of patterns.
|
||||
///
|
||||
/// Pattern values can be:
|
||||
/// - Exact string: `"application/json"`
|
||||
/// - Regex (surrounded by /): `"/^text\/.*/"`
|
||||
pub fn headers_match(
|
||||
patterns: &HashMap<String, String>,
|
||||
headers: &HashMap<String, String>,
|
||||
) -> bool {
|
||||
for (key, pattern) in patterns {
|
||||
let key_lower = key.to_lowercase();
|
||||
|
||||
// Find the header (case-insensitive)
|
||||
let header_value = headers
|
||||
.iter()
|
||||
.find(|(k, _)| k.to_lowercase() == key_lower)
|
||||
.map(|(_, v)| v.as_str());
|
||||
|
||||
let header_value = match header_value {
|
||||
Some(v) => v,
|
||||
None => return false, // Required header not present
|
||||
};
|
||||
|
||||
// Check if pattern is a regex (surrounded by /)
|
||||
if pattern.starts_with('/') && pattern.ends_with('/') && pattern.len() > 2 {
|
||||
let regex_str = &pattern[1..pattern.len() - 1];
|
||||
match Regex::new(regex_str) {
|
||||
Ok(re) => {
|
||||
if !re.is_match(header_value) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
// Invalid regex, fall back to exact match
|
||||
if header_value != pattern {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Exact match
|
||||
if header_value != pattern {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_exact_header_match() {
|
||||
let patterns: HashMap<String, String> = {
|
||||
let mut m = HashMap::new();
|
||||
m.insert("Content-Type".to_string(), "application/json".to_string());
|
||||
m
|
||||
};
|
||||
let headers: HashMap<String, String> = {
|
||||
let mut m = HashMap::new();
|
||||
m.insert("content-type".to_string(), "application/json".to_string());
|
||||
m
|
||||
};
|
||||
assert!(headers_match(&patterns, &headers));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_regex_header_match() {
|
||||
let patterns: HashMap<String, String> = {
|
||||
let mut m = HashMap::new();
|
||||
m.insert("Content-Type".to_string(), "/^text\\/.*/".to_string());
|
||||
m
|
||||
};
|
||||
let headers: HashMap<String, String> = {
|
||||
let mut m = HashMap::new();
|
||||
m.insert("content-type".to_string(), "text/html".to_string());
|
||||
m
|
||||
};
|
||||
assert!(headers_match(&patterns, &headers));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_missing_header() {
|
||||
let patterns: HashMap<String, String> = {
|
||||
let mut m = HashMap::new();
|
||||
m.insert("X-Custom".to_string(), "value".to_string());
|
||||
m
|
||||
};
|
||||
let headers: HashMap<String, String> = HashMap::new();
|
||||
assert!(!headers_match(&patterns, &headers));
|
||||
}
|
||||
}
|
||||
126
rust/crates/rustproxy-routing/src/matchers/ip.rs
Normal file
126
rust/crates/rustproxy-routing/src/matchers/ip.rs
Normal file
@@ -0,0 +1,126 @@
|
||||
use std::net::IpAddr;
|
||||
use std::str::FromStr;
|
||||
use ipnet::IpNet;
|
||||
|
||||
/// Match an IP address against a pattern.
|
||||
///
|
||||
/// Supported patterns:
|
||||
/// - `*` matches any IP
|
||||
/// - `192.168.1.0/24` CIDR range
|
||||
/// - `192.168.1.100` exact match
|
||||
/// - `192.168.1.*` wildcard (converted to CIDR)
|
||||
/// - `::ffff:192.168.1.100` IPv6-mapped IPv4
|
||||
pub fn ip_matches(pattern: &str, ip: &str) -> bool {
|
||||
let pattern = pattern.trim();
|
||||
|
||||
if pattern == "*" {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Normalize IPv4-mapped IPv6
|
||||
let normalized_ip = normalize_ip_str(ip);
|
||||
|
||||
// Try CIDR match
|
||||
if pattern.contains('/') {
|
||||
if let Ok(net) = IpNet::from_str(pattern) {
|
||||
if let Ok(addr) = IpAddr::from_str(&normalized_ip) {
|
||||
return net.contains(&addr);
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Handle wildcard patterns like 192.168.1.*
|
||||
if pattern.contains('*') {
|
||||
let pattern_cidr = wildcard_to_cidr(pattern);
|
||||
if let Some(cidr) = pattern_cidr {
|
||||
if let Ok(net) = IpNet::from_str(&cidr) {
|
||||
if let Ok(addr) = IpAddr::from_str(&normalized_ip) {
|
||||
return net.contains(&addr);
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Exact match
|
||||
let normalized_pattern = normalize_ip_str(pattern);
|
||||
normalized_ip == normalized_pattern
|
||||
}
|
||||
|
||||
/// Check if an IP matches any of the given patterns.
|
||||
pub fn ip_matches_any(patterns: &[String], ip: &str) -> bool {
|
||||
patterns.iter().any(|p| ip_matches(p, ip))
|
||||
}
|
||||
|
||||
/// Normalize IPv4-mapped IPv6 addresses.
|
||||
fn normalize_ip_str(ip: &str) -> String {
|
||||
let ip = ip.trim();
|
||||
if ip.starts_with("::ffff:") {
|
||||
return ip[7..].to_string();
|
||||
}
|
||||
ip.to_string()
|
||||
}
|
||||
|
||||
/// Convert a wildcard IP pattern to CIDR notation.
|
||||
/// e.g., "192.168.1.*" -> "192.168.1.0/24"
|
||||
fn wildcard_to_cidr(pattern: &str) -> Option<String> {
|
||||
let parts: Vec<&str> = pattern.split('.').collect();
|
||||
if parts.len() != 4 {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut octets = [0u8; 4];
|
||||
let mut prefix_len = 0;
|
||||
|
||||
for (i, part) in parts.iter().enumerate() {
|
||||
if *part == "*" {
|
||||
break;
|
||||
}
|
||||
if let Ok(n) = part.parse::<u8>() {
|
||||
octets[i] = n;
|
||||
prefix_len += 8;
|
||||
} else {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
|
||||
Some(format!("{}.{}.{}.{}/{}", octets[0], octets[1], octets[2], octets[3], prefix_len))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_wildcard_all() {
|
||||
assert!(ip_matches("*", "192.168.1.100"));
|
||||
assert!(ip_matches("*", "::1"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_exact_match() {
|
||||
assert!(ip_matches("192.168.1.100", "192.168.1.100"));
|
||||
assert!(!ip_matches("192.168.1.100", "192.168.1.101"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cidr() {
|
||||
assert!(ip_matches("192.168.1.0/24", "192.168.1.100"));
|
||||
assert!(ip_matches("192.168.1.0/24", "192.168.1.1"));
|
||||
assert!(!ip_matches("192.168.1.0/24", "192.168.2.1"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_wildcard_pattern() {
|
||||
assert!(ip_matches("192.168.1.*", "192.168.1.100"));
|
||||
assert!(ip_matches("192.168.1.*", "192.168.1.1"));
|
||||
assert!(!ip_matches("192.168.1.*", "192.168.2.1"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ipv6_mapped() {
|
||||
assert!(ip_matches("192.168.1.100", "::ffff:192.168.1.100"));
|
||||
assert!(ip_matches("192.168.1.0/24", "::ffff:192.168.1.50"));
|
||||
}
|
||||
}
|
||||
9
rust/crates/rustproxy-routing/src/matchers/mod.rs
Normal file
9
rust/crates/rustproxy-routing/src/matchers/mod.rs
Normal file
@@ -0,0 +1,9 @@
|
||||
pub mod domain;
|
||||
pub mod path;
|
||||
pub mod ip;
|
||||
pub mod header;
|
||||
|
||||
pub use domain::*;
|
||||
pub use path::*;
|
||||
pub use ip::*;
|
||||
pub use header::*;
|
||||
65
rust/crates/rustproxy-routing/src/matchers/path.rs
Normal file
65
rust/crates/rustproxy-routing/src/matchers/path.rs
Normal file
@@ -0,0 +1,65 @@
|
||||
/// Match a URL path against a pattern supporting wildcards.
|
||||
///
|
||||
/// Supported patterns:
|
||||
/// - `/api/*` matches `/api/anything` (single level)
|
||||
/// - `/api/**` matches `/api/any/depth/here`
|
||||
/// - `/exact/path` exact match
|
||||
/// - `/prefix*` prefix match
|
||||
pub fn path_matches(pattern: &str, path: &str) -> bool {
|
||||
// Exact match
|
||||
if pattern == path {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Double-star: match any depth
|
||||
if pattern.ends_with("/**") {
|
||||
let prefix = &pattern[..pattern.len() - 3];
|
||||
return path == prefix || path.starts_with(&format!("{}/", prefix));
|
||||
}
|
||||
|
||||
// Single-star at end: match single path segment
|
||||
if pattern.ends_with("/*") {
|
||||
let prefix = &pattern[..pattern.len() - 2];
|
||||
if path == prefix {
|
||||
return true;
|
||||
}
|
||||
if path.starts_with(&format!("{}/", prefix)) {
|
||||
let rest = &path[prefix.len() + 1..];
|
||||
// Single level means no more slashes
|
||||
return !rest.contains('/');
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Star anywhere: use glob matching
|
||||
if pattern.contains('*') {
|
||||
return glob_match::glob_match(pattern, path);
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_exact_path() {
|
||||
assert!(path_matches("/api/users", "/api/users"));
|
||||
assert!(!path_matches("/api/users", "/api/posts"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_single_wildcard() {
|
||||
assert!(path_matches("/api/*", "/api/users"));
|
||||
assert!(path_matches("/api/*", "/api/posts"));
|
||||
assert!(!path_matches("/api/*", "/api/users/123"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_double_wildcard() {
|
||||
assert!(path_matches("/api/**", "/api/users"));
|
||||
assert!(path_matches("/api/**", "/api/users/123"));
|
||||
assert!(path_matches("/api/**", "/api/users/123/posts"));
|
||||
}
|
||||
}
|
||||
545
rust/crates/rustproxy-routing/src/route_manager.rs
Normal file
545
rust/crates/rustproxy-routing/src/route_manager.rs
Normal file
@@ -0,0 +1,545 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use rustproxy_config::{RouteConfig, RouteTarget, TlsMode};
|
||||
use crate::matchers;
|
||||
|
||||
/// Context for route matching (subset of connection info).
|
||||
pub struct MatchContext<'a> {
|
||||
pub port: u16,
|
||||
pub domain: Option<&'a str>,
|
||||
pub path: Option<&'a str>,
|
||||
pub client_ip: Option<&'a str>,
|
||||
pub tls_version: Option<&'a str>,
|
||||
pub headers: Option<&'a HashMap<String, String>>,
|
||||
pub is_tls: bool,
|
||||
}
|
||||
|
||||
/// Result of a route match.
|
||||
pub struct RouteMatchResult<'a> {
|
||||
pub route: &'a RouteConfig,
|
||||
pub target: Option<&'a RouteTarget>,
|
||||
}
|
||||
|
||||
/// Port-indexed route lookup with priority-based matching.
|
||||
/// This is the core routing engine.
|
||||
pub struct RouteManager {
|
||||
/// Routes indexed by port for O(1) port lookup.
|
||||
port_index: HashMap<u16, Vec<usize>>,
|
||||
/// All routes, sorted by priority (highest first).
|
||||
routes: Vec<RouteConfig>,
|
||||
}
|
||||
|
||||
impl RouteManager {
|
||||
/// Create a new RouteManager from a list of routes.
|
||||
pub fn new(routes: Vec<RouteConfig>) -> Self {
|
||||
let mut manager = Self {
|
||||
port_index: HashMap::new(),
|
||||
routes: Vec::new(),
|
||||
};
|
||||
|
||||
// Filter enabled routes and sort by priority
|
||||
let mut enabled_routes: Vec<RouteConfig> = routes
|
||||
.into_iter()
|
||||
.filter(|r| r.is_enabled())
|
||||
.collect();
|
||||
enabled_routes.sort_by(|a, b| b.effective_priority().cmp(&a.effective_priority()));
|
||||
|
||||
// Build port index
|
||||
for (idx, route) in enabled_routes.iter().enumerate() {
|
||||
for port in route.listening_ports() {
|
||||
manager.port_index
|
||||
.entry(port)
|
||||
.or_default()
|
||||
.push(idx);
|
||||
}
|
||||
}
|
||||
|
||||
manager.routes = enabled_routes;
|
||||
manager
|
||||
}
|
||||
|
||||
/// Find the best matching route for the given context.
|
||||
pub fn find_route<'a>(&'a self, ctx: &MatchContext<'_>) -> Option<RouteMatchResult<'a>> {
|
||||
// Get routes for this port
|
||||
let route_indices = self.port_index.get(&ctx.port)?;
|
||||
|
||||
for &idx in route_indices {
|
||||
let route = &self.routes[idx];
|
||||
|
||||
if self.matches_route(route, ctx) {
|
||||
// Find the best matching target within the route
|
||||
let target = self.find_target(route, ctx);
|
||||
return Some(RouteMatchResult { route, target });
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Check if a route matches the given context.
|
||||
fn matches_route(&self, route: &RouteConfig, ctx: &MatchContext<'_>) -> bool {
|
||||
let rm = &route.route_match;
|
||||
|
||||
// Domain matching
|
||||
if let Some(ref domains) = rm.domains {
|
||||
if let Some(domain) = ctx.domain {
|
||||
let patterns = domains.to_vec();
|
||||
if !matchers::domain_matches_any(&patterns, domain) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
// If no domain provided but route requires domain, it depends on context
|
||||
// For TLS passthrough, we need SNI; for other cases we may still match
|
||||
}
|
||||
|
||||
// Path matching
|
||||
if let Some(ref pattern) = rm.path {
|
||||
if let Some(path) = ctx.path {
|
||||
if !matchers::path_matches(pattern, path) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
// Route requires path but none provided
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Client IP matching
|
||||
if let Some(ref client_ips) = rm.client_ip {
|
||||
if let Some(ip) = ctx.client_ip {
|
||||
if !matchers::ip_matches_any(client_ips, ip) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// TLS version matching
|
||||
if let Some(ref tls_versions) = rm.tls_version {
|
||||
if let Some(version) = ctx.tls_version {
|
||||
if !tls_versions.iter().any(|v| v == version) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Header matching
|
||||
if let Some(ref patterns) = rm.headers {
|
||||
if let Some(headers) = ctx.headers {
|
||||
if !matchers::headers_match(patterns, headers) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
/// Find the best matching target within a route.
|
||||
fn find_target<'a>(&self, route: &'a RouteConfig, ctx: &MatchContext<'_>) -> Option<&'a RouteTarget> {
|
||||
let targets = route.action.targets.as_ref()?;
|
||||
|
||||
if targets.len() == 1 && targets[0].target_match.is_none() {
|
||||
return Some(&targets[0]);
|
||||
}
|
||||
|
||||
// Sort candidates by priority (already in order from config)
|
||||
let mut best: Option<&RouteTarget> = None;
|
||||
let mut best_priority = i32::MIN;
|
||||
|
||||
for target in targets {
|
||||
let priority = target.priority.unwrap_or(0);
|
||||
|
||||
if let Some(ref tm) = target.target_match {
|
||||
if !self.matches_target(tm, ctx) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
if priority > best_priority || best.is_none() {
|
||||
best = Some(target);
|
||||
best_priority = priority;
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to first target without match criteria
|
||||
best.or_else(|| {
|
||||
targets.iter().find(|t| t.target_match.is_none())
|
||||
})
|
||||
}
|
||||
|
||||
/// Check if a target match criteria matches the context.
|
||||
fn matches_target(
|
||||
&self,
|
||||
tm: &rustproxy_config::TargetMatch,
|
||||
ctx: &MatchContext<'_>,
|
||||
) -> bool {
|
||||
// Port matching
|
||||
if let Some(ref ports) = tm.ports {
|
||||
if !ports.contains(&ctx.port) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Path matching
|
||||
if let Some(ref pattern) = tm.path {
|
||||
if let Some(path) = ctx.path {
|
||||
if !matchers::path_matches(pattern, path) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Header matching
|
||||
if let Some(ref patterns) = tm.headers {
|
||||
if let Some(headers) = ctx.headers {
|
||||
if !matchers::headers_match(patterns, headers) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
/// Get all unique listening ports.
|
||||
pub fn listening_ports(&self) -> Vec<u16> {
|
||||
let mut ports: Vec<u16> = self.port_index.keys().copied().collect();
|
||||
ports.sort();
|
||||
ports
|
||||
}
|
||||
|
||||
/// Get all routes for a specific port.
|
||||
pub fn routes_for_port(&self, port: u16) -> Vec<&RouteConfig> {
|
||||
self.port_index
|
||||
.get(&port)
|
||||
.map(|indices| indices.iter().map(|&i| &self.routes[i]).collect())
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
/// Get the total number of enabled routes.
|
||||
pub fn route_count(&self) -> usize {
|
||||
self.routes.len()
|
||||
}
|
||||
|
||||
/// Check if any route on the given port requires SNI.
|
||||
pub fn port_requires_sni(&self, port: u16) -> bool {
|
||||
let routes = self.routes_for_port(port);
|
||||
|
||||
// If multiple passthrough routes on same port, SNI is needed
|
||||
let passthrough_routes: Vec<_> = routes
|
||||
.iter()
|
||||
.filter(|r| {
|
||||
r.tls_mode() == Some(&TlsMode::Passthrough)
|
||||
})
|
||||
.collect();
|
||||
|
||||
if passthrough_routes.len() > 1 {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Single passthrough route with specific domain restriction needs SNI
|
||||
if let Some(route) = passthrough_routes.first() {
|
||||
if let Some(ref domains) = route.route_match.domains {
|
||||
let domain_list = domains.to_vec();
|
||||
// If it's not just a wildcard, SNI is needed
|
||||
if !domain_list.iter().all(|d| *d == "*") {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use rustproxy_config::*;
|
||||
|
||||
fn make_route(port: u16, domain: Option<&str>, priority: i32) -> RouteConfig {
|
||||
RouteConfig {
|
||||
id: None,
|
||||
route_match: RouteMatch {
|
||||
ports: PortRange::Single(port),
|
||||
domains: domain.map(|d| DomainSpec::Single(d.to_string())),
|
||||
path: None,
|
||||
client_ip: None,
|
||||
tls_version: None,
|
||||
headers: None,
|
||||
},
|
||||
action: RouteAction {
|
||||
action_type: RouteActionType::Forward,
|
||||
targets: Some(vec![RouteTarget {
|
||||
target_match: None,
|
||||
host: HostSpec::Single("localhost".to_string()),
|
||||
port: PortSpec::Fixed(8080),
|
||||
tls: None,
|
||||
websocket: None,
|
||||
load_balancing: None,
|
||||
send_proxy_protocol: None,
|
||||
headers: None,
|
||||
advanced: None,
|
||||
priority: None,
|
||||
}]),
|
||||
tls: None,
|
||||
websocket: None,
|
||||
load_balancing: None,
|
||||
advanced: None,
|
||||
options: None,
|
||||
forwarding_engine: None,
|
||||
nftables: None,
|
||||
send_proxy_protocol: None,
|
||||
},
|
||||
headers: None,
|
||||
security: None,
|
||||
name: None,
|
||||
description: None,
|
||||
priority: Some(priority),
|
||||
tags: None,
|
||||
enabled: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_basic_routing() {
|
||||
let routes = vec![
|
||||
make_route(80, Some("example.com"), 0),
|
||||
make_route(80, Some("other.com"), 0),
|
||||
];
|
||||
let manager = RouteManager::new(routes);
|
||||
|
||||
let ctx = MatchContext {
|
||||
port: 80,
|
||||
domain: Some("example.com"),
|
||||
path: None,
|
||||
client_ip: None,
|
||||
tls_version: None,
|
||||
headers: None,
|
||||
is_tls: false,
|
||||
};
|
||||
|
||||
let result = manager.find_route(&ctx);
|
||||
assert!(result.is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_priority_ordering() {
|
||||
let routes = vec![
|
||||
make_route(80, Some("*.example.com"), 0),
|
||||
make_route(80, Some("api.example.com"), 10), // Higher priority
|
||||
];
|
||||
let manager = RouteManager::new(routes);
|
||||
|
||||
let ctx = MatchContext {
|
||||
port: 80,
|
||||
domain: Some("api.example.com"),
|
||||
path: None,
|
||||
client_ip: None,
|
||||
tls_version: None,
|
||||
headers: None,
|
||||
is_tls: false,
|
||||
};
|
||||
|
||||
let result = manager.find_route(&ctx).unwrap();
|
||||
// Should match the higher-priority specific route
|
||||
assert!(result.route.route_match.domains.as_ref()
|
||||
.map(|d| d.to_vec())
|
||||
.unwrap()
|
||||
.contains(&"api.example.com"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_no_match() {
|
||||
let routes = vec![make_route(80, Some("example.com"), 0)];
|
||||
let manager = RouteManager::new(routes);
|
||||
|
||||
let ctx = MatchContext {
|
||||
port: 443, // Different port
|
||||
domain: Some("example.com"),
|
||||
path: None,
|
||||
client_ip: None,
|
||||
tls_version: None,
|
||||
headers: None,
|
||||
is_tls: false,
|
||||
};
|
||||
|
||||
assert!(manager.find_route(&ctx).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_disabled_routes_excluded() {
|
||||
let mut route = make_route(80, Some("example.com"), 0);
|
||||
route.enabled = Some(false);
|
||||
let manager = RouteManager::new(vec![route]);
|
||||
assert_eq!(manager.route_count(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_listening_ports() {
|
||||
let routes = vec![
|
||||
make_route(80, Some("a.com"), 0),
|
||||
make_route(443, Some("b.com"), 0),
|
||||
make_route(80, Some("c.com"), 0), // duplicate port
|
||||
];
|
||||
let manager = RouteManager::new(routes);
|
||||
let ports = manager.listening_ports();
|
||||
assert_eq!(ports, vec![80, 443]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_port_requires_sni_single_passthrough() {
|
||||
let mut route = make_route(443, Some("example.com"), 0);
|
||||
route.action.tls = Some(RouteTls {
|
||||
mode: TlsMode::Passthrough,
|
||||
certificate: None,
|
||||
acme: None,
|
||||
versions: None,
|
||||
ciphers: None,
|
||||
honor_cipher_order: None,
|
||||
session_timeout: None,
|
||||
});
|
||||
let manager = RouteManager::new(vec![route]);
|
||||
// Single passthrough route with specific domain needs SNI
|
||||
assert!(manager.port_requires_sni(443));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_port_requires_sni_wildcard_only() {
|
||||
let mut route = make_route(443, Some("*"), 0);
|
||||
route.action.tls = Some(RouteTls {
|
||||
mode: TlsMode::Passthrough,
|
||||
certificate: None,
|
||||
acme: None,
|
||||
versions: None,
|
||||
ciphers: None,
|
||||
honor_cipher_order: None,
|
||||
session_timeout: None,
|
||||
});
|
||||
let manager = RouteManager::new(vec![route]);
|
||||
// Single passthrough route with wildcard doesn't need SNI
|
||||
assert!(!manager.port_requires_sni(443));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_routes_for_port() {
|
||||
let routes = vec![
|
||||
make_route(80, Some("a.com"), 0),
|
||||
make_route(80, Some("b.com"), 0),
|
||||
make_route(443, Some("c.com"), 0),
|
||||
];
|
||||
let manager = RouteManager::new(routes);
|
||||
assert_eq!(manager.routes_for_port(80).len(), 2);
|
||||
assert_eq!(manager.routes_for_port(443).len(), 1);
|
||||
assert_eq!(manager.routes_for_port(8080).len(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_wildcard_domain_matches_any() {
|
||||
let routes = vec![make_route(80, Some("*"), 0)];
|
||||
let manager = RouteManager::new(routes);
|
||||
|
||||
let ctx = MatchContext {
|
||||
port: 80,
|
||||
domain: Some("anything.example.com"),
|
||||
path: None,
|
||||
client_ip: None,
|
||||
tls_version: None,
|
||||
headers: None,
|
||||
is_tls: false,
|
||||
};
|
||||
|
||||
assert!(manager.find_route(&ctx).is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_no_domain_route_matches_any_domain() {
|
||||
let routes = vec![make_route(80, None, 0)];
|
||||
let manager = RouteManager::new(routes);
|
||||
|
||||
let ctx = MatchContext {
|
||||
port: 80,
|
||||
domain: Some("example.com"),
|
||||
path: None,
|
||||
client_ip: None,
|
||||
tls_version: None,
|
||||
headers: None,
|
||||
is_tls: false,
|
||||
};
|
||||
|
||||
assert!(manager.find_route(&ctx).is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_target_sub_matching() {
|
||||
let mut route = make_route(80, Some("example.com"), 0);
|
||||
route.action.targets = Some(vec![
|
||||
RouteTarget {
|
||||
target_match: Some(rustproxy_config::TargetMatch {
|
||||
ports: None,
|
||||
path: Some("/api/*".to_string()),
|
||||
headers: None,
|
||||
method: None,
|
||||
}),
|
||||
host: HostSpec::Single("api-backend".to_string()),
|
||||
port: PortSpec::Fixed(3000),
|
||||
tls: None,
|
||||
websocket: None,
|
||||
load_balancing: None,
|
||||
send_proxy_protocol: None,
|
||||
headers: None,
|
||||
advanced: None,
|
||||
priority: Some(10),
|
||||
},
|
||||
RouteTarget {
|
||||
target_match: None,
|
||||
host: HostSpec::Single("default-backend".to_string()),
|
||||
port: PortSpec::Fixed(8080),
|
||||
tls: None,
|
||||
websocket: None,
|
||||
load_balancing: None,
|
||||
send_proxy_protocol: None,
|
||||
headers: None,
|
||||
advanced: None,
|
||||
priority: None,
|
||||
},
|
||||
]);
|
||||
let manager = RouteManager::new(vec![route]);
|
||||
|
||||
// Should match the API target
|
||||
let ctx = MatchContext {
|
||||
port: 80,
|
||||
domain: Some("example.com"),
|
||||
path: Some("/api/users"),
|
||||
client_ip: None,
|
||||
tls_version: None,
|
||||
headers: None,
|
||||
is_tls: false,
|
||||
};
|
||||
let result = manager.find_route(&ctx).unwrap();
|
||||
assert_eq!(result.target.unwrap().host.first(), "api-backend");
|
||||
|
||||
// Should fall back to default target
|
||||
let ctx = MatchContext {
|
||||
port: 80,
|
||||
domain: Some("example.com"),
|
||||
path: Some("/home"),
|
||||
client_ip: None,
|
||||
tls_version: None,
|
||||
headers: None,
|
||||
is_tls: false,
|
||||
};
|
||||
let result = manager.find_route(&ctx).unwrap();
|
||||
assert_eq!(result.target.unwrap().host.first(), "default-backend");
|
||||
}
|
||||
}
|
||||
17
rust/crates/rustproxy-security/Cargo.toml
Normal file
17
rust/crates/rustproxy-security/Cargo.toml
Normal file
@@ -0,0 +1,17 @@
|
||||
[package]
|
||||
name = "rustproxy-security"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
authors.workspace = true
|
||||
description = "IP filtering, rate limiting, and authentication for RustProxy"
|
||||
|
||||
[dependencies]
|
||||
rustproxy-config = { workspace = true }
|
||||
dashmap = { workspace = true }
|
||||
ipnet = { workspace = true }
|
||||
jsonwebtoken = { workspace = true }
|
||||
base64 = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
111
rust/crates/rustproxy-security/src/basic_auth.rs
Normal file
111
rust/crates/rustproxy-security/src/basic_auth.rs
Normal file
@@ -0,0 +1,111 @@
|
||||
use base64::Engine;
|
||||
use base64::engine::general_purpose::STANDARD as BASE64;
|
||||
|
||||
/// Basic auth validator.
|
||||
pub struct BasicAuthValidator {
|
||||
users: Vec<(String, String)>,
|
||||
realm: String,
|
||||
}
|
||||
|
||||
impl BasicAuthValidator {
|
||||
pub fn new(users: Vec<(String, String)>, realm: Option<String>) -> Self {
|
||||
Self {
|
||||
users,
|
||||
realm: realm.unwrap_or_else(|| "Restricted".to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Validate an Authorization header value.
|
||||
/// Returns the username if valid.
|
||||
pub fn validate(&self, auth_header: &str) -> Option<String> {
|
||||
let auth_header = auth_header.trim();
|
||||
if !auth_header.starts_with("Basic ") {
|
||||
return None;
|
||||
}
|
||||
|
||||
let encoded = &auth_header[6..];
|
||||
let decoded = BASE64.decode(encoded).ok()?;
|
||||
let credentials = String::from_utf8(decoded).ok()?;
|
||||
|
||||
let mut parts = credentials.splitn(2, ':');
|
||||
let username = parts.next()?;
|
||||
let password = parts.next()?;
|
||||
|
||||
for (u, p) in &self.users {
|
||||
if u == username && p == password {
|
||||
return Some(username.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Get the realm for WWW-Authenticate header.
|
||||
pub fn realm(&self) -> &str {
|
||||
&self.realm
|
||||
}
|
||||
|
||||
/// Generate the WWW-Authenticate header value.
|
||||
pub fn www_authenticate(&self) -> String {
|
||||
format!("Basic realm=\"{}\"", self.realm)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use base64::Engine;
|
||||
|
||||
fn make_validator() -> BasicAuthValidator {
|
||||
BasicAuthValidator::new(
|
||||
vec![
|
||||
("admin".to_string(), "secret".to_string()),
|
||||
("user".to_string(), "pass".to_string()),
|
||||
],
|
||||
Some("TestRealm".to_string()),
|
||||
)
|
||||
}
|
||||
|
||||
fn encode_basic(user: &str, pass: &str) -> String {
|
||||
let encoded = BASE64.encode(format!("{}:{}", user, pass));
|
||||
format!("Basic {}", encoded)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_valid_credentials() {
|
||||
let validator = make_validator();
|
||||
let header = encode_basic("admin", "secret");
|
||||
assert_eq!(validator.validate(&header), Some("admin".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_password() {
|
||||
let validator = make_validator();
|
||||
let header = encode_basic("admin", "wrong");
|
||||
assert_eq!(validator.validate(&header), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_not_basic_scheme() {
|
||||
let validator = make_validator();
|
||||
assert_eq!(validator.validate("Bearer sometoken"), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_malformed_base64() {
|
||||
let validator = make_validator();
|
||||
assert_eq!(validator.validate("Basic !!!not-base64!!!"), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_www_authenticate_format() {
|
||||
let validator = make_validator();
|
||||
assert_eq!(validator.www_authenticate(), "Basic realm=\"TestRealm\"");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_default_realm() {
|
||||
let validator = BasicAuthValidator::new(vec![], None);
|
||||
assert_eq!(validator.www_authenticate(), "Basic realm=\"Restricted\"");
|
||||
}
|
||||
}
|
||||
189
rust/crates/rustproxy-security/src/ip_filter.rs
Normal file
189
rust/crates/rustproxy-security/src/ip_filter.rs
Normal file
@@ -0,0 +1,189 @@
|
||||
use ipnet::IpNet;
|
||||
use std::net::IpAddr;
|
||||
use std::str::FromStr;
|
||||
|
||||
/// IP filter supporting CIDR ranges, wildcards, and exact matches.
|
||||
pub struct IpFilter {
|
||||
allow_list: Vec<IpPattern>,
|
||||
block_list: Vec<IpPattern>,
|
||||
}
|
||||
|
||||
/// Represents an IP pattern for matching.
|
||||
#[derive(Debug)]
|
||||
enum IpPattern {
|
||||
/// Exact IP match
|
||||
Exact(IpAddr),
|
||||
/// CIDR range match
|
||||
Cidr(IpNet),
|
||||
/// Wildcard (matches everything)
|
||||
Wildcard,
|
||||
}
|
||||
|
||||
impl IpPattern {
|
||||
fn parse(s: &str) -> Self {
|
||||
let s = s.trim();
|
||||
if s == "*" {
|
||||
return IpPattern::Wildcard;
|
||||
}
|
||||
if let Ok(net) = IpNet::from_str(s) {
|
||||
return IpPattern::Cidr(net);
|
||||
}
|
||||
if let Ok(addr) = IpAddr::from_str(s) {
|
||||
return IpPattern::Exact(addr);
|
||||
}
|
||||
// Try as CIDR by appending default prefix
|
||||
if let Ok(addr) = IpAddr::from_str(s) {
|
||||
return IpPattern::Exact(addr);
|
||||
}
|
||||
// Fallback: treat as exact, will never match an invalid string
|
||||
IpPattern::Exact(IpAddr::from_str("0.0.0.0").unwrap())
|
||||
}
|
||||
|
||||
fn matches(&self, ip: &IpAddr) -> bool {
|
||||
match self {
|
||||
IpPattern::Wildcard => true,
|
||||
IpPattern::Exact(addr) => addr == ip,
|
||||
IpPattern::Cidr(net) => net.contains(ip),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl IpFilter {
|
||||
/// Create a new IP filter from allow and block lists.
|
||||
pub fn new(allow_list: &[String], block_list: &[String]) -> Self {
|
||||
Self {
|
||||
allow_list: allow_list.iter().map(|s| IpPattern::parse(s)).collect(),
|
||||
block_list: block_list.iter().map(|s| IpPattern::parse(s)).collect(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if an IP is allowed.
|
||||
/// If allow_list is non-empty, IP must match at least one entry.
|
||||
/// If block_list is non-empty, IP must NOT match any entry.
|
||||
pub fn is_allowed(&self, ip: &IpAddr) -> bool {
|
||||
// Check block list first
|
||||
if !self.block_list.is_empty() {
|
||||
for pattern in &self.block_list {
|
||||
if pattern.matches(ip) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If allow list is non-empty, must match at least one
|
||||
if !self.allow_list.is_empty() {
|
||||
return self.allow_list.iter().any(|p| p.matches(ip));
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
/// Normalize IPv4-mapped IPv6 addresses (::ffff:x.x.x.x -> x.x.x.x)
|
||||
pub fn normalize_ip(ip: &IpAddr) -> IpAddr {
|
||||
match ip {
|
||||
IpAddr::V6(v6) => {
|
||||
if let Some(v4) = v6.to_ipv4_mapped() {
|
||||
IpAddr::V4(v4)
|
||||
} else {
|
||||
*ip
|
||||
}
|
||||
}
|
||||
_ => *ip,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_empty_lists_allow_all() {
|
||||
let filter = IpFilter::new(&[], &[]);
|
||||
let ip: IpAddr = "192.168.1.1".parse().unwrap();
|
||||
assert!(filter.is_allowed(&ip));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_allow_list_exact() {
|
||||
let filter = IpFilter::new(
|
||||
&["10.0.0.1".to_string()],
|
||||
&[],
|
||||
);
|
||||
let allowed: IpAddr = "10.0.0.1".parse().unwrap();
|
||||
let denied: IpAddr = "10.0.0.2".parse().unwrap();
|
||||
assert!(filter.is_allowed(&allowed));
|
||||
assert!(!filter.is_allowed(&denied));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_allow_list_cidr() {
|
||||
let filter = IpFilter::new(
|
||||
&["10.0.0.0/8".to_string()],
|
||||
&[],
|
||||
);
|
||||
let allowed: IpAddr = "10.255.255.255".parse().unwrap();
|
||||
let denied: IpAddr = "192.168.1.1".parse().unwrap();
|
||||
assert!(filter.is_allowed(&allowed));
|
||||
assert!(!filter.is_allowed(&denied));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_block_list() {
|
||||
let filter = IpFilter::new(
|
||||
&[],
|
||||
&["192.168.1.100".to_string()],
|
||||
);
|
||||
let blocked: IpAddr = "192.168.1.100".parse().unwrap();
|
||||
let allowed: IpAddr = "192.168.1.101".parse().unwrap();
|
||||
assert!(!filter.is_allowed(&blocked));
|
||||
assert!(filter.is_allowed(&allowed));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_block_trumps_allow() {
|
||||
let filter = IpFilter::new(
|
||||
&["10.0.0.0/8".to_string()],
|
||||
&["10.0.0.5".to_string()],
|
||||
);
|
||||
let blocked: IpAddr = "10.0.0.5".parse().unwrap();
|
||||
let allowed: IpAddr = "10.0.0.6".parse().unwrap();
|
||||
assert!(!filter.is_allowed(&blocked));
|
||||
assert!(filter.is_allowed(&allowed));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_wildcard_allow() {
|
||||
let filter = IpFilter::new(
|
||||
&["*".to_string()],
|
||||
&[],
|
||||
);
|
||||
let ip: IpAddr = "1.2.3.4".parse().unwrap();
|
||||
assert!(filter.is_allowed(&ip));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_wildcard_block() {
|
||||
let filter = IpFilter::new(
|
||||
&[],
|
||||
&["*".to_string()],
|
||||
);
|
||||
let ip: IpAddr = "1.2.3.4".parse().unwrap();
|
||||
assert!(!filter.is_allowed(&ip));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_normalize_ipv4_mapped_ipv6() {
|
||||
let mapped: IpAddr = "::ffff:192.168.1.1".parse().unwrap();
|
||||
let normalized = IpFilter::normalize_ip(&mapped);
|
||||
let expected: IpAddr = "192.168.1.1".parse().unwrap();
|
||||
assert_eq!(normalized, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_normalize_pure_ipv4() {
|
||||
let ip: IpAddr = "10.0.0.1".parse().unwrap();
|
||||
let normalized = IpFilter::normalize_ip(&ip);
|
||||
assert_eq!(normalized, ip);
|
||||
}
|
||||
}
|
||||
174
rust/crates/rustproxy-security/src/jwt_auth.rs
Normal file
174
rust/crates/rustproxy-security/src/jwt_auth.rs
Normal file
@@ -0,0 +1,174 @@
|
||||
use jsonwebtoken::{decode, DecodingKey, Validation, Algorithm};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// JWT claims (minimal structure).
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct Claims {
|
||||
pub sub: Option<String>,
|
||||
pub exp: Option<u64>,
|
||||
pub iss: Option<String>,
|
||||
pub aud: Option<String>,
|
||||
}
|
||||
|
||||
/// JWT auth validator.
|
||||
pub struct JwtValidator {
|
||||
decoding_key: DecodingKey,
|
||||
validation: Validation,
|
||||
}
|
||||
|
||||
impl JwtValidator {
|
||||
pub fn new(
|
||||
secret: &str,
|
||||
algorithm: Option<&str>,
|
||||
issuer: Option<&str>,
|
||||
audience: Option<&str>,
|
||||
) -> Self {
|
||||
let algo = match algorithm {
|
||||
Some("HS384") => Algorithm::HS384,
|
||||
Some("HS512") => Algorithm::HS512,
|
||||
Some("RS256") => Algorithm::RS256,
|
||||
_ => Algorithm::HS256,
|
||||
};
|
||||
|
||||
let mut validation = Validation::new(algo);
|
||||
if let Some(iss) = issuer {
|
||||
validation.set_issuer(&[iss]);
|
||||
}
|
||||
if let Some(aud) = audience {
|
||||
validation.set_audience(&[aud]);
|
||||
}
|
||||
|
||||
Self {
|
||||
decoding_key: DecodingKey::from_secret(secret.as_bytes()),
|
||||
validation,
|
||||
}
|
||||
}
|
||||
|
||||
/// Validate a JWT token string (without "Bearer " prefix).
|
||||
/// Returns the claims if valid.
|
||||
pub fn validate(&self, token: &str) -> Result<Claims, String> {
|
||||
decode::<Claims>(token, &self.decoding_key, &self.validation)
|
||||
.map(|data| data.claims)
|
||||
.map_err(|e| e.to_string())
|
||||
}
|
||||
|
||||
/// Extract token from Authorization header.
|
||||
pub fn extract_token(auth_header: &str) -> Option<&str> {
|
||||
let header = auth_header.trim();
|
||||
if header.starts_with("Bearer ") {
|
||||
Some(&header[7..])
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use jsonwebtoken::{encode, EncodingKey, Header};
|
||||
|
||||
fn make_token(secret: &str, claims: &Claims) -> String {
|
||||
encode(
|
||||
&Header::default(),
|
||||
claims,
|
||||
&EncodingKey::from_secret(secret.as_bytes()),
|
||||
)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
fn future_exp() -> u64 {
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs()
|
||||
+ 3600
|
||||
}
|
||||
|
||||
fn past_exp() -> u64 {
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs()
|
||||
- 3600
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_valid_token() {
|
||||
let secret = "test-secret";
|
||||
let claims = Claims {
|
||||
sub: Some("user123".to_string()),
|
||||
exp: Some(future_exp()),
|
||||
iss: None,
|
||||
aud: None,
|
||||
};
|
||||
let token = make_token(secret, &claims);
|
||||
let validator = JwtValidator::new(secret, None, None, None);
|
||||
let result = validator.validate(&token);
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(result.unwrap().sub, Some("user123".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_expired_token() {
|
||||
let secret = "test-secret";
|
||||
let claims = Claims {
|
||||
sub: Some("user123".to_string()),
|
||||
exp: Some(past_exp()),
|
||||
iss: None,
|
||||
aud: None,
|
||||
};
|
||||
let token = make_token(secret, &claims);
|
||||
let validator = JwtValidator::new(secret, None, None, None);
|
||||
assert!(validator.validate(&token).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_wrong_secret() {
|
||||
let claims = Claims {
|
||||
sub: Some("user123".to_string()),
|
||||
exp: Some(future_exp()),
|
||||
iss: None,
|
||||
aud: None,
|
||||
};
|
||||
let token = make_token("correct-secret", &claims);
|
||||
let validator = JwtValidator::new("wrong-secret", None, None, None);
|
||||
assert!(validator.validate(&token).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_issuer_validation() {
|
||||
let secret = "test-secret";
|
||||
let claims = Claims {
|
||||
sub: Some("user123".to_string()),
|
||||
exp: Some(future_exp()),
|
||||
iss: Some("my-issuer".to_string()),
|
||||
aud: None,
|
||||
};
|
||||
let token = make_token(secret, &claims);
|
||||
|
||||
// Correct issuer
|
||||
let validator = JwtValidator::new(secret, None, Some("my-issuer"), None);
|
||||
assert!(validator.validate(&token).is_ok());
|
||||
|
||||
// Wrong issuer
|
||||
let validator = JwtValidator::new(secret, None, Some("other-issuer"), None);
|
||||
assert!(validator.validate(&token).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_token_bearer() {
|
||||
assert_eq!(
|
||||
JwtValidator::extract_token("Bearer abc123"),
|
||||
Some("abc123")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_token_non_bearer() {
|
||||
assert_eq!(JwtValidator::extract_token("Basic abc123"), None);
|
||||
assert_eq!(JwtValidator::extract_token("abc123"), None);
|
||||
}
|
||||
}
|
||||
13
rust/crates/rustproxy-security/src/lib.rs
Normal file
13
rust/crates/rustproxy-security/src/lib.rs
Normal file
@@ -0,0 +1,13 @@
|
||||
//! # rustproxy-security
|
||||
//!
|
||||
//! IP filtering, rate limiting, and authentication for RustProxy.
|
||||
|
||||
pub mod ip_filter;
|
||||
pub mod rate_limiter;
|
||||
pub mod basic_auth;
|
||||
pub mod jwt_auth;
|
||||
|
||||
pub use ip_filter::*;
|
||||
pub use rate_limiter::*;
|
||||
pub use basic_auth::*;
|
||||
pub use jwt_auth::*;
|
||||
97
rust/crates/rustproxy-security/src/rate_limiter.rs
Normal file
97
rust/crates/rustproxy-security/src/rate_limiter.rs
Normal file
@@ -0,0 +1,97 @@
|
||||
use dashmap::DashMap;
|
||||
use std::time::Instant;
|
||||
|
||||
/// Sliding window rate limiter.
|
||||
pub struct RateLimiter {
|
||||
/// Map of key -> list of request timestamps
|
||||
windows: DashMap<String, Vec<Instant>>,
|
||||
/// Maximum requests per window
|
||||
max_requests: u64,
|
||||
/// Window duration in seconds
|
||||
window_seconds: u64,
|
||||
}
|
||||
|
||||
impl RateLimiter {
|
||||
pub fn new(max_requests: u64, window_seconds: u64) -> Self {
|
||||
Self {
|
||||
windows: DashMap::new(),
|
||||
max_requests,
|
||||
window_seconds,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if a request is allowed for the given key.
|
||||
/// Returns true if allowed, false if rate limited.
|
||||
pub fn check(&self, key: &str) -> bool {
|
||||
let now = Instant::now();
|
||||
let window = std::time::Duration::from_secs(self.window_seconds);
|
||||
|
||||
let mut entry = self.windows.entry(key.to_string()).or_default();
|
||||
let timestamps = entry.value_mut();
|
||||
|
||||
// Remove expired entries
|
||||
timestamps.retain(|t| now.duration_since(*t) < window);
|
||||
|
||||
if timestamps.len() as u64 >= self.max_requests {
|
||||
false
|
||||
} else {
|
||||
timestamps.push(now);
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
/// Clean up expired entries (call periodically).
|
||||
pub fn cleanup(&self) {
|
||||
let now = Instant::now();
|
||||
let window = std::time::Duration::from_secs(self.window_seconds);
|
||||
|
||||
self.windows.retain(|_, timestamps| {
|
||||
timestamps.retain(|t| now.duration_since(*t) < window);
|
||||
!timestamps.is_empty()
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_allow_under_limit() {
|
||||
let limiter = RateLimiter::new(5, 60);
|
||||
for _ in 0..5 {
|
||||
assert!(limiter.check("client-1"));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_block_over_limit() {
|
||||
let limiter = RateLimiter::new(3, 60);
|
||||
assert!(limiter.check("client-1"));
|
||||
assert!(limiter.check("client-1"));
|
||||
assert!(limiter.check("client-1"));
|
||||
assert!(!limiter.check("client-1")); // 4th request blocked
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_different_keys_independent() {
|
||||
let limiter = RateLimiter::new(2, 60);
|
||||
assert!(limiter.check("client-a"));
|
||||
assert!(limiter.check("client-a"));
|
||||
assert!(!limiter.check("client-a")); // blocked
|
||||
// Different key should still be allowed
|
||||
assert!(limiter.check("client-b"));
|
||||
assert!(limiter.check("client-b"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cleanup_removes_expired() {
|
||||
let limiter = RateLimiter::new(100, 0); // 0 second window = immediately expired
|
||||
limiter.check("client-1");
|
||||
// Sleep briefly to let entries expire
|
||||
std::thread::sleep(std::time::Duration::from_millis(10));
|
||||
limiter.cleanup();
|
||||
// After cleanup, the key should be allowed again (entries expired)
|
||||
assert!(limiter.check("client-1"));
|
||||
}
|
||||
}
|
||||
22
rust/crates/rustproxy-tls/Cargo.toml
Normal file
22
rust/crates/rustproxy-tls/Cargo.toml
Normal file
@@ -0,0 +1,22 @@
|
||||
[package]
|
||||
name = "rustproxy-tls"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
authors.workspace = true
|
||||
description = "TLS certificate management for RustProxy"
|
||||
|
||||
[dependencies]
|
||||
rustproxy-config = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
rustls = { workspace = true }
|
||||
instant-acme = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
anyhow = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
rcgen = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
tempfile = { workspace = true }
|
||||
360
rust/crates/rustproxy-tls/src/acme.rs
Normal file
360
rust/crates/rustproxy-tls/src/acme.rs
Normal file
@@ -0,0 +1,360 @@
|
||||
//! ACME (Let's Encrypt) integration using instant-acme.
|
||||
//!
|
||||
//! This module handles HTTP-01 challenge creation and certificate provisioning.
|
||||
//! Supports persisting ACME account credentials to disk for reuse across restarts.
|
||||
|
||||
use std::path::{Path, PathBuf};
|
||||
use instant_acme::{
|
||||
Account, NewAccount, NewOrder, Identifier, ChallengeType, OrderStatus,
|
||||
AccountCredentials,
|
||||
};
|
||||
use rcgen::{CertificateParams, KeyPair};
|
||||
use thiserror::Error;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum AcmeError {
|
||||
#[error("ACME account creation failed: {0}")]
|
||||
AccountCreation(String),
|
||||
#[error("ACME order failed: {0}")]
|
||||
OrderFailed(String),
|
||||
#[error("Challenge failed: {0}")]
|
||||
ChallengeFailed(String),
|
||||
#[error("Certificate finalization failed: {0}")]
|
||||
FinalizationFailed(String),
|
||||
#[error("No HTTP-01 challenge found")]
|
||||
NoHttp01Challenge,
|
||||
#[error("Timeout waiting for order: {0}")]
|
||||
Timeout(String),
|
||||
#[error("Account persistence error: {0}")]
|
||||
Persistence(String),
|
||||
}
|
||||
|
||||
/// Pending HTTP-01 challenge that needs to be served.
|
||||
pub struct PendingChallenge {
|
||||
pub token: String,
|
||||
pub key_authorization: String,
|
||||
pub domain: String,
|
||||
}
|
||||
|
||||
/// ACME client wrapper around instant-acme.
|
||||
pub struct AcmeClient {
|
||||
use_production: bool,
|
||||
email: String,
|
||||
/// Optional directory where account.json is persisted.
|
||||
account_dir: Option<PathBuf>,
|
||||
}
|
||||
|
||||
impl AcmeClient {
|
||||
pub fn new(email: String, use_production: bool) -> Self {
|
||||
Self {
|
||||
use_production,
|
||||
email,
|
||||
account_dir: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new client with account persistence at the given directory.
|
||||
pub fn with_persistence(email: String, use_production: bool, account_dir: impl AsRef<Path>) -> Self {
|
||||
Self {
|
||||
use_production,
|
||||
email,
|
||||
account_dir: Some(account_dir.as_ref().to_path_buf()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get or create an ACME account, persisting credentials if account_dir is set.
|
||||
async fn get_or_create_account(&self) -> Result<Account, AcmeError> {
|
||||
let directory_url = self.directory_url();
|
||||
|
||||
// Try to restore from persisted credentials
|
||||
if let Some(ref dir) = self.account_dir {
|
||||
let account_file = dir.join("account.json");
|
||||
if account_file.exists() {
|
||||
match std::fs::read_to_string(&account_file) {
|
||||
Ok(json) => {
|
||||
match serde_json::from_str::<AccountCredentials>(&json) {
|
||||
Ok(credentials) => {
|
||||
match Account::from_credentials(credentials).await {
|
||||
Ok(account) => {
|
||||
debug!("Restored ACME account from {}", account_file.display());
|
||||
return Ok(account);
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to restore ACME account, creating new: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Invalid account.json, creating new account: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Could not read account.json: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create a new account
|
||||
let contact = format!("mailto:{}", self.email);
|
||||
let (account, credentials) = Account::create(
|
||||
&NewAccount {
|
||||
contact: &[&contact],
|
||||
terms_of_service_agreed: true,
|
||||
only_return_existing: false,
|
||||
},
|
||||
directory_url,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| AcmeError::AccountCreation(e.to_string()))?;
|
||||
|
||||
debug!("ACME account created");
|
||||
|
||||
// Persist credentials if we have a directory
|
||||
if let Some(ref dir) = self.account_dir {
|
||||
if let Err(e) = std::fs::create_dir_all(dir) {
|
||||
warn!("Failed to create account directory {}: {}", dir.display(), e);
|
||||
} else {
|
||||
let account_file = dir.join("account.json");
|
||||
match serde_json::to_string_pretty(&credentials) {
|
||||
Ok(json) => {
|
||||
if let Err(e) = std::fs::write(&account_file, &json) {
|
||||
warn!("Failed to persist ACME account to {}: {}", account_file.display(), e);
|
||||
} else {
|
||||
info!("ACME account credentials persisted to {}", account_file.display());
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to serialize account credentials: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(account)
|
||||
}
|
||||
|
||||
/// Request a certificate for a domain using the HTTP-01 challenge.
|
||||
///
|
||||
/// Returns (cert_chain_pem, private_key_pem) on success.
|
||||
///
|
||||
/// The caller must serve the HTTP-01 challenge at:
|
||||
/// `http://<domain>/.well-known/acme-challenge/<token>`
|
||||
///
|
||||
/// The `challenge_handler` closure is called with a `PendingChallenge`
|
||||
/// and must arrange for the challenge response to be served. It should
|
||||
/// return once the challenge is ready to be validated.
|
||||
pub async fn provision<F, Fut>(
|
||||
&self,
|
||||
domain: &str,
|
||||
challenge_handler: F,
|
||||
) -> Result<(String, String), AcmeError>
|
||||
where
|
||||
F: FnOnce(PendingChallenge) -> Fut,
|
||||
Fut: std::future::Future<Output = Result<(), AcmeError>>,
|
||||
{
|
||||
info!("Starting ACME provisioning for {} via {}", domain, self.directory_url());
|
||||
|
||||
// 1. Get or create ACME account (with persistence)
|
||||
let account = self.get_or_create_account().await?;
|
||||
|
||||
// 2. Create order
|
||||
let identifier = Identifier::Dns(domain.to_string());
|
||||
let mut order = account
|
||||
.new_order(&NewOrder {
|
||||
identifiers: &[identifier],
|
||||
})
|
||||
.await
|
||||
.map_err(|e| AcmeError::OrderFailed(e.to_string()))?;
|
||||
|
||||
debug!("ACME order created");
|
||||
|
||||
// 3. Get authorizations and find HTTP-01 challenge
|
||||
let authorizations = order
|
||||
.authorizations()
|
||||
.await
|
||||
.map_err(|e| AcmeError::OrderFailed(e.to_string()))?;
|
||||
|
||||
// Find the HTTP-01 challenge
|
||||
let (challenge_token, challenge_url) = authorizations
|
||||
.iter()
|
||||
.flat_map(|auth| auth.challenges.iter())
|
||||
.find(|c| c.r#type == ChallengeType::Http01)
|
||||
.map(|c| {
|
||||
let key_auth = order.key_authorization(c);
|
||||
(
|
||||
PendingChallenge {
|
||||
token: c.token.clone(),
|
||||
key_authorization: key_auth.as_str().to_string(),
|
||||
domain: domain.to_string(),
|
||||
},
|
||||
c.url.clone(),
|
||||
)
|
||||
})
|
||||
.ok_or(AcmeError::NoHttp01Challenge)?;
|
||||
|
||||
// Call the handler to set up challenge serving
|
||||
challenge_handler(challenge_token).await?;
|
||||
|
||||
// 4. Notify ACME server that challenge is ready
|
||||
order
|
||||
.set_challenge_ready(&challenge_url)
|
||||
.await
|
||||
.map_err(|e| AcmeError::ChallengeFailed(e.to_string()))?;
|
||||
|
||||
debug!("Challenge marked as ready, waiting for validation...");
|
||||
|
||||
// 5. Poll for order to become ready
|
||||
let mut attempts = 0;
|
||||
let state = loop {
|
||||
tokio::time::sleep(std::time::Duration::from_secs(2)).await;
|
||||
let state = order
|
||||
.refresh()
|
||||
.await
|
||||
.map_err(|e| AcmeError::OrderFailed(e.to_string()))?;
|
||||
|
||||
match state.status {
|
||||
OrderStatus::Ready | OrderStatus::Valid => break state.status,
|
||||
OrderStatus::Invalid => {
|
||||
return Err(AcmeError::ChallengeFailed(
|
||||
"Order became invalid (challenge failed)".to_string(),
|
||||
));
|
||||
}
|
||||
_ => {
|
||||
attempts += 1;
|
||||
if attempts > 30 {
|
||||
return Err(AcmeError::Timeout(
|
||||
"Order did not become ready within 60 seconds".to_string(),
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
debug!("Order ready, finalizing...");
|
||||
|
||||
// 6. Generate CSR and finalize
|
||||
let key_pair = KeyPair::generate().map_err(|e| {
|
||||
AcmeError::FinalizationFailed(format!("Key generation failed: {}", e))
|
||||
})?;
|
||||
|
||||
let mut params = CertificateParams::new(vec![domain.to_string()]).map_err(|e| {
|
||||
AcmeError::FinalizationFailed(format!("CSR params failed: {}", e))
|
||||
})?;
|
||||
params.distinguished_name.push(rcgen::DnType::CommonName, domain);
|
||||
|
||||
let csr = params.serialize_request(&key_pair).map_err(|e| {
|
||||
AcmeError::FinalizationFailed(format!("CSR serialization failed: {}", e))
|
||||
})?;
|
||||
|
||||
if state == OrderStatus::Ready {
|
||||
order
|
||||
.finalize(csr.der())
|
||||
.await
|
||||
.map_err(|e| AcmeError::FinalizationFailed(e.to_string()))?;
|
||||
}
|
||||
|
||||
// 7. Wait for certificate to be issued
|
||||
let mut attempts = 0;
|
||||
loop {
|
||||
let state = order
|
||||
.refresh()
|
||||
.await
|
||||
.map_err(|e| AcmeError::OrderFailed(e.to_string()))?;
|
||||
if state.status == OrderStatus::Valid {
|
||||
break;
|
||||
}
|
||||
if state.status == OrderStatus::Invalid {
|
||||
return Err(AcmeError::FinalizationFailed(
|
||||
"Order became invalid during finalization".to_string(),
|
||||
));
|
||||
}
|
||||
attempts += 1;
|
||||
if attempts > 15 {
|
||||
return Err(AcmeError::Timeout(
|
||||
"Certificate not issued within 30 seconds".to_string(),
|
||||
));
|
||||
}
|
||||
tokio::time::sleep(std::time::Duration::from_secs(2)).await;
|
||||
}
|
||||
|
||||
// 8. Download certificate
|
||||
let cert_chain_pem = order
|
||||
.certificate()
|
||||
.await
|
||||
.map_err(|e| AcmeError::FinalizationFailed(e.to_string()))?
|
||||
.ok_or_else(|| {
|
||||
AcmeError::FinalizationFailed("No certificate returned".to_string())
|
||||
})?;
|
||||
|
||||
let private_key_pem = key_pair.serialize_pem();
|
||||
|
||||
info!("Certificate provisioned successfully for {}", domain);
|
||||
|
||||
Ok((cert_chain_pem, private_key_pem))
|
||||
}
|
||||
|
||||
/// Restore an ACME account from stored credentials.
|
||||
pub async fn restore_account(
|
||||
&self,
|
||||
credentials: AccountCredentials,
|
||||
) -> Result<Account, AcmeError> {
|
||||
Account::from_credentials(credentials)
|
||||
.await
|
||||
.map_err(|e| AcmeError::AccountCreation(e.to_string()))
|
||||
}
|
||||
|
||||
/// Get the ACME directory URL based on production/staging.
|
||||
pub fn directory_url(&self) -> &str {
|
||||
if self.use_production {
|
||||
"https://acme-v02.api.letsencrypt.org/directory"
|
||||
} else {
|
||||
"https://acme-staging-v02.api.letsencrypt.org/directory"
|
||||
}
|
||||
}
|
||||
|
||||
/// Whether this client is configured for production.
|
||||
pub fn is_production(&self) -> bool {
|
||||
self.use_production
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_directory_url_staging() {
|
||||
let client = AcmeClient::new("test@example.com".to_string(), false);
|
||||
assert!(client.directory_url().contains("staging"));
|
||||
assert!(!client.is_production());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_directory_url_production() {
|
||||
let client = AcmeClient::new("test@example.com".to_string(), true);
|
||||
assert!(!client.directory_url().contains("staging"));
|
||||
assert!(client.is_production());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_with_persistence_sets_account_dir() {
|
||||
let tmp = tempfile::tempdir().unwrap();
|
||||
let client = AcmeClient::with_persistence(
|
||||
"test@example.com".to_string(),
|
||||
false,
|
||||
tmp.path(),
|
||||
);
|
||||
assert!(client.account_dir.is_some());
|
||||
assert_eq!(client.account_dir.unwrap(), tmp.path());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_without_persistence_no_account_dir() {
|
||||
let client = AcmeClient::new("test@example.com".to_string(), false);
|
||||
assert!(client.account_dir.is_none());
|
||||
}
|
||||
}
|
||||
183
rust/crates/rustproxy-tls/src/cert_manager.rs
Normal file
183
rust/crates/rustproxy-tls/src/cert_manager.rs
Normal file
@@ -0,0 +1,183 @@
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
use thiserror::Error;
|
||||
use tracing::info;
|
||||
|
||||
use crate::cert_store::{CertStore, CertBundle, CertMetadata, CertSource};
|
||||
use crate::acme::AcmeClient;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum CertManagerError {
|
||||
#[error("ACME provisioning failed for {domain}: {message}")]
|
||||
AcmeFailure { domain: String, message: String },
|
||||
#[error("Certificate store error: {0}")]
|
||||
Store(#[from] crate::cert_store::CertStoreError),
|
||||
#[error("No ACME email configured")]
|
||||
NoEmail,
|
||||
}
|
||||
|
||||
/// Certificate lifecycle manager.
|
||||
/// Handles ACME provisioning, static cert loading, and renewal.
|
||||
pub struct CertManager {
|
||||
store: CertStore,
|
||||
acme_email: Option<String>,
|
||||
use_production: bool,
|
||||
renew_before_days: u32,
|
||||
}
|
||||
|
||||
impl CertManager {
|
||||
pub fn new(
|
||||
store: CertStore,
|
||||
acme_email: Option<String>,
|
||||
use_production: bool,
|
||||
renew_before_days: u32,
|
||||
) -> Self {
|
||||
Self {
|
||||
store,
|
||||
acme_email,
|
||||
use_production,
|
||||
renew_before_days,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get a certificate for a domain (from cache).
|
||||
pub fn get_cert(&self, domain: &str) -> Option<&CertBundle> {
|
||||
self.store.get(domain)
|
||||
}
|
||||
|
||||
/// Create an ACME client using this manager's configuration.
|
||||
/// Returns None if no ACME email is configured.
|
||||
/// Account credentials are persisted in the cert store base directory.
|
||||
pub fn acme_client(&self) -> Option<AcmeClient> {
|
||||
self.acme_email.as_ref().map(|email| {
|
||||
AcmeClient::with_persistence(
|
||||
email.clone(),
|
||||
self.use_production,
|
||||
self.store.base_dir(),
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
/// Load a static certificate into the store.
|
||||
pub fn load_static(
|
||||
&mut self,
|
||||
domain: String,
|
||||
bundle: CertBundle,
|
||||
) -> Result<(), CertManagerError> {
|
||||
self.store.store(domain, bundle)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check and return domains that need certificate renewal.
|
||||
///
|
||||
/// A certificate needs renewal if it expires within `renew_before_days`.
|
||||
/// Returns a list of domain names needing renewal.
|
||||
pub fn check_renewals(&self) -> Vec<String> {
|
||||
let now = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs();
|
||||
|
||||
let renewal_threshold = self.renew_before_days as u64 * 86400;
|
||||
let mut needs_renewal = Vec::new();
|
||||
|
||||
for (domain, bundle) in self.store.iter() {
|
||||
// Only auto-renew ACME certs
|
||||
if bundle.metadata.source != CertSource::Acme {
|
||||
continue;
|
||||
}
|
||||
|
||||
let time_until_expiry = bundle.metadata.expires_at.saturating_sub(now);
|
||||
if time_until_expiry < renewal_threshold {
|
||||
info!(
|
||||
"Certificate for {} needs renewal (expires in {} days)",
|
||||
domain,
|
||||
time_until_expiry / 86400
|
||||
);
|
||||
needs_renewal.push(domain.clone());
|
||||
}
|
||||
}
|
||||
|
||||
needs_renewal
|
||||
}
|
||||
|
||||
/// Renew a certificate for a domain.
|
||||
///
|
||||
/// Performs the full ACME provision+store flow. The `challenge_setup` closure
|
||||
/// is called to arrange for the HTTP-01 challenge to be served. It receives
|
||||
/// (token, key_authorization) and must make the challenge response available.
|
||||
///
|
||||
/// Returns the new CertBundle on success.
|
||||
pub async fn renew_domain<F, Fut>(
|
||||
&mut self,
|
||||
domain: &str,
|
||||
challenge_setup: F,
|
||||
) -> Result<CertBundle, CertManagerError>
|
||||
where
|
||||
F: FnOnce(String, String) -> Fut,
|
||||
Fut: std::future::Future<Output = ()>,
|
||||
{
|
||||
let acme_client = self.acme_client()
|
||||
.ok_or(CertManagerError::NoEmail)?;
|
||||
|
||||
info!("Renewing certificate for {}", domain);
|
||||
|
||||
let domain_owned = domain.to_string();
|
||||
let result = acme_client.provision(&domain_owned, |pending| {
|
||||
let token = pending.token.clone();
|
||||
let key_auth = pending.key_authorization.clone();
|
||||
async move {
|
||||
challenge_setup(token, key_auth).await;
|
||||
Ok(())
|
||||
}
|
||||
}).await.map_err(|e| CertManagerError::AcmeFailure {
|
||||
domain: domain.to_string(),
|
||||
message: e.to_string(),
|
||||
})?;
|
||||
|
||||
let (cert_pem, key_pem) = result;
|
||||
let now = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs();
|
||||
|
||||
let bundle = CertBundle {
|
||||
cert_pem,
|
||||
key_pem,
|
||||
ca_pem: None,
|
||||
metadata: CertMetadata {
|
||||
domain: domain.to_string(),
|
||||
source: CertSource::Acme,
|
||||
issued_at: now,
|
||||
expires_at: now + 90 * 86400,
|
||||
renewed_at: Some(now),
|
||||
},
|
||||
};
|
||||
|
||||
self.store.store(domain.to_string(), bundle.clone())?;
|
||||
info!("Certificate renewed and stored for {}", domain);
|
||||
|
||||
Ok(bundle)
|
||||
}
|
||||
|
||||
/// Load all certificates from disk.
|
||||
pub fn load_all(&mut self) -> Result<usize, CertManagerError> {
|
||||
let loaded = self.store.load_all()?;
|
||||
info!("Loaded {} certificates from store", loaded);
|
||||
Ok(loaded)
|
||||
}
|
||||
|
||||
/// Whether this manager has an ACME email configured.
|
||||
pub fn has_acme(&self) -> bool {
|
||||
self.acme_email.is_some()
|
||||
}
|
||||
|
||||
/// Get reference to the underlying store.
|
||||
pub fn store(&self) -> &CertStore {
|
||||
&self.store
|
||||
}
|
||||
|
||||
/// Get mutable reference to the underlying store.
|
||||
pub fn store_mut(&mut self) -> &mut CertStore {
|
||||
&mut self.store
|
||||
}
|
||||
}
|
||||
314
rust/crates/rustproxy-tls/src/cert_store.rs
Normal file
314
rust/crates/rustproxy-tls/src/cert_store.rs
Normal file
@@ -0,0 +1,314 @@
|
||||
use std::collections::HashMap;
|
||||
use std::path::{Path, PathBuf};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum CertStoreError {
|
||||
#[error("Certificate not found for domain: {0}")]
|
||||
NotFound(String),
|
||||
#[error("IO error: {0}")]
|
||||
Io(#[from] std::io::Error),
|
||||
#[error("Invalid certificate: {0}")]
|
||||
Invalid(String),
|
||||
#[error("JSON error: {0}")]
|
||||
Json(#[from] serde_json::Error),
|
||||
}
|
||||
|
||||
/// Certificate metadata stored alongside certs on disk.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct CertMetadata {
|
||||
pub domain: String,
|
||||
pub source: CertSource,
|
||||
pub issued_at: u64,
|
||||
pub expires_at: u64,
|
||||
pub renewed_at: Option<u64>,
|
||||
}
|
||||
|
||||
/// How a certificate was obtained.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum CertSource {
|
||||
Acme,
|
||||
Static,
|
||||
Custom,
|
||||
SelfSigned,
|
||||
}
|
||||
|
||||
/// An in-memory certificate bundle.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CertBundle {
|
||||
pub key_pem: String,
|
||||
pub cert_pem: String,
|
||||
pub ca_pem: Option<String>,
|
||||
pub metadata: CertMetadata,
|
||||
}
|
||||
|
||||
/// Filesystem-backed certificate store.
|
||||
///
|
||||
/// File layout per domain:
|
||||
/// ```text
|
||||
/// {base_dir}/{domain}/
|
||||
/// key.pem
|
||||
/// cert.pem
|
||||
/// ca.pem (optional)
|
||||
/// metadata.json
|
||||
/// ```
|
||||
pub struct CertStore {
|
||||
base_dir: PathBuf,
|
||||
/// In-memory cache of loaded certs
|
||||
cache: HashMap<String, CertBundle>,
|
||||
}
|
||||
|
||||
impl CertStore {
|
||||
/// Create a new cert store at the given directory.
|
||||
pub fn new(base_dir: impl AsRef<Path>) -> Self {
|
||||
Self {
|
||||
base_dir: base_dir.as_ref().to_path_buf(),
|
||||
cache: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get a certificate by domain.
|
||||
pub fn get(&self, domain: &str) -> Option<&CertBundle> {
|
||||
self.cache.get(domain)
|
||||
}
|
||||
|
||||
/// Store a certificate to both cache and filesystem.
|
||||
pub fn store(&mut self, domain: String, bundle: CertBundle) -> Result<(), CertStoreError> {
|
||||
// Sanitize domain for directory name (replace wildcards)
|
||||
let dir_name = domain.replace('*', "_wildcard_");
|
||||
let cert_dir = self.base_dir.join(&dir_name);
|
||||
|
||||
// Create directory
|
||||
std::fs::create_dir_all(&cert_dir)?;
|
||||
|
||||
// Write key
|
||||
std::fs::write(cert_dir.join("key.pem"), &bundle.key_pem)?;
|
||||
|
||||
// Write cert
|
||||
std::fs::write(cert_dir.join("cert.pem"), &bundle.cert_pem)?;
|
||||
|
||||
// Write CA cert if present
|
||||
if let Some(ref ca) = bundle.ca_pem {
|
||||
std::fs::write(cert_dir.join("ca.pem"), ca)?;
|
||||
}
|
||||
|
||||
// Write metadata
|
||||
let metadata_json = serde_json::to_string_pretty(&bundle.metadata)?;
|
||||
std::fs::write(cert_dir.join("metadata.json"), metadata_json)?;
|
||||
|
||||
// Update cache
|
||||
self.cache.insert(domain, bundle);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check if a certificate exists for a domain.
|
||||
pub fn has(&self, domain: &str) -> bool {
|
||||
self.cache.contains_key(domain)
|
||||
}
|
||||
|
||||
/// Load all certificates from the base directory.
|
||||
pub fn load_all(&mut self) -> Result<usize, CertStoreError> {
|
||||
if !self.base_dir.exists() {
|
||||
return Ok(0);
|
||||
}
|
||||
|
||||
let entries = std::fs::read_dir(&self.base_dir)?;
|
||||
let mut loaded = 0;
|
||||
|
||||
for entry in entries {
|
||||
let entry = entry?;
|
||||
let path = entry.path();
|
||||
|
||||
if !path.is_dir() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let metadata_path = path.join("metadata.json");
|
||||
let key_path = path.join("key.pem");
|
||||
let cert_path = path.join("cert.pem");
|
||||
|
||||
// All three files must exist
|
||||
if !metadata_path.exists() || !key_path.exists() || !cert_path.exists() {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Load metadata
|
||||
let metadata_str = std::fs::read_to_string(&metadata_path)?;
|
||||
let metadata: CertMetadata = serde_json::from_str(&metadata_str)?;
|
||||
|
||||
// Load key and cert
|
||||
let key_pem = std::fs::read_to_string(&key_path)?;
|
||||
let cert_pem = std::fs::read_to_string(&cert_path)?;
|
||||
|
||||
// Load CA cert if present
|
||||
let ca_path = path.join("ca.pem");
|
||||
let ca_pem = if ca_path.exists() {
|
||||
Some(std::fs::read_to_string(&ca_path)?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let domain = metadata.domain.clone();
|
||||
let bundle = CertBundle {
|
||||
key_pem,
|
||||
cert_pem,
|
||||
ca_pem,
|
||||
metadata,
|
||||
};
|
||||
|
||||
self.cache.insert(domain, bundle);
|
||||
loaded += 1;
|
||||
}
|
||||
|
||||
Ok(loaded)
|
||||
}
|
||||
|
||||
/// Get the base directory.
|
||||
pub fn base_dir(&self) -> &Path {
|
||||
&self.base_dir
|
||||
}
|
||||
|
||||
/// Get the number of cached certificates.
|
||||
pub fn count(&self) -> usize {
|
||||
self.cache.len()
|
||||
}
|
||||
|
||||
/// Iterate over all cached certificates.
|
||||
pub fn iter(&self) -> impl Iterator<Item = (&String, &CertBundle)> {
|
||||
self.cache.iter()
|
||||
}
|
||||
|
||||
/// Remove a certificate from cache and filesystem.
|
||||
pub fn remove(&mut self, domain: &str) -> Result<bool, CertStoreError> {
|
||||
let removed = self.cache.remove(domain).is_some();
|
||||
if removed {
|
||||
let dir_name = domain.replace('*', "_wildcard_");
|
||||
let cert_dir = self.base_dir.join(&dir_name);
|
||||
if cert_dir.exists() {
|
||||
std::fs::remove_dir_all(&cert_dir)?;
|
||||
}
|
||||
}
|
||||
Ok(removed)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn make_test_bundle(domain: &str) -> CertBundle {
|
||||
CertBundle {
|
||||
key_pem: "-----BEGIN PRIVATE KEY-----\ntest-key\n-----END PRIVATE KEY-----\n".to_string(),
|
||||
cert_pem: "-----BEGIN CERTIFICATE-----\ntest-cert\n-----END CERTIFICATE-----\n".to_string(),
|
||||
ca_pem: None,
|
||||
metadata: CertMetadata {
|
||||
domain: domain.to_string(),
|
||||
source: CertSource::Static,
|
||||
issued_at: 1700000000,
|
||||
expires_at: 1700000000 + 90 * 86400,
|
||||
renewed_at: None,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_store_and_load_roundtrip() {
|
||||
let tmp = tempfile::tempdir().unwrap();
|
||||
let mut store = CertStore::new(tmp.path());
|
||||
|
||||
let bundle = make_test_bundle("example.com");
|
||||
store.store("example.com".to_string(), bundle.clone()).unwrap();
|
||||
|
||||
// Verify files exist
|
||||
let cert_dir = tmp.path().join("example.com");
|
||||
assert!(cert_dir.join("key.pem").exists());
|
||||
assert!(cert_dir.join("cert.pem").exists());
|
||||
assert!(cert_dir.join("metadata.json").exists());
|
||||
assert!(!cert_dir.join("ca.pem").exists()); // No CA cert
|
||||
|
||||
// Load into a fresh store
|
||||
let mut store2 = CertStore::new(tmp.path());
|
||||
let loaded = store2.load_all().unwrap();
|
||||
assert_eq!(loaded, 1);
|
||||
|
||||
let loaded_bundle = store2.get("example.com").unwrap();
|
||||
assert_eq!(loaded_bundle.key_pem, bundle.key_pem);
|
||||
assert_eq!(loaded_bundle.cert_pem, bundle.cert_pem);
|
||||
assert_eq!(loaded_bundle.metadata.domain, "example.com");
|
||||
assert_eq!(loaded_bundle.metadata.source, CertSource::Static);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_store_with_ca_cert() {
|
||||
let tmp = tempfile::tempdir().unwrap();
|
||||
let mut store = CertStore::new(tmp.path());
|
||||
|
||||
let mut bundle = make_test_bundle("secure.com");
|
||||
bundle.ca_pem = Some("-----BEGIN CERTIFICATE-----\nca-cert\n-----END CERTIFICATE-----\n".to_string());
|
||||
store.store("secure.com".to_string(), bundle).unwrap();
|
||||
|
||||
let cert_dir = tmp.path().join("secure.com");
|
||||
assert!(cert_dir.join("ca.pem").exists());
|
||||
|
||||
let mut store2 = CertStore::new(tmp.path());
|
||||
store2.load_all().unwrap();
|
||||
let loaded = store2.get("secure.com").unwrap();
|
||||
assert!(loaded.ca_pem.is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_load_all_multiple_certs() {
|
||||
let tmp = tempfile::tempdir().unwrap();
|
||||
let mut store = CertStore::new(tmp.path());
|
||||
|
||||
store.store("a.com".to_string(), make_test_bundle("a.com")).unwrap();
|
||||
store.store("b.com".to_string(), make_test_bundle("b.com")).unwrap();
|
||||
store.store("c.com".to_string(), make_test_bundle("c.com")).unwrap();
|
||||
|
||||
let mut store2 = CertStore::new(tmp.path());
|
||||
let loaded = store2.load_all().unwrap();
|
||||
assert_eq!(loaded, 3);
|
||||
assert!(store2.has("a.com"));
|
||||
assert!(store2.has("b.com"));
|
||||
assert!(store2.has("c.com"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_load_all_missing_directory() {
|
||||
let mut store = CertStore::new("/nonexistent/path/to/certs");
|
||||
let loaded = store.load_all().unwrap();
|
||||
assert_eq!(loaded, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_remove_cert() {
|
||||
let tmp = tempfile::tempdir().unwrap();
|
||||
let mut store = CertStore::new(tmp.path());
|
||||
|
||||
store.store("remove-me.com".to_string(), make_test_bundle("remove-me.com")).unwrap();
|
||||
assert!(store.has("remove-me.com"));
|
||||
|
||||
let removed = store.remove("remove-me.com").unwrap();
|
||||
assert!(removed);
|
||||
assert!(!store.has("remove-me.com"));
|
||||
assert!(!tmp.path().join("remove-me.com").exists());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_wildcard_domain_storage() {
|
||||
let tmp = tempfile::tempdir().unwrap();
|
||||
let mut store = CertStore::new(tmp.path());
|
||||
|
||||
store.store("*.example.com".to_string(), make_test_bundle("*.example.com")).unwrap();
|
||||
|
||||
// Directory should use sanitized name
|
||||
assert!(tmp.path().join("_wildcard_.example.com").exists());
|
||||
|
||||
let mut store2 = CertStore::new(tmp.path());
|
||||
store2.load_all().unwrap();
|
||||
assert!(store2.has("*.example.com"));
|
||||
}
|
||||
}
|
||||
13
rust/crates/rustproxy-tls/src/lib.rs
Normal file
13
rust/crates/rustproxy-tls/src/lib.rs
Normal file
@@ -0,0 +1,13 @@
|
||||
//! # rustproxy-tls
|
||||
//!
|
||||
//! TLS certificate management for RustProxy.
|
||||
//! Handles ACME (Let's Encrypt), static certificates, and dynamic SNI resolution.
|
||||
|
||||
pub mod cert_store;
|
||||
pub mod cert_manager;
|
||||
pub mod acme;
|
||||
pub mod sni_resolver;
|
||||
|
||||
pub use cert_store::*;
|
||||
pub use cert_manager::*;
|
||||
pub use sni_resolver::*;
|
||||
139
rust/crates/rustproxy-tls/src/sni_resolver.rs
Normal file
139
rust/crates/rustproxy-tls/src/sni_resolver.rs
Normal file
@@ -0,0 +1,139 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, RwLock};
|
||||
|
||||
use crate::cert_store::CertBundle;
|
||||
|
||||
/// Dynamic SNI-based certificate resolver.
|
||||
/// Used by the TLS stack to select the right certificate based on client SNI.
|
||||
pub struct SniResolver {
|
||||
/// Domain -> certificate bundle mapping
|
||||
certs: RwLock<HashMap<String, Arc<CertBundle>>>,
|
||||
/// Fallback certificate (used when no SNI or no match)
|
||||
fallback: RwLock<Option<Arc<CertBundle>>>,
|
||||
}
|
||||
|
||||
impl SniResolver {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
certs: RwLock::new(HashMap::new()),
|
||||
fallback: RwLock::new(None),
|
||||
}
|
||||
}
|
||||
|
||||
/// Register a certificate for a domain.
|
||||
pub fn add_cert(&self, domain: String, bundle: CertBundle) {
|
||||
let mut certs = self.certs.write().unwrap();
|
||||
certs.insert(domain, Arc::new(bundle));
|
||||
}
|
||||
|
||||
/// Set the fallback certificate.
|
||||
pub fn set_fallback(&self, bundle: CertBundle) {
|
||||
let mut fallback = self.fallback.write().unwrap();
|
||||
*fallback = Some(Arc::new(bundle));
|
||||
}
|
||||
|
||||
/// Resolve a certificate for the given SNI domain.
|
||||
pub fn resolve(&self, domain: &str) -> Option<Arc<CertBundle>> {
|
||||
let certs = self.certs.read().unwrap();
|
||||
|
||||
// Try exact match
|
||||
if let Some(bundle) = certs.get(domain) {
|
||||
return Some(Arc::clone(bundle));
|
||||
}
|
||||
|
||||
// Try wildcard match (e.g., *.example.com)
|
||||
if let Some(dot_pos) = domain.find('.') {
|
||||
let wildcard = format!("*.{}", &domain[dot_pos + 1..]);
|
||||
if let Some(bundle) = certs.get(&wildcard) {
|
||||
return Some(Arc::clone(bundle));
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback
|
||||
let fallback = self.fallback.read().unwrap();
|
||||
fallback.clone()
|
||||
}
|
||||
|
||||
/// Remove a certificate for a domain.
|
||||
pub fn remove_cert(&self, domain: &str) {
|
||||
let mut certs = self.certs.write().unwrap();
|
||||
certs.remove(domain);
|
||||
}
|
||||
|
||||
/// Get the number of registered certificates.
|
||||
pub fn cert_count(&self) -> usize {
|
||||
self.certs.read().unwrap().len()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for SniResolver {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::cert_store::{CertBundle, CertMetadata, CertSource};
|
||||
|
||||
fn make_bundle(domain: &str) -> CertBundle {
|
||||
CertBundle {
|
||||
key_pem: format!("KEY-{}", domain),
|
||||
cert_pem: format!("CERT-{}", domain),
|
||||
ca_pem: None,
|
||||
metadata: CertMetadata {
|
||||
domain: domain.to_string(),
|
||||
source: CertSource::Static,
|
||||
issued_at: 0,
|
||||
expires_at: 0,
|
||||
renewed_at: None,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_exact_domain_resolve() {
|
||||
let resolver = SniResolver::new();
|
||||
resolver.add_cert("example.com".to_string(), make_bundle("example.com"));
|
||||
let result = resolver.resolve("example.com");
|
||||
assert!(result.is_some());
|
||||
assert_eq!(result.unwrap().cert_pem, "CERT-example.com");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_wildcard_resolve() {
|
||||
let resolver = SniResolver::new();
|
||||
resolver.add_cert("*.example.com".to_string(), make_bundle("*.example.com"));
|
||||
let result = resolver.resolve("sub.example.com");
|
||||
assert!(result.is_some());
|
||||
assert_eq!(result.unwrap().cert_pem, "CERT-*.example.com");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fallback() {
|
||||
let resolver = SniResolver::new();
|
||||
resolver.set_fallback(make_bundle("fallback"));
|
||||
let result = resolver.resolve("unknown.com");
|
||||
assert!(result.is_some());
|
||||
assert_eq!(result.unwrap().cert_pem, "CERT-fallback");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_no_match_no_fallback() {
|
||||
let resolver = SniResolver::new();
|
||||
resolver.add_cert("example.com".to_string(), make_bundle("example.com"));
|
||||
let result = resolver.resolve("other.com");
|
||||
assert!(result.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_remove_cert() {
|
||||
let resolver = SniResolver::new();
|
||||
resolver.add_cert("example.com".to_string(), make_bundle("example.com"));
|
||||
assert_eq!(resolver.cert_count(), 1);
|
||||
resolver.remove_cert("example.com");
|
||||
assert_eq!(resolver.cert_count(), 0);
|
||||
assert!(resolver.resolve("example.com").is_none());
|
||||
}
|
||||
}
|
||||
44
rust/crates/rustproxy/Cargo.toml
Normal file
44
rust/crates/rustproxy/Cargo.toml
Normal file
@@ -0,0 +1,44 @@
|
||||
[package]
|
||||
name = "rustproxy"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
authors.workspace = true
|
||||
description = "High-performance multi-protocol proxy built on Pingora, compatible with SmartProxy configuration"
|
||||
|
||||
[[bin]]
|
||||
name = "rustproxy"
|
||||
path = "src/main.rs"
|
||||
|
||||
[lib]
|
||||
name = "rustproxy"
|
||||
path = "src/lib.rs"
|
||||
|
||||
[dependencies]
|
||||
rustproxy-config = { workspace = true }
|
||||
rustproxy-routing = { workspace = true }
|
||||
rustproxy-tls = { workspace = true }
|
||||
rustproxy-passthrough = { workspace = true }
|
||||
rustproxy-http = { workspace = true }
|
||||
rustproxy-nftables = { workspace = true }
|
||||
rustproxy-metrics = { workspace = true }
|
||||
rustproxy-security = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
tracing-subscriber = { workspace = true }
|
||||
clap = { workspace = true }
|
||||
anyhow = { workspace = true }
|
||||
arc-swap = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
rustls = { workspace = true }
|
||||
tokio-rustls = { workspace = true }
|
||||
tokio-util = { workspace = true }
|
||||
dashmap = { workspace = true }
|
||||
hyper = { workspace = true }
|
||||
hyper-util = { workspace = true }
|
||||
http-body-util = { workspace = true }
|
||||
bytes = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
rcgen = { workspace = true }
|
||||
177
rust/crates/rustproxy/src/challenge_server.rs
Normal file
177
rust/crates/rustproxy/src/challenge_server.rs
Normal file
@@ -0,0 +1,177 @@
|
||||
//! HTTP-01 ACME challenge server.
|
||||
//!
|
||||
//! A lightweight HTTP server that serves ACME challenge responses at
|
||||
//! `/.well-known/acme-challenge/<token>`.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use bytes::Bytes;
|
||||
use dashmap::DashMap;
|
||||
use http_body_util::Full;
|
||||
use hyper::body::Incoming;
|
||||
use hyper::{Request, Response, StatusCode};
|
||||
use hyper_util::rt::TokioIo;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{debug, info, error};
|
||||
|
||||
/// ACME HTTP-01 challenge server.
|
||||
pub struct ChallengeServer {
|
||||
/// Token -> key authorization mapping
|
||||
challenges: Arc<DashMap<String, String>>,
|
||||
/// Cancellation token to stop the server
|
||||
cancel: CancellationToken,
|
||||
/// Server task handle
|
||||
handle: Option<tokio::task::JoinHandle<()>>,
|
||||
}
|
||||
|
||||
impl ChallengeServer {
|
||||
/// Create a new challenge server (not yet started).
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
challenges: Arc::new(DashMap::new()),
|
||||
cancel: CancellationToken::new(),
|
||||
handle: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Register a challenge token -> key_authorization mapping.
|
||||
pub fn set_challenge(&self, token: String, key_authorization: String) {
|
||||
debug!("Registered ACME challenge: token={}", token);
|
||||
self.challenges.insert(token, key_authorization);
|
||||
}
|
||||
|
||||
/// Remove a challenge token.
|
||||
pub fn remove_challenge(&self, token: &str) {
|
||||
self.challenges.remove(token);
|
||||
}
|
||||
|
||||
/// Start the challenge server on the given port.
|
||||
pub async fn start(&mut self, port: u16) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
let addr = format!("0.0.0.0:{}", port);
|
||||
let listener = TcpListener::bind(&addr).await?;
|
||||
info!("ACME challenge server listening on port {}", port);
|
||||
|
||||
let challenges = Arc::clone(&self.challenges);
|
||||
let cancel = self.cancel.clone();
|
||||
|
||||
let handle = tokio::spawn(async move {
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = cancel.cancelled() => {
|
||||
info!("ACME challenge server stopping");
|
||||
break;
|
||||
}
|
||||
result = listener.accept() => {
|
||||
match result {
|
||||
Ok((stream, _)) => {
|
||||
let challenges = Arc::clone(&challenges);
|
||||
tokio::spawn(async move {
|
||||
let io = TokioIo::new(stream);
|
||||
let service = hyper::service::service_fn(move |req: Request<Incoming>| {
|
||||
let challenges = Arc::clone(&challenges);
|
||||
async move {
|
||||
Self::handle_request(req, &challenges)
|
||||
}
|
||||
});
|
||||
|
||||
let conn = hyper::server::conn::http1::Builder::new()
|
||||
.serve_connection(io, service);
|
||||
|
||||
if let Err(e) = conn.await {
|
||||
debug!("Challenge server connection error: {}", e);
|
||||
}
|
||||
});
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Challenge server accept error: {}", e);
|
||||
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
self.handle = Some(handle);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Stop the challenge server.
|
||||
pub async fn stop(&mut self) {
|
||||
self.cancel.cancel();
|
||||
if let Some(handle) = self.handle.take() {
|
||||
let _ = tokio::time::timeout(
|
||||
std::time::Duration::from_secs(5),
|
||||
handle,
|
||||
).await;
|
||||
}
|
||||
self.challenges.clear();
|
||||
self.cancel = CancellationToken::new();
|
||||
info!("ACME challenge server stopped");
|
||||
}
|
||||
|
||||
/// Handle an HTTP request for ACME challenges.
|
||||
fn handle_request(
|
||||
req: Request<Incoming>,
|
||||
challenges: &DashMap<String, String>,
|
||||
) -> Result<Response<Full<Bytes>>, hyper::Error> {
|
||||
let path = req.uri().path();
|
||||
|
||||
if let Some(token) = path.strip_prefix("/.well-known/acme-challenge/") {
|
||||
if let Some(key_auth) = challenges.get(token) {
|
||||
debug!("Serving ACME challenge for token: {}", token);
|
||||
return Ok(Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
.header("content-type", "text/plain")
|
||||
.body(Full::new(Bytes::from(key_auth.value().clone())))
|
||||
.unwrap());
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Response::builder()
|
||||
.status(StatusCode::NOT_FOUND)
|
||||
.body(Full::new(Bytes::from("Not Found")))
|
||||
.unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_challenge_server_lifecycle() {
|
||||
let mut server = ChallengeServer::new();
|
||||
|
||||
// Set a challenge before starting
|
||||
server.set_challenge("test-token".to_string(), "test-key-auth".to_string());
|
||||
|
||||
// Start on a random port
|
||||
server.start(19900).await.unwrap();
|
||||
|
||||
// Give server a moment to start
|
||||
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
|
||||
|
||||
// Fetch the challenge
|
||||
let client = tokio::net::TcpStream::connect("127.0.0.1:19900").await.unwrap();
|
||||
let io = TokioIo::new(client);
|
||||
let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await.unwrap();
|
||||
tokio::spawn(async move { let _ = conn.await; });
|
||||
|
||||
let req = Request::get("/.well-known/acme-challenge/test-token")
|
||||
.body(Full::new(Bytes::new()))
|
||||
.unwrap();
|
||||
let resp = sender.send_request(req).await.unwrap();
|
||||
assert_eq!(resp.status(), StatusCode::OK);
|
||||
|
||||
// Test 404 for unknown token
|
||||
let req = Request::get("/.well-known/acme-challenge/unknown")
|
||||
.body(Full::new(Bytes::new()))
|
||||
.unwrap();
|
||||
let resp = sender.send_request(req).await.unwrap();
|
||||
assert_eq!(resp.status(), StatusCode::NOT_FOUND);
|
||||
|
||||
server.stop().await;
|
||||
}
|
||||
}
|
||||
931
rust/crates/rustproxy/src/lib.rs
Normal file
931
rust/crates/rustproxy/src/lib.rs
Normal file
@@ -0,0 +1,931 @@
|
||||
//! # RustProxy
|
||||
//!
|
||||
//! High-performance multi-protocol proxy built on Rust,
|
||||
//! compatible with SmartProxy configuration.
|
||||
//!
|
||||
//! ## Quick Start
|
||||
//!
|
||||
//! ```rust,no_run
|
||||
//! use rustproxy::RustProxy;
|
||||
//! use rustproxy_config::{RustProxyOptions, create_https_passthrough_route};
|
||||
//!
|
||||
//! #[tokio::main]
|
||||
//! async fn main() -> anyhow::Result<()> {
|
||||
//! let options = RustProxyOptions {
|
||||
//! routes: vec![
|
||||
//! create_https_passthrough_route("example.com", "backend", 443),
|
||||
//! ],
|
||||
//! ..Default::default()
|
||||
//! };
|
||||
//!
|
||||
//! let mut proxy = RustProxy::new(options)?;
|
||||
//! proxy.start().await?;
|
||||
//! Ok(())
|
||||
//! }
|
||||
//! ```
|
||||
|
||||
pub mod challenge_server;
|
||||
pub mod management;
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
|
||||
use arc_swap::ArcSwap;
|
||||
use anyhow::Result;
|
||||
use tracing::{info, warn, debug, error};
|
||||
|
||||
// Re-export key types
|
||||
pub use rustproxy_config;
|
||||
pub use rustproxy_routing;
|
||||
pub use rustproxy_passthrough;
|
||||
pub use rustproxy_tls;
|
||||
pub use rustproxy_http;
|
||||
pub use rustproxy_nftables;
|
||||
pub use rustproxy_metrics;
|
||||
pub use rustproxy_security;
|
||||
|
||||
use rustproxy_config::{RouteConfig, RustProxyOptions, TlsMode, CertificateSpec, ForwardingEngine};
|
||||
use rustproxy_routing::RouteManager;
|
||||
use rustproxy_passthrough::{TcpListenerManager, TlsCertConfig, ConnectionConfig};
|
||||
use rustproxy_metrics::{MetricsCollector, Metrics, Statistics};
|
||||
use rustproxy_tls::{CertManager, CertStore, CertBundle, CertMetadata, CertSource};
|
||||
use rustproxy_nftables::{NftManager, rule_builder};
|
||||
|
||||
/// Certificate status.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CertStatus {
|
||||
pub domain: String,
|
||||
pub source: String,
|
||||
pub expires_at: u64,
|
||||
pub is_valid: bool,
|
||||
}
|
||||
|
||||
/// The main RustProxy struct.
|
||||
/// This is the primary public API matching SmartProxy's interface.
|
||||
pub struct RustProxy {
|
||||
options: RustProxyOptions,
|
||||
route_table: ArcSwap<RouteManager>,
|
||||
listener_manager: Option<TcpListenerManager>,
|
||||
metrics: Arc<MetricsCollector>,
|
||||
cert_manager: Option<Arc<tokio::sync::Mutex<CertManager>>>,
|
||||
challenge_server: Option<challenge_server::ChallengeServer>,
|
||||
renewal_handle: Option<tokio::task::JoinHandle<()>>,
|
||||
nft_manager: Option<NftManager>,
|
||||
started: bool,
|
||||
started_at: Option<Instant>,
|
||||
/// Path to a Unix domain socket for relaying socket-handler connections back to TypeScript.
|
||||
socket_handler_relay_path: Option<String>,
|
||||
}
|
||||
|
||||
impl RustProxy {
|
||||
/// Create a new RustProxy instance with the given configuration.
|
||||
pub fn new(mut options: RustProxyOptions) -> Result<Self> {
|
||||
// Apply defaults to routes before validation
|
||||
Self::apply_defaults(&mut options);
|
||||
|
||||
// Validate routes
|
||||
if let Err(errors) = rustproxy_config::validate_routes(&options.routes) {
|
||||
for err in &errors {
|
||||
warn!("Route validation error: {}", err);
|
||||
}
|
||||
if !errors.is_empty() {
|
||||
anyhow::bail!("Route validation failed with {} errors", errors.len());
|
||||
}
|
||||
}
|
||||
|
||||
let route_manager = RouteManager::new(options.routes.clone());
|
||||
|
||||
// Set up certificate manager if ACME is configured
|
||||
let cert_manager = Self::build_cert_manager(&options)
|
||||
.map(|cm| Arc::new(tokio::sync::Mutex::new(cm)));
|
||||
|
||||
Ok(Self {
|
||||
options,
|
||||
route_table: ArcSwap::from(Arc::new(route_manager)),
|
||||
listener_manager: None,
|
||||
metrics: Arc::new(MetricsCollector::new()),
|
||||
cert_manager,
|
||||
challenge_server: None,
|
||||
renewal_handle: None,
|
||||
nft_manager: None,
|
||||
started: false,
|
||||
started_at: None,
|
||||
socket_handler_relay_path: None,
|
||||
})
|
||||
}
|
||||
|
||||
/// Apply default configuration to routes that lack targets or security.
|
||||
fn apply_defaults(options: &mut RustProxyOptions) {
|
||||
let defaults = match &options.defaults {
|
||||
Some(d) => d.clone(),
|
||||
None => return,
|
||||
};
|
||||
|
||||
for route in &mut options.routes {
|
||||
// Apply default target if route has no targets
|
||||
if route.action.targets.is_none() {
|
||||
if let Some(ref default_target) = defaults.target {
|
||||
debug!("Applying default target {}:{} to route {:?}",
|
||||
default_target.host, default_target.port,
|
||||
route.name.as_deref().unwrap_or("unnamed"));
|
||||
route.action.targets = Some(vec![
|
||||
rustproxy_config::RouteTarget {
|
||||
target_match: None,
|
||||
host: rustproxy_config::HostSpec::Single(default_target.host.clone()),
|
||||
port: rustproxy_config::PortSpec::Fixed(default_target.port),
|
||||
tls: None,
|
||||
websocket: None,
|
||||
load_balancing: None,
|
||||
send_proxy_protocol: None,
|
||||
headers: None,
|
||||
advanced: None,
|
||||
priority: None,
|
||||
}
|
||||
]);
|
||||
}
|
||||
}
|
||||
|
||||
// Apply default security if route has no security
|
||||
if route.security.is_none() {
|
||||
if let Some(ref default_security) = defaults.security {
|
||||
let mut security = rustproxy_config::RouteSecurity {
|
||||
ip_allow_list: None,
|
||||
ip_block_list: None,
|
||||
max_connections: default_security.max_connections,
|
||||
authentication: None,
|
||||
rate_limit: None,
|
||||
basic_auth: None,
|
||||
jwt_auth: None,
|
||||
};
|
||||
|
||||
if let Some(ref allow_list) = default_security.ip_allow_list {
|
||||
security.ip_allow_list = Some(allow_list.clone());
|
||||
}
|
||||
if let Some(ref block_list) = default_security.ip_block_list {
|
||||
security.ip_block_list = Some(block_list.clone());
|
||||
}
|
||||
|
||||
// Only apply if there's something meaningful
|
||||
if security.ip_allow_list.is_some() || security.ip_block_list.is_some() {
|
||||
debug!("Applying default security to route {:?}",
|
||||
route.name.as_deref().unwrap_or("unnamed"));
|
||||
route.security = Some(security);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Build a CertManager from options.
|
||||
fn build_cert_manager(options: &RustProxyOptions) -> Option<CertManager> {
|
||||
let acme = options.acme.as_ref()?;
|
||||
if !acme.enabled.unwrap_or(false) {
|
||||
return None;
|
||||
}
|
||||
|
||||
let store_path = acme.certificate_store
|
||||
.as_deref()
|
||||
.unwrap_or("./certs");
|
||||
let email = acme.email.clone()
|
||||
.or_else(|| acme.account_email.clone());
|
||||
let use_production = acme.use_production.unwrap_or(false);
|
||||
let renew_before_days = acme.renew_threshold_days.unwrap_or(30);
|
||||
|
||||
let store = CertStore::new(store_path);
|
||||
Some(CertManager::new(store, email, use_production, renew_before_days))
|
||||
}
|
||||
|
||||
/// Build ConnectionConfig from RustProxyOptions.
|
||||
fn build_connection_config(options: &RustProxyOptions) -> ConnectionConfig {
|
||||
ConnectionConfig {
|
||||
connection_timeout_ms: options.effective_connection_timeout(),
|
||||
initial_data_timeout_ms: options.effective_initial_data_timeout(),
|
||||
socket_timeout_ms: options.effective_socket_timeout(),
|
||||
max_connection_lifetime_ms: options.effective_max_connection_lifetime(),
|
||||
graceful_shutdown_timeout_ms: options.graceful_shutdown_timeout.unwrap_or(30_000),
|
||||
max_connections_per_ip: options.max_connections_per_ip,
|
||||
connection_rate_limit_per_minute: options.connection_rate_limit_per_minute,
|
||||
keep_alive_treatment: options.keep_alive_treatment.clone(),
|
||||
keep_alive_inactivity_multiplier: options.keep_alive_inactivity_multiplier,
|
||||
extended_keep_alive_lifetime_ms: options.extended_keep_alive_lifetime,
|
||||
accept_proxy_protocol: options.accept_proxy_protocol.unwrap_or(false),
|
||||
send_proxy_protocol: options.send_proxy_protocol.unwrap_or(false),
|
||||
}
|
||||
}
|
||||
|
||||
/// Start the proxy, binding to all configured ports.
|
||||
pub async fn start(&mut self) -> Result<()> {
|
||||
if self.started {
|
||||
anyhow::bail!("Proxy is already started");
|
||||
}
|
||||
|
||||
info!("Starting RustProxy...");
|
||||
|
||||
// Load persisted certificates
|
||||
if let Some(ref cm) = self.cert_manager {
|
||||
let mut cm = cm.lock().await;
|
||||
match cm.load_all() {
|
||||
Ok(count) => {
|
||||
if count > 0 {
|
||||
info!("Loaded {} persisted certificates", count);
|
||||
}
|
||||
}
|
||||
Err(e) => warn!("Failed to load persisted certificates: {}", e),
|
||||
}
|
||||
}
|
||||
|
||||
// Auto-provision certificates for routes with certificate: 'auto'
|
||||
self.auto_provision_certificates().await;
|
||||
|
||||
let route_manager = self.route_table.load();
|
||||
let ports = route_manager.listening_ports();
|
||||
|
||||
info!("Configured {} routes on {} ports", route_manager.route_count(), ports.len());
|
||||
|
||||
// Create TCP listener manager with metrics
|
||||
let mut listener = TcpListenerManager::with_metrics(
|
||||
Arc::clone(&*route_manager),
|
||||
Arc::clone(&self.metrics),
|
||||
);
|
||||
|
||||
// Apply connection config from options
|
||||
let conn_config = Self::build_connection_config(&self.options);
|
||||
debug!("Connection config: timeout={}ms, initial_data={}ms, socket={}ms, max_life={}ms",
|
||||
conn_config.connection_timeout_ms,
|
||||
conn_config.initial_data_timeout_ms,
|
||||
conn_config.socket_timeout_ms,
|
||||
conn_config.max_connection_lifetime_ms,
|
||||
);
|
||||
listener.set_connection_config(conn_config);
|
||||
|
||||
// Extract TLS configurations from routes and cert manager
|
||||
let mut tls_configs = Self::extract_tls_configs(&self.options.routes);
|
||||
|
||||
// Also load certs from cert manager into TLS config
|
||||
if let Some(ref cm) = self.cert_manager {
|
||||
let cm = cm.lock().await;
|
||||
for (domain, bundle) in cm.store().iter() {
|
||||
if !tls_configs.contains_key(domain) {
|
||||
tls_configs.insert(domain.clone(), TlsCertConfig {
|
||||
cert_pem: bundle.cert_pem.clone(),
|
||||
key_pem: bundle.key_pem.clone(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !tls_configs.is_empty() {
|
||||
debug!("Loaded TLS certificates for {} domains", tls_configs.len());
|
||||
listener.set_tls_configs(tls_configs);
|
||||
}
|
||||
|
||||
// Bind all ports
|
||||
for port in &ports {
|
||||
listener.add_port(*port).await?;
|
||||
}
|
||||
|
||||
self.listener_manager = Some(listener);
|
||||
self.started = true;
|
||||
self.started_at = Some(Instant::now());
|
||||
|
||||
// Apply NFTables rules for routes using nftables forwarding engine
|
||||
self.apply_nftables_rules(&self.options.routes.clone()).await;
|
||||
|
||||
// Start renewal timer if ACME is enabled
|
||||
self.start_renewal_timer();
|
||||
|
||||
info!("RustProxy started successfully on ports: {:?}", ports);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Auto-provision certificates for routes that use certificate: 'auto'.
|
||||
async fn auto_provision_certificates(&mut self) {
|
||||
let cm_arc = match self.cert_manager {
|
||||
Some(ref cm) => Arc::clone(cm),
|
||||
None => return,
|
||||
};
|
||||
|
||||
let mut domains_to_provision = Vec::new();
|
||||
|
||||
for route in &self.options.routes {
|
||||
let tls_mode = route.tls_mode();
|
||||
let needs_cert = matches!(
|
||||
tls_mode,
|
||||
Some(TlsMode::Terminate) | Some(TlsMode::TerminateAndReencrypt)
|
||||
);
|
||||
if !needs_cert {
|
||||
continue;
|
||||
}
|
||||
|
||||
let cert_spec = route.action.tls.as_ref()
|
||||
.and_then(|tls| tls.certificate.as_ref());
|
||||
|
||||
if let Some(CertificateSpec::Auto(_)) = cert_spec {
|
||||
if let Some(ref domains) = route.route_match.domains {
|
||||
for domain in domains.to_vec() {
|
||||
let domain = domain.to_string();
|
||||
// Skip if we already have a valid cert
|
||||
let cm = cm_arc.lock().await;
|
||||
if cm.store().has(&domain) {
|
||||
debug!("Already have cert for {}, skipping auto-provision", domain);
|
||||
continue;
|
||||
}
|
||||
drop(cm);
|
||||
domains_to_provision.push(domain);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if domains_to_provision.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
info!("Auto-provisioning certificates for {} domains", domains_to_provision.len());
|
||||
|
||||
// Start challenge server
|
||||
let acme_port = self.options.acme.as_ref()
|
||||
.and_then(|a| a.port)
|
||||
.unwrap_or(80);
|
||||
|
||||
let mut challenge_server = challenge_server::ChallengeServer::new();
|
||||
if let Err(e) = challenge_server.start(acme_port).await {
|
||||
error!("Failed to start ACME challenge server on port {}: {}", acme_port, e);
|
||||
return;
|
||||
}
|
||||
|
||||
for domain in &domains_to_provision {
|
||||
info!("Provisioning certificate for {}", domain);
|
||||
|
||||
let cm = cm_arc.lock().await;
|
||||
let acme_client = cm.acme_client();
|
||||
drop(cm);
|
||||
|
||||
if let Some(acme_client) = acme_client {
|
||||
let challenge_server_ref = &challenge_server;
|
||||
let result = acme_client.provision(domain, |pending| {
|
||||
challenge_server_ref.set_challenge(
|
||||
pending.token.clone(),
|
||||
pending.key_authorization.clone(),
|
||||
);
|
||||
async move { Ok(()) }
|
||||
}).await;
|
||||
|
||||
match result {
|
||||
Ok((cert_pem, key_pem)) => {
|
||||
let now = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs();
|
||||
|
||||
let bundle = CertBundle {
|
||||
cert_pem,
|
||||
key_pem,
|
||||
ca_pem: None,
|
||||
metadata: CertMetadata {
|
||||
domain: domain.clone(),
|
||||
source: CertSource::Acme,
|
||||
issued_at: now,
|
||||
expires_at: now + 90 * 86400, // 90 days
|
||||
renewed_at: None,
|
||||
},
|
||||
};
|
||||
|
||||
let mut cm = cm_arc.lock().await;
|
||||
if let Err(e) = cm.load_static(domain.clone(), bundle) {
|
||||
error!("Failed to store certificate for {}: {}", domain, e);
|
||||
}
|
||||
|
||||
info!("Certificate provisioned for {}", domain);
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Failed to provision certificate for {}: {}", domain, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
challenge_server.stop().await;
|
||||
}
|
||||
|
||||
/// Start the renewal timer background task.
|
||||
/// The background task checks for expiring certificates and renews them.
|
||||
fn start_renewal_timer(&mut self) {
|
||||
let cm_arc = match self.cert_manager {
|
||||
Some(ref cm) => Arc::clone(cm),
|
||||
None => return,
|
||||
};
|
||||
|
||||
let auto_renew = self.options.acme.as_ref()
|
||||
.and_then(|a| a.auto_renew)
|
||||
.unwrap_or(true);
|
||||
|
||||
if !auto_renew {
|
||||
return;
|
||||
}
|
||||
|
||||
let check_interval_hours = self.options.acme.as_ref()
|
||||
.and_then(|a| a.renew_check_interval_hours)
|
||||
.unwrap_or(24);
|
||||
|
||||
let acme_port = self.options.acme.as_ref()
|
||||
.and_then(|a| a.port)
|
||||
.unwrap_or(80);
|
||||
|
||||
let interval = std::time::Duration::from_secs(check_interval_hours as u64 * 3600);
|
||||
|
||||
let handle = tokio::spawn(async move {
|
||||
loop {
|
||||
tokio::time::sleep(interval).await;
|
||||
debug!("Certificate renewal check triggered (interval: {}h)", check_interval_hours);
|
||||
|
||||
// Check which domains need renewal
|
||||
let domains = {
|
||||
let cm = cm_arc.lock().await;
|
||||
cm.check_renewals()
|
||||
};
|
||||
|
||||
if domains.is_empty() {
|
||||
debug!("No certificates need renewal");
|
||||
continue;
|
||||
}
|
||||
|
||||
info!("Renewing {} certificate(s)", domains.len());
|
||||
|
||||
// Start challenge server for renewals
|
||||
let mut cs = challenge_server::ChallengeServer::new();
|
||||
if let Err(e) = cs.start(acme_port).await {
|
||||
error!("Failed to start challenge server for renewal: {}", e);
|
||||
continue;
|
||||
}
|
||||
|
||||
for domain in &domains {
|
||||
let cs_ref = &cs;
|
||||
let mut cm = cm_arc.lock().await;
|
||||
let result = cm.renew_domain(domain, |token, key_auth| {
|
||||
cs_ref.set_challenge(token, key_auth);
|
||||
async {}
|
||||
}).await;
|
||||
|
||||
match result {
|
||||
Ok(_bundle) => {
|
||||
info!("Successfully renewed certificate for {}", domain);
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Failed to renew certificate for {}: {}", domain, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
cs.stop().await;
|
||||
}
|
||||
});
|
||||
|
||||
self.renewal_handle = Some(handle);
|
||||
}
|
||||
|
||||
/// Stop the proxy gracefully.
|
||||
pub async fn stop(&mut self) -> Result<()> {
|
||||
if !self.started {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
info!("Stopping RustProxy...");
|
||||
|
||||
// Stop renewal timer
|
||||
if let Some(handle) = self.renewal_handle.take() {
|
||||
handle.abort();
|
||||
}
|
||||
|
||||
// Stop challenge server if running
|
||||
if let Some(ref mut cs) = self.challenge_server {
|
||||
cs.stop().await;
|
||||
}
|
||||
self.challenge_server = None;
|
||||
|
||||
// Clean up NFTables rules
|
||||
if let Some(ref mut nft) = self.nft_manager {
|
||||
if let Err(e) = nft.cleanup().await {
|
||||
warn!("NFTables cleanup failed: {}", e);
|
||||
}
|
||||
}
|
||||
self.nft_manager = None;
|
||||
|
||||
if let Some(ref mut listener) = self.listener_manager {
|
||||
listener.graceful_stop().await;
|
||||
}
|
||||
self.listener_manager = None;
|
||||
self.started = false;
|
||||
|
||||
info!("RustProxy stopped");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Update routes atomically (hot-reload).
|
||||
pub async fn update_routes(&mut self, routes: Vec<RouteConfig>) -> Result<()> {
|
||||
// Validate new routes
|
||||
rustproxy_config::validate_routes(&routes)
|
||||
.map_err(|errors| {
|
||||
let msgs: Vec<String> = errors.iter().map(|e| e.to_string()).collect();
|
||||
anyhow::anyhow!("Route validation failed: {}", msgs.join(", "))
|
||||
})?;
|
||||
|
||||
let new_manager = RouteManager::new(routes.clone());
|
||||
let new_ports = new_manager.listening_ports();
|
||||
|
||||
info!("Updating routes: {} routes on {} ports",
|
||||
new_manager.route_count(), new_ports.len());
|
||||
|
||||
// Get old ports
|
||||
let old_ports: Vec<u16> = if let Some(ref listener) = self.listener_manager {
|
||||
listener.listening_ports()
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
|
||||
// Atomically swap the route table
|
||||
let new_manager = Arc::new(new_manager);
|
||||
self.route_table.store(Arc::clone(&new_manager));
|
||||
|
||||
// Update listener manager
|
||||
if let Some(ref mut listener) = self.listener_manager {
|
||||
listener.update_route_manager(Arc::clone(&new_manager));
|
||||
|
||||
// Update TLS configs
|
||||
let mut tls_configs = Self::extract_tls_configs(&routes);
|
||||
if let Some(ref cm_arc) = self.cert_manager {
|
||||
let cm = cm_arc.lock().await;
|
||||
for (domain, bundle) in cm.store().iter() {
|
||||
if !tls_configs.contains_key(domain) {
|
||||
tls_configs.insert(domain.clone(), TlsCertConfig {
|
||||
cert_pem: bundle.cert_pem.clone(),
|
||||
key_pem: bundle.key_pem.clone(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
listener.set_tls_configs(tls_configs);
|
||||
|
||||
// Add new ports
|
||||
for port in &new_ports {
|
||||
if !old_ports.contains(port) {
|
||||
listener.add_port(*port).await?;
|
||||
}
|
||||
}
|
||||
|
||||
// Remove old ports no longer needed
|
||||
for port in &old_ports {
|
||||
if !new_ports.contains(port) {
|
||||
listener.remove_port(*port);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Update NFTables rules: remove old, apply new
|
||||
self.update_nftables_rules(&routes).await;
|
||||
|
||||
self.options.routes = routes;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Provision a certificate for a named route.
|
||||
pub async fn provision_certificate(&mut self, route_name: &str) -> Result<()> {
|
||||
let cm_arc = self.cert_manager.as_ref()
|
||||
.ok_or_else(|| anyhow::anyhow!("No certificate manager configured (ACME not enabled)"))?;
|
||||
|
||||
// Find the route by name
|
||||
let route = self.options.routes.iter()
|
||||
.find(|r| r.name.as_deref() == Some(route_name))
|
||||
.ok_or_else(|| anyhow::anyhow!("Route '{}' not found", route_name))?;
|
||||
|
||||
let domain = route.route_match.domains.as_ref()
|
||||
.and_then(|d| d.to_vec().first().map(|s| s.to_string()))
|
||||
.ok_or_else(|| anyhow::anyhow!("Route '{}' has no domain", route_name))?;
|
||||
|
||||
info!("Provisioning certificate for route '{}' (domain: {})", route_name, domain);
|
||||
|
||||
// Start challenge server
|
||||
let acme_port = self.options.acme.as_ref()
|
||||
.and_then(|a| a.port)
|
||||
.unwrap_or(80);
|
||||
|
||||
let mut cs = challenge_server::ChallengeServer::new();
|
||||
cs.start(acme_port).await
|
||||
.map_err(|e| anyhow::anyhow!("Failed to start challenge server: {}", e))?;
|
||||
|
||||
let cs_ref = &cs;
|
||||
let mut cm = cm_arc.lock().await;
|
||||
let result = cm.renew_domain(&domain, |token, key_auth| {
|
||||
cs_ref.set_challenge(token, key_auth);
|
||||
async {}
|
||||
}).await;
|
||||
drop(cm);
|
||||
|
||||
cs.stop().await;
|
||||
|
||||
let bundle = result
|
||||
.map_err(|e| anyhow::anyhow!("ACME provisioning failed: {}", e))?;
|
||||
|
||||
// Hot-swap into TLS configs
|
||||
if let Some(ref mut listener) = self.listener_manager {
|
||||
let mut tls_configs = Self::extract_tls_configs(&self.options.routes);
|
||||
tls_configs.insert(domain.clone(), TlsCertConfig {
|
||||
cert_pem: bundle.cert_pem.clone(),
|
||||
key_pem: bundle.key_pem.clone(),
|
||||
});
|
||||
let cm = cm_arc.lock().await;
|
||||
for (d, b) in cm.store().iter() {
|
||||
if !tls_configs.contains_key(d) {
|
||||
tls_configs.insert(d.clone(), TlsCertConfig {
|
||||
cert_pem: b.cert_pem.clone(),
|
||||
key_pem: b.key_pem.clone(),
|
||||
});
|
||||
}
|
||||
}
|
||||
listener.set_tls_configs(tls_configs);
|
||||
}
|
||||
|
||||
info!("Certificate provisioned and loaded for route '{}'", route_name);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Renew a certificate for a named route.
|
||||
pub async fn renew_certificate(&mut self, route_name: &str) -> Result<()> {
|
||||
// Renewal is just re-provisioning
|
||||
self.provision_certificate(route_name).await
|
||||
}
|
||||
|
||||
/// Get the status of a certificate for a named route.
|
||||
pub async fn get_certificate_status(&self, route_name: &str) -> Option<CertStatus> {
|
||||
let route = self.options.routes.iter()
|
||||
.find(|r| r.name.as_deref() == Some(route_name))?;
|
||||
|
||||
let domain = route.route_match.domains.as_ref()
|
||||
.and_then(|d| d.to_vec().first().map(|s| s.to_string()))?;
|
||||
|
||||
if let Some(ref cm_arc) = self.cert_manager {
|
||||
let cm = cm_arc.lock().await;
|
||||
if let Some(bundle) = cm.get_cert(&domain) {
|
||||
let now = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs();
|
||||
|
||||
return Some(CertStatus {
|
||||
domain,
|
||||
source: format!("{:?}", bundle.metadata.source),
|
||||
expires_at: bundle.metadata.expires_at,
|
||||
is_valid: bundle.metadata.expires_at > now,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Get current metrics snapshot.
|
||||
pub fn get_metrics(&self) -> Metrics {
|
||||
self.metrics.snapshot()
|
||||
}
|
||||
|
||||
/// Add a listening port at runtime.
|
||||
pub async fn add_listening_port(&mut self, port: u16) -> Result<()> {
|
||||
if let Some(ref mut listener) = self.listener_manager {
|
||||
listener.add_port(port).await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Remove a listening port at runtime.
|
||||
pub async fn remove_listening_port(&mut self, port: u16) -> Result<()> {
|
||||
if let Some(ref mut listener) = self.listener_manager {
|
||||
listener.remove_port(port);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get all currently listening ports.
|
||||
pub fn get_listening_ports(&self) -> Vec<u16> {
|
||||
self.listener_manager
|
||||
.as_ref()
|
||||
.map(|l| l.listening_ports())
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
/// Get statistics snapshot.
|
||||
pub fn get_statistics(&self) -> Statistics {
|
||||
let uptime = self.started_at
|
||||
.map(|t| t.elapsed().as_secs())
|
||||
.unwrap_or(0);
|
||||
|
||||
Statistics {
|
||||
active_connections: self.metrics.active_connections(),
|
||||
total_connections: self.metrics.total_connections(),
|
||||
routes_count: self.route_table.load().route_count() as u64,
|
||||
listening_ports: self.get_listening_ports(),
|
||||
uptime_seconds: uptime,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the Unix domain socket path for relaying socket-handler connections to TypeScript.
|
||||
pub fn set_socket_handler_relay_path(&mut self, path: Option<String>) {
|
||||
info!("Socket handler relay path set to: {:?}", path);
|
||||
self.socket_handler_relay_path = path;
|
||||
}
|
||||
|
||||
/// Get the current socket handler relay path.
|
||||
pub fn get_socket_handler_relay_path(&self) -> Option<&str> {
|
||||
self.socket_handler_relay_path.as_deref()
|
||||
}
|
||||
|
||||
/// Load a certificate for a domain and hot-swap the TLS configuration.
|
||||
pub async fn load_certificate(
|
||||
&mut self,
|
||||
domain: &str,
|
||||
cert_pem: String,
|
||||
key_pem: String,
|
||||
ca_pem: Option<String>,
|
||||
) -> Result<()> {
|
||||
info!("Loading certificate for domain: {}", domain);
|
||||
|
||||
// Store in cert manager if available
|
||||
if let Some(ref cm_arc) = self.cert_manager {
|
||||
let now = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs();
|
||||
|
||||
let bundle = CertBundle {
|
||||
cert_pem: cert_pem.clone(),
|
||||
key_pem: key_pem.clone(),
|
||||
ca_pem: ca_pem.clone(),
|
||||
metadata: CertMetadata {
|
||||
domain: domain.to_string(),
|
||||
source: CertSource::Static,
|
||||
issued_at: now,
|
||||
expires_at: now + 90 * 86400, // assume 90 days
|
||||
renewed_at: None,
|
||||
},
|
||||
};
|
||||
|
||||
let mut cm = cm_arc.lock().await;
|
||||
cm.load_static(domain.to_string(), bundle)
|
||||
.map_err(|e| anyhow::anyhow!("Failed to store certificate: {}", e))?;
|
||||
}
|
||||
|
||||
// Hot-swap TLS config on the listener
|
||||
if let Some(ref mut listener) = self.listener_manager {
|
||||
let mut tls_configs = Self::extract_tls_configs(&self.options.routes);
|
||||
|
||||
// Add the new cert
|
||||
tls_configs.insert(domain.to_string(), TlsCertConfig {
|
||||
cert_pem: cert_pem.clone(),
|
||||
key_pem: key_pem.clone(),
|
||||
});
|
||||
|
||||
// Also include all existing certs from cert manager
|
||||
if let Some(ref cm_arc) = self.cert_manager {
|
||||
let cm = cm_arc.lock().await;
|
||||
for (d, b) in cm.store().iter() {
|
||||
if !tls_configs.contains_key(d) {
|
||||
tls_configs.insert(d.clone(), TlsCertConfig {
|
||||
cert_pem: b.cert_pem.clone(),
|
||||
key_pem: b.key_pem.clone(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
listener.set_tls_configs(tls_configs);
|
||||
}
|
||||
|
||||
info!("Certificate loaded and TLS config updated for {}", domain);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get NFTables status.
|
||||
pub async fn get_nftables_status(&self) -> Result<HashMap<String, serde_json::Value>> {
|
||||
match &self.nft_manager {
|
||||
Some(nft) => Ok(nft.status()),
|
||||
None => Ok(HashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Apply NFTables rules for routes using the nftables forwarding engine.
|
||||
async fn apply_nftables_rules(&mut self, routes: &[RouteConfig]) {
|
||||
let nft_routes: Vec<&RouteConfig> = routes.iter()
|
||||
.filter(|r| r.action.forwarding_engine.as_ref() == Some(&ForwardingEngine::Nftables))
|
||||
.collect();
|
||||
|
||||
if nft_routes.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
info!("Applying NFTables rules for {} routes", nft_routes.len());
|
||||
|
||||
let table_name = nft_routes.iter()
|
||||
.find_map(|r| r.action.nftables.as_ref()?.table_name.clone())
|
||||
.unwrap_or_else(|| "rustproxy".to_string());
|
||||
|
||||
let mut nft = NftManager::new(Some(table_name));
|
||||
|
||||
for route in &nft_routes {
|
||||
let route_id = route.id.as_deref()
|
||||
.or(route.name.as_deref())
|
||||
.unwrap_or("unnamed");
|
||||
|
||||
let nft_options = match &route.action.nftables {
|
||||
Some(opts) => opts.clone(),
|
||||
None => rustproxy_config::NfTablesOptions {
|
||||
preserve_source_ip: None,
|
||||
protocol: None,
|
||||
max_rate: None,
|
||||
priority: None,
|
||||
table_name: None,
|
||||
use_ip_sets: None,
|
||||
use_advanced_nat: None,
|
||||
},
|
||||
};
|
||||
|
||||
let targets = match &route.action.targets {
|
||||
Some(targets) => targets,
|
||||
None => {
|
||||
warn!("NFTables route '{}' has no targets, skipping", route_id);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let source_ports = route.route_match.ports.to_ports();
|
||||
for target in targets {
|
||||
let target_host = target.host.first().to_string();
|
||||
let target_port_spec = &target.port;
|
||||
|
||||
for &source_port in &source_ports {
|
||||
let resolved_port = target_port_spec.resolve(source_port);
|
||||
let rules = rule_builder::build_dnat_rule(
|
||||
nft.table_name(),
|
||||
"prerouting",
|
||||
source_port,
|
||||
&target_host,
|
||||
resolved_port,
|
||||
&nft_options,
|
||||
);
|
||||
|
||||
let rule_id = format!("{}-{}-{}", route_id, source_port, resolved_port);
|
||||
if let Err(e) = nft.apply_rules(&rule_id, rules).await {
|
||||
error!("Failed to apply NFTables rules for route '{}': {}", route_id, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
self.nft_manager = Some(nft);
|
||||
}
|
||||
|
||||
/// Update NFTables rules when routes change.
|
||||
async fn update_nftables_rules(&mut self, new_routes: &[RouteConfig]) {
|
||||
// Clean up old rules
|
||||
if let Some(ref mut nft) = self.nft_manager {
|
||||
if let Err(e) = nft.cleanup().await {
|
||||
warn!("NFTables cleanup during update failed: {}", e);
|
||||
}
|
||||
}
|
||||
self.nft_manager = None;
|
||||
|
||||
// Apply new rules
|
||||
self.apply_nftables_rules(new_routes).await;
|
||||
}
|
||||
|
||||
/// Extract TLS configurations from route configs.
|
||||
fn extract_tls_configs(routes: &[RouteConfig]) -> HashMap<String, TlsCertConfig> {
|
||||
let mut configs = HashMap::new();
|
||||
|
||||
for route in routes {
|
||||
let tls_mode = route.tls_mode();
|
||||
let needs_cert = matches!(
|
||||
tls_mode,
|
||||
Some(TlsMode::Terminate) | Some(TlsMode::TerminateAndReencrypt)
|
||||
);
|
||||
if !needs_cert {
|
||||
continue;
|
||||
}
|
||||
|
||||
let cert_spec = route.action.tls.as_ref()
|
||||
.and_then(|tls| tls.certificate.as_ref());
|
||||
|
||||
if let Some(CertificateSpec::Static(cert_config)) = cert_spec {
|
||||
if let Some(ref domains) = route.route_match.domains {
|
||||
for domain in domains.to_vec() {
|
||||
configs.insert(domain.to_string(), TlsCertConfig {
|
||||
cert_pem: cert_config.cert.clone(),
|
||||
key_pem: cert_config.key.clone(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
configs
|
||||
}
|
||||
}
|
||||
90
rust/crates/rustproxy/src/main.rs
Normal file
90
rust/crates/rustproxy/src/main.rs
Normal file
@@ -0,0 +1,90 @@
|
||||
use clap::Parser;
|
||||
use tracing_subscriber::EnvFilter;
|
||||
use anyhow::Result;
|
||||
|
||||
use rustproxy::RustProxy;
|
||||
use rustproxy::management;
|
||||
use rustproxy_config::RustProxyOptions;
|
||||
|
||||
/// RustProxy - High-performance multi-protocol proxy
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(name = "rustproxy", version, about)]
|
||||
struct Cli {
|
||||
/// Path to JSON configuration file
|
||||
#[arg(short, long, default_value = "config.json")]
|
||||
config: String,
|
||||
|
||||
/// Log level (trace, debug, info, warn, error)
|
||||
#[arg(short, long, default_value = "info")]
|
||||
log_level: String,
|
||||
|
||||
/// Validate configuration without starting
|
||||
#[arg(long)]
|
||||
validate: bool,
|
||||
|
||||
/// Run in management mode (JSON-over-stdin IPC for TypeScript wrapper)
|
||||
#[arg(long)]
|
||||
management: bool,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
let cli = Cli::parse();
|
||||
|
||||
// Initialize tracing - write to stderr so stdout is reserved for management IPC
|
||||
tracing_subscriber::fmt()
|
||||
.with_writer(std::io::stderr)
|
||||
.with_env_filter(
|
||||
EnvFilter::try_from_default_env()
|
||||
.unwrap_or_else(|_| EnvFilter::new(&cli.log_level))
|
||||
)
|
||||
.init();
|
||||
|
||||
// Management mode: JSON IPC over stdin/stdout
|
||||
if cli.management {
|
||||
tracing::info!("RustProxy starting in management mode...");
|
||||
return management::management_loop().await;
|
||||
}
|
||||
|
||||
tracing::info!("RustProxy starting...");
|
||||
|
||||
// Load configuration
|
||||
let options = RustProxyOptions::from_file(&cli.config)
|
||||
.map_err(|e| anyhow::anyhow!("Failed to load config '{}': {}", cli.config, e))?;
|
||||
|
||||
tracing::info!(
|
||||
"Loaded {} routes from {}",
|
||||
options.routes.len(),
|
||||
cli.config
|
||||
);
|
||||
|
||||
// Validate-only mode
|
||||
if cli.validate {
|
||||
match rustproxy_config::validate_routes(&options.routes) {
|
||||
Ok(()) => {
|
||||
tracing::info!("Configuration is valid");
|
||||
return Ok(());
|
||||
}
|
||||
Err(errors) => {
|
||||
for err in &errors {
|
||||
tracing::error!("Validation error: {}", err);
|
||||
}
|
||||
anyhow::bail!("{} validation errors found", errors.len());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create and start proxy
|
||||
let mut proxy = RustProxy::new(options)?;
|
||||
proxy.start().await?;
|
||||
|
||||
// Wait for shutdown signal
|
||||
tracing::info!("RustProxy is running. Press Ctrl+C to stop.");
|
||||
tokio::signal::ctrl_c().await?;
|
||||
|
||||
tracing::info!("Shutdown signal received");
|
||||
proxy.stop().await?;
|
||||
|
||||
tracing::info!("RustProxy shutdown complete");
|
||||
Ok(())
|
||||
}
|
||||
470
rust/crates/rustproxy/src/management.rs
Normal file
470
rust/crates/rustproxy/src/management.rs
Normal file
@@ -0,0 +1,470 @@
|
||||
use anyhow::Result;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::io::{AsyncBufReadExt, BufReader};
|
||||
use tracing::{info, error};
|
||||
|
||||
use crate::RustProxy;
|
||||
use rustproxy_config::RustProxyOptions;
|
||||
|
||||
/// A management request from the TypeScript wrapper.
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct ManagementRequest {
|
||||
pub id: String,
|
||||
pub method: String,
|
||||
#[serde(default)]
|
||||
pub params: serde_json::Value,
|
||||
}
|
||||
|
||||
/// A management response back to the TypeScript wrapper.
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct ManagementResponse {
|
||||
pub id: String,
|
||||
pub success: bool,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub result: Option<serde_json::Value>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub error: Option<String>,
|
||||
}
|
||||
|
||||
/// An unsolicited event from the proxy to the TypeScript wrapper.
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct ManagementEvent {
|
||||
pub event: String,
|
||||
pub data: serde_json::Value,
|
||||
}
|
||||
|
||||
impl ManagementResponse {
|
||||
fn ok(id: String, result: serde_json::Value) -> Self {
|
||||
Self {
|
||||
id,
|
||||
success: true,
|
||||
result: Some(result),
|
||||
error: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn err(id: String, message: String) -> Self {
|
||||
Self {
|
||||
id,
|
||||
success: false,
|
||||
result: None,
|
||||
error: Some(message),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn send_line(line: &str) {
|
||||
// Use blocking stdout write - we're writing short JSON lines
|
||||
use std::io::Write;
|
||||
let stdout = std::io::stdout();
|
||||
let mut handle = stdout.lock();
|
||||
let _ = handle.write_all(line.as_bytes());
|
||||
let _ = handle.write_all(b"\n");
|
||||
let _ = handle.flush();
|
||||
}
|
||||
|
||||
fn send_response(response: &ManagementResponse) {
|
||||
match serde_json::to_string(response) {
|
||||
Ok(json) => send_line(&json),
|
||||
Err(e) => error!("Failed to serialize management response: {}", e),
|
||||
}
|
||||
}
|
||||
|
||||
fn send_event(event: &str, data: serde_json::Value) {
|
||||
let evt = ManagementEvent {
|
||||
event: event.to_string(),
|
||||
data,
|
||||
};
|
||||
match serde_json::to_string(&evt) {
|
||||
Ok(json) => send_line(&json),
|
||||
Err(e) => error!("Failed to serialize management event: {}", e),
|
||||
}
|
||||
}
|
||||
|
||||
/// Run the management loop, reading JSON commands from stdin and writing responses to stdout.
|
||||
pub async fn management_loop() -> Result<()> {
|
||||
let stdin = BufReader::new(tokio::io::stdin());
|
||||
let mut lines = stdin.lines();
|
||||
let mut proxy: Option<RustProxy> = None;
|
||||
|
||||
send_event("ready", serde_json::json!({}));
|
||||
|
||||
loop {
|
||||
let line = match lines.next_line().await {
|
||||
Ok(Some(line)) => line,
|
||||
Ok(None) => {
|
||||
// stdin closed - parent process exited
|
||||
info!("Management stdin closed, shutting down");
|
||||
if let Some(ref mut p) = proxy {
|
||||
let _ = p.stop().await;
|
||||
}
|
||||
break;
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Error reading management stdin: {}", e);
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
let line = line.trim().to_string();
|
||||
if line.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let request: ManagementRequest = match serde_json::from_str(&line) {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
error!("Failed to parse management request: {}", e);
|
||||
// Send error response without an ID
|
||||
send_response(&ManagementResponse::err(
|
||||
"unknown".to_string(),
|
||||
format!("Failed to parse request: {}", e),
|
||||
));
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let response = handle_request(&request, &mut proxy).await;
|
||||
send_response(&response);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn handle_request(
|
||||
request: &ManagementRequest,
|
||||
proxy: &mut Option<RustProxy>,
|
||||
) -> ManagementResponse {
|
||||
let id = request.id.clone();
|
||||
|
||||
match request.method.as_str() {
|
||||
"start" => handle_start(&id, &request.params, proxy).await,
|
||||
"stop" => handle_stop(&id, proxy).await,
|
||||
"updateRoutes" => handle_update_routes(&id, &request.params, proxy).await,
|
||||
"getMetrics" => handle_get_metrics(&id, proxy),
|
||||
"getStatistics" => handle_get_statistics(&id, proxy),
|
||||
"provisionCertificate" => handle_provision_certificate(&id, &request.params, proxy).await,
|
||||
"renewCertificate" => handle_renew_certificate(&id, &request.params, proxy).await,
|
||||
"getCertificateStatus" => handle_get_certificate_status(&id, &request.params, proxy).await,
|
||||
"getListeningPorts" => handle_get_listening_ports(&id, proxy),
|
||||
"getNftablesStatus" => handle_get_nftables_status(&id, proxy).await,
|
||||
"setSocketHandlerRelay" => handle_set_socket_handler_relay(&id, &request.params, proxy).await,
|
||||
"addListeningPort" => handle_add_listening_port(&id, &request.params, proxy).await,
|
||||
"removeListeningPort" => handle_remove_listening_port(&id, &request.params, proxy).await,
|
||||
"loadCertificate" => handle_load_certificate(&id, &request.params, proxy).await,
|
||||
_ => ManagementResponse::err(id, format!("Unknown method: {}", request.method)),
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_start(
|
||||
id: &str,
|
||||
params: &serde_json::Value,
|
||||
proxy: &mut Option<RustProxy>,
|
||||
) -> ManagementResponse {
|
||||
if proxy.is_some() {
|
||||
return ManagementResponse::err(id.to_string(), "Proxy is already running".to_string());
|
||||
}
|
||||
|
||||
let config = match params.get("config") {
|
||||
Some(config) => config,
|
||||
None => return ManagementResponse::err(id.to_string(), "Missing 'config' parameter".to_string()),
|
||||
};
|
||||
|
||||
let options: RustProxyOptions = match serde_json::from_value(config.clone()) {
|
||||
Ok(o) => o,
|
||||
Err(e) => return ManagementResponse::err(id.to_string(), format!("Invalid config: {}", e)),
|
||||
};
|
||||
|
||||
match RustProxy::new(options) {
|
||||
Ok(mut p) => {
|
||||
match p.start().await {
|
||||
Ok(()) => {
|
||||
send_event("started", serde_json::json!({}));
|
||||
*proxy = Some(p);
|
||||
ManagementResponse::ok(id.to_string(), serde_json::json!({}))
|
||||
}
|
||||
Err(e) => {
|
||||
send_event("error", serde_json::json!({"message": format!("{}", e)}));
|
||||
ManagementResponse::err(id.to_string(), format!("Failed to start: {}", e))
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to create proxy: {}", e)),
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_stop(
|
||||
id: &str,
|
||||
proxy: &mut Option<RustProxy>,
|
||||
) -> ManagementResponse {
|
||||
match proxy.as_mut() {
|
||||
Some(p) => {
|
||||
match p.stop().await {
|
||||
Ok(()) => {
|
||||
*proxy = None;
|
||||
send_event("stopped", serde_json::json!({}));
|
||||
ManagementResponse::ok(id.to_string(), serde_json::json!({}))
|
||||
}
|
||||
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to stop: {}", e)),
|
||||
}
|
||||
}
|
||||
None => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_update_routes(
|
||||
id: &str,
|
||||
params: &serde_json::Value,
|
||||
proxy: &mut Option<RustProxy>,
|
||||
) -> ManagementResponse {
|
||||
let p = match proxy.as_mut() {
|
||||
Some(p) => p,
|
||||
None => return ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
|
||||
};
|
||||
|
||||
let routes = match params.get("routes") {
|
||||
Some(routes) => routes,
|
||||
None => return ManagementResponse::err(id.to_string(), "Missing 'routes' parameter".to_string()),
|
||||
};
|
||||
|
||||
let routes: Vec<rustproxy_config::RouteConfig> = match serde_json::from_value(routes.clone()) {
|
||||
Ok(r) => r,
|
||||
Err(e) => return ManagementResponse::err(id.to_string(), format!("Invalid routes: {}", e)),
|
||||
};
|
||||
|
||||
match p.update_routes(routes).await {
|
||||
Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to update routes: {}", e)),
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_get_metrics(
|
||||
id: &str,
|
||||
proxy: &Option<RustProxy>,
|
||||
) -> ManagementResponse {
|
||||
match proxy.as_ref() {
|
||||
Some(p) => {
|
||||
let metrics = p.get_metrics();
|
||||
match serde_json::to_value(&metrics) {
|
||||
Ok(v) => ManagementResponse::ok(id.to_string(), v),
|
||||
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to serialize metrics: {}", e)),
|
||||
}
|
||||
}
|
||||
None => ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_get_statistics(
|
||||
id: &str,
|
||||
proxy: &Option<RustProxy>,
|
||||
) -> ManagementResponse {
|
||||
match proxy.as_ref() {
|
||||
Some(p) => {
|
||||
let stats = p.get_statistics();
|
||||
match serde_json::to_value(&stats) {
|
||||
Ok(v) => ManagementResponse::ok(id.to_string(), v),
|
||||
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to serialize statistics: {}", e)),
|
||||
}
|
||||
}
|
||||
None => ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_provision_certificate(
|
||||
id: &str,
|
||||
params: &serde_json::Value,
|
||||
proxy: &mut Option<RustProxy>,
|
||||
) -> ManagementResponse {
|
||||
let p = match proxy.as_mut() {
|
||||
Some(p) => p,
|
||||
None => return ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
|
||||
};
|
||||
|
||||
let route_name = match params.get("routeName").and_then(|v| v.as_str()) {
|
||||
Some(name) => name.to_string(),
|
||||
None => return ManagementResponse::err(id.to_string(), "Missing 'routeName' parameter".to_string()),
|
||||
};
|
||||
|
||||
match p.provision_certificate(&route_name).await {
|
||||
Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to provision certificate: {}", e)),
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_renew_certificate(
|
||||
id: &str,
|
||||
params: &serde_json::Value,
|
||||
proxy: &mut Option<RustProxy>,
|
||||
) -> ManagementResponse {
|
||||
let p = match proxy.as_mut() {
|
||||
Some(p) => p,
|
||||
None => return ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
|
||||
};
|
||||
|
||||
let route_name = match params.get("routeName").and_then(|v| v.as_str()) {
|
||||
Some(name) => name.to_string(),
|
||||
None => return ManagementResponse::err(id.to_string(), "Missing 'routeName' parameter".to_string()),
|
||||
};
|
||||
|
||||
match p.renew_certificate(&route_name).await {
|
||||
Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to renew certificate: {}", e)),
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_get_certificate_status(
|
||||
id: &str,
|
||||
params: &serde_json::Value,
|
||||
proxy: &Option<RustProxy>,
|
||||
) -> ManagementResponse {
|
||||
let p = match proxy.as_ref() {
|
||||
Some(p) => p,
|
||||
None => return ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
|
||||
};
|
||||
|
||||
let route_name = match params.get("routeName").and_then(|v| v.as_str()) {
|
||||
Some(name) => name,
|
||||
None => return ManagementResponse::err(id.to_string(), "Missing 'routeName' parameter".to_string()),
|
||||
};
|
||||
|
||||
match p.get_certificate_status(route_name).await {
|
||||
Some(status) => ManagementResponse::ok(id.to_string(), serde_json::json!({
|
||||
"domain": status.domain,
|
||||
"source": status.source,
|
||||
"expiresAt": status.expires_at,
|
||||
"isValid": status.is_valid,
|
||||
})),
|
||||
None => ManagementResponse::ok(id.to_string(), serde_json::Value::Null),
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_get_listening_ports(
|
||||
id: &str,
|
||||
proxy: &Option<RustProxy>,
|
||||
) -> ManagementResponse {
|
||||
match proxy.as_ref() {
|
||||
Some(p) => {
|
||||
let ports = p.get_listening_ports();
|
||||
ManagementResponse::ok(id.to_string(), serde_json::json!({ "ports": ports }))
|
||||
}
|
||||
None => ManagementResponse::ok(id.to_string(), serde_json::json!({ "ports": [] })),
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_get_nftables_status(
|
||||
id: &str,
|
||||
proxy: &Option<RustProxy>,
|
||||
) -> ManagementResponse {
|
||||
match proxy.as_ref() {
|
||||
Some(p) => {
|
||||
match p.get_nftables_status().await {
|
||||
Ok(status) => {
|
||||
match serde_json::to_value(&status) {
|
||||
Ok(v) => ManagementResponse::ok(id.to_string(), v),
|
||||
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to serialize: {}", e)),
|
||||
}
|
||||
}
|
||||
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to get status: {}", e)),
|
||||
}
|
||||
}
|
||||
None => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_set_socket_handler_relay(
|
||||
id: &str,
|
||||
params: &serde_json::Value,
|
||||
proxy: &mut Option<RustProxy>,
|
||||
) -> ManagementResponse {
|
||||
let p = match proxy.as_mut() {
|
||||
Some(p) => p,
|
||||
None => return ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
|
||||
};
|
||||
|
||||
let socket_path = params.get("socketPath")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|s| s.to_string());
|
||||
|
||||
info!("setSocketHandlerRelay: socket_path={:?}", socket_path);
|
||||
p.set_socket_handler_relay_path(socket_path);
|
||||
|
||||
ManagementResponse::ok(id.to_string(), serde_json::json!({}))
|
||||
}
|
||||
|
||||
async fn handle_add_listening_port(
|
||||
id: &str,
|
||||
params: &serde_json::Value,
|
||||
proxy: &mut Option<RustProxy>,
|
||||
) -> ManagementResponse {
|
||||
let p = match proxy.as_mut() {
|
||||
Some(p) => p,
|
||||
None => return ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
|
||||
};
|
||||
|
||||
let port = match params.get("port").and_then(|v| v.as_u64()) {
|
||||
Some(port) => port as u16,
|
||||
None => return ManagementResponse::err(id.to_string(), "Missing 'port' parameter".to_string()),
|
||||
};
|
||||
|
||||
match p.add_listening_port(port).await {
|
||||
Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to add port {}: {}", port, e)),
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_remove_listening_port(
|
||||
id: &str,
|
||||
params: &serde_json::Value,
|
||||
proxy: &mut Option<RustProxy>,
|
||||
) -> ManagementResponse {
|
||||
let p = match proxy.as_mut() {
|
||||
Some(p) => p,
|
||||
None => return ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
|
||||
};
|
||||
|
||||
let port = match params.get("port").and_then(|v| v.as_u64()) {
|
||||
Some(port) => port as u16,
|
||||
None => return ManagementResponse::err(id.to_string(), "Missing 'port' parameter".to_string()),
|
||||
};
|
||||
|
||||
match p.remove_listening_port(port).await {
|
||||
Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to remove port {}: {}", port, e)),
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_load_certificate(
|
||||
id: &str,
|
||||
params: &serde_json::Value,
|
||||
proxy: &mut Option<RustProxy>,
|
||||
) -> ManagementResponse {
|
||||
let p = match proxy.as_mut() {
|
||||
Some(p) => p,
|
||||
None => return ManagementResponse::err(id.to_string(), "Proxy is not running".to_string()),
|
||||
};
|
||||
|
||||
let domain = match params.get("domain").and_then(|v| v.as_str()) {
|
||||
Some(d) => d.to_string(),
|
||||
None => return ManagementResponse::err(id.to_string(), "Missing 'domain' parameter".to_string()),
|
||||
};
|
||||
|
||||
let cert = match params.get("cert").and_then(|v| v.as_str()) {
|
||||
Some(c) => c.to_string(),
|
||||
None => return ManagementResponse::err(id.to_string(), "Missing 'cert' parameter".to_string()),
|
||||
};
|
||||
|
||||
let key = match params.get("key").and_then(|v| v.as_str()) {
|
||||
Some(k) => k.to_string(),
|
||||
None => return ManagementResponse::err(id.to_string(), "Missing 'key' parameter".to_string()),
|
||||
};
|
||||
|
||||
let ca = params.get("ca").and_then(|v| v.as_str()).map(|s| s.to_string());
|
||||
|
||||
info!("loadCertificate: domain={}", domain);
|
||||
|
||||
// Load cert into cert manager and hot-swap TLS config
|
||||
match p.load_certificate(&domain, cert, key, ca).await {
|
||||
Ok(()) => ManagementResponse::ok(id.to_string(), serde_json::json!({})),
|
||||
Err(e) => ManagementResponse::err(id.to_string(), format!("Failed to load certificate for {}: {}", domain, e)),
|
||||
}
|
||||
}
|
||||
402
rust/crates/rustproxy/tests/common/mod.rs
Normal file
402
rust/crates/rustproxy/tests/common/mod.rs
Normal file
@@ -0,0 +1,402 @@
|
||||
use std::sync::atomic::{AtomicU16, Ordering};
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::task::JoinHandle;
|
||||
|
||||
/// Atomic port allocator starting at 19000 to avoid collisions.
|
||||
static PORT_COUNTER: AtomicU16 = AtomicU16::new(19000);
|
||||
|
||||
/// Get the next available port for testing.
|
||||
pub fn next_port() -> u16 {
|
||||
PORT_COUNTER.fetch_add(1, Ordering::SeqCst)
|
||||
}
|
||||
|
||||
/// Start a simple TCP echo server that echoes back whatever it receives.
|
||||
/// Returns the join handle for the server task.
|
||||
pub async fn start_echo_server(port: u16) -> JoinHandle<()> {
|
||||
let listener = TcpListener::bind(format!("127.0.0.1:{}", port))
|
||||
.await
|
||||
.expect("Failed to bind echo server");
|
||||
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
let (mut stream, _) = match listener.accept().await {
|
||||
Ok(conn) => conn,
|
||||
Err(_) => break,
|
||||
};
|
||||
tokio::spawn(async move {
|
||||
let mut buf = vec![0u8; 65536];
|
||||
loop {
|
||||
let n = match stream.read(&mut buf).await {
|
||||
Ok(0) | Err(_) => break,
|
||||
Ok(n) => n,
|
||||
};
|
||||
if stream.write_all(&buf[..n]).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Start a TCP echo server that prefixes responses to identify which backend responded.
|
||||
pub async fn start_prefix_echo_server(port: u16, prefix: &str) -> JoinHandle<()> {
|
||||
let prefix = prefix.to_string();
|
||||
let listener = TcpListener::bind(format!("127.0.0.1:{}", port))
|
||||
.await
|
||||
.expect("Failed to bind prefix echo server");
|
||||
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
let (mut stream, _) = match listener.accept().await {
|
||||
Ok(conn) => conn,
|
||||
Err(_) => break,
|
||||
};
|
||||
let pfx = prefix.clone();
|
||||
tokio::spawn(async move {
|
||||
let mut buf = vec![0u8; 65536];
|
||||
loop {
|
||||
let n = match stream.read(&mut buf).await {
|
||||
Ok(0) | Err(_) => break,
|
||||
Ok(n) => n,
|
||||
};
|
||||
let mut response = pfx.as_bytes().to_vec();
|
||||
response.extend_from_slice(&buf[..n]);
|
||||
if stream.write_all(&response).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Start a simple HTTP server that responds with a fixed status and body.
|
||||
pub async fn start_http_server(port: u16, status: u16, body: &str) -> JoinHandle<()> {
|
||||
let body = body.to_string();
|
||||
let listener = TcpListener::bind(format!("127.0.0.1:{}", port))
|
||||
.await
|
||||
.expect("Failed to bind HTTP server");
|
||||
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
let (mut stream, _) = match listener.accept().await {
|
||||
Ok(conn) => conn,
|
||||
Err(_) => break,
|
||||
};
|
||||
let b = body.clone();
|
||||
tokio::spawn(async move {
|
||||
let mut buf = vec![0u8; 8192];
|
||||
// Read the request
|
||||
let _n = stream.read(&mut buf).await.unwrap_or(0);
|
||||
// Send response
|
||||
let response = format!(
|
||||
"HTTP/1.1 {} OK\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
|
||||
status,
|
||||
b.len(),
|
||||
b,
|
||||
);
|
||||
let _ = stream.write_all(response.as_bytes()).await;
|
||||
let _ = stream.shutdown().await;
|
||||
});
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Start an HTTP backend server that echoes back request details as JSON.
|
||||
/// The response body contains: {"method":"GET","path":"/foo","host":"example.com","backend":"<name>"}
|
||||
/// Supports keep-alive by reading HTTP requests properly.
|
||||
pub async fn start_http_echo_backend(port: u16, backend_name: &str) -> JoinHandle<()> {
|
||||
let name = backend_name.to_string();
|
||||
let listener = TcpListener::bind(format!("127.0.0.1:{}", port))
|
||||
.await
|
||||
.unwrap_or_else(|_| panic!("Failed to bind HTTP echo backend on port {}", port));
|
||||
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
let (mut stream, _) = match listener.accept().await {
|
||||
Ok(conn) => conn,
|
||||
Err(_) => break,
|
||||
};
|
||||
let backend = name.clone();
|
||||
tokio::spawn(async move {
|
||||
let mut buf = vec![0u8; 16384];
|
||||
// Read request data
|
||||
let n = match stream.read(&mut buf).await {
|
||||
Ok(0) | Err(_) => return,
|
||||
Ok(n) => n,
|
||||
};
|
||||
let req_str = String::from_utf8_lossy(&buf[..n]);
|
||||
|
||||
// Parse first line: METHOD PATH HTTP/x.x
|
||||
let first_line = req_str.lines().next().unwrap_or("");
|
||||
let parts: Vec<&str> = first_line.split_whitespace().collect();
|
||||
let method = parts.first().copied().unwrap_or("UNKNOWN");
|
||||
let path = parts.get(1).copied().unwrap_or("/");
|
||||
|
||||
// Extract Host header
|
||||
let host = req_str.lines()
|
||||
.find(|l| l.to_lowercase().starts_with("host:"))
|
||||
.map(|l| l[5..].trim())
|
||||
.unwrap_or("unknown");
|
||||
|
||||
let body = format!(
|
||||
r#"{{"method":"{}","path":"{}","host":"{}","backend":"{}"}}"#,
|
||||
method, path, host, backend
|
||||
);
|
||||
|
||||
let response = format!(
|
||||
"HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
|
||||
body.len(),
|
||||
body,
|
||||
);
|
||||
let _ = stream.write_all(response.as_bytes()).await;
|
||||
let _ = stream.shutdown().await;
|
||||
});
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Wrap a future with a timeout, preventing tests from hanging.
|
||||
pub async fn with_timeout<F, T>(future: F, secs: u64) -> Result<T, &'static str>
|
||||
where
|
||||
F: std::future::Future<Output = T>,
|
||||
{
|
||||
match tokio::time::timeout(std::time::Duration::from_secs(secs), future).await {
|
||||
Ok(result) => Ok(result),
|
||||
Err(_) => Err("Test timed out"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Wait briefly for a server to be ready by attempting TCP connections.
|
||||
pub async fn wait_for_port(port: u16, timeout_ms: u64) -> bool {
|
||||
let start = std::time::Instant::now();
|
||||
let timeout = std::time::Duration::from_millis(timeout_ms);
|
||||
while start.elapsed() < timeout {
|
||||
if tokio::net::TcpStream::connect(format!("127.0.0.1:{}", port))
|
||||
.await
|
||||
.is_ok()
|
||||
{
|
||||
return true;
|
||||
}
|
||||
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
/// Helper to create a minimal route config for testing.
|
||||
pub fn make_test_route(
|
||||
port: u16,
|
||||
domain: Option<&str>,
|
||||
target_host: &str,
|
||||
target_port: u16,
|
||||
) -> rustproxy_config::RouteConfig {
|
||||
rustproxy_config::RouteConfig {
|
||||
id: None,
|
||||
route_match: rustproxy_config::RouteMatch {
|
||||
ports: rustproxy_config::PortRange::Single(port),
|
||||
domains: domain.map(|d| rustproxy_config::DomainSpec::Single(d.to_string())),
|
||||
path: None,
|
||||
client_ip: None,
|
||||
tls_version: None,
|
||||
headers: None,
|
||||
},
|
||||
action: rustproxy_config::RouteAction {
|
||||
action_type: rustproxy_config::RouteActionType::Forward,
|
||||
targets: Some(vec![rustproxy_config::RouteTarget {
|
||||
target_match: None,
|
||||
host: rustproxy_config::HostSpec::Single(target_host.to_string()),
|
||||
port: rustproxy_config::PortSpec::Fixed(target_port),
|
||||
tls: None,
|
||||
websocket: None,
|
||||
load_balancing: None,
|
||||
send_proxy_protocol: None,
|
||||
headers: None,
|
||||
advanced: None,
|
||||
priority: None,
|
||||
}]),
|
||||
tls: None,
|
||||
websocket: None,
|
||||
load_balancing: None,
|
||||
advanced: None,
|
||||
options: None,
|
||||
forwarding_engine: None,
|
||||
nftables: None,
|
||||
send_proxy_protocol: None,
|
||||
},
|
||||
headers: None,
|
||||
security: None,
|
||||
name: None,
|
||||
description: None,
|
||||
priority: None,
|
||||
tags: None,
|
||||
enabled: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Start a simple WebSocket echo backend.
|
||||
///
|
||||
/// Accepts WebSocket upgrade requests (HTTP Upgrade: websocket), sends 101 back,
|
||||
/// then echoes all data received on the connection.
|
||||
pub async fn start_ws_echo_backend(port: u16) -> JoinHandle<()> {
|
||||
let listener = TcpListener::bind(format!("127.0.0.1:{}", port))
|
||||
.await
|
||||
.unwrap_or_else(|_| panic!("Failed to bind WS echo backend on port {}", port));
|
||||
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
let (mut stream, _) = match listener.accept().await {
|
||||
Ok(conn) => conn,
|
||||
Err(_) => break,
|
||||
};
|
||||
tokio::spawn(async move {
|
||||
// Read the HTTP upgrade request
|
||||
let mut buf = vec![0u8; 4096];
|
||||
let n = match stream.read(&mut buf).await {
|
||||
Ok(0) | Err(_) => return,
|
||||
Ok(n) => n,
|
||||
};
|
||||
|
||||
let req_str = String::from_utf8_lossy(&buf[..n]);
|
||||
|
||||
// Extract Sec-WebSocket-Key for proper handshake
|
||||
let ws_key = req_str.lines()
|
||||
.find(|l| l.to_lowercase().starts_with("sec-websocket-key:"))
|
||||
.map(|l| l.split(':').nth(1).unwrap_or("").trim().to_string())
|
||||
.unwrap_or_default();
|
||||
|
||||
// Compute Sec-WebSocket-Accept (simplified - just echo for test purposes)
|
||||
// Real implementation would compute SHA-1 + base64
|
||||
let accept_response = format!(
|
||||
"HTTP/1.1 101 Switching Protocols\r\n\
|
||||
Upgrade: websocket\r\n\
|
||||
Connection: Upgrade\r\n\
|
||||
Sec-WebSocket-Accept: {}\r\n\
|
||||
\r\n",
|
||||
ws_key
|
||||
);
|
||||
|
||||
if stream.write_all(accept_response.as_bytes()).await.is_err() {
|
||||
return;
|
||||
}
|
||||
|
||||
// Echo all data back (raw TCP after upgrade)
|
||||
let mut echo_buf = vec![0u8; 65536];
|
||||
loop {
|
||||
let n = match stream.read(&mut echo_buf).await {
|
||||
Ok(0) | Err(_) => break,
|
||||
Ok(n) => n,
|
||||
};
|
||||
if stream.write_all(&echo_buf[..n]).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Generate a self-signed certificate for testing using rcgen.
|
||||
/// Returns (cert_pem, key_pem).
|
||||
pub fn generate_self_signed_cert(domain: &str) -> (String, String) {
|
||||
use rcgen::{CertificateParams, KeyPair};
|
||||
|
||||
let mut params = CertificateParams::new(vec![domain.to_string()]).unwrap();
|
||||
params.distinguished_name.push(rcgen::DnType::CommonName, domain);
|
||||
|
||||
let key_pair = KeyPair::generate().unwrap();
|
||||
let cert = params.self_signed(&key_pair).unwrap();
|
||||
|
||||
(cert.pem(), key_pair.serialize_pem())
|
||||
}
|
||||
|
||||
/// Start a TLS echo server using the given cert/key.
|
||||
/// Returns the join handle.
|
||||
pub async fn start_tls_echo_server(port: u16, cert_pem: &str, key_pem: &str) -> JoinHandle<()> {
|
||||
use std::sync::Arc;
|
||||
|
||||
let acceptor = rustproxy_passthrough::build_tls_acceptor(cert_pem, key_pem)
|
||||
.expect("Failed to build TLS acceptor");
|
||||
let acceptor = Arc::new(acceptor);
|
||||
|
||||
let listener = TcpListener::bind(format!("127.0.0.1:{}", port))
|
||||
.await
|
||||
.expect("Failed to bind TLS echo server");
|
||||
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
let (stream, _) = match listener.accept().await {
|
||||
Ok(conn) => conn,
|
||||
Err(_) => break,
|
||||
};
|
||||
let acc = acceptor.clone();
|
||||
tokio::spawn(async move {
|
||||
let mut tls_stream = match acc.accept(stream).await {
|
||||
Ok(s) => s,
|
||||
Err(_) => return,
|
||||
};
|
||||
let mut buf = vec![0u8; 65536];
|
||||
loop {
|
||||
let n = match tls_stream.read(&mut buf).await {
|
||||
Ok(0) | Err(_) => break,
|
||||
Ok(n) => n,
|
||||
};
|
||||
if tls_stream.write_all(&buf[..n]).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Helper to create a TLS terminate route with static cert for testing.
|
||||
pub fn make_tls_terminate_route(
|
||||
port: u16,
|
||||
domain: &str,
|
||||
target_host: &str,
|
||||
target_port: u16,
|
||||
cert_pem: &str,
|
||||
key_pem: &str,
|
||||
) -> rustproxy_config::RouteConfig {
|
||||
let mut route = make_test_route(port, Some(domain), target_host, target_port);
|
||||
route.action.tls = Some(rustproxy_config::RouteTls {
|
||||
mode: rustproxy_config::TlsMode::Terminate,
|
||||
certificate: Some(rustproxy_config::CertificateSpec::Static(
|
||||
rustproxy_config::CertificateConfig {
|
||||
cert: cert_pem.to_string(),
|
||||
key: key_pem.to_string(),
|
||||
ca: None,
|
||||
key_file: None,
|
||||
cert_file: None,
|
||||
},
|
||||
)),
|
||||
acme: None,
|
||||
versions: None,
|
||||
ciphers: None,
|
||||
honor_cipher_order: None,
|
||||
session_timeout: None,
|
||||
});
|
||||
route
|
||||
}
|
||||
|
||||
/// Helper to create a TLS passthrough route for testing.
|
||||
pub fn make_tls_passthrough_route(
|
||||
port: u16,
|
||||
domain: Option<&str>,
|
||||
target_host: &str,
|
||||
target_port: u16,
|
||||
) -> rustproxy_config::RouteConfig {
|
||||
let mut route = make_test_route(port, domain, target_host, target_port);
|
||||
route.action.tls = Some(rustproxy_config::RouteTls {
|
||||
mode: rustproxy_config::TlsMode::Passthrough,
|
||||
certificate: None,
|
||||
acme: None,
|
||||
versions: None,
|
||||
ciphers: None,
|
||||
honor_cipher_order: None,
|
||||
session_timeout: None,
|
||||
});
|
||||
route
|
||||
}
|
||||
453
rust/crates/rustproxy/tests/integration_http_proxy.rs
Normal file
453
rust/crates/rustproxy/tests/integration_http_proxy.rs
Normal file
@@ -0,0 +1,453 @@
|
||||
mod common;
|
||||
|
||||
use common::*;
|
||||
use rustproxy::RustProxy;
|
||||
use rustproxy_config::RustProxyOptions;
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
|
||||
/// Send a raw HTTP request and return the full response as a string.
|
||||
async fn send_http_request(port: u16, host: &str, method: &str, path: &str) -> String {
|
||||
let mut stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", port))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let request = format!(
|
||||
"{} {} HTTP/1.1\r\nHost: {}\r\nConnection: close\r\n\r\n",
|
||||
method, path, host,
|
||||
);
|
||||
stream.write_all(request.as_bytes()).await.unwrap();
|
||||
|
||||
let mut response = Vec::new();
|
||||
stream.read_to_end(&mut response).await.unwrap();
|
||||
String::from_utf8_lossy(&response).to_string()
|
||||
}
|
||||
|
||||
/// Extract the body from a raw HTTP response string (after the \r\n\r\n).
|
||||
fn extract_body(response: &str) -> &str {
|
||||
response.split("\r\n\r\n").nth(1).unwrap_or("")
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_http_forward_basic() {
|
||||
let backend_port = next_port();
|
||||
let proxy_port = next_port();
|
||||
|
||||
let _backend = start_http_echo_backend(backend_port, "main").await;
|
||||
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![make_test_route(proxy_port, None, "127.0.0.1", backend_port)],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut proxy = RustProxy::new(options).unwrap();
|
||||
proxy.start().await.unwrap();
|
||||
assert!(wait_for_port(proxy_port, 2000).await);
|
||||
|
||||
let result = with_timeout(async {
|
||||
let response = send_http_request(proxy_port, "anyhost.com", "GET", "/hello").await;
|
||||
let body = extract_body(&response);
|
||||
body.to_string()
|
||||
}, 10)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(result.contains(r#""method":"GET"#), "Expected GET method, got: {}", result);
|
||||
assert!(result.contains(r#""path":"/hello"#), "Expected /hello path, got: {}", result);
|
||||
assert!(result.contains(r#""backend":"main"#), "Expected main backend, got: {}", result);
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_http_forward_host_routing() {
|
||||
let backend1_port = next_port();
|
||||
let backend2_port = next_port();
|
||||
let proxy_port = next_port();
|
||||
|
||||
let _b1 = start_http_echo_backend(backend1_port, "alpha").await;
|
||||
let _b2 = start_http_echo_backend(backend2_port, "beta").await;
|
||||
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![
|
||||
make_test_route(proxy_port, Some("alpha.example.com"), "127.0.0.1", backend1_port),
|
||||
make_test_route(proxy_port, Some("beta.example.com"), "127.0.0.1", backend2_port),
|
||||
],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut proxy = RustProxy::new(options).unwrap();
|
||||
proxy.start().await.unwrap();
|
||||
assert!(wait_for_port(proxy_port, 2000).await);
|
||||
|
||||
// Test alpha domain
|
||||
let alpha_result = with_timeout(async {
|
||||
let response = send_http_request(proxy_port, "alpha.example.com", "GET", "/").await;
|
||||
extract_body(&response).to_string()
|
||||
}, 10)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(alpha_result.contains(r#""backend":"alpha"#), "Expected alpha backend, got: {}", alpha_result);
|
||||
|
||||
// Test beta domain
|
||||
let beta_result = with_timeout(async {
|
||||
let response = send_http_request(proxy_port, "beta.example.com", "GET", "/").await;
|
||||
extract_body(&response).to_string()
|
||||
}, 10)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(beta_result.contains(r#""backend":"beta"#), "Expected beta backend, got: {}", beta_result);
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_http_forward_path_routing() {
|
||||
let backend1_port = next_port();
|
||||
let backend2_port = next_port();
|
||||
let proxy_port = next_port();
|
||||
|
||||
let _b1 = start_http_echo_backend(backend1_port, "api").await;
|
||||
let _b2 = start_http_echo_backend(backend2_port, "web").await;
|
||||
|
||||
let mut api_route = make_test_route(proxy_port, None, "127.0.0.1", backend1_port);
|
||||
api_route.route_match.path = Some("/api/**".to_string());
|
||||
api_route.priority = Some(10);
|
||||
|
||||
let web_route = make_test_route(proxy_port, None, "127.0.0.1", backend2_port);
|
||||
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![api_route, web_route],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut proxy = RustProxy::new(options).unwrap();
|
||||
proxy.start().await.unwrap();
|
||||
assert!(wait_for_port(proxy_port, 2000).await);
|
||||
|
||||
// Test API path
|
||||
let api_result = with_timeout(async {
|
||||
let response = send_http_request(proxy_port, "any.com", "GET", "/api/users").await;
|
||||
extract_body(&response).to_string()
|
||||
}, 10)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(api_result.contains(r#""backend":"api"#), "Expected api backend, got: {}", api_result);
|
||||
|
||||
// Test web path (no /api prefix)
|
||||
let web_result = with_timeout(async {
|
||||
let response = send_http_request(proxy_port, "any.com", "GET", "/index.html").await;
|
||||
extract_body(&response).to_string()
|
||||
}, 10)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(web_result.contains(r#""backend":"web"#), "Expected web backend, got: {}", web_result);
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_http_forward_cors_preflight() {
|
||||
let backend_port = next_port();
|
||||
let proxy_port = next_port();
|
||||
|
||||
let _backend = start_http_echo_backend(backend_port, "main").await;
|
||||
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![make_test_route(proxy_port, None, "127.0.0.1", backend_port)],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut proxy = RustProxy::new(options).unwrap();
|
||||
proxy.start().await.unwrap();
|
||||
assert!(wait_for_port(proxy_port, 2000).await);
|
||||
|
||||
let result = with_timeout(async {
|
||||
let mut stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Send CORS preflight request
|
||||
let request = format!(
|
||||
"OPTIONS /api/data HTTP/1.1\r\nHost: example.com\r\nOrigin: http://localhost:3000\r\nAccess-Control-Request-Method: POST\r\nConnection: close\r\n\r\n",
|
||||
);
|
||||
stream.write_all(request.as_bytes()).await.unwrap();
|
||||
|
||||
let mut response = Vec::new();
|
||||
stream.read_to_end(&mut response).await.unwrap();
|
||||
String::from_utf8_lossy(&response).to_string()
|
||||
}, 10)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Should get 204 No Content with CORS headers
|
||||
assert!(result.contains("204"), "Expected 204 status, got: {}", result);
|
||||
assert!(result.to_lowercase().contains("access-control-allow-origin"),
|
||||
"Expected CORS header, got: {}", result);
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_http_forward_backend_error() {
|
||||
let backend_port = next_port();
|
||||
let proxy_port = next_port();
|
||||
|
||||
// Start an HTTP server that returns 500
|
||||
let _backend = start_http_server(backend_port, 500, "Internal Error").await;
|
||||
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![make_test_route(proxy_port, None, "127.0.0.1", backend_port)],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut proxy = RustProxy::new(options).unwrap();
|
||||
proxy.start().await.unwrap();
|
||||
assert!(wait_for_port(proxy_port, 2000).await);
|
||||
|
||||
let result = with_timeout(async {
|
||||
let response = send_http_request(proxy_port, "example.com", "GET", "/fail").await;
|
||||
response
|
||||
}, 10)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Proxy should relay the 500 from backend
|
||||
assert!(result.contains("500"), "Expected 500 status, got: {}", result);
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_http_forward_no_route_matched() {
|
||||
let proxy_port = next_port();
|
||||
|
||||
// Create a route only for a specific domain
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![make_test_route(proxy_port, Some("known.example.com"), "127.0.0.1", 9999)],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut proxy = RustProxy::new(options).unwrap();
|
||||
proxy.start().await.unwrap();
|
||||
assert!(wait_for_port(proxy_port, 2000).await);
|
||||
|
||||
let result = with_timeout(async {
|
||||
let response = send_http_request(proxy_port, "unknown.example.com", "GET", "/").await;
|
||||
response
|
||||
}, 10)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Should get 502 Bad Gateway (no route matched)
|
||||
assert!(result.contains("502"), "Expected 502 status, got: {}", result);
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_http_forward_backend_unavailable() {
|
||||
let proxy_port = next_port();
|
||||
let dead_port = next_port(); // No server running here
|
||||
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![make_test_route(proxy_port, None, "127.0.0.1", dead_port)],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut proxy = RustProxy::new(options).unwrap();
|
||||
proxy.start().await.unwrap();
|
||||
assert!(wait_for_port(proxy_port, 2000).await);
|
||||
|
||||
let result = with_timeout(async {
|
||||
let response = send_http_request(proxy_port, "example.com", "GET", "/").await;
|
||||
response
|
||||
}, 10)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Should get 502 Bad Gateway (backend unavailable)
|
||||
assert!(result.contains("502"), "Expected 502 status, got: {}", result);
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_https_terminate_http_forward() {
|
||||
let backend_port = next_port();
|
||||
let proxy_port = next_port();
|
||||
let domain = "httpproxy.example.com";
|
||||
|
||||
let (cert_pem, key_pem) = generate_self_signed_cert(domain);
|
||||
let _backend = start_http_echo_backend(backend_port, "tls-backend").await;
|
||||
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![make_tls_terminate_route(
|
||||
proxy_port, domain, "127.0.0.1", backend_port, &cert_pem, &key_pem,
|
||||
)],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut proxy = RustProxy::new(options).unwrap();
|
||||
proxy.start().await.unwrap();
|
||||
assert!(wait_for_port(proxy_port, 2000).await);
|
||||
|
||||
let result = with_timeout(async {
|
||||
let _ = rustls::crypto::ring::default_provider().install_default();
|
||||
let tls_config = rustls::ClientConfig::builder()
|
||||
.dangerous()
|
||||
.with_custom_certificate_verifier(std::sync::Arc::new(InsecureVerifier))
|
||||
.with_no_client_auth();
|
||||
let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(tls_config));
|
||||
|
||||
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
let server_name = rustls::pki_types::ServerName::try_from(domain.to_string()).unwrap();
|
||||
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
|
||||
|
||||
// Send HTTP request through TLS
|
||||
let request = format!(
|
||||
"GET /api/data HTTP/1.1\r\nHost: {}\r\nConnection: close\r\n\r\n",
|
||||
domain
|
||||
);
|
||||
tls_stream.write_all(request.as_bytes()).await.unwrap();
|
||||
|
||||
let mut response = Vec::new();
|
||||
tls_stream.read_to_end(&mut response).await.unwrap();
|
||||
String::from_utf8_lossy(&response).to_string()
|
||||
}, 10)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let body = extract_body(&result);
|
||||
assert!(body.contains(r#""method":"GET"#), "Expected GET, got: {}", body);
|
||||
assert!(body.contains(r#""path":"/api/data"#), "Expected /api/data, got: {}", body);
|
||||
assert!(body.contains(r#""backend":"tls-backend"#), "Expected tls-backend, got: {}", body);
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_websocket_through_proxy() {
|
||||
let backend_port = next_port();
|
||||
let proxy_port = next_port();
|
||||
|
||||
let _backend = start_ws_echo_backend(backend_port).await;
|
||||
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![make_test_route(proxy_port, None, "127.0.0.1", backend_port)],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut proxy = RustProxy::new(options).unwrap();
|
||||
proxy.start().await.unwrap();
|
||||
assert!(wait_for_port(proxy_port, 2000).await);
|
||||
|
||||
let result = with_timeout(async {
|
||||
let mut stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Send WebSocket upgrade request
|
||||
let request = format!(
|
||||
"GET /ws HTTP/1.1\r\n\
|
||||
Host: example.com\r\n\
|
||||
Upgrade: websocket\r\n\
|
||||
Connection: Upgrade\r\n\
|
||||
Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n\
|
||||
Sec-WebSocket-Version: 13\r\n\
|
||||
\r\n"
|
||||
);
|
||||
stream.write_all(request.as_bytes()).await.unwrap();
|
||||
|
||||
// Read the 101 response
|
||||
let mut response_buf = Vec::with_capacity(4096);
|
||||
let mut temp = [0u8; 1];
|
||||
loop {
|
||||
let n = stream.read(&mut temp).await.unwrap();
|
||||
if n == 0 { break; }
|
||||
response_buf.push(temp[0]);
|
||||
if response_buf.len() >= 4 {
|
||||
let len = response_buf.len();
|
||||
if response_buf[len-4..] == *b"\r\n\r\n" {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let response_str = String::from_utf8_lossy(&response_buf).to_string();
|
||||
assert!(response_str.contains("101"), "Expected 101 Switching Protocols, got: {}", response_str);
|
||||
assert!(
|
||||
response_str.to_lowercase().contains("upgrade: websocket"),
|
||||
"Expected Upgrade header, got: {}",
|
||||
response_str
|
||||
);
|
||||
|
||||
// After upgrade, send data and verify echo
|
||||
let test_data = b"Hello WebSocket!";
|
||||
stream.write_all(test_data).await.unwrap();
|
||||
|
||||
// Read echoed data
|
||||
let mut echo_buf = vec![0u8; 256];
|
||||
let n = stream.read(&mut echo_buf).await.unwrap();
|
||||
let echoed = &echo_buf[..n];
|
||||
|
||||
assert_eq!(echoed, test_data, "Expected echo of sent data");
|
||||
|
||||
"ok".to_string()
|
||||
}, 10)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(result, "ok");
|
||||
proxy.stop().await.unwrap();
|
||||
}
|
||||
|
||||
/// InsecureVerifier for test TLS client connections.
|
||||
#[derive(Debug)]
|
||||
struct InsecureVerifier;
|
||||
|
||||
impl rustls::client::danger::ServerCertVerifier for InsecureVerifier {
|
||||
fn verify_server_cert(
|
||||
&self,
|
||||
_end_entity: &rustls::pki_types::CertificateDer<'_>,
|
||||
_intermediates: &[rustls::pki_types::CertificateDer<'_>],
|
||||
_server_name: &rustls::pki_types::ServerName<'_>,
|
||||
_ocsp_response: &[u8],
|
||||
_now: rustls::pki_types::UnixTime,
|
||||
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
|
||||
Ok(rustls::client::danger::ServerCertVerified::assertion())
|
||||
}
|
||||
|
||||
fn verify_tls12_signature(
|
||||
&self,
|
||||
_message: &[u8],
|
||||
_cert: &rustls::pki_types::CertificateDer<'_>,
|
||||
_dss: &rustls::DigitallySignedStruct,
|
||||
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
|
||||
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
|
||||
}
|
||||
|
||||
fn verify_tls13_signature(
|
||||
&self,
|
||||
_message: &[u8],
|
||||
_cert: &rustls::pki_types::CertificateDer<'_>,
|
||||
_dss: &rustls::DigitallySignedStruct,
|
||||
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
|
||||
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
|
||||
}
|
||||
|
||||
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
|
||||
vec![
|
||||
rustls::SignatureScheme::RSA_PKCS1_SHA256,
|
||||
rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
|
||||
rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
|
||||
rustls::SignatureScheme::ED25519,
|
||||
rustls::SignatureScheme::RSA_PSS_SHA256,
|
||||
]
|
||||
}
|
||||
}
|
||||
250
rust/crates/rustproxy/tests/integration_proxy_lifecycle.rs
Normal file
250
rust/crates/rustproxy/tests/integration_proxy_lifecycle.rs
Normal file
@@ -0,0 +1,250 @@
|
||||
mod common;
|
||||
|
||||
use common::*;
|
||||
use rustproxy::RustProxy;
|
||||
use rustproxy_config::RustProxyOptions;
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_start_and_stop() {
|
||||
let port = next_port();
|
||||
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![make_test_route(port, None, "127.0.0.1", 8080)],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut proxy = RustProxy::new(options).unwrap();
|
||||
|
||||
// Not listening before start
|
||||
assert!(!wait_for_port(port, 200).await);
|
||||
|
||||
proxy.start().await.unwrap();
|
||||
assert!(wait_for_port(port, 2000).await, "Port should be listening after start");
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
|
||||
// Give the OS a moment to release the port
|
||||
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
|
||||
assert!(!wait_for_port(port, 200).await, "Port should not be listening after stop");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_double_start_fails() {
|
||||
let port = next_port();
|
||||
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![make_test_route(port, None, "127.0.0.1", 8080)],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut proxy = RustProxy::new(options).unwrap();
|
||||
proxy.start().await.unwrap();
|
||||
|
||||
// Second start should fail
|
||||
let result = proxy.start().await;
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().to_string().contains("already started"));
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_update_routes_hot_reload() {
|
||||
let port = next_port();
|
||||
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![make_test_route(port, Some("old.example.com"), "127.0.0.1", 8080)],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut proxy = RustProxy::new(options).unwrap();
|
||||
proxy.start().await.unwrap();
|
||||
|
||||
// Update routes atomically
|
||||
let new_routes = vec![
|
||||
make_test_route(port, Some("new.example.com"), "127.0.0.1", 9090),
|
||||
];
|
||||
let result = proxy.update_routes(new_routes).await;
|
||||
assert!(result.is_ok());
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_add_remove_listening_port() {
|
||||
let port1 = next_port();
|
||||
let port2 = next_port();
|
||||
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![make_test_route(port1, None, "127.0.0.1", 8080)],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut proxy = RustProxy::new(options).unwrap();
|
||||
proxy.start().await.unwrap();
|
||||
assert!(wait_for_port(port1, 2000).await);
|
||||
|
||||
// Add a new port
|
||||
proxy.add_listening_port(port2).await.unwrap();
|
||||
assert!(wait_for_port(port2, 2000).await, "New port should be listening");
|
||||
|
||||
// Remove the port
|
||||
proxy.remove_listening_port(port2).await.unwrap();
|
||||
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
|
||||
assert!(!wait_for_port(port2, 200).await, "Removed port should not be listening");
|
||||
|
||||
// Original port should still be listening
|
||||
assert!(wait_for_port(port1, 200).await, "Original port should still be listening");
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_statistics() {
|
||||
let port = next_port();
|
||||
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![make_test_route(port, None, "127.0.0.1", 8080)],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut proxy = RustProxy::new(options).unwrap();
|
||||
proxy.start().await.unwrap();
|
||||
|
||||
let stats = proxy.get_statistics();
|
||||
assert_eq!(stats.routes_count, 1);
|
||||
assert!(stats.listening_ports.contains(&port));
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_invalid_routes_rejected() {
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![{
|
||||
let mut route = make_test_route(80, None, "127.0.0.1", 8080);
|
||||
route.action.targets = None; // Invalid: forward without targets
|
||||
route
|
||||
}],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let result = RustProxy::new(options);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_metrics_track_connections() {
|
||||
let backend_port = next_port();
|
||||
let proxy_port = next_port();
|
||||
|
||||
let _backend = start_echo_server(backend_port).await;
|
||||
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![make_test_route(proxy_port, None, "127.0.0.1", backend_port)],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut proxy = RustProxy::new(options).unwrap();
|
||||
proxy.start().await.unwrap();
|
||||
assert!(wait_for_port(proxy_port, 2000).await);
|
||||
|
||||
// No connections yet
|
||||
let stats = proxy.get_statistics();
|
||||
assert_eq!(stats.total_connections, 0);
|
||||
|
||||
// Make a connection and send data
|
||||
{
|
||||
let mut stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
stream.write_all(b"hello").await.unwrap();
|
||||
let mut buf = vec![0u8; 16];
|
||||
let _ = stream.read(&mut buf).await;
|
||||
}
|
||||
|
||||
// Small delay for metrics to update
|
||||
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
|
||||
|
||||
let stats = proxy.get_statistics();
|
||||
assert!(stats.total_connections > 0, "Expected total_connections > 0, got {}", stats.total_connections);
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_metrics_track_bytes() {
|
||||
let backend_port = next_port();
|
||||
let proxy_port = next_port();
|
||||
|
||||
let _backend = start_http_echo_backend(backend_port, "metrics-test").await;
|
||||
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![make_test_route(proxy_port, None, "127.0.0.1", backend_port)],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut proxy = RustProxy::new(options).unwrap();
|
||||
proxy.start().await.unwrap();
|
||||
assert!(wait_for_port(proxy_port, 2000).await);
|
||||
|
||||
// Send HTTP request through proxy
|
||||
{
|
||||
let mut stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
let request = b"GET /test HTTP/1.1\r\nHost: example.com\r\nConnection: close\r\n\r\n";
|
||||
stream.write_all(request).await.unwrap();
|
||||
let mut response = Vec::new();
|
||||
stream.read_to_end(&mut response).await.unwrap();
|
||||
assert!(!response.is_empty(), "Expected non-empty response");
|
||||
}
|
||||
|
||||
// Small delay for metrics to update
|
||||
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
|
||||
|
||||
let stats = proxy.get_statistics();
|
||||
assert!(stats.total_connections > 0,
|
||||
"Expected some connections tracked, got {}", stats.total_connections);
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_hot_reload_port_changes() {
|
||||
let port1 = next_port();
|
||||
let port2 = next_port();
|
||||
let backend_port = next_port();
|
||||
|
||||
let _backend = start_echo_server(backend_port).await;
|
||||
|
||||
// Start with port1
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![make_test_route(port1, None, "127.0.0.1", backend_port)],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut proxy = RustProxy::new(options).unwrap();
|
||||
proxy.start().await.unwrap();
|
||||
assert!(wait_for_port(port1, 2000).await);
|
||||
assert!(!wait_for_port(port2, 200).await, "port2 should not be listening yet");
|
||||
|
||||
// Update routes to use port2 instead
|
||||
let new_routes = vec![
|
||||
make_test_route(port2, None, "127.0.0.1", backend_port),
|
||||
];
|
||||
proxy.update_routes(new_routes).await.unwrap();
|
||||
|
||||
// Port2 should now be listening, port1 should be closed
|
||||
assert!(wait_for_port(port2, 2000).await, "port2 should be listening after reload");
|
||||
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
|
||||
assert!(!wait_for_port(port1, 200).await, "port1 should be closed after reload");
|
||||
|
||||
// Verify port2 works
|
||||
let ports = proxy.get_listening_ports();
|
||||
assert!(ports.contains(&port2), "Expected port2 in listening ports: {:?}", ports);
|
||||
assert!(!ports.contains(&port1), "port1 should not be in listening ports: {:?}", ports);
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
}
|
||||
197
rust/crates/rustproxy/tests/integration_tcp_passthrough.rs
Normal file
197
rust/crates/rustproxy/tests/integration_tcp_passthrough.rs
Normal file
@@ -0,0 +1,197 @@
|
||||
mod common;
|
||||
|
||||
use common::*;
|
||||
use rustproxy::RustProxy;
|
||||
use rustproxy_config::RustProxyOptions;
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::net::TcpStream;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_tcp_forward_echo() {
|
||||
let backend_port = next_port();
|
||||
let proxy_port = next_port();
|
||||
|
||||
// Start echo backend
|
||||
let _backend = start_echo_server(backend_port).await;
|
||||
|
||||
// Configure proxy
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![make_test_route(proxy_port, None, "127.0.0.1", backend_port)],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut proxy = RustProxy::new(options).unwrap();
|
||||
proxy.start().await.unwrap();
|
||||
|
||||
// Wait for proxy to be ready
|
||||
assert!(wait_for_port(proxy_port, 2000).await, "Proxy port not ready");
|
||||
|
||||
// Connect and send data
|
||||
let result = with_timeout(async {
|
||||
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
stream.write_all(b"hello world").await.unwrap();
|
||||
|
||||
let mut buf = vec![0u8; 1024];
|
||||
let n = stream.read(&mut buf).await.unwrap();
|
||||
String::from_utf8_lossy(&buf[..n]).to_string()
|
||||
}, 5)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(result, "hello world");
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_tcp_forward_large_payload() {
|
||||
let backend_port = next_port();
|
||||
let proxy_port = next_port();
|
||||
|
||||
let _backend = start_echo_server(backend_port).await;
|
||||
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![make_test_route(proxy_port, None, "127.0.0.1", backend_port)],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut proxy = RustProxy::new(options).unwrap();
|
||||
proxy.start().await.unwrap();
|
||||
assert!(wait_for_port(proxy_port, 2000).await);
|
||||
|
||||
let result = with_timeout(async {
|
||||
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Send 1MB of data
|
||||
let data = vec![b'A'; 1_000_000];
|
||||
stream.write_all(&data).await.unwrap();
|
||||
stream.shutdown().await.unwrap();
|
||||
|
||||
// Read all back
|
||||
let mut received = Vec::new();
|
||||
stream.read_to_end(&mut received).await.unwrap();
|
||||
received.len()
|
||||
}, 10)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(result, 1_000_000);
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_tcp_forward_multiple_connections() {
|
||||
let backend_port = next_port();
|
||||
let proxy_port = next_port();
|
||||
|
||||
let _backend = start_echo_server(backend_port).await;
|
||||
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![make_test_route(proxy_port, None, "127.0.0.1", backend_port)],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut proxy = RustProxy::new(options).unwrap();
|
||||
proxy.start().await.unwrap();
|
||||
assert!(wait_for_port(proxy_port, 2000).await);
|
||||
|
||||
let result = with_timeout(async {
|
||||
let mut handles = Vec::new();
|
||||
for i in 0..10 {
|
||||
let port = proxy_port;
|
||||
handles.push(tokio::spawn(async move {
|
||||
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", port))
|
||||
.await
|
||||
.unwrap();
|
||||
let msg = format!("connection-{}", i);
|
||||
stream.write_all(msg.as_bytes()).await.unwrap();
|
||||
|
||||
let mut buf = vec![0u8; 1024];
|
||||
let n = stream.read(&mut buf).await.unwrap();
|
||||
String::from_utf8_lossy(&buf[..n]).to_string()
|
||||
}));
|
||||
}
|
||||
|
||||
let mut results = Vec::new();
|
||||
for handle in handles {
|
||||
results.push(handle.await.unwrap());
|
||||
}
|
||||
results
|
||||
}, 10)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(result.len(), 10);
|
||||
for (i, r) in result.iter().enumerate() {
|
||||
assert_eq!(r, &format!("connection-{}", i));
|
||||
}
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_tcp_forward_backend_unreachable() {
|
||||
let proxy_port = next_port();
|
||||
let dead_port = next_port(); // No server on this port
|
||||
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![make_test_route(proxy_port, None, "127.0.0.1", dead_port)],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut proxy = RustProxy::new(options).unwrap();
|
||||
proxy.start().await.unwrap();
|
||||
assert!(wait_for_port(proxy_port, 2000).await);
|
||||
|
||||
// Connection should complete (proxy accepts it) but data should not flow
|
||||
let result = with_timeout(async {
|
||||
let stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port)).await;
|
||||
stream.is_ok()
|
||||
}, 5)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(result, "Should be able to connect to proxy even if backend is down");
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_tcp_forward_bidirectional() {
|
||||
let backend_port = next_port();
|
||||
let proxy_port = next_port();
|
||||
|
||||
// Start a prefix echo server to verify data flows in both directions
|
||||
let _backend = start_prefix_echo_server(backend_port, "REPLY:").await;
|
||||
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![make_test_route(proxy_port, None, "127.0.0.1", backend_port)],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut proxy = RustProxy::new(options).unwrap();
|
||||
proxy.start().await.unwrap();
|
||||
assert!(wait_for_port(proxy_port, 2000).await);
|
||||
|
||||
let result = with_timeout(async {
|
||||
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
stream.write_all(b"test data").await.unwrap();
|
||||
|
||||
let mut buf = vec![0u8; 1024];
|
||||
let n = stream.read(&mut buf).await.unwrap();
|
||||
String::from_utf8_lossy(&buf[..n]).to_string()
|
||||
}, 5)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(result, "REPLY:test data");
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
}
|
||||
247
rust/crates/rustproxy/tests/integration_tls_passthrough.rs
Normal file
247
rust/crates/rustproxy/tests/integration_tls_passthrough.rs
Normal file
@@ -0,0 +1,247 @@
|
||||
mod common;
|
||||
|
||||
use common::*;
|
||||
use rustproxy::RustProxy;
|
||||
use rustproxy_config::RustProxyOptions;
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::net::TcpStream;
|
||||
|
||||
/// Build a minimal TLS ClientHello with the given SNI domain.
|
||||
/// This is enough for the proxy's SNI parser to extract the domain.
|
||||
fn build_client_hello(domain: &str) -> Vec<u8> {
|
||||
let domain_bytes = domain.as_bytes();
|
||||
let sni_length = domain_bytes.len() as u16;
|
||||
|
||||
// Server Name extension (type 0x0000)
|
||||
let mut sni_ext = Vec::new();
|
||||
sni_ext.extend_from_slice(&[0x00, 0x00]); // extension type: server_name
|
||||
let sni_list_len = sni_length + 5; // 2 (list len) + 1 (type) + 2 (name len) + name
|
||||
sni_ext.extend_from_slice(&(sni_list_len as u16).to_be_bytes()); // extension data length
|
||||
sni_ext.extend_from_slice(&((sni_list_len - 2) as u16).to_be_bytes()); // server name list length
|
||||
sni_ext.push(0x00); // host_name type
|
||||
sni_ext.extend_from_slice(&sni_length.to_be_bytes());
|
||||
sni_ext.extend_from_slice(domain_bytes);
|
||||
|
||||
let extensions_length = sni_ext.len() as u16;
|
||||
|
||||
// ClientHello message
|
||||
let mut client_hello = Vec::new();
|
||||
client_hello.extend_from_slice(&[0x03, 0x03]); // TLS 1.2 version
|
||||
client_hello.extend_from_slice(&[0x00; 32]); // random
|
||||
client_hello.push(0x00); // session_id length
|
||||
client_hello.extend_from_slice(&[0x00, 0x02, 0x00, 0xff]); // cipher suites (1 suite)
|
||||
client_hello.extend_from_slice(&[0x01, 0x00]); // compression methods (null)
|
||||
client_hello.extend_from_slice(&extensions_length.to_be_bytes());
|
||||
client_hello.extend_from_slice(&sni_ext);
|
||||
|
||||
let hello_len = client_hello.len() as u32;
|
||||
|
||||
// Handshake wrapper (type 1 = ClientHello)
|
||||
let mut handshake = Vec::new();
|
||||
handshake.push(0x01); // ClientHello
|
||||
handshake.extend_from_slice(&hello_len.to_be_bytes()[1..4]); // 3-byte length
|
||||
handshake.extend_from_slice(&client_hello);
|
||||
|
||||
let hs_len = handshake.len() as u16;
|
||||
|
||||
// TLS record
|
||||
let mut record = Vec::new();
|
||||
record.push(0x16); // ContentType: Handshake
|
||||
record.extend_from_slice(&[0x03, 0x01]); // TLS 1.0 (record version)
|
||||
record.extend_from_slice(&hs_len.to_be_bytes());
|
||||
record.extend_from_slice(&handshake);
|
||||
|
||||
record
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_tls_passthrough_sni_routing() {
|
||||
let backend1_port = next_port();
|
||||
let backend2_port = next_port();
|
||||
let proxy_port = next_port();
|
||||
|
||||
let _b1 = start_prefix_echo_server(backend1_port, "BACKEND1:").await;
|
||||
let _b2 = start_prefix_echo_server(backend2_port, "BACKEND2:").await;
|
||||
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![
|
||||
make_tls_passthrough_route(proxy_port, Some("one.example.com"), "127.0.0.1", backend1_port),
|
||||
make_tls_passthrough_route(proxy_port, Some("two.example.com"), "127.0.0.1", backend2_port),
|
||||
],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut proxy = RustProxy::new(options).unwrap();
|
||||
proxy.start().await.unwrap();
|
||||
assert!(wait_for_port(proxy_port, 2000).await);
|
||||
|
||||
// Send a fake ClientHello with SNI "one.example.com"
|
||||
let result = with_timeout(async {
|
||||
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
let hello = build_client_hello("one.example.com");
|
||||
stream.write_all(&hello).await.unwrap();
|
||||
|
||||
let mut buf = vec![0u8; 4096];
|
||||
let n = stream.read(&mut buf).await.unwrap();
|
||||
String::from_utf8_lossy(&buf[..n]).to_string()
|
||||
}, 5)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Backend1 should have received the ClientHello and prefixed its response
|
||||
assert!(result.starts_with("BACKEND1:"), "Expected BACKEND1 prefix, got: {}", result);
|
||||
|
||||
// Now test routing to backend2
|
||||
let result2 = with_timeout(async {
|
||||
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
let hello = build_client_hello("two.example.com");
|
||||
stream.write_all(&hello).await.unwrap();
|
||||
|
||||
let mut buf = vec![0u8; 4096];
|
||||
let n = stream.read(&mut buf).await.unwrap();
|
||||
String::from_utf8_lossy(&buf[..n]).to_string()
|
||||
}, 5)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(result2.starts_with("BACKEND2:"), "Expected BACKEND2 prefix, got: {}", result2);
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_tls_passthrough_unknown_sni() {
|
||||
let backend_port = next_port();
|
||||
let proxy_port = next_port();
|
||||
|
||||
let _backend = start_echo_server(backend_port).await;
|
||||
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![
|
||||
make_tls_passthrough_route(proxy_port, Some("known.example.com"), "127.0.0.1", backend_port),
|
||||
],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut proxy = RustProxy::new(options).unwrap();
|
||||
proxy.start().await.unwrap();
|
||||
assert!(wait_for_port(proxy_port, 2000).await);
|
||||
|
||||
// Send ClientHello with unknown SNI - should get no response (connection dropped)
|
||||
let result = with_timeout(async {
|
||||
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
let hello = build_client_hello("unknown.example.com");
|
||||
stream.write_all(&hello).await.unwrap();
|
||||
|
||||
let mut buf = vec![0u8; 4096];
|
||||
// Should either get 0 bytes (closed) or an error
|
||||
match stream.read(&mut buf).await {
|
||||
Ok(0) => true, // Connection closed = no route matched
|
||||
Ok(_) => false, // Got data = route shouldn't have matched
|
||||
Err(_) => true, // Error = connection dropped
|
||||
}
|
||||
}, 5)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(result, "Unknown SNI should result in dropped connection");
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_tls_passthrough_wildcard_domain() {
|
||||
let backend_port = next_port();
|
||||
let proxy_port = next_port();
|
||||
|
||||
let _backend = start_prefix_echo_server(backend_port, "WILDCARD:").await;
|
||||
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![
|
||||
make_tls_passthrough_route(proxy_port, Some("*.example.com"), "127.0.0.1", backend_port),
|
||||
],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut proxy = RustProxy::new(options).unwrap();
|
||||
proxy.start().await.unwrap();
|
||||
assert!(wait_for_port(proxy_port, 2000).await);
|
||||
|
||||
// Should match any subdomain of example.com
|
||||
let result = with_timeout(async {
|
||||
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
let hello = build_client_hello("anything.example.com");
|
||||
stream.write_all(&hello).await.unwrap();
|
||||
|
||||
let mut buf = vec![0u8; 4096];
|
||||
let n = stream.read(&mut buf).await.unwrap();
|
||||
String::from_utf8_lossy(&buf[..n]).to_string()
|
||||
}, 5)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(result.starts_with("WILDCARD:"), "Expected WILDCARD prefix, got: {}", result);
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_tls_passthrough_multiple_domains() {
|
||||
let b1_port = next_port();
|
||||
let b2_port = next_port();
|
||||
let b3_port = next_port();
|
||||
let proxy_port = next_port();
|
||||
|
||||
let _b1 = start_prefix_echo_server(b1_port, "B1:").await;
|
||||
let _b2 = start_prefix_echo_server(b2_port, "B2:").await;
|
||||
let _b3 = start_prefix_echo_server(b3_port, "B3:").await;
|
||||
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![
|
||||
make_tls_passthrough_route(proxy_port, Some("alpha.example.com"), "127.0.0.1", b1_port),
|
||||
make_tls_passthrough_route(proxy_port, Some("beta.example.com"), "127.0.0.1", b2_port),
|
||||
make_tls_passthrough_route(proxy_port, Some("gamma.example.com"), "127.0.0.1", b3_port),
|
||||
],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut proxy = RustProxy::new(options).unwrap();
|
||||
proxy.start().await.unwrap();
|
||||
assert!(wait_for_port(proxy_port, 2000).await);
|
||||
|
||||
for (domain, expected_prefix) in [
|
||||
("alpha.example.com", "B1:"),
|
||||
("beta.example.com", "B2:"),
|
||||
("gamma.example.com", "B3:"),
|
||||
] {
|
||||
let result = with_timeout(async {
|
||||
let mut stream = TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
let hello = build_client_hello(domain);
|
||||
stream.write_all(&hello).await.unwrap();
|
||||
|
||||
let mut buf = vec![0u8; 4096];
|
||||
let n = stream.read(&mut buf).await.unwrap();
|
||||
String::from_utf8_lossy(&buf[..n]).to_string()
|
||||
}, 5)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(
|
||||
result.starts_with(expected_prefix),
|
||||
"Domain {} should route to {}, got: {}",
|
||||
domain, expected_prefix, result
|
||||
);
|
||||
}
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
}
|
||||
324
rust/crates/rustproxy/tests/integration_tls_terminate.rs
Normal file
324
rust/crates/rustproxy/tests/integration_tls_terminate.rs
Normal file
@@ -0,0 +1,324 @@
|
||||
mod common;
|
||||
|
||||
use common::*;
|
||||
use rustproxy::RustProxy;
|
||||
use rustproxy_config::RustProxyOptions;
|
||||
use std::sync::Arc;
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
|
||||
/// Create a rustls client config that trusts self-signed certs.
|
||||
fn make_insecure_tls_client_config() -> Arc<rustls::ClientConfig> {
|
||||
let _ = rustls::crypto::ring::default_provider().install_default();
|
||||
let config = rustls::ClientConfig::builder()
|
||||
.dangerous()
|
||||
.with_custom_certificate_verifier(Arc::new(InsecureVerifier))
|
||||
.with_no_client_auth();
|
||||
Arc::new(config)
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct InsecureVerifier;
|
||||
|
||||
impl rustls::client::danger::ServerCertVerifier for InsecureVerifier {
|
||||
fn verify_server_cert(
|
||||
&self,
|
||||
_end_entity: &rustls::pki_types::CertificateDer<'_>,
|
||||
_intermediates: &[rustls::pki_types::CertificateDer<'_>],
|
||||
_server_name: &rustls::pki_types::ServerName<'_>,
|
||||
_ocsp_response: &[u8],
|
||||
_now: rustls::pki_types::UnixTime,
|
||||
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
|
||||
Ok(rustls::client::danger::ServerCertVerified::assertion())
|
||||
}
|
||||
|
||||
fn verify_tls12_signature(
|
||||
&self,
|
||||
_message: &[u8],
|
||||
_cert: &rustls::pki_types::CertificateDer<'_>,
|
||||
_dss: &rustls::DigitallySignedStruct,
|
||||
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
|
||||
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
|
||||
}
|
||||
|
||||
fn verify_tls13_signature(
|
||||
&self,
|
||||
_message: &[u8],
|
||||
_cert: &rustls::pki_types::CertificateDer<'_>,
|
||||
_dss: &rustls::DigitallySignedStruct,
|
||||
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
|
||||
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
|
||||
}
|
||||
|
||||
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
|
||||
vec![
|
||||
rustls::SignatureScheme::RSA_PKCS1_SHA256,
|
||||
rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
|
||||
rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
|
||||
rustls::SignatureScheme::ED25519,
|
||||
rustls::SignatureScheme::RSA_PSS_SHA256,
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_tls_terminate_basic() {
|
||||
let backend_port = next_port();
|
||||
let proxy_port = next_port();
|
||||
let domain = "test.example.com";
|
||||
|
||||
// Generate self-signed cert
|
||||
let (cert_pem, key_pem) = generate_self_signed_cert(domain);
|
||||
|
||||
// Start plain TCP echo backend (proxy terminates TLS, sends plain to backend)
|
||||
let _backend = start_echo_server(backend_port).await;
|
||||
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![make_tls_terminate_route(
|
||||
proxy_port, domain, "127.0.0.1", backend_port, &cert_pem, &key_pem,
|
||||
)],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut proxy = RustProxy::new(options).unwrap();
|
||||
proxy.start().await.unwrap();
|
||||
assert!(wait_for_port(proxy_port, 2000).await);
|
||||
|
||||
// Connect with TLS client
|
||||
let result = with_timeout(async {
|
||||
let tls_config = make_insecure_tls_client_config();
|
||||
let connector = tokio_rustls::TlsConnector::from(tls_config);
|
||||
|
||||
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let server_name = rustls::pki_types::ServerName::try_from(domain.to_string()).unwrap();
|
||||
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
|
||||
|
||||
tls_stream.write_all(b"hello TLS").await.unwrap();
|
||||
|
||||
let mut buf = vec![0u8; 1024];
|
||||
let n = tls_stream.read(&mut buf).await.unwrap();
|
||||
String::from_utf8_lossy(&buf[..n]).to_string()
|
||||
}, 10)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(result, "hello TLS");
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_tls_terminate_and_reencrypt() {
|
||||
let backend_port = next_port();
|
||||
let proxy_port = next_port();
|
||||
let domain = "reencrypt.example.com";
|
||||
let backend_domain = "backend.internal";
|
||||
|
||||
// Generate certs
|
||||
let (proxy_cert, proxy_key) = generate_self_signed_cert(domain);
|
||||
let (backend_cert, backend_key) = generate_self_signed_cert(backend_domain);
|
||||
|
||||
// Start TLS echo backend
|
||||
let _backend = start_tls_echo_server(backend_port, &backend_cert, &backend_key).await;
|
||||
|
||||
// Create terminate-and-reencrypt route
|
||||
let mut route = make_tls_terminate_route(
|
||||
proxy_port, domain, "127.0.0.1", backend_port, &proxy_cert, &proxy_key,
|
||||
);
|
||||
route.action.tls.as_mut().unwrap().mode = rustproxy_config::TlsMode::TerminateAndReencrypt;
|
||||
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![route],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut proxy = RustProxy::new(options).unwrap();
|
||||
proxy.start().await.unwrap();
|
||||
assert!(wait_for_port(proxy_port, 2000).await);
|
||||
|
||||
let result = with_timeout(async {
|
||||
let tls_config = make_insecure_tls_client_config();
|
||||
let connector = tokio_rustls::TlsConnector::from(tls_config);
|
||||
|
||||
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let server_name = rustls::pki_types::ServerName::try_from(domain.to_string()).unwrap();
|
||||
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
|
||||
|
||||
tls_stream.write_all(b"hello reencrypt").await.unwrap();
|
||||
|
||||
let mut buf = vec![0u8; 1024];
|
||||
let n = tls_stream.read(&mut buf).await.unwrap();
|
||||
String::from_utf8_lossy(&buf[..n]).to_string()
|
||||
}, 10)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(result, "hello reencrypt");
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_tls_terminate_sni_cert_selection() {
|
||||
let backend1_port = next_port();
|
||||
let backend2_port = next_port();
|
||||
let proxy_port = next_port();
|
||||
|
||||
let (cert1, key1) = generate_self_signed_cert("alpha.example.com");
|
||||
let (cert2, key2) = generate_self_signed_cert("beta.example.com");
|
||||
|
||||
let _b1 = start_prefix_echo_server(backend1_port, "ALPHA:").await;
|
||||
let _b2 = start_prefix_echo_server(backend2_port, "BETA:").await;
|
||||
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![
|
||||
make_tls_terminate_route(proxy_port, "alpha.example.com", "127.0.0.1", backend1_port, &cert1, &key1),
|
||||
make_tls_terminate_route(proxy_port, "beta.example.com", "127.0.0.1", backend2_port, &cert2, &key2),
|
||||
],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut proxy = RustProxy::new(options).unwrap();
|
||||
proxy.start().await.unwrap();
|
||||
assert!(wait_for_port(proxy_port, 2000).await);
|
||||
|
||||
// Test alpha domain
|
||||
let result = with_timeout(async {
|
||||
let tls_config = make_insecure_tls_client_config();
|
||||
let connector = tokio_rustls::TlsConnector::from(tls_config);
|
||||
|
||||
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let server_name = rustls::pki_types::ServerName::try_from("alpha.example.com".to_string()).unwrap();
|
||||
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
|
||||
|
||||
tls_stream.write_all(b"test").await.unwrap();
|
||||
|
||||
let mut buf = vec![0u8; 1024];
|
||||
let n = tls_stream.read(&mut buf).await.unwrap();
|
||||
String::from_utf8_lossy(&buf[..n]).to_string()
|
||||
}, 10)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(result.starts_with("ALPHA:"), "Expected ALPHA prefix, got: {}", result);
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_tls_terminate_large_payload() {
|
||||
let backend_port = next_port();
|
||||
let proxy_port = next_port();
|
||||
let domain = "large.example.com";
|
||||
|
||||
let (cert_pem, key_pem) = generate_self_signed_cert(domain);
|
||||
let _backend = start_echo_server(backend_port).await;
|
||||
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![make_tls_terminate_route(
|
||||
proxy_port, domain, "127.0.0.1", backend_port, &cert_pem, &key_pem,
|
||||
)],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut proxy = RustProxy::new(options).unwrap();
|
||||
proxy.start().await.unwrap();
|
||||
assert!(wait_for_port(proxy_port, 2000).await);
|
||||
|
||||
let result = with_timeout(async {
|
||||
let tls_config = make_insecure_tls_client_config();
|
||||
let connector = tokio_rustls::TlsConnector::from(tls_config);
|
||||
|
||||
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", proxy_port))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let server_name = rustls::pki_types::ServerName::try_from(domain.to_string()).unwrap();
|
||||
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
|
||||
|
||||
// Send 1MB of data
|
||||
let data = vec![b'X'; 1_000_000];
|
||||
tls_stream.write_all(&data).await.unwrap();
|
||||
tls_stream.shutdown().await.unwrap();
|
||||
|
||||
let mut received = Vec::new();
|
||||
tls_stream.read_to_end(&mut received).await.unwrap();
|
||||
received.len()
|
||||
}, 15)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(result, 1_000_000);
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_tls_terminate_concurrent() {
|
||||
let backend_port = next_port();
|
||||
let proxy_port = next_port();
|
||||
let domain = "concurrent.example.com";
|
||||
|
||||
let (cert_pem, key_pem) = generate_self_signed_cert(domain);
|
||||
let _backend = start_echo_server(backend_port).await;
|
||||
|
||||
let options = RustProxyOptions {
|
||||
routes: vec![make_tls_terminate_route(
|
||||
proxy_port, domain, "127.0.0.1", backend_port, &cert_pem, &key_pem,
|
||||
)],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut proxy = RustProxy::new(options).unwrap();
|
||||
proxy.start().await.unwrap();
|
||||
assert!(wait_for_port(proxy_port, 2000).await);
|
||||
|
||||
let result = with_timeout(async {
|
||||
let mut handles = Vec::new();
|
||||
for i in 0..10 {
|
||||
let port = proxy_port;
|
||||
let dom = domain.to_string();
|
||||
handles.push(tokio::spawn(async move {
|
||||
let tls_config = make_insecure_tls_client_config();
|
||||
let connector = tokio_rustls::TlsConnector::from(tls_config);
|
||||
|
||||
let stream = tokio::net::TcpStream::connect(format!("127.0.0.1:{}", port))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let server_name = rustls::pki_types::ServerName::try_from(dom).unwrap();
|
||||
let mut tls_stream = connector.connect(server_name, stream).await.unwrap();
|
||||
|
||||
let msg = format!("conn-{}", i);
|
||||
tls_stream.write_all(msg.as_bytes()).await.unwrap();
|
||||
|
||||
let mut buf = vec![0u8; 1024];
|
||||
let n = tls_stream.read(&mut buf).await.unwrap();
|
||||
String::from_utf8_lossy(&buf[..n]).to_string()
|
||||
}));
|
||||
}
|
||||
|
||||
let mut results = Vec::new();
|
||||
for handle in handles {
|
||||
results.push(handle.await.unwrap());
|
||||
}
|
||||
results
|
||||
}, 15)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(result.len(), 10);
|
||||
for (i, r) in result.iter().enumerate() {
|
||||
assert_eq!(r, &format!("conn-{}", i));
|
||||
}
|
||||
|
||||
proxy.stop().await.unwrap();
|
||||
}
|
||||
@@ -1,218 +0,0 @@
|
||||
import { tap, expect } from '@git.zone/tstest/tapbundle';
|
||||
import { SmartProxy } from '../ts/index.js';
|
||||
import * as plugins from '../ts/plugins.js';
|
||||
|
||||
/**
|
||||
* Test that verifies ACME challenge routes are properly created
|
||||
*/
|
||||
tap.test('should create ACME challenge route', async (tools) => {
|
||||
tools.timeout(5000);
|
||||
|
||||
// Create a challenge route manually to test its structure
|
||||
const challengeRoute = {
|
||||
name: 'acme-challenge',
|
||||
priority: 1000,
|
||||
match: {
|
||||
ports: 18080,
|
||||
path: '/.well-known/acme-challenge/*'
|
||||
},
|
||||
action: {
|
||||
type: 'socket-handler' as const,
|
||||
socketHandler: (socket: any, context: any) => {
|
||||
socket.once('data', (data: Buffer) => {
|
||||
const request = data.toString();
|
||||
const lines = request.split('\r\n');
|
||||
const [method, path] = lines[0].split(' ');
|
||||
const token = path?.split('/').pop() || '';
|
||||
|
||||
const response = [
|
||||
'HTTP/1.1 200 OK',
|
||||
'Content-Type: text/plain',
|
||||
`Content-Length: ${token.length}`,
|
||||
'Connection: close',
|
||||
'',
|
||||
token
|
||||
].join('\r\n');
|
||||
|
||||
socket.write(response);
|
||||
socket.end();
|
||||
});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Test that the challenge route has the correct structure
|
||||
expect(challengeRoute).toBeDefined();
|
||||
expect(challengeRoute.match.path).toEqual('/.well-known/acme-challenge/*');
|
||||
expect(challengeRoute.match.ports).toEqual(18080);
|
||||
expect(challengeRoute.action.type).toEqual('socket-handler');
|
||||
expect(challengeRoute.priority).toEqual(1000);
|
||||
|
||||
// Create a proxy with the challenge route
|
||||
const settings = {
|
||||
routes: [
|
||||
{
|
||||
name: 'secure-route',
|
||||
match: {
|
||||
ports: [18443],
|
||||
domains: 'test.local'
|
||||
},
|
||||
action: {
|
||||
type: 'forward' as const,
|
||||
targets: [{ host: 'localhost', port: 8080 }]
|
||||
}
|
||||
},
|
||||
challengeRoute
|
||||
]
|
||||
};
|
||||
|
||||
const proxy = new SmartProxy(settings);
|
||||
|
||||
// Mock NFTables manager
|
||||
(proxy as any).nftablesManager = {
|
||||
ensureNFTablesSetup: async () => {},
|
||||
stop: async () => {}
|
||||
};
|
||||
|
||||
// Mock certificate manager to prevent real ACME initialization
|
||||
(proxy as any).createCertificateManager = async function() {
|
||||
return {
|
||||
setUpdateRoutesCallback: () => {},
|
||||
setHttpProxy: () => {},
|
||||
setGlobalAcmeDefaults: () => {},
|
||||
setAcmeStateManager: () => {},
|
||||
initialize: async () => {},
|
||||
provisionAllCertificates: async () => {},
|
||||
stop: async () => {},
|
||||
getAcmeOptions: () => ({}),
|
||||
getState: () => ({ challengeRouteActive: false })
|
||||
};
|
||||
};
|
||||
|
||||
await proxy.start();
|
||||
|
||||
// Verify the challenge route is in the proxy's routes
|
||||
const proxyRoutes = proxy.routeManager.getRoutes();
|
||||
const foundChallengeRoute = proxyRoutes.find((r: any) => r.name === 'acme-challenge');
|
||||
|
||||
expect(foundChallengeRoute).toBeDefined();
|
||||
expect(foundChallengeRoute?.match.path).toEqual('/.well-known/acme-challenge/*');
|
||||
|
||||
await proxy.stop();
|
||||
});
|
||||
|
||||
tap.test('should handle HTTP request parsing correctly', async (tools) => {
|
||||
tools.timeout(5000);
|
||||
|
||||
let handlerCalled = false;
|
||||
let receivedContext: any;
|
||||
let parsedRequest: any = {};
|
||||
|
||||
const settings = {
|
||||
routes: [
|
||||
{
|
||||
name: 'test-static',
|
||||
match: {
|
||||
ports: [18090],
|
||||
path: '/test/*'
|
||||
},
|
||||
action: {
|
||||
type: 'socket-handler' as const,
|
||||
socketHandler: (socket, context) => {
|
||||
handlerCalled = true;
|
||||
receivedContext = context;
|
||||
|
||||
// Parse HTTP request from socket
|
||||
socket.once('data', (data) => {
|
||||
const request = data.toString();
|
||||
const lines = request.split('\r\n');
|
||||
const [method, path, protocol] = lines[0].split(' ');
|
||||
|
||||
// Parse headers
|
||||
const headers: any = {};
|
||||
for (let i = 1; i < lines.length; i++) {
|
||||
if (lines[i] === '') break;
|
||||
const [key, value] = lines[i].split(': ');
|
||||
if (key && value) {
|
||||
headers[key.toLowerCase()] = value;
|
||||
}
|
||||
}
|
||||
|
||||
// Store parsed request data
|
||||
parsedRequest = { method, path, headers };
|
||||
|
||||
// Send HTTP response
|
||||
const response = [
|
||||
'HTTP/1.1 200 OK',
|
||||
'Content-Type: text/plain',
|
||||
'Content-Length: 2',
|
||||
'Connection: close',
|
||||
'',
|
||||
'OK'
|
||||
].join('\r\n');
|
||||
|
||||
socket.write(response);
|
||||
socket.end();
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
};
|
||||
|
||||
const proxy = new SmartProxy(settings);
|
||||
|
||||
// Mock NFTables manager
|
||||
(proxy as any).nftablesManager = {
|
||||
ensureNFTablesSetup: async () => {},
|
||||
stop: async () => {}
|
||||
};
|
||||
|
||||
await proxy.start();
|
||||
|
||||
// Create a simple HTTP request
|
||||
const client = new plugins.net.Socket();
|
||||
|
||||
await new Promise<void>((resolve, reject) => {
|
||||
client.connect(18090, 'localhost', () => {
|
||||
// Send HTTP request
|
||||
const request = [
|
||||
'GET /test/example HTTP/1.1',
|
||||
'Host: localhost:18090',
|
||||
'User-Agent: test-client',
|
||||
'',
|
||||
''
|
||||
].join('\r\n');
|
||||
|
||||
client.write(request);
|
||||
|
||||
// Wait for response
|
||||
client.on('data', (data) => {
|
||||
const response = data.toString();
|
||||
expect(response).toContain('HTTP/1.1 200');
|
||||
expect(response).toContain('OK');
|
||||
client.end();
|
||||
resolve();
|
||||
});
|
||||
});
|
||||
|
||||
client.on('error', reject);
|
||||
});
|
||||
|
||||
// Verify handler was called
|
||||
expect(handlerCalled).toBeTrue();
|
||||
expect(receivedContext).toBeDefined();
|
||||
|
||||
// The context passed to socket handlers is IRouteContext, not HTTP request data
|
||||
expect(receivedContext.port).toEqual(18090);
|
||||
expect(receivedContext.routeName).toEqual('test-static');
|
||||
|
||||
// Verify the parsed HTTP request data
|
||||
expect(parsedRequest.path).toEqual('/test/example');
|
||||
expect(parsedRequest.method).toEqual('GET');
|
||||
expect(parsedRequest.headers.host).toEqual('localhost:18090');
|
||||
|
||||
await proxy.stop();
|
||||
});
|
||||
|
||||
export default tap.start();
|
||||
@@ -1,188 +0,0 @@
|
||||
import { expect, tap } from '@git.zone/tstest/tapbundle';
|
||||
import { AcmeStateManager } from '../ts/proxies/smart-proxy/acme-state-manager.js';
|
||||
import type { IRouteConfig } from '../ts/proxies/smart-proxy/models/route-types.js';
|
||||
|
||||
tap.test('AcmeStateManager should track challenge routes correctly', async (tools) => {
|
||||
const stateManager = new AcmeStateManager();
|
||||
|
||||
const challengeRoute: IRouteConfig = {
|
||||
name: 'acme-challenge',
|
||||
priority: 1000,
|
||||
match: {
|
||||
ports: 80,
|
||||
path: '/.well-known/acme-challenge/*'
|
||||
},
|
||||
action: {
|
||||
type: 'socket-handler',
|
||||
socketHandler: async (socket, context) => {
|
||||
// Mock handler that would write the challenge response
|
||||
socket.end('challenge response');
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Initially no challenge routes
|
||||
expect(stateManager.isChallengeRouteActive()).toBeFalse();
|
||||
expect(stateManager.getActiveChallengeRoutes()).toEqual([]);
|
||||
|
||||
// Add challenge route
|
||||
stateManager.addChallengeRoute(challengeRoute);
|
||||
expect(stateManager.isChallengeRouteActive()).toBeTrue();
|
||||
expect(stateManager.getActiveChallengeRoutes()).toHaveProperty("length", 1);
|
||||
expect(stateManager.getPrimaryChallengeRoute()).toEqual(challengeRoute);
|
||||
|
||||
// Remove challenge route
|
||||
stateManager.removeChallengeRoute('acme-challenge');
|
||||
expect(stateManager.isChallengeRouteActive()).toBeFalse();
|
||||
expect(stateManager.getActiveChallengeRoutes()).toEqual([]);
|
||||
expect(stateManager.getPrimaryChallengeRoute()).toBeNull();
|
||||
});
|
||||
|
||||
tap.test('AcmeStateManager should track port allocations', async (tools) => {
|
||||
const stateManager = new AcmeStateManager();
|
||||
|
||||
const challengeRoute1: IRouteConfig = {
|
||||
name: 'acme-challenge-1',
|
||||
priority: 1000,
|
||||
match: {
|
||||
ports: 80,
|
||||
path: '/.well-known/acme-challenge/*'
|
||||
},
|
||||
action: {
|
||||
type: 'socket-handler'
|
||||
}
|
||||
};
|
||||
|
||||
const challengeRoute2: IRouteConfig = {
|
||||
name: 'acme-challenge-2',
|
||||
priority: 900,
|
||||
match: {
|
||||
ports: [80, 8080],
|
||||
path: '/.well-known/acme-challenge/*'
|
||||
},
|
||||
action: {
|
||||
type: 'socket-handler'
|
||||
}
|
||||
};
|
||||
|
||||
// Add first route
|
||||
stateManager.addChallengeRoute(challengeRoute1);
|
||||
expect(stateManager.isPortAllocatedForAcme(80)).toBeTrue();
|
||||
expect(stateManager.isPortAllocatedForAcme(8080)).toBeFalse();
|
||||
expect(stateManager.getAcmePorts()).toEqual([80]);
|
||||
|
||||
// Add second route
|
||||
stateManager.addChallengeRoute(challengeRoute2);
|
||||
expect(stateManager.isPortAllocatedForAcme(80)).toBeTrue();
|
||||
expect(stateManager.isPortAllocatedForAcme(8080)).toBeTrue();
|
||||
expect(stateManager.getAcmePorts()).toContain(80);
|
||||
expect(stateManager.getAcmePorts()).toContain(8080);
|
||||
|
||||
// Remove first route - port 80 should still be allocated
|
||||
stateManager.removeChallengeRoute('acme-challenge-1');
|
||||
expect(stateManager.isPortAllocatedForAcme(80)).toBeTrue();
|
||||
expect(stateManager.isPortAllocatedForAcme(8080)).toBeTrue();
|
||||
|
||||
// Remove second route - all ports should be deallocated
|
||||
stateManager.removeChallengeRoute('acme-challenge-2');
|
||||
expect(stateManager.isPortAllocatedForAcme(80)).toBeFalse();
|
||||
expect(stateManager.isPortAllocatedForAcme(8080)).toBeFalse();
|
||||
expect(stateManager.getAcmePorts()).toEqual([]);
|
||||
});
|
||||
|
||||
tap.test('AcmeStateManager should select primary route by priority', async (tools) => {
|
||||
const stateManager = new AcmeStateManager();
|
||||
|
||||
const lowPriorityRoute: IRouteConfig = {
|
||||
name: 'low-priority',
|
||||
priority: 100,
|
||||
match: {
|
||||
ports: 80
|
||||
},
|
||||
action: {
|
||||
type: 'socket-handler'
|
||||
}
|
||||
};
|
||||
|
||||
const highPriorityRoute: IRouteConfig = {
|
||||
name: 'high-priority',
|
||||
priority: 2000,
|
||||
match: {
|
||||
ports: 80
|
||||
},
|
||||
action: {
|
||||
type: 'socket-handler'
|
||||
}
|
||||
};
|
||||
|
||||
const defaultPriorityRoute: IRouteConfig = {
|
||||
name: 'default-priority',
|
||||
// No priority specified - should default to 0
|
||||
match: {
|
||||
ports: 80
|
||||
},
|
||||
action: {
|
||||
type: 'socket-handler'
|
||||
}
|
||||
};
|
||||
|
||||
// Add low priority first
|
||||
stateManager.addChallengeRoute(lowPriorityRoute);
|
||||
expect(stateManager.getPrimaryChallengeRoute()?.name).toEqual('low-priority');
|
||||
|
||||
// Add high priority - should become primary
|
||||
stateManager.addChallengeRoute(highPriorityRoute);
|
||||
expect(stateManager.getPrimaryChallengeRoute()?.name).toEqual('high-priority');
|
||||
|
||||
// Add default priority - primary should remain high priority
|
||||
stateManager.addChallengeRoute(defaultPriorityRoute);
|
||||
expect(stateManager.getPrimaryChallengeRoute()?.name).toEqual('high-priority');
|
||||
|
||||
// Remove high priority - primary should fall back to low priority
|
||||
stateManager.removeChallengeRoute('high-priority');
|
||||
expect(stateManager.getPrimaryChallengeRoute()?.name).toEqual('low-priority');
|
||||
});
|
||||
|
||||
tap.test('AcmeStateManager should handle clear operation', async (tools) => {
|
||||
const stateManager = new AcmeStateManager();
|
||||
|
||||
const challengeRoute1: IRouteConfig = {
|
||||
name: 'route-1',
|
||||
match: {
|
||||
ports: [80, 443]
|
||||
},
|
||||
action: {
|
||||
type: 'socket-handler'
|
||||
}
|
||||
};
|
||||
|
||||
const challengeRoute2: IRouteConfig = {
|
||||
name: 'route-2',
|
||||
match: {
|
||||
ports: 8080
|
||||
},
|
||||
action: {
|
||||
type: 'socket-handler'
|
||||
}
|
||||
};
|
||||
|
||||
// Add routes
|
||||
stateManager.addChallengeRoute(challengeRoute1);
|
||||
stateManager.addChallengeRoute(challengeRoute2);
|
||||
|
||||
// Verify state before clear
|
||||
expect(stateManager.isChallengeRouteActive()).toBeTrue();
|
||||
expect(stateManager.getActiveChallengeRoutes()).toHaveProperty("length", 2);
|
||||
expect(stateManager.getAcmePorts()).toHaveProperty("length", 3);
|
||||
|
||||
// Clear all state
|
||||
stateManager.clear();
|
||||
|
||||
// Verify state after clear
|
||||
expect(stateManager.isChallengeRouteActive()).toBeFalse();
|
||||
expect(stateManager.getActiveChallengeRoutes()).toEqual([]);
|
||||
expect(stateManager.getAcmePorts()).toEqual([]);
|
||||
expect(stateManager.getPrimaryChallengeRoute()).toBeNull();
|
||||
});
|
||||
|
||||
export default tap.start();
|
||||
@@ -1,122 +0,0 @@
|
||||
import { tap, expect } from '@git.zone/tstest/tapbundle';
|
||||
import { SmartProxy } from '../ts/index.js';
|
||||
|
||||
// Test that certificate provisioning is deferred until after ports are listening
|
||||
tap.test('should defer certificate provisioning until ports are ready', async (tapTest) => {
|
||||
// Track when operations happen
|
||||
let portsListening = false;
|
||||
let certProvisioningStarted = false;
|
||||
let operationOrder: string[] = [];
|
||||
|
||||
// Create proxy with certificate route but without real ACME
|
||||
const proxy = new SmartProxy({
|
||||
routes: [{
|
||||
name: 'test-route',
|
||||
match: {
|
||||
ports: 8443,
|
||||
domains: ['test.local']
|
||||
},
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{ host: 'localhost', port: 8181 }],
|
||||
tls: {
|
||||
mode: 'terminate',
|
||||
certificate: 'auto',
|
||||
acme: {
|
||||
email: 'test@local.dev',
|
||||
useProduction: false
|
||||
}
|
||||
}
|
||||
}
|
||||
}]
|
||||
});
|
||||
|
||||
// Override the certificate manager creation to avoid real ACME
|
||||
const originalCreateCertManager = proxy['createCertificateManager'];
|
||||
proxy['createCertificateManager'] = async function(...args: any[]) {
|
||||
console.log('Creating mock cert manager');
|
||||
operationOrder.push('create-cert-manager');
|
||||
const mockCertManager = {
|
||||
certStore: null,
|
||||
smartAcme: null,
|
||||
httpProxy: null,
|
||||
renewalTimer: null,
|
||||
pendingChallenges: new Map(),
|
||||
challengeRoute: null,
|
||||
certStatus: new Map(),
|
||||
globalAcmeDefaults: null,
|
||||
updateRoutesCallback: undefined,
|
||||
challengeRouteActive: false,
|
||||
isProvisioning: false,
|
||||
acmeStateManager: null,
|
||||
initialize: async () => {
|
||||
operationOrder.push('cert-manager-init');
|
||||
console.log('Mock cert manager initialized');
|
||||
},
|
||||
provisionAllCertificates: async () => {
|
||||
operationOrder.push('cert-provisioning');
|
||||
certProvisioningStarted = true;
|
||||
// Check that ports are listening when provisioning starts
|
||||
if (!portsListening) {
|
||||
throw new Error('Certificate provisioning started before ports ready!');
|
||||
}
|
||||
console.log('Mock certificate provisioning (ports are ready)');
|
||||
},
|
||||
stop: async () => {},
|
||||
setHttpProxy: () => {},
|
||||
setGlobalAcmeDefaults: () => {},
|
||||
setAcmeStateManager: () => {},
|
||||
setUpdateRoutesCallback: () => {},
|
||||
getAcmeOptions: () => ({}),
|
||||
getState: () => ({ challengeRouteActive: false }),
|
||||
getCertStatus: () => new Map(),
|
||||
checkAndRenewCertificates: async () => {},
|
||||
addChallengeRoute: async () => {},
|
||||
removeChallengeRoute: async () => {},
|
||||
getCertificate: async () => null,
|
||||
isValidCertificate: () => false,
|
||||
waitForProvisioning: async () => {}
|
||||
} as any;
|
||||
|
||||
// Call initialize immediately as the real createCertificateManager does
|
||||
await mockCertManager.initialize();
|
||||
|
||||
return mockCertManager;
|
||||
};
|
||||
|
||||
// Track port manager operations
|
||||
const originalAddPorts = proxy['portManager'].addPorts;
|
||||
proxy['portManager'].addPorts = async function(ports: number[]) {
|
||||
operationOrder.push('ports-starting');
|
||||
const result = await originalAddPorts.call(this, ports);
|
||||
operationOrder.push('ports-ready');
|
||||
portsListening = true;
|
||||
console.log('Ports are now listening');
|
||||
return result;
|
||||
};
|
||||
|
||||
// Start the proxy
|
||||
await proxy.start();
|
||||
|
||||
// Log the operation order for debugging
|
||||
console.log('Operation order:', operationOrder);
|
||||
|
||||
// Verify operations happened in the correct order
|
||||
expect(operationOrder).toContain('create-cert-manager');
|
||||
expect(operationOrder).toContain('cert-manager-init');
|
||||
expect(operationOrder).toContain('ports-starting');
|
||||
expect(operationOrder).toContain('ports-ready');
|
||||
expect(operationOrder).toContain('cert-provisioning');
|
||||
|
||||
// Verify ports were ready before certificate provisioning
|
||||
const portsReadyIndex = operationOrder.indexOf('ports-ready');
|
||||
const certProvisioningIndex = operationOrder.indexOf('cert-provisioning');
|
||||
|
||||
expect(portsReadyIndex).toBeLessThan(certProvisioningIndex);
|
||||
expect(certProvisioningStarted).toEqual(true);
|
||||
expect(portsListening).toEqual(true);
|
||||
|
||||
await proxy.stop();
|
||||
});
|
||||
|
||||
export default tap.start();
|
||||
@@ -1,204 +0,0 @@
|
||||
import { tap, expect } from '@git.zone/tstest/tapbundle';
|
||||
import { SmartProxy } from '../ts/index.js';
|
||||
import * as net from 'net';
|
||||
|
||||
// Test that certificate provisioning waits for ports to be ready
|
||||
tap.test('should defer certificate provisioning until after ports are listening', async (tapTest) => {
|
||||
// Track the order of operations
|
||||
const operationLog: string[] = [];
|
||||
|
||||
// Create a mock server to verify ports are listening
|
||||
let port80Listening = false;
|
||||
|
||||
// Try to use port 8080 instead of 80 to avoid permission issues in testing
|
||||
const acmePort = 8080;
|
||||
|
||||
// Create proxy with ACME certificate requirement
|
||||
const proxy = new SmartProxy({
|
||||
useHttpProxy: [acmePort],
|
||||
httpProxyPort: 8845, // Use different port to avoid conflicts
|
||||
acme: {
|
||||
email: 'test@test.local',
|
||||
useProduction: false,
|
||||
port: acmePort
|
||||
},
|
||||
routes: [{
|
||||
name: 'test-acme-route',
|
||||
match: {
|
||||
ports: 8443,
|
||||
domains: ['test.local']
|
||||
},
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{ host: 'localhost', port: 8181 }],
|
||||
tls: {
|
||||
mode: 'terminate',
|
||||
certificate: 'auto',
|
||||
acme: {
|
||||
email: 'test@test.local',
|
||||
useProduction: false
|
||||
}
|
||||
}
|
||||
}
|
||||
}]
|
||||
});
|
||||
|
||||
// Mock some internal methods to track operation order
|
||||
const originalAddPorts = proxy['portManager'].addPorts;
|
||||
proxy['portManager'].addPorts = async function(ports: number[]) {
|
||||
operationLog.push('Starting port listeners');
|
||||
const result = await originalAddPorts.call(this, ports);
|
||||
operationLog.push('Port listeners started');
|
||||
port80Listening = true;
|
||||
return result;
|
||||
};
|
||||
|
||||
// Track that we created a certificate manager and SmartProxy will call provisionAllCertificates
|
||||
let certManagerCreated = false;
|
||||
|
||||
// Override createCertificateManager to set up our tracking
|
||||
const originalCreateCertManager = (proxy as any).createCertificateManager;
|
||||
(proxy as any).certManagerCreated = false;
|
||||
|
||||
// Mock certificate manager to avoid real ACME initialization
|
||||
(proxy as any).createCertificateManager = async function() {
|
||||
operationLog.push('Creating certificate manager');
|
||||
const mockCertManager = {
|
||||
setUpdateRoutesCallback: () => {},
|
||||
setHttpProxy: () => {},
|
||||
setGlobalAcmeDefaults: () => {},
|
||||
setAcmeStateManager: () => {},
|
||||
initialize: async () => {
|
||||
operationLog.push('Certificate manager initialized');
|
||||
},
|
||||
provisionAllCertificates: async () => {
|
||||
operationLog.push('Starting certificate provisioning');
|
||||
if (!port80Listening) {
|
||||
operationLog.push('ERROR: Certificate provisioning started before ports ready');
|
||||
}
|
||||
operationLog.push('Certificate provisioning completed');
|
||||
},
|
||||
stop: async () => {},
|
||||
getAcmeOptions: () => ({ email: 'test@test.local', useProduction: false }),
|
||||
getState: () => ({ challengeRouteActive: false })
|
||||
};
|
||||
certManagerCreated = true;
|
||||
(proxy as any).certManager = mockCertManager;
|
||||
return mockCertManager;
|
||||
};
|
||||
|
||||
// Start the proxy
|
||||
await proxy.start();
|
||||
|
||||
// Verify the order of operations
|
||||
expect(operationLog).toContain('Starting port listeners');
|
||||
expect(operationLog).toContain('Port listeners started');
|
||||
expect(operationLog).toContain('Starting certificate provisioning');
|
||||
|
||||
// Ensure port listeners started before certificate provisioning
|
||||
const portStartIndex = operationLog.indexOf('Port listeners started');
|
||||
const certStartIndex = operationLog.indexOf('Starting certificate provisioning');
|
||||
|
||||
expect(portStartIndex).toBeLessThan(certStartIndex);
|
||||
expect(operationLog).not.toContain('ERROR: Certificate provisioning started before ports ready');
|
||||
|
||||
await proxy.stop();
|
||||
});
|
||||
|
||||
// Test that ACME challenge route is available when certificate is requested
|
||||
tap.test('should have ACME challenge route ready before certificate provisioning', async (tapTest) => {
|
||||
let challengeRouteActive = false;
|
||||
let certificateProvisioningStarted = false;
|
||||
|
||||
const proxy = new SmartProxy({
|
||||
useHttpProxy: [8080],
|
||||
httpProxyPort: 8846, // Use different port to avoid conflicts
|
||||
acme: {
|
||||
email: 'test@test.local',
|
||||
useProduction: false,
|
||||
port: 8080
|
||||
},
|
||||
routes: [{
|
||||
name: 'test-route',
|
||||
match: {
|
||||
ports: 8443,
|
||||
domains: ['test.example.com']
|
||||
},
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{ host: 'localhost', port: 8181 }],
|
||||
tls: {
|
||||
mode: 'terminate',
|
||||
certificate: 'auto'
|
||||
}
|
||||
}
|
||||
}]
|
||||
});
|
||||
|
||||
// Mock the certificate manager to track operations
|
||||
const originalInitialize = proxy['certManager'] ?
|
||||
proxy['certManager'].initialize : null;
|
||||
|
||||
if (proxy['certManager']) {
|
||||
const certManager = proxy['certManager'];
|
||||
|
||||
// Track when challenge route is added
|
||||
const originalAddChallenge = certManager['addChallengeRoute'];
|
||||
certManager['addChallengeRoute'] = async function() {
|
||||
await originalAddChallenge.call(this);
|
||||
challengeRouteActive = true;
|
||||
};
|
||||
|
||||
// Track when certificate provisioning starts
|
||||
const originalProvisionAcme = certManager['provisionAcmeCertificate'];
|
||||
certManager['provisionAcmeCertificate'] = async function(...args: any[]) {
|
||||
certificateProvisioningStarted = true;
|
||||
// Verify challenge route is active
|
||||
expect(challengeRouteActive).toEqual(true);
|
||||
// Don't actually provision in test
|
||||
return;
|
||||
};
|
||||
}
|
||||
|
||||
// Mock certificate manager to avoid real ACME initialization
|
||||
(proxy as any).createCertificateManager = async function() {
|
||||
const mockCertManager = {
|
||||
setUpdateRoutesCallback: () => {},
|
||||
setHttpProxy: () => {},
|
||||
setGlobalAcmeDefaults: () => {},
|
||||
setAcmeStateManager: () => {},
|
||||
initialize: async () => {
|
||||
challengeRouteActive = true;
|
||||
},
|
||||
provisionAllCertificates: async () => {
|
||||
certificateProvisioningStarted = true;
|
||||
expect(challengeRouteActive).toEqual(true);
|
||||
},
|
||||
stop: async () => {},
|
||||
getAcmeOptions: () => ({ email: 'test@test.local', useProduction: false }),
|
||||
getState: () => ({ challengeRouteActive: false }),
|
||||
addChallengeRoute: async () => {
|
||||
challengeRouteActive = true;
|
||||
},
|
||||
provisionAcmeCertificate: async () => {
|
||||
certificateProvisioningStarted = true;
|
||||
expect(challengeRouteActive).toEqual(true);
|
||||
}
|
||||
};
|
||||
// Call initialize like the real createCertificateManager does
|
||||
await mockCertManager.initialize();
|
||||
return mockCertManager;
|
||||
};
|
||||
|
||||
await proxy.start();
|
||||
|
||||
// Give it a moment to complete initialization
|
||||
await new Promise(resolve => setTimeout(resolve, 100));
|
||||
|
||||
// Verify challenge route was added before any certificate provisioning
|
||||
expect(challengeRouteActive).toEqual(true);
|
||||
|
||||
await proxy.stop();
|
||||
});
|
||||
|
||||
export default tap.start();
|
||||
@@ -1,77 +0,0 @@
|
||||
import { expect, tap } from '@git.zone/tstest/tapbundle';
|
||||
import * as plugins from '../ts/plugins.js';
|
||||
import * as smartproxy from '../ts/index.js';
|
||||
|
||||
// This test verifies that SmartProxy correctly uses the updated SmartAcme v8.0.0 API
|
||||
// with the optional wildcard parameter
|
||||
|
||||
tap.test('SmartCertManager should call getCertificateForDomain with wildcard option', async () => {
|
||||
console.log('Testing SmartCertManager with SmartAcme v8.0.0 API...');
|
||||
|
||||
// Create a mock route with ACME certificate configuration
|
||||
const mockRoute: smartproxy.IRouteConfig = {
|
||||
match: {
|
||||
domains: ['test.example.com'],
|
||||
ports: 443
|
||||
},
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{
|
||||
host: 'localhost',
|
||||
port: 8080
|
||||
}],
|
||||
tls: {
|
||||
mode: 'terminate',
|
||||
certificate: 'auto',
|
||||
acme: {
|
||||
email: 'test@example.com',
|
||||
useProduction: false
|
||||
}
|
||||
}
|
||||
},
|
||||
name: 'test-route'
|
||||
};
|
||||
|
||||
// Create a certificate manager
|
||||
const certManager = new smartproxy.SmartCertManager(
|
||||
[mockRoute],
|
||||
'./test-certs',
|
||||
{
|
||||
email: 'test@example.com',
|
||||
useProduction: false
|
||||
}
|
||||
);
|
||||
|
||||
// Since we can't actually test ACME in a unit test, we'll just verify the logic
|
||||
// The actual test would be that it builds and runs without errors
|
||||
|
||||
// Test the wildcard logic for different domain types and challenge handlers
|
||||
const testCases = [
|
||||
{ domain: 'example.com', hasDnsChallenge: true, shouldIncludeWildcard: true },
|
||||
{ domain: 'example.com', hasDnsChallenge: false, shouldIncludeWildcard: false },
|
||||
{ domain: 'sub.example.com', hasDnsChallenge: true, shouldIncludeWildcard: true },
|
||||
{ domain: 'sub.example.com', hasDnsChallenge: false, shouldIncludeWildcard: false },
|
||||
{ domain: '*.example.com', hasDnsChallenge: true, shouldIncludeWildcard: false },
|
||||
{ domain: '*.example.com', hasDnsChallenge: false, shouldIncludeWildcard: false },
|
||||
{ domain: 'test', hasDnsChallenge: true, shouldIncludeWildcard: false }, // single label domain
|
||||
{ domain: 'test', hasDnsChallenge: false, shouldIncludeWildcard: false },
|
||||
{ domain: 'my.sub.example.com', hasDnsChallenge: true, shouldIncludeWildcard: true },
|
||||
{ domain: 'my.sub.example.com', hasDnsChallenge: false, shouldIncludeWildcard: false }
|
||||
];
|
||||
|
||||
for (const testCase of testCases) {
|
||||
const shouldIncludeWildcard = !testCase.domain.startsWith('*.') &&
|
||||
testCase.domain.includes('.') &&
|
||||
testCase.domain.split('.').length >= 2 &&
|
||||
testCase.hasDnsChallenge;
|
||||
|
||||
console.log(`Domain: ${testCase.domain}, DNS-01: ${testCase.hasDnsChallenge}, Should include wildcard: ${shouldIncludeWildcard}`);
|
||||
expect(shouldIncludeWildcard).toEqual(testCase.shouldIncludeWildcard);
|
||||
}
|
||||
|
||||
console.log('All wildcard logic tests passed!');
|
||||
});
|
||||
|
||||
tap.start({
|
||||
throwOnError: true
|
||||
});
|
||||
@@ -1,423 +0,0 @@
|
||||
import { expect, tap } from '@git.zone/tstest/tapbundle';
|
||||
import { SmartProxy } from '../ts/index.js';
|
||||
import type { TSmartProxyCertProvisionObject } from '../ts/index.js';
|
||||
import * as fs from 'fs';
|
||||
import * as path from 'path';
|
||||
import { fileURLToPath } from 'url';
|
||||
|
||||
const __filename = fileURLToPath(import.meta.url);
|
||||
const __dirname = path.dirname(__filename);
|
||||
|
||||
let testProxy: SmartProxy;
|
||||
|
||||
// Load test certificates from helpers
|
||||
const testCert = fs.readFileSync(path.join(__dirname, 'helpers/test-cert.pem'), 'utf8');
|
||||
const testKey = fs.readFileSync(path.join(__dirname, 'helpers/test-key.pem'), 'utf8');
|
||||
|
||||
// Helper to create a fully mocked certificate manager that doesn't contact ACME servers
|
||||
function createMockCertManager(options: {
|
||||
onProvisionAll?: () => void;
|
||||
onGetCertForDomain?: (domain: string) => void;
|
||||
} = {}) {
|
||||
return {
|
||||
setUpdateRoutesCallback: function(callback: any) {
|
||||
this.updateRoutesCallback = callback;
|
||||
},
|
||||
updateRoutesCallback: null as any,
|
||||
setHttpProxy: function() {},
|
||||
setGlobalAcmeDefaults: function() {},
|
||||
setAcmeStateManager: function() {},
|
||||
setRoutes: function(routes: any) {},
|
||||
initialize: async function() {},
|
||||
provisionAllCertificates: async function() {
|
||||
if (options.onProvisionAll) {
|
||||
options.onProvisionAll();
|
||||
}
|
||||
},
|
||||
stop: async function() {},
|
||||
getAcmeOptions: function() {
|
||||
return { email: 'test@example.com', useProduction: false };
|
||||
},
|
||||
getState: function() {
|
||||
return { challengeRouteActive: false };
|
||||
},
|
||||
smartAcme: {
|
||||
getCertificateForDomain: async (domain: string) => {
|
||||
if (options.onGetCertForDomain) {
|
||||
options.onGetCertForDomain(domain);
|
||||
}
|
||||
throw new Error('Mocked ACME - not calling real servers');
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
tap.test('SmartProxy should support custom certificate provision function', async () => {
|
||||
// Create test certificate object matching ICert interface
|
||||
const testCertObject = {
|
||||
id: 'test-cert-1',
|
||||
domainName: 'test.example.com',
|
||||
created: Date.now(),
|
||||
validUntil: Date.now() + 90 * 24 * 60 * 60 * 1000, // 90 days
|
||||
privateKey: testKey,
|
||||
publicKey: testCert,
|
||||
csr: ''
|
||||
};
|
||||
|
||||
// Custom certificate store for testing
|
||||
const customCerts = new Map<string, typeof testCertObject>();
|
||||
customCerts.set('test.example.com', testCertObject);
|
||||
|
||||
// Create proxy with custom certificate provision
|
||||
testProxy = new SmartProxy({
|
||||
certProvisionFunction: async (domain: string): Promise<TSmartProxyCertProvisionObject> => {
|
||||
console.log(`Custom cert provision called for domain: ${domain}`);
|
||||
|
||||
// Return custom cert for known domains
|
||||
if (customCerts.has(domain)) {
|
||||
console.log(`Returning custom certificate for ${domain}`);
|
||||
return customCerts.get(domain)!;
|
||||
}
|
||||
|
||||
// Fallback to Let's Encrypt for other domains
|
||||
console.log(`Falling back to Let's Encrypt for ${domain}`);
|
||||
return 'http01';
|
||||
},
|
||||
certProvisionFallbackToAcme: true,
|
||||
acme: {
|
||||
email: 'test@example.com',
|
||||
useProduction: false
|
||||
},
|
||||
routes: [
|
||||
{
|
||||
name: 'test-route',
|
||||
match: {
|
||||
ports: [443],
|
||||
domains: ['test.example.com']
|
||||
},
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{
|
||||
host: 'localhost',
|
||||
port: 8080
|
||||
}],
|
||||
tls: {
|
||||
mode: 'terminate',
|
||||
certificate: 'auto'
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
});
|
||||
|
||||
expect(testProxy).toBeInstanceOf(SmartProxy);
|
||||
});
|
||||
|
||||
tap.test('Custom certificate provision function should be called', async () => {
|
||||
let provisionCalled = false;
|
||||
const provisionedDomains: string[] = [];
|
||||
|
||||
const testProxy2 = new SmartProxy({
|
||||
certProvisionFunction: async (domain: string): Promise<TSmartProxyCertProvisionObject> => {
|
||||
provisionCalled = true;
|
||||
provisionedDomains.push(domain);
|
||||
|
||||
// Return a test certificate matching ICert interface
|
||||
return {
|
||||
id: `test-cert-${domain}`,
|
||||
domainName: domain,
|
||||
created: Date.now(),
|
||||
validUntil: Date.now() + 90 * 24 * 60 * 60 * 1000,
|
||||
privateKey: testKey,
|
||||
publicKey: testCert,
|
||||
csr: ''
|
||||
};
|
||||
},
|
||||
acme: {
|
||||
email: 'test@example.com',
|
||||
useProduction: false,
|
||||
port: 9080
|
||||
},
|
||||
routes: [
|
||||
{
|
||||
name: 'custom-cert-route',
|
||||
match: {
|
||||
ports: [9443],
|
||||
domains: ['custom.example.com']
|
||||
},
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{
|
||||
host: 'localhost',
|
||||
port: 8080
|
||||
}],
|
||||
tls: {
|
||||
mode: 'terminate',
|
||||
certificate: 'auto'
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
});
|
||||
|
||||
// Fully mock the certificate manager to avoid ACME server contact
|
||||
let certManagerCalled = false;
|
||||
(testProxy2 as any).createCertificateManager = async function() {
|
||||
const mockCertManager = createMockCertManager({
|
||||
onProvisionAll: () => {
|
||||
certManagerCalled = true;
|
||||
// Simulate calling the provision function
|
||||
testProxy2.settings.certProvisionFunction?.('custom.example.com');
|
||||
}
|
||||
});
|
||||
|
||||
// Set callback as in real implementation
|
||||
mockCertManager.setUpdateRoutesCallback(async (routes: any) => {
|
||||
await this.updateRoutes(routes);
|
||||
});
|
||||
|
||||
return mockCertManager;
|
||||
};
|
||||
|
||||
// Start the proxy (this will trigger certificate provisioning)
|
||||
await testProxy2.start();
|
||||
|
||||
expect(certManagerCalled).toBeTrue();
|
||||
expect(provisionCalled).toBeTrue();
|
||||
expect(provisionedDomains).toContain('custom.example.com');
|
||||
|
||||
await testProxy2.stop();
|
||||
});
|
||||
|
||||
tap.test('Should fallback to ACME when custom provision fails', async () => {
|
||||
const failedDomains: string[] = [];
|
||||
let acmeAttempted = false;
|
||||
|
||||
const testProxy3 = new SmartProxy({
|
||||
certProvisionFunction: async (domain: string): Promise<TSmartProxyCertProvisionObject> => {
|
||||
failedDomains.push(domain);
|
||||
throw new Error('Custom provision failed for testing');
|
||||
},
|
||||
certProvisionFallbackToAcme: true,
|
||||
acme: {
|
||||
email: 'test@example.com',
|
||||
useProduction: false,
|
||||
port: 9080
|
||||
},
|
||||
routes: [
|
||||
{
|
||||
name: 'fallback-route',
|
||||
match: {
|
||||
ports: [9444],
|
||||
domains: ['fallback.example.com']
|
||||
},
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{
|
||||
host: 'localhost',
|
||||
port: 8080
|
||||
}],
|
||||
tls: {
|
||||
mode: 'terminate',
|
||||
certificate: 'auto'
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
});
|
||||
|
||||
// Fully mock the certificate manager to avoid ACME server contact
|
||||
(testProxy3 as any).createCertificateManager = async function() {
|
||||
const mockCertManager = createMockCertManager({
|
||||
onProvisionAll: async () => {
|
||||
// Simulate the provision logic: first try custom function, then ACME
|
||||
try {
|
||||
await testProxy3.settings.certProvisionFunction?.('fallback.example.com');
|
||||
} catch (e) {
|
||||
// Custom provision failed, try ACME
|
||||
acmeAttempted = true;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Set callback as in real implementation
|
||||
mockCertManager.setUpdateRoutesCallback(async (routes: any) => {
|
||||
await this.updateRoutes(routes);
|
||||
});
|
||||
|
||||
return mockCertManager;
|
||||
};
|
||||
|
||||
// Start the proxy
|
||||
await testProxy3.start();
|
||||
|
||||
// Custom provision should have failed
|
||||
expect(failedDomains).toContain('fallback.example.com');
|
||||
|
||||
// ACME should have been attempted as fallback
|
||||
expect(acmeAttempted).toBeTrue();
|
||||
|
||||
await testProxy3.stop();
|
||||
});
|
||||
|
||||
tap.test('Should not fallback when certProvisionFallbackToAcme is false', async () => {
|
||||
let errorThrown = false;
|
||||
let errorMessage = '';
|
||||
|
||||
const testProxy4 = new SmartProxy({
|
||||
certProvisionFunction: async (_domain: string): Promise<TSmartProxyCertProvisionObject> => {
|
||||
throw new Error('Custom provision failed for testing');
|
||||
},
|
||||
certProvisionFallbackToAcme: false,
|
||||
acme: {
|
||||
email: 'test@example.com',
|
||||
useProduction: false,
|
||||
port: 9082
|
||||
},
|
||||
routes: [
|
||||
{
|
||||
name: 'no-fallback-route',
|
||||
match: {
|
||||
ports: [9449],
|
||||
domains: ['no-fallback.example.com']
|
||||
},
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{
|
||||
host: 'localhost',
|
||||
port: 8080
|
||||
}],
|
||||
tls: {
|
||||
mode: 'terminate',
|
||||
certificate: 'auto'
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
});
|
||||
|
||||
// Fully mock the certificate manager to avoid ACME server contact
|
||||
(testProxy4 as any).createCertificateManager = async function() {
|
||||
const mockCertManager = createMockCertManager({
|
||||
onProvisionAll: async () => {
|
||||
// Simulate the provision logic with no fallback
|
||||
try {
|
||||
await testProxy4.settings.certProvisionFunction?.('no-fallback.example.com');
|
||||
} catch (e: any) {
|
||||
errorThrown = true;
|
||||
errorMessage = e.message;
|
||||
// With certProvisionFallbackToAcme=false, the error should propagate
|
||||
if (!testProxy4.settings.certProvisionFallbackToAcme) {
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Set callback as in real implementation
|
||||
mockCertManager.setUpdateRoutesCallback(async (routes: any) => {
|
||||
await this.updateRoutes(routes);
|
||||
});
|
||||
|
||||
return mockCertManager;
|
||||
};
|
||||
|
||||
try {
|
||||
await testProxy4.start();
|
||||
} catch (e) {
|
||||
// Expected to fail
|
||||
}
|
||||
|
||||
expect(errorThrown).toBeTrue();
|
||||
expect(errorMessage).toInclude('Custom provision failed for testing');
|
||||
|
||||
await testProxy4.stop();
|
||||
});
|
||||
|
||||
tap.test('Should return http01 for unknown domains', async () => {
|
||||
let returnedHttp01 = false;
|
||||
let acmeAttempted = false;
|
||||
|
||||
const testProxy5 = new SmartProxy({
|
||||
certProvisionFunction: async (domain: string): Promise<TSmartProxyCertProvisionObject> => {
|
||||
if (domain === 'known.example.com') {
|
||||
return {
|
||||
id: `test-cert-${domain}`,
|
||||
domainName: domain,
|
||||
created: Date.now(),
|
||||
validUntil: Date.now() + 90 * 24 * 60 * 60 * 1000,
|
||||
privateKey: testKey,
|
||||
publicKey: testCert,
|
||||
csr: ''
|
||||
};
|
||||
}
|
||||
returnedHttp01 = true;
|
||||
return 'http01';
|
||||
},
|
||||
acme: {
|
||||
email: 'test@example.com',
|
||||
useProduction: false,
|
||||
port: 9081
|
||||
},
|
||||
routes: [
|
||||
{
|
||||
name: 'unknown-domain-route',
|
||||
match: {
|
||||
ports: [9446],
|
||||
domains: ['unknown.example.com']
|
||||
},
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{
|
||||
host: 'localhost',
|
||||
port: 8080
|
||||
}],
|
||||
tls: {
|
||||
mode: 'terminate',
|
||||
certificate: 'auto'
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
});
|
||||
|
||||
// Fully mock the certificate manager to avoid ACME server contact
|
||||
(testProxy5 as any).createCertificateManager = async function() {
|
||||
const mockCertManager = createMockCertManager({
|
||||
onProvisionAll: async () => {
|
||||
// Simulate the provision logic: call provision function first
|
||||
const result = await testProxy5.settings.certProvisionFunction?.('unknown.example.com');
|
||||
if (result === 'http01') {
|
||||
// http01 means use ACME
|
||||
acmeAttempted = true;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Set callback as in real implementation
|
||||
mockCertManager.setUpdateRoutesCallback(async (routes: any) => {
|
||||
await this.updateRoutes(routes);
|
||||
});
|
||||
|
||||
return mockCertManager;
|
||||
};
|
||||
|
||||
await testProxy5.start();
|
||||
|
||||
// Should have returned http01 for unknown domain
|
||||
expect(returnedHttp01).toBeTrue();
|
||||
|
||||
// ACME should have been attempted
|
||||
expect(acmeAttempted).toBeTrue();
|
||||
|
||||
await testProxy5.stop();
|
||||
});
|
||||
|
||||
tap.test('cleanup', async () => {
|
||||
// Clean up any test proxies
|
||||
if (testProxy) {
|
||||
await testProxy.stop();
|
||||
}
|
||||
});
|
||||
|
||||
export default tap.start();
|
||||
@@ -1,241 +0,0 @@
|
||||
import { SmartProxy } from '../ts/proxies/smart-proxy/index.js';
|
||||
import { expect, tap } from '@git.zone/tstest/tapbundle';
|
||||
|
||||
const testProxy = new SmartProxy({
|
||||
routes: [{
|
||||
name: 'test-route',
|
||||
match: { ports: 9443, domains: 'test.local' },
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{ host: 'localhost', port: 8080 }],
|
||||
tls: {
|
||||
mode: 'terminate',
|
||||
certificate: 'auto',
|
||||
acme: {
|
||||
email: 'test@test.local',
|
||||
useProduction: false
|
||||
}
|
||||
}
|
||||
}
|
||||
}],
|
||||
acme: {
|
||||
port: 9080 // Use high port for ACME challenges
|
||||
}
|
||||
});
|
||||
|
||||
tap.test('should provision certificate automatically', async () => {
|
||||
// Mock certificate manager to avoid real ACME initialization
|
||||
const mockCertStatus = {
|
||||
domain: 'test-route',
|
||||
status: 'valid' as const,
|
||||
source: 'acme' as const,
|
||||
expiryDate: new Date(Date.now() + 90 * 24 * 60 * 60 * 1000),
|
||||
issueDate: new Date()
|
||||
};
|
||||
|
||||
(testProxy as any).createCertificateManager = async function() {
|
||||
return {
|
||||
setUpdateRoutesCallback: () => {},
|
||||
setHttpProxy: () => {},
|
||||
setGlobalAcmeDefaults: () => {},
|
||||
setAcmeStateManager: () => {},
|
||||
initialize: async () => {},
|
||||
provisionAllCertificates: async () => {},
|
||||
stop: async () => {},
|
||||
getAcmeOptions: () => ({ email: 'test@test.local', useProduction: false }),
|
||||
getState: () => ({ challengeRouteActive: false }),
|
||||
getCertificateStatus: () => mockCertStatus
|
||||
};
|
||||
};
|
||||
|
||||
(testProxy as any).getCertificateStatus = () => mockCertStatus;
|
||||
|
||||
await testProxy.start();
|
||||
|
||||
const status = testProxy.getCertificateStatus('test-route');
|
||||
expect(status).toBeDefined();
|
||||
expect(status.status).toEqual('valid');
|
||||
expect(status.source).toEqual('acme');
|
||||
|
||||
await testProxy.stop();
|
||||
});
|
||||
|
||||
tap.test('should handle static certificates', async () => {
|
||||
const proxy = new SmartProxy({
|
||||
routes: [{
|
||||
name: 'static-route',
|
||||
match: { ports: 9444, domains: 'static.example.com' },
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{ host: 'localhost', port: 8080 }],
|
||||
tls: {
|
||||
mode: 'terminate',
|
||||
certificate: {
|
||||
cert: '-----BEGIN CERTIFICATE-----\nMIIC...\n-----END CERTIFICATE-----',
|
||||
key: '-----BEGIN PRIVATE KEY-----\nMIIE...\n-----END PRIVATE KEY-----'
|
||||
}
|
||||
}
|
||||
}
|
||||
}]
|
||||
});
|
||||
|
||||
await proxy.start();
|
||||
|
||||
const status = proxy.getCertificateStatus('static-route');
|
||||
expect(status).toBeDefined();
|
||||
expect(status.status).toEqual('valid');
|
||||
expect(status.source).toEqual('static');
|
||||
|
||||
await proxy.stop();
|
||||
});
|
||||
|
||||
tap.test('should handle ACME challenge routes', async () => {
|
||||
const proxy = new SmartProxy({
|
||||
routes: [{
|
||||
name: 'auto-cert-route',
|
||||
match: { ports: 9445, domains: 'acme.local' },
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{ host: 'localhost', port: 8080 }],
|
||||
tls: {
|
||||
mode: 'terminate',
|
||||
certificate: 'auto',
|
||||
acme: {
|
||||
email: 'acme@test.local',
|
||||
useProduction: false,
|
||||
challengePort: 9081
|
||||
}
|
||||
}
|
||||
}
|
||||
}, {
|
||||
name: 'port-9081-route',
|
||||
match: { ports: 9081, domains: 'acme.local' },
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{ host: 'localhost', port: 8080 }]
|
||||
}
|
||||
}],
|
||||
acme: {
|
||||
port: 9081 // Use high port for ACME challenges
|
||||
}
|
||||
});
|
||||
|
||||
// Mock certificate manager to avoid real ACME initialization
|
||||
(proxy as any).createCertificateManager = async function() {
|
||||
return {
|
||||
setUpdateRoutesCallback: () => {},
|
||||
setHttpProxy: () => {},
|
||||
setGlobalAcmeDefaults: () => {},
|
||||
setAcmeStateManager: () => {},
|
||||
initialize: async () => {},
|
||||
provisionAllCertificates: async () => {},
|
||||
stop: async () => {},
|
||||
getAcmeOptions: () => ({ email: 'acme@test.local', useProduction: false }),
|
||||
getState: () => ({ challengeRouteActive: false })
|
||||
};
|
||||
};
|
||||
|
||||
await proxy.start();
|
||||
|
||||
// Verify the proxy is configured with routes including the necessary port
|
||||
const routes = proxy.settings.routes;
|
||||
|
||||
// Check that we have a route listening on the ACME challenge port
|
||||
const acmeChallengePort = 9081;
|
||||
const routesOnChallengePort = routes.filter((r: any) => {
|
||||
const ports = Array.isArray(r.match.ports) ? r.match.ports : [r.match.ports];
|
||||
return ports.includes(acmeChallengePort);
|
||||
});
|
||||
|
||||
expect(routesOnChallengePort.length).toBeGreaterThan(0);
|
||||
expect(routesOnChallengePort[0].name).toEqual('port-9081-route');
|
||||
|
||||
// Verify the main route has ACME configuration
|
||||
const mainRoute = routes.find((r: any) => r.name === 'auto-cert-route');
|
||||
expect(mainRoute).toBeDefined();
|
||||
expect(mainRoute?.action.tls?.certificate).toEqual('auto');
|
||||
expect(mainRoute?.action.tls?.acme?.email).toEqual('acme@test.local');
|
||||
expect(mainRoute?.action.tls?.acme?.challengePort).toEqual(9081);
|
||||
|
||||
await proxy.stop();
|
||||
});
|
||||
|
||||
tap.test('should renew certificates', async () => {
|
||||
const proxy = new SmartProxy({
|
||||
routes: [{
|
||||
name: 'renew-route',
|
||||
match: { ports: 9446, domains: 'renew.local' },
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{ host: 'localhost', port: 8080 }],
|
||||
tls: {
|
||||
mode: 'terminate',
|
||||
certificate: 'auto',
|
||||
acme: {
|
||||
email: 'renew@test.local',
|
||||
useProduction: false,
|
||||
renewBeforeDays: 30
|
||||
}
|
||||
}
|
||||
}
|
||||
}],
|
||||
acme: {
|
||||
port: 9082 // Use high port for ACME challenges
|
||||
}
|
||||
});
|
||||
|
||||
// Mock certificate manager with renewal capability
|
||||
let renewCalled = false;
|
||||
const mockCertStatus = {
|
||||
domain: 'renew-route',
|
||||
status: 'valid' as const,
|
||||
source: 'acme' as const,
|
||||
expiryDate: new Date(Date.now() + 90 * 24 * 60 * 60 * 1000),
|
||||
issueDate: new Date()
|
||||
};
|
||||
|
||||
(proxy as any).certManager = {
|
||||
renewCertificate: async (routeName: string) => {
|
||||
renewCalled = true;
|
||||
expect(routeName).toEqual('renew-route');
|
||||
},
|
||||
getCertificateStatus: () => mockCertStatus,
|
||||
setUpdateRoutesCallback: () => {},
|
||||
setHttpProxy: () => {},
|
||||
setGlobalAcmeDefaults: () => {},
|
||||
setAcmeStateManager: () => {},
|
||||
initialize: async () => {},
|
||||
provisionAllCertificates: async () => {},
|
||||
stop: async () => {},
|
||||
getAcmeOptions: () => ({ email: 'renew@test.local', useProduction: false }),
|
||||
getState: () => ({ challengeRouteActive: false })
|
||||
};
|
||||
|
||||
(proxy as any).createCertificateManager = async function() {
|
||||
return this.certManager;
|
||||
};
|
||||
|
||||
(proxy as any).getCertificateStatus = function(routeName: string) {
|
||||
return this.certManager.getCertificateStatus(routeName);
|
||||
};
|
||||
|
||||
(proxy as any).renewCertificate = async function(routeName: string) {
|
||||
if (this.certManager) {
|
||||
await this.certManager.renewCertificate(routeName);
|
||||
}
|
||||
};
|
||||
|
||||
await proxy.start();
|
||||
|
||||
// Force renewal
|
||||
await proxy.renewCertificate('renew-route');
|
||||
expect(renewCalled).toBeTrue();
|
||||
|
||||
const status = proxy.getCertificateStatus('renew-route');
|
||||
expect(status).toBeDefined();
|
||||
expect(status.status).toEqual('valid');
|
||||
|
||||
await proxy.stop();
|
||||
});
|
||||
|
||||
export default tap.start();
|
||||
@@ -1,146 +0,0 @@
|
||||
import { expect, tap } from '@git.zone/tstest/tapbundle';
|
||||
import { SmartProxy } from '../ts/index.js';
|
||||
|
||||
tap.test('cleanup queue bug - verify queue processing handles more than batch size', async () => {
|
||||
console.log('\n=== Cleanup Queue Bug Test ===');
|
||||
console.log('Purpose: Verify that the cleanup queue correctly processes all connections');
|
||||
console.log('even when there are more than the batch size (100)');
|
||||
|
||||
// Create proxy
|
||||
const proxy = new SmartProxy({
|
||||
routes: [{
|
||||
name: 'test-route',
|
||||
match: { ports: 8588 },
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{ host: 'localhost', port: 9996 }]
|
||||
}
|
||||
}],
|
||||
enableDetailedLogging: false,
|
||||
});
|
||||
|
||||
await proxy.start();
|
||||
console.log('✓ Proxy started on port 8588');
|
||||
|
||||
// Access connection manager
|
||||
const cm = (proxy as any).connectionManager;
|
||||
|
||||
// Create mock connection records
|
||||
console.log('\n--- Creating 150 mock connections ---');
|
||||
const mockConnections: any[] = [];
|
||||
|
||||
for (let i = 0; i < 150; i++) {
|
||||
// Create mock socket objects with necessary methods
|
||||
const mockIncoming = {
|
||||
destroyed: true,
|
||||
writable: false,
|
||||
remoteAddress: '127.0.0.1',
|
||||
removeAllListeners: () => {},
|
||||
destroy: () => {},
|
||||
end: () => {},
|
||||
on: () => {},
|
||||
once: () => {},
|
||||
emit: () => {},
|
||||
pause: () => {},
|
||||
resume: () => {}
|
||||
};
|
||||
|
||||
const mockOutgoing = {
|
||||
destroyed: true,
|
||||
writable: false,
|
||||
removeAllListeners: () => {},
|
||||
destroy: () => {},
|
||||
end: () => {},
|
||||
on: () => {},
|
||||
once: () => {},
|
||||
emit: () => {}
|
||||
};
|
||||
|
||||
const mockRecord = {
|
||||
id: `mock-${i}`,
|
||||
incoming: mockIncoming,
|
||||
outgoing: mockOutgoing,
|
||||
connectionClosed: false,
|
||||
incomingStartTime: Date.now(),
|
||||
lastActivity: Date.now(),
|
||||
remoteIP: '127.0.0.1',
|
||||
remotePort: 10000 + i,
|
||||
localPort: 8588,
|
||||
bytesReceived: 100,
|
||||
bytesSent: 100,
|
||||
incomingTerminationReason: null,
|
||||
cleanupTimer: null
|
||||
};
|
||||
|
||||
// Add to connection records
|
||||
cm.connectionRecords.set(mockRecord.id, mockRecord);
|
||||
mockConnections.push(mockRecord);
|
||||
}
|
||||
|
||||
console.log(`Created ${cm.getConnectionCount()} mock connections`);
|
||||
expect(cm.getConnectionCount()).toEqual(150);
|
||||
|
||||
// Queue all connections for cleanup
|
||||
console.log('\n--- Queueing all connections for cleanup ---');
|
||||
|
||||
// The cleanup queue processes immediately when it reaches batch size (100)
|
||||
// So after queueing 150, the first 100 will be processed immediately
|
||||
for (const conn of mockConnections) {
|
||||
cm.initiateCleanupOnce(conn, 'test_cleanup');
|
||||
}
|
||||
|
||||
// After queueing 150, the first 100 should have been processed immediately
|
||||
// leaving 50 in the queue
|
||||
console.log(`Cleanup queue size after queueing: ${cm.cleanupQueue.size}`);
|
||||
console.log(`Active connections after initial batch: ${cm.getConnectionCount()}`);
|
||||
|
||||
// The first 100 should have been cleaned up immediately
|
||||
expect(cm.cleanupQueue.size).toEqual(50);
|
||||
expect(cm.getConnectionCount()).toEqual(50);
|
||||
|
||||
// Wait for remaining cleanup to complete
|
||||
console.log('\n--- Waiting for remaining cleanup batches to process ---');
|
||||
|
||||
// The remaining 50 connections should be cleaned up in the next batch
|
||||
let waitTime = 0;
|
||||
let lastCount = cm.getConnectionCount();
|
||||
|
||||
while (cm.getConnectionCount() > 0 || cm.cleanupQueue.size > 0) {
|
||||
await new Promise(resolve => setTimeout(resolve, 100));
|
||||
waitTime += 100;
|
||||
|
||||
const currentCount = cm.getConnectionCount();
|
||||
if (currentCount !== lastCount) {
|
||||
console.log(`Active connections: ${currentCount}, Queue size: ${cm.cleanupQueue.size}`);
|
||||
lastCount = currentCount;
|
||||
}
|
||||
|
||||
if (waitTime > 5000) {
|
||||
console.log('Timeout waiting for cleanup to complete');
|
||||
break;
|
||||
}
|
||||
}
|
||||
console.log(`All cleanup completed in ${waitTime}ms`);
|
||||
|
||||
// Check final state
|
||||
const finalCount = cm.getConnectionCount();
|
||||
console.log(`\nFinal connection count: ${finalCount}`);
|
||||
console.log(`Final cleanup queue size: ${cm.cleanupQueue.size}`);
|
||||
|
||||
// All connections should be cleaned up
|
||||
expect(finalCount).toEqual(0);
|
||||
expect(cm.cleanupQueue.size).toEqual(0);
|
||||
|
||||
// Verify termination stats - all 150 should have been terminated
|
||||
const stats = cm.getTerminationStats();
|
||||
console.log('Termination stats:', stats);
|
||||
expect(stats.incoming.test_cleanup).toEqual(150);
|
||||
|
||||
// Cleanup
|
||||
console.log('\n--- Stopping proxy ---');
|
||||
await proxy.stop();
|
||||
|
||||
console.log('\n✓ Test complete: Cleanup queue now correctly processes all connections');
|
||||
});
|
||||
|
||||
export default tap.start();
|
||||
@@ -1,240 +0,0 @@
|
||||
import { tap, expect } from '@git.zone/tstest/tapbundle';
|
||||
import * as net from 'net';
|
||||
import * as plugins from '../ts/plugins.js';
|
||||
|
||||
// Import SmartProxy and configurations
|
||||
import { SmartProxy } from '../ts/index.js';
|
||||
|
||||
tap.test('should handle clients that connect and immediately disconnect without sending data', async () => {
|
||||
console.log('\n=== Testing Connect-Disconnect Cleanup ===');
|
||||
|
||||
// Create a SmartProxy instance
|
||||
const proxy = new SmartProxy({
|
||||
enableDetailedLogging: false,
|
||||
initialDataTimeout: 5000, // 5 second timeout for initial data
|
||||
routes: [{
|
||||
name: 'test-route',
|
||||
match: { ports: 8560 },
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{
|
||||
host: 'localhost',
|
||||
port: 9999 // Non-existent port
|
||||
}]
|
||||
}
|
||||
}]
|
||||
});
|
||||
|
||||
// Start the proxy
|
||||
await proxy.start();
|
||||
console.log('✓ Proxy started on port 8560');
|
||||
|
||||
// Helper to get active connection count
|
||||
const getActiveConnections = () => {
|
||||
const connectionManager = (proxy as any).connectionManager;
|
||||
return connectionManager ? connectionManager.getConnectionCount() : 0;
|
||||
};
|
||||
|
||||
const initialCount = getActiveConnections();
|
||||
console.log(`Initial connection count: ${initialCount}`);
|
||||
|
||||
// Test 1: Connect and immediately disconnect without sending data
|
||||
console.log('\n--- Test 1: Immediate disconnect ---');
|
||||
const connectionCounts: number[] = [];
|
||||
|
||||
for (let i = 0; i < 10; i++) {
|
||||
const client = new net.Socket();
|
||||
|
||||
// Connect and immediately destroy
|
||||
client.connect(8560, 'localhost', () => {
|
||||
// Connected - immediately destroy without sending data
|
||||
client.destroy();
|
||||
});
|
||||
|
||||
// Wait a tiny bit
|
||||
await new Promise(resolve => setTimeout(resolve, 10));
|
||||
|
||||
const count = getActiveConnections();
|
||||
connectionCounts.push(count);
|
||||
if ((i + 1) % 5 === 0) {
|
||||
console.log(`After ${i + 1} connect/disconnect cycles: ${count} active connections`);
|
||||
}
|
||||
}
|
||||
|
||||
// Wait a bit for cleanup
|
||||
await new Promise(resolve => setTimeout(resolve, 500));
|
||||
|
||||
const afterImmediateDisconnect = getActiveConnections();
|
||||
console.log(`After immediate disconnect test: ${afterImmediateDisconnect} active connections`);
|
||||
|
||||
// Test 2: Connect, wait a bit, then disconnect without sending data
|
||||
console.log('\n--- Test 2: Delayed disconnect ---');
|
||||
|
||||
for (let i = 0; i < 5; i++) {
|
||||
const client = new net.Socket();
|
||||
|
||||
client.on('error', () => {
|
||||
// Ignore errors
|
||||
});
|
||||
|
||||
client.connect(8560, 'localhost', () => {
|
||||
// Wait 100ms then disconnect without sending data
|
||||
setTimeout(() => {
|
||||
if (!client.destroyed) {
|
||||
client.destroy();
|
||||
}
|
||||
}, 100);
|
||||
});
|
||||
}
|
||||
|
||||
// Check count immediately
|
||||
const duringDelayed = getActiveConnections();
|
||||
console.log(`During delayed disconnect test: ${duringDelayed} active connections`);
|
||||
|
||||
// Wait for cleanup
|
||||
await new Promise(resolve => setTimeout(resolve, 1000));
|
||||
|
||||
const afterDelayedDisconnect = getActiveConnections();
|
||||
console.log(`After delayed disconnect test: ${afterDelayedDisconnect} active connections`);
|
||||
|
||||
// Test 3: Mix of immediate and delayed disconnects
|
||||
console.log('\n--- Test 3: Mixed disconnect patterns ---');
|
||||
|
||||
const promises = [];
|
||||
for (let i = 0; i < 20; i++) {
|
||||
promises.push(new Promise<void>((resolve) => {
|
||||
const client = new net.Socket();
|
||||
|
||||
client.on('error', () => {
|
||||
resolve();
|
||||
});
|
||||
|
||||
client.on('close', () => {
|
||||
resolve();
|
||||
});
|
||||
|
||||
client.connect(8560, 'localhost', () => {
|
||||
if (i % 2 === 0) {
|
||||
// Half disconnect immediately
|
||||
client.destroy();
|
||||
} else {
|
||||
// Half wait 50ms
|
||||
setTimeout(() => {
|
||||
if (!client.destroyed) {
|
||||
client.destroy();
|
||||
}
|
||||
}, 50);
|
||||
}
|
||||
});
|
||||
|
||||
// Failsafe timeout
|
||||
setTimeout(() => resolve(), 200);
|
||||
}));
|
||||
}
|
||||
|
||||
// Wait for all to complete
|
||||
await Promise.all(promises);
|
||||
|
||||
const duringMixed = getActiveConnections();
|
||||
console.log(`During mixed test: ${duringMixed} active connections`);
|
||||
|
||||
// Final cleanup wait
|
||||
await new Promise(resolve => setTimeout(resolve, 1000));
|
||||
|
||||
const finalCount = getActiveConnections();
|
||||
console.log(`\nFinal connection count: ${finalCount}`);
|
||||
|
||||
// Stop the proxy
|
||||
await proxy.stop();
|
||||
console.log('✓ Proxy stopped');
|
||||
|
||||
// Verify all connections were cleaned up
|
||||
expect(finalCount).toEqual(initialCount);
|
||||
expect(afterImmediateDisconnect).toEqual(initialCount);
|
||||
expect(afterDelayedDisconnect).toEqual(initialCount);
|
||||
|
||||
// Check that connections didn't accumulate during the test
|
||||
const maxCount = Math.max(...connectionCounts);
|
||||
console.log(`\nMax connection count during immediate disconnect test: ${maxCount}`);
|
||||
expect(maxCount).toBeLessThan(3); // Should stay very low
|
||||
|
||||
console.log('\n✅ PASS: Connect-disconnect cleanup working correctly!');
|
||||
});
|
||||
|
||||
tap.test('should handle clients that error during connection', async () => {
|
||||
console.log('\n=== Testing Connection Error Cleanup ===');
|
||||
|
||||
const proxy = new SmartProxy({
|
||||
enableDetailedLogging: false,
|
||||
routes: [{
|
||||
name: 'test-route',
|
||||
match: { ports: 8561 },
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{
|
||||
host: 'localhost',
|
||||
port: 9999
|
||||
}]
|
||||
}
|
||||
}]
|
||||
});
|
||||
|
||||
await proxy.start();
|
||||
console.log('✓ Proxy started on port 8561');
|
||||
|
||||
const getActiveConnections = () => {
|
||||
const connectionManager = (proxy as any).connectionManager;
|
||||
return connectionManager ? connectionManager.getConnectionCount() : 0;
|
||||
};
|
||||
|
||||
const initialCount = getActiveConnections();
|
||||
console.log(`Initial connection count: ${initialCount}`);
|
||||
|
||||
// Create connections that will error
|
||||
const promises = [];
|
||||
for (let i = 0; i < 10; i++) {
|
||||
promises.push(new Promise<void>((resolve) => {
|
||||
const client = new net.Socket();
|
||||
|
||||
client.on('error', () => {
|
||||
resolve();
|
||||
});
|
||||
|
||||
client.on('close', () => {
|
||||
resolve();
|
||||
});
|
||||
|
||||
// Connect to proxy
|
||||
client.connect(8561, 'localhost', () => {
|
||||
// Force an error by writing invalid data then destroying
|
||||
try {
|
||||
client.write(Buffer.alloc(1024 * 1024)); // Large write
|
||||
client.destroy();
|
||||
} catch (e) {
|
||||
// Ignore
|
||||
}
|
||||
});
|
||||
|
||||
// Timeout
|
||||
setTimeout(() => resolve(), 500);
|
||||
}));
|
||||
}
|
||||
|
||||
await Promise.all(promises);
|
||||
console.log('✓ All error connections completed');
|
||||
|
||||
// Wait for cleanup
|
||||
await new Promise(resolve => setTimeout(resolve, 500));
|
||||
|
||||
const finalCount = getActiveConnections();
|
||||
console.log(`Final connection count: ${finalCount}`);
|
||||
|
||||
await proxy.stop();
|
||||
console.log('✓ Proxy stopped');
|
||||
|
||||
expect(finalCount).toEqual(initialCount);
|
||||
|
||||
console.log('\n✅ PASS: Connection error cleanup working correctly!');
|
||||
});
|
||||
|
||||
export default tap.start();
|
||||
@@ -1,277 +0,0 @@
|
||||
import { tap, expect } from '@git.zone/tstest/tapbundle';
|
||||
import * as net from 'net';
|
||||
import * as plugins from '../ts/plugins.js';
|
||||
|
||||
// Import SmartProxy and configurations
|
||||
import { SmartProxy } from '../ts/index.js';
|
||||
|
||||
tap.test('comprehensive connection cleanup test - all scenarios', async () => {
|
||||
console.log('\n=== Comprehensive Connection Cleanup Test ===');
|
||||
|
||||
// Create a SmartProxy instance
|
||||
const proxy = new SmartProxy({
|
||||
enableDetailedLogging: false,
|
||||
initialDataTimeout: 2000,
|
||||
socketTimeout: 5000,
|
||||
routes: [
|
||||
{
|
||||
name: 'non-tls-route',
|
||||
match: { ports: 8570 },
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{
|
||||
host: 'localhost',
|
||||
port: 9999 // Non-existent port
|
||||
}]
|
||||
}
|
||||
},
|
||||
{
|
||||
name: 'tls-route',
|
||||
match: { ports: 8571 },
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{
|
||||
host: 'localhost',
|
||||
port: 9999 // Non-existent port
|
||||
}],
|
||||
tls: {
|
||||
mode: 'passthrough'
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
});
|
||||
|
||||
// Start the proxy
|
||||
await proxy.start();
|
||||
console.log('✓ Proxy started on ports 8570 (non-TLS) and 8571 (TLS)');
|
||||
|
||||
// Helper to get active connection count
|
||||
const getActiveConnections = () => {
|
||||
const connectionManager = (proxy as any).connectionManager;
|
||||
return connectionManager ? connectionManager.getConnectionCount() : 0;
|
||||
};
|
||||
|
||||
const initialCount = getActiveConnections();
|
||||
console.log(`Initial connection count: ${initialCount}`);
|
||||
|
||||
// Test 1: Rapid ECONNREFUSED retries (from original issue)
|
||||
console.log('\n--- Test 1: Rapid ECONNREFUSED retries ---');
|
||||
for (let i = 0; i < 10; i++) {
|
||||
await new Promise<void>((resolve) => {
|
||||
const client = new net.Socket();
|
||||
|
||||
client.on('error', () => {
|
||||
client.destroy();
|
||||
resolve();
|
||||
});
|
||||
|
||||
client.on('close', () => {
|
||||
resolve();
|
||||
});
|
||||
|
||||
client.connect(8570, 'localhost', () => {
|
||||
// Send data to trigger routing
|
||||
client.write('GET / HTTP/1.1\r\nHost: test.com\r\n\r\n');
|
||||
});
|
||||
|
||||
setTimeout(() => {
|
||||
if (!client.destroyed) {
|
||||
client.destroy();
|
||||
}
|
||||
resolve();
|
||||
}, 100);
|
||||
});
|
||||
|
||||
if ((i + 1) % 5 === 0) {
|
||||
const count = getActiveConnections();
|
||||
console.log(`After ${i + 1} ECONNREFUSED retries: ${count} active connections`);
|
||||
}
|
||||
}
|
||||
|
||||
// Test 2: Connect without sending data (immediate disconnect)
|
||||
console.log('\n--- Test 2: Connect without sending data ---');
|
||||
for (let i = 0; i < 10; i++) {
|
||||
const client = new net.Socket();
|
||||
|
||||
client.on('error', () => {
|
||||
// Ignore
|
||||
});
|
||||
|
||||
// Connect to non-TLS port and immediately disconnect
|
||||
client.connect(8570, 'localhost', () => {
|
||||
client.destroy();
|
||||
});
|
||||
|
||||
await new Promise(resolve => setTimeout(resolve, 10));
|
||||
}
|
||||
|
||||
const afterNoData = getActiveConnections();
|
||||
console.log(`After connect-without-data test: ${afterNoData} active connections`);
|
||||
|
||||
// Test 3: TLS connections that disconnect before handshake
|
||||
console.log('\n--- Test 3: TLS early disconnect ---');
|
||||
for (let i = 0; i < 10; i++) {
|
||||
const client = new net.Socket();
|
||||
|
||||
client.on('error', () => {
|
||||
// Ignore
|
||||
});
|
||||
|
||||
// Connect to TLS port but disconnect before sending handshake
|
||||
client.connect(8571, 'localhost', () => {
|
||||
// Wait 50ms then disconnect (before initial data timeout)
|
||||
setTimeout(() => {
|
||||
client.destroy();
|
||||
}, 50);
|
||||
});
|
||||
|
||||
await new Promise(resolve => setTimeout(resolve, 100));
|
||||
}
|
||||
|
||||
const afterTlsEarly = getActiveConnections();
|
||||
console.log(`After TLS early disconnect test: ${afterTlsEarly} active connections`);
|
||||
|
||||
// Test 4: Mixed pattern - simulating real-world chaos
|
||||
console.log('\n--- Test 4: Mixed chaos pattern ---');
|
||||
const promises = [];
|
||||
|
||||
for (let i = 0; i < 30; i++) {
|
||||
promises.push(new Promise<void>((resolve) => {
|
||||
const client = new net.Socket();
|
||||
const port = i % 2 === 0 ? 8570 : 8571;
|
||||
|
||||
client.on('error', () => {
|
||||
resolve();
|
||||
});
|
||||
|
||||
client.on('close', () => {
|
||||
resolve();
|
||||
});
|
||||
|
||||
client.connect(port, 'localhost', () => {
|
||||
const scenario = i % 5;
|
||||
|
||||
switch (scenario) {
|
||||
case 0:
|
||||
// Immediate disconnect
|
||||
client.destroy();
|
||||
break;
|
||||
case 1:
|
||||
// Send data then disconnect
|
||||
client.write('GET / HTTP/1.1\r\nHost: test.com\r\n\r\n');
|
||||
setTimeout(() => client.destroy(), 20);
|
||||
break;
|
||||
case 2:
|
||||
// Disconnect after delay
|
||||
setTimeout(() => client.destroy(), 100);
|
||||
break;
|
||||
case 3:
|
||||
// Send partial TLS handshake
|
||||
if (port === 8571) {
|
||||
client.write(Buffer.from([0x16, 0x03, 0x01])); // Partial TLS
|
||||
}
|
||||
setTimeout(() => client.destroy(), 50);
|
||||
break;
|
||||
case 4:
|
||||
// Just let it timeout
|
||||
break;
|
||||
}
|
||||
});
|
||||
|
||||
// Failsafe
|
||||
setTimeout(() => {
|
||||
if (!client.destroyed) {
|
||||
client.destroy();
|
||||
}
|
||||
resolve();
|
||||
}, 500);
|
||||
}));
|
||||
|
||||
// Small delay between connections
|
||||
if (i % 5 === 0) {
|
||||
await new Promise(resolve => setTimeout(resolve, 10));
|
||||
}
|
||||
}
|
||||
|
||||
await Promise.all(promises);
|
||||
console.log('✓ Chaos test completed');
|
||||
|
||||
// Wait for any cleanup
|
||||
await new Promise(resolve => setTimeout(resolve, 1000));
|
||||
|
||||
const afterChaos = getActiveConnections();
|
||||
console.log(`After chaos test: ${afterChaos} active connections`);
|
||||
|
||||
// Test 5: NFTables route (should cleanup properly)
|
||||
console.log('\n--- Test 5: NFTables route cleanup ---');
|
||||
const nftProxy = new SmartProxy({
|
||||
enableDetailedLogging: false,
|
||||
routes: [{
|
||||
name: 'nftables-route',
|
||||
match: { ports: 8572 },
|
||||
action: {
|
||||
type: 'forward',
|
||||
forwardingEngine: 'nftables',
|
||||
targets: [{
|
||||
host: 'localhost',
|
||||
port: 9999
|
||||
}]
|
||||
}
|
||||
}]
|
||||
});
|
||||
|
||||
await nftProxy.start();
|
||||
|
||||
const getNftConnections = () => {
|
||||
const connectionManager = (nftProxy as any).connectionManager;
|
||||
return connectionManager ? connectionManager.getConnectionCount() : 0;
|
||||
};
|
||||
|
||||
// Create NFTables connections
|
||||
for (let i = 0; i < 5; i++) {
|
||||
const client = new net.Socket();
|
||||
|
||||
client.on('error', () => {
|
||||
// Ignore
|
||||
});
|
||||
|
||||
client.connect(8572, 'localhost', () => {
|
||||
setTimeout(() => client.destroy(), 50);
|
||||
});
|
||||
|
||||
await new Promise(resolve => setTimeout(resolve, 100));
|
||||
}
|
||||
|
||||
await new Promise(resolve => setTimeout(resolve, 500));
|
||||
|
||||
const nftFinal = getNftConnections();
|
||||
console.log(`NFTables connections after test: ${nftFinal}`);
|
||||
|
||||
await nftProxy.stop();
|
||||
|
||||
// Final check on main proxy
|
||||
const finalCount = getActiveConnections();
|
||||
console.log(`\nFinal connection count: ${finalCount}`);
|
||||
|
||||
// Stop the proxy
|
||||
await proxy.stop();
|
||||
console.log('✓ Proxy stopped');
|
||||
|
||||
// Verify all connections were cleaned up
|
||||
expect(finalCount).toEqual(initialCount);
|
||||
expect(afterNoData).toEqual(initialCount);
|
||||
expect(afterTlsEarly).toEqual(initialCount);
|
||||
expect(afterChaos).toEqual(initialCount);
|
||||
expect(nftFinal).toEqual(0);
|
||||
|
||||
console.log('\n✅ PASS: Comprehensive connection cleanup test passed!');
|
||||
console.log('All connection scenarios properly cleaned up:');
|
||||
console.log('- ECONNREFUSED rapid retries');
|
||||
console.log('- Connect without sending data');
|
||||
console.log('- TLS early disconnect');
|
||||
console.log('- Mixed chaos patterns');
|
||||
console.log('- NFTables connections');
|
||||
});
|
||||
|
||||
export default tap.start();
|
||||
@@ -1,304 +0,0 @@
|
||||
import { expect, tap } from '@git.zone/tstest/tapbundle';
|
||||
import * as net from 'net';
|
||||
import { SmartProxy } from '../ts/proxies/smart-proxy/index.js';
|
||||
import { HttpProxy } from '../ts/proxies/http-proxy/index.js';
|
||||
|
||||
let testServer: net.Server;
|
||||
let smartProxy: SmartProxy;
|
||||
let httpProxy: HttpProxy;
|
||||
const TEST_SERVER_PORT = 5100;
|
||||
const PROXY_PORT = 5101;
|
||||
const HTTP_PROXY_PORT = 5102;
|
||||
|
||||
// Track all created servers and connections for cleanup
|
||||
const allServers: net.Server[] = [];
|
||||
const allProxies: (SmartProxy | HttpProxy)[] = [];
|
||||
const activeConnections: net.Socket[] = [];
|
||||
|
||||
// Helper: Creates a test TCP server
|
||||
function createTestServer(port: number): Promise<net.Server> {
|
||||
return new Promise((resolve) => {
|
||||
const server = net.createServer((socket) => {
|
||||
socket.on('data', (data) => {
|
||||
socket.write(`Echo: ${data.toString()}`);
|
||||
});
|
||||
socket.on('error', () => {});
|
||||
});
|
||||
server.listen(port, 'localhost', () => {
|
||||
console.log(`[Test Server] Listening on localhost:${port}`);
|
||||
allServers.push(server);
|
||||
resolve(server);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// Helper: Creates multiple concurrent connections
|
||||
// If waitForData is true, waits for the connection to be fully established (can receive data)
|
||||
async function createConcurrentConnections(
|
||||
port: number,
|
||||
count: number,
|
||||
waitForData: boolean = false
|
||||
): Promise<net.Socket[]> {
|
||||
const connections: net.Socket[] = [];
|
||||
const promises: Promise<net.Socket>[] = [];
|
||||
|
||||
for (let i = 0; i < count; i++) {
|
||||
promises.push(
|
||||
new Promise((resolve, reject) => {
|
||||
const client = new net.Socket();
|
||||
const timeout = setTimeout(() => {
|
||||
client.destroy();
|
||||
reject(new Error(`Connection ${i} timeout`));
|
||||
}, 5000);
|
||||
|
||||
client.connect(port, 'localhost', () => {
|
||||
if (!waitForData) {
|
||||
clearTimeout(timeout);
|
||||
activeConnections.push(client);
|
||||
connections.push(client);
|
||||
resolve(client);
|
||||
}
|
||||
// If waitForData, we wait for the close event to see if connection was rejected
|
||||
});
|
||||
|
||||
if (waitForData) {
|
||||
// Wait a bit to see if connection gets closed by server
|
||||
client.once('close', () => {
|
||||
clearTimeout(timeout);
|
||||
reject(new Error('Connection closed by server'));
|
||||
});
|
||||
|
||||
// If we can write and get a response, connection is truly established
|
||||
setTimeout(() => {
|
||||
if (!client.destroyed) {
|
||||
clearTimeout(timeout);
|
||||
activeConnections.push(client);
|
||||
connections.push(client);
|
||||
resolve(client);
|
||||
}
|
||||
}, 100);
|
||||
}
|
||||
|
||||
client.on('error', (err) => {
|
||||
clearTimeout(timeout);
|
||||
reject(err);
|
||||
});
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
await Promise.all(promises);
|
||||
return connections;
|
||||
}
|
||||
|
||||
// Helper: Clean up connections
|
||||
function cleanupConnections(connections: net.Socket[]): void {
|
||||
connections.forEach(conn => {
|
||||
if (!conn.destroyed) {
|
||||
conn.destroy();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
tap.test('Setup test environment', async () => {
|
||||
testServer = await createTestServer(TEST_SERVER_PORT);
|
||||
|
||||
// Create SmartProxy with low connection limits for testing
|
||||
smartProxy = new SmartProxy({
|
||||
routes: [{
|
||||
name: 'test-route',
|
||||
match: {
|
||||
ports: PROXY_PORT
|
||||
},
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{
|
||||
host: 'localhost',
|
||||
port: TEST_SERVER_PORT
|
||||
}]
|
||||
},
|
||||
security: {
|
||||
maxConnections: 5 // Low limit for testing
|
||||
}
|
||||
}],
|
||||
maxConnectionsPerIP: 3, // Low per-IP limit
|
||||
connectionRateLimitPerMinute: 10, // Low rate limit
|
||||
defaults: {
|
||||
security: {
|
||||
maxConnections: 10 // Low global limit
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
await smartProxy.start();
|
||||
allProxies.push(smartProxy);
|
||||
});
|
||||
|
||||
tap.test('Per-IP connection limits', async () => {
|
||||
// Test that we can create up to the per-IP limit
|
||||
const connections1 = await createConcurrentConnections(PROXY_PORT, 3);
|
||||
expect(connections1.length).toEqual(3);
|
||||
|
||||
// Allow server-side processing to complete
|
||||
await new Promise(resolve => setTimeout(resolve, 50));
|
||||
|
||||
// Try to create one more connection - should fail
|
||||
// Use waitForData=true to detect if server closes the connection after accepting it
|
||||
try {
|
||||
await createConcurrentConnections(PROXY_PORT, 1, true);
|
||||
// If we get here, the 4th connection was truly established
|
||||
throw new Error('Should not allow more than 3 connections per IP');
|
||||
} catch (err) {
|
||||
console.log(`Per-IP limit error received: ${err.message}`);
|
||||
// Connection should be rejected - either reset, refused, or closed by server
|
||||
const isRejected = err.message.includes('ECONNRESET') ||
|
||||
err.message.includes('ECONNREFUSED') ||
|
||||
err.message.includes('closed');
|
||||
expect(isRejected).toBeTrue();
|
||||
}
|
||||
|
||||
// Clean up first set of connections
|
||||
cleanupConnections(connections1);
|
||||
await new Promise(resolve => setTimeout(resolve, 100));
|
||||
|
||||
// Should be able to create new connections after cleanup
|
||||
const connections2 = await createConcurrentConnections(PROXY_PORT, 2);
|
||||
expect(connections2.length).toEqual(2);
|
||||
|
||||
cleanupConnections(connections2);
|
||||
});
|
||||
|
||||
tap.test('Route-level connection limits', async () => {
|
||||
// Create multiple connections up to route limit
|
||||
const connections = await createConcurrentConnections(PROXY_PORT, 5);
|
||||
expect(connections.length).toEqual(5);
|
||||
|
||||
// Try to exceed route limit
|
||||
try {
|
||||
await createConcurrentConnections(PROXY_PORT, 1);
|
||||
throw new Error('Should not allow more than 5 connections for this route');
|
||||
} catch (err) {
|
||||
// Connection should be rejected - either reset or refused
|
||||
console.log('Connection limit error:', err.message);
|
||||
const isRejected = err.message.includes('ECONNRESET') ||
|
||||
err.message.includes('ECONNREFUSED') ||
|
||||
err.message.includes('closed') ||
|
||||
err.message.includes('5 connections');
|
||||
expect(isRejected).toBeTrue();
|
||||
}
|
||||
|
||||
cleanupConnections(connections);
|
||||
});
|
||||
|
||||
tap.test('Connection rate limiting', async () => {
|
||||
// Create connections rapidly
|
||||
const connections: net.Socket[] = [];
|
||||
|
||||
// Create 10 connections rapidly (at rate limit)
|
||||
for (let i = 0; i < 10; i++) {
|
||||
try {
|
||||
const conn = await createConcurrentConnections(PROXY_PORT, 1);
|
||||
connections.push(...conn);
|
||||
// Small delay to avoid per-IP limit
|
||||
if (connections.length >= 3) {
|
||||
cleanupConnections(connections.splice(0, 3));
|
||||
await new Promise(resolve => setTimeout(resolve, 50));
|
||||
}
|
||||
} catch (err) {
|
||||
// Expected to fail at some point due to rate limit
|
||||
expect(i).toBeGreaterThan(0);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
cleanupConnections(connections);
|
||||
});
|
||||
|
||||
tap.test('HttpProxy per-IP validation', async () => {
|
||||
// Skip complex HttpProxy integration test - focus on SmartProxy connection limits
|
||||
// The HttpProxy has its own per-IP validation that's tested separately
|
||||
// This test would require TLS certificates and more complex setup
|
||||
console.log('Skipping HttpProxy per-IP validation - tested separately');
|
||||
});
|
||||
|
||||
tap.test('IP tracking cleanup', async (tools) => {
|
||||
// Wait for any previous test cleanup to complete
|
||||
await tools.delayFor(300);
|
||||
|
||||
// Create and close connections
|
||||
const connections: net.Socket[] = [];
|
||||
|
||||
for (let i = 0; i < 2; i++) {
|
||||
try {
|
||||
const conn = await createConcurrentConnections(PROXY_PORT, 1);
|
||||
connections.push(...conn);
|
||||
} catch {
|
||||
// Ignore rejections
|
||||
}
|
||||
}
|
||||
|
||||
// Close all connections
|
||||
cleanupConnections(connections);
|
||||
|
||||
// Wait for cleanup to process
|
||||
await tools.delayFor(500);
|
||||
|
||||
// Verify that IP tracking has been cleaned up
|
||||
const securityManager = (smartProxy as any).securityManager;
|
||||
const ipCount = securityManager.getConnectionCountByIP('::ffff:127.0.0.1');
|
||||
|
||||
// Should have no connections tracked for this IP after cleanup
|
||||
// Note: Due to asynchronous cleanup, we allow for some variance
|
||||
expect(ipCount).toBeLessThanOrEqual(1);
|
||||
});
|
||||
|
||||
tap.test('Cleanup queue race condition handling', async () => {
|
||||
// Wait for previous test cleanup
|
||||
await new Promise(resolve => setTimeout(resolve, 300));
|
||||
|
||||
// Create connections sequentially to avoid hitting per-IP limit
|
||||
const allConnections: net.Socket[] = [];
|
||||
for (let i = 0; i < 2; i++) {
|
||||
try {
|
||||
const conn = await createConcurrentConnections(PROXY_PORT, 1);
|
||||
allConnections.push(...conn);
|
||||
} catch {
|
||||
// Ignore connection rejections
|
||||
}
|
||||
}
|
||||
|
||||
// Close all connections rapidly
|
||||
allConnections.forEach(conn => conn.destroy());
|
||||
|
||||
// Give cleanup queue time to process
|
||||
await new Promise(resolve => setTimeout(resolve, 500));
|
||||
|
||||
// Verify all connections were cleaned up
|
||||
const connectionManager = (smartProxy as any).connectionManager;
|
||||
const remainingConnections = connectionManager.getConnectionCount();
|
||||
|
||||
// Allow for some variance due to async cleanup
|
||||
expect(remainingConnections).toBeLessThanOrEqual(1);
|
||||
});
|
||||
|
||||
tap.test('Cleanup and shutdown', async () => {
|
||||
// Clean up any remaining connections
|
||||
cleanupConnections(activeConnections);
|
||||
activeConnections.length = 0;
|
||||
|
||||
// Stop all proxies
|
||||
for (const proxy of allProxies) {
|
||||
await proxy.stop();
|
||||
}
|
||||
allProxies.length = 0;
|
||||
|
||||
// Close all test servers
|
||||
for (const server of allServers) {
|
||||
await new Promise<void>((resolve) => {
|
||||
server.close(() => resolve());
|
||||
});
|
||||
}
|
||||
allServers.length = 0;
|
||||
});
|
||||
|
||||
export default tap.start();
|
||||
@@ -1,83 +0,0 @@
|
||||
import { tap, expect } from '@git.zone/tstest/tapbundle';
|
||||
import { SmartProxy } from '../ts/index.js';
|
||||
|
||||
tap.test('should verify certificate manager callback is preserved on updateRoutes', async () => {
|
||||
// Create proxy with initial cert routes
|
||||
const proxy = new SmartProxy({
|
||||
routes: [{
|
||||
name: 'cert-route',
|
||||
match: { ports: [18443], domains: ['test.local'] },
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{ host: 'localhost', port: 3000 }],
|
||||
tls: {
|
||||
mode: 'terminate',
|
||||
certificate: 'auto',
|
||||
acme: { email: 'test@local.test' }
|
||||
}
|
||||
}
|
||||
}],
|
||||
acme: { email: 'test@local.test', port: 18080 }
|
||||
});
|
||||
|
||||
// Track callback preservation
|
||||
let initialCallbackSet = false;
|
||||
let updateCallbackSet = false;
|
||||
|
||||
// Mock certificate manager creation
|
||||
(proxy as any).createCertificateManager = async function(...args: any[]) {
|
||||
const certManager = {
|
||||
updateRoutesCallback: null as any,
|
||||
setUpdateRoutesCallback: function(callback: any) {
|
||||
this.updateRoutesCallback = callback;
|
||||
if (!initialCallbackSet) {
|
||||
initialCallbackSet = true;
|
||||
} else {
|
||||
updateCallbackSet = true;
|
||||
}
|
||||
},
|
||||
setHttpProxy: () => {},
|
||||
setGlobalAcmeDefaults: () => {},
|
||||
setAcmeStateManager: () => {},
|
||||
setRoutes: (routes: any) => {},
|
||||
initialize: async () => {},
|
||||
provisionAllCertificates: async () => {},
|
||||
stop: async () => {},
|
||||
getAcmeOptions: () => ({ email: 'test@local.test' }),
|
||||
getState: () => ({ challengeRouteActive: false })
|
||||
};
|
||||
|
||||
// Set callback as in real implementation
|
||||
certManager.setUpdateRoutesCallback(async (routes) => {
|
||||
await this.updateRoutes(routes);
|
||||
});
|
||||
|
||||
return certManager;
|
||||
};
|
||||
|
||||
await proxy.start();
|
||||
expect(initialCallbackSet).toEqual(true);
|
||||
|
||||
// Update routes - this should preserve the callback
|
||||
await proxy.updateRoutes([{
|
||||
name: 'updated-route',
|
||||
match: { ports: [18444], domains: ['test2.local'] },
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{ host: 'localhost', port: 3001 }],
|
||||
tls: {
|
||||
mode: 'terminate',
|
||||
certificate: 'auto',
|
||||
acme: { email: 'test@local.test' }
|
||||
}
|
||||
}
|
||||
}]);
|
||||
|
||||
expect(updateCallbackSet).toEqual(true);
|
||||
|
||||
await proxy.stop();
|
||||
|
||||
console.log('Fix verified: Certificate manager callback is preserved on updateRoutes');
|
||||
});
|
||||
|
||||
export default tap.start();
|
||||
@@ -1,183 +0,0 @@
|
||||
import { tap, expect } from '@git.zone/tstest/tapbundle';
|
||||
import * as net from 'net';
|
||||
|
||||
// Unit test for the HTTP forwarding fix
|
||||
tap.test('should forward non-TLS connections on HttpProxy ports', async (tapTest) => {
|
||||
// Test configuration
|
||||
const testPort = 8080;
|
||||
const httpProxyPort = 8844;
|
||||
|
||||
// Track forwarding logic
|
||||
let forwardedToHttpProxy = false;
|
||||
let setupDirectConnection = false;
|
||||
|
||||
// Create mock settings
|
||||
const mockSettings = {
|
||||
useHttpProxy: [testPort],
|
||||
httpProxyPort: httpProxyPort,
|
||||
routes: [{
|
||||
name: 'test-route',
|
||||
match: { ports: testPort },
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{ host: 'localhost', port: 8181 }]
|
||||
}
|
||||
}]
|
||||
};
|
||||
|
||||
// Create mock connection record
|
||||
const mockRecord = {
|
||||
id: 'test-connection',
|
||||
localPort: testPort,
|
||||
remoteIP: '127.0.0.1',
|
||||
isTLS: false
|
||||
};
|
||||
|
||||
// Mock HttpProxyBridge
|
||||
const mockHttpProxyBridge = {
|
||||
getHttpProxy: () => ({ available: true }),
|
||||
forwardToHttpProxy: async () => {
|
||||
forwardedToHttpProxy = true;
|
||||
}
|
||||
};
|
||||
|
||||
// Test the logic from handleForwardAction
|
||||
const route = mockSettings.routes[0];
|
||||
const action = route.action as any;
|
||||
|
||||
// Simulate the fixed logic
|
||||
if (!action.tls) {
|
||||
// No TLS settings - check if this port should use HttpProxy
|
||||
const isHttpProxyPort = mockSettings.useHttpProxy?.includes(mockRecord.localPort);
|
||||
|
||||
if (isHttpProxyPort && mockHttpProxyBridge.getHttpProxy()) {
|
||||
// Forward non-TLS connections to HttpProxy if configured
|
||||
console.log(`Using HttpProxy for non-TLS connection on port ${mockRecord.localPort}`);
|
||||
await mockHttpProxyBridge.forwardToHttpProxy();
|
||||
} else {
|
||||
// Basic forwarding
|
||||
console.log(`Using basic forwarding`);
|
||||
setupDirectConnection = true;
|
||||
}
|
||||
}
|
||||
|
||||
// Verify the fix works correctly
|
||||
expect(forwardedToHttpProxy).toEqual(true);
|
||||
expect(setupDirectConnection).toEqual(false);
|
||||
|
||||
console.log('Test passed: Non-TLS connections on HttpProxy ports are forwarded correctly');
|
||||
});
|
||||
|
||||
// Test that non-HttpProxy ports still use direct connection
|
||||
tap.test('should use direct connection for non-HttpProxy ports', async (tapTest) => {
|
||||
let forwardedToHttpProxy = false;
|
||||
let setupDirectConnection = false;
|
||||
|
||||
const mockSettings = {
|
||||
useHttpProxy: [80, 443], // Different ports
|
||||
httpProxyPort: 8844,
|
||||
routes: [{
|
||||
name: 'test-route',
|
||||
match: { ports: 8080 }, // Not in useHttpProxy
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{ host: 'localhost', port: 8181 }]
|
||||
}
|
||||
}]
|
||||
};
|
||||
|
||||
const mockRecord = {
|
||||
id: 'test-connection-2',
|
||||
localPort: 8080, // Not in useHttpProxy
|
||||
remoteIP: '127.0.0.1',
|
||||
isTLS: false
|
||||
};
|
||||
|
||||
const mockHttpProxyBridge = {
|
||||
getHttpProxy: () => ({ available: true }),
|
||||
forwardToHttpProxy: async () => {
|
||||
forwardedToHttpProxy = true;
|
||||
}
|
||||
};
|
||||
|
||||
const route = mockSettings.routes[0];
|
||||
const action = route.action as any;
|
||||
|
||||
// Test the logic
|
||||
if (!action.tls) {
|
||||
const isHttpProxyPort = mockSettings.useHttpProxy?.includes(mockRecord.localPort);
|
||||
|
||||
if (isHttpProxyPort && mockHttpProxyBridge.getHttpProxy()) {
|
||||
console.log(`Using HttpProxy for non-TLS connection on port ${mockRecord.localPort}`);
|
||||
await mockHttpProxyBridge.forwardToHttpProxy();
|
||||
} else {
|
||||
console.log(`Using basic forwarding for port ${mockRecord.localPort}`);
|
||||
setupDirectConnection = true;
|
||||
}
|
||||
}
|
||||
|
||||
// Verify port 8080 uses direct connection when not in useHttpProxy
|
||||
expect(forwardedToHttpProxy).toEqual(false);
|
||||
expect(setupDirectConnection).toEqual(true);
|
||||
|
||||
console.log('Test passed: Non-HttpProxy ports use direct connection');
|
||||
});
|
||||
|
||||
// Test HTTP-01 ACME challenge scenario
|
||||
tap.test('should handle ACME HTTP-01 challenges on port 80 with HttpProxy', async (tapTest) => {
|
||||
let forwardedToHttpProxy = false;
|
||||
|
||||
const mockSettings = {
|
||||
useHttpProxy: [80], // Port 80 configured for HttpProxy
|
||||
httpProxyPort: 8844,
|
||||
acme: {
|
||||
port: 80,
|
||||
email: 'test@example.com'
|
||||
},
|
||||
routes: [{
|
||||
name: 'acme-challenge',
|
||||
match: {
|
||||
ports: 80,
|
||||
paths: ['/.well-known/acme-challenge/*']
|
||||
},
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{ host: 'localhost', port: 8080 }]
|
||||
}
|
||||
}]
|
||||
};
|
||||
|
||||
const mockRecord = {
|
||||
id: 'acme-connection',
|
||||
localPort: 80,
|
||||
remoteIP: '127.0.0.1',
|
||||
isTLS: false
|
||||
};
|
||||
|
||||
const mockHttpProxyBridge = {
|
||||
getHttpProxy: () => ({ available: true }),
|
||||
forwardToHttpProxy: async () => {
|
||||
forwardedToHttpProxy = true;
|
||||
}
|
||||
};
|
||||
|
||||
const route = mockSettings.routes[0];
|
||||
const action = route.action as any;
|
||||
|
||||
// Test the fix for ACME HTTP-01 challenges
|
||||
if (!action.tls) {
|
||||
const isHttpProxyPort = mockSettings.useHttpProxy?.includes(mockRecord.localPort);
|
||||
|
||||
if (isHttpProxyPort && mockHttpProxyBridge.getHttpProxy()) {
|
||||
console.log(`Using HttpProxy for ACME challenge on port ${mockRecord.localPort}`);
|
||||
await mockHttpProxyBridge.forwardToHttpProxy();
|
||||
}
|
||||
}
|
||||
|
||||
// Verify HTTP-01 challenges on port 80 go through HttpProxy
|
||||
expect(forwardedToHttpProxy).toEqual(true);
|
||||
|
||||
console.log('Test passed: ACME HTTP-01 challenges on port 80 use HttpProxy');
|
||||
});
|
||||
|
||||
export default tap.start();
|
||||
@@ -1,256 +0,0 @@
|
||||
import { tap, expect } from '@git.zone/tstest/tapbundle';
|
||||
import { RouteConnectionHandler } from '../ts/proxies/smart-proxy/route-connection-handler.js';
|
||||
import type { ISmartProxyOptions } from '../ts/proxies/smart-proxy/models/interfaces.js';
|
||||
import * as net from 'net';
|
||||
|
||||
// Direct test of the fix in RouteConnectionHandler
|
||||
tap.test('should detect and forward non-TLS connections on useHttpProxy ports', async (tapTest) => {
|
||||
// Create mock objects
|
||||
const mockSettings: ISmartProxyOptions = {
|
||||
useHttpProxy: [8080],
|
||||
httpProxyPort: 8844,
|
||||
routes: [{
|
||||
name: 'test-route',
|
||||
match: { ports: 8080 },
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{ host: 'localhost', port: 8181 }]
|
||||
}
|
||||
}]
|
||||
};
|
||||
|
||||
let httpProxyForwardCalled = false;
|
||||
let directConnectionCalled = false;
|
||||
|
||||
// Create mocks for dependencies
|
||||
const mockHttpProxyBridge = {
|
||||
getHttpProxy: () => ({ available: true }),
|
||||
forwardToHttpProxy: async (...args: any[]) => {
|
||||
console.log('Mock: forwardToHttpProxy called');
|
||||
httpProxyForwardCalled = true;
|
||||
}
|
||||
};
|
||||
|
||||
// Mock connection manager
|
||||
const mockConnectionManager = {
|
||||
createConnection: (socket: any) => ({
|
||||
id: 'test-connection',
|
||||
localPort: 8080,
|
||||
remoteIP: '127.0.0.1',
|
||||
isTLS: false
|
||||
}),
|
||||
generateConnectionId: () => 'test-connection-id',
|
||||
initiateCleanupOnce: () => {},
|
||||
cleanupConnection: () => {},
|
||||
getConnectionCount: () => 1,
|
||||
trackConnectionByRoute: (routeId: string, connectionId: string) => {},
|
||||
handleError: (type: string, record: any) => {
|
||||
return (error: Error) => {
|
||||
console.log(`Mock: Error handled for ${type}: ${error.message}`);
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
// Mock route manager that returns a matching route
|
||||
const mockRouteManager = {
|
||||
findMatchingRoute: (criteria: any) => ({
|
||||
route: mockSettings.routes[0]
|
||||
}),
|
||||
getRoutes: () => mockSettings.routes,
|
||||
getRoutesForPort: (port: number) => mockSettings.routes.filter(r => {
|
||||
const ports = Array.isArray(r.match.ports) ? r.match.ports : [r.match.ports];
|
||||
return ports.some(p => {
|
||||
if (typeof p === 'number') {
|
||||
return p === port;
|
||||
} else if (p && typeof p === 'object' && 'from' in p && 'to' in p) {
|
||||
return port >= p.from && port <= p.to;
|
||||
}
|
||||
return false;
|
||||
});
|
||||
})
|
||||
};
|
||||
|
||||
// Mock security manager
|
||||
const mockSecurityManager = {
|
||||
validateAndTrackIP: () => ({ allowed: true })
|
||||
};
|
||||
|
||||
// Create a mock SmartProxy instance with necessary properties
|
||||
const mockSmartProxy = {
|
||||
settings: mockSettings,
|
||||
connectionManager: mockConnectionManager,
|
||||
securityManager: mockSecurityManager,
|
||||
httpProxyBridge: mockHttpProxyBridge,
|
||||
routeManager: mockRouteManager
|
||||
} as any;
|
||||
|
||||
// Create route connection handler instance
|
||||
const handler = new RouteConnectionHandler(mockSmartProxy);
|
||||
|
||||
// Override setupDirectConnection to track if it's called
|
||||
handler['setupDirectConnection'] = (...args: any[]) => {
|
||||
console.log('Mock: setupDirectConnection called');
|
||||
directConnectionCalled = true;
|
||||
};
|
||||
|
||||
// Test: Create a mock socket representing non-TLS connection on port 8080
|
||||
const mockSocket = {
|
||||
localPort: 8080,
|
||||
remoteAddress: '127.0.0.1',
|
||||
on: function(event: string, handler: Function) { return this; },
|
||||
once: function(event: string, handler: Function) {
|
||||
// Capture the data handler
|
||||
if (event === 'data') {
|
||||
this._dataHandler = handler;
|
||||
}
|
||||
return this;
|
||||
},
|
||||
end: () => {},
|
||||
destroy: () => {},
|
||||
pause: () => {},
|
||||
resume: () => {},
|
||||
removeListener: function() { return this; },
|
||||
emit: () => {},
|
||||
setNoDelay: () => {},
|
||||
setKeepAlive: () => {},
|
||||
_dataHandler: null as any
|
||||
} as any;
|
||||
|
||||
// Simulate the handler processing the connection
|
||||
handler.handleConnection(mockSocket);
|
||||
|
||||
// Simulate receiving non-TLS data
|
||||
if (mockSocket._dataHandler) {
|
||||
mockSocket._dataHandler(Buffer.from('GET / HTTP/1.1\r\nHost: test.local\r\n\r\n'));
|
||||
}
|
||||
|
||||
// Give it a moment to process
|
||||
await new Promise(resolve => setTimeout(resolve, 100));
|
||||
|
||||
// Verify that the connection was forwarded to HttpProxy, not direct connection
|
||||
expect(httpProxyForwardCalled).toEqual(true);
|
||||
expect(directConnectionCalled).toEqual(false);
|
||||
});
|
||||
|
||||
// Test that verifies TLS connections still work normally
|
||||
tap.test('should handle TLS connections normally', async (tapTest) => {
|
||||
const mockSettings: ISmartProxyOptions = {
|
||||
useHttpProxy: [443],
|
||||
httpProxyPort: 8844,
|
||||
routes: [{
|
||||
name: 'tls-route',
|
||||
match: { ports: 443 },
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{ host: 'localhost', port: 8443 }],
|
||||
tls: { mode: 'terminate' }
|
||||
}
|
||||
}]
|
||||
};
|
||||
|
||||
let httpProxyForwardCalled = false;
|
||||
|
||||
const mockHttpProxyBridge = {
|
||||
getHttpProxy: () => ({ available: true }),
|
||||
forwardToHttpProxy: async (...args: any[]) => {
|
||||
httpProxyForwardCalled = true;
|
||||
}
|
||||
};
|
||||
|
||||
const mockConnectionManager = {
|
||||
createConnection: (socket: any) => ({
|
||||
id: 'test-tls-connection',
|
||||
localPort: 443,
|
||||
remoteIP: '127.0.0.1',
|
||||
isTLS: true,
|
||||
tlsHandshakeComplete: false
|
||||
}),
|
||||
generateConnectionId: () => 'test-tls-connection-id',
|
||||
initiateCleanupOnce: () => {},
|
||||
cleanupConnection: () => {},
|
||||
getConnectionCount: () => 1,
|
||||
trackConnectionByRoute: (routeId: string, connectionId: string) => {},
|
||||
handleError: (type: string, record: any) => {
|
||||
return (error: Error) => {
|
||||
console.log(`Mock: Error handled for ${type}: ${error.message}`);
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
const mockTlsManager = {
|
||||
isTlsHandshake: (chunk: Buffer) => true,
|
||||
isClientHello: (chunk: Buffer) => true,
|
||||
extractSNI: (chunk: Buffer) => 'test.local'
|
||||
};
|
||||
|
||||
const mockRouteManager = {
|
||||
findMatchingRoute: (criteria: any) => ({
|
||||
route: mockSettings.routes[0]
|
||||
}),
|
||||
getRoutes: () => mockSettings.routes,
|
||||
getRoutesForPort: (port: number) => mockSettings.routes.filter(r => {
|
||||
const ports = Array.isArray(r.match.ports) ? r.match.ports : [r.match.ports];
|
||||
return ports.some(p => {
|
||||
if (typeof p === 'number') {
|
||||
return p === port;
|
||||
} else if (p && typeof p === 'object' && 'from' in p && 'to' in p) {
|
||||
return port >= p.from && port <= p.to;
|
||||
}
|
||||
return false;
|
||||
});
|
||||
})
|
||||
};
|
||||
|
||||
const mockSecurityManager = {
|
||||
validateAndTrackIP: () => ({ allowed: true })
|
||||
};
|
||||
|
||||
// Create a mock SmartProxy instance with necessary properties
|
||||
const mockSmartProxy = {
|
||||
settings: mockSettings,
|
||||
connectionManager: mockConnectionManager,
|
||||
securityManager: mockSecurityManager,
|
||||
tlsManager: mockTlsManager,
|
||||
httpProxyBridge: mockHttpProxyBridge,
|
||||
routeManager: mockRouteManager
|
||||
} as any;
|
||||
|
||||
const handler = new RouteConnectionHandler(mockSmartProxy);
|
||||
|
||||
const mockSocket = {
|
||||
localPort: 443,
|
||||
remoteAddress: '127.0.0.1',
|
||||
on: function(event: string, handler: Function) { return this; },
|
||||
once: function(event: string, handler: Function) {
|
||||
// Capture the data handler
|
||||
if (event === 'data') {
|
||||
this._dataHandler = handler;
|
||||
}
|
||||
return this;
|
||||
},
|
||||
end: () => {},
|
||||
destroy: () => {},
|
||||
pause: () => {},
|
||||
resume: () => {},
|
||||
removeListener: function() { return this; },
|
||||
emit: () => {},
|
||||
setNoDelay: () => {},
|
||||
setKeepAlive: () => {},
|
||||
_dataHandler: null as any
|
||||
} as any;
|
||||
|
||||
handler.handleConnection(mockSocket);
|
||||
|
||||
// Simulate TLS handshake
|
||||
if (mockSocket._dataHandler) {
|
||||
const tlsHandshake = Buffer.from([0x16, 0x03, 0x01, 0x00, 0x05]);
|
||||
mockSocket._dataHandler(tlsHandshake);
|
||||
}
|
||||
|
||||
await new Promise(resolve => setTimeout(resolve, 100));
|
||||
|
||||
// TLS connections with 'terminate' mode should go to HttpProxy
|
||||
expect(httpProxyForwardCalled).toEqual(true);
|
||||
});
|
||||
|
||||
export default tap.start();
|
||||
@@ -1,189 +0,0 @@
|
||||
import { tap, expect } from '@git.zone/tstest/tapbundle';
|
||||
import { SmartProxy } from '../ts/index.js';
|
||||
import * as net from 'net';
|
||||
|
||||
// Test that verifies HTTP connections on ports configured in useHttpProxy are properly forwarded
|
||||
tap.test('should detect and forward non-TLS connections on HttpProxy ports', async (tapTest) => {
|
||||
// Track whether the connection was forwarded to HttpProxy
|
||||
let forwardedToHttpProxy = false;
|
||||
let connectionPath = '';
|
||||
|
||||
// Create a SmartProxy instance first
|
||||
const proxy = new SmartProxy({
|
||||
useHttpProxy: [8081], // Use different port to avoid conflicts
|
||||
httpProxyPort: 8847, // Use different port to avoid conflicts
|
||||
routes: [{
|
||||
name: 'test-http-forward',
|
||||
match: { ports: 8081 },
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{ host: 'localhost', port: 8181 }]
|
||||
}
|
||||
}]
|
||||
});
|
||||
|
||||
// Add detailed logging to the existing proxy instance
|
||||
proxy.settings.enableDetailedLogging = true;
|
||||
|
||||
// Override the HttpProxy initialization to avoid actual HttpProxy setup
|
||||
proxy['httpProxyBridge'].initialize = async () => {
|
||||
console.log('Mock: HttpProxyBridge initialized');
|
||||
};
|
||||
proxy['httpProxyBridge'].start = async () => {
|
||||
console.log('Mock: HttpProxyBridge started');
|
||||
};
|
||||
proxy['httpProxyBridge'].stop = async () => {
|
||||
console.log('Mock: HttpProxyBridge stopped');
|
||||
return Promise.resolve(); // Ensure it returns a resolved promise
|
||||
};
|
||||
|
||||
await proxy.start();
|
||||
|
||||
// Mock the HttpProxy forwarding AFTER start to ensure it's not overridden
|
||||
const originalForward = (proxy as any).httpProxyBridge.forwardToHttpProxy;
|
||||
(proxy as any).httpProxyBridge.forwardToHttpProxy = async function(...args: any[]) {
|
||||
forwardedToHttpProxy = true;
|
||||
connectionPath = 'httpproxy';
|
||||
console.log('Mock: Connection forwarded to HttpProxy with args:', args[0], 'on port:', args[2]?.localPort);
|
||||
// Properly close the connection for the test
|
||||
const socket = args[1];
|
||||
socket.end();
|
||||
socket.destroy();
|
||||
};
|
||||
|
||||
// Mock getHttpProxy to indicate HttpProxy is available
|
||||
(proxy as any).httpProxyBridge.getHttpProxy = () => ({ available: true });
|
||||
|
||||
// Make a connection to port 8080
|
||||
const client = new net.Socket();
|
||||
|
||||
await new Promise<void>((resolve, reject) => {
|
||||
client.connect(8081, 'localhost', () => {
|
||||
console.log('Client connected to proxy on port 8081');
|
||||
// Send a non-TLS HTTP request
|
||||
client.write('GET / HTTP/1.1\r\nHost: test.local\r\n\r\n');
|
||||
// Add a small delay to ensure data is sent
|
||||
setTimeout(() => resolve(), 50);
|
||||
});
|
||||
|
||||
client.on('error', reject);
|
||||
});
|
||||
|
||||
// Give it a moment to process
|
||||
await new Promise(resolve => setTimeout(resolve, 100));
|
||||
|
||||
// Verify the connection was forwarded to HttpProxy
|
||||
expect(forwardedToHttpProxy).toEqual(true);
|
||||
expect(connectionPath).toEqual('httpproxy');
|
||||
|
||||
client.destroy();
|
||||
|
||||
// Restore original method before stopping
|
||||
(proxy as any).httpProxyBridge.forwardToHttpProxy = originalForward;
|
||||
|
||||
console.log('About to stop proxy...');
|
||||
await proxy.stop();
|
||||
console.log('Proxy stopped');
|
||||
|
||||
// Wait a bit to ensure port is released
|
||||
await new Promise(resolve => setTimeout(resolve, 100));
|
||||
});
|
||||
|
||||
// Test that verifies the fix detects non-TLS connections
|
||||
tap.test('should properly detect non-TLS connections on HttpProxy ports', async (tapTest) => {
|
||||
const targetPort = 8182;
|
||||
let receivedConnection = false;
|
||||
|
||||
// Create a target server that never receives the connection (because it goes to HttpProxy)
|
||||
const targetServer = net.createServer((socket) => {
|
||||
receivedConnection = true;
|
||||
socket.end();
|
||||
});
|
||||
|
||||
await new Promise<void>((resolve) => {
|
||||
targetServer.listen(targetPort, () => {
|
||||
console.log(`Target server listening on port ${targetPort}`);
|
||||
resolve();
|
||||
});
|
||||
});
|
||||
|
||||
// Mock HttpProxyBridge to track forwarding
|
||||
let httpProxyForwardCalled = false;
|
||||
|
||||
const proxy = new SmartProxy({
|
||||
useHttpProxy: [8082], // Use different port to avoid conflicts
|
||||
httpProxyPort: 8848, // Use different port to avoid conflicts
|
||||
routes: [{
|
||||
name: 'test-route',
|
||||
match: {
|
||||
ports: 8082
|
||||
},
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{ host: 'localhost', port: targetPort }]
|
||||
}
|
||||
}]
|
||||
});
|
||||
|
||||
// Override the forwardToHttpProxy method to track calls
|
||||
const originalForward = proxy['httpProxyBridge'].forwardToHttpProxy;
|
||||
proxy['httpProxyBridge'].forwardToHttpProxy = async function(...args: any[]) {
|
||||
httpProxyForwardCalled = true;
|
||||
console.log('HttpProxy forward called with connectionId:', args[0]);
|
||||
// Properly close the connection
|
||||
const socket = args[1];
|
||||
socket.end();
|
||||
socket.destroy();
|
||||
};
|
||||
|
||||
// Mock HttpProxyBridge methods
|
||||
proxy['httpProxyBridge'].initialize = async () => {
|
||||
console.log('Mock: HttpProxyBridge initialized');
|
||||
};
|
||||
proxy['httpProxyBridge'].start = async () => {
|
||||
console.log('Mock: HttpProxyBridge started');
|
||||
};
|
||||
proxy['httpProxyBridge'].stop = async () => {
|
||||
console.log('Mock: HttpProxyBridge stopped');
|
||||
return Promise.resolve(); // Ensure it returns a resolved promise
|
||||
};
|
||||
|
||||
// Mock getHttpProxy to return a truthy value
|
||||
proxy['httpProxyBridge'].getHttpProxy = () => ({} as any);
|
||||
|
||||
await proxy.start();
|
||||
|
||||
// Make a non-TLS connection
|
||||
const client = new net.Socket();
|
||||
|
||||
await new Promise<void>((resolve, reject) => {
|
||||
client.connect(8082, 'localhost', () => {
|
||||
console.log('Connected to proxy');
|
||||
client.write('GET / HTTP/1.1\r\nHost: test.local\r\n\r\n');
|
||||
// Add a small delay to ensure data is sent
|
||||
setTimeout(() => resolve(), 50);
|
||||
});
|
||||
|
||||
client.on('error', () => resolve()); // Ignore errors since we're ending the connection
|
||||
});
|
||||
|
||||
await new Promise(resolve => setTimeout(resolve, 100));
|
||||
|
||||
// Verify that HttpProxy was called, not direct connection
|
||||
expect(httpProxyForwardCalled).toEqual(true);
|
||||
expect(receivedConnection).toEqual(false); // Target should not receive direct connection
|
||||
|
||||
client.destroy();
|
||||
await proxy.stop();
|
||||
await new Promise<void>((resolve) => {
|
||||
targetServer.close(() => resolve());
|
||||
});
|
||||
|
||||
// Wait a bit to ensure port is released
|
||||
await new Promise(resolve => setTimeout(resolve, 100));
|
||||
|
||||
// Restore original method
|
||||
proxy['httpProxyBridge'].forwardToHttpProxy = originalForward;
|
||||
});
|
||||
|
||||
export default tap.start();
|
||||
@@ -1,246 +0,0 @@
|
||||
import { tap, expect } from '@git.zone/tstest/tapbundle';
|
||||
import { SmartProxy } from '../ts/index.js';
|
||||
import * as plugins from '../ts/plugins.js';
|
||||
import * as net from 'net';
|
||||
import * as http from 'http';
|
||||
|
||||
/**
|
||||
* This test verifies our improved port binding intelligence for ACME challenges.
|
||||
* It specifically tests:
|
||||
* 1. Using port 8080 instead of 80 for ACME HTTP challenges
|
||||
* 2. Correctly handling shared port bindings between regular routes and challenge routes
|
||||
* 3. Avoiding port conflicts when updating routes
|
||||
*/
|
||||
|
||||
tap.test('should handle ACME challenges on port 8080 with improved port binding intelligence', async (tapTest) => {
|
||||
// Create a simple echo server to act as our target
|
||||
const targetPort = 9001;
|
||||
let receivedData = '';
|
||||
|
||||
const targetServer = net.createServer((socket) => {
|
||||
console.log('Target server received connection');
|
||||
|
||||
socket.on('data', (data) => {
|
||||
receivedData += data.toString();
|
||||
console.log('Target server received data:', data.toString().split('\n')[0]);
|
||||
|
||||
// Send a simple HTTP response
|
||||
const response = 'HTTP/1.1 200 OK\r\nContent-Type: text/plain\r\nContent-Length: 13\r\n\r\nHello, World!';
|
||||
socket.write(response);
|
||||
});
|
||||
});
|
||||
|
||||
await new Promise<void>((resolve) => {
|
||||
targetServer.listen(targetPort, () => {
|
||||
console.log(`Target server listening on port ${targetPort}`);
|
||||
resolve();
|
||||
});
|
||||
});
|
||||
|
||||
// In this test we will NOT create a mock ACME server on the same port
|
||||
// as SmartProxy will use, instead we'll let SmartProxy handle it
|
||||
const acmeServerPort = 9009;
|
||||
const acmeRequests: string[] = [];
|
||||
let acmeServer: http.Server | null = null;
|
||||
|
||||
// We'll assume the ACME port is available for SmartProxy
|
||||
let acmePortAvailable = true;
|
||||
|
||||
// Create SmartProxy with ACME configured to use port 8080
|
||||
console.log('Creating SmartProxy with ACME port 8080...');
|
||||
const tempCertDir = './temp-certs';
|
||||
|
||||
try {
|
||||
await plugins.smartfile.fs.ensureDir(tempCertDir);
|
||||
} catch (error) {
|
||||
// Directory may already exist, that's ok
|
||||
}
|
||||
|
||||
const proxy = new SmartProxy({
|
||||
enableDetailedLogging: true,
|
||||
routes: [
|
||||
{
|
||||
name: 'test-route',
|
||||
match: {
|
||||
ports: [9003],
|
||||
domains: ['test.example.com']
|
||||
},
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{ host: 'localhost', port: targetPort }],
|
||||
tls: {
|
||||
mode: 'terminate',
|
||||
certificate: 'auto' // Use ACME for certificate
|
||||
}
|
||||
}
|
||||
},
|
||||
// Also add a route for port 8080 to test port sharing
|
||||
{
|
||||
name: 'http-route',
|
||||
match: {
|
||||
ports: [9009],
|
||||
domains: ['test.example.com']
|
||||
},
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{ host: 'localhost', port: targetPort }]
|
||||
}
|
||||
}
|
||||
],
|
||||
acme: {
|
||||
email: 'test@example.com',
|
||||
useProduction: false,
|
||||
port: 9009, // Use 9009 instead of default 80
|
||||
certificateStore: tempCertDir
|
||||
}
|
||||
});
|
||||
|
||||
// Mock the certificate manager to avoid actual ACME operations
|
||||
console.log('Mocking certificate manager...');
|
||||
const createCertManager = (proxy as any).createCertificateManager;
|
||||
(proxy as any).createCertificateManager = async function(...args: any[]) {
|
||||
// Create a completely mocked certificate manager that doesn't use ACME at all
|
||||
return {
|
||||
initialize: async () => {},
|
||||
getCertPair: async () => {
|
||||
return {
|
||||
publicKey: 'MOCK CERTIFICATE',
|
||||
privateKey: 'MOCK PRIVATE KEY'
|
||||
};
|
||||
},
|
||||
getAcmeOptions: () => {
|
||||
return {
|
||||
port: 9009
|
||||
};
|
||||
},
|
||||
getState: () => {
|
||||
return {
|
||||
initializing: false,
|
||||
ready: true,
|
||||
port: 9009
|
||||
};
|
||||
},
|
||||
provisionAllCertificates: async () => {
|
||||
console.log('Mock: Provisioning certificates');
|
||||
return [];
|
||||
},
|
||||
stop: async () => {},
|
||||
setRoutes: (routes: any) => {},
|
||||
smartAcme: {
|
||||
getCertificateForDomain: async () => {
|
||||
// Return a mock certificate
|
||||
return {
|
||||
publicKey: 'MOCK CERTIFICATE',
|
||||
privateKey: 'MOCK PRIVATE KEY',
|
||||
validUntil: Date.now() + 90 * 24 * 60 * 60 * 1000,
|
||||
created: Date.now()
|
||||
};
|
||||
},
|
||||
start: async () => {},
|
||||
stop: async () => {}
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
// Track port binding attempts to verify intelligence
|
||||
const portBindAttempts: number[] = [];
|
||||
const originalAddPort = (proxy as any).portManager.addPort;
|
||||
(proxy as any).portManager.addPort = async function(port: number) {
|
||||
portBindAttempts.push(port);
|
||||
return originalAddPort.call(this, port);
|
||||
};
|
||||
|
||||
try {
|
||||
console.log('Starting SmartProxy...');
|
||||
await proxy.start();
|
||||
|
||||
console.log('Port binding attempts:', portBindAttempts);
|
||||
|
||||
// Check that we tried to bind to port 9009
|
||||
// Should attempt to bind to port 9009
|
||||
expect(portBindAttempts.includes(9009)).toEqual(true);
|
||||
// Should attempt to bind to port 9003
|
||||
expect(portBindAttempts.includes(9003)).toEqual(true);
|
||||
|
||||
// Get actual bound ports
|
||||
const boundPorts = proxy.getListeningPorts();
|
||||
console.log('Actually bound ports:', boundPorts);
|
||||
|
||||
// If port 9009 was available, we should be bound to it
|
||||
if (acmePortAvailable) {
|
||||
// Should be bound to port 9009 if available
|
||||
expect(boundPorts.includes(9009)).toEqual(true);
|
||||
}
|
||||
|
||||
// Should be bound to port 9003
|
||||
expect(boundPorts.includes(9003)).toEqual(true);
|
||||
|
||||
// Test adding a new route on port 8080
|
||||
console.log('Testing route update with port reuse...');
|
||||
|
||||
// Reset tracking
|
||||
portBindAttempts.length = 0;
|
||||
|
||||
// Add a new route on port 8080
|
||||
const newRoutes = [
|
||||
...proxy.settings.routes,
|
||||
{
|
||||
name: 'additional-route',
|
||||
match: {
|
||||
ports: [9009],
|
||||
path: '/additional'
|
||||
},
|
||||
action: {
|
||||
type: 'forward' as const,
|
||||
targets: [{ host: 'localhost', port: targetPort }]
|
||||
}
|
||||
}
|
||||
];
|
||||
|
||||
// Update routes - this should NOT try to rebind port 8080
|
||||
await proxy.updateRoutes(newRoutes);
|
||||
|
||||
console.log('Port binding attempts after update:', portBindAttempts);
|
||||
|
||||
// We should not try to rebind port 9009 since it's already bound
|
||||
// Should not attempt to rebind port 9009
|
||||
expect(portBindAttempts.includes(9009)).toEqual(false);
|
||||
|
||||
// We should still be listening on both ports
|
||||
const portsAfterUpdate = proxy.getListeningPorts();
|
||||
console.log('Bound ports after update:', portsAfterUpdate);
|
||||
|
||||
if (acmePortAvailable) {
|
||||
// Should still be bound to port 9009
|
||||
expect(portsAfterUpdate.includes(9009)).toEqual(true);
|
||||
}
|
||||
// Should still be bound to port 9003
|
||||
expect(portsAfterUpdate.includes(9003)).toEqual(true);
|
||||
|
||||
// The test is successful at this point - we've verified the port binding intelligence
|
||||
console.log('Port binding intelligence verified successfully!');
|
||||
// We'll skip the actual connection test to avoid timeouts
|
||||
} finally {
|
||||
// Clean up
|
||||
console.log('Cleaning up...');
|
||||
await proxy.stop();
|
||||
|
||||
if (targetServer) {
|
||||
await new Promise<void>((resolve) => {
|
||||
targetServer.close(() => resolve());
|
||||
});
|
||||
}
|
||||
|
||||
// No acmeServer to close in this test
|
||||
|
||||
// Clean up temp directory
|
||||
try {
|
||||
// Remove temp directory
|
||||
await plugins.smartfile.fs.remove(tempCertDir);
|
||||
} catch (error) {
|
||||
console.error('Failed to remove temp directory:', error);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
export default tap.start();
|
||||
@@ -1,114 +0,0 @@
|
||||
import { expect, tap } from '@git.zone/tstest/tapbundle';
|
||||
import { SecurityManager } from '../ts/proxies/http-proxy/security-manager.js';
|
||||
import { createLogger } from '../ts/proxies/http-proxy/models/types.js';
|
||||
|
||||
let securityManager: SecurityManager;
|
||||
const logger = createLogger('error'); // Quiet logger for tests
|
||||
|
||||
tap.test('Setup HttpProxy SecurityManager', async () => {
|
||||
securityManager = new SecurityManager(logger, [], 3, 10); // Low limits for testing
|
||||
});
|
||||
|
||||
tap.test('HttpProxy IP connection tracking', async () => {
|
||||
const testIP = '10.0.0.1';
|
||||
|
||||
// Track connections
|
||||
securityManager.trackConnectionByIP(testIP, 'http-conn1');
|
||||
securityManager.trackConnectionByIP(testIP, 'http-conn2');
|
||||
|
||||
expect(securityManager.getConnectionCountByIP(testIP)).toEqual(2);
|
||||
|
||||
// Validate IP should pass
|
||||
let result = securityManager.validateIP(testIP);
|
||||
expect(result.allowed).toBeTrue();
|
||||
|
||||
// Add one more to reach limit
|
||||
securityManager.trackConnectionByIP(testIP, 'http-conn3');
|
||||
|
||||
// Should now reject new connections
|
||||
result = securityManager.validateIP(testIP);
|
||||
expect(result.allowed).toBeFalse();
|
||||
expect(result.reason).toInclude('Maximum connections per IP (3) exceeded');
|
||||
|
||||
// Remove a connection
|
||||
securityManager.removeConnectionByIP(testIP, 'http-conn1');
|
||||
|
||||
// Should allow connections again
|
||||
result = securityManager.validateIP(testIP);
|
||||
expect(result.allowed).toBeTrue();
|
||||
|
||||
// Clean up
|
||||
securityManager.removeConnectionByIP(testIP, 'http-conn2');
|
||||
securityManager.removeConnectionByIP(testIP, 'http-conn3');
|
||||
});
|
||||
|
||||
tap.test('HttpProxy connection rate limiting', async () => {
|
||||
const testIP = '10.0.0.2';
|
||||
|
||||
// Make 10 connection attempts rapidly (at rate limit)
|
||||
// Note: We don't track connections here as we're testing rate limiting, not per-IP limiting
|
||||
for (let i = 0; i < 10; i++) {
|
||||
const result = securityManager.validateIP(testIP);
|
||||
expect(result.allowed).toBeTrue();
|
||||
}
|
||||
|
||||
// 11th connection should be rate limited
|
||||
const result = securityManager.validateIP(testIP);
|
||||
expect(result.allowed).toBeFalse();
|
||||
expect(result.reason).toInclude('Connection rate limit (10/min) exceeded');
|
||||
});
|
||||
|
||||
tap.test('HttpProxy CLIENT_IP header handling', async () => {
|
||||
// This tests the scenario where SmartProxy forwards the real client IP
|
||||
const realClientIP = '203.0.113.1';
|
||||
const proxyIP = '127.0.0.1';
|
||||
|
||||
// Simulate SmartProxy tracking the real client IP
|
||||
securityManager.trackConnectionByIP(realClientIP, 'forwarded-conn1');
|
||||
securityManager.trackConnectionByIP(realClientIP, 'forwarded-conn2');
|
||||
securityManager.trackConnectionByIP(realClientIP, 'forwarded-conn3');
|
||||
|
||||
// Real client IP should be at limit
|
||||
let result = securityManager.validateIP(realClientIP);
|
||||
expect(result.allowed).toBeFalse();
|
||||
|
||||
// But proxy IP should still be allowed
|
||||
result = securityManager.validateIP(proxyIP);
|
||||
expect(result.allowed).toBeTrue();
|
||||
|
||||
// Clean up
|
||||
securityManager.removeConnectionByIP(realClientIP, 'forwarded-conn1');
|
||||
securityManager.removeConnectionByIP(realClientIP, 'forwarded-conn2');
|
||||
securityManager.removeConnectionByIP(realClientIP, 'forwarded-conn3');
|
||||
});
|
||||
|
||||
tap.test('HttpProxy automatic cleanup', async (tools) => {
|
||||
const testIP = '10.0.0.3';
|
||||
|
||||
// Create and immediately remove connections
|
||||
for (let i = 0; i < 5; i++) {
|
||||
securityManager.trackConnectionByIP(testIP, `cleanup-conn${i}`);
|
||||
securityManager.removeConnectionByIP(testIP, `cleanup-conn${i}`);
|
||||
}
|
||||
|
||||
// Add rate limit entries
|
||||
for (let i = 0; i < 5; i++) {
|
||||
securityManager.validateIP(testIP);
|
||||
}
|
||||
|
||||
// Wait a bit (cleanup runs every 60 seconds in production)
|
||||
// For testing, we'll just verify the cleanup logic works
|
||||
await tools.delayFor(100);
|
||||
|
||||
// Manually trigger cleanup (in production this happens automatically)
|
||||
(securityManager as any).performIpCleanup();
|
||||
|
||||
// IP should be cleaned up
|
||||
expect(securityManager.getConnectionCountByIP(testIP)).toEqual(0);
|
||||
});
|
||||
|
||||
tap.test('Cleanup HttpProxy SecurityManager', async () => {
|
||||
securityManager.clearIPTracking();
|
||||
});
|
||||
|
||||
export default tap.start();
|
||||
@@ -1,405 +0,0 @@
|
||||
import { expect, tap } from '@git.zone/tstest/tapbundle';
|
||||
import * as plugins from '../ts/plugins.js';
|
||||
import { HttpProxy } from '../ts/proxies/http-proxy/index.js';
|
||||
import type { IRouteConfig } from '../ts/proxies/smart-proxy/models/route-types.js';
|
||||
import type { IRouteContext } from '../ts/core/models/route-context.js';
|
||||
|
||||
// Declare variables for tests
|
||||
let httpProxy: HttpProxy;
|
||||
let testServer: plugins.http.Server;
|
||||
let testServerHttp2: plugins.http2.Http2Server;
|
||||
let serverPort: number;
|
||||
let serverPortHttp2: number;
|
||||
|
||||
// Setup test environment
|
||||
tap.test('setup HttpProxy function-based targets test environment', async (tools) => {
|
||||
// Set a reasonable timeout for the test
|
||||
tools.timeout(30000); // 30 seconds
|
||||
// Create simple HTTP server to respond to requests
|
||||
testServer = plugins.http.createServer((req, res) => {
|
||||
res.writeHead(200, { 'Content-Type': 'application/json' });
|
||||
res.end(JSON.stringify({
|
||||
url: req.url,
|
||||
headers: req.headers,
|
||||
method: req.method,
|
||||
message: 'HTTP/1.1 Response'
|
||||
}));
|
||||
});
|
||||
|
||||
// Create simple HTTP/2 server to respond to requests
|
||||
testServerHttp2 = plugins.http2.createServer();
|
||||
testServerHttp2.on('stream', (stream, headers) => {
|
||||
stream.respond({
|
||||
'content-type': 'application/json',
|
||||
':status': 200
|
||||
});
|
||||
stream.end(JSON.stringify({
|
||||
path: headers[':path'],
|
||||
headers,
|
||||
method: headers[':method'],
|
||||
message: 'HTTP/2 Response'
|
||||
}));
|
||||
});
|
||||
|
||||
// Handle HTTP/2 errors
|
||||
testServerHttp2.on('error', (err) => {
|
||||
console.error('HTTP/2 server error:', err);
|
||||
});
|
||||
|
||||
// Start the servers
|
||||
await new Promise<void>(resolve => {
|
||||
testServer.listen(0, () => {
|
||||
const address = testServer.address() as { port: number };
|
||||
serverPort = address.port;
|
||||
resolve();
|
||||
});
|
||||
});
|
||||
|
||||
await new Promise<void>(resolve => {
|
||||
testServerHttp2.listen(0, () => {
|
||||
const address = testServerHttp2.address() as { port: number };
|
||||
serverPortHttp2 = address.port;
|
||||
resolve();
|
||||
});
|
||||
});
|
||||
|
||||
// Create HttpProxy instance
|
||||
httpProxy = new HttpProxy({
|
||||
port: 0, // Use dynamic port
|
||||
logLevel: 'info', // Use info level to see more logs
|
||||
// Disable ACME to avoid trying to bind to port 80
|
||||
acme: {
|
||||
enabled: false
|
||||
}
|
||||
});
|
||||
|
||||
await httpProxy.start();
|
||||
|
||||
// Log the actual port being used
|
||||
const actualPort = httpProxy.getListeningPort();
|
||||
console.log(`HttpProxy actual listening port: ${actualPort}`);
|
||||
});
|
||||
|
||||
// Test static host/port routes
|
||||
tap.test('should support static host/port routes', async () => {
|
||||
// Get proxy port first
|
||||
const proxyPort = httpProxy.getListeningPort();
|
||||
|
||||
const routes: IRouteConfig[] = [
|
||||
{
|
||||
name: 'static-route',
|
||||
priority: 100,
|
||||
match: {
|
||||
domains: 'example.com',
|
||||
ports: proxyPort
|
||||
},
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{
|
||||
host: 'localhost',
|
||||
port: serverPort
|
||||
}]
|
||||
}
|
||||
}
|
||||
];
|
||||
|
||||
await httpProxy.updateRouteConfigs(routes);
|
||||
|
||||
// Make request to proxy
|
||||
const response = await makeRequest({
|
||||
hostname: 'localhost',
|
||||
port: proxyPort,
|
||||
path: '/test',
|
||||
method: 'GET',
|
||||
headers: {
|
||||
'Host': 'example.com'
|
||||
}
|
||||
});
|
||||
|
||||
expect(response.statusCode).toEqual(200);
|
||||
const body = JSON.parse(response.body);
|
||||
expect(body.url).toEqual('/test');
|
||||
expect(body.headers.host).toEqual(`localhost:${serverPort}`);
|
||||
});
|
||||
|
||||
// Test function-based host
|
||||
tap.test('should support function-based host', async () => {
|
||||
const proxyPort = httpProxy.getListeningPort();
|
||||
const routes: IRouteConfig[] = [
|
||||
{
|
||||
name: 'function-host-route',
|
||||
priority: 100,
|
||||
match: {
|
||||
domains: 'function.example.com',
|
||||
ports: proxyPort
|
||||
},
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{
|
||||
host: (context: IRouteContext) => {
|
||||
// Return localhost always in this test
|
||||
return 'localhost';
|
||||
},
|
||||
port: serverPort
|
||||
}]
|
||||
}
|
||||
}
|
||||
];
|
||||
|
||||
await httpProxy.updateRouteConfigs(routes);
|
||||
|
||||
// Make request to proxy
|
||||
const response = await makeRequest({
|
||||
hostname: 'localhost',
|
||||
port: proxyPort,
|
||||
path: '/function-host',
|
||||
method: 'GET',
|
||||
headers: {
|
||||
'Host': 'function.example.com'
|
||||
}
|
||||
});
|
||||
|
||||
expect(response.statusCode).toEqual(200);
|
||||
const body = JSON.parse(response.body);
|
||||
expect(body.url).toEqual('/function-host');
|
||||
expect(body.headers.host).toEqual(`localhost:${serverPort}`);
|
||||
});
|
||||
|
||||
// Test function-based port
|
||||
tap.test('should support function-based port', async () => {
|
||||
const proxyPort = httpProxy.getListeningPort();
|
||||
const routes: IRouteConfig[] = [
|
||||
{
|
||||
name: 'function-port-route',
|
||||
priority: 100,
|
||||
match: {
|
||||
domains: 'function-port.example.com',
|
||||
ports: proxyPort
|
||||
},
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{
|
||||
host: 'localhost',
|
||||
port: (context: IRouteContext) => {
|
||||
// Return test server port
|
||||
return serverPort;
|
||||
}
|
||||
}]
|
||||
}
|
||||
}
|
||||
];
|
||||
|
||||
await httpProxy.updateRouteConfigs(routes);
|
||||
|
||||
// Make request to proxy
|
||||
const response = await makeRequest({
|
||||
hostname: 'localhost',
|
||||
port: proxyPort,
|
||||
path: '/function-port',
|
||||
method: 'GET',
|
||||
headers: {
|
||||
'Host': 'function-port.example.com'
|
||||
}
|
||||
});
|
||||
|
||||
expect(response.statusCode).toEqual(200);
|
||||
const body = JSON.parse(response.body);
|
||||
expect(body.url).toEqual('/function-port');
|
||||
expect(body.headers.host).toEqual(`localhost:${serverPort}`);
|
||||
});
|
||||
|
||||
// Test function-based host AND port
|
||||
tap.test('should support function-based host AND port', async () => {
|
||||
const proxyPort = httpProxy.getListeningPort();
|
||||
const routes: IRouteConfig[] = [
|
||||
{
|
||||
name: 'function-both-route',
|
||||
priority: 100,
|
||||
match: {
|
||||
domains: 'function-both.example.com',
|
||||
ports: proxyPort
|
||||
},
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{
|
||||
host: (context: IRouteContext) => {
|
||||
return 'localhost';
|
||||
},
|
||||
port: (context: IRouteContext) => {
|
||||
return serverPort;
|
||||
}
|
||||
}]
|
||||
}
|
||||
}
|
||||
];
|
||||
|
||||
await httpProxy.updateRouteConfigs(routes);
|
||||
|
||||
// Make request to proxy
|
||||
const response = await makeRequest({
|
||||
hostname: 'localhost',
|
||||
port: proxyPort,
|
||||
path: '/function-both',
|
||||
method: 'GET',
|
||||
headers: {
|
||||
'Host': 'function-both.example.com'
|
||||
}
|
||||
});
|
||||
|
||||
expect(response.statusCode).toEqual(200);
|
||||
const body = JSON.parse(response.body);
|
||||
expect(body.url).toEqual('/function-both');
|
||||
expect(body.headers.host).toEqual(`localhost:${serverPort}`);
|
||||
});
|
||||
|
||||
// Test context-based routing with path
|
||||
tap.test('should support context-based routing with path', async () => {
|
||||
const proxyPort = httpProxy.getListeningPort();
|
||||
const routes: IRouteConfig[] = [
|
||||
{
|
||||
name: 'context-path-route',
|
||||
priority: 100,
|
||||
match: {
|
||||
domains: 'context.example.com',
|
||||
ports: proxyPort
|
||||
},
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{
|
||||
host: (context: IRouteContext) => {
|
||||
// Use path to determine host
|
||||
if (context.path?.startsWith('/api')) {
|
||||
return 'localhost';
|
||||
} else {
|
||||
return '127.0.0.1'; // Another way to reference localhost
|
||||
}
|
||||
},
|
||||
port: serverPort
|
||||
}]
|
||||
}
|
||||
}
|
||||
];
|
||||
|
||||
await httpProxy.updateRouteConfigs(routes);
|
||||
|
||||
// Make request to proxy with /api path
|
||||
const apiResponse = await makeRequest({
|
||||
hostname: 'localhost',
|
||||
port: proxyPort,
|
||||
path: '/api/test',
|
||||
method: 'GET',
|
||||
headers: {
|
||||
'Host': 'context.example.com'
|
||||
}
|
||||
});
|
||||
|
||||
expect(apiResponse.statusCode).toEqual(200);
|
||||
const apiBody = JSON.parse(apiResponse.body);
|
||||
expect(apiBody.url).toEqual('/api/test');
|
||||
|
||||
// Make request to proxy with non-api path
|
||||
const nonApiResponse = await makeRequest({
|
||||
hostname: 'localhost',
|
||||
port: proxyPort,
|
||||
path: '/web/test',
|
||||
method: 'GET',
|
||||
headers: {
|
||||
'Host': 'context.example.com'
|
||||
}
|
||||
});
|
||||
|
||||
expect(nonApiResponse.statusCode).toEqual(200);
|
||||
const nonApiBody = JSON.parse(nonApiResponse.body);
|
||||
expect(nonApiBody.url).toEqual('/web/test');
|
||||
});
|
||||
|
||||
// Cleanup test environment
|
||||
tap.test('cleanup HttpProxy function-based targets test environment', async () => {
|
||||
// Skip cleanup if setup failed
|
||||
if (!httpProxy && !testServer && !testServerHttp2) {
|
||||
console.log('Skipping cleanup - setup failed');
|
||||
return;
|
||||
}
|
||||
|
||||
// Stop test servers first
|
||||
if (testServer) {
|
||||
await new Promise<void>((resolve, reject) => {
|
||||
testServer.close((err) => {
|
||||
if (err) {
|
||||
console.error('Error closing test server:', err);
|
||||
reject(err);
|
||||
} else {
|
||||
console.log('Test server closed successfully');
|
||||
resolve();
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
if (testServerHttp2) {
|
||||
await new Promise<void>((resolve, reject) => {
|
||||
testServerHttp2.close((err) => {
|
||||
if (err) {
|
||||
console.error('Error closing HTTP/2 test server:', err);
|
||||
reject(err);
|
||||
} else {
|
||||
console.log('HTTP/2 test server closed successfully');
|
||||
resolve();
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// Stop HttpProxy last
|
||||
if (httpProxy) {
|
||||
console.log('Stopping HttpProxy...');
|
||||
await httpProxy.stop();
|
||||
console.log('HttpProxy stopped successfully');
|
||||
}
|
||||
|
||||
// Force exit after a short delay to ensure cleanup
|
||||
const cleanupTimeout = setTimeout(() => {
|
||||
console.log('Cleanup completed, exiting');
|
||||
}, 100);
|
||||
|
||||
// Don't keep the process alive just for this timeout
|
||||
if (cleanupTimeout.unref) {
|
||||
cleanupTimeout.unref();
|
||||
}
|
||||
});
|
||||
|
||||
// Helper function to make HTTPS requests with self-signed certificate support
|
||||
async function makeRequest(options: plugins.http.RequestOptions): Promise<{ statusCode: number, headers: plugins.http.IncomingHttpHeaders, body: string }> {
|
||||
return new Promise((resolve, reject) => {
|
||||
// Use HTTPS with rejectUnauthorized: false to accept self-signed certificates
|
||||
const req = plugins.https.request({
|
||||
...options,
|
||||
rejectUnauthorized: false, // Accept self-signed certificates
|
||||
}, (res) => {
|
||||
let body = '';
|
||||
res.on('data', (chunk) => {
|
||||
body += chunk;
|
||||
});
|
||||
res.on('end', () => {
|
||||
resolve({
|
||||
statusCode: res.statusCode || 0,
|
||||
headers: res.headers,
|
||||
body
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
req.on('error', (err) => {
|
||||
console.error(`Request error: ${err.message}`);
|
||||
reject(err);
|
||||
});
|
||||
|
||||
req.end();
|
||||
});
|
||||
}
|
||||
|
||||
// Start the tests
|
||||
tap.start().then(() => {
|
||||
// Ensure process exits after tests complete
|
||||
process.exit(0);
|
||||
});
|
||||
@@ -1,596 +0,0 @@
|
||||
import { expect, tap } from '@git.zone/tstest/tapbundle';
|
||||
import * as smartproxy from '../ts/index.js';
|
||||
import { loadTestCertificates } from './helpers/certificates.js';
|
||||
import * as https from 'https';
|
||||
import * as http from 'http';
|
||||
import { WebSocket, WebSocketServer } from 'ws';
|
||||
|
||||
let testProxy: smartproxy.HttpProxy;
|
||||
let testServer: http.Server;
|
||||
let wsServer: WebSocketServer;
|
||||
let testCertificates: { privateKey: string; publicKey: string };
|
||||
|
||||
// Helper function to make HTTPS requests
|
||||
async function makeHttpsRequest(
|
||||
options: https.RequestOptions,
|
||||
): Promise<{ statusCode: number; headers: http.IncomingHttpHeaders; body: string }> {
|
||||
console.log('[TEST] Making HTTPS request:', {
|
||||
hostname: options.hostname,
|
||||
port: options.port,
|
||||
path: options.path,
|
||||
method: options.method,
|
||||
headers: options.headers,
|
||||
});
|
||||
return new Promise((resolve, reject) => {
|
||||
const req = https.request(options, (res) => {
|
||||
console.log('[TEST] Received HTTPS response:', {
|
||||
statusCode: res.statusCode,
|
||||
headers: res.headers,
|
||||
});
|
||||
let data = '';
|
||||
res.on('data', (chunk) => (data += chunk));
|
||||
res.on('end', () => {
|
||||
console.log('[TEST] Response completed:', { data });
|
||||
// Ensure the socket is destroyed to prevent hanging connections
|
||||
res.socket?.destroy();
|
||||
resolve({
|
||||
statusCode: res.statusCode!,
|
||||
headers: res.headers,
|
||||
body: data,
|
||||
});
|
||||
});
|
||||
});
|
||||
req.on('error', (error) => {
|
||||
console.error('[TEST] Request error:', error);
|
||||
reject(error);
|
||||
});
|
||||
req.end();
|
||||
});
|
||||
}
|
||||
|
||||
// Setup test environment
|
||||
tap.test('setup test environment', async () => {
|
||||
// Load and validate certificates
|
||||
console.log('[TEST] Loading and validating certificates');
|
||||
testCertificates = loadTestCertificates();
|
||||
console.log('[TEST] Certificates loaded and validated');
|
||||
|
||||
// Create a test HTTP server
|
||||
testServer = http.createServer((req, res) => {
|
||||
console.log('[TEST SERVER] Received HTTP request:', {
|
||||
url: req.url,
|
||||
method: req.method,
|
||||
headers: req.headers,
|
||||
});
|
||||
res.writeHead(200, { 'Content-Type': 'text/plain' });
|
||||
res.end('Hello from test server!');
|
||||
});
|
||||
|
||||
// Handle WebSocket upgrade requests
|
||||
testServer.on('upgrade', (request, socket, head) => {
|
||||
console.log('[TEST SERVER] Received WebSocket upgrade request:', {
|
||||
url: request.url,
|
||||
method: request.method,
|
||||
headers: {
|
||||
host: request.headers.host,
|
||||
upgrade: request.headers.upgrade,
|
||||
connection: request.headers.connection,
|
||||
'sec-websocket-key': request.headers['sec-websocket-key'],
|
||||
'sec-websocket-version': request.headers['sec-websocket-version'],
|
||||
'sec-websocket-protocol': request.headers['sec-websocket-protocol'],
|
||||
},
|
||||
});
|
||||
|
||||
if (request.headers.upgrade?.toLowerCase() !== 'websocket') {
|
||||
console.log('[TEST SERVER] Not a WebSocket upgrade request');
|
||||
socket.destroy();
|
||||
return;
|
||||
}
|
||||
|
||||
console.log('[TEST SERVER] Handling WebSocket upgrade');
|
||||
wsServer.handleUpgrade(request, socket, head, (ws) => {
|
||||
console.log('[TEST SERVER] WebSocket connection upgraded');
|
||||
wsServer.emit('connection', ws, request);
|
||||
});
|
||||
});
|
||||
|
||||
// Create a WebSocket server (for the test HTTP server)
|
||||
console.log('[TEST SERVER] Creating WebSocket server');
|
||||
wsServer = new WebSocketServer({
|
||||
noServer: true,
|
||||
perMessageDeflate: false,
|
||||
clientTracking: true,
|
||||
handleProtocols: () => 'echo-protocol',
|
||||
});
|
||||
|
||||
wsServer.on('connection', (ws, request) => {
|
||||
console.log('[TEST SERVER] WebSocket connection established:', {
|
||||
url: request.url,
|
||||
headers: {
|
||||
host: request.headers.host,
|
||||
upgrade: request.headers.upgrade,
|
||||
connection: request.headers.connection,
|
||||
'sec-websocket-key': request.headers['sec-websocket-key'],
|
||||
'sec-websocket-version': request.headers['sec-websocket-version'],
|
||||
'sec-websocket-protocol': request.headers['sec-websocket-protocol'],
|
||||
},
|
||||
});
|
||||
|
||||
// Set up connection timeout
|
||||
const connectionTimeout = setTimeout(() => {
|
||||
console.error('[TEST SERVER] WebSocket connection timed out');
|
||||
ws.terminate();
|
||||
}, 5000);
|
||||
|
||||
// Clear timeout when connection is properly closed
|
||||
const clearConnectionTimeout = () => {
|
||||
clearTimeout(connectionTimeout);
|
||||
};
|
||||
|
||||
ws.on('message', (message) => {
|
||||
const msg = message.toString();
|
||||
console.log('[TEST SERVER] Received WebSocket message:', msg);
|
||||
try {
|
||||
const response = `Echo: ${msg}`;
|
||||
console.log('[TEST SERVER] Sending WebSocket response:', response);
|
||||
ws.send(response);
|
||||
// Clear timeout on successful message exchange
|
||||
clearConnectionTimeout();
|
||||
} catch (error) {
|
||||
console.error('[TEST SERVER] Error sending WebSocket message:', error);
|
||||
}
|
||||
});
|
||||
|
||||
ws.on('error', (error) => {
|
||||
console.error('[TEST SERVER] WebSocket error:', error);
|
||||
clearConnectionTimeout();
|
||||
});
|
||||
|
||||
ws.on('close', (code, reason) => {
|
||||
console.log('[TEST SERVER] WebSocket connection closed:', {
|
||||
code,
|
||||
reason: reason.toString(),
|
||||
wasClean: code === 1000 || code === 1001,
|
||||
});
|
||||
clearConnectionTimeout();
|
||||
});
|
||||
|
||||
ws.on('ping', (data) => {
|
||||
try {
|
||||
console.log('[TEST SERVER] Received ping, sending pong');
|
||||
ws.pong(data);
|
||||
} catch (error) {
|
||||
console.error('[TEST SERVER] Error sending pong:', error);
|
||||
}
|
||||
});
|
||||
|
||||
ws.on('pong', (data) => {
|
||||
console.log('[TEST SERVER] Received pong');
|
||||
});
|
||||
});
|
||||
|
||||
wsServer.on('error', (error) => {
|
||||
console.error('Test server: WebSocket server error:', error);
|
||||
});
|
||||
|
||||
wsServer.on('headers', (headers) => {
|
||||
console.log('Test server: WebSocket headers:', headers);
|
||||
});
|
||||
|
||||
wsServer.on('close', () => {
|
||||
console.log('Test server: WebSocket server closed');
|
||||
});
|
||||
|
||||
await new Promise<void>((resolve) => testServer.listen(3100, resolve));
|
||||
console.log('Test server listening on port 3100');
|
||||
});
|
||||
|
||||
tap.test('should create proxy instance', async () => {
|
||||
// Test with the original minimal options (only port)
|
||||
testProxy = new smartproxy.HttpProxy({
|
||||
port: 3001,
|
||||
});
|
||||
expect(testProxy).toEqual(testProxy); // Instance equality check
|
||||
});
|
||||
|
||||
tap.test('should create proxy instance with extended options', async () => {
|
||||
// Test with extended options to verify backward compatibility
|
||||
testProxy = new smartproxy.HttpProxy({
|
||||
port: 3001,
|
||||
maxConnections: 5000,
|
||||
keepAliveTimeout: 120000,
|
||||
headersTimeout: 60000,
|
||||
logLevel: 'info',
|
||||
cors: {
|
||||
allowOrigin: '*',
|
||||
allowMethods: 'GET, POST, OPTIONS',
|
||||
allowHeaders: 'Content-Type',
|
||||
maxAge: 3600
|
||||
}
|
||||
});
|
||||
expect(testProxy).toEqual(testProxy); // Instance equality check
|
||||
expect(testProxy.options.port).toEqual(3001);
|
||||
});
|
||||
|
||||
tap.test('should start the proxy server', async () => {
|
||||
// Create a new proxy instance
|
||||
testProxy = new smartproxy.HttpProxy({
|
||||
port: 3001,
|
||||
maxConnections: 5000,
|
||||
backendProtocol: 'http1',
|
||||
acme: {
|
||||
enabled: false // Disable ACME for testing
|
||||
}
|
||||
});
|
||||
|
||||
// Configure routes for the proxy
|
||||
await testProxy.updateRouteConfigs([
|
||||
{
|
||||
match: {
|
||||
ports: [3001],
|
||||
domains: ['push.rocks', 'localhost']
|
||||
},
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{
|
||||
host: 'localhost',
|
||||
port: 3100
|
||||
}],
|
||||
tls: {
|
||||
mode: 'terminate'
|
||||
},
|
||||
websocket: {
|
||||
enabled: true,
|
||||
subprotocols: ['echo-protocol']
|
||||
}
|
||||
}
|
||||
}
|
||||
]);
|
||||
|
||||
// Start the proxy
|
||||
await testProxy.start();
|
||||
|
||||
// Verify the proxy is listening on the correct port
|
||||
expect(testProxy.getListeningPort()).toEqual(3001);
|
||||
});
|
||||
|
||||
tap.test('should route HTTPS requests based on host header', async () => {
|
||||
// IMPORTANT: Connect to localhost (where the proxy is listening) but use the Host header "push.rocks"
|
||||
const response = await makeHttpsRequest({
|
||||
hostname: 'localhost', // changed from 'push.rocks' to 'localhost'
|
||||
port: 3001,
|
||||
path: '/',
|
||||
method: 'GET',
|
||||
headers: {
|
||||
host: 'push.rocks', // virtual host for routing
|
||||
},
|
||||
rejectUnauthorized: false,
|
||||
});
|
||||
|
||||
expect(response.statusCode).toEqual(200);
|
||||
expect(response.body).toEqual('Hello from test server!');
|
||||
});
|
||||
|
||||
tap.test('should handle unknown host headers', async () => {
|
||||
// Connect to localhost but use an unknown host header.
|
||||
const response = await makeHttpsRequest({
|
||||
hostname: 'localhost', // connecting to localhost
|
||||
port: 3001,
|
||||
path: '/',
|
||||
method: 'GET',
|
||||
headers: {
|
||||
host: 'unknown.host', // this should not match any proxy config
|
||||
},
|
||||
rejectUnauthorized: false,
|
||||
});
|
||||
|
||||
// Expect a 404 response with the appropriate error message.
|
||||
expect(response.statusCode).toEqual(404);
|
||||
});
|
||||
|
||||
tap.test('should support WebSocket connections', async () => {
|
||||
// Create a WebSocket client
|
||||
console.log('[TEST] Testing WebSocket connection');
|
||||
|
||||
console.log('[TEST] Creating WebSocket to wss://localhost:3001/ with host header: push.rocks');
|
||||
const ws = new WebSocket('wss://localhost:3001/', {
|
||||
protocol: 'echo-protocol',
|
||||
rejectUnauthorized: false,
|
||||
headers: {
|
||||
host: 'push.rocks'
|
||||
}
|
||||
});
|
||||
|
||||
const connectionTimeout = setTimeout(() => {
|
||||
console.error('[TEST] WebSocket connection timeout');
|
||||
ws.terminate();
|
||||
}, 5000);
|
||||
|
||||
const timeouts: NodeJS.Timeout[] = [connectionTimeout];
|
||||
|
||||
try {
|
||||
// Wait for connection with timeout
|
||||
await Promise.race([
|
||||
new Promise<void>((resolve, reject) => {
|
||||
ws.on('open', () => {
|
||||
console.log('[TEST] WebSocket connected');
|
||||
clearTimeout(connectionTimeout);
|
||||
resolve();
|
||||
});
|
||||
ws.on('error', (err) => {
|
||||
console.error('[TEST] WebSocket connection error:', err);
|
||||
clearTimeout(connectionTimeout);
|
||||
reject(err);
|
||||
});
|
||||
}),
|
||||
new Promise<void>((_, reject) => {
|
||||
const timeout = setTimeout(() => reject(new Error('Connection timeout')), 3000);
|
||||
timeouts.push(timeout);
|
||||
})
|
||||
]);
|
||||
|
||||
// Send a message and receive echo with timeout
|
||||
await Promise.race([
|
||||
new Promise<void>((resolve, reject) => {
|
||||
const testMessage = 'Hello WebSocket!';
|
||||
let messageReceived = false;
|
||||
|
||||
ws.on('message', (data) => {
|
||||
messageReceived = true;
|
||||
const message = data.toString();
|
||||
console.log('[TEST] Received WebSocket message:', message);
|
||||
expect(message).toEqual(`Echo: ${testMessage}`);
|
||||
resolve();
|
||||
});
|
||||
|
||||
ws.on('error', (err) => {
|
||||
console.error('[TEST] WebSocket message error:', err);
|
||||
reject(err);
|
||||
});
|
||||
|
||||
console.log('[TEST] Sending WebSocket message:', testMessage);
|
||||
ws.send(testMessage);
|
||||
|
||||
// Add additional debug logging
|
||||
const debugTimeout = setTimeout(() => {
|
||||
if (!messageReceived) {
|
||||
console.log('[TEST] No message received after 2 seconds');
|
||||
}
|
||||
}, 2000);
|
||||
timeouts.push(debugTimeout);
|
||||
}),
|
||||
new Promise<void>((_, reject) => {
|
||||
const timeout = setTimeout(() => reject(new Error('Message timeout')), 3000);
|
||||
timeouts.push(timeout);
|
||||
})
|
||||
]);
|
||||
|
||||
// Close the connection properly
|
||||
await Promise.race([
|
||||
new Promise<void>((resolve) => {
|
||||
ws.on('close', () => {
|
||||
console.log('[TEST] WebSocket closed');
|
||||
resolve();
|
||||
});
|
||||
ws.close();
|
||||
}),
|
||||
new Promise<void>((resolve) => {
|
||||
const timeout = setTimeout(() => {
|
||||
console.log('[TEST] Force closing WebSocket');
|
||||
ws.terminate();
|
||||
resolve();
|
||||
}, 2000);
|
||||
timeouts.push(timeout);
|
||||
})
|
||||
]);
|
||||
} catch (error) {
|
||||
console.error('[TEST] WebSocket test error:', error);
|
||||
try {
|
||||
ws.terminate();
|
||||
} catch (terminateError) {
|
||||
console.error('[TEST] Error during terminate:', terminateError);
|
||||
}
|
||||
// Skip if WebSocket fails for now
|
||||
console.log('[TEST] WebSocket test failed, continuing with other tests');
|
||||
} finally {
|
||||
// Clean up all timeouts
|
||||
timeouts.forEach(timeout => clearTimeout(timeout));
|
||||
}
|
||||
});
|
||||
|
||||
tap.test('should handle custom headers', async () => {
|
||||
await testProxy.addDefaultHeaders({
|
||||
'X-Proxy-Header': 'test-value',
|
||||
});
|
||||
|
||||
const response = await makeHttpsRequest({
|
||||
hostname: 'localhost', // changed to 'localhost'
|
||||
port: 3001,
|
||||
path: '/',
|
||||
method: 'GET',
|
||||
headers: {
|
||||
host: 'push.rocks', // still routing to push.rocks
|
||||
},
|
||||
rejectUnauthorized: false,
|
||||
});
|
||||
|
||||
expect(response.headers['x-proxy-header']).toEqual('test-value');
|
||||
});
|
||||
|
||||
tap.test('should handle CORS preflight requests', async () => {
|
||||
// Test OPTIONS request (CORS preflight)
|
||||
const response = await makeHttpsRequest({
|
||||
hostname: 'localhost',
|
||||
port: 3001,
|
||||
path: '/',
|
||||
method: 'OPTIONS',
|
||||
headers: {
|
||||
host: 'push.rocks',
|
||||
origin: 'https://example.com',
|
||||
'access-control-request-method': 'POST',
|
||||
'access-control-request-headers': 'content-type'
|
||||
},
|
||||
rejectUnauthorized: false,
|
||||
});
|
||||
|
||||
// Should get appropriate CORS headers
|
||||
expect(response.statusCode).toBeLessThan(300); // 200 or 204
|
||||
expect(response.headers['access-control-allow-origin']).toEqual('*');
|
||||
expect(response.headers['access-control-allow-methods']).toContain('GET');
|
||||
expect(response.headers['access-control-allow-methods']).toContain('POST');
|
||||
});
|
||||
|
||||
tap.test('should track connections and metrics', async () => {
|
||||
// Get metrics from the proxy
|
||||
const metrics = testProxy.getMetrics();
|
||||
|
||||
// Verify metrics structure and some values
|
||||
expect(metrics).toHaveProperty('activeConnections');
|
||||
expect(metrics).toHaveProperty('totalRequests');
|
||||
expect(metrics).toHaveProperty('failedRequests');
|
||||
expect(metrics).toHaveProperty('uptime');
|
||||
expect(metrics).toHaveProperty('memoryUsage');
|
||||
expect(metrics).toHaveProperty('activeWebSockets');
|
||||
|
||||
// Should have served at least some requests from previous tests
|
||||
expect(metrics.totalRequests).toBeGreaterThan(0);
|
||||
expect(metrics.uptime).toBeGreaterThan(0);
|
||||
});
|
||||
|
||||
tap.test('should update capacity settings', async () => {
|
||||
// Update proxy capacity settings
|
||||
testProxy.updateCapacity(2000, 60000, 25);
|
||||
|
||||
// Verify settings were updated
|
||||
expect(testProxy.options.maxConnections).toEqual(2000);
|
||||
expect(testProxy.options.keepAliveTimeout).toEqual(60000);
|
||||
expect(testProxy.options.connectionPoolSize).toEqual(25);
|
||||
});
|
||||
|
||||
tap.test('should handle certificate requests', async () => {
|
||||
// Test certificate request (this won't actually issue a cert in test mode)
|
||||
const result = await testProxy.requestCertificate('test.example.com');
|
||||
|
||||
// In test mode with ACME disabled, this should return false
|
||||
expect(result).toEqual(false);
|
||||
});
|
||||
|
||||
tap.test('should update certificates directly', async () => {
|
||||
// Test certificate update
|
||||
const testCert = '-----BEGIN CERTIFICATE-----\nMIIB...test...';
|
||||
const testKey = '-----BEGIN PRIVATE KEY-----\nMIIE...test...';
|
||||
|
||||
// This should not throw
|
||||
expect(() => {
|
||||
testProxy.updateCertificate('test.example.com', testCert, testKey);
|
||||
}).not.toThrow();
|
||||
});
|
||||
|
||||
tap.test('cleanup', async () => {
|
||||
console.log('[TEST] Starting cleanup');
|
||||
|
||||
try {
|
||||
// 1. Close WebSocket clients if server exists
|
||||
if (wsServer && wsServer.clients) {
|
||||
console.log(`[TEST] Terminating ${wsServer.clients.size} WebSocket clients`);
|
||||
wsServer.clients.forEach((client) => {
|
||||
try {
|
||||
client.terminate();
|
||||
} catch (err) {
|
||||
console.error('[TEST] Error terminating client:', err);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// 2. Close WebSocket server with timeout
|
||||
if (wsServer) {
|
||||
console.log('[TEST] Closing WebSocket server');
|
||||
await Promise.race([
|
||||
new Promise<void>((resolve, reject) => {
|
||||
wsServer.close((err) => {
|
||||
if (err) {
|
||||
console.error('[TEST] Error closing WebSocket server:', err);
|
||||
reject(err);
|
||||
} else {
|
||||
console.log('[TEST] WebSocket server closed');
|
||||
resolve();
|
||||
}
|
||||
});
|
||||
}).catch((err) => {
|
||||
console.error('[TEST] Caught error closing WebSocket server:', err);
|
||||
}),
|
||||
new Promise<void>((resolve) => {
|
||||
setTimeout(() => {
|
||||
console.log('[TEST] WebSocket server close timeout');
|
||||
resolve();
|
||||
}, 1000);
|
||||
})
|
||||
]);
|
||||
}
|
||||
|
||||
// 3. Close test server with timeout
|
||||
if (testServer) {
|
||||
console.log('[TEST] Closing test server');
|
||||
// First close all connections
|
||||
testServer.closeAllConnections();
|
||||
|
||||
await Promise.race([
|
||||
new Promise<void>((resolve, reject) => {
|
||||
testServer.close((err) => {
|
||||
if (err) {
|
||||
console.error('[TEST] Error closing test server:', err);
|
||||
reject(err);
|
||||
} else {
|
||||
console.log('[TEST] Test server closed');
|
||||
resolve();
|
||||
}
|
||||
});
|
||||
}).catch((err) => {
|
||||
console.error('[TEST] Caught error closing test server:', err);
|
||||
}),
|
||||
new Promise<void>((resolve) => {
|
||||
setTimeout(() => {
|
||||
console.log('[TEST] Test server close timeout');
|
||||
resolve();
|
||||
}, 1000);
|
||||
})
|
||||
]);
|
||||
}
|
||||
|
||||
// 4. Stop the proxy with timeout
|
||||
if (testProxy) {
|
||||
console.log('[TEST] Stopping proxy');
|
||||
await Promise.race([
|
||||
testProxy.stop()
|
||||
.then(() => {
|
||||
console.log('[TEST] Proxy stopped successfully');
|
||||
})
|
||||
.catch((error) => {
|
||||
console.error('[TEST] Error stopping proxy:', error);
|
||||
}),
|
||||
new Promise<void>((resolve) => {
|
||||
setTimeout(() => {
|
||||
console.log('[TEST] Proxy stop timeout');
|
||||
resolve();
|
||||
}, 2000);
|
||||
})
|
||||
]);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('[TEST] Error during cleanup:', error);
|
||||
}
|
||||
|
||||
console.log('[TEST] Cleanup complete');
|
||||
|
||||
// Add debugging to see what might be keeping the process alive
|
||||
if (process.env.DEBUG_HANDLES) {
|
||||
console.log('[TEST] Active handles:', (process as any)._getActiveHandles?.().length);
|
||||
console.log('[TEST] Active requests:', (process as any)._getActiveRequests?.().length);
|
||||
}
|
||||
});
|
||||
|
||||
// Exit handler removed to prevent interference with test cleanup
|
||||
|
||||
// Teardown test removed - let tap handle proper cleanup
|
||||
|
||||
export default tap.start();
|
||||
@@ -1,250 +0,0 @@
|
||||
import { expect, tap } from '@git.zone/tstest/tapbundle';
|
||||
import * as net from 'net';
|
||||
import { SmartProxy } from '../ts/index.js';
|
||||
import * as plugins from '../ts/plugins.js';
|
||||
|
||||
tap.test('keepalive support - verify keepalive connections are properly handled', async (tools) => {
|
||||
console.log('\n=== KeepAlive Support Test ===');
|
||||
console.log('Purpose: Verify that keepalive connections are not prematurely cleaned up');
|
||||
|
||||
// Create a simple echo backend
|
||||
const echoBackend = net.createServer((socket) => {
|
||||
socket.on('data', (data) => {
|
||||
// Echo back received data
|
||||
try {
|
||||
socket.write(data);
|
||||
} catch (err) {
|
||||
// Ignore write errors during shutdown
|
||||
}
|
||||
});
|
||||
|
||||
socket.on('error', (err: NodeJS.ErrnoException) => {
|
||||
// Ignore errors from backend sockets
|
||||
console.log(`Backend socket error (expected during cleanup): ${err.code}`);
|
||||
});
|
||||
});
|
||||
|
||||
await new Promise<void>((resolve) => {
|
||||
echoBackend.listen(9998, () => {
|
||||
console.log('✓ Echo backend started on port 9998');
|
||||
resolve();
|
||||
});
|
||||
});
|
||||
|
||||
// Test 1: Standard keepalive treatment
|
||||
console.log('\n--- Test 1: Standard KeepAlive Treatment ---');
|
||||
|
||||
const proxy1 = new SmartProxy({
|
||||
routes: [{
|
||||
name: 'keepalive-route',
|
||||
match: { ports: 8590 },
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{ host: 'localhost', port: 9998 }]
|
||||
}
|
||||
}],
|
||||
keepAlive: true,
|
||||
keepAliveTreatment: 'standard',
|
||||
inactivityTimeout: 5000, // 5 seconds for faster testing
|
||||
enableDetailedLogging: false,
|
||||
});
|
||||
|
||||
await proxy1.start();
|
||||
console.log('✓ Proxy with standard keepalive started on port 8590');
|
||||
|
||||
// Create a keepalive connection
|
||||
const client1 = net.connect(8590, 'localhost');
|
||||
|
||||
// Add error handler to prevent unhandled errors
|
||||
client1.on('error', (err: NodeJS.ErrnoException) => {
|
||||
console.log(`Client1 error (expected during cleanup): ${err.code}`);
|
||||
});
|
||||
|
||||
await new Promise<void>((resolve) => {
|
||||
client1.on('connect', () => {
|
||||
console.log('Client connected');
|
||||
client1.setKeepAlive(true, 1000);
|
||||
resolve();
|
||||
});
|
||||
});
|
||||
|
||||
// Send initial data
|
||||
client1.write('Hello keepalive\n');
|
||||
|
||||
// Wait for echo
|
||||
await new Promise<void>((resolve) => {
|
||||
client1.once('data', (data) => {
|
||||
console.log(`Received echo: ${data.toString().trim()}`);
|
||||
resolve();
|
||||
});
|
||||
});
|
||||
|
||||
// Check connection is marked as keepalive
|
||||
const cm1 = (proxy1 as any).connectionManager;
|
||||
const connections1 = cm1.getConnections();
|
||||
let keepAliveCount = 0;
|
||||
|
||||
for (const [id, record] of connections1) {
|
||||
if (record.hasKeepAlive) {
|
||||
keepAliveCount++;
|
||||
console.log(`KeepAlive connection ${id}: hasKeepAlive=${record.hasKeepAlive}`);
|
||||
}
|
||||
}
|
||||
|
||||
expect(keepAliveCount).toEqual(1);
|
||||
|
||||
// Wait to ensure it's not cleaned up prematurely
|
||||
await plugins.smartdelay.delayFor(6000);
|
||||
|
||||
const afterWaitCount1 = cm1.getConnectionCount();
|
||||
console.log(`Connections after 6s wait: ${afterWaitCount1}`);
|
||||
expect(afterWaitCount1).toEqual(1); // Should still be connected
|
||||
|
||||
// Send more data to keep it alive
|
||||
client1.write('Still alive\n');
|
||||
|
||||
// Clean up test 1
|
||||
client1.destroy();
|
||||
await proxy1.stop();
|
||||
await plugins.smartdelay.delayFor(500); // Wait for port to be released
|
||||
|
||||
// Test 2: Extended keepalive treatment
|
||||
console.log('\n--- Test 2: Extended KeepAlive Treatment ---');
|
||||
|
||||
const proxy2 = new SmartProxy({
|
||||
routes: [{
|
||||
name: 'keepalive-extended',
|
||||
match: { ports: 8591 },
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{ host: 'localhost', port: 9998 }]
|
||||
}
|
||||
}],
|
||||
keepAlive: true,
|
||||
keepAliveTreatment: 'extended',
|
||||
keepAliveInactivityMultiplier: 6,
|
||||
inactivityTimeout: 2000, // 2 seconds base, 12 seconds with multiplier
|
||||
enableDetailedLogging: false,
|
||||
});
|
||||
|
||||
await proxy2.start();
|
||||
console.log('✓ Proxy with extended keepalive started on port 8591');
|
||||
|
||||
const client2 = net.connect(8591, 'localhost');
|
||||
|
||||
// Add error handler to prevent unhandled errors
|
||||
client2.on('error', (err: NodeJS.ErrnoException) => {
|
||||
console.log(`Client2 error (expected during cleanup): ${err.code}`);
|
||||
});
|
||||
|
||||
await new Promise<void>((resolve) => {
|
||||
client2.on('connect', () => {
|
||||
console.log('Client connected with extended timeout');
|
||||
client2.setKeepAlive(true, 1000);
|
||||
resolve();
|
||||
});
|
||||
});
|
||||
|
||||
// Send initial data
|
||||
client2.write('Extended keepalive\n');
|
||||
|
||||
// Check connection
|
||||
const cm2 = (proxy2 as any).connectionManager;
|
||||
await plugins.smartdelay.delayFor(1000);
|
||||
|
||||
const connections2 = cm2.getConnections();
|
||||
for (const [id, record] of connections2) {
|
||||
console.log(`Extended connection ${id}: hasKeepAlive=${record.hasKeepAlive}, treatment=extended`);
|
||||
}
|
||||
|
||||
// Wait 3 seconds (would timeout with standard treatment)
|
||||
await plugins.smartdelay.delayFor(3000);
|
||||
|
||||
const midWaitCount = cm2.getConnectionCount();
|
||||
console.log(`Connections after 3s (base timeout exceeded): ${midWaitCount}`);
|
||||
expect(midWaitCount).toEqual(1); // Should still be connected due to extended treatment
|
||||
|
||||
// Clean up test 2
|
||||
client2.destroy();
|
||||
await proxy2.stop();
|
||||
await plugins.smartdelay.delayFor(500); // Wait for port to be released
|
||||
|
||||
// Test 3: Immortal keepalive treatment
|
||||
console.log('\n--- Test 3: Immortal KeepAlive Treatment ---');
|
||||
|
||||
const proxy3 = new SmartProxy({
|
||||
routes: [{
|
||||
name: 'keepalive-immortal',
|
||||
match: { ports: 8592 },
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{ host: 'localhost', port: 9998 }]
|
||||
}
|
||||
}],
|
||||
keepAlive: true,
|
||||
keepAliveTreatment: 'immortal',
|
||||
inactivityTimeout: 1000, // 1 second - should be ignored for immortal
|
||||
enableDetailedLogging: false,
|
||||
});
|
||||
|
||||
await proxy3.start();
|
||||
console.log('✓ Proxy with immortal keepalive started on port 8592');
|
||||
|
||||
const client3 = net.connect(8592, 'localhost');
|
||||
|
||||
// Add error handler to prevent unhandled errors
|
||||
client3.on('error', (err: NodeJS.ErrnoException) => {
|
||||
console.log(`Client3 error (expected during cleanup): ${err.code}`);
|
||||
});
|
||||
|
||||
await new Promise<void>((resolve) => {
|
||||
client3.on('connect', () => {
|
||||
console.log('Client connected with immortal treatment');
|
||||
client3.setKeepAlive(true, 1000);
|
||||
resolve();
|
||||
});
|
||||
});
|
||||
|
||||
// Send initial data
|
||||
client3.write('Immortal connection\n');
|
||||
|
||||
// Wait well beyond normal timeout
|
||||
await plugins.smartdelay.delayFor(5000);
|
||||
|
||||
const cm3 = (proxy3 as any).connectionManager;
|
||||
const immortalCount = cm3.getConnectionCount();
|
||||
console.log(`Immortal connections after 5s inactivity: ${immortalCount}`);
|
||||
expect(immortalCount).toEqual(1); // Should never timeout
|
||||
|
||||
// Verify zombie detection doesn't affect immortal connections
|
||||
console.log('\n--- Verifying zombie detection respects keepalive ---');
|
||||
|
||||
// Manually trigger inactivity check
|
||||
cm3.performOptimizedInactivityCheck();
|
||||
|
||||
await plugins.smartdelay.delayFor(1000);
|
||||
|
||||
const afterCheckCount = cm3.getConnectionCount();
|
||||
console.log(`Connections after manual inactivity check: ${afterCheckCount}`);
|
||||
expect(afterCheckCount).toEqual(1); // Should still be alive
|
||||
|
||||
// Clean up
|
||||
client3.destroy();
|
||||
await proxy3.stop();
|
||||
|
||||
// Close backend and wait for it to fully close
|
||||
await new Promise<void>((resolve) => {
|
||||
echoBackend.close(() => {
|
||||
console.log('Echo backend closed');
|
||||
resolve();
|
||||
});
|
||||
});
|
||||
|
||||
console.log('\n✓ All keepalive tests passed:');
|
||||
console.log(' - Standard treatment works correctly');
|
||||
console.log(' - Extended treatment applies multiplier');
|
||||
console.log(' - Immortal treatment never times out');
|
||||
console.log(' - Zombie detection respects keepalive settings');
|
||||
});
|
||||
|
||||
export default tap.start();
|
||||
@@ -1,151 +0,0 @@
|
||||
import { tap, expect } from '@git.zone/tstest/tapbundle';
|
||||
import { SmartProxy, createHttpRoute } from '../ts/index.js';
|
||||
import * as http from 'http';
|
||||
|
||||
tap.test('should not have memory leaks in long-running operations', async (tools) => {
|
||||
// Get initial memory usage
|
||||
const getMemoryUsage = () => {
|
||||
if (global.gc) {
|
||||
global.gc();
|
||||
}
|
||||
const usage = process.memoryUsage();
|
||||
return {
|
||||
heapUsed: Math.round(usage.heapUsed / 1024 / 1024), // MB
|
||||
external: Math.round(usage.external / 1024 / 1024), // MB
|
||||
rss: Math.round(usage.rss / 1024 / 1024) // MB
|
||||
};
|
||||
};
|
||||
|
||||
// Create a target server
|
||||
const targetServer = http.createServer((req, res) => {
|
||||
res.writeHead(200, { 'Content-Type': 'text/plain' });
|
||||
res.end('OK');
|
||||
});
|
||||
await new Promise<void>((resolve) => targetServer.listen(3100, resolve));
|
||||
|
||||
// Create the proxy - use non-privileged port
|
||||
const routes = [
|
||||
createHttpRoute(['test1.local', 'test2.local', 'test3.local'], { host: 'localhost', port: 3100 }),
|
||||
];
|
||||
// Update route to use port 8080
|
||||
routes[0].match.ports = 8080;
|
||||
|
||||
const proxy = new SmartProxy({
|
||||
routes: routes
|
||||
});
|
||||
await proxy.start();
|
||||
|
||||
console.log('Starting memory leak test...');
|
||||
const initialMemory = getMemoryUsage();
|
||||
console.log('Initial memory:', initialMemory);
|
||||
|
||||
// Function to make requests
|
||||
const makeRequest = (domain: string): Promise<void> => {
|
||||
return new Promise((resolve, reject) => {
|
||||
const req = http.request({
|
||||
hostname: 'localhost',
|
||||
port: 8080,
|
||||
path: '/',
|
||||
method: 'GET',
|
||||
headers: {
|
||||
'Host': domain
|
||||
}
|
||||
}, (res) => {
|
||||
res.on('data', () => {});
|
||||
res.on('end', resolve);
|
||||
});
|
||||
req.on('error', reject);
|
||||
req.end();
|
||||
});
|
||||
};
|
||||
|
||||
// Test 1: Many requests to the same routes
|
||||
console.log('Test 1: Making 1000 requests to same routes...');
|
||||
for (let i = 0; i < 1000; i++) {
|
||||
await makeRequest(`test${(i % 3) + 1}.local`);
|
||||
if (i % 100 === 0) {
|
||||
console.log(` Progress: ${i}/1000`);
|
||||
}
|
||||
}
|
||||
|
||||
const afterSameRoutesMemory = getMemoryUsage();
|
||||
console.log('Memory after same routes:', afterSameRoutesMemory);
|
||||
|
||||
// Test 2: Many requests to different routes (tests routeContextCache)
|
||||
console.log('Test 2: Making 1000 requests to different routes...');
|
||||
for (let i = 0; i < 1000; i++) {
|
||||
// Create unique domain to test cache growth
|
||||
await makeRequest(`test${i}.local`);
|
||||
if (i % 100 === 0) {
|
||||
console.log(` Progress: ${i}/1000`);
|
||||
}
|
||||
}
|
||||
|
||||
const afterDifferentRoutesMemory = getMemoryUsage();
|
||||
console.log('Memory after different routes:', afterDifferentRoutesMemory);
|
||||
|
||||
// Test 3: Check metrics collector memory
|
||||
console.log('Test 3: Checking metrics collector...');
|
||||
const metrics = proxy.getMetrics();
|
||||
console.log(`Active connections: ${metrics.connections.active()}`);
|
||||
console.log(`Total connections: ${metrics.connections.total()}`);
|
||||
console.log(`RPS: ${metrics.requests.perSecond()}`);
|
||||
|
||||
// Test 4: Many rapid connections (tests requestTimestamps array)
|
||||
console.log('Test 4: Making 500 rapid requests...');
|
||||
const rapidRequests = [];
|
||||
for (let i = 0; i < 500; i++) {
|
||||
rapidRequests.push(makeRequest('test1.local'));
|
||||
if (i % 50 === 0) {
|
||||
// Wait a bit to let some complete
|
||||
await Promise.all(rapidRequests);
|
||||
rapidRequests.length = 0;
|
||||
// Add delay to allow connections to close
|
||||
await new Promise(resolve => setTimeout(resolve, 100));
|
||||
console.log(` Progress: ${i}/500`);
|
||||
}
|
||||
}
|
||||
await Promise.all(rapidRequests);
|
||||
|
||||
const afterRapidMemory = getMemoryUsage();
|
||||
console.log('Memory after rapid requests:', afterRapidMemory);
|
||||
|
||||
// Force garbage collection and check final memory
|
||||
await new Promise(resolve => setTimeout(resolve, 1000));
|
||||
const finalMemory = getMemoryUsage();
|
||||
console.log('Final memory:', finalMemory);
|
||||
|
||||
// Memory leak checks
|
||||
const memoryGrowth = finalMemory.heapUsed - initialMemory.heapUsed;
|
||||
console.log(`Total memory growth: ${memoryGrowth} MB`);
|
||||
|
||||
// Check for excessive memory growth
|
||||
// Allow some growth but not excessive (e.g., more than 50MB for this test)
|
||||
expect(memoryGrowth).toBeLessThan(50);
|
||||
|
||||
// Check specific potential leaks
|
||||
// 1. Route context cache should not grow unbounded
|
||||
const routeHandler = proxy.routeConnectionHandler as any;
|
||||
if (routeHandler.routeContextCache) {
|
||||
console.log(`Route context cache size: ${routeHandler.routeContextCache.size}`);
|
||||
// Should not have 1000 entries from different routes test
|
||||
expect(routeHandler.routeContextCache.size).toBeLessThan(100);
|
||||
}
|
||||
|
||||
// 2. Metrics collector should clean up old timestamps
|
||||
const metricsCollector = (proxy as any).metricsCollector;
|
||||
if (metricsCollector && metricsCollector.requestTimestamps) {
|
||||
console.log(`Request timestamps array length: ${metricsCollector.requestTimestamps.length}`);
|
||||
// Should clean up old timestamps periodically
|
||||
expect(metricsCollector.requestTimestamps.length).toBeLessThanOrEqual(10000);
|
||||
}
|
||||
|
||||
// Cleanup
|
||||
await proxy.stop();
|
||||
await new Promise<void>((resolve) => targetServer.close(() => resolve()));
|
||||
|
||||
console.log('Memory leak test completed successfully');
|
||||
});
|
||||
|
||||
// Run with: node --expose-gc test.memory-leak-check.node.ts
|
||||
export default tap.start();
|
||||
@@ -1,59 +0,0 @@
|
||||
import { tap, expect } from '@git.zone/tstest/tapbundle';
|
||||
import { SmartProxy, createHttpRoute } from '../ts/index.js';
|
||||
import * as http from 'http';
|
||||
|
||||
tap.test('memory leak fixes verification', async () => {
|
||||
// Test 1: MetricsCollector requestTimestamps cleanup
|
||||
console.log('\n=== Test 1: MetricsCollector requestTimestamps cleanup ===');
|
||||
const proxy = new SmartProxy({
|
||||
routes: [
|
||||
createHttpRoute('test.local', { host: 'localhost', port: 3200 }, {
|
||||
match: {
|
||||
ports: 8081,
|
||||
domains: 'test.local'
|
||||
}
|
||||
}),
|
||||
]
|
||||
});
|
||||
|
||||
await proxy.start();
|
||||
|
||||
const metricsCollector = (proxy as any).metricsCollector;
|
||||
|
||||
// Check initial state
|
||||
console.log('Initial timestamps:', metricsCollector.requestTimestamps.length);
|
||||
|
||||
// Simulate many requests to test cleanup
|
||||
for (let i = 0; i < 6000; i++) {
|
||||
metricsCollector.recordRequest();
|
||||
}
|
||||
|
||||
// Should be cleaned up to MAX_TIMESTAMPS (5000)
|
||||
console.log('After 6000 requests:', metricsCollector.requestTimestamps.length);
|
||||
expect(metricsCollector.requestTimestamps.length).toBeLessThanOrEqual(5000);
|
||||
|
||||
await proxy.stop();
|
||||
|
||||
// Test 2: Verify intervals are cleaned up
|
||||
console.log('\n=== Test 2: Verify cleanup methods exist ===');
|
||||
|
||||
// Check RequestHandler has destroy method
|
||||
const { RequestHandler } = await import('../ts/proxies/http-proxy/request-handler.js');
|
||||
const requestHandler = new RequestHandler({ port: 8080 }, null as any);
|
||||
expect(typeof requestHandler.destroy).toEqual('function');
|
||||
console.log('✓ RequestHandler has destroy method');
|
||||
|
||||
// Check FunctionCache has destroy method
|
||||
const { FunctionCache } = await import('../ts/proxies/http-proxy/function-cache.js');
|
||||
const functionCache = new FunctionCache({ debug: () => {}, info: () => {} } as any);
|
||||
expect(typeof functionCache.destroy).toEqual('function');
|
||||
console.log('✓ FunctionCache has destroy method');
|
||||
|
||||
// Cleanup
|
||||
requestHandler.destroy();
|
||||
functionCache.destroy();
|
||||
|
||||
console.log('\n✅ All memory leak fixes verified!');
|
||||
});
|
||||
|
||||
export default tap.start();
|
||||
@@ -1,131 +0,0 @@
|
||||
import { tap, expect } from '@git.zone/tstest/tapbundle';
|
||||
|
||||
tap.test('memory leak fixes - unit tests', async () => {
|
||||
console.log('\n=== Testing MetricsCollector memory management ===');
|
||||
|
||||
// Import and test MetricsCollector directly
|
||||
const { MetricsCollector } = await import('../ts/proxies/smart-proxy/metrics-collector.js');
|
||||
|
||||
// Create a mock SmartProxy with minimal required properties
|
||||
const mockProxy = {
|
||||
connectionManager: {
|
||||
getConnectionCount: () => 0,
|
||||
getConnections: () => new Map(),
|
||||
getTerminationStats: () => ({ incoming: {} })
|
||||
},
|
||||
routeConnectionHandler: {
|
||||
newConnectionSubject: {
|
||||
subscribe: () => ({ unsubscribe: () => {} })
|
||||
}
|
||||
},
|
||||
settings: {}
|
||||
};
|
||||
|
||||
const collector = new MetricsCollector(mockProxy as any);
|
||||
collector.start();
|
||||
|
||||
// Test timestamp cleanup
|
||||
console.log('Testing requestTimestamps cleanup...');
|
||||
|
||||
// Add 6000 timestamps
|
||||
for (let i = 0; i < 6000; i++) {
|
||||
collector.recordRequest(`conn-${i}`, 'test-route', '127.0.0.1');
|
||||
}
|
||||
|
||||
// Access private property for testing
|
||||
let timestamps = (collector as any).requestTimestamps;
|
||||
console.log(`Timestamps after 6000 requests: ${timestamps.length}`);
|
||||
|
||||
// Force one more request to trigger cleanup
|
||||
collector.recordRequest('conn-final', 'test-route', '127.0.0.1');
|
||||
timestamps = (collector as any).requestTimestamps;
|
||||
console.log(`Timestamps after cleanup trigger: ${timestamps.length}`);
|
||||
|
||||
// Now check the RPS window - all timestamps are within 1 minute so they won't be cleaned
|
||||
const now = Date.now();
|
||||
const oldestTimestamp = Math.min(...timestamps);
|
||||
const windowAge = now - oldestTimestamp;
|
||||
console.log(`Window age: ${windowAge}ms (should be < 60000ms for all to be kept)`);
|
||||
|
||||
// Since all timestamps are recent (within RPS window), they won't be cleaned by window
|
||||
// But the array size should still be limited
|
||||
console.log(`MAX_TIMESTAMPS: ${(collector as any).MAX_TIMESTAMPS}`);
|
||||
|
||||
// The issue is our rapid-fire test - all timestamps are within the window
|
||||
// Let's test with older timestamps
|
||||
console.log('\nTesting with mixed old/new timestamps...');
|
||||
(collector as any).requestTimestamps = [];
|
||||
|
||||
// Add some old timestamps (older than window)
|
||||
const oldTime = now - 70000; // 70 seconds ago
|
||||
for (let i = 0; i < 3000; i++) {
|
||||
(collector as any).requestTimestamps.push(oldTime);
|
||||
}
|
||||
|
||||
// Add new timestamps to exceed limit
|
||||
for (let i = 0; i < 3000; i++) {
|
||||
collector.recordRequest(`conn-new-${i}`, 'test-route', '127.0.0.1');
|
||||
}
|
||||
|
||||
timestamps = (collector as any).requestTimestamps;
|
||||
console.log(`After mixed timestamps: ${timestamps.length} (old ones should be cleaned)`);
|
||||
|
||||
// Old timestamps should be cleaned when we exceed MAX_TIMESTAMPS
|
||||
expect(timestamps.length).toBeLessThanOrEqual(5000);
|
||||
|
||||
// Stop the collector
|
||||
collector.stop();
|
||||
|
||||
console.log('\n=== Testing FunctionCache cleanup ===');
|
||||
|
||||
const { FunctionCache } = await import('../ts/proxies/http-proxy/function-cache.js');
|
||||
|
||||
const mockLogger = {
|
||||
debug: () => {},
|
||||
info: () => {},
|
||||
warn: () => {},
|
||||
error: () => {}
|
||||
};
|
||||
|
||||
const cache = new FunctionCache(mockLogger as any);
|
||||
|
||||
// Check that cleanup interval was set
|
||||
expect((cache as any).cleanupInterval).toBeTruthy();
|
||||
|
||||
// Test destroy method
|
||||
cache.destroy();
|
||||
|
||||
// Cleanup interval should be cleared
|
||||
expect((cache as any).cleanupInterval).toBeNull();
|
||||
|
||||
console.log('✓ FunctionCache properly cleans up interval');
|
||||
|
||||
console.log('\n=== Testing RequestHandler cleanup ===');
|
||||
|
||||
const { RequestHandler } = await import('../ts/proxies/http-proxy/request-handler.js');
|
||||
|
||||
const mockConnectionPool = {
|
||||
getConnection: () => null,
|
||||
releaseConnection: () => {}
|
||||
};
|
||||
|
||||
const handler = new RequestHandler(
|
||||
{ port: 8080, logLevel: 'error' },
|
||||
mockConnectionPool as any
|
||||
);
|
||||
|
||||
// Check that cleanup interval was set
|
||||
expect((handler as any).rateLimitCleanupInterval).toBeTruthy();
|
||||
|
||||
// Test destroy method
|
||||
handler.destroy();
|
||||
|
||||
// Cleanup interval should be cleared
|
||||
expect((handler as any).rateLimitCleanupInterval).toBeNull();
|
||||
|
||||
console.log('✓ RequestHandler properly cleans up interval');
|
||||
|
||||
console.log('\n✅ All memory leak fixes verified!');
|
||||
});
|
||||
|
||||
export default tap.start();
|
||||
@@ -1,280 +0,0 @@
|
||||
import { expect, tap } from '@git.zone/tstest/tapbundle';
|
||||
import { SmartProxy } from '../ts/index.js';
|
||||
import * as net from 'net';
|
||||
import * as plugins from '../ts/plugins.js';
|
||||
|
||||
tap.test('MetricsCollector provides accurate metrics', async (tools) => {
|
||||
console.log('\n=== MetricsCollector Test ===');
|
||||
|
||||
// Create a simple echo server for testing
|
||||
const echoServer = net.createServer((socket) => {
|
||||
socket.on('data', (data) => {
|
||||
socket.write(data);
|
||||
});
|
||||
socket.on('error', () => {}); // Ignore errors
|
||||
});
|
||||
|
||||
await new Promise<void>((resolve) => {
|
||||
echoServer.listen(9995, () => {
|
||||
console.log('✓ Echo server started on port 9995');
|
||||
resolve();
|
||||
});
|
||||
});
|
||||
|
||||
// Create SmartProxy with test routes
|
||||
const proxy = new SmartProxy({
|
||||
routes: [
|
||||
{
|
||||
name: 'test-route-1',
|
||||
match: { ports: 8700 },
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{ host: 'localhost', port: 9995 }]
|
||||
}
|
||||
},
|
||||
{
|
||||
name: 'test-route-2',
|
||||
match: { ports: 8701 },
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{ host: 'localhost', port: 9995 }]
|
||||
}
|
||||
}
|
||||
],
|
||||
enableDetailedLogging: true,
|
||||
});
|
||||
|
||||
await proxy.start();
|
||||
console.log('✓ Proxy started on ports 8700 and 8701');
|
||||
|
||||
// Get metrics interface
|
||||
const metrics = proxy.getMetrics();
|
||||
|
||||
// Test 1: Initial state
|
||||
console.log('\n--- Test 1: Initial State ---');
|
||||
expect(metrics.connections.active()).toEqual(0);
|
||||
expect(metrics.connections.total()).toEqual(0);
|
||||
expect(metrics.requests.perSecond()).toEqual(0);
|
||||
expect(metrics.connections.byRoute().size).toEqual(0);
|
||||
expect(metrics.connections.byIP().size).toEqual(0);
|
||||
|
||||
const throughput = metrics.throughput.instant();
|
||||
expect(throughput.in).toEqual(0);
|
||||
expect(throughput.out).toEqual(0);
|
||||
console.log('✓ Initial metrics are all zero');
|
||||
|
||||
// Test 2: Create connections and verify metrics
|
||||
console.log('\n--- Test 2: Active Connections ---');
|
||||
const clients: net.Socket[] = [];
|
||||
|
||||
// Create 3 connections to route 1
|
||||
for (let i = 0; i < 3; i++) {
|
||||
const client = net.connect(8700, 'localhost');
|
||||
clients.push(client);
|
||||
await new Promise<void>((resolve) => {
|
||||
client.on('connect', resolve);
|
||||
client.on('error', () => resolve());
|
||||
});
|
||||
}
|
||||
|
||||
// Create 2 connections to route 2
|
||||
for (let i = 0; i < 2; i++) {
|
||||
const client = net.connect(8701, 'localhost');
|
||||
clients.push(client);
|
||||
await new Promise<void>((resolve) => {
|
||||
client.on('connect', resolve);
|
||||
client.on('error', () => resolve());
|
||||
});
|
||||
}
|
||||
|
||||
// Wait for connections to be fully established and routed
|
||||
await plugins.smartdelay.delayFor(300);
|
||||
|
||||
// Verify connection counts
|
||||
expect(metrics.connections.active()).toEqual(5);
|
||||
expect(metrics.connections.total()).toEqual(5);
|
||||
console.log(`✓ Active connections: ${metrics.connections.active()}`);
|
||||
console.log(`✓ Total connections: ${metrics.connections.total()}`);
|
||||
|
||||
// Test 3: Connections by route
|
||||
console.log('\n--- Test 3: Connections by Route ---');
|
||||
const routeConnections = metrics.connections.byRoute();
|
||||
console.log('Route connections:', Array.from(routeConnections.entries()));
|
||||
|
||||
// Check if we have the expected counts
|
||||
let route1Count = 0;
|
||||
let route2Count = 0;
|
||||
for (const [routeName, count] of routeConnections) {
|
||||
if (routeName === 'test-route-1') route1Count = count;
|
||||
if (routeName === 'test-route-2') route2Count = count;
|
||||
}
|
||||
|
||||
expect(route1Count).toEqual(3);
|
||||
expect(route2Count).toEqual(2);
|
||||
console.log('✓ Route test-route-1 has 3 connections');
|
||||
console.log('✓ Route test-route-2 has 2 connections');
|
||||
|
||||
// Test 4: Connections by IP
|
||||
console.log('\n--- Test 4: Connections by IP ---');
|
||||
const ipConnections = metrics.connections.byIP();
|
||||
// All connections are from localhost (127.0.0.1 or ::1)
|
||||
let totalIPConnections = 0;
|
||||
for (const [ip, count] of ipConnections) {
|
||||
console.log(` IP ${ip}: ${count} connections`);
|
||||
totalIPConnections += count;
|
||||
}
|
||||
expect(totalIPConnections).toEqual(5);
|
||||
console.log('✓ Total connections by IP matches active connections');
|
||||
|
||||
// Test 5: RPS calculation
|
||||
console.log('\n--- Test 5: Requests Per Second ---');
|
||||
const rps = metrics.requests.perSecond();
|
||||
console.log(` Current RPS: ${rps.toFixed(2)}`);
|
||||
// We created 5 connections, so RPS should be > 0
|
||||
expect(rps).toBeGreaterThan(0);
|
||||
console.log('✓ RPS is greater than 0');
|
||||
|
||||
// Test 6: Throughput
|
||||
console.log('\n--- Test 6: Throughput ---');
|
||||
// Send some data through connections
|
||||
for (const client of clients) {
|
||||
if (!client.destroyed) {
|
||||
client.write('Hello metrics!\n');
|
||||
}
|
||||
}
|
||||
|
||||
// Wait for data to be transmitted and for sampling to occur
|
||||
await plugins.smartdelay.delayFor(1100); // Wait for at least one sampling interval
|
||||
|
||||
const throughputAfter = metrics.throughput.instant();
|
||||
console.log(` Bytes in: ${throughputAfter.in}`);
|
||||
console.log(` Bytes out: ${throughputAfter.out}`);
|
||||
// Throughput might still be 0 if no samples were taken, so just check it's defined
|
||||
expect(throughputAfter.in).toBeDefined();
|
||||
expect(throughputAfter.out).toBeDefined();
|
||||
console.log('✓ Throughput shows bytes transferred');
|
||||
|
||||
// Test 7: Close some connections
|
||||
console.log('\n--- Test 7: Connection Cleanup ---');
|
||||
// Close first 2 clients
|
||||
clients[0].destroy();
|
||||
clients[1].destroy();
|
||||
|
||||
await plugins.smartdelay.delayFor(100);
|
||||
|
||||
expect(metrics.connections.active()).toEqual(3);
|
||||
// Note: total() includes active connections + terminated connections from stats
|
||||
// The terminated connections might not be counted immediately
|
||||
const totalConns = metrics.connections.total();
|
||||
expect(totalConns).toBeGreaterThanOrEqual(3); // At least the active connections
|
||||
console.log(`✓ Active connections reduced to ${metrics.connections.active()}`);
|
||||
console.log(`✓ Total connections: ${totalConns}`);
|
||||
|
||||
// Test 8: Helper methods
|
||||
console.log('\n--- Test 8: Helper Methods ---');
|
||||
|
||||
// Test getTopIPs
|
||||
const topIPs = metrics.connections.topIPs(5);
|
||||
expect(topIPs.length).toBeGreaterThan(0);
|
||||
console.log('✓ getTopIPs returns IP list');
|
||||
|
||||
// Test throughput rate
|
||||
const throughputRate = metrics.throughput.recent();
|
||||
console.log(` Throughput rate: ${throughputRate.in} bytes/sec in, ${throughputRate.out} bytes/sec out`);
|
||||
console.log('✓ Throughput rates calculated');
|
||||
|
||||
// Cleanup
|
||||
console.log('\n--- Cleanup ---');
|
||||
for (const client of clients) {
|
||||
if (!client.destroyed) {
|
||||
client.destroy();
|
||||
}
|
||||
}
|
||||
|
||||
await proxy.stop();
|
||||
echoServer.close();
|
||||
|
||||
console.log('\n✓ All MetricsCollector tests passed');
|
||||
});
|
||||
|
||||
// Test with mock data for unit testing
|
||||
tap.test('MetricsCollector unit test with mock data', async () => {
|
||||
console.log('\n=== MetricsCollector Unit Test ===');
|
||||
|
||||
// Create a mock SmartProxy with mock ConnectionManager
|
||||
const mockConnections = new Map([
|
||||
['conn1', {
|
||||
remoteIP: '192.168.1.1',
|
||||
routeName: 'api',
|
||||
bytesReceived: 1000,
|
||||
bytesSent: 500,
|
||||
incomingStartTime: Date.now() - 5000
|
||||
}],
|
||||
['conn2', {
|
||||
remoteIP: '192.168.1.1',
|
||||
routeName: 'web',
|
||||
bytesReceived: 2000,
|
||||
bytesSent: 1500,
|
||||
incomingStartTime: Date.now() - 10000
|
||||
}],
|
||||
['conn3', {
|
||||
remoteIP: '192.168.1.2',
|
||||
routeName: 'api',
|
||||
bytesReceived: 500,
|
||||
bytesSent: 250,
|
||||
incomingStartTime: Date.now() - 3000
|
||||
}]
|
||||
]);
|
||||
|
||||
const mockSmartProxy = {
|
||||
connectionManager: {
|
||||
getConnectionCount: () => mockConnections.size,
|
||||
getConnections: () => mockConnections,
|
||||
getTerminationStats: () => ({
|
||||
incoming: { normal: 10, timeout: 2, error: 1 }
|
||||
})
|
||||
}
|
||||
};
|
||||
|
||||
// Import MetricsCollector directly
|
||||
const { MetricsCollector } = await import('../ts/proxies/smart-proxy/metrics-collector.js');
|
||||
const metrics = new MetricsCollector(mockSmartProxy as any);
|
||||
|
||||
// Test metrics calculation
|
||||
console.log('\n--- Testing with Mock Data ---');
|
||||
|
||||
expect(metrics.connections.active()).toEqual(3);
|
||||
console.log(`✓ Active connections: ${metrics.connections.active()}`);
|
||||
|
||||
expect(metrics.connections.total()).toEqual(16); // 3 active + 13 terminated
|
||||
console.log(`✓ Total connections: ${metrics.connections.total()}`);
|
||||
|
||||
const routeConns = metrics.connections.byRoute();
|
||||
expect(routeConns.get('api')).toEqual(2);
|
||||
expect(routeConns.get('web')).toEqual(1);
|
||||
console.log('✓ Connections by route calculated correctly');
|
||||
|
||||
const ipConns = metrics.connections.byIP();
|
||||
expect(ipConns.get('192.168.1.1')).toEqual(2);
|
||||
expect(ipConns.get('192.168.1.2')).toEqual(1);
|
||||
console.log('✓ Connections by IP calculated correctly');
|
||||
|
||||
// Throughput tracker returns rates, not totals - just verify it returns something
|
||||
const throughput = metrics.throughput.instant();
|
||||
expect(throughput.in).toBeDefined();
|
||||
expect(throughput.out).toBeDefined();
|
||||
console.log(`✓ Throughput rates calculated: ${throughput.in} bytes/sec in, ${throughput.out} bytes/sec out`);
|
||||
|
||||
// Test RPS tracking
|
||||
metrics.recordRequest('test-1', 'test-route', '192.168.1.1');
|
||||
metrics.recordRequest('test-2', 'test-route', '192.168.1.1');
|
||||
metrics.recordRequest('test-3', 'test-route', '192.168.1.2');
|
||||
|
||||
const rps = metrics.requests.perSecond();
|
||||
expect(rps).toBeGreaterThan(0);
|
||||
console.log(`✓ RPS tracking works: ${rps.toFixed(2)} req/sec`);
|
||||
|
||||
console.log('\n✓ All unit tests passed');
|
||||
});
|
||||
|
||||
export default tap.start();
|
||||
@@ -1,188 +0,0 @@
|
||||
import { expect, tap } from '@git.zone/tstest/tapbundle';
|
||||
import { NFTablesManager } from '../ts/proxies/smart-proxy/nftables-manager.js';
|
||||
import type { IRouteConfig } from '../ts/proxies/smart-proxy/models/route-types.js';
|
||||
import type { ISmartProxyOptions } from '../ts/proxies/smart-proxy/models/interfaces.js';
|
||||
import * as child_process from 'child_process';
|
||||
import { promisify } from 'util';
|
||||
|
||||
const exec = promisify(child_process.exec);
|
||||
|
||||
// Check if we have root privileges
|
||||
async function checkRootPrivileges(): Promise<boolean> {
|
||||
try {
|
||||
const { stdout } = await exec('id -u');
|
||||
return stdout.trim() === '0';
|
||||
} catch (err) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Skip tests if not root
|
||||
const isRoot = await checkRootPrivileges();
|
||||
if (!isRoot) {
|
||||
console.log('');
|
||||
console.log('========================================');
|
||||
console.log('NFTablesManager tests require root privileges');
|
||||
console.log('Skipping NFTablesManager tests');
|
||||
console.log('========================================');
|
||||
console.log('');
|
||||
// Skip tests when not running as root - tests are marked with tap.skip.test
|
||||
}
|
||||
|
||||
/**
|
||||
* Tests for the NFTablesManager class
|
||||
*/
|
||||
|
||||
// Sample route configurations for testing
|
||||
const sampleRoute: IRouteConfig = {
|
||||
name: 'test-nftables-route',
|
||||
match: {
|
||||
ports: 8080,
|
||||
domains: 'test.example.com'
|
||||
},
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{
|
||||
host: 'localhost',
|
||||
port: 8000
|
||||
}],
|
||||
forwardingEngine: 'nftables',
|
||||
nftables: {
|
||||
protocol: 'tcp',
|
||||
preserveSourceIP: true,
|
||||
useIPSets: true
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Sample SmartProxy options
|
||||
const sampleOptions: ISmartProxyOptions = {
|
||||
routes: [sampleRoute],
|
||||
enableDetailedLogging: true
|
||||
};
|
||||
|
||||
// Instance of NFTablesManager for testing
|
||||
let manager: NFTablesManager;
|
||||
|
||||
// Skip these tests by default since they require root privileges to run NFTables commands
|
||||
// When running as root, change this to false
|
||||
const SKIP_TESTS = true;
|
||||
|
||||
tap.skip.test('NFTablesManager setup test', async () => {
|
||||
// Test will be skipped if not running as root due to tap.skip.test
|
||||
|
||||
// Create a SmartProxy instance first
|
||||
const { SmartProxy } = await import('../ts/proxies/smart-proxy/smart-proxy.js');
|
||||
const proxy = new SmartProxy(sampleOptions);
|
||||
|
||||
// Create a new instance of NFTablesManager
|
||||
manager = new NFTablesManager(proxy);
|
||||
|
||||
// Verify the instance was created successfully
|
||||
expect(manager).toBeTruthy();
|
||||
});
|
||||
|
||||
tap.skip.test('NFTablesManager route provisioning test', async () => {
|
||||
// Test will be skipped if not running as root due to tap.skip.test
|
||||
|
||||
// Provision the sample route
|
||||
const result = await manager.provisionRoute(sampleRoute);
|
||||
|
||||
// Verify the route was provisioned successfully
|
||||
expect(result).toEqual(true);
|
||||
|
||||
// Verify the route is listed as provisioned
|
||||
expect(manager.isRouteProvisioned(sampleRoute)).toEqual(true);
|
||||
});
|
||||
|
||||
tap.skip.test('NFTablesManager status test', async () => {
|
||||
// Test will be skipped if not running as root due to tap.skip.test
|
||||
|
||||
// Get the status of the managed rules
|
||||
const status = await manager.getStatus();
|
||||
|
||||
// Verify status includes our route
|
||||
const keys = Object.keys(status);
|
||||
expect(keys.length).toBeGreaterThan(0);
|
||||
|
||||
// Check the status of the first rule
|
||||
const firstStatus = status[keys[0]];
|
||||
expect(firstStatus.active).toEqual(true);
|
||||
expect(firstStatus.ruleCount.added).toBeGreaterThan(0);
|
||||
});
|
||||
|
||||
tap.skip.test('NFTablesManager route updating test', async () => {
|
||||
// Test will be skipped if not running as root due to tap.skip.test
|
||||
|
||||
// Create an updated version of the sample route
|
||||
const updatedRoute: IRouteConfig = {
|
||||
...sampleRoute,
|
||||
action: {
|
||||
...sampleRoute.action,
|
||||
targets: [{
|
||||
host: 'localhost',
|
||||
port: 9000 // Different port
|
||||
}],
|
||||
nftables: {
|
||||
...sampleRoute.action.nftables,
|
||||
protocol: 'all' // Different protocol
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Update the route
|
||||
const result = await manager.updateRoute(sampleRoute, updatedRoute);
|
||||
|
||||
// Verify the route was updated successfully
|
||||
expect(result).toEqual(true);
|
||||
|
||||
// Verify the old route is no longer provisioned
|
||||
expect(manager.isRouteProvisioned(sampleRoute)).toEqual(false);
|
||||
|
||||
// Verify the new route is provisioned
|
||||
expect(manager.isRouteProvisioned(updatedRoute)).toEqual(true);
|
||||
});
|
||||
|
||||
tap.skip.test('NFTablesManager route deprovisioning test', async () => {
|
||||
// Test will be skipped if not running as root due to tap.skip.test
|
||||
|
||||
// Create an updated version of the sample route from the previous test
|
||||
const updatedRoute: IRouteConfig = {
|
||||
...sampleRoute,
|
||||
action: {
|
||||
...sampleRoute.action,
|
||||
targets: [{
|
||||
host: 'localhost',
|
||||
port: 9000 // Different port from original test
|
||||
}],
|
||||
nftables: {
|
||||
...sampleRoute.action.nftables,
|
||||
protocol: 'all' // Different protocol from original test
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Deprovision the route
|
||||
const result = await manager.deprovisionRoute(updatedRoute);
|
||||
|
||||
// Verify the route was deprovisioned successfully
|
||||
expect(result).toEqual(true);
|
||||
|
||||
// Verify the route is no longer provisioned
|
||||
expect(manager.isRouteProvisioned(updatedRoute)).toEqual(false);
|
||||
});
|
||||
|
||||
tap.skip.test('NFTablesManager cleanup test', async () => {
|
||||
// Test will be skipped if not running as root due to tap.skip.test
|
||||
|
||||
// Stop all NFTables rules
|
||||
await manager.stop();
|
||||
|
||||
// Get the status of the managed rules
|
||||
const status = await manager.getStatus();
|
||||
|
||||
// Verify there are no active rules
|
||||
expect(Object.keys(status).length).toEqual(0);
|
||||
});
|
||||
|
||||
export default tap.start();
|
||||
@@ -1,166 +0,0 @@
|
||||
import { SmartProxy } from '../ts/proxies/smart-proxy/index.js';
|
||||
import { NFTablesManager } from '../ts/proxies/smart-proxy/nftables-manager.js';
|
||||
import { createNfTablesRoute } from '../ts/proxies/smart-proxy/utils/route-helpers.js';
|
||||
import { expect, tap } from '@git.zone/tstest/tapbundle';
|
||||
import * as child_process from 'child_process';
|
||||
import { promisify } from 'util';
|
||||
|
||||
const exec = promisify(child_process.exec);
|
||||
|
||||
// Check if we have root privileges
|
||||
async function checkRootPrivileges(): Promise<boolean> {
|
||||
try {
|
||||
const { stdout } = await exec('id -u');
|
||||
return stdout.trim() === '0';
|
||||
} catch (err) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Skip tests if not root
|
||||
const isRoot = await checkRootPrivileges();
|
||||
if (!isRoot) {
|
||||
console.log('');
|
||||
console.log('========================================');
|
||||
console.log('NFTables status tests require root privileges');
|
||||
console.log('Skipping NFTables status tests');
|
||||
console.log('========================================');
|
||||
console.log('');
|
||||
}
|
||||
|
||||
// Define the test function based on root privileges
|
||||
const testFn = isRoot ? tap.test : tap.skip.test;
|
||||
|
||||
testFn('NFTablesManager status functionality', async () => {
|
||||
const { SmartProxy } = await import('../ts/proxies/smart-proxy/smart-proxy.js');
|
||||
const proxy = new SmartProxy({ routes: [] });
|
||||
const nftablesManager = new NFTablesManager(proxy);
|
||||
|
||||
// Create test routes
|
||||
const testRoutes = [
|
||||
createNfTablesRoute('test-route-1', { host: 'localhost', port: 8080 }, { ports: 9080 }),
|
||||
createNfTablesRoute('test-route-2', { host: 'localhost', port: 8081 }, { ports: 9081 }),
|
||||
createNfTablesRoute('test-route-3', { host: 'localhost', port: 8082 }, {
|
||||
ports: 9082,
|
||||
ipAllowList: ['127.0.0.1', '192.168.1.0/24']
|
||||
})
|
||||
];
|
||||
|
||||
// Get initial status (should be empty)
|
||||
let status = await nftablesManager.getStatus();
|
||||
expect(Object.keys(status).length).toEqual(0);
|
||||
|
||||
// Provision routes
|
||||
for (const route of testRoutes) {
|
||||
await nftablesManager.provisionRoute(route);
|
||||
}
|
||||
|
||||
// Get status after provisioning
|
||||
status = await nftablesManager.getStatus();
|
||||
expect(Object.keys(status).length).toEqual(3);
|
||||
|
||||
// Check status structure
|
||||
for (const routeStatus of Object.values(status)) {
|
||||
expect(routeStatus).toHaveProperty('active');
|
||||
expect(routeStatus).toHaveProperty('ruleCount');
|
||||
expect(routeStatus).toHaveProperty('lastUpdate');
|
||||
expect(routeStatus.active).toBeTrue();
|
||||
}
|
||||
|
||||
// Deprovision one route
|
||||
await nftablesManager.deprovisionRoute(testRoutes[0]);
|
||||
|
||||
// Check status after deprovisioning
|
||||
status = await nftablesManager.getStatus();
|
||||
expect(Object.keys(status).length).toEqual(2);
|
||||
|
||||
// Cleanup remaining routes
|
||||
await nftablesManager.stop();
|
||||
|
||||
// Final status should be empty
|
||||
status = await nftablesManager.getStatus();
|
||||
expect(Object.keys(status).length).toEqual(0);
|
||||
});
|
||||
|
||||
testFn('SmartProxy getNfTablesStatus functionality', async () => {
|
||||
const smartProxy = new SmartProxy({
|
||||
routes: [
|
||||
createNfTablesRoute('proxy-test-1', { host: 'localhost', port: 3000 }, { ports: 3001 }),
|
||||
createNfTablesRoute('proxy-test-2', { host: 'localhost', port: 3002 }, { ports: 3003 }),
|
||||
// Include a non-NFTables route to ensure it's not included in the status
|
||||
{
|
||||
name: 'non-nftables-route',
|
||||
match: { ports: 3004 },
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{ host: 'localhost', port: 3005 }]
|
||||
}
|
||||
}
|
||||
]
|
||||
});
|
||||
|
||||
// Start the proxy
|
||||
await smartProxy.start();
|
||||
|
||||
// Get NFTables status
|
||||
const status = await smartProxy.getNfTablesStatus();
|
||||
|
||||
// Should only have 2 NFTables routes
|
||||
const statusKeys = Object.keys(status);
|
||||
expect(statusKeys.length).toEqual(2);
|
||||
|
||||
// Check that both NFTables routes are in the status
|
||||
const routeIds = statusKeys.sort();
|
||||
expect(routeIds).toContain('proxy-test-1:3001');
|
||||
expect(routeIds).toContain('proxy-test-2:3003');
|
||||
|
||||
// Verify status structure
|
||||
for (const [routeId, routeStatus] of Object.entries(status)) {
|
||||
expect(routeStatus).toHaveProperty('active', true);
|
||||
expect(routeStatus).toHaveProperty('ruleCount');
|
||||
expect(routeStatus.ruleCount).toHaveProperty('total');
|
||||
expect(routeStatus.ruleCount.total).toBeGreaterThan(0);
|
||||
}
|
||||
|
||||
// Stop the proxy
|
||||
await smartProxy.stop();
|
||||
|
||||
// After stopping, status should be empty
|
||||
const finalStatus = await smartProxy.getNfTablesStatus();
|
||||
expect(Object.keys(finalStatus).length).toEqual(0);
|
||||
});
|
||||
|
||||
testFn('NFTables route update status tracking', async () => {
|
||||
const smartProxy = new SmartProxy({
|
||||
routes: [
|
||||
createNfTablesRoute('update-test', { host: 'localhost', port: 4000 }, { ports: 4001 })
|
||||
]
|
||||
});
|
||||
|
||||
await smartProxy.start();
|
||||
|
||||
// Get initial status
|
||||
let status = await smartProxy.getNfTablesStatus();
|
||||
expect(Object.keys(status).length).toEqual(1);
|
||||
const initialUpdate = status['update-test:4001'].lastUpdate;
|
||||
|
||||
// Wait a moment
|
||||
await new Promise(resolve => setTimeout(resolve, 10));
|
||||
|
||||
// Update the route
|
||||
await smartProxy.updateRoutes([
|
||||
createNfTablesRoute('update-test', { host: 'localhost', port: 4002 }, { ports: 4001 })
|
||||
]);
|
||||
|
||||
// Get status after update
|
||||
status = await smartProxy.getNfTablesStatus();
|
||||
expect(Object.keys(status).length).toEqual(1);
|
||||
const updatedTime = status['update-test:4001'].lastUpdate;
|
||||
|
||||
// The update time should be different
|
||||
expect(updatedTime.getTime()).toBeGreaterThan(initialUpdate.getTime());
|
||||
|
||||
await smartProxy.stop();
|
||||
});
|
||||
|
||||
export default tap.start();
|
||||
@@ -1,281 +0,0 @@
|
||||
import { expect, tap } from '@git.zone/tstest/tapbundle';
|
||||
import { SmartProxy } from '../ts/index.js';
|
||||
|
||||
/**
|
||||
* Test that verifies port 80 is not double-registered when both
|
||||
* user routes and ACME challenges use the same port
|
||||
*/
|
||||
tap.test('should not double-register port 80 when user route and ACME use same port', async (tools) => {
|
||||
tools.timeout(5000);
|
||||
|
||||
let port80AddCount = 0;
|
||||
const activePorts = new Set<number>();
|
||||
|
||||
const settings = {
|
||||
port: 9901,
|
||||
routes: [
|
||||
{
|
||||
name: 'user-route',
|
||||
match: {
|
||||
ports: [80]
|
||||
},
|
||||
action: {
|
||||
type: 'forward' as const,
|
||||
targets: [{ host: 'localhost', port: 3000 }]
|
||||
}
|
||||
},
|
||||
{
|
||||
name: 'secure-route',
|
||||
match: {
|
||||
ports: [443]
|
||||
},
|
||||
action: {
|
||||
type: 'forward' as const,
|
||||
targets: [{ host: 'localhost', port: 3001 }],
|
||||
tls: {
|
||||
mode: 'terminate' as const,
|
||||
certificate: 'auto' as const
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
acme: {
|
||||
email: 'test@test.com',
|
||||
port: 80 // ACME on same port as user route
|
||||
}
|
||||
};
|
||||
|
||||
const proxy = new SmartProxy(settings);
|
||||
|
||||
// Mock the port manager to track port additions
|
||||
const mockPortManager = {
|
||||
addPort: async (port: number) => {
|
||||
if (activePorts.has(port)) {
|
||||
return; // Simulate deduplication
|
||||
}
|
||||
activePorts.add(port);
|
||||
if (port === 80) {
|
||||
port80AddCount++;
|
||||
}
|
||||
},
|
||||
addPorts: async (ports: number[]) => {
|
||||
for (const port of ports) {
|
||||
await mockPortManager.addPort(port);
|
||||
}
|
||||
},
|
||||
updatePorts: async (requiredPorts: Set<number>) => {
|
||||
for (const port of requiredPorts) {
|
||||
await mockPortManager.addPort(port);
|
||||
}
|
||||
},
|
||||
setShuttingDown: () => {},
|
||||
closeAll: async () => { activePorts.clear(); },
|
||||
stop: async () => { await mockPortManager.closeAll(); }
|
||||
};
|
||||
|
||||
// Inject mock
|
||||
(proxy as any).portManager = mockPortManager;
|
||||
|
||||
// Mock certificate manager to prevent ACME calls
|
||||
(proxy as any).createCertificateManager = async function(routes: any[], certDir: string, acmeOptions: any, initialState?: any) {
|
||||
const mockCertManager = {
|
||||
setUpdateRoutesCallback: function(callback: any) { /* noop */ },
|
||||
setHttpProxy: function() {},
|
||||
setGlobalAcmeDefaults: function() {},
|
||||
setAcmeStateManager: function() {},
|
||||
initialize: async function() {
|
||||
// Simulate ACME route addition
|
||||
const challengeRoute = {
|
||||
name: 'acme-challenge',
|
||||
priority: 1000,
|
||||
match: {
|
||||
ports: acmeOptions?.port || 80,
|
||||
path: '/.well-known/acme-challenge/*'
|
||||
},
|
||||
action: {
|
||||
type: 'static'
|
||||
}
|
||||
};
|
||||
// This would trigger route update in real implementation
|
||||
},
|
||||
provisionAllCertificates: async function() {
|
||||
// Mock implementation to satisfy the call in SmartProxy.start()
|
||||
// Add the ACME challenge port here too in case initialize was skipped
|
||||
const challengePort = acmeOptions?.port || 80;
|
||||
await mockPortManager.addPort(challengePort);
|
||||
console.log(`Added ACME challenge port from provisionAllCertificates: ${challengePort}`);
|
||||
},
|
||||
getAcmeOptions: () => acmeOptions,
|
||||
getState: () => ({ challengeRouteActive: false }),
|
||||
stop: async () => {}
|
||||
};
|
||||
return mockCertManager;
|
||||
};
|
||||
|
||||
// Mock NFTables
|
||||
(proxy as any).nftablesManager = {
|
||||
ensureNFTablesSetup: async () => {},
|
||||
stop: async () => {}
|
||||
};
|
||||
|
||||
// Mock admin server
|
||||
(proxy as any).startAdminServer = async function() {
|
||||
(this as any).servers.set(this.settings.port, {
|
||||
port: this.settings.port,
|
||||
close: async () => {}
|
||||
});
|
||||
};
|
||||
|
||||
await proxy.start();
|
||||
|
||||
// Verify that port 80 was added only once
|
||||
expect(port80AddCount).toEqual(1);
|
||||
|
||||
await proxy.stop();
|
||||
});
|
||||
|
||||
/**
|
||||
* Test that verifies ACME can use a different port than user routes
|
||||
*/
|
||||
tap.test('should handle ACME on different port than user routes', async (tools) => {
|
||||
tools.timeout(5000);
|
||||
|
||||
const portAddHistory: number[] = [];
|
||||
const activePorts = new Set<number>();
|
||||
|
||||
const settings = {
|
||||
port: 9902,
|
||||
routes: [
|
||||
{
|
||||
name: 'user-route',
|
||||
match: {
|
||||
ports: [80]
|
||||
},
|
||||
action: {
|
||||
type: 'forward' as const,
|
||||
targets: [{ host: 'localhost', port: 3000 }]
|
||||
}
|
||||
},
|
||||
{
|
||||
name: 'secure-route',
|
||||
match: {
|
||||
ports: [443]
|
||||
},
|
||||
action: {
|
||||
type: 'forward' as const,
|
||||
targets: [{ host: 'localhost', port: 3001 }],
|
||||
tls: {
|
||||
mode: 'terminate' as const,
|
||||
certificate: 'auto' as const
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
acme: {
|
||||
email: 'test@test.com',
|
||||
port: 8080 // ACME on different port than user routes
|
||||
}
|
||||
};
|
||||
|
||||
const proxy = new SmartProxy(settings);
|
||||
|
||||
// Mock the port manager
|
||||
const mockPortManager = {
|
||||
addPort: async (port: number) => {
|
||||
console.log(`Attempting to add port: ${port}`);
|
||||
if (!activePorts.has(port)) {
|
||||
activePorts.add(port);
|
||||
portAddHistory.push(port);
|
||||
console.log(`Port ${port} added to history`);
|
||||
} else {
|
||||
console.log(`Port ${port} already active, not adding to history`);
|
||||
}
|
||||
},
|
||||
addPorts: async (ports: number[]) => {
|
||||
for (const port of ports) {
|
||||
await mockPortManager.addPort(port);
|
||||
}
|
||||
},
|
||||
updatePorts: async (requiredPorts: Set<number>) => {
|
||||
for (const port of requiredPorts) {
|
||||
await mockPortManager.addPort(port);
|
||||
}
|
||||
},
|
||||
setShuttingDown: () => {},
|
||||
closeAll: async () => { activePorts.clear(); },
|
||||
stop: async () => { await mockPortManager.closeAll(); }
|
||||
};
|
||||
|
||||
// Inject mocks
|
||||
(proxy as any).portManager = mockPortManager;
|
||||
|
||||
// Mock certificate manager
|
||||
(proxy as any).createCertificateManager = async function(routes: any[], certDir: string, acmeOptions: any, initialState?: any) {
|
||||
const mockCertManager = {
|
||||
setUpdateRoutesCallback: function(callback: any) { /* noop */ },
|
||||
setHttpProxy: function() {},
|
||||
setGlobalAcmeDefaults: function() {},
|
||||
setAcmeStateManager: function() {},
|
||||
initialize: async function() {
|
||||
// Simulate ACME route addition on different port
|
||||
const challengePort = acmeOptions?.port || 80;
|
||||
const challengeRoute = {
|
||||
name: 'acme-challenge',
|
||||
priority: 1000,
|
||||
match: {
|
||||
ports: challengePort,
|
||||
path: '/.well-known/acme-challenge/*'
|
||||
},
|
||||
action: {
|
||||
type: 'static'
|
||||
}
|
||||
};
|
||||
|
||||
// Add the ACME port to our port tracking
|
||||
await mockPortManager.addPort(challengePort);
|
||||
|
||||
// For debugging
|
||||
console.log(`Added ACME challenge port: ${challengePort}`);
|
||||
},
|
||||
provisionAllCertificates: async function() {
|
||||
// Mock implementation to satisfy the call in SmartProxy.start()
|
||||
// Add the ACME challenge port here too in case initialize was skipped
|
||||
const challengePort = acmeOptions?.port || 80;
|
||||
await mockPortManager.addPort(challengePort);
|
||||
console.log(`Added ACME challenge port from provisionAllCertificates: ${challengePort}`);
|
||||
},
|
||||
getAcmeOptions: () => acmeOptions,
|
||||
getState: () => ({ challengeRouteActive: false }),
|
||||
stop: async () => {}
|
||||
};
|
||||
return mockCertManager;
|
||||
};
|
||||
|
||||
// Mock NFTables
|
||||
(proxy as any).nftablesManager = {
|
||||
ensureNFTablesSetup: async () => {},
|
||||
stop: async () => {}
|
||||
};
|
||||
|
||||
// Mock admin server
|
||||
(proxy as any).startAdminServer = async function() {
|
||||
(this as any).servers.set(this.settings.port, {
|
||||
port: this.settings.port,
|
||||
close: async () => {}
|
||||
});
|
||||
};
|
||||
|
||||
await proxy.start();
|
||||
|
||||
// Log the port history for debugging
|
||||
console.log('Port add history:', portAddHistory);
|
||||
|
||||
// Verify that all expected ports were added
|
||||
expect(portAddHistory.includes(80)).toBeTrue(); // User route
|
||||
expect(portAddHistory.includes(443)).toBeTrue(); // TLS route
|
||||
expect(portAddHistory.includes(8080)).toBeTrue(); // ACME challenge on different port
|
||||
|
||||
await proxy.stop();
|
||||
});
|
||||
|
||||
export default tap.start();
|
||||
@@ -1,182 +0,0 @@
|
||||
import { expect, tap } from '@git.zone/tstest/tapbundle';
|
||||
import * as plugins from '../ts/plugins.js';
|
||||
import { SmartProxy } from '../ts/index.js';
|
||||
|
||||
let outerProxy: SmartProxy;
|
||||
let innerProxy: SmartProxy;
|
||||
|
||||
tap.test('setup two smartproxies in a chain configuration', async () => {
|
||||
// Setup inner proxy (backend proxy)
|
||||
innerProxy = new SmartProxy({
|
||||
routes: [
|
||||
{
|
||||
name: 'inner-backend',
|
||||
match: {
|
||||
ports: 8002
|
||||
},
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{
|
||||
host: 'httpbin.org',
|
||||
port: 443
|
||||
}]
|
||||
}
|
||||
}
|
||||
],
|
||||
defaults: {
|
||||
target: {
|
||||
host: 'httpbin.org',
|
||||
port: 443
|
||||
}
|
||||
},
|
||||
acceptProxyProtocol: true,
|
||||
sendProxyProtocol: false,
|
||||
enableDetailedLogging: true,
|
||||
inactivityTimeout: 10000 // Shorter timeout for testing
|
||||
});
|
||||
await innerProxy.start();
|
||||
|
||||
// Setup outer proxy (frontend proxy)
|
||||
outerProxy = new SmartProxy({
|
||||
routes: [
|
||||
{
|
||||
name: 'outer-frontend',
|
||||
match: {
|
||||
ports: 8001
|
||||
},
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{
|
||||
host: 'localhost',
|
||||
port: 8002
|
||||
}],
|
||||
sendProxyProtocol: true
|
||||
}
|
||||
}
|
||||
],
|
||||
defaults: {
|
||||
target: {
|
||||
host: 'localhost',
|
||||
port: 8002
|
||||
}
|
||||
},
|
||||
sendProxyProtocol: true,
|
||||
enableDetailedLogging: true,
|
||||
inactivityTimeout: 10000 // Shorter timeout for testing
|
||||
});
|
||||
await outerProxy.start();
|
||||
});
|
||||
|
||||
tap.test('should properly cleanup connections in proxy chain', async (tools) => {
|
||||
const testDuration = 30000; // 30 seconds
|
||||
const connectionInterval = 500; // Create new connection every 500ms
|
||||
const connectionDuration = 2000; // Each connection lasts 2 seconds
|
||||
|
||||
let connectionsCreated = 0;
|
||||
let connectionsCompleted = 0;
|
||||
|
||||
// Function to create a test connection
|
||||
const createTestConnection = async () => {
|
||||
connectionsCreated++;
|
||||
const connectionId = connectionsCreated;
|
||||
|
||||
try {
|
||||
const socket = plugins.net.connect({
|
||||
port: 8001,
|
||||
host: 'localhost'
|
||||
});
|
||||
|
||||
await new Promise<void>((resolve, reject) => {
|
||||
socket.on('connect', () => {
|
||||
console.log(`Connection ${connectionId} established`);
|
||||
|
||||
// Send TLS Client Hello for httpbin.org
|
||||
const clientHello = Buffer.from([
|
||||
0x16, 0x03, 0x01, 0x00, 0xc8, // TLS handshake header
|
||||
0x01, 0x00, 0x00, 0xc4, // Client Hello
|
||||
0x03, 0x03, // TLS 1.2
|
||||
...Array(32).fill(0), // Random bytes
|
||||
0x00, // Session ID length
|
||||
0x00, 0x02, 0x13, 0x01, // Cipher suites
|
||||
0x01, 0x00, // Compression methods
|
||||
0x00, 0x97, // Extensions length
|
||||
0x00, 0x00, 0x00, 0x0f, 0x00, 0x0d, // SNI extension
|
||||
0x00, 0x00, 0x0a, 0x68, 0x74, 0x74, 0x70, 0x62, 0x69, 0x6e, 0x2e, 0x6f, 0x72, 0x67 // "httpbin.org"
|
||||
]);
|
||||
|
||||
socket.write(clientHello);
|
||||
|
||||
// Keep connection alive for specified duration
|
||||
setTimeout(() => {
|
||||
socket.destroy();
|
||||
connectionsCompleted++;
|
||||
console.log(`Connection ${connectionId} closed (completed: ${connectionsCompleted}/${connectionsCreated})`);
|
||||
resolve();
|
||||
}, connectionDuration);
|
||||
});
|
||||
|
||||
socket.on('error', (err) => {
|
||||
console.log(`Connection ${connectionId} error: ${err.message}`);
|
||||
connectionsCompleted++;
|
||||
reject(err);
|
||||
});
|
||||
});
|
||||
} catch (err) {
|
||||
console.log(`Failed to create connection ${connectionId}: ${err.message}`);
|
||||
connectionsCompleted++;
|
||||
}
|
||||
};
|
||||
|
||||
// Start creating connections
|
||||
const startTime = Date.now();
|
||||
const connectionTimer = setInterval(() => {
|
||||
if (Date.now() - startTime < testDuration) {
|
||||
createTestConnection().catch(() => {});
|
||||
} else {
|
||||
clearInterval(connectionTimer);
|
||||
}
|
||||
}, connectionInterval);
|
||||
|
||||
// Monitor connection counts
|
||||
const monitorInterval = setInterval(() => {
|
||||
const outerConnections = (outerProxy as any).connectionManager.getConnectionCount();
|
||||
const innerConnections = (innerProxy as any).connectionManager.getConnectionCount();
|
||||
|
||||
console.log(`Active connections - Outer: ${outerConnections}, Inner: ${innerConnections}, Created: ${connectionsCreated}, Completed: ${connectionsCompleted}`);
|
||||
}, 2000);
|
||||
|
||||
// Wait for test duration + cleanup time
|
||||
await tools.delayFor(testDuration + 10000);
|
||||
|
||||
clearInterval(connectionTimer);
|
||||
clearInterval(monitorInterval);
|
||||
|
||||
// Wait for all connections to complete
|
||||
while (connectionsCompleted < connectionsCreated) {
|
||||
await tools.delayFor(100);
|
||||
}
|
||||
|
||||
// Give some time for cleanup
|
||||
await tools.delayFor(5000);
|
||||
|
||||
// Check final connection counts
|
||||
const finalOuterConnections = (outerProxy as any).connectionManager.getConnectionCount();
|
||||
const finalInnerConnections = (innerProxy as any).connectionManager.getConnectionCount();
|
||||
|
||||
console.log(`\nFinal connection counts:`);
|
||||
console.log(`Outer proxy: ${finalOuterConnections}`);
|
||||
console.log(`Inner proxy: ${finalInnerConnections}`);
|
||||
console.log(`Total created: ${connectionsCreated}`);
|
||||
console.log(`Total completed: ${connectionsCompleted}`);
|
||||
|
||||
// Both proxies should have cleaned up all connections
|
||||
expect(finalOuterConnections).toEqual(0);
|
||||
expect(finalInnerConnections).toEqual(0);
|
||||
});
|
||||
|
||||
tap.test('cleanup proxies', async () => {
|
||||
await outerProxy.stop();
|
||||
await innerProxy.stop();
|
||||
});
|
||||
|
||||
export default tap.start();
|
||||
@@ -1,193 +0,0 @@
|
||||
import { tap, expect } from '@git.zone/tstest/tapbundle';
|
||||
import * as net from 'net';
|
||||
import * as plugins from '../ts/plugins.js';
|
||||
|
||||
// Import SmartProxy and configurations
|
||||
import { SmartProxy } from '../ts/index.js';
|
||||
|
||||
tap.test('simple proxy chain test - identify connection accumulation', async () => {
|
||||
console.log('\n=== Simple Proxy Chain Test ===');
|
||||
console.log('Setup: Client → SmartProxy1 (8590) → SmartProxy2 (8591) → Backend (down)');
|
||||
|
||||
// Create backend server that accepts and immediately closes connections
|
||||
const backend = net.createServer((socket) => {
|
||||
console.log('Backend: Connection received, closing immediately');
|
||||
socket.destroy();
|
||||
});
|
||||
|
||||
await new Promise<void>((resolve) => {
|
||||
backend.listen(9998, () => {
|
||||
console.log('✓ Backend server started on port 9998 (closes connections immediately)');
|
||||
resolve();
|
||||
});
|
||||
});
|
||||
|
||||
// Create SmartProxy2 (downstream)
|
||||
const proxy2 = new SmartProxy({
|
||||
enableDetailedLogging: true,
|
||||
socketTimeout: 5000,
|
||||
routes: [{
|
||||
name: 'to-backend',
|
||||
match: { ports: 8591 },
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{
|
||||
host: 'localhost',
|
||||
port: 9998 // Backend that closes immediately
|
||||
}]
|
||||
}
|
||||
}]
|
||||
});
|
||||
|
||||
// Create SmartProxy1 (upstream)
|
||||
const proxy1 = new SmartProxy({
|
||||
enableDetailedLogging: true,
|
||||
socketTimeout: 5000,
|
||||
routes: [{
|
||||
name: 'to-proxy2',
|
||||
match: { ports: 8590 },
|
||||
action: {
|
||||
type: 'forward',
|
||||
targets: [{
|
||||
host: 'localhost',
|
||||
port: 8591 // Forward to proxy2
|
||||
}]
|
||||
}
|
||||
}]
|
||||
});
|
||||
|
||||
await proxy2.start();
|
||||
console.log('✓ SmartProxy2 started on port 8591');
|
||||
|
||||
await proxy1.start();
|
||||
console.log('✓ SmartProxy1 started on port 8590');
|
||||
|
||||
// Helper to get connection counts
|
||||
const getConnectionCounts = () => {
|
||||
const conn1 = (proxy1 as any).connectionManager;
|
||||
const conn2 = (proxy2 as any).connectionManager;
|
||||
return {
|
||||
proxy1: conn1 ? conn1.getConnectionCount() : 0,
|
||||
proxy2: conn2 ? conn2.getConnectionCount() : 0
|
||||
};
|
||||
};
|
||||
|
||||
console.log('\n--- Making 5 sequential connections ---');
|
||||
|
||||
for (let i = 0; i < 5; i++) {
|
||||
console.log(`\n=== Connection ${i + 1} ===`);
|
||||
|
||||
const counts = getConnectionCounts();
|
||||
console.log(`Before: Proxy1=${counts.proxy1}, Proxy2=${counts.proxy2}`);
|
||||
|
||||
await new Promise<void>((resolve) => {
|
||||
const client = new net.Socket();
|
||||
let dataReceived = false;
|
||||
|
||||
client.on('data', (data) => {
|
||||
console.log(`Client received data: ${data.toString()}`);
|
||||
dataReceived = true;
|
||||
});
|
||||
|
||||
client.on('error', (err: NodeJS.ErrnoException) => {
|
||||
console.log(`Client error: ${err.code}`);
|
||||
resolve();
|
||||
});
|
||||
|
||||
client.on('close', () => {
|
||||
console.log(`Client closed (data received: ${dataReceived})`);
|
||||
resolve();
|
||||
});
|
||||
|
||||
client.connect(8590, 'localhost', () => {
|
||||
console.log('Client connected to Proxy1');
|
||||
// Send HTTP request
|
||||
client.write('GET / HTTP/1.1\r\nHost: test.com\r\n\r\n');
|
||||
});
|
||||
|
||||
// Timeout
|
||||
setTimeout(() => {
|
||||
if (!client.destroyed) {
|
||||
console.log('Client timeout, destroying');
|
||||
client.destroy();
|
||||
}
|
||||
resolve();
|
||||
}, 2000);
|
||||
});
|
||||
|
||||
// Wait a bit and check counts
|
||||
await new Promise(resolve => setTimeout(resolve, 500));
|
||||
|
||||
const afterCounts = getConnectionCounts();
|
||||
console.log(`After: Proxy1=${afterCounts.proxy1}, Proxy2=${afterCounts.proxy2}`);
|
||||
|
||||
if (afterCounts.proxy1 > 0 || afterCounts.proxy2 > 0) {
|
||||
console.log('⚠️ WARNING: Connections not cleaned up!');
|
||||
}
|
||||
}
|
||||
|
||||
console.log('\n--- Test with backend completely down ---');
|
||||
|
||||
// Stop backend
|
||||
backend.close();
|
||||
await new Promise(resolve => setTimeout(resolve, 100));
|
||||
console.log('✓ Backend stopped');
|
||||
|
||||
// Make more connections with backend down
|
||||
for (let i = 0; i < 3; i++) {
|
||||
console.log(`\n=== Connection ${i + 6} (backend down) ===`);
|
||||
|
||||
const counts = getConnectionCounts();
|
||||
console.log(`Before: Proxy1=${counts.proxy1}, Proxy2=${counts.proxy2}`);
|
||||
|
||||
await new Promise<void>((resolve) => {
|
||||
const client = new net.Socket();
|
||||
|
||||
client.on('error', () => {
|
||||
resolve();
|
||||
});
|
||||
|
||||
client.on('close', () => {
|
||||
resolve();
|
||||
});
|
||||
|
||||
client.connect(8590, 'localhost', () => {
|
||||
client.write('GET / HTTP/1.1\r\nHost: test.com\r\n\r\n');
|
||||
});
|
||||
|
||||
setTimeout(() => {
|
||||
if (!client.destroyed) {
|
||||
client.destroy();
|
||||
}
|
||||
resolve();
|
||||
}, 1000);
|
||||
});
|
||||
|
||||
await new Promise(resolve => setTimeout(resolve, 500));
|
||||
|
||||
const afterCounts = getConnectionCounts();
|
||||
console.log(`After: Proxy1=${afterCounts.proxy1}, Proxy2=${afterCounts.proxy2}`);
|
||||
}
|
||||
|
||||
// Final check
|
||||
console.log('\n--- Final Check ---');
|
||||
await new Promise(resolve => setTimeout(resolve, 1000));
|
||||
|
||||
const finalCounts = getConnectionCounts();
|
||||
console.log(`Final counts: Proxy1=${finalCounts.proxy1}, Proxy2=${finalCounts.proxy2}`);
|
||||
|
||||
await proxy1.stop();
|
||||
await proxy2.stop();
|
||||
|
||||
// Verify
|
||||
if (finalCounts.proxy1 > 0 || finalCounts.proxy2 > 0) {
|
||||
console.log('\n❌ FAIL: Connections accumulated!');
|
||||
} else {
|
||||
console.log('\n✅ PASS: No connection accumulation');
|
||||
}
|
||||
|
||||
expect(finalCounts.proxy1).toEqual(0);
|
||||
expect(finalCounts.proxy2).toEqual(0);
|
||||
});
|
||||
|
||||
export default tap.start();
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user