255 lines
8.2 KiB
TypeScript
255 lines
8.2 KiB
TypeScript
/**
|
|
* Sanity Middleware
|
|
*
|
|
* Validates request structure and parameters.
|
|
*/
|
|
|
|
import type { IChatCompletionRequest, IEmbeddingsRequest } from '../../interfaces/api.ts';
|
|
import { ModelRegistry } from '../../models/registry.ts';
|
|
|
|
/**
|
|
* Validation result
|
|
*/
|
|
export interface IValidationResult {
|
|
valid: boolean;
|
|
error?: string;
|
|
param?: string;
|
|
}
|
|
|
|
/**
|
|
* Request validation middleware
|
|
*/
|
|
export class SanityMiddleware {
|
|
private modelRegistry: ModelRegistry;
|
|
|
|
constructor(modelRegistry: ModelRegistry) {
|
|
this.modelRegistry = modelRegistry;
|
|
}
|
|
|
|
/**
|
|
* Validate chat completion request
|
|
*/
|
|
public validateChatRequest(body: unknown): IValidationResult {
|
|
if (!body || typeof body !== 'object') {
|
|
return { valid: false, error: 'Request body must be a JSON object' };
|
|
}
|
|
|
|
const request = body as Record<string, unknown>;
|
|
|
|
// Validate model
|
|
if (!request.model || typeof request.model !== 'string') {
|
|
return { valid: false, error: 'Missing or invalid "model" field', param: 'model' };
|
|
}
|
|
|
|
// Validate messages
|
|
if (!Array.isArray(request.messages)) {
|
|
return { valid: false, error: 'Missing or invalid "messages" field', param: 'messages' };
|
|
}
|
|
|
|
if (request.messages.length === 0) {
|
|
return { valid: false, error: '"messages" array cannot be empty', param: 'messages' };
|
|
}
|
|
|
|
// Validate each message
|
|
for (let i = 0; i < request.messages.length; i++) {
|
|
const msg = request.messages[i] as Record<string, unknown>;
|
|
const msgValidation = this.validateMessage(msg, i);
|
|
if (!msgValidation.valid) {
|
|
return msgValidation;
|
|
}
|
|
}
|
|
|
|
// Validate optional parameters
|
|
if (request.temperature !== undefined) {
|
|
const temp = request.temperature as number;
|
|
if (typeof temp !== 'number' || temp < 0 || temp > 2) {
|
|
return { valid: false, error: '"temperature" must be between 0 and 2', param: 'temperature' };
|
|
}
|
|
}
|
|
|
|
if (request.top_p !== undefined) {
|
|
const topP = request.top_p as number;
|
|
if (typeof topP !== 'number' || topP < 0 || topP > 1) {
|
|
return { valid: false, error: '"top_p" must be between 0 and 1', param: 'top_p' };
|
|
}
|
|
}
|
|
|
|
if (request.max_tokens !== undefined) {
|
|
const maxTokens = request.max_tokens as number;
|
|
if (typeof maxTokens !== 'number' || maxTokens < 1) {
|
|
return { valid: false, error: '"max_tokens" must be a positive integer', param: 'max_tokens' };
|
|
}
|
|
}
|
|
|
|
if (request.n !== undefined) {
|
|
const n = request.n as number;
|
|
if (typeof n !== 'number' || n < 1 || n > 10) {
|
|
return { valid: false, error: '"n" must be between 1 and 10', param: 'n' };
|
|
}
|
|
}
|
|
|
|
if (request.stream !== undefined && typeof request.stream !== 'boolean') {
|
|
return { valid: false, error: '"stream" must be a boolean', param: 'stream' };
|
|
}
|
|
|
|
if (request.presence_penalty !== undefined) {
|
|
const pp = request.presence_penalty as number;
|
|
if (typeof pp !== 'number' || pp < -2 || pp > 2) {
|
|
return { valid: false, error: '"presence_penalty" must be between -2 and 2', param: 'presence_penalty' };
|
|
}
|
|
}
|
|
|
|
if (request.frequency_penalty !== undefined) {
|
|
const fp = request.frequency_penalty as number;
|
|
if (typeof fp !== 'number' || fp < -2 || fp > 2) {
|
|
return { valid: false, error: '"frequency_penalty" must be between -2 and 2', param: 'frequency_penalty' };
|
|
}
|
|
}
|
|
|
|
return { valid: true };
|
|
}
|
|
|
|
/**
|
|
* Validate a single message in the chat request
|
|
*/
|
|
private validateMessage(msg: Record<string, unknown>, index: number): IValidationResult {
|
|
if (!msg || typeof msg !== 'object') {
|
|
return { valid: false, error: `Message at index ${index} must be an object`, param: `messages[${index}]` };
|
|
}
|
|
|
|
// Validate role
|
|
const validRoles = ['system', 'user', 'assistant', 'tool'];
|
|
if (!msg.role || !validRoles.includes(msg.role as string)) {
|
|
return {
|
|
valid: false,
|
|
error: `Invalid role at index ${index}. Must be one of: ${validRoles.join(', ')}`,
|
|
param: `messages[${index}].role`,
|
|
};
|
|
}
|
|
|
|
// Validate content (can be null for assistant with tool_calls)
|
|
if (msg.role === 'assistant' && msg.tool_calls) {
|
|
// Content can be null/undefined when tool_calls present
|
|
} else if (msg.content === undefined || msg.content === null) {
|
|
return {
|
|
valid: false,
|
|
error: `Missing content at index ${index}`,
|
|
param: `messages[${index}].content`,
|
|
};
|
|
} else if (typeof msg.content !== 'string') {
|
|
return {
|
|
valid: false,
|
|
error: `Content at index ${index} must be a string`,
|
|
param: `messages[${index}].content`,
|
|
};
|
|
}
|
|
|
|
// Validate tool response message
|
|
if (msg.role === 'tool' && !msg.tool_call_id) {
|
|
return {
|
|
valid: false,
|
|
error: `Tool message at index ${index} requires tool_call_id`,
|
|
param: `messages[${index}].tool_call_id`,
|
|
};
|
|
}
|
|
|
|
return { valid: true };
|
|
}
|
|
|
|
/**
|
|
* Validate embeddings request
|
|
*/
|
|
public validateEmbeddingsRequest(body: unknown): IValidationResult {
|
|
if (!body || typeof body !== 'object') {
|
|
return { valid: false, error: 'Request body must be a JSON object' };
|
|
}
|
|
|
|
const request = body as Record<string, unknown>;
|
|
|
|
// Validate model
|
|
if (!request.model || typeof request.model !== 'string') {
|
|
return { valid: false, error: 'Missing or invalid "model" field', param: 'model' };
|
|
}
|
|
|
|
// Validate input
|
|
if (request.input === undefined || request.input === null) {
|
|
return { valid: false, error: 'Missing "input" field', param: 'input' };
|
|
}
|
|
|
|
const input = request.input;
|
|
if (typeof input !== 'string' && !Array.isArray(input)) {
|
|
return { valid: false, error: '"input" must be a string or array of strings', param: 'input' };
|
|
}
|
|
|
|
if (Array.isArray(input)) {
|
|
for (let i = 0; i < input.length; i++) {
|
|
if (typeof input[i] !== 'string') {
|
|
return { valid: false, error: `"input[${i}]" must be a string`, param: `input[${i}]` };
|
|
}
|
|
}
|
|
|
|
if (input.length === 0) {
|
|
return { valid: false, error: '"input" array cannot be empty', param: 'input' };
|
|
}
|
|
}
|
|
|
|
// Validate encoding_format
|
|
if (request.encoding_format !== undefined) {
|
|
const format = request.encoding_format as string;
|
|
if (format !== 'float' && format !== 'base64') {
|
|
return { valid: false, error: '"encoding_format" must be "float" or "base64"', param: 'encoding_format' };
|
|
}
|
|
}
|
|
|
|
return { valid: true };
|
|
}
|
|
|
|
/**
|
|
* Check if model is in greenlist (async validation)
|
|
*/
|
|
public async validateModelGreenlist(modelName: string): Promise<IValidationResult> {
|
|
const isGreenlit = await this.modelRegistry.isModelGreenlit(modelName);
|
|
if (!isGreenlit) {
|
|
return {
|
|
valid: false,
|
|
error: `Model "${modelName}" is not greenlit. Contact administrator to add it to the greenlist.`,
|
|
param: 'model',
|
|
};
|
|
}
|
|
return { valid: true };
|
|
}
|
|
|
|
/**
|
|
* Sanitize request body by removing unknown fields
|
|
*/
|
|
public sanitizeChatRequest(body: Record<string, unknown>): IChatCompletionRequest {
|
|
return {
|
|
model: body.model as string,
|
|
messages: body.messages as IChatCompletionRequest['messages'],
|
|
max_tokens: body.max_tokens as number | undefined,
|
|
temperature: body.temperature as number | undefined,
|
|
top_p: body.top_p as number | undefined,
|
|
n: body.n as number | undefined,
|
|
stream: body.stream as boolean | undefined,
|
|
stop: body.stop as string | string[] | undefined,
|
|
presence_penalty: body.presence_penalty as number | undefined,
|
|
frequency_penalty: body.frequency_penalty as number | undefined,
|
|
user: body.user as string | undefined,
|
|
tools: body.tools as IChatCompletionRequest['tools'],
|
|
tool_choice: body.tool_choice as IChatCompletionRequest['tool_choice'],
|
|
};
|
|
}
|
|
|
|
/**
|
|
* Sanitize embeddings request
|
|
*/
|
|
public sanitizeEmbeddingsRequest(body: Record<string, unknown>): IEmbeddingsRequest {
|
|
return {
|
|
model: body.model as string,
|
|
input: body.input as string | string[],
|
|
user: body.user as string | undefined,
|
|
encoding_format: body.encoding_format as 'float' | 'base64' | undefined,
|
|
};
|
|
}
|
|
}
|