469 lines
18 KiB
TypeScript
469 lines
18 KiB
TypeScript
import * as plugins from '../../plugins.js';
|
|
import '../../core/models/socket-augmentation.js';
|
|
import { type INetworkProxyOptions, type IWebSocketWithHeartbeat, type ILogger, createLogger, type IReverseProxyConfig } from './models/types.js';
|
|
import { ConnectionPool } from './connection-pool.js';
|
|
import { ProxyRouter, RouteRouter } from '../../http/router/index.js';
|
|
import type { IRouteConfig } from '../smart-proxy/models/route-types.js';
|
|
import type { IRouteContext } from '../../core/models/route-context.js';
|
|
import { toBaseContext } from '../../core/models/route-context.js';
|
|
import { ContextCreator } from './context-creator.js';
|
|
import { SecurityManager } from './security-manager.js';
|
|
import { TemplateUtils } from '../../core/utils/template-utils.js';
|
|
import { getMessageSize, toBuffer } from '../../core/utils/websocket-utils.js';
|
|
|
|
/**
|
|
* Handles WebSocket connections and proxying
|
|
*/
|
|
export class WebSocketHandler {
|
|
private heartbeatInterval: NodeJS.Timeout | null = null;
|
|
private wsServer: plugins.ws.WebSocketServer | null = null;
|
|
private logger: ILogger;
|
|
private contextCreator: ContextCreator = new ContextCreator();
|
|
private routeRouter: RouteRouter | null = null;
|
|
private securityManager: SecurityManager;
|
|
|
|
constructor(
|
|
private options: INetworkProxyOptions,
|
|
private connectionPool: ConnectionPool,
|
|
private legacyRouter: ProxyRouter, // Legacy router for backward compatibility
|
|
private routes: IRouteConfig[] = [] // Routes for modern router
|
|
) {
|
|
this.logger = createLogger(options.logLevel || 'info');
|
|
this.securityManager = new SecurityManager(this.logger, routes);
|
|
|
|
// Initialize modern router if we have routes
|
|
if (routes.length > 0) {
|
|
this.routeRouter = new RouteRouter(routes, this.logger);
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Set the route configurations
|
|
*/
|
|
public setRoutes(routes: IRouteConfig[]): void {
|
|
this.routes = routes;
|
|
|
|
// Initialize or update the route router
|
|
if (!this.routeRouter) {
|
|
this.routeRouter = new RouteRouter(routes, this.logger);
|
|
} else {
|
|
this.routeRouter.setRoutes(routes);
|
|
}
|
|
|
|
// Update the security manager
|
|
this.securityManager.setRoutes(routes);
|
|
}
|
|
|
|
/**
|
|
* Initialize WebSocket server on an existing HTTPS server
|
|
*/
|
|
public initialize(server: plugins.https.Server): void {
|
|
// Create WebSocket server
|
|
this.wsServer = new plugins.ws.WebSocketServer({
|
|
server: server,
|
|
clientTracking: true
|
|
});
|
|
|
|
// Handle WebSocket connections
|
|
this.wsServer.on('connection', (wsIncoming: IWebSocketWithHeartbeat, req: plugins.http.IncomingMessage) => {
|
|
this.handleWebSocketConnection(wsIncoming, req);
|
|
});
|
|
|
|
// Start the heartbeat interval
|
|
this.startHeartbeat();
|
|
|
|
this.logger.info('WebSocket handler initialized');
|
|
}
|
|
|
|
/**
|
|
* Start the heartbeat interval to check for inactive WebSocket connections
|
|
*/
|
|
private startHeartbeat(): void {
|
|
// Clean up existing interval if any
|
|
if (this.heartbeatInterval) {
|
|
clearInterval(this.heartbeatInterval);
|
|
}
|
|
|
|
// Set up the heartbeat interval (check every 30 seconds)
|
|
this.heartbeatInterval = setInterval(() => {
|
|
if (!this.wsServer || this.wsServer.clients.size === 0) {
|
|
return; // Skip if no active connections
|
|
}
|
|
|
|
this.logger.debug(`WebSocket heartbeat check for ${this.wsServer.clients.size} clients`);
|
|
|
|
this.wsServer.clients.forEach((ws: plugins.wsDefault) => {
|
|
const wsWithHeartbeat = ws as IWebSocketWithHeartbeat;
|
|
|
|
if (wsWithHeartbeat.isAlive === false) {
|
|
this.logger.debug('Terminating inactive WebSocket connection');
|
|
return wsWithHeartbeat.terminate();
|
|
}
|
|
|
|
wsWithHeartbeat.isAlive = false;
|
|
wsWithHeartbeat.ping();
|
|
});
|
|
}, 30000);
|
|
|
|
// Make sure the interval doesn't keep the process alive
|
|
if (this.heartbeatInterval.unref) {
|
|
this.heartbeatInterval.unref();
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Handle a new WebSocket connection
|
|
*/
|
|
private handleWebSocketConnection(wsIncoming: IWebSocketWithHeartbeat, req: plugins.http.IncomingMessage): void {
|
|
try {
|
|
// Initialize heartbeat tracking
|
|
wsIncoming.isAlive = true;
|
|
wsIncoming.lastPong = Date.now();
|
|
|
|
// Handle pong messages to track liveness
|
|
wsIncoming.on('pong', () => {
|
|
wsIncoming.isAlive = true;
|
|
wsIncoming.lastPong = Date.now();
|
|
});
|
|
|
|
// Create a context for routing
|
|
const connectionId = `ws-${Date.now()}-${Math.floor(Math.random() * 10000)}`;
|
|
const routeContext = this.contextCreator.createHttpRouteContext(req, {
|
|
connectionId,
|
|
clientIp: req.socket.remoteAddress?.replace('::ffff:', '') || '0.0.0.0',
|
|
serverIp: req.socket.localAddress?.replace('::ffff:', '') || '0.0.0.0',
|
|
tlsVersion: req.socket.getTLSVersion?.() || undefined
|
|
});
|
|
|
|
// Try modern router first if available
|
|
let route: IRouteConfig | undefined;
|
|
if (this.routeRouter) {
|
|
route = this.routeRouter.routeReq(req);
|
|
}
|
|
|
|
// Define destination variables
|
|
let destination: { host: string; port: number };
|
|
|
|
// If we found a route with the modern router, use it
|
|
if (route && route.action.type === 'forward' && route.action.target) {
|
|
this.logger.debug(`Found matching WebSocket route: ${route.name || 'unnamed'}`);
|
|
|
|
// Check if WebSockets are enabled for this route
|
|
if (route.action.websocket?.enabled === false) {
|
|
this.logger.debug(`WebSockets are disabled for route: ${route.name || 'unnamed'}`);
|
|
wsIncoming.close(1003, 'WebSockets not supported for this route');
|
|
return;
|
|
}
|
|
|
|
// Check security restrictions if configured to authenticate WebSocket requests
|
|
if (route.action.websocket?.authenticateRequest !== false && route.security) {
|
|
if (!this.securityManager.isAllowed(route, toBaseContext(routeContext))) {
|
|
this.logger.warn(`WebSocket connection denied by security policy for ${routeContext.clientIp}`);
|
|
wsIncoming.close(1008, 'Access denied by security policy');
|
|
return;
|
|
}
|
|
|
|
// Check origin restrictions if configured
|
|
const origin = req.headers.origin;
|
|
if (origin && route.action.websocket?.allowedOrigins && route.action.websocket.allowedOrigins.length > 0) {
|
|
const isAllowed = route.action.websocket.allowedOrigins.some(allowedOrigin => {
|
|
// Handle wildcards and template variables
|
|
if (allowedOrigin.includes('*') || allowedOrigin.includes('{')) {
|
|
const pattern = allowedOrigin.replace(/\*/g, '.*');
|
|
const resolvedPattern = TemplateUtils.resolveTemplateVariables(pattern, routeContext);
|
|
const regex = new RegExp(`^${resolvedPattern}$`);
|
|
return regex.test(origin);
|
|
}
|
|
return allowedOrigin === origin;
|
|
});
|
|
|
|
if (!isAllowed) {
|
|
this.logger.warn(`WebSocket origin ${origin} not allowed for route: ${route.name || 'unnamed'}`);
|
|
wsIncoming.close(1008, 'Origin not allowed');
|
|
return;
|
|
}
|
|
}
|
|
}
|
|
|
|
// Extract target information, resolving functions if needed
|
|
let targetHost: string | string[];
|
|
let targetPort: number;
|
|
|
|
try {
|
|
// Resolve host if it's a function
|
|
if (typeof route.action.target.host === 'function') {
|
|
const resolvedHost = route.action.target.host(toBaseContext(routeContext));
|
|
targetHost = resolvedHost;
|
|
this.logger.debug(`Resolved function-based host for WebSocket: ${Array.isArray(resolvedHost) ? resolvedHost.join(', ') : resolvedHost}`);
|
|
} else {
|
|
targetHost = route.action.target.host;
|
|
}
|
|
|
|
// Resolve port if it's a function
|
|
if (typeof route.action.target.port === 'function') {
|
|
targetPort = route.action.target.port(toBaseContext(routeContext));
|
|
this.logger.debug(`Resolved function-based port for WebSocket: ${targetPort}`);
|
|
} else {
|
|
targetPort = route.action.target.port === 'preserve' ? routeContext.port : route.action.target.port as number;
|
|
}
|
|
|
|
// Select a single host if an array was provided
|
|
const selectedHost = Array.isArray(targetHost)
|
|
? targetHost[Math.floor(Math.random() * targetHost.length)]
|
|
: targetHost;
|
|
|
|
// Create a destination for the WebSocket connection
|
|
destination = {
|
|
host: selectedHost,
|
|
port: targetPort
|
|
};
|
|
} catch (err) {
|
|
this.logger.error(`Error evaluating function-based target for WebSocket: ${err}`);
|
|
wsIncoming.close(1011, 'Internal server error');
|
|
return;
|
|
}
|
|
} else {
|
|
// Fall back to legacy routing if no matching route found via modern router
|
|
const proxyConfig = this.legacyRouter.routeReq(req);
|
|
|
|
if (!proxyConfig) {
|
|
this.logger.warn(`No proxy configuration for WebSocket host: ${req.headers.host}`);
|
|
wsIncoming.close(1008, 'No proxy configuration for this host');
|
|
return;
|
|
}
|
|
|
|
// Get destination target using round-robin if multiple targets
|
|
destination = this.connectionPool.getNextTarget(
|
|
proxyConfig.destinationIps,
|
|
proxyConfig.destinationPorts[0]
|
|
);
|
|
}
|
|
|
|
// Build target URL with potential path rewriting
|
|
const protocol = (req.socket as any).encrypted ? 'wss' : 'ws';
|
|
let targetPath = req.url || '/';
|
|
|
|
// Apply path rewriting if configured
|
|
if (route?.action.websocket?.rewritePath) {
|
|
const originalPath = targetPath;
|
|
targetPath = TemplateUtils.resolveTemplateVariables(
|
|
route.action.websocket.rewritePath,
|
|
{...routeContext, path: targetPath}
|
|
);
|
|
this.logger.debug(`WebSocket path rewritten: ${originalPath} -> ${targetPath}`);
|
|
}
|
|
|
|
const targetUrl = `${protocol}://${destination.host}:${destination.port}${targetPath}`;
|
|
|
|
this.logger.debug(`WebSocket connection from ${req.socket.remoteAddress} to ${targetUrl}`);
|
|
|
|
// Create headers for outgoing WebSocket connection
|
|
const headers: { [key: string]: string } = {};
|
|
|
|
// Copy relevant headers from incoming request
|
|
for (const [key, value] of Object.entries(req.headers)) {
|
|
if (value && typeof value === 'string' &&
|
|
key.toLowerCase() !== 'connection' &&
|
|
key.toLowerCase() !== 'upgrade' &&
|
|
key.toLowerCase() !== 'sec-websocket-key' &&
|
|
key.toLowerCase() !== 'sec-websocket-version') {
|
|
headers[key] = value;
|
|
}
|
|
}
|
|
|
|
// Always rewrite host header for WebSockets for consistency
|
|
headers['host'] = `${destination.host}:${destination.port}`;
|
|
|
|
// Add custom headers from route configuration
|
|
if (route?.action.websocket?.customHeaders) {
|
|
for (const [key, value] of Object.entries(route.action.websocket.customHeaders)) {
|
|
// Skip if header already exists and we're not overriding
|
|
if (headers[key.toLowerCase()] && !value.startsWith('!')) {
|
|
continue;
|
|
}
|
|
|
|
// Handle special delete directive (!delete)
|
|
if (value === '!delete') {
|
|
delete headers[key.toLowerCase()];
|
|
continue;
|
|
}
|
|
|
|
// Handle forced override (!value)
|
|
let finalValue: string;
|
|
if (value.startsWith('!') && value !== '!delete') {
|
|
// Keep the ! but resolve any templates in the rest
|
|
const templateValue = value.substring(1);
|
|
finalValue = '!' + TemplateUtils.resolveTemplateVariables(templateValue, routeContext);
|
|
} else {
|
|
// Resolve templates in the entire value
|
|
finalValue = TemplateUtils.resolveTemplateVariables(value, routeContext);
|
|
}
|
|
|
|
// Set the header
|
|
headers[key.toLowerCase()] = finalValue;
|
|
}
|
|
}
|
|
|
|
// Create WebSocket connection options
|
|
const wsOptions: any = {
|
|
headers: headers,
|
|
followRedirects: true
|
|
};
|
|
|
|
// Add subprotocols if configured
|
|
if (route?.action.websocket?.subprotocols && route.action.websocket.subprotocols.length > 0) {
|
|
wsOptions.protocols = route.action.websocket.subprotocols;
|
|
} else if (req.headers['sec-websocket-protocol']) {
|
|
// Pass through client requested protocols
|
|
wsOptions.protocols = req.headers['sec-websocket-protocol'].split(',').map(p => p.trim());
|
|
}
|
|
|
|
// Create outgoing WebSocket connection
|
|
const wsOutgoing = new plugins.wsDefault(targetUrl, wsOptions);
|
|
|
|
// Handle connection errors
|
|
wsOutgoing.on('error', (err) => {
|
|
this.logger.error(`WebSocket target connection error: ${err.message}`);
|
|
if (wsIncoming.readyState === wsIncoming.OPEN) {
|
|
wsIncoming.close(1011, 'Internal server error');
|
|
}
|
|
});
|
|
|
|
// Handle outgoing connection open
|
|
wsOutgoing.on('open', () => {
|
|
// Set up custom ping interval if configured
|
|
let pingInterval: NodeJS.Timeout | null = null;
|
|
if (route?.action.websocket?.pingInterval && route.action.websocket.pingInterval > 0) {
|
|
pingInterval = setInterval(() => {
|
|
if (wsIncoming.readyState === wsIncoming.OPEN) {
|
|
wsIncoming.ping();
|
|
this.logger.debug(`Sent WebSocket ping to client for route: ${route.name || 'unnamed'}`);
|
|
}
|
|
}, route.action.websocket.pingInterval);
|
|
|
|
// Don't keep process alive just for pings
|
|
if (pingInterval.unref) pingInterval.unref();
|
|
}
|
|
|
|
// Set up custom ping timeout if configured
|
|
let pingTimeout: NodeJS.Timeout | null = null;
|
|
const pingTimeoutMs = route?.action.websocket?.pingTimeout || 60000; // Default 60s
|
|
|
|
// Define timeout function for cleaner code
|
|
const resetPingTimeout = () => {
|
|
if (pingTimeout) clearTimeout(pingTimeout);
|
|
pingTimeout = setTimeout(() => {
|
|
this.logger.debug(`WebSocket ping timeout for client connection on route: ${route?.name || 'unnamed'}`);
|
|
wsIncoming.terminate();
|
|
}, pingTimeoutMs);
|
|
|
|
// Don't keep process alive just for timeouts
|
|
if (pingTimeout.unref) pingTimeout.unref();
|
|
};
|
|
|
|
// Reset timeout on pong
|
|
wsIncoming.on('pong', () => {
|
|
wsIncoming.isAlive = true;
|
|
wsIncoming.lastPong = Date.now();
|
|
resetPingTimeout();
|
|
});
|
|
|
|
// Initial ping timeout
|
|
resetPingTimeout();
|
|
|
|
// Handle potential message size limits
|
|
const maxSize = route?.action.websocket?.maxPayloadSize || 0;
|
|
|
|
// Forward incoming messages to outgoing connection
|
|
wsIncoming.on('message', (data, isBinary) => {
|
|
if (wsOutgoing.readyState === wsOutgoing.OPEN) {
|
|
// Check message size if limit is set
|
|
const messageSize = getMessageSize(data);
|
|
if (maxSize > 0 && messageSize > maxSize) {
|
|
this.logger.warn(`WebSocket message exceeds max size (${messageSize} > ${maxSize})`);
|
|
wsIncoming.close(1009, 'Message too big');
|
|
return;
|
|
}
|
|
|
|
wsOutgoing.send(data, { binary: isBinary });
|
|
}
|
|
});
|
|
|
|
// Forward outgoing messages to incoming connection
|
|
wsOutgoing.on('message', (data, isBinary) => {
|
|
if (wsIncoming.readyState === wsIncoming.OPEN) {
|
|
wsIncoming.send(data, { binary: isBinary });
|
|
}
|
|
});
|
|
|
|
// Handle closing of connections
|
|
wsIncoming.on('close', (code, reason) => {
|
|
this.logger.debug(`WebSocket client connection closed: ${code} ${reason}`);
|
|
if (wsOutgoing.readyState === wsOutgoing.OPEN) {
|
|
wsOutgoing.close(code, reason);
|
|
}
|
|
|
|
// Clean up timers
|
|
if (pingInterval) clearInterval(pingInterval);
|
|
if (pingTimeout) clearTimeout(pingTimeout);
|
|
});
|
|
|
|
wsOutgoing.on('close', (code, reason) => {
|
|
this.logger.debug(`WebSocket target connection closed: ${code} ${reason}`);
|
|
if (wsIncoming.readyState === wsIncoming.OPEN) {
|
|
wsIncoming.close(code, reason);
|
|
}
|
|
|
|
// Clean up timers
|
|
if (pingInterval) clearInterval(pingInterval);
|
|
if (pingTimeout) clearTimeout(pingTimeout);
|
|
});
|
|
|
|
this.logger.debug(`WebSocket connection established: ${req.headers.host} -> ${destination.host}:${destination.port}`);
|
|
});
|
|
|
|
} catch (error) {
|
|
this.logger.error(`Error handling WebSocket connection: ${error.message}`);
|
|
if (wsIncoming.readyState === wsIncoming.OPEN) {
|
|
wsIncoming.close(1011, 'Internal server error');
|
|
}
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Get information about active WebSocket connections
|
|
*/
|
|
public getConnectionInfo(): { activeConnections: number } {
|
|
return {
|
|
activeConnections: this.wsServer ? this.wsServer.clients.size : 0
|
|
};
|
|
}
|
|
|
|
/**
|
|
* Shutdown the WebSocket handler
|
|
*/
|
|
public shutdown(): void {
|
|
// Stop heartbeat interval
|
|
if (this.heartbeatInterval) {
|
|
clearInterval(this.heartbeatInterval);
|
|
this.heartbeatInterval = null;
|
|
}
|
|
|
|
// Close all WebSocket connections
|
|
if (this.wsServer) {
|
|
this.logger.info(`Closing ${this.wsServer.clients.size} WebSocket connections`);
|
|
|
|
for (const client of this.wsServer.clients) {
|
|
try {
|
|
client.terminate();
|
|
} catch (error) {
|
|
this.logger.error('Error terminating WebSocket client', error);
|
|
}
|
|
}
|
|
|
|
// Close the server
|
|
this.wsServer.close();
|
|
this.wsServer = null;
|
|
}
|
|
}
|
|
} |