413 lines
12 KiB
TypeScript
413 lines
12 KiB
TypeScript
import type { ILogger } from './models/types.js';
|
|
import type { IRouteConfig } from '../smart-proxy/models/route-types.js';
|
|
import type { IRouteContext } from '../../core/models/route-context.js';
|
|
import {
|
|
isIPAuthorized,
|
|
normalizeIP,
|
|
parseBasicAuthHeader,
|
|
cleanupExpiredRateLimits,
|
|
type IRateLimitInfo
|
|
} from '../../core/utils/security-utils.js';
|
|
|
|
/**
|
|
* Manages security features for the HttpProxy
|
|
* Implements IP filtering, rate limiting, and authentication.
|
|
* Uses shared utilities from security-utils.ts.
|
|
*/
|
|
export class SecurityManager {
|
|
// Cache IP filtering results to avoid constant regex matching
|
|
private ipFilterCache: Map<string, Map<string, boolean>> = new Map();
|
|
|
|
// Store rate limits per route and key
|
|
private rateLimits: Map<string, Map<string, IRateLimitInfo>> = new Map();
|
|
|
|
// Connection tracking by IP
|
|
private connectionsByIP: Map<string, Set<string>> = new Map();
|
|
private connectionRateByIP: Map<string, number[]> = new Map();
|
|
|
|
constructor(
|
|
private logger: ILogger,
|
|
private routes: IRouteConfig[] = [],
|
|
private maxConnectionsPerIP: number = 100,
|
|
private connectionRateLimitPerMinute: number = 300
|
|
) {
|
|
// Start periodic cleanup for connection tracking
|
|
this.startPeriodicIpCleanup();
|
|
}
|
|
|
|
/**
|
|
* Update the routes configuration
|
|
*/
|
|
public setRoutes(routes: IRouteConfig[]): void {
|
|
this.routes = routes;
|
|
// Reset caches when routes change
|
|
this.ipFilterCache.clear();
|
|
}
|
|
|
|
/**
|
|
* Check if a client is allowed to access a specific route
|
|
*
|
|
* @param route The route to check access for
|
|
* @param context The route context with client information
|
|
* @returns True if access is allowed, false otherwise
|
|
*/
|
|
public isAllowed(route: IRouteConfig, context: IRouteContext): boolean {
|
|
if (!route.security) {
|
|
return true; // No security restrictions
|
|
}
|
|
|
|
// --- IP filtering ---
|
|
if (!this.isIpAllowed(route, context.clientIp)) {
|
|
this.logger.debug(`IP ${context.clientIp} is blocked for route ${route.name || 'unnamed'}`);
|
|
return false;
|
|
}
|
|
|
|
// --- Rate limiting ---
|
|
if (route.security.rateLimit?.enabled && !this.isWithinRateLimit(route, context)) {
|
|
this.logger.debug(`Rate limit exceeded for route ${route.name || 'unnamed'}`);
|
|
return false;
|
|
}
|
|
|
|
// --- Basic Auth (handled at HTTP level) ---
|
|
// Basic auth is not checked here as it requires HTTP headers
|
|
// and is handled in the RequestHandler
|
|
|
|
return true;
|
|
}
|
|
|
|
/**
|
|
* Check if an IP is allowed based on route security settings
|
|
*/
|
|
private isIpAllowed(route: IRouteConfig, clientIp: string): boolean {
|
|
if (!route.security) {
|
|
return true; // No security restrictions
|
|
}
|
|
|
|
const routeId = route.name || 'unnamed';
|
|
|
|
// Check cache first
|
|
if (!this.ipFilterCache.has(routeId)) {
|
|
this.ipFilterCache.set(routeId, new Map());
|
|
}
|
|
|
|
const routeCache = this.ipFilterCache.get(routeId)!;
|
|
if (routeCache.has(clientIp)) {
|
|
return routeCache.get(clientIp)!;
|
|
}
|
|
|
|
// Use shared utility for IP authorization
|
|
const allowed = isIPAuthorized(
|
|
clientIp,
|
|
route.security.ipAllowList,
|
|
route.security.ipBlockList
|
|
);
|
|
|
|
// Cache the result
|
|
routeCache.set(clientIp, allowed);
|
|
|
|
return allowed;
|
|
}
|
|
|
|
/**
|
|
* Check if request is within rate limit
|
|
*/
|
|
private isWithinRateLimit(route: IRouteConfig, context: IRouteContext): boolean {
|
|
if (!route.security?.rateLimit?.enabled) {
|
|
return true;
|
|
}
|
|
|
|
const rateLimit = route.security.rateLimit;
|
|
const routeId = route.name || 'unnamed';
|
|
|
|
// Determine rate limit key (by IP, path, or header)
|
|
let key = context.clientIp; // Default to IP
|
|
|
|
if (rateLimit.keyBy === 'path' && context.path) {
|
|
key = `${context.clientIp}:${context.path}`;
|
|
} else if (rateLimit.keyBy === 'header' && rateLimit.headerName && context.headers) {
|
|
const headerValue = context.headers[rateLimit.headerName.toLowerCase()];
|
|
if (headerValue) {
|
|
key = `${context.clientIp}:${headerValue}`;
|
|
}
|
|
}
|
|
|
|
// Get or create rate limit tracking for this route
|
|
if (!this.rateLimits.has(routeId)) {
|
|
this.rateLimits.set(routeId, new Map());
|
|
}
|
|
|
|
const routeLimits = this.rateLimits.get(routeId)!;
|
|
const now = Date.now();
|
|
|
|
// Get or create rate limit tracking for this key
|
|
let limit = routeLimits.get(key);
|
|
if (!limit || limit.expiry < now) {
|
|
// Create new rate limit or reset expired one
|
|
limit = {
|
|
count: 1,
|
|
expiry: now + (rateLimit.window * 1000)
|
|
};
|
|
routeLimits.set(key, limit);
|
|
return true;
|
|
}
|
|
|
|
// Increment the counter
|
|
limit.count++;
|
|
|
|
// Check if rate limit is exceeded
|
|
return limit.count <= rateLimit.maxRequests;
|
|
}
|
|
|
|
/**
|
|
* Clean up expired rate limits
|
|
* Should be called periodically to prevent memory leaks
|
|
*/
|
|
public cleanupExpiredRateLimits(): void {
|
|
cleanupExpiredRateLimits(this.rateLimits, {
|
|
info: this.logger.info.bind(this.logger),
|
|
warn: this.logger.warn.bind(this.logger),
|
|
error: this.logger.error.bind(this.logger),
|
|
debug: this.logger.debug?.bind(this.logger)
|
|
});
|
|
}
|
|
|
|
/**
|
|
* Check basic auth credentials
|
|
*
|
|
* @param route The route to check auth for
|
|
* @param username The provided username
|
|
* @param password The provided password
|
|
* @returns True if credentials are valid, false otherwise
|
|
*/
|
|
public checkBasicAuth(route: IRouteConfig, username: string, password: string): boolean {
|
|
if (!route.security?.basicAuth?.enabled) {
|
|
return true;
|
|
}
|
|
|
|
const basicAuth = route.security.basicAuth;
|
|
|
|
// Check credentials against configured users
|
|
for (const user of basicAuth.users) {
|
|
if (user.username === username && user.password === password) {
|
|
return true;
|
|
}
|
|
}
|
|
|
|
return false;
|
|
}
|
|
|
|
/**
|
|
* Verify a JWT token
|
|
*
|
|
* @param route The route to verify the token for
|
|
* @param token The JWT token to verify
|
|
* @returns True if the token is valid, false otherwise
|
|
*/
|
|
public verifyJwtToken(route: IRouteConfig, token: string): boolean {
|
|
if (!route.security?.jwtAuth?.enabled) {
|
|
return true;
|
|
}
|
|
|
|
try {
|
|
const jwtAuth = route.security.jwtAuth;
|
|
|
|
// Verify structure
|
|
const parts = token.split('.');
|
|
if (parts.length !== 3) {
|
|
return false;
|
|
}
|
|
|
|
// Decode payload
|
|
const payload = JSON.parse(Buffer.from(parts[1], 'base64').toString());
|
|
|
|
// Check expiration
|
|
if (payload.exp && payload.exp < Math.floor(Date.now() / 1000)) {
|
|
return false;
|
|
}
|
|
|
|
// Check issuer
|
|
if (jwtAuth.issuer && payload.iss !== jwtAuth.issuer) {
|
|
return false;
|
|
}
|
|
|
|
// Check audience
|
|
if (jwtAuth.audience && payload.aud !== jwtAuth.audience) {
|
|
return false;
|
|
}
|
|
|
|
// Note: In a real implementation, you'd also verify the signature
|
|
// using the secret and algorithm specified in jwtAuth
|
|
|
|
return true;
|
|
} catch (err) {
|
|
this.logger.error(`Error verifying JWT: ${err}`);
|
|
return false;
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Get connections count by IP (checks normalized variants)
|
|
*/
|
|
public getConnectionCountByIP(ip: string): number {
|
|
// Check all normalized variants of the IP
|
|
const variants = normalizeIP(ip);
|
|
for (const variant of variants) {
|
|
const connections = this.connectionsByIP.get(variant);
|
|
if (connections) {
|
|
return connections.size;
|
|
}
|
|
}
|
|
return 0;
|
|
}
|
|
|
|
/**
|
|
* Check and update connection rate for an IP
|
|
* @returns true if within rate limit, false if exceeding limit
|
|
*/
|
|
public checkConnectionRate(ip: string): boolean {
|
|
const now = Date.now();
|
|
const minute = 60 * 1000;
|
|
|
|
// Find existing rate tracking (check normalized variants)
|
|
const variants = normalizeIP(ip);
|
|
let existingKey: string | null = null;
|
|
for (const variant of variants) {
|
|
if (this.connectionRateByIP.has(variant)) {
|
|
existingKey = variant;
|
|
break;
|
|
}
|
|
}
|
|
|
|
const key = existingKey || ip;
|
|
|
|
if (!this.connectionRateByIP.has(key)) {
|
|
this.connectionRateByIP.set(key, [now]);
|
|
return true;
|
|
}
|
|
|
|
// Get timestamps and filter out entries older than 1 minute
|
|
const timestamps = this.connectionRateByIP.get(key)!.filter((time) => now - time < minute);
|
|
timestamps.push(now);
|
|
this.connectionRateByIP.set(key, timestamps);
|
|
|
|
// Check if rate exceeds limit
|
|
return timestamps.length <= this.connectionRateLimitPerMinute;
|
|
}
|
|
|
|
/**
|
|
* Track connection by IP
|
|
*/
|
|
public trackConnectionByIP(ip: string, connectionId: string): void {
|
|
// Check if any variant already exists
|
|
const variants = normalizeIP(ip);
|
|
let existingKey: string | null = null;
|
|
|
|
for (const variant of variants) {
|
|
if (this.connectionsByIP.has(variant)) {
|
|
existingKey = variant;
|
|
break;
|
|
}
|
|
}
|
|
|
|
const key = existingKey || ip;
|
|
if (!this.connectionsByIP.has(key)) {
|
|
this.connectionsByIP.set(key, new Set());
|
|
}
|
|
this.connectionsByIP.get(key)!.add(connectionId);
|
|
}
|
|
|
|
/**
|
|
* Remove connection tracking for an IP
|
|
*/
|
|
public removeConnectionByIP(ip: string, connectionId: string): void {
|
|
// Check all variants to find where the connection is tracked
|
|
const variants = normalizeIP(ip);
|
|
|
|
for (const variant of variants) {
|
|
if (this.connectionsByIP.has(variant)) {
|
|
const connections = this.connectionsByIP.get(variant)!;
|
|
connections.delete(connectionId);
|
|
if (connections.size === 0) {
|
|
this.connectionsByIP.delete(variant);
|
|
}
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Check if IP should be allowed considering connection rate and max connections
|
|
* @returns Object with result and reason
|
|
*/
|
|
public validateIP(ip: string): { allowed: boolean; reason?: string } {
|
|
// Check connection count limit
|
|
if (this.getConnectionCountByIP(ip) >= this.maxConnectionsPerIP) {
|
|
return {
|
|
allowed: false,
|
|
reason: `Maximum connections per IP (${this.maxConnectionsPerIP}) exceeded`
|
|
};
|
|
}
|
|
|
|
// Check connection rate limit
|
|
if (!this.checkConnectionRate(ip)) {
|
|
return {
|
|
allowed: false,
|
|
reason: `Connection rate limit (${this.connectionRateLimitPerMinute}/min) exceeded`
|
|
};
|
|
}
|
|
|
|
return { allowed: true };
|
|
}
|
|
|
|
/**
|
|
* Clears all IP tracking data (for shutdown)
|
|
*/
|
|
public clearIPTracking(): void {
|
|
this.connectionsByIP.clear();
|
|
this.connectionRateByIP.clear();
|
|
}
|
|
|
|
/**
|
|
* Start periodic cleanup of IP tracking data
|
|
*/
|
|
private startPeriodicIpCleanup(): void {
|
|
// Clean up IP tracking data every minute
|
|
setInterval(() => {
|
|
this.performIpCleanup();
|
|
}, 60000).unref();
|
|
}
|
|
|
|
/**
|
|
* Perform cleanup of expired IP data
|
|
*/
|
|
private performIpCleanup(): void {
|
|
const now = Date.now();
|
|
const minute = 60 * 1000;
|
|
let cleanedRateLimits = 0;
|
|
let cleanedIPs = 0;
|
|
|
|
// Clean up expired rate limit timestamps
|
|
for (const [ip, timestamps] of this.connectionRateByIP.entries()) {
|
|
const validTimestamps = timestamps.filter((time) => now - time < minute);
|
|
|
|
if (validTimestamps.length === 0) {
|
|
this.connectionRateByIP.delete(ip);
|
|
cleanedRateLimits++;
|
|
} else if (validTimestamps.length < timestamps.length) {
|
|
this.connectionRateByIP.set(ip, validTimestamps);
|
|
}
|
|
}
|
|
|
|
// Clean up IPs with no active connections
|
|
for (const [ip, connections] of this.connectionsByIP.entries()) {
|
|
if (connections.size === 0) {
|
|
this.connectionsByIP.delete(ip);
|
|
cleanedIPs++;
|
|
}
|
|
}
|
|
|
|
if (cleanedRateLimits > 0 || cleanedIPs > 0) {
|
|
this.logger.debug(`IP cleanup: removed ${cleanedIPs} IPs and ${cleanedRateLimits} rate limits`);
|
|
}
|
|
}
|
|
} |