262 lines
7.6 KiB
TypeScript
262 lines
7.6 KiB
TypeScript
/**
|
|
* Embeddings handler.
|
|
*/
|
|
|
|
import * as http from 'node:http';
|
|
import type {
|
|
IApiError,
|
|
IEmbeddingData,
|
|
IEmbeddingsRequest,
|
|
IEmbeddingsResponse,
|
|
} from '../../interfaces/api.ts';
|
|
import { ClusterCoordinator } from '../../cluster/coordinator.ts';
|
|
import { ContainerManager } from '../../containers/container-manager.ts';
|
|
import { UpstreamTimeoutError } from '../../containers/base-container.ts';
|
|
import { API_SERVER } from '../../constants.ts';
|
|
import { logger } from '../../logger.ts';
|
|
import { ModelRegistry } from '../../models/registry.ts';
|
|
|
|
export class EmbeddingsHandler {
|
|
private containerManager: ContainerManager;
|
|
private modelRegistry: ModelRegistry;
|
|
private clusterCoordinator: ClusterCoordinator;
|
|
|
|
constructor(
|
|
containerManager: ContainerManager,
|
|
modelRegistry: ModelRegistry,
|
|
clusterCoordinator: ClusterCoordinator,
|
|
) {
|
|
this.containerManager = containerManager;
|
|
this.modelRegistry = modelRegistry;
|
|
this.clusterCoordinator = clusterCoordinator;
|
|
}
|
|
|
|
public async handleEmbeddings(
|
|
req: http.IncomingMessage,
|
|
res: http.ServerResponse,
|
|
body: IEmbeddingsRequest,
|
|
): Promise<void> {
|
|
const canonicalModel = await this.resolveCanonicalModel(body.model);
|
|
const requestBody: IEmbeddingsRequest = {
|
|
...body,
|
|
model: canonicalModel,
|
|
};
|
|
|
|
logger.dim(`Embeddings request for model: ${canonicalModel}`);
|
|
|
|
try {
|
|
const container = await this.containerManager.findContainerForModel(canonicalModel);
|
|
if (container) {
|
|
const response = await this.generateEmbeddings(container, requestBody);
|
|
res.writeHead(200, { 'Content-Type': 'application/json' });
|
|
res.end(JSON.stringify(response));
|
|
return;
|
|
}
|
|
|
|
const ensured = await this.clusterCoordinator.ensureModelViaControlPlane(canonicalModel);
|
|
if (!ensured) {
|
|
this.sendError(
|
|
res,
|
|
404,
|
|
`Embedding model "${canonicalModel}" not found`,
|
|
'model_not_found',
|
|
);
|
|
return;
|
|
}
|
|
|
|
if (ensured.location.nodeName === this.clusterCoordinator.getLocalNodeName()) {
|
|
const localContainer = await this.containerManager.findContainerForModel(canonicalModel);
|
|
if (!localContainer) {
|
|
this.sendError(
|
|
res,
|
|
503,
|
|
`Embedding model "${canonicalModel}" is not ready`,
|
|
'server_error',
|
|
);
|
|
return;
|
|
}
|
|
|
|
const response = await this.generateEmbeddings(localContainer, requestBody);
|
|
res.writeHead(200, { 'Content-Type': 'application/json' });
|
|
res.end(JSON.stringify(response));
|
|
return;
|
|
}
|
|
|
|
const response = await this.fetchWithTimeout(`${ensured.location.endpoint}/v1/embeddings`, {
|
|
method: 'POST',
|
|
headers: this.buildForwardHeaders(req),
|
|
body: JSON.stringify(requestBody),
|
|
});
|
|
|
|
const text = await response.text();
|
|
res.writeHead(response.status, {
|
|
'Content-Type': response.headers.get('content-type') || 'application/json',
|
|
});
|
|
res.end(text);
|
|
} catch (error) {
|
|
if (error instanceof UpstreamTimeoutError) {
|
|
this.sendError(res, 504, error.message, 'upstream_timeout');
|
|
return;
|
|
}
|
|
|
|
const message = error instanceof Error ? error.message : String(error);
|
|
logger.error(`Embeddings error: ${message}`);
|
|
this.sendError(res, 500, `Embeddings generation failed: ${message}`, 'server_error');
|
|
}
|
|
}
|
|
|
|
private async resolveCanonicalModel(modelName: string): Promise<string> {
|
|
const model = await this.modelRegistry.getModel(modelName);
|
|
return model?.id || modelName;
|
|
}
|
|
|
|
private async generateEmbeddings(
|
|
container: import('../../containers/base-container.ts').BaseContainer,
|
|
request: IEmbeddingsRequest,
|
|
): Promise<IEmbeddingsResponse> {
|
|
const inputs = Array.isArray(request.input) ? request.input : [request.input];
|
|
const embeddings: IEmbeddingData[] = [];
|
|
let totalTokens = 0;
|
|
|
|
for (let i = 0; i < inputs.length; i++) {
|
|
const input = inputs[i];
|
|
const embedding = await this.getEmbeddingFromContainer(container, request.model, input);
|
|
|
|
embeddings.push({
|
|
object: 'embedding',
|
|
embedding: embedding.vector,
|
|
index: i,
|
|
});
|
|
|
|
totalTokens += embedding.tokenCount;
|
|
}
|
|
|
|
return {
|
|
object: 'list',
|
|
data: embeddings,
|
|
model: request.model,
|
|
usage: {
|
|
prompt_tokens: totalTokens,
|
|
total_tokens: totalTokens,
|
|
},
|
|
};
|
|
}
|
|
|
|
private async getEmbeddingFromContainer(
|
|
container: import('../../containers/base-container.ts').BaseContainer,
|
|
model: string,
|
|
input: string,
|
|
): Promise<{ vector: number[]; tokenCount: number }> {
|
|
const endpoint = container.getEndpoint();
|
|
const containerType = container.type;
|
|
|
|
if (containerType === 'vllm') {
|
|
return this.getVllmEmbedding(endpoint, model, input);
|
|
}
|
|
|
|
if (containerType === 'tgi') {
|
|
return this.getTgiEmbedding(endpoint, model, input);
|
|
}
|
|
|
|
throw new Error(`Container type ${containerType} does not support embeddings`);
|
|
}
|
|
|
|
private async getVllmEmbedding(
|
|
endpoint: string,
|
|
model: string,
|
|
input: string,
|
|
): Promise<{ vector: number[]; tokenCount: number }> {
|
|
const response = await this.fetchWithTimeout(`${endpoint}/v1/embeddings`, {
|
|
method: 'POST',
|
|
headers: { 'Content-Type': 'application/json' },
|
|
body: JSON.stringify({ model, input }),
|
|
});
|
|
|
|
if (!response.ok) {
|
|
throw new Error(`vLLM embedding error: ${await response.text()}`);
|
|
}
|
|
|
|
const result = await response.json() as IEmbeddingsResponse;
|
|
return {
|
|
vector: result.data[0].embedding,
|
|
tokenCount: result.usage.total_tokens,
|
|
};
|
|
}
|
|
|
|
private async getTgiEmbedding(
|
|
endpoint: string,
|
|
_model: string,
|
|
input: string,
|
|
): Promise<{ vector: number[]; tokenCount: number }> {
|
|
const response = await this.fetchWithTimeout(`${endpoint}/embed`, {
|
|
method: 'POST',
|
|
headers: { 'Content-Type': 'application/json' },
|
|
body: JSON.stringify({ inputs: input }),
|
|
});
|
|
|
|
if (!response.ok) {
|
|
throw new Error(`TGI embedding error: ${await response.text()}`);
|
|
}
|
|
|
|
const result = await response.json() as number[][];
|
|
return {
|
|
vector: result[0],
|
|
tokenCount: Math.ceil(input.length / 4),
|
|
};
|
|
}
|
|
|
|
private buildForwardHeaders(req: http.IncomingMessage): Record<string, string> {
|
|
const headers: Record<string, string> = {
|
|
'Content-Type': 'application/json',
|
|
};
|
|
|
|
if (typeof req.headers.authorization === 'string') {
|
|
headers.Authorization = req.headers.authorization;
|
|
}
|
|
|
|
if (typeof req.headers['x-request-id'] === 'string') {
|
|
headers['X-Request-Id'] = req.headers['x-request-id'];
|
|
}
|
|
|
|
return headers;
|
|
}
|
|
|
|
private async fetchWithTimeout(url: string, init: RequestInit): Promise<Response> {
|
|
const controller = new AbortController();
|
|
const timeout = setTimeout(() => controller.abort(), API_SERVER.REQUEST_TIMEOUT_MS);
|
|
|
|
try {
|
|
return await fetch(url, {
|
|
...init,
|
|
signal: controller.signal,
|
|
});
|
|
} catch (error) {
|
|
if (error instanceof Error && error.name === 'AbortError') {
|
|
throw new UpstreamTimeoutError();
|
|
}
|
|
throw error;
|
|
} finally {
|
|
clearTimeout(timeout);
|
|
}
|
|
}
|
|
|
|
private sendError(
|
|
res: http.ServerResponse,
|
|
statusCode: number,
|
|
message: string,
|
|
type: string,
|
|
param?: string,
|
|
): void {
|
|
const error: IApiError = {
|
|
error: {
|
|
message,
|
|
type,
|
|
param,
|
|
},
|
|
};
|
|
|
|
res.writeHead(statusCode, { 'Content-Type': 'application/json' });
|
|
res.end(JSON.stringify(error));
|
|
}
|
|
}
|