diff --git a/test/api-router_test.ts b/test/api-router_test.ts index 4f2757a..83b9cf3 100644 --- a/test/api-router_test.ts +++ b/test/api-router_test.ts @@ -1,4 +1,5 @@ import { assertEquals } from 'jsr:@std/assert@^1.0.0'; +import { EventEmitter } from 'node:events'; import { ApiRouter } from '../ts/api/router.ts'; class TestResponse { @@ -18,6 +19,29 @@ class TestResponse { } } +class TestRequest extends EventEmitter { + public method: string; + public headers: Record; + public destroyed = false; + public paused = false; + + constructor(method: string, headers: Record) { + super(); + this.method = method; + this.headers = headers; + } + + public pause(): this { + this.paused = true; + return this; + } + + public destroy(): this { + this.destroyed = true; + return this; + } +} + function createRouter(): ApiRouter { return new ApiRouter( {} as never, @@ -55,3 +79,20 @@ Deno.test('ApiRouter rejects protected endpoints without a bearer token', async assertEquals(response.statusCode, 401); assertEquals(JSON.parse(response.body).error.type, 'authentication_error'); }); + +Deno.test('ApiRouter returns 413 for oversized request bodies', async () => { + const router = createRouter(); + const request = new TestRequest('POST', { + authorization: 'Bearer valid-key', + }); + const response = new TestResponse(); + + const routePromise = router.route(request as never, response as never, '/v1/chat/completions'); + request.emit('data', 'x'.repeat(10 * 1024 * 1024 + 1)); + await routePromise; + + assertEquals(response.statusCode, 413); + assertEquals(request.paused, true); + assertEquals(request.destroyed, true); + assertEquals(JSON.parse(response.body).error.message, 'Request body too large'); +}); diff --git a/ts/api/router.ts b/ts/api/router.ts index 030c111..501ea60 100644 --- a/ts/api/router.ts +++ b/ts/api/router.ts @@ -17,6 +17,11 @@ import { EmbeddingsHandler } from './handlers/embeddings.ts'; import { AuthMiddleware } from './middleware/auth.ts'; import { SanityMiddleware } from './middleware/sanity.ts'; +interface IParsedRequestBody { + kind: 'ok' | 'invalid' | 'too_large'; + body?: unknown; +} + /** * API Router - routes requests to handlers */ @@ -119,11 +124,16 @@ export class ApiRouter { } // Parse body - const body = await this.parseRequestBody(req); - if (!body) { + const parsedBody = await this.parseRequestBody(req); + if (parsedBody.kind === 'too_large') { + this.sendError(res, 413, 'Request body too large', 'invalid_request_error'); + return; + } + if (parsedBody.kind !== 'ok') { this.sendError(res, 400, 'Invalid JSON body', 'invalid_request_error'); return; } + const body = parsedBody.body; // Validate request const validation = this.sanityMiddleware.validateChatRequest(body); @@ -155,11 +165,16 @@ export class ApiRouter { } // Parse body - const body = await this.parseRequestBody(req); - if (!body) { + const parsedBody = await this.parseRequestBody(req); + if (parsedBody.kind === 'too_large') { + this.sendError(res, 413, 'Request body too large', 'invalid_request_error'); + return; + } + if (parsedBody.kind !== 'ok') { this.sendError(res, 400, 'Invalid JSON body', 'invalid_request_error'); return; } + const body = parsedBody.body; // Convert to chat format and handle const chatBody = this.convertCompletionToChat(body as Record); @@ -229,11 +244,16 @@ export class ApiRouter { } // Parse body - const body = await this.parseRequestBody(req); - if (!body) { + const parsedBody = await this.parseRequestBody(req); + if (parsedBody.kind === 'too_large') { + this.sendError(res, 413, 'Request body too large', 'invalid_request_error'); + return; + } + if (parsedBody.kind !== 'ok') { this.sendError(res, 400, 'Invalid JSON body', 'invalid_request_error'); return; } + const body = parsedBody.body; const validation = this.sanityMiddleware.validateEmbeddingsRequest(body); if (!validation.valid) { @@ -250,28 +270,45 @@ export class ApiRouter { /** * Parse request body */ - private async parseRequestBody(req: http.IncomingMessage): Promise { + private async parseRequestBody(req: http.IncomingMessage): Promise { return new Promise((resolve) => { let body = ''; + let resolved = false; + + const finish = (result: IParsedRequestBody): void => { + if (resolved) { + return; + } + resolved = true; + resolve(result); + }; req.on('data', (chunk) => { + if (resolved) { + return; + } + body += chunk.toString(); - // Limit body size + if (body.length > 10 * 1024 * 1024) { - resolve(null); + req.pause(); + req.destroy(); + finish({ kind: 'too_large' }); } }); req.on('end', () => { try { - resolve(JSON.parse(body)); + finish({ kind: 'ok', body: JSON.parse(body) }); } catch { - resolve(null); + finish({ kind: 'invalid' }); } }); req.on('error', () => { - resolve(null); + if (!resolved) { + finish({ kind: 'invalid' }); + } }); }); }