diff --git a/test/api-timeout_test.ts b/test/api-timeout_test.ts new file mode 100644 index 0000000..a532918 --- /dev/null +++ b/test/api-timeout_test.ts @@ -0,0 +1,120 @@ +import { assertEquals } from 'jsr:@std/assert@^1.0.0'; +import { ChatHandler } from '../ts/api/handlers/chat.ts'; +import { EmbeddingsHandler } from '../ts/api/handlers/embeddings.ts'; +import { UpstreamTimeoutError } from '../ts/containers/base-container.ts'; + +class TestResponse { + public statusCode = 200; + public headers: Record = {}; + public body = ''; + + public writeHead(statusCode: number, headers: Record): TestResponse { + this.statusCode = statusCode; + this.headers = headers; + return this; + } + + public end(body = ''): TestResponse { + this.body = body; + return this; + } + + public write(_chunk: string | Uint8Array): boolean { + return true; + } +} + +Deno.test('ChatHandler maps upstream timeouts to 504 responses', async () => { + const handler = new ChatHandler( + { + async findContainerForModel() { + return { + async chatCompletion() { + throw new UpstreamTimeoutError(); + }, + async chatCompletionStream() { + throw new UpstreamTimeoutError(); + }, + }; + }, + } as never, + { + async getModel(modelName: string) { + return { id: modelName }; + }, + } as never, + { + async loadModel() { + return { success: false }; + }, + } as never, + { + shouldDeployLocallyFirst() { + return false; + }, + } as never, + ); + + const response = new TestResponse(); + await handler.handleChatCompletion( + { headers: {} } as never, + response as never, + { model: 'meta-llama/Llama-3.1-8B-Instruct', messages: [{ role: 'user', content: 'hi' }] }, + ); + + assertEquals(response.statusCode, 504); + assertEquals(JSON.parse(response.body).error.type, 'upstream_timeout'); +}); + +Deno.test('EmbeddingsHandler maps upstream timeouts to 504 responses', async () => { + const originalFetch = globalThis.fetch; + globalThis.fetch = async () => { + const error = new Error('request aborted'); + error.name = 'AbortError'; + throw error; + }; + + try { + const handler = new EmbeddingsHandler( + { + async findContainerForModel() { + return null; + }, + } as never, + { + async getModel(modelName: string) { + return { id: modelName }; + }, + } as never, + { + async ensureModelViaControlPlane(modelName: string) { + return { + location: { + modelId: modelName, + nodeName: 'worker-a', + endpoint: 'http://worker-a:8080', + healthy: true, + engine: 'vllm', + containerId: 'remote', + }, + }; + }, + getLocalNodeName() { + return 'control'; + }, + } as never, + ); + + const response = new TestResponse(); + await handler.handleEmbeddings( + { headers: {} } as never, + response as never, + { model: 'BAAI/bge-m3', input: 'hello' }, + ); + + assertEquals(response.statusCode, 504); + assertEquals(JSON.parse(response.body).error.type, 'upstream_timeout'); + } finally { + globalThis.fetch = originalFetch; + } +}); diff --git a/ts/api/handlers/chat.ts b/ts/api/handlers/chat.ts index 4d566d9..b4c2e94 100644 --- a/ts/api/handlers/chat.ts +++ b/ts/api/handlers/chat.ts @@ -6,6 +6,7 @@ import * as http from 'node:http'; import type { IApiError, IChatCompletionRequest } from '../../interfaces/api.ts'; import { ClusterCoordinator } from '../../cluster/coordinator.ts'; import { ContainerManager } from '../../containers/container-manager.ts'; +import { UpstreamTimeoutError } from '../../containers/base-container.ts'; import { API_SERVER } from '../../constants.ts'; import { logger } from '../../logger.ts'; import { ModelRegistry } from '../../models/registry.ts'; @@ -86,6 +87,11 @@ export class ChatHandler { await this.proxyChatRequest(req, res, ensured.location.endpoint, requestBody); } catch (error) { + if (error instanceof UpstreamTimeoutError) { + this.sendError(res, 504, error.message, 'upstream_timeout'); + return; + } + const message = error instanceof Error ? error.message : String(error); logger.error(`Chat completion error: ${message}`); this.sendError(res, 500, `Chat completion failed: ${message}`, 'server_error'); @@ -166,6 +172,11 @@ export class ChatHandler { headers: this.buildForwardHeaders(req), body: JSON.stringify(body), signal: controller.signal, + }).catch((error) => { + if (error instanceof Error && error.name === 'AbortError') { + throw new UpstreamTimeoutError(); + } + throw error; }).finally(() => clearTimeout(timeout)); if (body.stream) { diff --git a/ts/api/handlers/embeddings.ts b/ts/api/handlers/embeddings.ts index 5c90a76..df187c9 100644 --- a/ts/api/handlers/embeddings.ts +++ b/ts/api/handlers/embeddings.ts @@ -11,6 +11,7 @@ import type { } from '../../interfaces/api.ts'; import { ClusterCoordinator } from '../../cluster/coordinator.ts'; import { ContainerManager } from '../../containers/container-manager.ts'; +import { UpstreamTimeoutError } from '../../containers/base-container.ts'; import { API_SERVER } from '../../constants.ts'; import { logger } from '../../logger.ts'; import { ModelRegistry } from '../../models/registry.ts'; @@ -93,6 +94,11 @@ export class EmbeddingsHandler { }); res.end(text); } catch (error) { + if (error instanceof UpstreamTimeoutError) { + this.sendError(res, 504, error.message, 'upstream_timeout'); + return; + } + const message = error instanceof Error ? error.message : String(error); logger.error(`Embeddings error: ${message}`); this.sendError(res, 500, `Embeddings generation failed: ${message}`, 'server_error'); @@ -224,6 +230,11 @@ export class EmbeddingsHandler { ...init, signal: controller.signal, }); + } catch (error) { + if (error instanceof Error && error.name === 'AbortError') { + throw new UpstreamTimeoutError(); + } + throw error; } finally { clearTimeout(timeout); } diff --git a/ts/containers/base-container.ts b/ts/containers/base-container.ts index eb676a1..fd64a20 100644 --- a/ts/containers/base-container.ts +++ b/ts/containers/base-container.ts @@ -24,6 +24,13 @@ export type TModelPullProgress = (progress: { percent?: number; }) => void; +export class UpstreamTimeoutError extends Error { + constructor(message: string = 'Upstream request timed out') { + super(message); + this.name = 'UpstreamTimeoutError'; + } +} + /** * Abstract base class for AI model containers */ @@ -181,6 +188,11 @@ export abstract class BaseContainer { }); return response; + } catch (error) { + if (error instanceof Error && error.name === 'AbortError') { + throw new UpstreamTimeoutError(); + } + throw error; } finally { clearTimeout(timeoutId); }