start fixing tests
This commit is contained in:
@ -6,3 +6,4 @@ export * from './common-types.js';
|
||||
export * from './socket-augmentation.js';
|
||||
export * from './route-context.js';
|
||||
export * from './wrapped-socket.js';
|
||||
export * from './socket-types.js';
|
||||
|
21
ts/core/models/socket-types.ts
Normal file
21
ts/core/models/socket-types.ts
Normal file
@ -0,0 +1,21 @@
|
||||
import * as net from 'net';
|
||||
import { WrappedSocket } from './wrapped-socket.js';
|
||||
|
||||
/**
|
||||
* Type guard to check if a socket is a WrappedSocket
|
||||
*/
|
||||
export function isWrappedSocket(socket: net.Socket | WrappedSocket): socket is WrappedSocket {
|
||||
return socket instanceof WrappedSocket || 'socket' in socket;
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper to get the underlying socket from either a Socket or WrappedSocket
|
||||
*/
|
||||
export function getUnderlyingSocket(socket: net.Socket | WrappedSocket): net.Socket {
|
||||
return isWrappedSocket(socket) ? socket.socket : socket;
|
||||
}
|
||||
|
||||
/**
|
||||
* Type that represents either a regular socket or a wrapped socket
|
||||
*/
|
||||
export type AnySocket = net.Socket | WrappedSocket;
|
@ -1,4 +1,3 @@
|
||||
import { EventEmitter } from 'events';
|
||||
import * as plugins from '../../plugins.js';
|
||||
|
||||
/**
|
||||
@ -7,22 +6,66 @@ import * as plugins from '../../plugins.js';
|
||||
*
|
||||
* This is the FOUNDATION for all PROXY protocol support and must be implemented
|
||||
* before any protocol parsing can occur.
|
||||
*
|
||||
* This implementation uses a Proxy to delegate all properties and methods
|
||||
* to the underlying socket while allowing override of specific properties.
|
||||
*/
|
||||
export class WrappedSocket extends EventEmitter {
|
||||
export class WrappedSocket {
|
||||
public readonly socket: plugins.net.Socket;
|
||||
private realClientIP?: string;
|
||||
private realClientPort?: number;
|
||||
|
||||
// Make TypeScript happy by declaring the Socket methods that will be proxied
|
||||
[key: string]: any;
|
||||
|
||||
constructor(
|
||||
public readonly socket: plugins.net.Socket,
|
||||
socket: plugins.net.Socket,
|
||||
realClientIP?: string,
|
||||
realClientPort?: number
|
||||
) {
|
||||
super();
|
||||
this.socket = socket;
|
||||
this.realClientIP = realClientIP;
|
||||
this.realClientPort = realClientPort;
|
||||
|
||||
// Forward all socket events
|
||||
this.forwardSocketEvents();
|
||||
// Create a proxy that delegates everything to the underlying socket
|
||||
return new Proxy(this, {
|
||||
get(target, prop, receiver) {
|
||||
// Override specific properties
|
||||
if (prop === 'remoteAddress') {
|
||||
return target.remoteAddress;
|
||||
}
|
||||
if (prop === 'remotePort') {
|
||||
return target.remotePort;
|
||||
}
|
||||
if (prop === 'socket') {
|
||||
return target.socket;
|
||||
}
|
||||
if (prop === 'realClientIP') {
|
||||
return target.realClientIP;
|
||||
}
|
||||
if (prop === 'realClientPort') {
|
||||
return target.realClientPort;
|
||||
}
|
||||
if (prop === 'isFromTrustedProxy') {
|
||||
return target.isFromTrustedProxy;
|
||||
}
|
||||
if (prop === 'setProxyInfo') {
|
||||
return target.setProxyInfo.bind(target);
|
||||
}
|
||||
|
||||
// For all other properties/methods, delegate to the underlying socket
|
||||
const value = target.socket[prop as keyof plugins.net.Socket];
|
||||
if (typeof value === 'function') {
|
||||
return value.bind(target.socket);
|
||||
}
|
||||
return value;
|
||||
},
|
||||
set(target, prop, value) {
|
||||
// Set on the underlying socket
|
||||
(target.socket as any)[prop] = value;
|
||||
return true;
|
||||
}
|
||||
}) as any;
|
||||
}
|
||||
|
||||
/**
|
||||
@ -39,35 +82,6 @@ export class WrappedSocket extends EventEmitter {
|
||||
return this.realClientPort || this.socket.remotePort;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the remote family (IPv4 or IPv6)
|
||||
*/
|
||||
get remoteFamily(): string | undefined {
|
||||
// If we have a real client IP, determine the family
|
||||
if (this.realClientIP) {
|
||||
if (this.realClientIP.includes(':')) {
|
||||
return 'IPv6';
|
||||
} else {
|
||||
return 'IPv4';
|
||||
}
|
||||
}
|
||||
return this.socket.remoteFamily;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the local address of the socket
|
||||
*/
|
||||
get localAddress(): string | undefined {
|
||||
return this.socket.localAddress;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the local port of the socket
|
||||
*/
|
||||
get localPort(): number | undefined {
|
||||
return this.socket.localPort;
|
||||
}
|
||||
|
||||
/**
|
||||
* Indicates if this connection came through a trusted proxy
|
||||
*/
|
||||
@ -82,178 +96,4 @@ export class WrappedSocket extends EventEmitter {
|
||||
this.realClientIP = ip;
|
||||
this.realClientPort = port;
|
||||
}
|
||||
|
||||
// Pass-through all socket methods
|
||||
write(data: any, encoding?: any, callback?: any): boolean {
|
||||
return this.socket.write(data, encoding, callback);
|
||||
}
|
||||
|
||||
end(data?: any, encoding?: any, callback?: any): this {
|
||||
this.socket.end(data, encoding, callback);
|
||||
return this;
|
||||
}
|
||||
|
||||
destroy(error?: Error): this {
|
||||
this.socket.destroy(error);
|
||||
return this;
|
||||
}
|
||||
|
||||
pause(): this {
|
||||
this.socket.pause();
|
||||
return this;
|
||||
}
|
||||
|
||||
resume(): this {
|
||||
this.socket.resume();
|
||||
return this;
|
||||
}
|
||||
|
||||
setTimeout(timeout: number, callback?: () => void): this {
|
||||
this.socket.setTimeout(timeout, callback);
|
||||
return this;
|
||||
}
|
||||
|
||||
setNoDelay(noDelay?: boolean): this {
|
||||
this.socket.setNoDelay(noDelay);
|
||||
return this;
|
||||
}
|
||||
|
||||
setKeepAlive(enable?: boolean, initialDelay?: number): this {
|
||||
this.socket.setKeepAlive(enable, initialDelay);
|
||||
return this;
|
||||
}
|
||||
|
||||
ref(): this {
|
||||
this.socket.ref();
|
||||
return this;
|
||||
}
|
||||
|
||||
unref(): this {
|
||||
this.socket.unref();
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Pipe to another stream
|
||||
*/
|
||||
pipe<T extends NodeJS.WritableStream>(destination: T, options?: {
|
||||
end?: boolean;
|
||||
}): T {
|
||||
return this.socket.pipe(destination, options);
|
||||
}
|
||||
|
||||
/**
|
||||
* Cork the stream
|
||||
*/
|
||||
cork(): void {
|
||||
if ('cork' in this.socket && typeof this.socket.cork === 'function') {
|
||||
this.socket.cork();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Uncork the stream
|
||||
*/
|
||||
uncork(): void {
|
||||
if ('uncork' in this.socket && typeof this.socket.uncork === 'function') {
|
||||
this.socket.uncork();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the number of bytes read
|
||||
*/
|
||||
get bytesRead(): number {
|
||||
return this.socket.bytesRead;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the number of bytes written
|
||||
*/
|
||||
get bytesWritten(): number {
|
||||
return this.socket.bytesWritten;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if the socket is connecting
|
||||
*/
|
||||
get connecting(): boolean {
|
||||
return this.socket.connecting;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if the socket is destroyed
|
||||
*/
|
||||
get destroyed(): boolean {
|
||||
return this.socket.destroyed;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if the socket is readable
|
||||
*/
|
||||
get readable(): boolean {
|
||||
return this.socket.readable;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if the socket is writable
|
||||
*/
|
||||
get writable(): boolean {
|
||||
return this.socket.writable;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get pending status
|
||||
*/
|
||||
get pending(): boolean {
|
||||
return this.socket.pending;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get ready state
|
||||
*/
|
||||
get readyState(): string {
|
||||
return this.socket.readyState;
|
||||
}
|
||||
|
||||
/**
|
||||
* Address info
|
||||
*/
|
||||
address(): plugins.net.AddressInfo | {} | null {
|
||||
const addr = this.socket.address();
|
||||
if (addr === null) return null;
|
||||
if (typeof addr === 'string') return addr as any;
|
||||
return addr;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set socket encoding
|
||||
*/
|
||||
setEncoding(encoding?: BufferEncoding): this {
|
||||
this.socket.setEncoding(encoding);
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Connect method (for client sockets)
|
||||
*/
|
||||
connect(options: plugins.net.SocketConnectOpts, connectionListener?: () => void): this;
|
||||
connect(port: number, host?: string, connectionListener?: () => void): this;
|
||||
connect(path: string, connectionListener?: () => void): this;
|
||||
connect(...args: any[]): this {
|
||||
(this.socket as any).connect(...args);
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Forward all events from the underlying socket
|
||||
*/
|
||||
private forwardSocketEvents(): void {
|
||||
const events = ['data', 'end', 'close', 'error', 'drain', 'timeout', 'connect', 'ready', 'lookup'];
|
||||
events.forEach(event => {
|
||||
this.socket.on(event, (...args) => {
|
||||
this.emit(event, ...args);
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
@ -3,6 +3,7 @@ import { HttpProxy } from '../http-proxy/index.js';
|
||||
import { setupBidirectionalForwarding } from '../../core/utils/socket-utils.js';
|
||||
import type { IConnectionRecord, ISmartProxyOptions } from './models/interfaces.js';
|
||||
import type { IRouteConfig } from './models/route-types.js';
|
||||
import { WrappedSocket } from '../../core/models/wrapped-socket.js';
|
||||
|
||||
export class HttpProxyBridge {
|
||||
private httpProxy: HttpProxy | null = null;
|
||||
@ -98,7 +99,7 @@ export class HttpProxyBridge {
|
||||
*/
|
||||
public async forwardToHttpProxy(
|
||||
connectionId: string,
|
||||
socket: plugins.net.Socket,
|
||||
socket: plugins.net.Socket | WrappedSocket,
|
||||
record: IConnectionRecord,
|
||||
initialChunk: Buffer,
|
||||
httpProxyPort: number,
|
||||
@ -125,7 +126,10 @@ export class HttpProxyBridge {
|
||||
}
|
||||
|
||||
// Use centralized bidirectional forwarding
|
||||
setupBidirectionalForwarding(socket, proxySocket, {
|
||||
// Extract underlying socket if it's a WrappedSocket
|
||||
const underlyingSocket = socket instanceof WrappedSocket ? socket.socket : socket;
|
||||
|
||||
setupBidirectionalForwarding(underlyingSocket, proxySocket, {
|
||||
onClientData: (chunk) => {
|
||||
// Update stats if needed
|
||||
if (record) {
|
||||
|
@ -12,6 +12,7 @@ import { TimeoutManager } from './timeout-manager.js';
|
||||
import { SharedRouteManager as RouteManager } from '../../core/routing/route-manager.js';
|
||||
import { cleanupSocket, createIndependentSocketHandlers, setupSocketHandlers, createSocketWithErrorHandler, setupBidirectionalForwarding } from '../../core/utils/socket-utils.js';
|
||||
import { WrappedSocket } from '../../core/models/wrapped-socket.js';
|
||||
import { getUnderlyingSocket } from '../../core/models/socket-types.js';
|
||||
|
||||
/**
|
||||
* Handles new connection processing and setup logic with support for route-based configuration
|
||||
@ -192,7 +193,7 @@ export class RouteConnectionHandler {
|
||||
// If no routes require TLS handling and it's not port 443, route immediately
|
||||
if (!needsTlsHandling && localPort !== 443) {
|
||||
// Extract underlying socket for socket-utils functions
|
||||
const underlyingSocket = socket instanceof WrappedSocket ? socket.socket : socket;
|
||||
const underlyingSocket = getUnderlyingSocket(socket);
|
||||
// Set up proper socket handlers for immediate routing
|
||||
setupSocketHandlers(
|
||||
underlyingSocket,
|
||||
@ -222,7 +223,7 @@ export class RouteConnectionHandler {
|
||||
);
|
||||
|
||||
// Route immediately for non-TLS connections
|
||||
this.routeConnection(underlyingSocket, record, '', undefined);
|
||||
this.routeConnection(socket, record, '', undefined);
|
||||
return;
|
||||
}
|
||||
|
||||
@ -379,8 +380,7 @@ export class RouteConnectionHandler {
|
||||
}
|
||||
|
||||
// Find the appropriate route for this connection
|
||||
const underlyingSocket = socket instanceof WrappedSocket ? socket.socket : socket;
|
||||
this.routeConnection(underlyingSocket, record, serverName, chunk);
|
||||
this.routeConnection(socket, record, serverName, chunk);
|
||||
});
|
||||
}
|
||||
|
||||
@ -388,7 +388,7 @@ export class RouteConnectionHandler {
|
||||
* Route the connection based on match criteria
|
||||
*/
|
||||
private routeConnection(
|
||||
socket: plugins.net.Socket,
|
||||
socket: plugins.net.Socket | WrappedSocket,
|
||||
record: IConnectionRecord,
|
||||
serverName: string,
|
||||
initialChunk?: Buffer
|
||||
@ -576,7 +576,7 @@ export class RouteConnectionHandler {
|
||||
* Handle a forward action for a route
|
||||
*/
|
||||
private handleForwardAction(
|
||||
socket: plugins.net.Socket,
|
||||
socket: plugins.net.Socket | WrappedSocket,
|
||||
record: IConnectionRecord,
|
||||
route: IRouteConfig,
|
||||
initialChunk?: Buffer
|
||||
@ -893,7 +893,7 @@ export class RouteConnectionHandler {
|
||||
* Handle a socket-handler action for a route
|
||||
*/
|
||||
private async handleSocketHandlerAction(
|
||||
socket: plugins.net.Socket,
|
||||
socket: plugins.net.Socket | WrappedSocket,
|
||||
record: IConnectionRecord,
|
||||
route: IRouteConfig,
|
||||
initialChunk?: Buffer
|
||||
@ -957,8 +957,9 @@ export class RouteConnectionHandler {
|
||||
});
|
||||
|
||||
try {
|
||||
// Call the handler with socket AND context
|
||||
const result = route.action.socketHandler(socket, routeContext);
|
||||
// Call the handler with the appropriate socket (extract underlying if needed)
|
||||
const handlerSocket = getUnderlyingSocket(socket);
|
||||
const result = route.action.socketHandler(handlerSocket, routeContext);
|
||||
|
||||
// Handle async handlers properly
|
||||
if (result instanceof Promise) {
|
||||
@ -1012,7 +1013,7 @@ export class RouteConnectionHandler {
|
||||
* Sets up a direct connection to the target
|
||||
*/
|
||||
private setupDirectConnection(
|
||||
socket: plugins.net.Socket,
|
||||
socket: plugins.net.Socket | WrappedSocket,
|
||||
record: IConnectionRecord,
|
||||
serverName?: string,
|
||||
initialChunk?: Buffer,
|
||||
@ -1162,7 +1163,10 @@ export class RouteConnectionHandler {
|
||||
}
|
||||
|
||||
// Use centralized bidirectional forwarding setup
|
||||
setupBidirectionalForwarding(socket, targetSocket, {
|
||||
// Extract underlying sockets for socket-utils functions
|
||||
const incomingSocket = getUnderlyingSocket(socket);
|
||||
|
||||
setupBidirectionalForwarding(incomingSocket, targetSocket, {
|
||||
onClientData: (chunk) => {
|
||||
record.bytesReceived += chunk.length;
|
||||
this.timeoutManager.updateActivity(record);
|
||||
|
Reference in New Issue
Block a user