/** * Sanity Middleware * * Validates request structure and parameters. */ import type { IChatCompletionRequest, IEmbeddingsRequest } from '../../interfaces/api.ts'; import { ModelRegistry } from '../../models/registry.ts'; /** * Validation result */ export interface IValidationResult { valid: boolean; error?: string; param?: string; } /** * Request validation middleware */ export class SanityMiddleware { private modelRegistry: ModelRegistry; constructor(modelRegistry: ModelRegistry) { this.modelRegistry = modelRegistry; } /** * Validate chat completion request */ public validateChatRequest(body: unknown): IValidationResult { if (!body || typeof body !== 'object') { return { valid: false, error: 'Request body must be a JSON object' }; } const request = body as Record; // Validate model if (!request.model || typeof request.model !== 'string') { return { valid: false, error: 'Missing or invalid "model" field', param: 'model' }; } // Validate messages if (!Array.isArray(request.messages)) { return { valid: false, error: 'Missing or invalid "messages" field', param: 'messages' }; } if (request.messages.length === 0) { return { valid: false, error: '"messages" array cannot be empty', param: 'messages' }; } // Validate each message for (let i = 0; i < request.messages.length; i++) { const msg = request.messages[i] as Record; const msgValidation = this.validateMessage(msg, i); if (!msgValidation.valid) { return msgValidation; } } // Validate optional parameters if (request.temperature !== undefined) { const temp = request.temperature as number; if (typeof temp !== 'number' || temp < 0 || temp > 2) { return { valid: false, error: '"temperature" must be between 0 and 2', param: 'temperature' }; } } if (request.top_p !== undefined) { const topP = request.top_p as number; if (typeof topP !== 'number' || topP < 0 || topP > 1) { return { valid: false, error: '"top_p" must be between 0 and 1', param: 'top_p' }; } } if (request.max_tokens !== undefined) { const maxTokens = request.max_tokens as number; if (typeof maxTokens !== 'number' || maxTokens < 1) { return { valid: false, error: '"max_tokens" must be a positive integer', param: 'max_tokens' }; } } if (request.n !== undefined) { const n = request.n as number; if (typeof n !== 'number' || n < 1 || n > 10) { return { valid: false, error: '"n" must be between 1 and 10', param: 'n' }; } } if (request.stream !== undefined && typeof request.stream !== 'boolean') { return { valid: false, error: '"stream" must be a boolean', param: 'stream' }; } if (request.presence_penalty !== undefined) { const pp = request.presence_penalty as number; if (typeof pp !== 'number' || pp < -2 || pp > 2) { return { valid: false, error: '"presence_penalty" must be between -2 and 2', param: 'presence_penalty' }; } } if (request.frequency_penalty !== undefined) { const fp = request.frequency_penalty as number; if (typeof fp !== 'number' || fp < -2 || fp > 2) { return { valid: false, error: '"frequency_penalty" must be between -2 and 2', param: 'frequency_penalty' }; } } return { valid: true }; } /** * Validate a single message in the chat request */ private validateMessage(msg: Record, index: number): IValidationResult { if (!msg || typeof msg !== 'object') { return { valid: false, error: `Message at index ${index} must be an object`, param: `messages[${index}]` }; } // Validate role const validRoles = ['system', 'user', 'assistant', 'tool']; if (!msg.role || !validRoles.includes(msg.role as string)) { return { valid: false, error: `Invalid role at index ${index}. Must be one of: ${validRoles.join(', ')}`, param: `messages[${index}].role`, }; } // Validate content (can be null for assistant with tool_calls) if (msg.role === 'assistant' && msg.tool_calls) { // Content can be null/undefined when tool_calls present } else if (msg.content === undefined || msg.content === null) { return { valid: false, error: `Missing content at index ${index}`, param: `messages[${index}].content`, }; } else if (typeof msg.content !== 'string') { return { valid: false, error: `Content at index ${index} must be a string`, param: `messages[${index}].content`, }; } // Validate tool response message if (msg.role === 'tool' && !msg.tool_call_id) { return { valid: false, error: `Tool message at index ${index} requires tool_call_id`, param: `messages[${index}].tool_call_id`, }; } return { valid: true }; } /** * Validate embeddings request */ public validateEmbeddingsRequest(body: unknown): IValidationResult { if (!body || typeof body !== 'object') { return { valid: false, error: 'Request body must be a JSON object' }; } const request = body as Record; // Validate model if (!request.model || typeof request.model !== 'string') { return { valid: false, error: 'Missing or invalid "model" field', param: 'model' }; } // Validate input if (request.input === undefined || request.input === null) { return { valid: false, error: 'Missing "input" field', param: 'input' }; } const input = request.input; if (typeof input !== 'string' && !Array.isArray(input)) { return { valid: false, error: '"input" must be a string or array of strings', param: 'input' }; } if (Array.isArray(input)) { for (let i = 0; i < input.length; i++) { if (typeof input[i] !== 'string') { return { valid: false, error: `"input[${i}]" must be a string`, param: `input[${i}]` }; } } if (input.length === 0) { return { valid: false, error: '"input" array cannot be empty', param: 'input' }; } } // Validate encoding_format if (request.encoding_format !== undefined) { const format = request.encoding_format as string; if (format !== 'float' && format !== 'base64') { return { valid: false, error: '"encoding_format" must be "float" or "base64"', param: 'encoding_format' }; } } return { valid: true }; } /** * Check if model is in greenlist (async validation) */ public async validateModelGreenlist(modelName: string): Promise { const isGreenlit = await this.modelRegistry.isModelGreenlit(modelName); if (!isGreenlit) { return { valid: false, error: `Model "${modelName}" is not greenlit. Contact administrator to add it to the greenlist.`, param: 'model', }; } return { valid: true }; } /** * Sanitize request body by removing unknown fields */ public sanitizeChatRequest(body: Record): IChatCompletionRequest { return { model: body.model as string, messages: body.messages as IChatCompletionRequest['messages'], max_tokens: body.max_tokens as number | undefined, temperature: body.temperature as number | undefined, top_p: body.top_p as number | undefined, n: body.n as number | undefined, stream: body.stream as boolean | undefined, stop: body.stop as string | string[] | undefined, presence_penalty: body.presence_penalty as number | undefined, frequency_penalty: body.frequency_penalty as number | undefined, user: body.user as string | undefined, tools: body.tools as IChatCompletionRequest['tools'], tool_choice: body.tool_choice as IChatCompletionRequest['tool_choice'], }; } /** * Sanitize embeddings request */ public sanitizeEmbeddingsRequest(body: Record): IEmbeddingsRequest { return { model: body.model as string, input: body.input as string | string[], user: body.user as string | undefined, encoding_format: body.encoding_format as 'float' | 'base64' | undefined, }; } }