import * as plugins from '../../plugins.js'; import '../../core/models/socket-augmentation.js'; import { type IHttpProxyOptions, type IWebSocketWithHeartbeat, type ILogger, createLogger, type IReverseProxyConfig } from './models/types.js'; import { ConnectionPool } from './connection-pool.js'; import { ProxyRouter, RouteRouter } from '../../routing/router/index.js'; import type { IRouteConfig } from '../smart-proxy/models/route-types.js'; import type { IRouteContext } from '../../core/models/route-context.js'; import { toBaseContext } from '../../core/models/route-context.js'; import { ContextCreator } from './context-creator.js'; import { SecurityManager } from './security-manager.js'; import { TemplateUtils } from '../../core/utils/template-utils.js'; import { getMessageSize, toBuffer } from '../../core/utils/websocket-utils.js'; /** * Handles WebSocket connections and proxying */ export class WebSocketHandler { private heartbeatInterval: NodeJS.Timeout | null = null; private wsServer: plugins.ws.WebSocketServer | null = null; private logger: ILogger; private contextCreator: ContextCreator = new ContextCreator(); private routeRouter: RouteRouter | null = null; private securityManager: SecurityManager; constructor( private options: IHttpProxyOptions, private connectionPool: ConnectionPool, private legacyRouter: ProxyRouter, // Legacy router for backward compatibility private routes: IRouteConfig[] = [] // Routes for modern router ) { this.logger = createLogger(options.logLevel || 'info'); this.securityManager = new SecurityManager(this.logger, routes); // Initialize modern router if we have routes if (routes.length > 0) { this.routeRouter = new RouteRouter(routes, this.logger); } } /** * Set the route configurations */ public setRoutes(routes: IRouteConfig[]): void { this.routes = routes; // Initialize or update the route router if (!this.routeRouter) { this.routeRouter = new RouteRouter(routes, this.logger); } else { this.routeRouter.setRoutes(routes); } // Update the security manager this.securityManager.setRoutes(routes); } /** * Initialize WebSocket server on an existing HTTPS server */ public initialize(server: plugins.https.Server): void { // Create WebSocket server this.wsServer = new plugins.ws.WebSocketServer({ server: server, clientTracking: true }); // Handle WebSocket connections this.wsServer.on('connection', (wsIncoming: IWebSocketWithHeartbeat, req: plugins.http.IncomingMessage) => { this.handleWebSocketConnection(wsIncoming, req); }); // Start the heartbeat interval this.startHeartbeat(); this.logger.info('WebSocket handler initialized'); } /** * Start the heartbeat interval to check for inactive WebSocket connections */ private startHeartbeat(): void { // Clean up existing interval if any if (this.heartbeatInterval) { clearInterval(this.heartbeatInterval); } // Set up the heartbeat interval (check every 30 seconds) this.heartbeatInterval = setInterval(() => { if (!this.wsServer || this.wsServer.clients.size === 0) { return; // Skip if no active connections } this.logger.debug(`WebSocket heartbeat check for ${this.wsServer.clients.size} clients`); this.wsServer.clients.forEach((ws: plugins.wsDefault) => { const wsWithHeartbeat = ws as IWebSocketWithHeartbeat; if (wsWithHeartbeat.isAlive === false) { this.logger.debug('Terminating inactive WebSocket connection'); return wsWithHeartbeat.terminate(); } wsWithHeartbeat.isAlive = false; wsWithHeartbeat.ping(); }); }, 30000); // Make sure the interval doesn't keep the process alive if (this.heartbeatInterval.unref) { this.heartbeatInterval.unref(); } } /** * Handle a new WebSocket connection */ private handleWebSocketConnection(wsIncoming: IWebSocketWithHeartbeat, req: plugins.http.IncomingMessage): void { this.logger.debug(`WebSocket connection initiated from ${req.headers.host}`); try { // Initialize heartbeat tracking wsIncoming.isAlive = true; wsIncoming.lastPong = Date.now(); // Handle pong messages to track liveness wsIncoming.on('pong', () => { wsIncoming.isAlive = true; wsIncoming.lastPong = Date.now(); }); // Create a context for routing const connectionId = `ws-${Date.now()}-${Math.floor(Math.random() * 10000)}`; const routeContext = this.contextCreator.createHttpRouteContext(req, { connectionId, clientIp: req.socket.remoteAddress?.replace('::ffff:', '') || '0.0.0.0', serverIp: req.socket.localAddress?.replace('::ffff:', '') || '0.0.0.0', tlsVersion: req.socket.getTLSVersion?.() || undefined }); // Try modern router first if available let route: IRouteConfig | undefined; if (this.routeRouter) { route = this.routeRouter.routeReq(req); } // Define destination variables let destination: { host: string; port: number }; // If we found a route with the modern router, use it if (route && route.action.type === 'forward' && route.action.target) { this.logger.debug(`Found matching WebSocket route: ${route.name || 'unnamed'}`); // Check if WebSockets are enabled for this route if (route.action.websocket?.enabled === false) { this.logger.debug(`WebSockets are disabled for route: ${route.name || 'unnamed'}`); wsIncoming.close(1003, 'WebSockets not supported for this route'); return; } // Check security restrictions if configured to authenticate WebSocket requests if (route.action.websocket?.authenticateRequest !== false && route.security) { if (!this.securityManager.isAllowed(route, toBaseContext(routeContext))) { this.logger.warn(`WebSocket connection denied by security policy for ${routeContext.clientIp}`); wsIncoming.close(1008, 'Access denied by security policy'); return; } // Check origin restrictions if configured const origin = req.headers.origin; if (origin && route.action.websocket?.allowedOrigins && route.action.websocket.allowedOrigins.length > 0) { const isAllowed = route.action.websocket.allowedOrigins.some(allowedOrigin => { // Handle wildcards and template variables if (allowedOrigin.includes('*') || allowedOrigin.includes('{')) { const pattern = allowedOrigin.replace(/\*/g, '.*'); const resolvedPattern = TemplateUtils.resolveTemplateVariables(pattern, routeContext); const regex = new RegExp(`^${resolvedPattern}$`); return regex.test(origin); } return allowedOrigin === origin; }); if (!isAllowed) { this.logger.warn(`WebSocket origin ${origin} not allowed for route: ${route.name || 'unnamed'}`); wsIncoming.close(1008, 'Origin not allowed'); return; } } } // Extract target information, resolving functions if needed let targetHost: string | string[]; let targetPort: number; try { // Resolve host if it's a function if (typeof route.action.target.host === 'function') { const resolvedHost = route.action.target.host(toBaseContext(routeContext)); targetHost = resolvedHost; this.logger.debug(`Resolved function-based host for WebSocket: ${Array.isArray(resolvedHost) ? resolvedHost.join(', ') : resolvedHost}`); } else { targetHost = route.action.target.host; } // Resolve port if it's a function if (typeof route.action.target.port === 'function') { targetPort = route.action.target.port(toBaseContext(routeContext)); this.logger.debug(`Resolved function-based port for WebSocket: ${targetPort}`); } else { targetPort = route.action.target.port === 'preserve' ? routeContext.port : route.action.target.port as number; } // Select a single host if an array was provided const selectedHost = Array.isArray(targetHost) ? targetHost[Math.floor(Math.random() * targetHost.length)] : targetHost; // Create a destination for the WebSocket connection destination = { host: selectedHost, port: targetPort }; this.logger.debug(`WebSocket destination resolved: ${selectedHost}:${targetPort}`); } catch (err) { this.logger.error(`Error evaluating function-based target for WebSocket: ${err}`); wsIncoming.close(1011, 'Internal server error'); return; } } else { // Fall back to legacy routing if no matching route found via modern router const proxyConfig = this.legacyRouter.routeReq(req); if (!proxyConfig) { this.logger.warn(`No proxy configuration for WebSocket host: ${req.headers.host}`); wsIncoming.close(1008, 'No proxy configuration for this host'); return; } // Get destination target using round-robin if multiple targets destination = this.connectionPool.getNextTarget( proxyConfig.destinationIps, proxyConfig.destinationPorts[0] ); } // Build target URL with potential path rewriting // Determine protocol based on the target's configuration // For WebSocket connections, we use ws for HTTP backends and wss for HTTPS backends const isTargetSecure = destination.port === 443; const protocol = isTargetSecure ? 'wss' : 'ws'; let targetPath = req.url || '/'; // Apply path rewriting if configured if (route?.action.websocket?.rewritePath) { const originalPath = targetPath; targetPath = TemplateUtils.resolveTemplateVariables( route.action.websocket.rewritePath, {...routeContext, path: targetPath} ); this.logger.debug(`WebSocket path rewritten: ${originalPath} -> ${targetPath}`); } const targetUrl = `${protocol}://${destination.host}:${destination.port}${targetPath}`; this.logger.debug(`WebSocket connection from ${req.socket.remoteAddress} to ${targetUrl}`); // Create headers for outgoing WebSocket connection const headers: { [key: string]: string } = {}; // Copy relevant headers from incoming request for (const [key, value] of Object.entries(req.headers)) { if (value && typeof value === 'string' && key.toLowerCase() !== 'connection' && key.toLowerCase() !== 'upgrade' && key.toLowerCase() !== 'sec-websocket-key' && key.toLowerCase() !== 'sec-websocket-version') { headers[key] = value; } } // Always rewrite host header for WebSockets for consistency headers['host'] = `${destination.host}:${destination.port}`; // Add custom headers from route configuration if (route?.action.websocket?.customHeaders) { for (const [key, value] of Object.entries(route.action.websocket.customHeaders)) { // Skip if header already exists and we're not overriding if (headers[key.toLowerCase()] && !value.startsWith('!')) { continue; } // Handle special delete directive (!delete) if (value === '!delete') { delete headers[key.toLowerCase()]; continue; } // Handle forced override (!value) let finalValue: string; if (value.startsWith('!') && value !== '!delete') { // Keep the ! but resolve any templates in the rest const templateValue = value.substring(1); finalValue = '!' + TemplateUtils.resolveTemplateVariables(templateValue, routeContext); } else { // Resolve templates in the entire value finalValue = TemplateUtils.resolveTemplateVariables(value, routeContext); } // Set the header headers[key.toLowerCase()] = finalValue; } } // Create WebSocket connection options const wsOptions: any = { headers: headers, followRedirects: true }; // Add subprotocols if configured if (route?.action.websocket?.subprotocols && route.action.websocket.subprotocols.length > 0) { wsOptions.protocols = route.action.websocket.subprotocols; } else if (req.headers['sec-websocket-protocol']) { // Pass through client requested protocols wsOptions.protocols = req.headers['sec-websocket-protocol'].split(',').map(p => p.trim()); } // Create outgoing WebSocket connection this.logger.debug(`Creating WebSocket connection to ${targetUrl} with options:`, { headers: wsOptions.headers, protocols: wsOptions.protocols }); const wsOutgoing = new plugins.wsDefault(targetUrl, wsOptions); this.logger.debug(`WebSocket instance created, waiting for connection...`); // Handle connection errors wsOutgoing.on('error', (err) => { this.logger.error(`WebSocket target connection error: ${err.message}`); if (wsIncoming.readyState === wsIncoming.OPEN) { wsIncoming.close(1011, 'Internal server error'); } }); // Handle outgoing connection open wsOutgoing.on('open', () => { this.logger.debug(`WebSocket target connection opened to ${targetUrl}`); // Set up custom ping interval if configured let pingInterval: NodeJS.Timeout | null = null; if (route?.action.websocket?.pingInterval && route.action.websocket.pingInterval > 0) { pingInterval = setInterval(() => { if (wsIncoming.readyState === wsIncoming.OPEN) { wsIncoming.ping(); this.logger.debug(`Sent WebSocket ping to client for route: ${route.name || 'unnamed'}`); } }, route.action.websocket.pingInterval); // Don't keep process alive just for pings if (pingInterval.unref) pingInterval.unref(); } // Set up custom ping timeout if configured let pingTimeout: NodeJS.Timeout | null = null; const pingTimeoutMs = route?.action.websocket?.pingTimeout || 60000; // Default 60s // Define timeout function for cleaner code const resetPingTimeout = () => { if (pingTimeout) clearTimeout(pingTimeout); pingTimeout = setTimeout(() => { this.logger.debug(`WebSocket ping timeout for client connection on route: ${route?.name || 'unnamed'}`); wsIncoming.terminate(); }, pingTimeoutMs); // Don't keep process alive just for timeouts if (pingTimeout.unref) pingTimeout.unref(); }; // Reset timeout on pong wsIncoming.on('pong', () => { wsIncoming.isAlive = true; wsIncoming.lastPong = Date.now(); resetPingTimeout(); }); // Initial ping timeout resetPingTimeout(); // Handle potential message size limits const maxSize = route?.action.websocket?.maxPayloadSize || 0; // Forward incoming messages to outgoing connection wsIncoming.on('message', (data, isBinary) => { this.logger.debug(`WebSocket forwarding message from client to target: ${data.toString()}`); if (wsOutgoing.readyState === wsOutgoing.OPEN) { // Check message size if limit is set const messageSize = getMessageSize(data); if (maxSize > 0 && messageSize > maxSize) { this.logger.warn(`WebSocket message exceeds max size (${messageSize} > ${maxSize})`); wsIncoming.close(1009, 'Message too big'); return; } wsOutgoing.send(data, { binary: isBinary }); } else { this.logger.warn(`WebSocket target connection not open (state: ${wsOutgoing.readyState})`); } }); // Forward outgoing messages to incoming connection wsOutgoing.on('message', (data, isBinary) => { this.logger.debug(`WebSocket forwarding message from target to client: ${data.toString()}`); if (wsIncoming.readyState === wsIncoming.OPEN) { wsIncoming.send(data, { binary: isBinary }); } else { this.logger.warn(`WebSocket client connection not open (state: ${wsIncoming.readyState})`); } }); // Handle closing of connections wsIncoming.on('close', (code, reason) => { this.logger.debug(`WebSocket client connection closed: ${code} ${reason}`); if (wsOutgoing.readyState === wsOutgoing.OPEN) { // Ensure code is a valid WebSocket close code number const validCode = typeof code === 'number' && code >= 1000 && code <= 4999 ? code : 1000; try { const reasonString = reason ? toBuffer(reason).toString() : ''; wsOutgoing.close(validCode, reasonString); } catch (err) { this.logger.error('Error closing wsOutgoing:', err); wsOutgoing.close(validCode); } } // Clean up timers if (pingInterval) clearInterval(pingInterval); if (pingTimeout) clearTimeout(pingTimeout); }); wsOutgoing.on('close', (code, reason) => { this.logger.debug(`WebSocket target connection closed: ${code} ${reason}`); if (wsIncoming.readyState === wsIncoming.OPEN) { // Ensure code is a valid WebSocket close code number const validCode = typeof code === 'number' && code >= 1000 && code <= 4999 ? code : 1000; try { const reasonString = reason ? toBuffer(reason).toString() : ''; wsIncoming.close(validCode, reasonString); } catch (err) { this.logger.error('Error closing wsIncoming:', err); wsIncoming.close(validCode); } } // Clean up timers if (pingInterval) clearInterval(pingInterval); if (pingTimeout) clearTimeout(pingTimeout); }); this.logger.debug(`WebSocket connection established: ${req.headers.host} -> ${destination.host}:${destination.port}`); }); } catch (error) { this.logger.error(`Error handling WebSocket connection: ${error.message}`); if (wsIncoming.readyState === wsIncoming.OPEN) { wsIncoming.close(1011, 'Internal server error'); } } } /** * Get information about active WebSocket connections */ public getConnectionInfo(): { activeConnections: number } { return { activeConnections: this.wsServer ? this.wsServer.clients.size : 0 }; } /** * Shutdown the WebSocket handler */ public shutdown(): void { // Stop heartbeat interval if (this.heartbeatInterval) { clearInterval(this.heartbeatInterval); this.heartbeatInterval = null; } // Close all WebSocket connections if (this.wsServer) { this.logger.info(`Closing ${this.wsServer.clients.size} WebSocket connections`); for (const client of this.wsServer.clients) { try { client.terminate(); } catch (error) { this.logger.error('Error terminating WebSocket client', error); } } // Close the server this.wsServer.close(); this.wsServer = null; } } }