fix(api): map upstream timeouts to 504 responses
This commit is contained in:
@@ -0,0 +1,120 @@
|
|||||||
|
import { assertEquals } from 'jsr:@std/assert@^1.0.0';
|
||||||
|
import { ChatHandler } from '../ts/api/handlers/chat.ts';
|
||||||
|
import { EmbeddingsHandler } from '../ts/api/handlers/embeddings.ts';
|
||||||
|
import { UpstreamTimeoutError } from '../ts/containers/base-container.ts';
|
||||||
|
|
||||||
|
class TestResponse {
|
||||||
|
public statusCode = 200;
|
||||||
|
public headers: Record<string, string> = {};
|
||||||
|
public body = '';
|
||||||
|
|
||||||
|
public writeHead(statusCode: number, headers: Record<string, string>): TestResponse {
|
||||||
|
this.statusCode = statusCode;
|
||||||
|
this.headers = headers;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public end(body = ''): TestResponse {
|
||||||
|
this.body = body;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public write(_chunk: string | Uint8Array): boolean {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Deno.test('ChatHandler maps upstream timeouts to 504 responses', async () => {
|
||||||
|
const handler = new ChatHandler(
|
||||||
|
{
|
||||||
|
async findContainerForModel() {
|
||||||
|
return {
|
||||||
|
async chatCompletion() {
|
||||||
|
throw new UpstreamTimeoutError();
|
||||||
|
},
|
||||||
|
async chatCompletionStream() {
|
||||||
|
throw new UpstreamTimeoutError();
|
||||||
|
},
|
||||||
|
};
|
||||||
|
},
|
||||||
|
} as never,
|
||||||
|
{
|
||||||
|
async getModel(modelName: string) {
|
||||||
|
return { id: modelName };
|
||||||
|
},
|
||||||
|
} as never,
|
||||||
|
{
|
||||||
|
async loadModel() {
|
||||||
|
return { success: false };
|
||||||
|
},
|
||||||
|
} as never,
|
||||||
|
{
|
||||||
|
shouldDeployLocallyFirst() {
|
||||||
|
return false;
|
||||||
|
},
|
||||||
|
} as never,
|
||||||
|
);
|
||||||
|
|
||||||
|
const response = new TestResponse();
|
||||||
|
await handler.handleChatCompletion(
|
||||||
|
{ headers: {} } as never,
|
||||||
|
response as never,
|
||||||
|
{ model: 'meta-llama/Llama-3.1-8B-Instruct', messages: [{ role: 'user', content: 'hi' }] },
|
||||||
|
);
|
||||||
|
|
||||||
|
assertEquals(response.statusCode, 504);
|
||||||
|
assertEquals(JSON.parse(response.body).error.type, 'upstream_timeout');
|
||||||
|
});
|
||||||
|
|
||||||
|
Deno.test('EmbeddingsHandler maps upstream timeouts to 504 responses', async () => {
|
||||||
|
const originalFetch = globalThis.fetch;
|
||||||
|
globalThis.fetch = async () => {
|
||||||
|
const error = new Error('request aborted');
|
||||||
|
error.name = 'AbortError';
|
||||||
|
throw error;
|
||||||
|
};
|
||||||
|
|
||||||
|
try {
|
||||||
|
const handler = new EmbeddingsHandler(
|
||||||
|
{
|
||||||
|
async findContainerForModel() {
|
||||||
|
return null;
|
||||||
|
},
|
||||||
|
} as never,
|
||||||
|
{
|
||||||
|
async getModel(modelName: string) {
|
||||||
|
return { id: modelName };
|
||||||
|
},
|
||||||
|
} as never,
|
||||||
|
{
|
||||||
|
async ensureModelViaControlPlane(modelName: string) {
|
||||||
|
return {
|
||||||
|
location: {
|
||||||
|
modelId: modelName,
|
||||||
|
nodeName: 'worker-a',
|
||||||
|
endpoint: 'http://worker-a:8080',
|
||||||
|
healthy: true,
|
||||||
|
engine: 'vllm',
|
||||||
|
containerId: 'remote',
|
||||||
|
},
|
||||||
|
};
|
||||||
|
},
|
||||||
|
getLocalNodeName() {
|
||||||
|
return 'control';
|
||||||
|
},
|
||||||
|
} as never,
|
||||||
|
);
|
||||||
|
|
||||||
|
const response = new TestResponse();
|
||||||
|
await handler.handleEmbeddings(
|
||||||
|
{ headers: {} } as never,
|
||||||
|
response as never,
|
||||||
|
{ model: 'BAAI/bge-m3', input: 'hello' },
|
||||||
|
);
|
||||||
|
|
||||||
|
assertEquals(response.statusCode, 504);
|
||||||
|
assertEquals(JSON.parse(response.body).error.type, 'upstream_timeout');
|
||||||
|
} finally {
|
||||||
|
globalThis.fetch = originalFetch;
|
||||||
|
}
|
||||||
|
});
|
||||||
@@ -6,6 +6,7 @@ import * as http from 'node:http';
|
|||||||
import type { IApiError, IChatCompletionRequest } from '../../interfaces/api.ts';
|
import type { IApiError, IChatCompletionRequest } from '../../interfaces/api.ts';
|
||||||
import { ClusterCoordinator } from '../../cluster/coordinator.ts';
|
import { ClusterCoordinator } from '../../cluster/coordinator.ts';
|
||||||
import { ContainerManager } from '../../containers/container-manager.ts';
|
import { ContainerManager } from '../../containers/container-manager.ts';
|
||||||
|
import { UpstreamTimeoutError } from '../../containers/base-container.ts';
|
||||||
import { API_SERVER } from '../../constants.ts';
|
import { API_SERVER } from '../../constants.ts';
|
||||||
import { logger } from '../../logger.ts';
|
import { logger } from '../../logger.ts';
|
||||||
import { ModelRegistry } from '../../models/registry.ts';
|
import { ModelRegistry } from '../../models/registry.ts';
|
||||||
@@ -86,6 +87,11 @@ export class ChatHandler {
|
|||||||
|
|
||||||
await this.proxyChatRequest(req, res, ensured.location.endpoint, requestBody);
|
await this.proxyChatRequest(req, res, ensured.location.endpoint, requestBody);
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
|
if (error instanceof UpstreamTimeoutError) {
|
||||||
|
this.sendError(res, 504, error.message, 'upstream_timeout');
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
const message = error instanceof Error ? error.message : String(error);
|
const message = error instanceof Error ? error.message : String(error);
|
||||||
logger.error(`Chat completion error: ${message}`);
|
logger.error(`Chat completion error: ${message}`);
|
||||||
this.sendError(res, 500, `Chat completion failed: ${message}`, 'server_error');
|
this.sendError(res, 500, `Chat completion failed: ${message}`, 'server_error');
|
||||||
@@ -166,6 +172,11 @@ export class ChatHandler {
|
|||||||
headers: this.buildForwardHeaders(req),
|
headers: this.buildForwardHeaders(req),
|
||||||
body: JSON.stringify(body),
|
body: JSON.stringify(body),
|
||||||
signal: controller.signal,
|
signal: controller.signal,
|
||||||
|
}).catch((error) => {
|
||||||
|
if (error instanceof Error && error.name === 'AbortError') {
|
||||||
|
throw new UpstreamTimeoutError();
|
||||||
|
}
|
||||||
|
throw error;
|
||||||
}).finally(() => clearTimeout(timeout));
|
}).finally(() => clearTimeout(timeout));
|
||||||
|
|
||||||
if (body.stream) {
|
if (body.stream) {
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import type {
|
|||||||
} from '../../interfaces/api.ts';
|
} from '../../interfaces/api.ts';
|
||||||
import { ClusterCoordinator } from '../../cluster/coordinator.ts';
|
import { ClusterCoordinator } from '../../cluster/coordinator.ts';
|
||||||
import { ContainerManager } from '../../containers/container-manager.ts';
|
import { ContainerManager } from '../../containers/container-manager.ts';
|
||||||
|
import { UpstreamTimeoutError } from '../../containers/base-container.ts';
|
||||||
import { API_SERVER } from '../../constants.ts';
|
import { API_SERVER } from '../../constants.ts';
|
||||||
import { logger } from '../../logger.ts';
|
import { logger } from '../../logger.ts';
|
||||||
import { ModelRegistry } from '../../models/registry.ts';
|
import { ModelRegistry } from '../../models/registry.ts';
|
||||||
@@ -93,6 +94,11 @@ export class EmbeddingsHandler {
|
|||||||
});
|
});
|
||||||
res.end(text);
|
res.end(text);
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
|
if (error instanceof UpstreamTimeoutError) {
|
||||||
|
this.sendError(res, 504, error.message, 'upstream_timeout');
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
const message = error instanceof Error ? error.message : String(error);
|
const message = error instanceof Error ? error.message : String(error);
|
||||||
logger.error(`Embeddings error: ${message}`);
|
logger.error(`Embeddings error: ${message}`);
|
||||||
this.sendError(res, 500, `Embeddings generation failed: ${message}`, 'server_error');
|
this.sendError(res, 500, `Embeddings generation failed: ${message}`, 'server_error');
|
||||||
@@ -224,6 +230,11 @@ export class EmbeddingsHandler {
|
|||||||
...init,
|
...init,
|
||||||
signal: controller.signal,
|
signal: controller.signal,
|
||||||
});
|
});
|
||||||
|
} catch (error) {
|
||||||
|
if (error instanceof Error && error.name === 'AbortError') {
|
||||||
|
throw new UpstreamTimeoutError();
|
||||||
|
}
|
||||||
|
throw error;
|
||||||
} finally {
|
} finally {
|
||||||
clearTimeout(timeout);
|
clearTimeout(timeout);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -24,6 +24,13 @@ export type TModelPullProgress = (progress: {
|
|||||||
percent?: number;
|
percent?: number;
|
||||||
}) => void;
|
}) => void;
|
||||||
|
|
||||||
|
export class UpstreamTimeoutError extends Error {
|
||||||
|
constructor(message: string = 'Upstream request timed out') {
|
||||||
|
super(message);
|
||||||
|
this.name = 'UpstreamTimeoutError';
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Abstract base class for AI model containers
|
* Abstract base class for AI model containers
|
||||||
*/
|
*/
|
||||||
@@ -181,6 +188,11 @@ export abstract class BaseContainer {
|
|||||||
});
|
});
|
||||||
|
|
||||||
return response;
|
return response;
|
||||||
|
} catch (error) {
|
||||||
|
if (error instanceof Error && error.name === 'AbortError') {
|
||||||
|
throw new UpstreamTimeoutError();
|
||||||
|
}
|
||||||
|
throw error;
|
||||||
} finally {
|
} finally {
|
||||||
clearTimeout(timeoutId);
|
clearTimeout(timeoutId);
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user