feat(cluster,api,models,cli): add cluster-aware model catalog deployments and request routing

This commit is contained in:
2026-04-20 23:00:50 +00:00
parent 83cacd0cf1
commit 4f2266e1b7
55 changed files with 3970 additions and 1630 deletions
+149 -65
View File
@@ -1,58 +1,89 @@
/**
* Chat Completions Handler
*
* Handles /v1/chat/completions and /v1/completions endpoints.
* Chat completions handler.
*/
import * as http from 'node:http';
import type {
IChatCompletionRequest,
IChatCompletionResponse,
IApiError,
} from '../../interfaces/api.ts';
import { logger } from '../../logger.ts';
import type { IApiError, IChatCompletionRequest } from '../../interfaces/api.ts';
import { ClusterCoordinator } from '../../cluster/coordinator.ts';
import { ContainerManager } from '../../containers/container-manager.ts';
import { logger } from '../../logger.ts';
import { ModelRegistry } from '../../models/registry.ts';
import { ModelLoader } from '../../models/loader.ts';
/**
* Handler for chat completion requests
*/
export class ChatHandler {
private containerManager: ContainerManager;
private modelRegistry: ModelRegistry;
private modelLoader: ModelLoader;
private clusterCoordinator: ClusterCoordinator;
constructor(containerManager: ContainerManager, modelLoader: ModelLoader) {
constructor(
containerManager: ContainerManager,
modelRegistry: ModelRegistry,
modelLoader: ModelLoader,
clusterCoordinator: ClusterCoordinator,
) {
this.containerManager = containerManager;
this.modelRegistry = modelRegistry;
this.modelLoader = modelLoader;
this.clusterCoordinator = clusterCoordinator;
}
/**
* Handle POST /v1/chat/completions
*/
public async handleChatCompletion(
req: http.IncomingMessage,
res: http.ServerResponse,
body: IChatCompletionRequest,
): Promise<void> {
const modelName = body.model;
const isStream = body.stream === true;
const canonicalModel = await this.resolveCanonicalModel(body.model);
const requestBody: IChatCompletionRequest = {
...body,
model: canonicalModel,
};
logger.dim(`Chat completion request for model: ${modelName}`);
logger.dim(`Chat completion request for model: ${canonicalModel}`);
try {
// Find or load the model
const container = await this.findOrLoadModel(modelName);
if (!container) {
this.sendError(res, 404, `Model "${modelName}" not found or could not be loaded`, 'model_not_found');
const container = await this.findOrLoadLocalModel(canonicalModel);
if (container) {
if (requestBody.stream) {
await this.handleStreamingCompletion(res, container, requestBody);
} else {
await this.handleNonStreamingCompletion(res, container, requestBody);
}
return;
}
// Route to streaming or non-streaming handler
if (isStream) {
await this.handleStreamingCompletion(res, container, body);
} else {
await this.handleNonStreamingCompletion(res, container, body);
const ensured = await this.clusterCoordinator.ensureModelViaControlPlane(canonicalModel);
if (!ensured) {
this.sendError(
res,
404,
`Model "${canonicalModel}" not found or could not be deployed`,
'model_not_found',
);
return;
}
if (ensured.location.nodeName === this.clusterCoordinator.getLocalNodeName()) {
const localContainer = await this.findLocalContainer(canonicalModel);
if (!localContainer) {
this.sendError(
res,
503,
`Model "${canonicalModel}" was scheduled locally but is not ready`,
'server_error',
);
return;
}
if (requestBody.stream) {
await this.handleStreamingCompletion(res, localContainer, requestBody);
} else {
await this.handleNonStreamingCompletion(res, localContainer, requestBody);
}
return;
}
await this.proxyChatRequest(req, res, ensured.location.endpoint, requestBody);
} catch (error) {
const message = error instanceof Error ? error.message : String(error);
logger.error(`Chat completion error: ${message}`);
@@ -60,34 +91,38 @@ export class ChatHandler {
}
}
/**
* Find container with model or attempt to load it
*/
private async findOrLoadModel(
private async resolveCanonicalModel(modelName: string): Promise<string> {
const model = await this.modelRegistry.getModel(modelName);
return model?.id || modelName;
}
private async findLocalContainer(
modelName: string,
): Promise<import('../../containers/base-container.ts').BaseContainer | null> {
// First, check if model is already loaded
const container = await this.containerManager.findContainerForModel(modelName);
if (container) {
return container;
}
// Try to load the model
logger.info(`Model ${modelName} not loaded, attempting to load...`);
const loadResult = await this.modelLoader.loadModel(modelName);
if (!loadResult.success) {
logger.error(`Failed to load model: ${loadResult.error}`);
return null;
}
// Find the container again after loading
return this.containerManager.findContainerForModel(modelName);
}
/**
* Handle non-streaming chat completion
*/
private async findOrLoadLocalModel(
modelName: string,
): Promise<import('../../containers/base-container.ts').BaseContainer | null> {
const existing = await this.findLocalContainer(modelName);
if (existing) {
return existing;
}
if (!this.clusterCoordinator.shouldDeployLocallyFirst()) {
return null;
}
logger.info(`Model ${modelName} not loaded, attempting local deploy...`);
const loadResult = await this.modelLoader.loadModel(modelName);
if (!loadResult.success) {
return null;
}
return this.findLocalContainer(loadResult.model);
}
private async handleNonStreamingCompletion(
res: http.ServerResponse,
container: import('../../containers/base-container.ts').BaseContainer,
@@ -99,35 +134,85 @@ export class ChatHandler {
res.end(JSON.stringify(response));
}
/**
* Handle streaming chat completion
*/
private async handleStreamingCompletion(
res: http.ServerResponse,
container: import('../../containers/base-container.ts').BaseContainer,
body: IChatCompletionRequest,
): Promise<void> {
// Set SSE headers
res.writeHead(200, {
'Content-Type': 'text/event-stream',
'Cache-Control': 'no-cache',
'Connection': 'keep-alive',
Connection: 'keep-alive',
'X-Accel-Buffering': 'no',
});
// Stream chunks to client
await container.chatCompletionStream(body, (chunk) => {
res.write(`data: ${chunk}\n\n`);
res.write(chunk);
});
// Send final done message
res.write('data: [DONE]\n\n');
res.end();
}
/**
* Send error response
*/
private async proxyChatRequest(
req: http.IncomingMessage,
res: http.ServerResponse,
targetEndpoint: string,
body: IChatCompletionRequest,
): Promise<void> {
const response = await fetch(`${targetEndpoint}/v1/chat/completions`, {
method: 'POST',
headers: this.buildForwardHeaders(req),
body: JSON.stringify(body),
});
if (body.stream) {
res.writeHead(response.status, {
'Content-Type': response.headers.get('content-type') || 'text/event-stream',
'Cache-Control': 'no-cache',
Connection: 'keep-alive',
});
const reader = response.body?.getReader();
if (!reader) {
res.end();
return;
}
while (true) {
const { done, value } = await reader.read();
if (done) {
break;
}
res.write(value);
}
res.end();
return;
}
const text = await response.text();
res.writeHead(response.status, {
'Content-Type': response.headers.get('content-type') || 'application/json',
});
res.end(text);
}
private buildForwardHeaders(req: http.IncomingMessage): Record<string, string> {
const headers: Record<string, string> = {
'Content-Type': 'application/json',
};
if (typeof req.headers.authorization === 'string') {
headers.Authorization = req.headers.authorization;
}
if (typeof req.headers['x-request-id'] === 'string') {
headers['X-Request-Id'] = req.headers['x-request-id'];
}
return headers;
}
private sendError(
res: http.ServerResponse,
statusCode: number,
@@ -140,7 +225,6 @@ export class ChatHandler {
message,
type,
param,
code: null,
},
};
+316
View File
@@ -0,0 +1,316 @@
import * as http from 'node:http';
import type { IApiError } from '../../interfaces/api.ts';
import type { IClusterNodeHeartbeat } from '../../interfaces/cluster.ts';
import { ClusterCoordinator } from '../../cluster/coordinator.ts';
import { CLUSTER } from '../../constants.ts';
export class ClusterHandler {
private clusterCoordinator: ClusterCoordinator;
constructor(clusterCoordinator: ClusterCoordinator) {
this.clusterCoordinator = clusterCoordinator;
}
public async handle(
req: http.IncomingMessage,
res: http.ServerResponse,
path: string,
url: URL,
): Promise<void> {
if (!this.authenticate(req)) {
return this.sendError(res, 401, 'Invalid cluster secret', 'authentication_error');
}
if (path === '/_cluster/status' && req.method === 'GET') {
return this.sendJson(res, 200, this.clusterCoordinator.getStatus());
}
if (path === '/_cluster/nodes' && req.method === 'GET') {
return this.sendJson(res, 200, this.clusterCoordinator.getStatus().nodes);
}
if (path === '/_cluster/desired' && req.method === 'GET') {
return this.sendJson(res, 200, this.clusterCoordinator.getDesiredDeployments());
}
if (
(path === '/_cluster/nodes/register' || path === '/_cluster/nodes/heartbeat') &&
req.method === 'POST'
) {
const body = await this.parseBody(req) as IClusterNodeHeartbeat | null;
if (!body) {
return this.sendError(
res,
400,
'Invalid cluster heartbeat payload',
'invalid_request_error',
);
}
this.clusterCoordinator.acceptHeartbeat(body);
return this.sendJson(res, 200, { ok: true });
}
if (path === '/_cluster/models/resolve' && req.method === 'GET') {
const model = url.searchParams.get('model');
if (!model) {
return this.sendError(res, 400, 'Missing model query parameter', 'invalid_request_error');
}
const resolved = await this.clusterCoordinator.resolveModel(model);
if (!resolved) {
return this.sendError(res, 404, `Model "${model}" not found in cluster`, 'model_not_found');
}
return this.sendJson(res, 200, resolved);
}
if (path === '/_cluster/models/ensure' && req.method === 'POST') {
if (!this.clusterCoordinator.canManageClusterState()) {
return this.sendError(
res,
409,
'This node is not the control plane',
'invalid_request_error',
);
}
const body = await this.parseBody(req) as { model?: string } | null;
if (!body?.model) {
return this.sendError(res, 400, 'Missing model in request body', 'invalid_request_error');
}
const ensured = await this.clusterCoordinator.ensureModel(body.model);
if (!ensured) {
return this.sendError(res, 503, `Unable to schedule model "${body.model}"`, 'server_error');
}
return this.sendJson(res, 200, ensured);
}
if (path === '/_cluster/models/desired' && req.method === 'POST') {
if (!this.clusterCoordinator.canManageClusterState()) {
return this.sendError(
res,
409,
'This node is not the control plane',
'invalid_request_error',
);
}
const body = await this.parseBody(req) as { model?: string; desiredReplicas?: number } | null;
if (!body?.model || body.desiredReplicas === undefined) {
return this.sendError(
res,
400,
'Missing model or desiredReplicas in request body',
'invalid_request_error',
);
}
const desiredDeployment = await this.clusterCoordinator.setDesiredReplicas(
body.model,
body.desiredReplicas,
);
if (!desiredDeployment) {
return this.sendError(res, 404, `Model "${body.model}" not found`, 'model_not_found');
}
return this.sendJson(res, 200, desiredDeployment);
}
if (path === '/_cluster/models/desired/remove' && req.method === 'POST') {
if (!this.clusterCoordinator.canManageClusterState()) {
return this.sendError(
res,
409,
'This node is not the control plane',
'invalid_request_error',
);
}
const body = await this.parseBody(req) as { model?: string } | null;
if (!body?.model) {
return this.sendError(res, 400, 'Missing model in request body', 'invalid_request_error');
}
const removed = await this.clusterCoordinator.clearDesiredDeployment(body.model);
return this.sendJson(res, 200, { removed });
}
if (path === '/_cluster/deployments' && req.method === 'POST') {
const body = await this.parseBody(req) as { model?: string; replicaOrdinal?: number } | null;
if (!body?.model) {
return this.sendError(res, 400, 'Missing model in request body', 'invalid_request_error');
}
const deployed = body.replicaOrdinal !== undefined
? await this.clusterCoordinator.deployReplicaLocally(body.model, body.replicaOrdinal)
: await this.clusterCoordinator.deployModelLocally(body.model);
if (!deployed) {
return this.sendError(res, 503, `Unable to deploy model "${body.model}"`, 'server_error');
}
return this.sendJson(res, 200, deployed);
}
if (path === '/_cluster/nodes/cordon' && req.method === 'POST') {
if (!this.clusterCoordinator.canManageClusterState()) {
return this.sendError(
res,
409,
'This node is not the control plane',
'invalid_request_error',
);
}
const body = await this.parseBody(req) as { nodeName?: string } | null;
if (!body?.nodeName) {
return this.sendError(
res,
400,
'Missing nodeName in request body',
'invalid_request_error',
);
}
const schedulerState = this.clusterCoordinator.setNodeSchedulerState(
body.nodeName,
'cordoned',
);
return this.sendJson(res, 200, { nodeName: body.nodeName, schedulerState });
}
if (path === '/_cluster/nodes/uncordon' && req.method === 'POST') {
if (!this.clusterCoordinator.canManageClusterState()) {
return this.sendError(
res,
409,
'This node is not the control plane',
'invalid_request_error',
);
}
const body = await this.parseBody(req) as { nodeName?: string } | null;
if (!body?.nodeName) {
return this.sendError(
res,
400,
'Missing nodeName in request body',
'invalid_request_error',
);
}
const schedulerState = this.clusterCoordinator.setNodeSchedulerState(body.nodeName, 'active');
return this.sendJson(res, 200, { nodeName: body.nodeName, schedulerState });
}
if (path === '/_cluster/nodes/drain' && req.method === 'POST') {
if (!this.clusterCoordinator.canManageClusterState()) {
return this.sendError(
res,
409,
'This node is not the control plane',
'invalid_request_error',
);
}
const body = await this.parseBody(req) as { nodeName?: string } | null;
if (!body?.nodeName) {
return this.sendError(
res,
400,
'Missing nodeName in request body',
'invalid_request_error',
);
}
const schedulerState = this.clusterCoordinator.setNodeSchedulerState(
body.nodeName,
'draining',
);
return this.sendJson(res, 200, { nodeName: body.nodeName, schedulerState });
}
if (path === '/_cluster/nodes/activate' && req.method === 'POST') {
if (!this.clusterCoordinator.canManageClusterState()) {
return this.sendError(
res,
409,
'This node is not the control plane',
'invalid_request_error',
);
}
const body = await this.parseBody(req) as { nodeName?: string } | null;
if (!body?.nodeName) {
return this.sendError(
res,
400,
'Missing nodeName in request body',
'invalid_request_error',
);
}
const schedulerState = this.clusterCoordinator.setNodeSchedulerState(body.nodeName, 'active');
return this.sendJson(res, 200, { nodeName: body.nodeName, schedulerState });
}
return this.sendError(res, 404, `Unknown cluster endpoint: ${path}`, 'invalid_request_error');
}
private authenticate(req: http.IncomingMessage): boolean {
const sharedSecret = this.clusterCoordinator.getSharedSecret();
if (!sharedSecret) {
return true;
}
return req.headers[CLUSTER.AUTH_HEADER_NAME] === sharedSecret;
}
private async parseBody(req: http.IncomingMessage): Promise<unknown | null> {
return new Promise((resolve) => {
let body = '';
req.on('data', (chunk) => {
body += chunk.toString();
});
req.on('end', () => {
if (!body) {
resolve(null);
return;
}
try {
resolve(JSON.parse(body));
} catch {
resolve(null);
}
});
req.on('error', () => resolve(null));
});
}
private sendJson(res: http.ServerResponse, statusCode: number, body: unknown): void {
res.writeHead(statusCode, { 'Content-Type': 'application/json' });
res.end(JSON.stringify(body));
}
private sendError(
res: http.ServerResponse,
statusCode: number,
message: string,
type: string,
): void {
const error: IApiError = {
error: {
message,
type,
},
};
this.sendJson(res, statusCode, error);
}
}
+96 -96
View File
@@ -1,53 +1,96 @@
/**
* Embeddings Handler
*
* Handles /v1/embeddings endpoint.
* Embeddings handler.
*/
import * as http from 'node:http';
import type {
IApiError,
IEmbeddingData,
IEmbeddingsRequest,
IEmbeddingsResponse,
IEmbeddingData,
IApiError,
} from '../../interfaces/api.ts';
import { logger } from '../../logger.ts';
import { ClusterCoordinator } from '../../cluster/coordinator.ts';
import { ContainerManager } from '../../containers/container-manager.ts';
import { logger } from '../../logger.ts';
import { ModelRegistry } from '../../models/registry.ts';
/**
* Handler for embeddings requests
*/
export class EmbeddingsHandler {
private containerManager: ContainerManager;
private modelRegistry: ModelRegistry;
private clusterCoordinator: ClusterCoordinator;
constructor(containerManager: ContainerManager) {
constructor(
containerManager: ContainerManager,
modelRegistry: ModelRegistry,
clusterCoordinator: ClusterCoordinator,
) {
this.containerManager = containerManager;
this.modelRegistry = modelRegistry;
this.clusterCoordinator = clusterCoordinator;
}
/**
* Handle POST /v1/embeddings
*/
public async handleEmbeddings(
req: http.IncomingMessage,
res: http.ServerResponse,
body: IEmbeddingsRequest,
): Promise<void> {
const modelName = body.model;
const canonicalModel = await this.resolveCanonicalModel(body.model);
const requestBody: IEmbeddingsRequest = {
...body,
model: canonicalModel,
};
logger.dim(`Embeddings request for model: ${modelName}`);
logger.dim(`Embeddings request for model: ${canonicalModel}`);
try {
// Find container with the embedding model
const container = await this.containerManager.findContainerForModel(modelName);
if (!container) {
this.sendError(res, 404, `Embedding model "${modelName}" not found`, 'model_not_found');
const container = await this.containerManager.findContainerForModel(canonicalModel);
if (container) {
const response = await this.generateEmbeddings(container, requestBody);
res.writeHead(200, { 'Content-Type': 'application/json' });
res.end(JSON.stringify(response));
return;
}
// Generate embeddings
const response = await this.generateEmbeddings(container, body);
const ensured = await this.clusterCoordinator.ensureModelViaControlPlane(canonicalModel);
if (!ensured) {
this.sendError(
res,
404,
`Embedding model "${canonicalModel}" not found`,
'model_not_found',
);
return;
}
res.writeHead(200, { 'Content-Type': 'application/json' });
res.end(JSON.stringify(response));
if (ensured.location.nodeName === this.clusterCoordinator.getLocalNodeName()) {
const localContainer = await this.containerManager.findContainerForModel(canonicalModel);
if (!localContainer) {
this.sendError(
res,
503,
`Embedding model "${canonicalModel}" is not ready`,
'server_error',
);
return;
}
const response = await this.generateEmbeddings(localContainer, requestBody);
res.writeHead(200, { 'Content-Type': 'application/json' });
res.end(JSON.stringify(response));
return;
}
const response = await fetch(`${ensured.location.endpoint}/v1/embeddings`, {
method: 'POST',
headers: this.buildForwardHeaders(req),
body: JSON.stringify(requestBody),
});
const text = await response.text();
res.writeHead(response.status, {
'Content-Type': response.headers.get('content-type') || 'application/json',
});
res.end(text);
} catch (error) {
const message = error instanceof Error ? error.message : String(error);
logger.error(`Embeddings error: ${message}`);
@@ -55,9 +98,11 @@ export class EmbeddingsHandler {
}
}
/**
* Generate embeddings from container
*/
private async resolveCanonicalModel(modelName: string): Promise<string> {
const model = await this.modelRegistry.getModel(modelName);
return model?.id || modelName;
}
private async generateEmbeddings(
container: import('../../containers/base-container.ts').BaseContainer,
request: IEmbeddingsRequest,
@@ -66,7 +111,6 @@ export class EmbeddingsHandler {
const embeddings: IEmbeddingData[] = [];
let totalTokens = 0;
// Generate embeddings for each input
for (let i = 0; i < inputs.length; i++) {
const input = inputs[i];
const embedding = await this.getEmbeddingFromContainer(container, request.model, input);
@@ -91,9 +135,6 @@ export class EmbeddingsHandler {
};
}
/**
* Get embedding from container (container-specific implementation)
*/
private async getEmbeddingFromContainer(
container: import('../../containers/base-container.ts').BaseContainer,
model: string,
@@ -102,54 +143,17 @@ export class EmbeddingsHandler {
const endpoint = container.getEndpoint();
const containerType = container.type;
// Route to container-specific embedding endpoint
if (containerType === 'ollama') {
return this.getOllamaEmbedding(endpoint, model, input);
} else if (containerType === 'vllm') {
if (containerType === 'vllm') {
return this.getVllmEmbedding(endpoint, model, input);
} else if (containerType === 'tgi') {
}
if (containerType === 'tgi') {
return this.getTgiEmbedding(endpoint, model, input);
}
throw new Error(`Container type ${containerType} does not support embeddings`);
}
/**
* Get embedding from Ollama
*/
private async getOllamaEmbedding(
endpoint: string,
model: string,
input: string,
): Promise<{ vector: number[]; tokenCount: number }> {
const response = await fetch(`${endpoint}/api/embeddings`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
model,
prompt: input,
}),
});
if (!response.ok) {
const errorText = await response.text();
throw new Error(`Ollama embedding error: ${errorText}`);
}
const result = await response.json() as { embedding: number[] };
// Estimate token count (rough approximation: ~4 chars per token)
const tokenCount = Math.ceil(input.length / 4);
return {
vector: result.embedding,
tokenCount,
};
}
/**
* Get embedding from vLLM (OpenAI-compatible)
*/
private async getVllmEmbedding(
endpoint: string,
model: string,
@@ -158,61 +162,58 @@ export class EmbeddingsHandler {
const response = await fetch(`${endpoint}/v1/embeddings`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
model,
input,
}),
body: JSON.stringify({ model, input }),
});
if (!response.ok) {
const errorText = await response.text();
throw new Error(`vLLM embedding error: ${errorText}`);
throw new Error(`vLLM embedding error: ${await response.text()}`);
}
const result = await response.json() as IEmbeddingsResponse;
return {
vector: result.data[0].embedding,
tokenCount: result.usage.total_tokens,
};
}
/**
* Get embedding from TGI
*/
private async getTgiEmbedding(
endpoint: string,
_model: string,
input: string,
): Promise<{ vector: number[]; tokenCount: number }> {
// TGI uses /embed endpoint
const response = await fetch(`${endpoint}/embed`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
inputs: input,
}),
body: JSON.stringify({ inputs: input }),
});
if (!response.ok) {
const errorText = await response.text();
throw new Error(`TGI embedding error: ${errorText}`);
throw new Error(`TGI embedding error: ${await response.text()}`);
}
const result = await response.json() as number[][];
// Estimate token count
const tokenCount = Math.ceil(input.length / 4);
return {
vector: result[0],
tokenCount,
tokenCount: Math.ceil(input.length / 4),
};
}
/**
* Send error response
*/
private buildForwardHeaders(req: http.IncomingMessage): Record<string, string> {
const headers: Record<string, string> = {
'Content-Type': 'application/json',
};
if (typeof req.headers.authorization === 'string') {
headers.Authorization = req.headers.authorization;
}
if (typeof req.headers['x-request-id'] === 'string') {
headers['X-Request-Id'] = req.headers['x-request-id'];
}
return headers;
}
private sendError(
res: http.ServerResponse,
statusCode: number,
@@ -225,7 +226,6 @@ export class EmbeddingsHandler {
message,
type,
param,
code: null,
},
};
+1
View File
@@ -5,5 +5,6 @@
*/
export { ChatHandler } from './chat.ts';
export { ClusterHandler } from './cluster.ts';
export { ModelsHandler } from './models.ts';
export { EmbeddingsHandler } from './embeddings.ts';
+53 -50
View File
@@ -1,34 +1,29 @@
/**
* Models Handler
*
* Handles /v1/models endpoints.
* Models handler.
*/
import * as http from 'node:http';
import type {
IModelInfo,
IListModelsResponse,
IApiError,
} from '../../interfaces/api.ts';
import { logger } from '../../logger.ts';
import type { IApiError, IListModelsResponse, IModelInfo } from '../../interfaces/api.ts';
import { ClusterCoordinator } from '../../cluster/coordinator.ts';
import { ContainerManager } from '../../containers/container-manager.ts';
import { logger } from '../../logger.ts';
import { ModelRegistry } from '../../models/registry.ts';
/**
* Handler for model-related requests
*/
export class ModelsHandler {
private containerManager: ContainerManager;
private modelRegistry: ModelRegistry;
private clusterCoordinator: ClusterCoordinator;
constructor(containerManager: ContainerManager, modelRegistry: ModelRegistry) {
constructor(
containerManager: ContainerManager,
modelRegistry: ModelRegistry,
clusterCoordinator: ClusterCoordinator,
) {
this.containerManager = containerManager;
this.modelRegistry = modelRegistry;
this.clusterCoordinator = clusterCoordinator;
}
/**
* Handle GET /v1/models
*/
public async handleListModels(res: http.ServerResponse): Promise<void> {
try {
const models = await this.getAvailableModels();
@@ -47,13 +42,12 @@ export class ModelsHandler {
}
}
/**
* Handle GET /v1/models/:model
*/
public async handleGetModel(res: http.ServerResponse, modelId: string): Promise<void> {
try {
const models = await this.getAvailableModels();
const model = models.find((m) => m.id === modelId);
const requested = await this.modelRegistry.getModel(modelId);
const canonicalId = requested?.id || modelId;
const model = models.find((entry) => entry.id === canonicalId);
if (!model) {
this.sendError(res, 404, `Model "${modelId}" not found`, 'model_not_found');
@@ -69,51 +63,61 @@ export class ModelsHandler {
}
}
/**
* Get all available models from containers and greenlist
*/
private async getAvailableModels(): Promise<IModelInfo[]> {
const models: IModelInfo[] = [];
const seen = new Set<string>();
const timestamp = Math.floor(Date.now() / 1000);
// Get models from running containers
const containerModels = await this.containerManager.getAllAvailableModels();
for (const [modelId, modelInfo] of containerModels) {
if (!seen.has(modelId)) {
seen.add(modelId);
models.push({
id: modelId,
object: 'model',
created: timestamp,
owned_by: `modelgrid-${modelInfo.container}`,
});
for (const [modelId, endpoints] of containerModels) {
if (seen.has(modelId)) {
continue;
}
const primaryEndpoint = endpoints[0];
seen.add(modelId);
models.push({
id: modelId,
object: 'model',
created: timestamp,
owned_by: `modelgrid-${primaryEndpoint?.type || 'vllm'}`,
});
}
// Add greenlit models that aren't loaded yet
const greenlitModels = await this.modelRegistry.getAllGreenlitModels();
for (const greenlit of greenlitModels) {
if (!seen.has(greenlit.name)) {
seen.add(greenlit.name);
models.push({
id: greenlit.name,
object: 'model',
created: timestamp,
owned_by: `modelgrid-${greenlit.container}`,
});
const clusterStatus = this.clusterCoordinator.getStatus();
for (const [modelId, locations] of Object.entries(clusterStatus.models)) {
if (seen.has(modelId) || locations.length === 0) {
continue;
}
seen.add(modelId);
models.push({
id: modelId,
object: 'model',
created: timestamp,
owned_by: `modelgrid-${locations[0].engine}`,
});
}
// Sort alphabetically
models.sort((a, b) => a.id.localeCompare(b.id));
const catalogModels = await this.modelRegistry.getAllModels();
for (const model of catalogModels) {
if (seen.has(model.id)) {
continue;
}
seen.add(model.id);
models.push({
id: model.id,
object: 'model',
created: timestamp,
owned_by: `modelgrid-${model.engine}`,
});
}
models.sort((left, right) => left.id.localeCompare(right.id));
return models;
}
/**
* Send error response
*/
private sendError(
res: http.ServerResponse,
statusCode: number,
@@ -126,7 +130,6 @@ export class ModelsHandler {
message,
type,
param,
code: null,
},
};