Files
modelgrid/ts/api/handlers/embeddings.ts

236 lines
6.0 KiB
TypeScript
Raw Normal View History

2026-01-30 03:16:57 +00:00
/**
* Embeddings Handler
*
* Handles /v1/embeddings endpoint.
*/
import * as http from 'node:http';
import type {
IEmbeddingsRequest,
IEmbeddingsResponse,
IEmbeddingData,
IApiError,
} from '../../interfaces/api.ts';
import { logger } from '../../logger.ts';
import { ContainerManager } from '../../containers/container-manager.ts';
/**
* Handler for embeddings requests
*/
export class EmbeddingsHandler {
private containerManager: ContainerManager;
constructor(containerManager: ContainerManager) {
this.containerManager = containerManager;
}
/**
* Handle POST /v1/embeddings
*/
public async handleEmbeddings(
res: http.ServerResponse,
body: IEmbeddingsRequest,
): Promise<void> {
const modelName = body.model;
logger.dim(`Embeddings request for model: ${modelName}`);
try {
// Find container with the embedding model
const container = await this.containerManager.findContainerForModel(modelName);
if (!container) {
this.sendError(res, 404, `Embedding model "${modelName}" not found`, 'model_not_found');
return;
}
// Generate embeddings
const response = await this.generateEmbeddings(container, body);
res.writeHead(200, { 'Content-Type': 'application/json' });
res.end(JSON.stringify(response));
} catch (error) {
const message = error instanceof Error ? error.message : String(error);
logger.error(`Embeddings error: ${message}`);
this.sendError(res, 500, `Embeddings generation failed: ${message}`, 'server_error');
}
}
/**
* Generate embeddings from container
*/
private async generateEmbeddings(
container: import('../../containers/base-container.ts').BaseContainer,
request: IEmbeddingsRequest,
): Promise<IEmbeddingsResponse> {
const inputs = Array.isArray(request.input) ? request.input : [request.input];
const embeddings: IEmbeddingData[] = [];
let totalTokens = 0;
// Generate embeddings for each input
for (let i = 0; i < inputs.length; i++) {
const input = inputs[i];
const embedding = await this.getEmbeddingFromContainer(container, request.model, input);
embeddings.push({
object: 'embedding',
embedding: embedding.vector,
index: i,
});
totalTokens += embedding.tokenCount;
}
return {
object: 'list',
data: embeddings,
model: request.model,
usage: {
prompt_tokens: totalTokens,
total_tokens: totalTokens,
},
};
}
/**
* Get embedding from container (container-specific implementation)
*/
private async getEmbeddingFromContainer(
container: import('../../containers/base-container.ts').BaseContainer,
model: string,
input: string,
): Promise<{ vector: number[]; tokenCount: number }> {
const endpoint = container.getEndpoint();
const containerType = container.type;
// Route to container-specific embedding endpoint
if (containerType === 'ollama') {
return this.getOllamaEmbedding(endpoint, model, input);
} else if (containerType === 'vllm') {
return this.getVllmEmbedding(endpoint, model, input);
} else if (containerType === 'tgi') {
return this.getTgiEmbedding(endpoint, model, input);
}
throw new Error(`Container type ${containerType} does not support embeddings`);
}
/**
* Get embedding from Ollama
*/
private async getOllamaEmbedding(
endpoint: string,
model: string,
input: string,
): Promise<{ vector: number[]; tokenCount: number }> {
const response = await fetch(`${endpoint}/api/embeddings`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
model,
prompt: input,
}),
});
if (!response.ok) {
const errorText = await response.text();
throw new Error(`Ollama embedding error: ${errorText}`);
}
const result = await response.json() as { embedding: number[] };
// Estimate token count (rough approximation: ~4 chars per token)
const tokenCount = Math.ceil(input.length / 4);
return {
vector: result.embedding,
tokenCount,
};
}
/**
* Get embedding from vLLM (OpenAI-compatible)
*/
private async getVllmEmbedding(
endpoint: string,
model: string,
input: string,
): Promise<{ vector: number[]; tokenCount: number }> {
const response = await fetch(`${endpoint}/v1/embeddings`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
model,
input,
}),
});
if (!response.ok) {
const errorText = await response.text();
throw new Error(`vLLM embedding error: ${errorText}`);
}
const result = await response.json() as IEmbeddingsResponse;
return {
vector: result.data[0].embedding,
tokenCount: result.usage.total_tokens,
};
}
/**
* Get embedding from TGI
*/
private async getTgiEmbedding(
endpoint: string,
_model: string,
input: string,
): Promise<{ vector: number[]; tokenCount: number }> {
// TGI uses /embed endpoint
const response = await fetch(`${endpoint}/embed`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
inputs: input,
}),
});
if (!response.ok) {
const errorText = await response.text();
throw new Error(`TGI embedding error: ${errorText}`);
}
const result = await response.json() as number[][];
// Estimate token count
const tokenCount = Math.ceil(input.length / 4);
return {
vector: result[0],
tokenCount,
};
}
/**
* Send error response
*/
private sendError(
res: http.ServerResponse,
statusCode: number,
message: string,
type: string,
param?: string,
): void {
const error: IApiError = {
error: {
message,
type,
param,
code: null,
},
};
res.writeHead(statusCode, { 'Content-Type': 'application/json' });
res.end(JSON.stringify(error));
}
}