From d47b0485170072c31560866f295fd4111628d329 Mon Sep 17 00:00:00 2001 From: Juergen Kunz Date: Mon, 21 Jul 2025 19:40:01 +0000 Subject: [PATCH] feat(detection): add centralized protocol detection module - Created ts/detection module for unified protocol detection - Implemented TLS and HTTP detectors with fragmentation support - Moved TLS detection logic from existing code to centralized module - Updated RouteConnectionHandler to use ProtocolDetector for both TLS and HTTP - Refactored ACME HTTP parsing to use detection module - Added comprehensive tests for detection functionality - Eliminated duplicate protocol detection code across codebase This centralizes all non-destructive protocol detection into a single module, improving code organization and reducing duplication between ACME and routing. --- test/test.detection.ts | 131 ++++++++ ts/detection/detectors/http-detector.ts | 281 ++++++++++++++++++ ts/detection/detectors/tls-detector.ts | 257 ++++++++++++++++ ts/detection/index.ts | 22 ++ ts/detection/models/detection-types.ts | 102 +++++++ ts/detection/models/interfaces.ts | 115 +++++++ ts/detection/protocol-detector.ts | 222 ++++++++++++++ ts/detection/utils/buffer-utils.ts | 174 +++++++++++ ts/detection/utils/parser-utils.ts | 141 +++++++++ ts/index.ts | 3 +- ts/proxies/smart-proxy/models/interfaces.ts | 7 + .../smart-proxy/route-connection-handler.ts | 161 +++++----- ts/proxies/smart-proxy/tls-manager.ts | 1 + ts/proxies/smart-proxy/utils/route-helpers.ts | 130 ++++---- 14 files changed, 1620 insertions(+), 127 deletions(-) create mode 100644 test/test.detection.ts create mode 100644 ts/detection/detectors/http-detector.ts create mode 100644 ts/detection/detectors/tls-detector.ts create mode 100644 ts/detection/index.ts create mode 100644 ts/detection/models/detection-types.ts create mode 100644 ts/detection/models/interfaces.ts create mode 100644 ts/detection/protocol-detector.ts create mode 100644 ts/detection/utils/buffer-utils.ts create mode 100644 ts/detection/utils/parser-utils.ts diff --git a/test/test.detection.ts b/test/test.detection.ts new file mode 100644 index 0000000..1f43b08 --- /dev/null +++ b/test/test.detection.ts @@ -0,0 +1,131 @@ +import { expect, tap } from '@git.zone/tstest/tapbundle'; +import * as smartproxy from '../ts/index.js'; + +tap.test('Protocol Detection - TLS Detection', async () => { + // Test TLS handshake detection + const tlsHandshake = Buffer.from([ + 0x16, // Handshake record type + 0x03, 0x01, // TLS 1.0 + 0x00, 0x05, // Length: 5 bytes + 0x01, // ClientHello + 0x00, 0x00, 0x01, 0x00 // Handshake length and data + ]); + + const detector = new smartproxy.detection.TlsDetector(); + expect(detector.canHandle(tlsHandshake)).toEqual(true); + + const result = detector.detect(tlsHandshake); + expect(result).toBeDefined(); + expect(result?.protocol).toEqual('tls'); + expect(result?.connectionInfo.tlsVersion).toEqual('TLSv1.0'); +}); + +tap.test('Protocol Detection - HTTP Detection', async () => { + // Test HTTP request detection + const httpRequest = Buffer.from( + 'GET /test HTTP/1.1\r\n' + + 'Host: example.com\r\n' + + 'User-Agent: TestClient/1.0\r\n' + + '\r\n' + ); + + const detector = new smartproxy.detection.HttpDetector(); + expect(detector.canHandle(httpRequest)).toEqual(true); + + const result = detector.detect(httpRequest); + expect(result).toBeDefined(); + expect(result?.protocol).toEqual('http'); + expect(result?.connectionInfo.method).toEqual('GET'); + expect(result?.connectionInfo.path).toEqual('/test'); + expect(result?.connectionInfo.domain).toEqual('example.com'); +}); + +tap.test('Protocol Detection - Main Detector TLS', async () => { + const tlsHandshake = Buffer.from([ + 0x16, // Handshake record type + 0x03, 0x03, // TLS 1.2 + 0x00, 0x05, // Length: 5 bytes + 0x01, // ClientHello + 0x00, 0x00, 0x01, 0x00 // Handshake length and data + ]); + + const result = await smartproxy.detection.ProtocolDetector.detect(tlsHandshake); + expect(result.protocol).toEqual('tls'); + expect(result.connectionInfo.tlsVersion).toEqual('TLSv1.2'); +}); + +tap.test('Protocol Detection - Main Detector HTTP', async () => { + const httpRequest = Buffer.from( + 'POST /api/test HTTP/1.1\r\n' + + 'Host: api.example.com\r\n' + + 'Content-Type: application/json\r\n' + + 'Content-Length: 2\r\n' + + '\r\n' + + '{}' + ); + + const result = await smartproxy.detection.ProtocolDetector.detect(httpRequest); + expect(result.protocol).toEqual('http'); + expect(result.connectionInfo.method).toEqual('POST'); + expect(result.connectionInfo.path).toEqual('/api/test'); + expect(result.connectionInfo.domain).toEqual('api.example.com'); +}); + +tap.test('Protocol Detection - Unknown Protocol', async () => { + const unknownData = Buffer.from('UNKNOWN PROTOCOL DATA\r\n'); + + const result = await smartproxy.detection.ProtocolDetector.detect(unknownData); + expect(result.protocol).toEqual('unknown'); + expect(result.isComplete).toEqual(true); +}); + +tap.test('Protocol Detection - Fragmented HTTP', async () => { + const connectionId = 'test-connection-1'; + + // First fragment + const fragment1 = Buffer.from('GET /test HT'); + let result = await smartproxy.detection.ProtocolDetector.detectWithConnectionTracking( + fragment1, + connectionId + ); + expect(result.protocol).toEqual('http'); + expect(result.isComplete).toEqual(false); + + // Second fragment + const fragment2 = Buffer.from('TP/1.1\r\nHost: example.com\r\n\r\n'); + result = await smartproxy.detection.ProtocolDetector.detectWithConnectionTracking( + fragment2, + connectionId + ); + expect(result.protocol).toEqual('http'); + expect(result.isComplete).toEqual(true); + expect(result.connectionInfo.method).toEqual('GET'); + expect(result.connectionInfo.path).toEqual('/test'); + expect(result.connectionInfo.domain).toEqual('example.com'); +}); + +tap.test('Protocol Detection - HTTP Methods', async () => { + const methods = ['GET', 'POST', 'PUT', 'DELETE', 'PATCH', 'HEAD', 'OPTIONS']; + + for (const method of methods) { + const request = Buffer.from( + `${method} /test HTTP/1.1\r\n` + + 'Host: example.com\r\n' + + '\r\n' + ); + + const detector = new smartproxy.detection.HttpDetector(); + const result = detector.detect(request); + expect(result?.connectionInfo.method).toEqual(method); + } +}); + +tap.test('Protocol Detection - Invalid Data', async () => { + // Binary data that's not a valid protocol + const binaryData = Buffer.from([0xFF, 0xFE, 0xFD, 0xFC, 0xFB]); + + const result = await smartproxy.detection.ProtocolDetector.detect(binaryData); + expect(result.protocol).toEqual('unknown'); +}); + +tap.start(); \ No newline at end of file diff --git a/ts/detection/detectors/http-detector.ts b/ts/detection/detectors/http-detector.ts new file mode 100644 index 0000000..110ff5c --- /dev/null +++ b/ts/detection/detectors/http-detector.ts @@ -0,0 +1,281 @@ +/** + * HTTP protocol detector + */ + +import type { IProtocolDetector } from '../models/interfaces.js'; +import type { IDetectionResult, IDetectionOptions, IConnectionInfo, THttpMethod } from '../models/detection-types.js'; +import { extractLine, isPrintableAscii, BufferAccumulator } from '../utils/buffer-utils.js'; +import { parseHttpRequestLine, parseHttpHeaders, extractDomainFromHost, isHttpMethod } from '../utils/parser-utils.js'; + +/** + * HTTP detector implementation + */ +export class HttpDetector implements IProtocolDetector { + /** + * Minimum bytes needed to identify HTTP method + */ + private static readonly MIN_HTTP_METHOD_SIZE = 3; // GET + + /** + * Maximum reasonable HTTP header size + */ + private static readonly MAX_HEADER_SIZE = 8192; + + /** + * Fragment tracking for incomplete headers + */ + private static fragmentedBuffers = new Map(); + + /** + * Detect HTTP protocol from buffer + */ + detect(buffer: Buffer, options?: IDetectionOptions): IDetectionResult | null { + // Check if buffer is too small + if (buffer.length < HttpDetector.MIN_HTTP_METHOD_SIZE) { + return null; + } + + // Quick check: first bytes should be printable ASCII + if (!isPrintableAscii(buffer, Math.min(20, buffer.length))) { + return null; + } + + // Try to extract the first line + const firstLineResult = extractLine(buffer, 0); + if (!firstLineResult) { + // No complete line yet + return { + protocol: 'http', + connectionInfo: { protocol: 'http' }, + isComplete: false, + bytesNeeded: buffer.length + 100 // Estimate + }; + } + + // Parse the request line + const requestLine = parseHttpRequestLine(firstLineResult.line); + if (!requestLine) { + // Not a valid HTTP request line + return null; + } + + // Initialize connection info + const connectionInfo: IConnectionInfo = { + protocol: 'http', + method: requestLine.method, + path: requestLine.path, + httpVersion: requestLine.version + }; + + // Check if we want to extract headers + if (options?.extractFullHeaders !== false) { + // Look for the end of headers (double CRLF) + const headerEndSequence = Buffer.from('\r\n\r\n'); + const headerEndIndex = buffer.indexOf(headerEndSequence); + + if (headerEndIndex === -1) { + // Headers not complete yet + const maxSize = options?.maxBufferSize || HttpDetector.MAX_HEADER_SIZE; + if (buffer.length >= maxSize) { + // Headers too large, reject + return null; + } + + return { + protocol: 'http', + connectionInfo, + isComplete: false, + bytesNeeded: buffer.length + 200 // Estimate + }; + } + + // Extract all header lines + const headerLines: string[] = []; + let currentOffset = firstLineResult.nextOffset; + + while (currentOffset < headerEndIndex) { + const lineResult = extractLine(buffer, currentOffset); + if (!lineResult) { + break; + } + + if (lineResult.line.length === 0) { + // Empty line marks end of headers + break; + } + + headerLines.push(lineResult.line); + currentOffset = lineResult.nextOffset; + } + + // Parse headers + const headers = parseHttpHeaders(headerLines); + connectionInfo.headers = headers; + + // Extract domain from Host header + const hostHeader = headers['host']; + if (hostHeader) { + connectionInfo.domain = extractDomainFromHost(hostHeader); + } + + // Calculate remaining buffer + const bodyStartIndex = headerEndIndex + 4; // After \r\n\r\n + const remainingBuffer = buffer.length > bodyStartIndex + ? buffer.slice(bodyStartIndex) + : undefined; + + return { + protocol: 'http', + connectionInfo, + remainingBuffer, + isComplete: true + }; + } else { + // Just extract Host header for domain + let currentOffset = firstLineResult.nextOffset; + const maxLines = 50; // Reasonable limit + + for (let i = 0; i < maxLines && currentOffset < buffer.length; i++) { + const lineResult = extractLine(buffer, currentOffset); + if (!lineResult) { + // Need more data + return { + protocol: 'http', + connectionInfo, + isComplete: false, + bytesNeeded: buffer.length + 50 + }; + } + + if (lineResult.line.length === 0) { + // End of headers + break; + } + + // Quick check for Host header + if (lineResult.line.toLowerCase().startsWith('host:')) { + const colonIndex = lineResult.line.indexOf(':'); + const hostValue = lineResult.line.slice(colonIndex + 1).trim(); + connectionInfo.domain = extractDomainFromHost(hostValue); + + // If we only needed the domain, we can return early + return { + protocol: 'http', + connectionInfo, + isComplete: true + }; + } + + currentOffset = lineResult.nextOffset; + } + + // If we reach here, no Host header found yet + return { + protocol: 'http', + connectionInfo, + isComplete: false, + bytesNeeded: buffer.length + 100 + }; + } + } + + /** + * Check if buffer can be handled by this detector + */ + canHandle(buffer: Buffer): boolean { + if (buffer.length < HttpDetector.MIN_HTTP_METHOD_SIZE) { + return false; + } + + // Check if first bytes could be an HTTP method + const firstWord = buffer.slice(0, Math.min(10, buffer.length)).toString('ascii').split(' ')[0]; + return isHttpMethod(firstWord); + } + + /** + * Get minimum bytes needed for detection + */ + getMinimumBytes(): number { + return HttpDetector.MIN_HTTP_METHOD_SIZE; + } + + /** + * Quick check if buffer starts with HTTP method + */ + static quickCheck(buffer: Buffer): boolean { + if (buffer.length < 3) { + return false; + } + + // Check common HTTP methods + const start = buffer.slice(0, 7).toString('ascii'); + return start.startsWith('GET ') || + start.startsWith('POST ') || + start.startsWith('PUT ') || + start.startsWith('DELETE ') || + start.startsWith('HEAD ') || + start.startsWith('OPTIONS') || + start.startsWith('PATCH ') || + start.startsWith('CONNECT') || + start.startsWith('TRACE '); + } + + /** + * Handle fragmented HTTP detection with connection tracking + */ + static detectWithFragments( + buffer: Buffer, + connectionId: string, + options?: IDetectionOptions + ): IDetectionResult | null { + const detector = new HttpDetector(); + + // Try direct detection first + const directResult = detector.detect(buffer, options); + if (directResult && directResult.isComplete) { + // Clean up any tracked fragments for this connection + this.fragmentedBuffers.delete(connectionId); + return directResult; + } + + // Handle fragmentation + let accumulator = this.fragmentedBuffers.get(connectionId); + if (!accumulator) { + accumulator = new BufferAccumulator(); + this.fragmentedBuffers.set(connectionId, accumulator); + } + + accumulator.append(buffer); + const fullBuffer = accumulator.getBuffer(); + + // Check size limit + const maxSize = options?.maxBufferSize || this.MAX_HEADER_SIZE; + if (fullBuffer.length > maxSize) { + // Too large, clean up and reject + this.fragmentedBuffers.delete(connectionId); + return null; + } + + // Try detection on accumulated buffer + const result = detector.detect(fullBuffer, options); + + if (result && result.isComplete) { + // Success - clean up + this.fragmentedBuffers.delete(connectionId); + return result; + } + + return result; + } + + /** + * Clean up old fragment buffers + */ + static cleanupFragments(maxAge: number = 5000): void { + // TODO: Add timestamp tracking to BufferAccumulator for cleanup + // For now, just clear if too many connections + if (this.fragmentedBuffers.size > 1000) { + this.fragmentedBuffers.clear(); + } + } +} \ No newline at end of file diff --git a/ts/detection/detectors/tls-detector.ts b/ts/detection/detectors/tls-detector.ts new file mode 100644 index 0000000..05e1fdd --- /dev/null +++ b/ts/detection/detectors/tls-detector.ts @@ -0,0 +1,257 @@ +/** + * TLS protocol detector + */ + +// TLS detector doesn't need plugins imports +import type { IProtocolDetector } from '../models/interfaces.js'; +import type { IDetectionResult, IDetectionOptions, IConnectionInfo } from '../models/detection-types.js'; +import { readUInt16BE, readUInt24BE, BufferAccumulator } from '../utils/buffer-utils.js'; +import { tlsVersionToString } from '../utils/parser-utils.js'; + +// Import existing TLS utilities +import { TlsUtils, TlsRecordType, TlsHandshakeType, TlsExtensionType } from '../../tls/utils/tls-utils.js'; +import { SniExtraction } from '../../tls/sni/sni-extraction.js'; +import { ClientHelloParser } from '../../tls/sni/client-hello-parser.js'; + +/** + * TLS detector implementation + */ +export class TlsDetector implements IProtocolDetector { + /** + * Minimum bytes needed to identify TLS (record header) + */ + private static readonly MIN_TLS_HEADER_SIZE = 5; + + /** + * Fragment tracking for incomplete handshakes + */ + private static fragmentedBuffers = new Map(); + + /** + * Detect TLS protocol from buffer + */ + detect(buffer: Buffer, options?: IDetectionOptions): IDetectionResult | null { + // Check if buffer is too small + if (buffer.length < TlsDetector.MIN_TLS_HEADER_SIZE) { + return null; + } + + // Check if this is a TLS record + if (!this.isTlsRecord(buffer)) { + return null; + } + + // Extract basic TLS info + const recordType = buffer[0]; + const tlsMajor = buffer[1]; + const tlsMinor = buffer[2]; + const recordLength = readUInt16BE(buffer, 3); + + // Initialize connection info + const connectionInfo: IConnectionInfo = { + protocol: 'tls', + tlsVersion: tlsVersionToString(tlsMajor, tlsMinor) || undefined + }; + + // If it's a handshake, try to extract more info + if (recordType === TlsRecordType.HANDSHAKE && buffer.length >= 6) { + const handshakeType = buffer[5]; + + // For ClientHello, extract SNI and other info + if (handshakeType === TlsHandshakeType.CLIENT_HELLO) { + // Check if we have the complete handshake + const totalRecordLength = recordLength + 5; // Including TLS header + if (buffer.length >= totalRecordLength) { + // Extract SNI using existing logic + const sni = SniExtraction.extractSNI(buffer); + if (sni) { + connectionInfo.domain = sni; + connectionInfo.sni = sni; + } + + // Parse ClientHello for additional info + const parseResult = ClientHelloParser.parseClientHello(buffer); + if (parseResult.isValid) { + // Extract ALPN if present + const alpnExtension = parseResult.extensions.find( + ext => ext.type === TlsExtensionType.APPLICATION_LAYER_PROTOCOL_NEGOTIATION + ); + + if (alpnExtension) { + connectionInfo.alpn = this.parseAlpnExtension(alpnExtension.data); + } + + // Store cipher suites if needed + if (parseResult.cipherSuites && options?.extractFullHeaders) { + connectionInfo.cipherSuites = this.parseCipherSuites(parseResult.cipherSuites); + } + } + + // Return complete result + return { + protocol: 'tls', + connectionInfo, + remainingBuffer: buffer.length > totalRecordLength + ? buffer.slice(totalRecordLength) + : undefined, + isComplete: true + }; + } else { + // Incomplete handshake + return { + protocol: 'tls', + connectionInfo, + isComplete: false, + bytesNeeded: totalRecordLength + }; + } + } + } + + // For other TLS record types, just return basic info + return { + protocol: 'tls', + connectionInfo, + isComplete: true, + remainingBuffer: buffer.length > recordLength + 5 + ? buffer.slice(recordLength + 5) + : undefined + }; + } + + /** + * Check if buffer can be handled by this detector + */ + canHandle(buffer: Buffer): boolean { + return buffer.length >= TlsDetector.MIN_TLS_HEADER_SIZE && + this.isTlsRecord(buffer); + } + + /** + * Get minimum bytes needed for detection + */ + getMinimumBytes(): number { + return TlsDetector.MIN_TLS_HEADER_SIZE; + } + + /** + * Check if buffer contains a valid TLS record + */ + private isTlsRecord(buffer: Buffer): boolean { + const recordType = buffer[0]; + + // Check for valid record type + const validTypes = [ + TlsRecordType.CHANGE_CIPHER_SPEC, + TlsRecordType.ALERT, + TlsRecordType.HANDSHAKE, + TlsRecordType.APPLICATION_DATA, + TlsRecordType.HEARTBEAT + ]; + + if (!validTypes.includes(recordType)) { + return false; + } + + // Check TLS version bytes (should be 0x03 0x0X) + if (buffer[1] !== 0x03) { + return false; + } + + // Check record length is reasonable + const recordLength = readUInt16BE(buffer, 3); + if (recordLength > 16384) { // Max TLS record size + return false; + } + + return true; + } + + /** + * Parse ALPN extension data + */ + private parseAlpnExtension(data: Buffer): string[] { + const protocols: string[] = []; + + if (data.length < 2) { + return protocols; + } + + const listLength = readUInt16BE(data, 0); + let offset = 2; + + while (offset < Math.min(2 + listLength, data.length)) { + const protoLength = data[offset]; + offset++; + + if (offset + protoLength <= data.length) { + const protocol = data.slice(offset, offset + protoLength).toString('ascii'); + protocols.push(protocol); + offset += protoLength; + } else { + break; + } + } + + return protocols; + } + + /** + * Parse cipher suites + */ + private parseCipherSuites(data: Buffer): number[] { + const suites: number[] = []; + + for (let i = 0; i + 1 < data.length; i += 2) { + const suite = readUInt16BE(data, i); + suites.push(suite); + } + + return suites; + } + + /** + * Handle fragmented TLS detection with connection tracking + */ + static detectWithFragments( + buffer: Buffer, + connectionId: string, + options?: IDetectionOptions + ): IDetectionResult | null { + const detector = new TlsDetector(); + + // Try direct detection first + const directResult = detector.detect(buffer, options); + if (directResult && directResult.isComplete) { + // Clean up any tracked fragments for this connection + this.fragmentedBuffers.delete(connectionId); + return directResult; + } + + // Handle fragmentation + let accumulator = this.fragmentedBuffers.get(connectionId); + if (!accumulator) { + accumulator = new BufferAccumulator(); + this.fragmentedBuffers.set(connectionId, accumulator); + } + + accumulator.append(buffer); + const fullBuffer = accumulator.getBuffer(); + + // Try detection on accumulated buffer + const result = detector.detect(fullBuffer, options); + + if (result && result.isComplete) { + // Success - clean up + this.fragmentedBuffers.delete(connectionId); + return result; + } + + // Check timeout + if (options?.timeout) { + // TODO: Implement timeout handling + } + + return result; + } +} \ No newline at end of file diff --git a/ts/detection/index.ts b/ts/detection/index.ts new file mode 100644 index 0000000..21c384a --- /dev/null +++ b/ts/detection/index.ts @@ -0,0 +1,22 @@ +/** + * Centralized Protocol Detection Module + * + * This module provides unified protocol detection capabilities for + * both TLS and HTTP protocols, extracting connection information + * without consuming the data stream. + */ + +// Main detector +export * from './protocol-detector.js'; + +// Models +export * from './models/detection-types.js'; +export * from './models/interfaces.js'; + +// Individual detectors +export * from './detectors/tls-detector.js'; +export * from './detectors/http-detector.js'; + +// Utilities +export * from './utils/buffer-utils.js'; +export * from './utils/parser-utils.js'; \ No newline at end of file diff --git a/ts/detection/models/detection-types.ts b/ts/detection/models/detection-types.ts new file mode 100644 index 0000000..68930f0 --- /dev/null +++ b/ts/detection/models/detection-types.ts @@ -0,0 +1,102 @@ +/** + * Type definitions for protocol detection + */ + +/** + * Supported protocol types that can be detected + */ +export type TProtocolType = 'tls' | 'http' | 'unknown'; + +/** + * HTTP method types + */ +export type THttpMethod = 'GET' | 'POST' | 'PUT' | 'DELETE' | 'PATCH' | 'HEAD' | 'OPTIONS' | 'CONNECT' | 'TRACE'; + +/** + * TLS version identifiers + */ +export type TTlsVersion = 'SSLv3' | 'TLSv1.0' | 'TLSv1.1' | 'TLSv1.2' | 'TLSv1.3'; + +/** + * Connection information extracted from protocol detection + */ +export interface IConnectionInfo { + /** + * The detected protocol type + */ + protocol: TProtocolType; + + /** + * Domain/hostname extracted from the connection + * - For TLS: from SNI extension + * - For HTTP: from Host header + */ + domain?: string; + + /** + * HTTP-specific fields + */ + method?: THttpMethod; + path?: string; + httpVersion?: string; + headers?: Record; + + /** + * TLS-specific fields + */ + tlsVersion?: TTlsVersion; + sni?: string; + alpn?: string[]; + cipherSuites?: number[]; +} + +/** + * Result of protocol detection + */ +export interface IDetectionResult { + /** + * The detected protocol type + */ + protocol: TProtocolType; + + /** + * Extracted connection information + */ + connectionInfo: IConnectionInfo; + + /** + * Any remaining buffer data after detection headers + * This can be used to continue processing the stream + */ + remainingBuffer?: Buffer; + + /** + * Whether the detection is complete or needs more data + */ + isComplete: boolean; + + /** + * Minimum bytes needed for complete detection (if incomplete) + */ + bytesNeeded?: number; +} + +/** + * Options for protocol detection + */ +export interface IDetectionOptions { + /** + * Maximum bytes to buffer for detection (default: 8192) + */ + maxBufferSize?: number; + + /** + * Timeout for detection in milliseconds (default: 5000) + */ + timeout?: number; + + /** + * Whether to extract full headers or just essential info + */ + extractFullHeaders?: boolean; +} \ No newline at end of file diff --git a/ts/detection/models/interfaces.ts b/ts/detection/models/interfaces.ts new file mode 100644 index 0000000..1299f9e --- /dev/null +++ b/ts/detection/models/interfaces.ts @@ -0,0 +1,115 @@ +/** + * Interface definitions for protocol detection components + */ + +import type { IDetectionResult, IDetectionOptions } from './detection-types.js'; + +/** + * Interface for protocol detectors + */ +export interface IProtocolDetector { + /** + * Detect protocol from buffer data + * @param buffer The buffer to analyze + * @param options Detection options + * @returns Detection result or null if protocol cannot be determined + */ + detect(buffer: Buffer, options?: IDetectionOptions): IDetectionResult | null; + + /** + * Check if buffer potentially contains this protocol + * @param buffer The buffer to check + * @returns True if buffer might contain this protocol + */ + canHandle(buffer: Buffer): boolean; + + /** + * Get the minimum bytes needed for detection + */ + getMinimumBytes(): number; +} + +/** + * Interface for connection tracking during fragmented detection + */ +export interface IConnectionTracker { + /** + * Connection identifier + */ + id: string; + + /** + * Accumulated buffer data + */ + buffer: Buffer; + + /** + * Timestamp of first data + */ + startTime: number; + + /** + * Current detection state + */ + state: 'detecting' | 'complete' | 'failed'; + + /** + * Partial detection result (if any) + */ + partialResult?: Partial; +} + +/** + * Interface for buffer accumulator (handles fragmented data) + */ +export interface IBufferAccumulator { + /** + * Add data to accumulator + */ + append(data: Buffer): void; + + /** + * Get accumulated buffer + */ + getBuffer(): Buffer; + + /** + * Get buffer length + */ + length(): number; + + /** + * Clear accumulated data + */ + clear(): void; + + /** + * Check if accumulator has enough data + */ + hasMinimumBytes(minBytes: number): boolean; +} + +/** + * Detection events + */ +export interface IDetectionEvents { + /** + * Emitted when protocol is successfully detected + */ + detected: (result: IDetectionResult) => void; + + /** + * Emitted when detection fails + */ + failed: (error: Error) => void; + + /** + * Emitted when detection times out + */ + timeout: () => void; + + /** + * Emitted when more data is needed + */ + needMoreData: (bytesNeeded: number) => void; +} \ No newline at end of file diff --git a/ts/detection/protocol-detector.ts b/ts/detection/protocol-detector.ts new file mode 100644 index 0000000..0995718 --- /dev/null +++ b/ts/detection/protocol-detector.ts @@ -0,0 +1,222 @@ +/** + * Main protocol detector that orchestrates detection across different protocols + */ + +import type { IDetectionResult, IDetectionOptions, IConnectionInfo } from './models/detection-types.js'; +import { TlsDetector } from './detectors/tls-detector.js'; +import { HttpDetector } from './detectors/http-detector.js'; + +/** + * Main protocol detector class + */ +export class ProtocolDetector { + /** + * Connection tracking for fragmented detection + */ + private static connectionTracking = new Map(); + + /** + * Detect protocol from buffer data + * + * @param buffer The buffer to analyze + * @param options Detection options + * @returns Detection result with protocol information + */ + static async detect( + buffer: Buffer, + options?: IDetectionOptions + ): Promise { + // Quick sanity check + if (!buffer || buffer.length === 0) { + return { + protocol: 'unknown', + connectionInfo: { protocol: 'unknown' }, + isComplete: true + }; + } + + // Try TLS detection first (more specific) + const tlsDetector = new TlsDetector(); + if (tlsDetector.canHandle(buffer)) { + const tlsResult = tlsDetector.detect(buffer, options); + if (tlsResult) { + return tlsResult; + } + } + + // Try HTTP detection + const httpDetector = new HttpDetector(); + if (httpDetector.canHandle(buffer)) { + const httpResult = httpDetector.detect(buffer, options); + if (httpResult) { + return httpResult; + } + } + + // Neither TLS nor HTTP + return { + protocol: 'unknown', + connectionInfo: { protocol: 'unknown' }, + isComplete: true + }; + } + + /** + * Detect protocol with connection tracking for fragmented data + * + * @param buffer The buffer to analyze + * @param connectionId Unique connection identifier + * @param options Detection options + * @returns Detection result with protocol information + */ + static async detectWithConnectionTracking( + buffer: Buffer, + connectionId: string, + options?: IDetectionOptions + ): Promise { + // Initialize or get connection tracking + let tracking = this.connectionTracking.get(connectionId); + if (!tracking) { + tracking = { startTime: Date.now() }; + this.connectionTracking.set(connectionId, tracking); + } + + // Check timeout + if (options?.timeout) { + const elapsed = Date.now() - tracking.startTime; + if (elapsed > options.timeout) { + // Timeout - clean up and return unknown + this.connectionTracking.delete(connectionId); + TlsDetector.detectWithFragments(Buffer.alloc(0), connectionId); // Force cleanup + HttpDetector.detectWithFragments(Buffer.alloc(0), connectionId); // Force cleanup + + return { + protocol: 'unknown', + connectionInfo: { protocol: 'unknown' }, + isComplete: true + }; + } + } + + // If we already know the protocol, use the appropriate detector + if (tracking.protocol === 'tls') { + const result = TlsDetector.detectWithFragments(buffer, connectionId, options); + if (result && result.isComplete) { + this.connectionTracking.delete(connectionId); + } + return result || { + protocol: 'unknown', + connectionInfo: { protocol: 'unknown' }, + isComplete: true + }; + } else if (tracking.protocol === 'http') { + const result = HttpDetector.detectWithFragments(buffer, connectionId, options); + if (result && result.isComplete) { + this.connectionTracking.delete(connectionId); + } + return result || { + protocol: 'unknown', + connectionInfo: { protocol: 'unknown' }, + isComplete: true + }; + } + + // First time detection - try to determine protocol + // Quick checks first + if (buffer.length > 0) { + // TLS always starts with specific byte values + if (buffer[0] >= 0x14 && buffer[0] <= 0x18) { + tracking.protocol = 'tls'; + const result = TlsDetector.detectWithFragments(buffer, connectionId, options); + if (result) { + if (result.isComplete) { + this.connectionTracking.delete(connectionId); + } + return result; + } + } + // HTTP starts with ASCII text + else if (HttpDetector.quickCheck(buffer)) { + tracking.protocol = 'http'; + const result = HttpDetector.detectWithFragments(buffer, connectionId, options); + if (result) { + if (result.isComplete) { + this.connectionTracking.delete(connectionId); + } + return result; + } + } + } + + // Can't determine protocol yet + return { + protocol: 'unknown', + connectionInfo: { protocol: 'unknown' }, + isComplete: false, + bytesNeeded: 10 // Need more data to determine protocol + }; + } + + /** + * Clean up old connection tracking entries + * + * @param maxAge Maximum age in milliseconds (default: 30 seconds) + */ + static cleanupConnections(maxAge: number = 30000): void { + const now = Date.now(); + const toDelete: string[] = []; + + for (const [connectionId, tracking] of this.connectionTracking.entries()) { + if (now - tracking.startTime > maxAge) { + toDelete.push(connectionId); + } + } + + for (const connectionId of toDelete) { + this.connectionTracking.delete(connectionId); + // Also clean up detector-specific buffers + TlsDetector.detectWithFragments(Buffer.alloc(0), connectionId); // Force cleanup + HttpDetector.detectWithFragments(Buffer.alloc(0), connectionId); // Force cleanup + } + + // Also trigger cleanup in detectors + HttpDetector.cleanupFragments(maxAge); + } + + /** + * Extract domain from connection info + * + * @param connectionInfo Connection information from detection + * @returns The domain/hostname if found + */ + static extractDomain(connectionInfo: IConnectionInfo): string | undefined { + // For both TLS and HTTP, domain is stored in the domain field + return connectionInfo.domain; + } + + /** + * Create a connection ID from connection parameters + * + * @param params Connection parameters + * @returns A unique connection identifier + */ + static createConnectionId(params: { + sourceIp?: string; + sourcePort?: number; + destIp?: string; + destPort?: number; + socketId?: string; + }): string { + // If socketId is provided, use it + if (params.socketId) { + return params.socketId; + } + + // Otherwise create from connection tuple + const { sourceIp = 'unknown', sourcePort = 0, destIp = 'unknown', destPort = 0 } = params; + return `${sourceIp}:${sourcePort}-${destIp}:${destPort}`; + } +} \ No newline at end of file diff --git a/ts/detection/utils/buffer-utils.ts b/ts/detection/utils/buffer-utils.ts new file mode 100644 index 0000000..722990d --- /dev/null +++ b/ts/detection/utils/buffer-utils.ts @@ -0,0 +1,174 @@ +/** + * Buffer manipulation utilities for protocol detection + */ + +/** + * BufferAccumulator class for handling fragmented data + */ +export class BufferAccumulator { + private chunks: Buffer[] = []; + private totalLength = 0; + + /** + * Append data to the accumulator + */ + append(data: Buffer): void { + this.chunks.push(data); + this.totalLength += data.length; + } + + /** + * Get the accumulated buffer + */ + getBuffer(): Buffer { + if (this.chunks.length === 0) { + return Buffer.alloc(0); + } + if (this.chunks.length === 1) { + return this.chunks[0]; + } + return Buffer.concat(this.chunks, this.totalLength); + } + + /** + * Get current buffer length + */ + length(): number { + return this.totalLength; + } + + /** + * Clear all accumulated data + */ + clear(): void { + this.chunks = []; + this.totalLength = 0; + } + + /** + * Check if accumulator has minimum bytes + */ + hasMinimumBytes(minBytes: number): boolean { + return this.totalLength >= minBytes; + } +} + +/** + * Read a big-endian 16-bit integer from buffer + */ +export function readUInt16BE(buffer: Buffer, offset: number): number { + if (offset + 2 > buffer.length) { + throw new Error('Buffer too short for UInt16BE read'); + } + return (buffer[offset] << 8) | buffer[offset + 1]; +} + +/** + * Read a big-endian 24-bit integer from buffer + */ +export function readUInt24BE(buffer: Buffer, offset: number): number { + if (offset + 3 > buffer.length) { + throw new Error('Buffer too short for UInt24BE read'); + } + return (buffer[offset] << 16) | (buffer[offset + 1] << 8) | buffer[offset + 2]; +} + +/** + * Find a byte sequence in a buffer + */ +export function findSequence(buffer: Buffer, sequence: Buffer, startOffset = 0): number { + if (sequence.length === 0) { + return startOffset; + } + + const searchLength = buffer.length - sequence.length + 1; + for (let i = startOffset; i < searchLength; i++) { + let found = true; + for (let j = 0; j < sequence.length; j++) { + if (buffer[i + j] !== sequence[j]) { + found = false; + break; + } + } + if (found) { + return i; + } + } + return -1; +} + +/** + * Extract a line from buffer (up to CRLF or LF) + */ +export function extractLine(buffer: Buffer, startOffset = 0): { line: string; nextOffset: number } | null { + let lineEnd = -1; + let skipBytes = 1; + + // Look for CRLF first + const crlfPos = findSequence(buffer, Buffer.from('\r\n'), startOffset); + if (crlfPos !== -1) { + lineEnd = crlfPos; + skipBytes = 2; + } else { + // Look for LF only + for (let i = startOffset; i < buffer.length; i++) { + if (buffer[i] === 0x0A) { // LF + lineEnd = i; + break; + } + } + } + + if (lineEnd === -1) { + return null; + } + + const line = buffer.slice(startOffset, lineEnd).toString('utf8'); + return { + line, + nextOffset: lineEnd + skipBytes + }; +} + +/** + * Check if buffer starts with a string (case-insensitive) + */ +export function startsWithString(buffer: Buffer, str: string, offset = 0): boolean { + if (offset + str.length > buffer.length) { + return false; + } + + const bufferStr = buffer.slice(offset, offset + str.length).toString('utf8'); + return bufferStr.toLowerCase() === str.toLowerCase(); +} + +/** + * Safe buffer slice that doesn't throw on out-of-bounds + */ +export function safeSlice(buffer: Buffer, start: number, end?: number): Buffer { + const safeStart = Math.max(0, Math.min(start, buffer.length)); + const safeEnd = end === undefined + ? buffer.length + : Math.max(safeStart, Math.min(end, buffer.length)); + + return buffer.slice(safeStart, safeEnd); +} + +/** + * Check if buffer contains printable ASCII + */ +export function isPrintableAscii(buffer: Buffer, length?: number): boolean { + const checkLength = length || buffer.length; + + for (let i = 0; i < checkLength && i < buffer.length; i++) { + const byte = buffer[i]; + // Check if byte is printable ASCII (0x20-0x7E) or tab/newline/carriage return + if (byte < 0x20 || byte > 0x7E) { + if (byte !== 0x09 && byte !== 0x0A && byte !== 0x0D) { + return false; + } + } + } + + return true; +} \ No newline at end of file diff --git a/ts/detection/utils/parser-utils.ts b/ts/detection/utils/parser-utils.ts new file mode 100644 index 0000000..d381ca1 --- /dev/null +++ b/ts/detection/utils/parser-utils.ts @@ -0,0 +1,141 @@ +/** + * Parser utilities for protocol detection + */ + +import type { THttpMethod, TTlsVersion } from '../models/detection-types.js'; + +/** + * Valid HTTP methods + */ +export const HTTP_METHODS: THttpMethod[] = [ + 'GET', 'POST', 'PUT', 'DELETE', 'PATCH', 'HEAD', 'OPTIONS', 'CONNECT', 'TRACE' +]; + +/** + * HTTP version strings + */ +export const HTTP_VERSIONS = ['HTTP/1.0', 'HTTP/1.1', 'HTTP/2', 'HTTP/3']; + +/** + * Parse HTTP request line + */ +export function parseHttpRequestLine(line: string): { + method: THttpMethod; + path: string; + version: string; +} | null { + const parts = line.trim().split(' '); + + if (parts.length !== 3) { + return null; + } + + const [method, path, version] = parts; + + // Validate method + if (!HTTP_METHODS.includes(method as THttpMethod)) { + return null; + } + + // Validate version + if (!version.startsWith('HTTP/')) { + return null; + } + + return { + method: method as THttpMethod, + path, + version + }; +} + +/** + * Parse HTTP header line + */ +export function parseHttpHeader(line: string): { name: string; value: string } | null { + const colonIndex = line.indexOf(':'); + + if (colonIndex === -1) { + return null; + } + + const name = line.slice(0, colonIndex).trim(); + const value = line.slice(colonIndex + 1).trim(); + + if (!name) { + return null; + } + + return { name, value }; +} + +/** + * Parse HTTP headers from lines + */ +export function parseHttpHeaders(lines: string[]): Record { + const headers: Record = {}; + + for (const line of lines) { + const header = parseHttpHeader(line); + if (header) { + // Convert header names to lowercase for consistency + headers[header.name.toLowerCase()] = header.value; + } + } + + return headers; +} + +/** + * Convert TLS version bytes to version string + */ +export function tlsVersionToString(major: number, minor: number): TTlsVersion | null { + if (major === 0x03) { + switch (minor) { + case 0x00: return 'SSLv3'; + case 0x01: return 'TLSv1.0'; + case 0x02: return 'TLSv1.1'; + case 0x03: return 'TLSv1.2'; + case 0x04: return 'TLSv1.3'; + } + } + return null; +} + +/** + * Extract domain from Host header value + */ +export function extractDomainFromHost(hostHeader: string): string { + // Remove port if present + const colonIndex = hostHeader.lastIndexOf(':'); + if (colonIndex !== -1) { + // Check if it's not part of IPv6 address + const beforeColon = hostHeader.slice(0, colonIndex); + if (!beforeColon.includes(']')) { + return beforeColon; + } + } + return hostHeader; +} + +/** + * Validate domain name + */ +export function isValidDomain(domain: string): boolean { + // Basic domain validation + if (!domain || domain.length > 253) { + return false; + } + + // Check for valid characters and structure + const domainRegex = /^(?!-)[A-Za-z0-9-]{1,63}(?; + }; } \ No newline at end of file diff --git a/ts/proxies/smart-proxy/route-connection-handler.ts b/ts/proxies/smart-proxy/route-connection-handler.ts index cfc5e00..8f88cb8 100644 --- a/ts/proxies/smart-proxy/route-connection-handler.ts +++ b/ts/proxies/smart-proxy/route-connection-handler.ts @@ -10,6 +10,7 @@ import { WrappedSocket } from '../../core/models/wrapped-socket.js'; import { getUnderlyingSocket } from '../../core/models/socket-types.js'; import { ProxyProtocolParser } from '../../core/utils/proxy-protocol.js'; import type { SmartProxy } from './smart-proxy.js'; +import { ProtocolDetector } from '../../detection/index.js'; /** * Handles new connection processing and setup logic with support for route-based configuration @@ -301,11 +302,27 @@ export class RouteConnectionHandler { }); // Handler for processing initial data (after potential PROXY protocol) - const processInitialData = (chunk: Buffer) => { + const processInitialData = async (chunk: Buffer) => { + // Use ProtocolDetector to identify protocol + const connectionId = ProtocolDetector.createConnectionId({ + sourceIp: record.remoteIP, + sourcePort: socket.remotePort, + destIp: socket.localAddress, + destPort: socket.localPort, + socketId: record.id + }); + + const detectionResult = await ProtocolDetector.detectWithConnectionTracking( + chunk, + connectionId, + { extractFullHeaders: false } // Only extract essential info for routing + ); + // Block non-TLS connections on port 443 - if (!this.smartProxy.tlsManager.isTlsHandshake(chunk) && localPort === 443) { - logger.log('warn', `Non-TLS connection ${connectionId} detected on port 443. Terminating connection - only TLS traffic is allowed on standard HTTPS port.`, { - connectionId, + if (localPort === 443 && detectionResult.protocol !== 'tls') { + logger.log('warn', `Non-TLS connection ${record.id} detected on port 443. Terminating connection - only TLS traffic is allowed on standard HTTPS port.`, { + connectionId: record.id, + detectedProtocol: detectionResult.protocol, message: 'Terminating connection - only TLS traffic is allowed on standard HTTPS port.', component: 'route-handler' }); @@ -318,71 +335,78 @@ export class RouteConnectionHandler { return; } - // Check if this looks like a TLS handshake + // Extract domain and protocol info let serverName = ''; - if (this.smartProxy.tlsManager.isTlsHandshake(chunk)) { + if (detectionResult.protocol === 'tls') { record.isTLS = true; + serverName = detectionResult.connectionInfo.domain || ''; + + // Lock the connection to the negotiated SNI + record.lockedDomain = serverName; - // Check for ClientHello to extract SNI - if (this.smartProxy.tlsManager.isClientHello(chunk)) { - // Create connection info for SNI extraction - const connInfo = { - sourceIp: record.remoteIP, - sourcePort: socket.remotePort || 0, - destIp: socket.localAddress || '', - destPort: socket.localPort || 0, - }; - - // Extract SNI - serverName = this.smartProxy.tlsManager.extractSNI(chunk, connInfo) || ''; - - // Lock the connection to the negotiated SNI - record.lockedDomain = serverName; - - // Check if we should reject connections without SNI - if (!serverName && this.smartProxy.settings.allowSessionTicket === false) { - logger.log('warn', `No SNI detected in TLS ClientHello for connection ${connectionId}; sending TLS alert`, { - connectionId, - component: 'route-handler' - }); - if (record.incomingTerminationReason === null) { - record.incomingTerminationReason = 'session_ticket_blocked_no_sni'; - this.smartProxy.connectionManager.incrementTerminationStat( - 'incoming', - 'session_ticket_blocked_no_sni' - ); - } - const alert = Buffer.from([0x15, 0x03, 0x03, 0x00, 0x02, 0x01, 0x70]); - try { - // Count the alert bytes being sent - record.bytesSent += alert.length; - if (this.smartProxy.metricsCollector) { - this.smartProxy.metricsCollector.recordBytes(record.id, 0, alert.length); - } - - socket.cork(); - socket.write(alert); - socket.uncork(); - socket.end(); - } catch { - socket.end(); - } - this.smartProxy.connectionManager.cleanupConnection(record, 'session_ticket_blocked_no_sni'); - return; + // Check if we should reject connections without SNI + if (!serverName && this.smartProxy.settings.allowSessionTicket === false) { + logger.log('warn', `No SNI detected in TLS ClientHello for connection ${record.id}; sending TLS alert`, { + connectionId: record.id, + component: 'route-handler' + }); + if (record.incomingTerminationReason === null) { + record.incomingTerminationReason = 'session_ticket_blocked_no_sni'; + this.smartProxy.connectionManager.incrementTerminationStat( + 'incoming', + 'session_ticket_blocked_no_sni' + ); } - - if (this.smartProxy.settings.enableDetailedLogging) { - logger.log('info', `TLS connection with SNI`, { - connectionId, - serverName: serverName || '(empty)', - component: 'route-handler' - }); + const alert = Buffer.from([0x15, 0x03, 0x03, 0x00, 0x02, 0x01, 0x70]); + try { + // Count the alert bytes being sent + record.bytesSent += alert.length; + if (this.smartProxy.metricsCollector) { + this.smartProxy.metricsCollector.recordBytes(record.id, 0, alert.length); + } + + socket.cork(); + socket.write(alert); + socket.uncork(); + socket.end(); + } catch { + socket.end(); } + this.smartProxy.connectionManager.cleanupConnection(record, 'session_ticket_blocked_no_sni'); + return; + } + + if (this.smartProxy.settings.enableDetailedLogging) { + logger.log('info', `TLS connection with SNI`, { + connectionId: record.id, + serverName: serverName || '(empty)', + component: 'route-handler' + }); + } + } else if (detectionResult.protocol === 'http') { + // For HTTP, extract domain from Host header + serverName = detectionResult.connectionInfo.domain || ''; + + // Store HTTP-specific info for later use + record.httpInfo = { + method: detectionResult.connectionInfo.method, + path: detectionResult.connectionInfo.path, + headers: detectionResult.connectionInfo.headers + }; + + if (this.smartProxy.settings.enableDetailedLogging) { + logger.log('info', `HTTP connection detected`, { + connectionId: record.id, + domain: serverName || '(no host header)', + method: detectionResult.connectionInfo.method, + path: detectionResult.connectionInfo.path, + component: 'route-handler' + }); } } // Find the appropriate route for this connection - this.routeConnection(socket, record, serverName, chunk); + this.routeConnection(socket, record, serverName, chunk, detectionResult); }; // First data handler to capture initial TLS handshake or PROXY protocol @@ -454,7 +478,8 @@ export class RouteConnectionHandler { socket: plugins.net.Socket | WrappedSocket, record: IConnectionRecord, serverName: string, - initialChunk?: Buffer + initialChunk?: Buffer, + detectionResult?: any // Using any temporarily to avoid circular dependency issues ): void { const connectionId = record.id; const localPort = record.localPort; @@ -635,7 +660,7 @@ export class RouteConnectionHandler { // Handle the route based on its action type switch (route.action.type) { case 'forward': - return this.handleForwardAction(socket, record, route, initialChunk); + return this.handleForwardAction(socket, record, route, initialChunk, detectionResult); case 'socket-handler': logger.log('info', `Handling socket-handler action for route ${route.name}`, { @@ -738,7 +763,8 @@ export class RouteConnectionHandler { socket: plugins.net.Socket | WrappedSocket, record: IConnectionRecord, route: IRouteConfig, - initialChunk?: Buffer + initialChunk?: Buffer, + detectionResult?: any // Using any temporarily to avoid circular dependency issues ): void { const connectionId = record.id; const action = route.action as IRouteAction; @@ -819,14 +845,11 @@ export class RouteConnectionHandler { // Create context for target selection const targetSelectionContext = { port: record.localPort, - path: undefined, // Will be populated from HTTP headers if available - headers: undefined, // Will be populated from HTTP headers if available - method: undefined // Will be populated from HTTP headers if available + path: record.httpInfo?.path, + headers: record.httpInfo?.headers, + method: record.httpInfo?.method }; - // TODO: Extract path, headers, and method from initialChunk if it's HTTP - // For now, we'll select based on port only - const selectedTarget = this.selectTarget(action.targets, targetSelectionContext); if (!selectedTarget) { logger.log('error', `No matching target found for connection ${connectionId}`, { diff --git a/ts/proxies/smart-proxy/tls-manager.ts b/ts/proxies/smart-proxy/tls-manager.ts index 2b6e36e..437df24 100644 --- a/ts/proxies/smart-proxy/tls-manager.ts +++ b/ts/proxies/smart-proxy/tls-manager.ts @@ -1,5 +1,6 @@ import * as plugins from '../../plugins.js'; import { SniHandler } from '../../tls/sni/sni-handler.js'; +import { ProtocolDetector, TlsDetector } from '../../detection/index.js'; import type { SmartProxy } from './smart-proxy.js'; /** diff --git a/ts/proxies/smart-proxy/utils/route-helpers.ts b/ts/proxies/smart-proxy/utils/route-helpers.ts index b416039..29192fb 100644 --- a/ts/proxies/smart-proxy/utils/route-helpers.ts +++ b/ts/proxies/smart-proxy/utils/route-helpers.ts @@ -21,6 +21,7 @@ import * as plugins from '../../../plugins.js'; import type { IRouteConfig, IRouteMatch, IRouteAction, IRouteTarget, TPortRange, IRouteContext } from '../models/route-types.js'; import { mergeRouteConfigs } from './route-utils.js'; +import { ProtocolDetector, HttpDetector } from '../../../detection/index.js'; /** * Create an HTTP-only route configuration @@ -956,83 +957,91 @@ export const SocketHandlers = { /** * HTTP redirect handler + * Now uses the centralized detection module for HTTP parsing */ httpRedirect: (locationTemplate: string, statusCode: number = 301) => (socket: plugins.net.Socket, context: IRouteContext) => { - let buffer = ''; + const connectionId = ProtocolDetector.createConnectionId({ + socketId: context.connectionId || `${Date.now()}-${Math.random()}` + }); - socket.once('data', (data) => { - buffer += data.toString(); + socket.once('data', async (data) => { + // Use detection module for parsing + const detectionResult = await ProtocolDetector.detectWithConnectionTracking( + data, + connectionId, + { extractFullHeaders: false } // We only need method and path + ); - const lines = buffer.split('\r\n'); - const requestLine = lines[0]; - const [method, path] = requestLine.split(' '); + if (detectionResult.protocol === 'http' && detectionResult.connectionInfo.path) { + const method = detectionResult.connectionInfo.method || 'GET'; + const path = detectionResult.connectionInfo.path || '/'; + + const domain = context.domain || 'localhost'; + const port = context.port; + + let finalLocation = locationTemplate + .replace('{domain}', domain) + .replace('{port}', String(port)) + .replace('{path}', path) + .replace('{clientIp}', context.clientIp); + + const message = `Redirecting to ${finalLocation}`; + const response = [ + `HTTP/1.1 ${statusCode} ${statusCode === 301 ? 'Moved Permanently' : 'Found'}`, + `Location: ${finalLocation}`, + 'Content-Type: text/plain', + `Content-Length: ${message.length}`, + 'Connection: close', + '', + message + ].join('\r\n'); + + socket.write(response); + } else { + // Not a valid HTTP request, close connection + socket.write('HTTP/1.1 400 Bad Request\r\nConnection: close\r\n\r\n'); + } - const domain = context.domain || 'localhost'; - const port = context.port; - - let finalLocation = locationTemplate - .replace('{domain}', domain) - .replace('{port}', String(port)) - .replace('{path}', path) - .replace('{clientIp}', context.clientIp); - - const message = `Redirecting to ${finalLocation}`; - const response = [ - `HTTP/1.1 ${statusCode} ${statusCode === 301 ? 'Moved Permanently' : 'Found'}`, - `Location: ${finalLocation}`, - 'Content-Type: text/plain', - `Content-Length: ${message.length}`, - 'Connection: close', - '', - message - ].join('\r\n'); - - socket.write(response); socket.end(); + // Clean up detection state + ProtocolDetector.cleanupConnections(); }); }, /** * HTTP server handler for ACME challenges and other HTTP needs + * Now uses the centralized detection module for HTTP parsing */ httpServer: (handler: (req: { method: string; url: string; headers: Record; body?: string }, res: { status: (code: number) => void; header: (name: string, value: string) => void; send: (data: string) => void; end: () => void }) => void) => (socket: plugins.net.Socket, context: IRouteContext) => { - let buffer = ''; let requestParsed = false; + const connectionId = ProtocolDetector.createConnectionId({ + socketId: context.connectionId || `${Date.now()}-${Math.random()}` + }); - socket.on('data', (data) => { + const processData = async (data: Buffer) => { if (requestParsed) return; // Only handle the first request - buffer += data.toString(); + // Use HttpDetector for parsing + const detectionResult = await ProtocolDetector.detectWithConnectionTracking( + data, + connectionId, + { extractFullHeaders: true } + ); - // Check if we have a complete HTTP request - const headerEndIndex = buffer.indexOf('\r\n\r\n'); - if (headerEndIndex === -1) return; // Need more data - - requestParsed = true; - - // Parse the HTTP request - const headerPart = buffer.substring(0, headerEndIndex); - const bodyPart = buffer.substring(headerEndIndex + 4); - - const lines = headerPart.split('\r\n'); - const [method, url] = lines[0].split(' '); - - const headers: Record = {}; - for (let i = 1; i < lines.length; i++) { - const colonIndex = lines[i].indexOf(':'); - if (colonIndex > 0) { - const name = lines[i].substring(0, colonIndex).trim().toLowerCase(); - const value = lines[i].substring(colonIndex + 1).trim(); - headers[name] = value; - } + if (detectionResult.protocol !== 'http' || !detectionResult.isComplete) { + // Not a complete HTTP request yet + return; } - // Create request object + requestParsed = true; + const connInfo = detectionResult.connectionInfo; + + // Create request object from detection result const req = { - method: method || 'GET', - url: url || '/', - headers, - body: bodyPart + method: connInfo.method || 'GET', + url: connInfo.path || '/', + headers: connInfo.headers || {}, + body: detectionResult.remainingBuffer?.toString() || '' }; // Create response object @@ -1093,13 +1102,20 @@ export const SocketHandlers = { res.send('Internal Server Error'); } } - }); + }; + + socket.on('data', processData); socket.on('error', () => { if (!requestParsed) { socket.end(); } }); + + socket.on('close', () => { + // Clean up detection state + ProtocolDetector.cleanupConnections(); + }); } };