feat(protocols): refactor protocol utilities into centralized protocols module
Some checks failed
Default (tags) / security (push) Successful in 55s
Default (tags) / test (push) Failing after 30m45s
Default (tags) / release (push) Has been skipped
Default (tags) / metadata (push) Has been skipped

This commit is contained in:
Juergen Kunz
2025-07-21 22:37:45 +00:00
parent d47b048517
commit 36068a6d92
32 changed files with 1155 additions and 394 deletions

View File

@@ -0,0 +1,219 @@
/**
* HTTP Protocol Constants
*/
/**
* HTTP methods
*/
export const HTTP_METHODS = [
'GET', 'POST', 'PUT', 'DELETE', 'PATCH', 'HEAD', 'OPTIONS', 'CONNECT', 'TRACE'
] as const;
export type THttpMethod = typeof HTTP_METHODS[number];
/**
* HTTP version strings
*/
export const HTTP_VERSIONS = ['HTTP/1.0', 'HTTP/1.1', 'HTTP/2', 'HTTP/3'] as const;
export type THttpVersion = typeof HTTP_VERSIONS[number];
/**
* HTTP status codes
*/
export enum HttpStatus {
// 1xx Informational
CONTINUE = 100,
SWITCHING_PROTOCOLS = 101,
PROCESSING = 102,
EARLY_HINTS = 103,
// 2xx Success
OK = 200,
CREATED = 201,
ACCEPTED = 202,
NON_AUTHORITATIVE_INFORMATION = 203,
NO_CONTENT = 204,
RESET_CONTENT = 205,
PARTIAL_CONTENT = 206,
MULTI_STATUS = 207,
ALREADY_REPORTED = 208,
IM_USED = 226,
// 3xx Redirection
MULTIPLE_CHOICES = 300,
MOVED_PERMANENTLY = 301,
FOUND = 302,
SEE_OTHER = 303,
NOT_MODIFIED = 304,
USE_PROXY = 305,
TEMPORARY_REDIRECT = 307,
PERMANENT_REDIRECT = 308,
// 4xx Client Error
BAD_REQUEST = 400,
UNAUTHORIZED = 401,
PAYMENT_REQUIRED = 402,
FORBIDDEN = 403,
NOT_FOUND = 404,
METHOD_NOT_ALLOWED = 405,
NOT_ACCEPTABLE = 406,
PROXY_AUTHENTICATION_REQUIRED = 407,
REQUEST_TIMEOUT = 408,
CONFLICT = 409,
GONE = 410,
LENGTH_REQUIRED = 411,
PRECONDITION_FAILED = 412,
PAYLOAD_TOO_LARGE = 413,
URI_TOO_LONG = 414,
UNSUPPORTED_MEDIA_TYPE = 415,
RANGE_NOT_SATISFIABLE = 416,
EXPECTATION_FAILED = 417,
IM_A_TEAPOT = 418,
MISDIRECTED_REQUEST = 421,
UNPROCESSABLE_ENTITY = 422,
LOCKED = 423,
FAILED_DEPENDENCY = 424,
TOO_EARLY = 425,
UPGRADE_REQUIRED = 426,
PRECONDITION_REQUIRED = 428,
TOO_MANY_REQUESTS = 429,
REQUEST_HEADER_FIELDS_TOO_LARGE = 431,
UNAVAILABLE_FOR_LEGAL_REASONS = 451,
// 5xx Server Error
INTERNAL_SERVER_ERROR = 500,
NOT_IMPLEMENTED = 501,
BAD_GATEWAY = 502,
SERVICE_UNAVAILABLE = 503,
GATEWAY_TIMEOUT = 504,
HTTP_VERSION_NOT_SUPPORTED = 505,
VARIANT_ALSO_NEGOTIATES = 506,
INSUFFICIENT_STORAGE = 507,
LOOP_DETECTED = 508,
NOT_EXTENDED = 510,
NETWORK_AUTHENTICATION_REQUIRED = 511,
}
/**
* HTTP status text mapping
*/
export const HTTP_STATUS_TEXT: Record<HttpStatus, string> = {
// 1xx
[HttpStatus.CONTINUE]: 'Continue',
[HttpStatus.SWITCHING_PROTOCOLS]: 'Switching Protocols',
[HttpStatus.PROCESSING]: 'Processing',
[HttpStatus.EARLY_HINTS]: 'Early Hints',
// 2xx
[HttpStatus.OK]: 'OK',
[HttpStatus.CREATED]: 'Created',
[HttpStatus.ACCEPTED]: 'Accepted',
[HttpStatus.NON_AUTHORITATIVE_INFORMATION]: 'Non-Authoritative Information',
[HttpStatus.NO_CONTENT]: 'No Content',
[HttpStatus.RESET_CONTENT]: 'Reset Content',
[HttpStatus.PARTIAL_CONTENT]: 'Partial Content',
[HttpStatus.MULTI_STATUS]: 'Multi-Status',
[HttpStatus.ALREADY_REPORTED]: 'Already Reported',
[HttpStatus.IM_USED]: 'IM Used',
// 3xx
[HttpStatus.MULTIPLE_CHOICES]: 'Multiple Choices',
[HttpStatus.MOVED_PERMANENTLY]: 'Moved Permanently',
[HttpStatus.FOUND]: 'Found',
[HttpStatus.SEE_OTHER]: 'See Other',
[HttpStatus.NOT_MODIFIED]: 'Not Modified',
[HttpStatus.USE_PROXY]: 'Use Proxy',
[HttpStatus.TEMPORARY_REDIRECT]: 'Temporary Redirect',
[HttpStatus.PERMANENT_REDIRECT]: 'Permanent Redirect',
// 4xx
[HttpStatus.BAD_REQUEST]: 'Bad Request',
[HttpStatus.UNAUTHORIZED]: 'Unauthorized',
[HttpStatus.PAYMENT_REQUIRED]: 'Payment Required',
[HttpStatus.FORBIDDEN]: 'Forbidden',
[HttpStatus.NOT_FOUND]: 'Not Found',
[HttpStatus.METHOD_NOT_ALLOWED]: 'Method Not Allowed',
[HttpStatus.NOT_ACCEPTABLE]: 'Not Acceptable',
[HttpStatus.PROXY_AUTHENTICATION_REQUIRED]: 'Proxy Authentication Required',
[HttpStatus.REQUEST_TIMEOUT]: 'Request Timeout',
[HttpStatus.CONFLICT]: 'Conflict',
[HttpStatus.GONE]: 'Gone',
[HttpStatus.LENGTH_REQUIRED]: 'Length Required',
[HttpStatus.PRECONDITION_FAILED]: 'Precondition Failed',
[HttpStatus.PAYLOAD_TOO_LARGE]: 'Payload Too Large',
[HttpStatus.URI_TOO_LONG]: 'URI Too Long',
[HttpStatus.UNSUPPORTED_MEDIA_TYPE]: 'Unsupported Media Type',
[HttpStatus.RANGE_NOT_SATISFIABLE]: 'Range Not Satisfiable',
[HttpStatus.EXPECTATION_FAILED]: 'Expectation Failed',
[HttpStatus.IM_A_TEAPOT]: "I'm a teapot",
[HttpStatus.MISDIRECTED_REQUEST]: 'Misdirected Request',
[HttpStatus.UNPROCESSABLE_ENTITY]: 'Unprocessable Entity',
[HttpStatus.LOCKED]: 'Locked',
[HttpStatus.FAILED_DEPENDENCY]: 'Failed Dependency',
[HttpStatus.TOO_EARLY]: 'Too Early',
[HttpStatus.UPGRADE_REQUIRED]: 'Upgrade Required',
[HttpStatus.PRECONDITION_REQUIRED]: 'Precondition Required',
[HttpStatus.TOO_MANY_REQUESTS]: 'Too Many Requests',
[HttpStatus.REQUEST_HEADER_FIELDS_TOO_LARGE]: 'Request Header Fields Too Large',
[HttpStatus.UNAVAILABLE_FOR_LEGAL_REASONS]: 'Unavailable For Legal Reasons',
// 5xx
[HttpStatus.INTERNAL_SERVER_ERROR]: 'Internal Server Error',
[HttpStatus.NOT_IMPLEMENTED]: 'Not Implemented',
[HttpStatus.BAD_GATEWAY]: 'Bad Gateway',
[HttpStatus.SERVICE_UNAVAILABLE]: 'Service Unavailable',
[HttpStatus.GATEWAY_TIMEOUT]: 'Gateway Timeout',
[HttpStatus.HTTP_VERSION_NOT_SUPPORTED]: 'HTTP Version Not Supported',
[HttpStatus.VARIANT_ALSO_NEGOTIATES]: 'Variant Also Negotiates',
[HttpStatus.INSUFFICIENT_STORAGE]: 'Insufficient Storage',
[HttpStatus.LOOP_DETECTED]: 'Loop Detected',
[HttpStatus.NOT_EXTENDED]: 'Not Extended',
[HttpStatus.NETWORK_AUTHENTICATION_REQUIRED]: 'Network Authentication Required',
};
/**
* Common HTTP headers
*/
export const HTTP_HEADERS = {
// Request headers
HOST: 'host',
USER_AGENT: 'user-agent',
ACCEPT: 'accept',
ACCEPT_LANGUAGE: 'accept-language',
ACCEPT_ENCODING: 'accept-encoding',
AUTHORIZATION: 'authorization',
CACHE_CONTROL: 'cache-control',
CONNECTION: 'connection',
CONTENT_TYPE: 'content-type',
CONTENT_LENGTH: 'content-length',
COOKIE: 'cookie',
// Response headers
SET_COOKIE: 'set-cookie',
LOCATION: 'location',
SERVER: 'server',
DATE: 'date',
EXPIRES: 'expires',
LAST_MODIFIED: 'last-modified',
ETAG: 'etag',
// CORS headers
ACCESS_CONTROL_ALLOW_ORIGIN: 'access-control-allow-origin',
ACCESS_CONTROL_ALLOW_METHODS: 'access-control-allow-methods',
ACCESS_CONTROL_ALLOW_HEADERS: 'access-control-allow-headers',
// Security headers
STRICT_TRANSPORT_SECURITY: 'strict-transport-security',
X_CONTENT_TYPE_OPTIONS: 'x-content-type-options',
X_FRAME_OPTIONS: 'x-frame-options',
X_XSS_PROTECTION: 'x-xss-protection',
CONTENT_SECURITY_POLICY: 'content-security-policy',
} as const;
/**
* Get HTTP status text
*/
export function getStatusText(status: HttpStatus): string {
return HTTP_STATUS_TEXT[status] || 'Unknown';
}

View File

@@ -0,0 +1,8 @@
/**
* HTTP Protocol Module
* Generic HTTP protocol knowledge and parsing utilities
*/
export * from './constants.js';
export * from './types.js';
export * from './parser.js';

219
ts/protocols/http/parser.ts Normal file
View File

@@ -0,0 +1,219 @@
/**
* HTTP Protocol Parser
* Generic HTTP parsing utilities
*/
import { HTTP_METHODS, type THttpMethod, type THttpVersion } from './constants.js';
import type { IHttpRequestLine, IHttpHeader } from './types.js';
/**
* HTTP parser utilities
*/
export class HttpParser {
/**
* Check if string is a valid HTTP method
*/
static isHttpMethod(str: string): str is THttpMethod {
return HTTP_METHODS.includes(str as THttpMethod);
}
/**
* Parse HTTP request line
*/
static parseRequestLine(line: string): IHttpRequestLine | null {
const parts = line.trim().split(' ');
if (parts.length !== 3) {
return null;
}
const [method, path, version] = parts;
// Validate method
if (!this.isHttpMethod(method)) {
return null;
}
// Validate version
if (!version.startsWith('HTTP/')) {
return null;
}
return {
method: method as THttpMethod,
path,
version: version as THttpVersion
};
}
/**
* Parse HTTP header line
*/
static parseHeaderLine(line: string): IHttpHeader | null {
const colonIndex = line.indexOf(':');
if (colonIndex === -1) {
return null;
}
const name = line.slice(0, colonIndex).trim();
const value = line.slice(colonIndex + 1).trim();
if (!name) {
return null;
}
return { name, value };
}
/**
* Parse HTTP headers from lines
*/
static parseHeaders(lines: string[]): Record<string, string> {
const headers: Record<string, string> = {};
for (const line of lines) {
const header = this.parseHeaderLine(line);
if (header) {
// Convert header names to lowercase for consistency
headers[header.name.toLowerCase()] = header.value;
}
}
return headers;
}
/**
* Extract domain from Host header value
*/
static extractDomainFromHost(hostHeader: string): string {
// Remove port if present
const colonIndex = hostHeader.lastIndexOf(':');
if (colonIndex !== -1) {
// Check if it's not part of IPv6 address
const beforeColon = hostHeader.slice(0, colonIndex);
if (!beforeColon.includes(']')) {
return beforeColon;
}
}
return hostHeader;
}
/**
* Validate domain name
*/
static isValidDomain(domain: string): boolean {
// Basic domain validation
if (!domain || domain.length > 253) {
return false;
}
// Check for valid characters and structure
const domainRegex = /^(?!-)[A-Za-z0-9-]{1,63}(?<!-)(\.[A-Za-z0-9-]{1,63})*$/;
return domainRegex.test(domain);
}
/**
* Extract line from buffer
*/
static extractLine(buffer: Buffer, offset: number = 0): { line: string; nextOffset: number } | null {
// Look for CRLF
const crlfIndex = buffer.indexOf('\r\n', offset);
if (crlfIndex === -1) {
// Look for just LF
const lfIndex = buffer.indexOf('\n', offset);
if (lfIndex === -1) {
return null;
}
return {
line: buffer.slice(offset, lfIndex).toString('utf8'),
nextOffset: lfIndex + 1
};
}
return {
line: buffer.slice(offset, crlfIndex).toString('utf8'),
nextOffset: crlfIndex + 2
};
}
/**
* Check if buffer contains printable ASCII
*/
static isPrintableAscii(buffer: Buffer, length?: number): boolean {
const checkLength = Math.min(length || buffer.length, buffer.length);
for (let i = 0; i < checkLength; i++) {
const byte = buffer[i];
// Allow printable ASCII (32-126) plus tab (9), LF (10), and CR (13)
if (byte < 32 || byte > 126) {
if (byte !== 9 && byte !== 10 && byte !== 13) {
return false;
}
}
}
return true;
}
/**
* Quick check if buffer starts with HTTP method
*/
static quickCheck(buffer: Buffer): boolean {
if (buffer.length < 3) {
return false;
}
// Check common HTTP methods
const start = buffer.slice(0, 7).toString('ascii');
return start.startsWith('GET ') ||
start.startsWith('POST ') ||
start.startsWith('PUT ') ||
start.startsWith('DELETE ') ||
start.startsWith('HEAD ') ||
start.startsWith('OPTIONS') ||
start.startsWith('PATCH ') ||
start.startsWith('CONNECT') ||
start.startsWith('TRACE ');
}
/**
* Parse query string
*/
static parseQueryString(queryString: string): Record<string, string> {
const params: Record<string, string> = {};
if (!queryString) {
return params;
}
// Remove leading '?' if present
if (queryString.startsWith('?')) {
queryString = queryString.slice(1);
}
const pairs = queryString.split('&');
for (const pair of pairs) {
const [key, value] = pair.split('=');
if (key) {
params[decodeURIComponent(key)] = value ? decodeURIComponent(value) : '';
}
}
return params;
}
/**
* Build query string from params
*/
static buildQueryString(params: Record<string, string>): string {
const pairs: string[] = [];
for (const [key, value] of Object.entries(params)) {
pairs.push(`${encodeURIComponent(key)}=${encodeURIComponent(value)}`);
}
return pairs.length > 0 ? '?' + pairs.join('&') : '';
}
}

View File

@@ -0,0 +1,70 @@
/**
* HTTP Protocol Type Definitions
*/
import type { THttpMethod, THttpVersion, HttpStatus } from './constants.js';
/**
* HTTP request line structure
*/
export interface IHttpRequestLine {
method: THttpMethod;
path: string;
version: THttpVersion;
}
/**
* HTTP response line structure
*/
export interface IHttpResponseLine {
version: THttpVersion;
status: HttpStatus;
statusText: string;
}
/**
* HTTP header structure
*/
export interface IHttpHeader {
name: string;
value: string;
}
/**
* HTTP message structure (base for request and response)
*/
export interface IHttpMessage {
headers: Record<string, string>;
body?: Buffer;
}
/**
* HTTP request structure
*/
export interface IHttpRequest extends IHttpMessage {
method: THttpMethod;
path: string;
version: THttpVersion;
query?: Record<string, string>;
}
/**
* HTTP response structure
*/
export interface IHttpResponse extends IHttpMessage {
status: HttpStatus;
statusText: string;
version: THttpVersion;
}
/**
* Parsed URL structure
*/
export interface IParsedUrl {
protocol?: string;
hostname?: string;
port?: number;
path?: string;
query?: string;
fragment?: string;
}

11
ts/protocols/index.ts Normal file
View File

@@ -0,0 +1,11 @@
/**
* Protocol-specific modules for smartproxy
*
* This directory contains generic protocol knowledge separated from
* smartproxy-specific implementation details.
*/
export * as tls from './tls/index.js';
export * as http from './http/index.js';
export * as proxy from './proxy/index.js';
export * as websocket from './websocket/index.js';

View File

@@ -0,0 +1,7 @@
/**
* PROXY Protocol Module
* HAProxy PROXY protocol implementation
*/
export * from './types.js';
export * from './parser.js';

View File

@@ -0,0 +1,183 @@
/**
* PROXY Protocol Parser
* Implementation of HAProxy PROXY protocol v1 (text format)
* Spec: https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt
*/
import type { IProxyInfo, IProxyParseResult, TProxyProtocol } from './types.js';
/**
* PROXY protocol parser
*/
export class ProxyProtocolParser {
static readonly PROXY_V1_SIGNATURE = 'PROXY ';
static readonly MAX_HEADER_LENGTH = 107; // Max length for v1 header
static readonly HEADER_TERMINATOR = '\r\n';
/**
* Parse PROXY protocol v1 header from buffer
* Returns proxy info and remaining data after header
*/
static parse(data: Buffer): IProxyParseResult {
// Check if buffer starts with PROXY signature
if (!data.toString('ascii', 0, 6).startsWith(this.PROXY_V1_SIGNATURE)) {
return {
proxyInfo: null,
remainingData: data
};
}
// Find header terminator
const headerEndIndex = data.indexOf(this.HEADER_TERMINATOR);
if (headerEndIndex === -1) {
// Header incomplete, need more data
if (data.length > this.MAX_HEADER_LENGTH) {
// Header too long, invalid
throw new Error('PROXY protocol header exceeds maximum length');
}
return {
proxyInfo: null,
remainingData: data
};
}
// Extract header line
const headerLine = data.toString('ascii', 0, headerEndIndex);
const remainingData = data.slice(headerEndIndex + 2); // Skip \r\n
// Parse header
const parts = headerLine.split(' ');
if (parts.length < 2) {
throw new Error(`Invalid PROXY protocol header format: ${headerLine}`);
}
const [signature, protocol] = parts;
// Validate protocol
if (!['TCP4', 'TCP6', 'UNKNOWN'].includes(protocol)) {
throw new Error(`Invalid PROXY protocol: ${protocol}`);
}
// For UNKNOWN protocol, ignore addresses
if (protocol === 'UNKNOWN') {
return {
proxyInfo: {
protocol: 'UNKNOWN',
sourceIP: '',
sourcePort: 0,
destinationIP: '',
destinationPort: 0
},
remainingData
};
}
// For TCP4/TCP6, we need all 6 parts
if (parts.length !== 6) {
throw new Error(`Invalid PROXY protocol header format: ${headerLine}`);
}
const [, , srcIP, dstIP, srcPort, dstPort] = parts;
// Validate and parse ports
const sourcePort = parseInt(srcPort, 10);
const destinationPort = parseInt(dstPort, 10);
if (isNaN(sourcePort) || sourcePort < 0 || sourcePort > 65535) {
throw new Error(`Invalid source port: ${srcPort}`);
}
if (isNaN(destinationPort) || destinationPort < 0 || destinationPort > 65535) {
throw new Error(`Invalid destination port: ${dstPort}`);
}
// Validate IP addresses
const protocolType = protocol as TProxyProtocol;
if (!this.isValidIP(srcIP, protocolType)) {
throw new Error(`Invalid source IP for ${protocol}: ${srcIP}`);
}
if (!this.isValidIP(dstIP, protocolType)) {
throw new Error(`Invalid destination IP for ${protocol}: ${dstIP}`);
}
return {
proxyInfo: {
protocol: protocolType,
sourceIP: srcIP,
sourcePort,
destinationIP: dstIP,
destinationPort
},
remainingData
};
}
/**
* Generate PROXY protocol v1 header
*/
static generate(info: IProxyInfo): Buffer {
if (info.protocol === 'UNKNOWN') {
return Buffer.from(`PROXY UNKNOWN\r\n`, 'ascii');
}
const header = `PROXY ${info.protocol} ${info.sourceIP} ${info.destinationIP} ${info.sourcePort} ${info.destinationPort}\r\n`;
if (header.length > this.MAX_HEADER_LENGTH) {
throw new Error('Generated PROXY protocol header exceeds maximum length');
}
return Buffer.from(header, 'ascii');
}
/**
* Validate IP address format
*/
static isValidIP(ip: string, protocol: TProxyProtocol): boolean {
if (protocol === 'TCP4') {
return this.isIPv4(ip);
} else if (protocol === 'TCP6') {
return this.isIPv6(ip);
}
return false;
}
/**
* Check if string is valid IPv4
*/
static isIPv4(ip: string): boolean {
const parts = ip.split('.');
if (parts.length !== 4) return false;
for (const part of parts) {
const num = parseInt(part, 10);
if (isNaN(num) || num < 0 || num > 255 || part !== num.toString()) {
return false;
}
}
return true;
}
/**
* Check if string is valid IPv6
*/
static isIPv6(ip: string): boolean {
// Basic IPv6 validation
const ipv6Regex = /^(([0-9a-fA-F]{1,4}:){7}[0-9a-fA-F]{1,4}|([0-9a-fA-F]{1,4}:){1,7}:|([0-9a-fA-F]{1,4}:){1,6}:[0-9a-fA-F]{1,4}|([0-9a-fA-F]{1,4}:){1,5}(:[0-9a-fA-F]{1,4}){1,2}|([0-9a-fA-F]{1,4}:){1,4}(:[0-9a-fA-F]{1,4}){1,3}|([0-9a-fA-F]{1,4}:){1,3}(:[0-9a-fA-F]{1,4}){1,4}|([0-9a-fA-F]{1,4}:){1,2}(:[0-9a-fA-F]{1,4}){1,5}|[0-9a-fA-F]{1,4}:((:[0-9a-fA-F]{1,4}){1,6})|:((:[0-9a-fA-F]{1,4}){1,7}|:)|fe80:(:[0-9a-fA-F]{0,4}){0,4}%[0-9a-zA-Z]{1,}|::(ffff(:0{1,4}){0,1}:){0,1}((25[0-5]|(2[0-4]|1{0,1}[0-9]){0,1}[0-9])\.){3}(25[0-5]|(2[0-4]|1{0,1}[0-9]){0,1}[0-9])|([0-9a-fA-F]{1,4}:){1,4}:((25[0-5]|(2[0-4]|1{0,1}[0-9]){0,1}[0-9])\.){3}(25[0-5]|(2[0-4]|1{0,1}[0-9]){0,1}[0-9]))$/;
return ipv6Regex.test(ip);
}
/**
* Create a connection ID string for tracking
*/
static createConnectionId(connectionInfo: {
sourceIp?: string;
sourcePort?: number;
destIp?: string;
destPort?: number;
}): string {
const { sourceIp, sourcePort, destIp, destPort } = connectionInfo;
return `${sourceIp}:${sourcePort}-${destIp}:${destPort}`;
}
}

View File

@@ -0,0 +1,53 @@
/**
* PROXY Protocol Type Definitions
* Based on HAProxy PROXY protocol specification
*/
/**
* PROXY protocol version
*/
export type TProxyProtocolVersion = 'v1' | 'v2';
/**
* Connection protocol type
*/
export type TProxyProtocol = 'TCP4' | 'TCP6' | 'UNKNOWN';
/**
* Interface representing parsed PROXY protocol information
*/
export interface IProxyInfo {
protocol: TProxyProtocol;
sourceIP: string;
sourcePort: number;
destinationIP: string;
destinationPort: number;
}
/**
* Interface for parse result including remaining data
*/
export interface IProxyParseResult {
proxyInfo: IProxyInfo | null;
remainingData: Buffer;
}
/**
* PROXY protocol v2 header format
*/
export interface IProxyV2Header {
signature: Buffer;
versionCommand: number;
family: number;
length: number;
}
/**
* Connection information for PROXY protocol
*/
export interface IProxyConnectionInfo {
sourceIp?: string;
sourcePort?: number;
destIp?: string;
destPort?: number;
}

View File

@@ -0,0 +1,3 @@
/**
* TLS alerts
*/

View File

@@ -0,0 +1,259 @@
import * as plugins from '../../../plugins.js';
import { TlsAlertLevel, TlsAlertDescription, TlsVersion } from '../utils/tls-utils.js';
/**
* TlsAlert class for creating and sending TLS alert messages
*/
export class TlsAlert {
// Use enum values from TlsAlertLevel
static readonly LEVEL_WARNING = TlsAlertLevel.WARNING;
static readonly LEVEL_FATAL = TlsAlertLevel.FATAL;
// Use enum values from TlsAlertDescription
static readonly CLOSE_NOTIFY = TlsAlertDescription.CLOSE_NOTIFY;
static readonly UNEXPECTED_MESSAGE = TlsAlertDescription.UNEXPECTED_MESSAGE;
static readonly BAD_RECORD_MAC = TlsAlertDescription.BAD_RECORD_MAC;
static readonly DECRYPTION_FAILED = TlsAlertDescription.DECRYPTION_FAILED;
static readonly RECORD_OVERFLOW = TlsAlertDescription.RECORD_OVERFLOW;
static readonly DECOMPRESSION_FAILURE = TlsAlertDescription.DECOMPRESSION_FAILURE;
static readonly HANDSHAKE_FAILURE = TlsAlertDescription.HANDSHAKE_FAILURE;
static readonly NO_CERTIFICATE = TlsAlertDescription.NO_CERTIFICATE;
static readonly BAD_CERTIFICATE = TlsAlertDescription.BAD_CERTIFICATE;
static readonly UNSUPPORTED_CERTIFICATE = TlsAlertDescription.UNSUPPORTED_CERTIFICATE;
static readonly CERTIFICATE_REVOKED = TlsAlertDescription.CERTIFICATE_REVOKED;
static readonly CERTIFICATE_EXPIRED = TlsAlertDescription.CERTIFICATE_EXPIRED;
static readonly CERTIFICATE_UNKNOWN = TlsAlertDescription.CERTIFICATE_UNKNOWN;
static readonly ILLEGAL_PARAMETER = TlsAlertDescription.ILLEGAL_PARAMETER;
static readonly UNKNOWN_CA = TlsAlertDescription.UNKNOWN_CA;
static readonly ACCESS_DENIED = TlsAlertDescription.ACCESS_DENIED;
static readonly DECODE_ERROR = TlsAlertDescription.DECODE_ERROR;
static readonly DECRYPT_ERROR = TlsAlertDescription.DECRYPT_ERROR;
static readonly EXPORT_RESTRICTION = TlsAlertDescription.EXPORT_RESTRICTION;
static readonly PROTOCOL_VERSION = TlsAlertDescription.PROTOCOL_VERSION;
static readonly INSUFFICIENT_SECURITY = TlsAlertDescription.INSUFFICIENT_SECURITY;
static readonly INTERNAL_ERROR = TlsAlertDescription.INTERNAL_ERROR;
static readonly INAPPROPRIATE_FALLBACK = TlsAlertDescription.INAPPROPRIATE_FALLBACK;
static readonly USER_CANCELED = TlsAlertDescription.USER_CANCELED;
static readonly NO_RENEGOTIATION = TlsAlertDescription.NO_RENEGOTIATION;
static readonly MISSING_EXTENSION = TlsAlertDescription.MISSING_EXTENSION;
static readonly UNSUPPORTED_EXTENSION = TlsAlertDescription.UNSUPPORTED_EXTENSION;
static readonly CERTIFICATE_REQUIRED = TlsAlertDescription.CERTIFICATE_REQUIRED;
static readonly UNRECOGNIZED_NAME = TlsAlertDescription.UNRECOGNIZED_NAME;
static readonly BAD_CERTIFICATE_STATUS_RESPONSE = TlsAlertDescription.BAD_CERTIFICATE_STATUS_RESPONSE;
static readonly BAD_CERTIFICATE_HASH_VALUE = TlsAlertDescription.BAD_CERTIFICATE_HASH_VALUE;
static readonly UNKNOWN_PSK_IDENTITY = TlsAlertDescription.UNKNOWN_PSK_IDENTITY;
static readonly CERTIFICATE_REQUIRED_1_3 = TlsAlertDescription.CERTIFICATE_REQUIRED_1_3;
static readonly NO_APPLICATION_PROTOCOL = TlsAlertDescription.NO_APPLICATION_PROTOCOL;
/**
* Create a TLS alert buffer with the specified level and description code
*
* @param level Alert level (warning or fatal)
* @param description Alert description code
* @param tlsVersion TLS version bytes (default is TLS 1.2: 0x0303)
* @returns Buffer containing the TLS alert message
*/
static create(
level: number,
description: number,
tlsVersion: [number, number] = [TlsVersion.TLS1_2[0], TlsVersion.TLS1_2[1]]
): Buffer {
return Buffer.from([
0x15, // Alert record type
tlsVersion[0],
tlsVersion[1], // TLS version (default to TLS 1.2: 0x0303)
0x00,
0x02, // Length
level, // Alert level
description, // Alert description
]);
}
/**
* Create a warning-level TLS alert
*
* @param description Alert description code
* @returns Buffer containing the warning-level TLS alert message
*/
static createWarning(description: number): Buffer {
return this.create(this.LEVEL_WARNING, description);
}
/**
* Create a fatal-level TLS alert
*
* @param description Alert description code
* @returns Buffer containing the fatal-level TLS alert message
*/
static createFatal(description: number): Buffer {
return this.create(this.LEVEL_FATAL, description);
}
/**
* Send a TLS alert to a socket and optionally close the connection
*
* @param socket The socket to send the alert to
* @param level Alert level (warning or fatal)
* @param description Alert description code
* @param closeAfterSend Whether to close the connection after sending the alert
* @param closeDelay Milliseconds to wait before closing the connection (default: 200ms)
* @returns Promise that resolves when the alert has been sent
*/
static async send(
socket: plugins.net.Socket,
level: number,
description: number,
closeAfterSend: boolean = false,
closeDelay: number = 200
): Promise<void> {
const alert = this.create(level, description);
return new Promise<void>((resolve, reject) => {
try {
// Ensure the alert is written as a single packet
socket.cork();
const writeSuccessful = socket.write(alert, (err) => {
if (err) {
reject(err);
return;
}
if (closeAfterSend) {
setTimeout(() => {
socket.end();
resolve();
}, closeDelay);
} else {
resolve();
}
});
socket.uncork();
// If write wasn't successful immediately, wait for drain
if (!writeSuccessful && !closeAfterSend) {
socket.once('drain', () => {
resolve();
});
}
} catch (err) {
reject(err);
}
});
}
/**
* Pre-defined TLS alert messages
*/
static readonly alerts = {
// Warning level alerts
closeNotify: TlsAlert.createWarning(TlsAlert.CLOSE_NOTIFY),
unsupportedExtension: TlsAlert.createWarning(TlsAlert.UNSUPPORTED_EXTENSION),
certificateRequired: TlsAlert.createWarning(TlsAlert.CERTIFICATE_REQUIRED),
unrecognizedName: TlsAlert.createWarning(TlsAlert.UNRECOGNIZED_NAME),
noRenegotiation: TlsAlert.createWarning(TlsAlert.NO_RENEGOTIATION),
userCanceled: TlsAlert.createWarning(TlsAlert.USER_CANCELED),
// Warning level alerts for session resumption
certificateExpiredWarning: TlsAlert.createWarning(TlsAlert.CERTIFICATE_EXPIRED),
handshakeFailureWarning: TlsAlert.createWarning(TlsAlert.HANDSHAKE_FAILURE),
insufficientSecurityWarning: TlsAlert.createWarning(TlsAlert.INSUFFICIENT_SECURITY),
// Fatal level alerts
unexpectedMessage: TlsAlert.createFatal(TlsAlert.UNEXPECTED_MESSAGE),
badRecordMac: TlsAlert.createFatal(TlsAlert.BAD_RECORD_MAC),
recordOverflow: TlsAlert.createFatal(TlsAlert.RECORD_OVERFLOW),
handshakeFailure: TlsAlert.createFatal(TlsAlert.HANDSHAKE_FAILURE),
badCertificate: TlsAlert.createFatal(TlsAlert.BAD_CERTIFICATE),
certificateExpired: TlsAlert.createFatal(TlsAlert.CERTIFICATE_EXPIRED),
certificateUnknown: TlsAlert.createFatal(TlsAlert.CERTIFICATE_UNKNOWN),
illegalParameter: TlsAlert.createFatal(TlsAlert.ILLEGAL_PARAMETER),
unknownCA: TlsAlert.createFatal(TlsAlert.UNKNOWN_CA),
accessDenied: TlsAlert.createFatal(TlsAlert.ACCESS_DENIED),
decodeError: TlsAlert.createFatal(TlsAlert.DECODE_ERROR),
decryptError: TlsAlert.createFatal(TlsAlert.DECRYPT_ERROR),
protocolVersion: TlsAlert.createFatal(TlsAlert.PROTOCOL_VERSION),
insufficientSecurity: TlsAlert.createFatal(TlsAlert.INSUFFICIENT_SECURITY),
internalError: TlsAlert.createFatal(TlsAlert.INTERNAL_ERROR),
unrecognizedNameFatal: TlsAlert.createFatal(TlsAlert.UNRECOGNIZED_NAME),
};
/**
* Utility method to send a warning-level unrecognized_name alert
* Specifically designed for SNI issues to encourage the client to retry with SNI
*
* @param socket The socket to send the alert to
* @returns Promise that resolves when the alert has been sent
*/
static async sendSniRequired(socket: plugins.net.Socket): Promise<void> {
return this.send(socket, this.LEVEL_WARNING, this.UNRECOGNIZED_NAME);
}
/**
* Utility method to send a close_notify alert and close the connection
*
* @param socket The socket to send the alert to
* @param closeDelay Milliseconds to wait before closing the connection (default: 200ms)
* @returns Promise that resolves when the alert has been sent and the connection closed
*/
static async sendCloseNotify(socket: plugins.net.Socket, closeDelay: number = 200): Promise<void> {
return this.send(socket, this.LEVEL_WARNING, this.CLOSE_NOTIFY, true, closeDelay);
}
/**
* Utility method to send a certificate_expired alert to force new TLS session
*
* @param socket The socket to send the alert to
* @param fatal Whether to send as a fatal alert (default: false)
* @param closeAfterSend Whether to close the connection after sending the alert (default: true)
* @param closeDelay Milliseconds to wait before closing the connection (default: 200ms)
* @returns Promise that resolves when the alert has been sent
*/
static async sendCertificateExpired(
socket: plugins.net.Socket,
fatal: boolean = false,
closeAfterSend: boolean = true,
closeDelay: number = 200
): Promise<void> {
const level = fatal ? this.LEVEL_FATAL : this.LEVEL_WARNING;
return this.send(socket, level, this.CERTIFICATE_EXPIRED, closeAfterSend, closeDelay);
}
/**
* Send a sequence of alerts to force SNI from clients
* This combines multiple alerts to ensure maximum browser compatibility
*
* @param socket The socket to send the alerts to
* @returns Promise that resolves when all alerts have been sent
*/
static async sendForceSniSequence(socket: plugins.net.Socket): Promise<void> {
try {
// Send unrecognized_name (warning)
socket.cork();
socket.write(this.alerts.unrecognizedName);
socket.uncork();
// Give the socket time to send the alert
return new Promise((resolve) => {
setTimeout(resolve, 50);
});
} catch (err) {
return Promise.reject(err);
}
}
/**
* Send a fatal level alert that immediately terminates the connection
*
* @param socket The socket to send the alert to
* @param description Alert description code
* @param closeDelay Milliseconds to wait before closing the connection (default: 100ms)
* @returns Promise that resolves when the alert has been sent and the connection closed
*/
static async sendFatalAndClose(
socket: plugins.net.Socket,
description: number,
closeDelay: number = 100
): Promise<void> {
return this.send(socket, this.LEVEL_FATAL, description, true, closeDelay);
}
}

37
ts/protocols/tls/index.ts Normal file
View File

@@ -0,0 +1,37 @@
/**
* TLS Protocol Module
* Contains generic TLS protocol knowledge including parsers, constants, and utilities
*/
// Export all sub-modules
export * from './alerts/index.js';
export * from './sni/index.js';
export * from './utils/index.js';
// Re-export main utilities and types for convenience
export {
TlsUtils,
TlsRecordType,
TlsHandshakeType,
TlsExtensionType,
TlsAlertLevel,
TlsAlertDescription,
TlsVersion
} from './utils/tls-utils.js';
export { TlsAlert } from './alerts/tls-alert.js';
export { ClientHelloParser } from './sni/client-hello-parser.js';
export { SniExtraction } from './sni/sni-extraction.js';
// Export tlsVersionToString helper
export function tlsVersionToString(major: number, minor: number): string | null {
if (major === 0x03) {
switch (minor) {
case 0x00: return 'SSLv3';
case 0x01: return 'TLSv1.0';
case 0x02: return 'TLSv1.1';
case 0x03: return 'TLSv1.2';
case 0x04: return 'TLSv1.3';
}
}
return null;
}

View File

@@ -0,0 +1,629 @@
import { Buffer } from 'buffer';
import {
TlsRecordType,
TlsHandshakeType,
TlsExtensionType,
TlsUtils
} from '../utils/tls-utils.js';
/**
* Interface for logging functions used by the parser
*/
export type LoggerFunction = (message: string) => void;
/**
* Result of a session resumption check
*/
export interface SessionResumptionResult {
isResumption: boolean;
hasSNI: boolean;
}
/**
* Information about parsed TLS extensions
*/
export interface ExtensionInfo {
type: number;
length: number;
data: Buffer;
}
/**
* Result of a ClientHello parse operation
*/
export interface ClientHelloParseResult {
isValid: boolean;
version?: [number, number];
random?: Buffer;
sessionId?: Buffer;
hasSessionId: boolean;
cipherSuites?: Buffer;
compressionMethods?: Buffer;
extensions: ExtensionInfo[];
serverNameList?: string[];
hasSessionTicket: boolean;
hasPsk: boolean;
hasEarlyData: boolean;
error?: string;
}
/**
* Fragment tracking information
*/
export interface FragmentTrackingInfo {
buffer: Buffer;
timestamp: number;
connectionId: string;
}
/**
* Class for parsing TLS ClientHello messages
*/
export class ClientHelloParser {
// Buffer for handling fragmented ClientHello messages
private static fragmentedBuffers: Map<string, FragmentTrackingInfo> = new Map();
private static fragmentTimeout: number = 1000; // ms to wait for fragments before cleanup
/**
* Clean up expired fragments
*/
private static cleanupExpiredFragments(): void {
const now = Date.now();
for (const [connectionId, info] of this.fragmentedBuffers.entries()) {
if (now - info.timestamp > this.fragmentTimeout) {
this.fragmentedBuffers.delete(connectionId);
}
}
}
/**
* Handles potential fragmented ClientHello messages by buffering and reassembling
* TLS record fragments that might span multiple TCP packets.
*
* @param buffer The current buffer fragment
* @param connectionId Unique identifier for the connection
* @param logger Optional logging function
* @returns A complete buffer if reassembly is successful, or undefined if more fragments are needed
*/
public static handleFragmentedClientHello(
buffer: Buffer,
connectionId: string,
logger?: LoggerFunction
): Buffer | undefined {
const log = logger || (() => {});
// Periodically clean up expired fragments
this.cleanupExpiredFragments();
// Check if we've seen this connection before
if (!this.fragmentedBuffers.has(connectionId)) {
// New connection, start with this buffer
this.fragmentedBuffers.set(connectionId, {
buffer,
timestamp: Date.now(),
connectionId
});
// Evaluate if this buffer already contains a complete ClientHello
try {
if (buffer.length >= 5) {
// Get the record length from TLS header
const recordLength = (buffer[3] << 8) + buffer[4] + 5; // +5 for the TLS record header itself
log(`Initial buffer size: ${buffer.length}, expected record length: ${recordLength}`);
// Check if this buffer already contains a complete TLS record
if (buffer.length >= recordLength) {
log(`Initial buffer contains complete ClientHello, length: ${buffer.length}`);
return buffer;
}
} else {
log(
`Initial buffer too small (${buffer.length} bytes), needs at least 5 bytes for TLS header`
);
}
} catch (e) {
log(`Error checking initial buffer completeness: ${e}`);
}
log(`Started buffering connection ${connectionId}, initial size: ${buffer.length}`);
return undefined; // Need more fragments
} else {
// Existing connection, append this buffer
const existingInfo = this.fragmentedBuffers.get(connectionId)!;
const newBuffer = Buffer.concat([existingInfo.buffer, buffer]);
// Update the buffer and timestamp
this.fragmentedBuffers.set(connectionId, {
...existingInfo,
buffer: newBuffer,
timestamp: Date.now()
});
log(`Appended to buffer for ${connectionId}, new size: ${newBuffer.length}`);
// Check if we now have a complete ClientHello
try {
if (newBuffer.length >= 5) {
// Get the record length from TLS header
const recordLength = (newBuffer[3] << 8) + newBuffer[4] + 5; // +5 for the TLS record header itself
log(
`Reassembled buffer size: ${newBuffer.length}, expected record length: ${recordLength}`
);
// Check if we have a complete TLS record now
if (newBuffer.length >= recordLength) {
log(
`Assembled complete ClientHello, length: ${newBuffer.length}, needed: ${recordLength}`
);
// Extract the complete TLS record (might be followed by more data)
const completeRecord = newBuffer.slice(0, recordLength);
// Check if this record is indeed a ClientHello (type 1) at position 5
if (
completeRecord.length > 5 &&
completeRecord[5] === TlsHandshakeType.CLIENT_HELLO
) {
log(`Verified record is a ClientHello handshake message`);
// Complete message received, remove from tracking
this.fragmentedBuffers.delete(connectionId);
return completeRecord;
} else {
log(`Record is complete but not a ClientHello handshake, continuing to buffer`);
// This might be another TLS record type preceding the ClientHello
// Try checking for a ClientHello starting at the end of this record
if (newBuffer.length > recordLength + 5) {
const nextRecordType = newBuffer[recordLength];
log(
`Next record type: ${nextRecordType} (looking for ${TlsRecordType.HANDSHAKE})`
);
if (nextRecordType === TlsRecordType.HANDSHAKE) {
const handshakeType = newBuffer[recordLength + 5];
log(
`Next handshake type: ${handshakeType} (looking for ${TlsHandshakeType.CLIENT_HELLO})`
);
if (handshakeType === TlsHandshakeType.CLIENT_HELLO) {
// Found a ClientHello in the next record, return the entire buffer
log(`Found ClientHello in subsequent record, returning full buffer`);
this.fragmentedBuffers.delete(connectionId);
return newBuffer;
}
}
}
}
}
}
} catch (e) {
log(`Error checking reassembled buffer completeness: ${e}`);
}
return undefined; // Still need more fragments
}
}
/**
* Parses a TLS ClientHello message and extracts all components
*
* @param buffer The buffer containing the ClientHello message
* @param logger Optional logging function
* @returns Parsed ClientHello or undefined if parsing failed
*/
public static parseClientHello(
buffer: Buffer,
logger?: LoggerFunction
): ClientHelloParseResult {
const log = logger || (() => {});
const result: ClientHelloParseResult = {
isValid: false,
hasSessionId: false,
extensions: [],
hasSessionTicket: false,
hasPsk: false,
hasEarlyData: false
};
try {
// Check basic validity
if (buffer.length < 5) {
result.error = 'Buffer too small for TLS record header';
return result;
}
// Check record type (must be HANDSHAKE)
if (buffer[0] !== TlsRecordType.HANDSHAKE) {
result.error = `Not a TLS handshake record: ${buffer[0]}`;
return result;
}
// Get TLS version from record header
const majorVersion = buffer[1];
const minorVersion = buffer[2];
result.version = [majorVersion, minorVersion];
log(`TLS record version: ${majorVersion}.${minorVersion}`);
// Parse record length (bytes 3-4, big-endian)
const recordLength = (buffer[3] << 8) + buffer[4];
log(`Record length: ${recordLength}`);
// Validate record length against buffer size
if (buffer.length < recordLength + 5) {
result.error = 'Buffer smaller than expected record length';
return result;
}
// Start of handshake message in the buffer
let pos = 5;
// Check handshake type (must be CLIENT_HELLO)
if (buffer[pos] !== TlsHandshakeType.CLIENT_HELLO) {
result.error = `Not a ClientHello message: ${buffer[pos]}`;
return result;
}
// Skip handshake type (1 byte)
pos += 1;
// Parse handshake length (3 bytes, big-endian)
const handshakeLength = (buffer[pos] << 16) + (buffer[pos + 1] << 8) + buffer[pos + 2];
log(`Handshake length: ${handshakeLength}`);
// Skip handshake length (3 bytes)
pos += 3;
// Check client version (2 bytes)
const clientMajorVersion = buffer[pos];
const clientMinorVersion = buffer[pos + 1];
log(`Client version: ${clientMajorVersion}.${clientMinorVersion}`);
// Skip client version (2 bytes)
pos += 2;
// Extract client random (32 bytes)
if (pos + 32 > buffer.length) {
result.error = 'Buffer too small for client random';
return result;
}
result.random = buffer.slice(pos, pos + 32);
log(`Client random: ${result.random.toString('hex')}`);
// Skip client random (32 bytes)
pos += 32;
// Parse session ID
if (pos + 1 > buffer.length) {
result.error = 'Buffer too small for session ID length';
return result;
}
const sessionIdLength = buffer[pos];
log(`Session ID length: ${sessionIdLength}`);
pos += 1;
result.hasSessionId = sessionIdLength > 0;
if (sessionIdLength > 0) {
if (pos + sessionIdLength > buffer.length) {
result.error = 'Buffer too small for session ID';
return result;
}
result.sessionId = buffer.slice(pos, pos + sessionIdLength);
log(`Session ID: ${result.sessionId.toString('hex')}`);
}
// Skip session ID
pos += sessionIdLength;
// Check if we have enough bytes left for cipher suites
if (pos + 2 > buffer.length) {
result.error = 'Buffer too small for cipher suites length';
return result;
}
// Parse cipher suites length (2 bytes, big-endian)
const cipherSuitesLength = (buffer[pos] << 8) + buffer[pos + 1];
log(`Cipher suites length: ${cipherSuitesLength}`);
pos += 2;
// Extract cipher suites
if (pos + cipherSuitesLength > buffer.length) {
result.error = 'Buffer too small for cipher suites';
return result;
}
result.cipherSuites = buffer.slice(pos, pos + cipherSuitesLength);
// Skip cipher suites
pos += cipherSuitesLength;
// Check if we have enough bytes left for compression methods
if (pos + 1 > buffer.length) {
result.error = 'Buffer too small for compression methods length';
return result;
}
// Parse compression methods length (1 byte)
const compressionMethodsLength = buffer[pos];
log(`Compression methods length: ${compressionMethodsLength}`);
pos += 1;
// Extract compression methods
if (pos + compressionMethodsLength > buffer.length) {
result.error = 'Buffer too small for compression methods';
return result;
}
result.compressionMethods = buffer.slice(pos, pos + compressionMethodsLength);
// Skip compression methods
pos += compressionMethodsLength;
// Check if we have enough bytes for extensions length
if (pos + 2 > buffer.length) {
// No extensions present - this is valid for older TLS versions
result.isValid = true;
return result;
}
// Parse extensions length (2 bytes, big-endian)
const extensionsLength = (buffer[pos] << 8) + buffer[pos + 1];
log(`Extensions length: ${extensionsLength}`);
pos += 2;
// Extensions end position
const extensionsEnd = pos + extensionsLength;
// Check if extensions length is valid
if (extensionsEnd > buffer.length) {
result.error = 'Extensions length exceeds buffer size';
return result;
}
// Iterate through extensions
const serverNames: string[] = [];
while (pos + 4 <= extensionsEnd) {
// Parse extension type (2 bytes, big-endian)
const extensionType = (buffer[pos] << 8) + buffer[pos + 1];
log(`Extension type: 0x${extensionType.toString(16).padStart(4, '0')}`);
pos += 2;
// Parse extension length (2 bytes, big-endian)
const extensionLength = (buffer[pos] << 8) + buffer[pos + 1];
log(`Extension length: ${extensionLength}`);
pos += 2;
// Extract extension data
if (pos + extensionLength > extensionsEnd) {
result.error = `Extension ${extensionType} data exceeds bounds`;
return result;
}
const extensionData = buffer.slice(pos, pos + extensionLength);
// Record all extensions
result.extensions.push({
type: extensionType,
length: extensionLength,
data: extensionData
});
// Track specific extension types
if (extensionType === TlsExtensionType.SERVER_NAME) {
// Server Name Indication (SNI)
this.parseServerNameExtension(extensionData, serverNames, logger);
} else if (extensionType === TlsExtensionType.SESSION_TICKET) {
// Session ticket
result.hasSessionTicket = true;
} else if (extensionType === TlsExtensionType.PRE_SHARED_KEY) {
// TLS 1.3 PSK
result.hasPsk = true;
} else if (extensionType === TlsExtensionType.EARLY_DATA) {
// TLS 1.3 Early Data (0-RTT)
result.hasEarlyData = true;
}
// Move to next extension
pos += extensionLength;
}
// Store any server names found
if (serverNames.length > 0) {
result.serverNameList = serverNames;
}
// Mark as valid if we get here
result.isValid = true;
return result;
} catch (error) {
const errorMessage = error instanceof Error ? error.message : String(error);
log(`Error parsing ClientHello: ${errorMessage}`);
result.error = errorMessage;
return result;
}
}
/**
* Parses the server name extension data and extracts hostnames
*
* @param data Extension data buffer
* @param serverNames Array to populate with found server names
* @param logger Optional logging function
* @returns true if parsing succeeded
*/
private static parseServerNameExtension(
data: Buffer,
serverNames: string[],
logger?: LoggerFunction
): boolean {
const log = logger || (() => {});
try {
// Need at least 2 bytes for server name list length
if (data.length < 2) {
log('SNI extension too small for server name list length');
return false;
}
// Parse server name list length (2 bytes)
const listLength = (data[0] << 8) + data[1];
// Skip to first name entry
let pos = 2;
// End of list
const listEnd = pos + listLength;
// Validate length
if (listEnd > data.length) {
log('SNI server name list exceeds extension data');
return false;
}
// Process all name entries
while (pos + 3 <= listEnd) {
// Name type (1 byte)
const nameType = data[pos];
pos += 1;
// For hostname, type must be 0
if (nameType !== 0) {
// Skip this entry
if (pos + 2 <= listEnd) {
const nameLength = (data[pos] << 8) + data[pos + 1];
pos += 2 + nameLength;
continue;
} else {
log('Malformed SNI entry');
return false;
}
}
// Parse hostname length (2 bytes)
if (pos + 2 > listEnd) {
log('SNI extension truncated');
return false;
}
const nameLength = (data[pos] << 8) + data[pos + 1];
pos += 2;
// Extract hostname
if (pos + nameLength > listEnd) {
log('SNI hostname truncated');
return false;
}
// Extract the hostname as UTF-8
try {
const hostname = data.slice(pos, pos + nameLength).toString('utf8');
log(`Found SNI hostname: ${hostname}`);
serverNames.push(hostname);
} catch (err) {
log(`Error extracting hostname: ${err}`);
}
// Move to next entry
pos += nameLength;
}
return serverNames.length > 0;
} catch (error) {
log(`Error parsing SNI extension: ${error}`);
return false;
}
}
/**
* Determines if a ClientHello contains session resumption indicators
*
* @param buffer The ClientHello buffer
* @param logger Optional logging function
* @returns Session resumption result
*/
public static hasSessionResumption(
buffer: Buffer,
logger?: LoggerFunction
): SessionResumptionResult {
const log = logger || (() => {});
if (!TlsUtils.isClientHello(buffer)) {
return { isResumption: false, hasSNI: false };
}
const parseResult = this.parseClientHello(buffer, logger);
if (!parseResult.isValid) {
log(`ClientHello parse failed: ${parseResult.error}`);
return { isResumption: false, hasSNI: false };
}
// Check resumption indicators
const hasSessionId = parseResult.hasSessionId;
const hasSessionTicket = parseResult.hasSessionTicket;
const hasPsk = parseResult.hasPsk;
const hasEarlyData = parseResult.hasEarlyData;
// Check for SNI
const hasSNI = !!parseResult.serverNameList && parseResult.serverNameList.length > 0;
// Consider it a resumption if any resumption mechanism is present
const isResumption = hasSessionTicket || hasPsk || hasEarlyData ||
(hasSessionId && !hasPsk); // Legacy resumption
// Log details
if (isResumption) {
log(
'Session resumption detected: ' +
(hasSessionTicket ? 'session ticket, ' : '') +
(hasPsk ? 'PSK, ' : '') +
(hasEarlyData ? 'early data, ' : '') +
(hasSessionId ? 'session ID' : '') +
(hasSNI ? ', with SNI' : ', without SNI')
);
}
return { isResumption, hasSNI };
}
/**
* Checks if a ClientHello appears to be from a tab reactivation
*
* @param buffer The ClientHello buffer
* @param logger Optional logging function
* @returns true if it appears to be a tab reactivation
*/
public static isTabReactivationHandshake(
buffer: Buffer,
logger?: LoggerFunction
): boolean {
const log = logger || (() => {});
if (!TlsUtils.isClientHello(buffer)) {
return false;
}
// Parse the ClientHello
const parseResult = this.parseClientHello(buffer, logger);
if (!parseResult.isValid) {
return false;
}
// Tab reactivation pattern: session identifier + (ticket or PSK) but no SNI
const hasSessionId = parseResult.hasSessionId;
const hasSessionTicket = parseResult.hasSessionTicket;
const hasPsk = parseResult.hasPsk;
const hasSNI = !!parseResult.serverNameList && parseResult.serverNameList.length > 0;
if ((hasSessionId && (hasSessionTicket || hasPsk)) && !hasSNI) {
log('Detected tab reactivation pattern: session resumption without SNI');
return true;
}
return false;
}
}

View File

@@ -0,0 +1,6 @@
/**
* TLS SNI (Server Name Indication) protocol utilities
*/
export * from './client-hello-parser.js';
export * from './sni-extraction.js';

View File

@@ -0,0 +1,353 @@
import { Buffer } from 'buffer';
import { TlsExtensionType, TlsUtils } from '../utils/tls-utils.js';
import {
ClientHelloParser,
type LoggerFunction
} from './client-hello-parser.js';
/**
* Connection tracking information
*/
export interface ConnectionInfo {
sourceIp: string;
sourcePort: number;
destIp: string;
destPort: number;
timestamp?: number;
}
/**
* Utilities for extracting SNI information from TLS handshakes
*/
export class SniExtraction {
/**
* Extracts the SNI (Server Name Indication) from a TLS ClientHello message.
*
* @param buffer The buffer containing the TLS ClientHello message
* @param logger Optional logging function
* @returns The extracted server name or undefined if not found
*/
public static extractSNI(buffer: Buffer, logger?: LoggerFunction): string | undefined {
const log = logger || (() => {});
try {
// Parse the ClientHello
const parseResult = ClientHelloParser.parseClientHello(buffer, logger);
if (!parseResult.isValid) {
log(`Failed to parse ClientHello: ${parseResult.error}`);
return undefined;
}
// Check if ServerName extension was found
if (parseResult.serverNameList && parseResult.serverNameList.length > 0) {
// Use the first hostname (most common case)
const serverName = parseResult.serverNameList[0];
log(`Found SNI: ${serverName}`);
return serverName;
}
log('No SNI extension found in ClientHello');
return undefined;
} catch (error) {
log(`Error extracting SNI: ${error instanceof Error ? error.message : String(error)}`);
return undefined;
}
}
/**
* Attempts to extract SNI from the PSK extension in a TLS 1.3 ClientHello.
*
* In TLS 1.3, when a client attempts to resume a session, it may include
* the server name in the PSK identity hint rather than in the SNI extension.
*
* @param buffer The buffer containing the TLS ClientHello message
* @param logger Optional logging function
* @returns The extracted server name or undefined if not found
*/
public static extractSNIFromPSKExtension(
buffer: Buffer,
logger?: LoggerFunction
): string | undefined {
const log = logger || (() => {});
try {
// Ensure this is a ClientHello
if (!TlsUtils.isClientHello(buffer)) {
log('Not a ClientHello message');
return undefined;
}
// Parse the ClientHello to find PSK extension
const parseResult = ClientHelloParser.parseClientHello(buffer, logger);
if (!parseResult.isValid || !parseResult.extensions) {
return undefined;
}
// Find the PSK extension
const pskExtension = parseResult.extensions.find(ext =>
ext.type === TlsExtensionType.PRE_SHARED_KEY);
if (!pskExtension) {
log('No PSK extension found');
return undefined;
}
// Parse the PSK extension data
const data = pskExtension.data;
// PSK extension structure:
// 2 bytes: identities list length
if (data.length < 2) return undefined;
const identitiesLength = (data[0] << 8) + data[1];
let pos = 2;
// End of identities list
const identitiesEnd = pos + identitiesLength;
if (identitiesEnd > data.length) return undefined;
// Process each PSK identity
while (pos + 2 <= identitiesEnd) {
// Identity length (2 bytes)
if (pos + 2 > identitiesEnd) break;
const identityLength = (data[pos] << 8) + data[pos + 1];
pos += 2;
if (pos + identityLength > identitiesEnd) break;
// Try to extract hostname from identity
// Chrome often embeds the hostname in the PSK identity
// This is a heuristic as there's no standard format
if (identityLength > 0) {
const identity = data.slice(pos, pos + identityLength);
// Skip identity bytes
pos += identityLength;
// Skip obfuscated ticket age (4 bytes)
if (pos + 4 <= identitiesEnd) {
pos += 4;
} else {
break;
}
// Try to parse the identity as UTF-8
try {
const identityStr = identity.toString('utf8');
log(`PSK identity: ${identityStr}`);
// Check if the identity contains hostname hints
// Chrome often embeds the hostname in a known format
// Try to extract using common patterns
// Pattern 1: Look for domain name pattern
const domainPattern =
/([a-z0-9]([a-z0-9-]{0,61}[a-z0-9])?\.)+[a-z0-9]([a-z0-9-]{0,61}[a-z0-9])?/i;
const domainMatch = identityStr.match(domainPattern);
if (domainMatch && domainMatch[0]) {
log(`Found domain in PSK identity: ${domainMatch[0]}`);
return domainMatch[0];
}
// Pattern 2: Chrome sometimes uses a specific format with delimiters
// This is a heuristic approach since the format isn't standardized
const parts = identityStr.split('|');
if (parts.length > 1) {
for (const part of parts) {
if (part.includes('.') && !part.includes('/')) {
const possibleDomain = part.trim();
if (/^[a-z0-9.-]+$/i.test(possibleDomain)) {
log(`Found possible domain in PSK delimiter format: ${possibleDomain}`);
return possibleDomain;
}
}
}
}
} catch (e) {
log('Failed to parse PSK identity as UTF-8');
}
}
}
log('No hostname found in PSK extension');
return undefined;
} catch (error) {
log(`Error parsing PSK: ${error instanceof Error ? error.message : String(error)}`);
return undefined;
}
}
/**
* Main entry point for SNI extraction with support for fragmented messages
* and session resumption edge cases.
*
* @param buffer The buffer containing TLS data
* @param connectionInfo Connection tracking information
* @param logger Optional logging function
* @param cachedSni Optional previously cached SNI value
* @returns The extracted server name or undefined
*/
public static extractSNIWithResumptionSupport(
buffer: Buffer,
connectionInfo?: ConnectionInfo,
logger?: LoggerFunction,
cachedSni?: string
): string | undefined {
const log = logger || (() => {});
// Log buffer details for debugging
if (logger) {
log(`Buffer size: ${buffer.length} bytes`);
log(`Buffer starts with: ${buffer.slice(0, Math.min(10, buffer.length)).toString('hex')}`);
if (buffer.length >= 5) {
const recordType = buffer[0];
const majorVersion = buffer[1];
const minorVersion = buffer[2];
const recordLength = (buffer[3] << 8) + buffer[4];
log(
`TLS Record: type=${recordType}, version=${majorVersion}.${minorVersion}, length=${recordLength}`
);
}
}
// Check if we need to handle fragmented packets
let processBuffer = buffer;
if (connectionInfo) {
const connectionId = TlsUtils.createConnectionId(connectionInfo);
const reassembledBuffer = ClientHelloParser.handleFragmentedClientHello(
buffer,
connectionId,
logger
);
if (!reassembledBuffer) {
log(`Waiting for more fragments on connection ${connectionId}`);
return undefined; // Need more fragments to complete ClientHello
}
processBuffer = reassembledBuffer;
log(`Using reassembled buffer of length ${processBuffer.length}`);
}
// First try the standard SNI extraction
const standardSni = this.extractSNI(processBuffer, logger);
if (standardSni) {
log(`Found standard SNI: ${standardSni}`);
return standardSni;
}
// Check for session resumption when standard SNI extraction fails
if (TlsUtils.isClientHello(processBuffer)) {
const resumptionInfo = ClientHelloParser.hasSessionResumption(processBuffer, logger);
if (resumptionInfo.isResumption) {
log(`Detected session resumption in ClientHello without standard SNI`);
// Try to extract SNI from PSK extension
const pskSni = this.extractSNIFromPSKExtension(processBuffer, logger);
if (pskSni) {
log(`Extracted SNI from PSK extension: ${pskSni}`);
return pskSni;
}
}
}
// If cached SNI was provided, use it for application data packets
if (cachedSni && TlsUtils.isTlsApplicationData(buffer)) {
log(`Using provided cached SNI for application data: ${cachedSni}`);
return cachedSni;
}
return undefined;
}
/**
* Unified method for processing a TLS packet and extracting SNI.
* Main entry point for SNI extraction that handles all edge cases.
*
* @param buffer The buffer containing TLS data
* @param connectionInfo Connection tracking information
* @param logger Optional logging function
* @param cachedSni Optional previously cached SNI value
* @returns The extracted server name or undefined
*/
public static processTlsPacket(
buffer: Buffer,
connectionInfo: ConnectionInfo,
logger?: LoggerFunction,
cachedSni?: string
): string | undefined {
const log = logger || (() => {});
// Add timestamp if not provided
if (!connectionInfo.timestamp) {
connectionInfo.timestamp = Date.now();
}
// Check if this is a TLS handshake or application data
if (!TlsUtils.isTlsHandshake(buffer) && !TlsUtils.isTlsApplicationData(buffer)) {
log('Not a TLS handshake or application data packet');
return undefined;
}
// Create connection ID for tracking
const connectionId = TlsUtils.createConnectionId(connectionInfo);
log(`Processing TLS packet for connection ${connectionId}, buffer length: ${buffer.length}`);
// Handle application data with cached SNI (for connection racing)
if (TlsUtils.isTlsApplicationData(buffer)) {
// If explicit cachedSni was provided, use it
if (cachedSni) {
log(`Using provided cached SNI for application data: ${cachedSni}`);
return cachedSni;
}
log('Application data packet without cached SNI, cannot determine hostname');
return undefined;
}
// Enhanced session resumption detection
if (TlsUtils.isClientHello(buffer)) {
const resumptionInfo = ClientHelloParser.hasSessionResumption(buffer, logger);
if (resumptionInfo.isResumption) {
log(`Session resumption detected in TLS packet`);
// Always try standard SNI extraction first
const standardSni = this.extractSNI(buffer, logger);
if (standardSni) {
log(`Found standard SNI in session resumption: ${standardSni}`);
return standardSni;
}
// Enhanced session resumption SNI extraction
// Try extracting from PSK identity
const pskSni = this.extractSNIFromPSKExtension(buffer, logger);
if (pskSni) {
log(`Extracted SNI from PSK extension: ${pskSni}`);
return pskSni;
}
log(`Session resumption without extractable SNI`);
}
}
// For handshake messages, try the full extraction process
const sni = this.extractSNIWithResumptionSupport(buffer, connectionInfo, logger);
if (sni) {
log(`Successfully extracted SNI: ${sni}`);
return sni;
}
// If we couldn't extract an SNI, check if this is a valid ClientHello
if (TlsUtils.isClientHello(buffer)) {
log('Valid ClientHello detected, but no SNI extracted - might need more data');
}
return undefined;
}
}

View File

@@ -0,0 +1,3 @@
/**
* TLS utilities
*/

View File

@@ -0,0 +1,201 @@
import * as plugins from '../../../plugins.js';
/**
* TLS record types as defined in various RFCs
*/
export enum TlsRecordType {
CHANGE_CIPHER_SPEC = 20,
ALERT = 21,
HANDSHAKE = 22,
APPLICATION_DATA = 23,
HEARTBEAT = 24, // RFC 6520
}
/**
* TLS handshake message types
*/
export enum TlsHandshakeType {
HELLO_REQUEST = 0,
CLIENT_HELLO = 1,
SERVER_HELLO = 2,
NEW_SESSION_TICKET = 4,
ENCRYPTED_EXTENSIONS = 8, // TLS 1.3
CERTIFICATE = 11,
SERVER_KEY_EXCHANGE = 12,
CERTIFICATE_REQUEST = 13,
SERVER_HELLO_DONE = 14,
CERTIFICATE_VERIFY = 15,
CLIENT_KEY_EXCHANGE = 16,
FINISHED = 20,
}
/**
* TLS extension types
*/
export enum TlsExtensionType {
SERVER_NAME = 0, // SNI
MAX_FRAGMENT_LENGTH = 1,
CLIENT_CERTIFICATE_URL = 2,
TRUSTED_CA_KEYS = 3,
TRUNCATED_HMAC = 4,
STATUS_REQUEST = 5, // OCSP
SUPPORTED_GROUPS = 10, // Previously named "elliptic_curves"
EC_POINT_FORMATS = 11,
SIGNATURE_ALGORITHMS = 13,
APPLICATION_LAYER_PROTOCOL_NEGOTIATION = 16, // ALPN
SIGNED_CERTIFICATE_TIMESTAMP = 18, // Certificate Transparency
PADDING = 21,
SESSION_TICKET = 35,
PRE_SHARED_KEY = 41, // TLS 1.3
EARLY_DATA = 42, // TLS 1.3 0-RTT
SUPPORTED_VERSIONS = 43, // TLS 1.3
COOKIE = 44, // TLS 1.3
PSK_KEY_EXCHANGE_MODES = 45, // TLS 1.3
CERTIFICATE_AUTHORITIES = 47, // TLS 1.3
POST_HANDSHAKE_AUTH = 49, // TLS 1.3
SIGNATURE_ALGORITHMS_CERT = 50, // TLS 1.3
KEY_SHARE = 51, // TLS 1.3
}
/**
* TLS alert levels
*/
export enum TlsAlertLevel {
WARNING = 1,
FATAL = 2,
}
/**
* TLS alert description codes
*/
export enum TlsAlertDescription {
CLOSE_NOTIFY = 0,
UNEXPECTED_MESSAGE = 10,
BAD_RECORD_MAC = 20,
DECRYPTION_FAILED = 21, // TLS 1.0 only
RECORD_OVERFLOW = 22,
DECOMPRESSION_FAILURE = 30, // TLS 1.2 and below
HANDSHAKE_FAILURE = 40,
NO_CERTIFICATE = 41, // SSLv3 only
BAD_CERTIFICATE = 42,
UNSUPPORTED_CERTIFICATE = 43,
CERTIFICATE_REVOKED = 44,
CERTIFICATE_EXPIRED = 45,
CERTIFICATE_UNKNOWN = 46,
ILLEGAL_PARAMETER = 47,
UNKNOWN_CA = 48,
ACCESS_DENIED = 49,
DECODE_ERROR = 50,
DECRYPT_ERROR = 51,
EXPORT_RESTRICTION = 60, // TLS 1.0 only
PROTOCOL_VERSION = 70,
INSUFFICIENT_SECURITY = 71,
INTERNAL_ERROR = 80,
INAPPROPRIATE_FALLBACK = 86,
USER_CANCELED = 90,
NO_RENEGOTIATION = 100, // TLS 1.2 and below
MISSING_EXTENSION = 109, // TLS 1.3
UNSUPPORTED_EXTENSION = 110, // TLS 1.3
CERTIFICATE_REQUIRED = 111, // TLS 1.3
UNRECOGNIZED_NAME = 112,
BAD_CERTIFICATE_STATUS_RESPONSE = 113,
BAD_CERTIFICATE_HASH_VALUE = 114, // TLS 1.2 and below
UNKNOWN_PSK_IDENTITY = 115,
CERTIFICATE_REQUIRED_1_3 = 116, // TLS 1.3
NO_APPLICATION_PROTOCOL = 120,
}
/**
* TLS version codes (major.minor)
*/
export const TlsVersion = {
SSL3: [0x03, 0x00],
TLS1_0: [0x03, 0x01],
TLS1_1: [0x03, 0x02],
TLS1_2: [0x03, 0x03],
TLS1_3: [0x03, 0x04],
};
/**
* Utility functions for TLS protocol operations
*/
export class TlsUtils {
/**
* Checks if a buffer contains a TLS handshake record
* @param buffer The buffer to check
* @returns true if the buffer starts with a TLS handshake record
*/
public static isTlsHandshake(buffer: Buffer): boolean {
return buffer.length > 0 && buffer[0] === TlsRecordType.HANDSHAKE;
}
/**
* Checks if a buffer contains TLS application data
* @param buffer The buffer to check
* @returns true if the buffer starts with a TLS application data record
*/
public static isTlsApplicationData(buffer: Buffer): boolean {
return buffer.length > 0 && buffer[0] === TlsRecordType.APPLICATION_DATA;
}
/**
* Checks if a buffer contains a TLS alert record
* @param buffer The buffer to check
* @returns true if the buffer starts with a TLS alert record
*/
public static isTlsAlert(buffer: Buffer): boolean {
return buffer.length > 0 && buffer[0] === TlsRecordType.ALERT;
}
/**
* Checks if a buffer contains a TLS ClientHello message
* @param buffer The buffer to check
* @returns true if the buffer appears to be a ClientHello message
*/
public static isClientHello(buffer: Buffer): boolean {
// Minimum ClientHello size (TLS record header + handshake header)
if (buffer.length < 9) {
return false;
}
// Check record type (must be TLS_HANDSHAKE_RECORD_TYPE)
if (buffer[0] !== TlsRecordType.HANDSHAKE) {
return false;
}
// Skip version and length in TLS record header (5 bytes total)
// Check handshake type at byte 5 (must be CLIENT_HELLO)
return buffer[5] === TlsHandshakeType.CLIENT_HELLO;
}
/**
* Gets the record length from a TLS record header
* @param buffer Buffer containing a TLS record
* @returns The record length if the buffer is valid, -1 otherwise
*/
public static getTlsRecordLength(buffer: Buffer): number {
if (buffer.length < 5) {
return -1;
}
// Bytes 3-4 contain the record length (big-endian)
return (buffer[3] << 8) + buffer[4];
}
/**
* Creates a connection ID based on source/destination information
* Used to track fragmented ClientHello messages across multiple packets
*
* @param connectionInfo Object containing connection identifiers
* @returns A string ID for the connection
*/
public static createConnectionId(connectionInfo: {
sourceIp?: string;
sourcePort?: number;
destIp?: string;
destPort?: number;
}): string {
const { sourceIp, sourcePort, destIp, destPort } = connectionInfo;
return `${sourceIp}:${sourcePort}-${destIp}:${destPort}`;
}
}

View File

@@ -0,0 +1,60 @@
/**
* WebSocket Protocol Constants
* Based on RFC 6455
*/
/**
* WebSocket opcode types
*/
export enum WebSocketOpcode {
CONTINUATION = 0x0,
TEXT = 0x1,
BINARY = 0x2,
CLOSE = 0x8,
PING = 0x9,
PONG = 0xa,
}
/**
* WebSocket close codes
*/
export enum WebSocketCloseCode {
NORMAL_CLOSURE = 1000,
GOING_AWAY = 1001,
PROTOCOL_ERROR = 1002,
UNSUPPORTED_DATA = 1003,
NO_STATUS_RECEIVED = 1005,
ABNORMAL_CLOSURE = 1006,
INVALID_FRAME_PAYLOAD_DATA = 1007,
POLICY_VIOLATION = 1008,
MESSAGE_TOO_BIG = 1009,
MISSING_EXTENSION = 1010,
INTERNAL_ERROR = 1011,
SERVICE_RESTART = 1012,
TRY_AGAIN_LATER = 1013,
BAD_GATEWAY = 1014,
TLS_HANDSHAKE = 1015,
}
/**
* WebSocket protocol version
*/
export const WEBSOCKET_VERSION = 13;
/**
* WebSocket magic string for handshake
*/
export const WEBSOCKET_MAGIC_STRING = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11';
/**
* WebSocket headers
*/
export const WEBSOCKET_HEADERS = {
UPGRADE: 'upgrade',
CONNECTION: 'connection',
SEC_WEBSOCKET_KEY: 'sec-websocket-key',
SEC_WEBSOCKET_VERSION: 'sec-websocket-version',
SEC_WEBSOCKET_ACCEPT: 'sec-websocket-accept',
SEC_WEBSOCKET_PROTOCOL: 'sec-websocket-protocol',
SEC_WEBSOCKET_EXTENSIONS: 'sec-websocket-extensions',
} as const;

View File

@@ -0,0 +1,8 @@
/**
* WebSocket Protocol Module
* WebSocket protocol utilities and constants
*/
export * from './constants.js';
export * from './types.js';
export * from './utils.js';

View File

@@ -0,0 +1,53 @@
/**
* WebSocket Protocol Type Definitions
*/
import type { WebSocketOpcode, WebSocketCloseCode } from './constants.js';
/**
* WebSocket frame header
*/
export interface IWebSocketFrameHeader {
fin: boolean;
rsv1: boolean;
rsv2: boolean;
rsv3: boolean;
opcode: WebSocketOpcode;
masked: boolean;
payloadLength: number;
maskingKey?: Buffer;
}
/**
* WebSocket frame
*/
export interface IWebSocketFrame {
header: IWebSocketFrameHeader;
payload: Buffer;
}
/**
* WebSocket close frame payload
*/
export interface IWebSocketClosePayload {
code: WebSocketCloseCode;
reason?: string;
}
/**
* WebSocket handshake request headers
*/
export interface IWebSocketHandshakeHeaders {
upgrade: string;
connection: string;
'sec-websocket-key': string;
'sec-websocket-version': string;
'sec-websocket-protocol'?: string;
'sec-websocket-extensions'?: string;
[key: string]: string | undefined;
}
/**
* Type for WebSocket raw data (matching ws library)
*/
export type RawData = Buffer | ArrayBuffer | Buffer[] | any;

View File

@@ -0,0 +1,98 @@
/**
* WebSocket Protocol Utilities
*/
import * as crypto from 'crypto';
import { WEBSOCKET_MAGIC_STRING } from './constants.js';
import type { RawData } from './types.js';
/**
* Get the length of a WebSocket message regardless of its type
* (handles all possible WebSocket message data types)
*/
export function getMessageSize(data: RawData): number {
if (typeof data === 'string') {
// For string data, get the byte length
return Buffer.from(data, 'utf8').length;
} else if (data instanceof Buffer) {
// For Node.js Buffer
return data.length;
} else if (data instanceof ArrayBuffer) {
// For ArrayBuffer
return data.byteLength;
} else if (Array.isArray(data)) {
// For array of buffers, sum their lengths
return data.reduce((sum, chunk) => {
if (chunk instanceof Buffer) {
return sum + chunk.length;
} else if (chunk instanceof ArrayBuffer) {
return sum + chunk.byteLength;
}
return sum;
}, 0);
} else {
// For other types, try to determine the size or return 0
try {
return Buffer.from(data).length;
} catch (e) {
return 0;
}
}
}
/**
* Convert any raw WebSocket data to Buffer for consistent handling
*/
export function toBuffer(data: RawData): Buffer {
if (typeof data === 'string') {
return Buffer.from(data, 'utf8');
} else if (data instanceof Buffer) {
return data;
} else if (data instanceof ArrayBuffer) {
return Buffer.from(data);
} else if (Array.isArray(data)) {
// For array of buffers, concatenate them
return Buffer.concat(data.map(chunk => {
if (chunk instanceof Buffer) {
return chunk;
} else if (chunk instanceof ArrayBuffer) {
return Buffer.from(chunk);
}
return Buffer.from(chunk);
}));
} else {
// For other types, try to convert to Buffer or return empty Buffer
try {
return Buffer.from(data);
} catch (e) {
return Buffer.alloc(0);
}
}
}
/**
* Generate WebSocket accept key from client key
*/
export function generateAcceptKey(clientKey: string): string {
const hash = crypto.createHash('sha1');
hash.update(clientKey + WEBSOCKET_MAGIC_STRING);
return hash.digest('base64');
}
/**
* Validate WebSocket upgrade request
*/
export function isWebSocketUpgrade(headers: Record<string, string>): boolean {
const upgrade = headers['upgrade'];
const connection = headers['connection'];
return upgrade?.toLowerCase() === 'websocket' &&
connection?.toLowerCase().includes('upgrade');
}
/**
* Generate random WebSocket key for client handshake
*/
export function generateWebSocketKey(): string {
return crypto.randomBytes(16).toString('base64');
}