Files
modelgrid/ts/api/middleware/sanity.ts
Juergen Kunz daaf6559e3
Some checks failed
CI / Type Check & Lint (push) Failing after 5s
CI / Build Test (Current Platform) (push) Failing after 5s
CI / Build All Platforms (push) Successful in 49s
initial
2026-01-30 03:16:57 +00:00

255 lines
8.2 KiB
TypeScript

/**
* Sanity Middleware
*
* Validates request structure and parameters.
*/
import type { IChatCompletionRequest, IEmbeddingsRequest } from '../../interfaces/api.ts';
import { ModelRegistry } from '../../models/registry.ts';
/**
* Validation result
*/
export interface IValidationResult {
valid: boolean;
error?: string;
param?: string;
}
/**
* Request validation middleware
*/
export class SanityMiddleware {
private modelRegistry: ModelRegistry;
constructor(modelRegistry: ModelRegistry) {
this.modelRegistry = modelRegistry;
}
/**
* Validate chat completion request
*/
public validateChatRequest(body: unknown): IValidationResult {
if (!body || typeof body !== 'object') {
return { valid: false, error: 'Request body must be a JSON object' };
}
const request = body as Record<string, unknown>;
// Validate model
if (!request.model || typeof request.model !== 'string') {
return { valid: false, error: 'Missing or invalid "model" field', param: 'model' };
}
// Validate messages
if (!Array.isArray(request.messages)) {
return { valid: false, error: 'Missing or invalid "messages" field', param: 'messages' };
}
if (request.messages.length === 0) {
return { valid: false, error: '"messages" array cannot be empty', param: 'messages' };
}
// Validate each message
for (let i = 0; i < request.messages.length; i++) {
const msg = request.messages[i] as Record<string, unknown>;
const msgValidation = this.validateMessage(msg, i);
if (!msgValidation.valid) {
return msgValidation;
}
}
// Validate optional parameters
if (request.temperature !== undefined) {
const temp = request.temperature as number;
if (typeof temp !== 'number' || temp < 0 || temp > 2) {
return { valid: false, error: '"temperature" must be between 0 and 2', param: 'temperature' };
}
}
if (request.top_p !== undefined) {
const topP = request.top_p as number;
if (typeof topP !== 'number' || topP < 0 || topP > 1) {
return { valid: false, error: '"top_p" must be between 0 and 1', param: 'top_p' };
}
}
if (request.max_tokens !== undefined) {
const maxTokens = request.max_tokens as number;
if (typeof maxTokens !== 'number' || maxTokens < 1) {
return { valid: false, error: '"max_tokens" must be a positive integer', param: 'max_tokens' };
}
}
if (request.n !== undefined) {
const n = request.n as number;
if (typeof n !== 'number' || n < 1 || n > 10) {
return { valid: false, error: '"n" must be between 1 and 10', param: 'n' };
}
}
if (request.stream !== undefined && typeof request.stream !== 'boolean') {
return { valid: false, error: '"stream" must be a boolean', param: 'stream' };
}
if (request.presence_penalty !== undefined) {
const pp = request.presence_penalty as number;
if (typeof pp !== 'number' || pp < -2 || pp > 2) {
return { valid: false, error: '"presence_penalty" must be between -2 and 2', param: 'presence_penalty' };
}
}
if (request.frequency_penalty !== undefined) {
const fp = request.frequency_penalty as number;
if (typeof fp !== 'number' || fp < -2 || fp > 2) {
return { valid: false, error: '"frequency_penalty" must be between -2 and 2', param: 'frequency_penalty' };
}
}
return { valid: true };
}
/**
* Validate a single message in the chat request
*/
private validateMessage(msg: Record<string, unknown>, index: number): IValidationResult {
if (!msg || typeof msg !== 'object') {
return { valid: false, error: `Message at index ${index} must be an object`, param: `messages[${index}]` };
}
// Validate role
const validRoles = ['system', 'user', 'assistant', 'tool'];
if (!msg.role || !validRoles.includes(msg.role as string)) {
return {
valid: false,
error: `Invalid role at index ${index}. Must be one of: ${validRoles.join(', ')}`,
param: `messages[${index}].role`,
};
}
// Validate content (can be null for assistant with tool_calls)
if (msg.role === 'assistant' && msg.tool_calls) {
// Content can be null/undefined when tool_calls present
} else if (msg.content === undefined || msg.content === null) {
return {
valid: false,
error: `Missing content at index ${index}`,
param: `messages[${index}].content`,
};
} else if (typeof msg.content !== 'string') {
return {
valid: false,
error: `Content at index ${index} must be a string`,
param: `messages[${index}].content`,
};
}
// Validate tool response message
if (msg.role === 'tool' && !msg.tool_call_id) {
return {
valid: false,
error: `Tool message at index ${index} requires tool_call_id`,
param: `messages[${index}].tool_call_id`,
};
}
return { valid: true };
}
/**
* Validate embeddings request
*/
public validateEmbeddingsRequest(body: unknown): IValidationResult {
if (!body || typeof body !== 'object') {
return { valid: false, error: 'Request body must be a JSON object' };
}
const request = body as Record<string, unknown>;
// Validate model
if (!request.model || typeof request.model !== 'string') {
return { valid: false, error: 'Missing or invalid "model" field', param: 'model' };
}
// Validate input
if (request.input === undefined || request.input === null) {
return { valid: false, error: 'Missing "input" field', param: 'input' };
}
const input = request.input;
if (typeof input !== 'string' && !Array.isArray(input)) {
return { valid: false, error: '"input" must be a string or array of strings', param: 'input' };
}
if (Array.isArray(input)) {
for (let i = 0; i < input.length; i++) {
if (typeof input[i] !== 'string') {
return { valid: false, error: `"input[${i}]" must be a string`, param: `input[${i}]` };
}
}
if (input.length === 0) {
return { valid: false, error: '"input" array cannot be empty', param: 'input' };
}
}
// Validate encoding_format
if (request.encoding_format !== undefined) {
const format = request.encoding_format as string;
if (format !== 'float' && format !== 'base64') {
return { valid: false, error: '"encoding_format" must be "float" or "base64"', param: 'encoding_format' };
}
}
return { valid: true };
}
/**
* Check if model is in greenlist (async validation)
*/
public async validateModelGreenlist(modelName: string): Promise<IValidationResult> {
const isGreenlit = await this.modelRegistry.isModelGreenlit(modelName);
if (!isGreenlit) {
return {
valid: false,
error: `Model "${modelName}" is not greenlit. Contact administrator to add it to the greenlist.`,
param: 'model',
};
}
return { valid: true };
}
/**
* Sanitize request body by removing unknown fields
*/
public sanitizeChatRequest(body: Record<string, unknown>): IChatCompletionRequest {
return {
model: body.model as string,
messages: body.messages as IChatCompletionRequest['messages'],
max_tokens: body.max_tokens as number | undefined,
temperature: body.temperature as number | undefined,
top_p: body.top_p as number | undefined,
n: body.n as number | undefined,
stream: body.stream as boolean | undefined,
stop: body.stop as string | string[] | undefined,
presence_penalty: body.presence_penalty as number | undefined,
frequency_penalty: body.frequency_penalty as number | undefined,
user: body.user as string | undefined,
tools: body.tools as IChatCompletionRequest['tools'],
tool_choice: body.tool_choice as IChatCompletionRequest['tool_choice'],
};
}
/**
* Sanitize embeddings request
*/
public sanitizeEmbeddingsRequest(body: Record<string, unknown>): IEmbeddingsRequest {
return {
model: body.model as string,
input: body.input as string | string[],
user: body.user as string | undefined,
encoding_format: body.encoding_format as 'float' | 'base64' | undefined,
};
}
}