feat(cluster,api,models,cli): add cluster-aware model catalog deployments and request routing
This commit is contained in:
+149
-65
@@ -1,58 +1,89 @@
|
||||
/**
|
||||
* Chat Completions Handler
|
||||
*
|
||||
* Handles /v1/chat/completions and /v1/completions endpoints.
|
||||
* Chat completions handler.
|
||||
*/
|
||||
|
||||
import * as http from 'node:http';
|
||||
import type {
|
||||
IChatCompletionRequest,
|
||||
IChatCompletionResponse,
|
||||
IApiError,
|
||||
} from '../../interfaces/api.ts';
|
||||
import { logger } from '../../logger.ts';
|
||||
import type { IApiError, IChatCompletionRequest } from '../../interfaces/api.ts';
|
||||
import { ClusterCoordinator } from '../../cluster/coordinator.ts';
|
||||
import { ContainerManager } from '../../containers/container-manager.ts';
|
||||
import { logger } from '../../logger.ts';
|
||||
import { ModelRegistry } from '../../models/registry.ts';
|
||||
import { ModelLoader } from '../../models/loader.ts';
|
||||
|
||||
/**
|
||||
* Handler for chat completion requests
|
||||
*/
|
||||
export class ChatHandler {
|
||||
private containerManager: ContainerManager;
|
||||
private modelRegistry: ModelRegistry;
|
||||
private modelLoader: ModelLoader;
|
||||
private clusterCoordinator: ClusterCoordinator;
|
||||
|
||||
constructor(containerManager: ContainerManager, modelLoader: ModelLoader) {
|
||||
constructor(
|
||||
containerManager: ContainerManager,
|
||||
modelRegistry: ModelRegistry,
|
||||
modelLoader: ModelLoader,
|
||||
clusterCoordinator: ClusterCoordinator,
|
||||
) {
|
||||
this.containerManager = containerManager;
|
||||
this.modelRegistry = modelRegistry;
|
||||
this.modelLoader = modelLoader;
|
||||
this.clusterCoordinator = clusterCoordinator;
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle POST /v1/chat/completions
|
||||
*/
|
||||
public async handleChatCompletion(
|
||||
req: http.IncomingMessage,
|
||||
res: http.ServerResponse,
|
||||
body: IChatCompletionRequest,
|
||||
): Promise<void> {
|
||||
const modelName = body.model;
|
||||
const isStream = body.stream === true;
|
||||
const canonicalModel = await this.resolveCanonicalModel(body.model);
|
||||
const requestBody: IChatCompletionRequest = {
|
||||
...body,
|
||||
model: canonicalModel,
|
||||
};
|
||||
|
||||
logger.dim(`Chat completion request for model: ${modelName}`);
|
||||
logger.dim(`Chat completion request for model: ${canonicalModel}`);
|
||||
|
||||
try {
|
||||
// Find or load the model
|
||||
const container = await this.findOrLoadModel(modelName);
|
||||
if (!container) {
|
||||
this.sendError(res, 404, `Model "${modelName}" not found or could not be loaded`, 'model_not_found');
|
||||
const container = await this.findOrLoadLocalModel(canonicalModel);
|
||||
if (container) {
|
||||
if (requestBody.stream) {
|
||||
await this.handleStreamingCompletion(res, container, requestBody);
|
||||
} else {
|
||||
await this.handleNonStreamingCompletion(res, container, requestBody);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// Route to streaming or non-streaming handler
|
||||
if (isStream) {
|
||||
await this.handleStreamingCompletion(res, container, body);
|
||||
} else {
|
||||
await this.handleNonStreamingCompletion(res, container, body);
|
||||
const ensured = await this.clusterCoordinator.ensureModelViaControlPlane(canonicalModel);
|
||||
if (!ensured) {
|
||||
this.sendError(
|
||||
res,
|
||||
404,
|
||||
`Model "${canonicalModel}" not found or could not be deployed`,
|
||||
'model_not_found',
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
if (ensured.location.nodeName === this.clusterCoordinator.getLocalNodeName()) {
|
||||
const localContainer = await this.findLocalContainer(canonicalModel);
|
||||
if (!localContainer) {
|
||||
this.sendError(
|
||||
res,
|
||||
503,
|
||||
`Model "${canonicalModel}" was scheduled locally but is not ready`,
|
||||
'server_error',
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
if (requestBody.stream) {
|
||||
await this.handleStreamingCompletion(res, localContainer, requestBody);
|
||||
} else {
|
||||
await this.handleNonStreamingCompletion(res, localContainer, requestBody);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
await this.proxyChatRequest(req, res, ensured.location.endpoint, requestBody);
|
||||
} catch (error) {
|
||||
const message = error instanceof Error ? error.message : String(error);
|
||||
logger.error(`Chat completion error: ${message}`);
|
||||
@@ -60,34 +91,38 @@ export class ChatHandler {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Find container with model or attempt to load it
|
||||
*/
|
||||
private async findOrLoadModel(
|
||||
private async resolveCanonicalModel(modelName: string): Promise<string> {
|
||||
const model = await this.modelRegistry.getModel(modelName);
|
||||
return model?.id || modelName;
|
||||
}
|
||||
|
||||
private async findLocalContainer(
|
||||
modelName: string,
|
||||
): Promise<import('../../containers/base-container.ts').BaseContainer | null> {
|
||||
// First, check if model is already loaded
|
||||
const container = await this.containerManager.findContainerForModel(modelName);
|
||||
if (container) {
|
||||
return container;
|
||||
}
|
||||
|
||||
// Try to load the model
|
||||
logger.info(`Model ${modelName} not loaded, attempting to load...`);
|
||||
const loadResult = await this.modelLoader.loadModel(modelName);
|
||||
|
||||
if (!loadResult.success) {
|
||||
logger.error(`Failed to load model: ${loadResult.error}`);
|
||||
return null;
|
||||
}
|
||||
|
||||
// Find the container again after loading
|
||||
return this.containerManager.findContainerForModel(modelName);
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle non-streaming chat completion
|
||||
*/
|
||||
private async findOrLoadLocalModel(
|
||||
modelName: string,
|
||||
): Promise<import('../../containers/base-container.ts').BaseContainer | null> {
|
||||
const existing = await this.findLocalContainer(modelName);
|
||||
if (existing) {
|
||||
return existing;
|
||||
}
|
||||
|
||||
if (!this.clusterCoordinator.shouldDeployLocallyFirst()) {
|
||||
return null;
|
||||
}
|
||||
|
||||
logger.info(`Model ${modelName} not loaded, attempting local deploy...`);
|
||||
const loadResult = await this.modelLoader.loadModel(modelName);
|
||||
if (!loadResult.success) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return this.findLocalContainer(loadResult.model);
|
||||
}
|
||||
|
||||
private async handleNonStreamingCompletion(
|
||||
res: http.ServerResponse,
|
||||
container: import('../../containers/base-container.ts').BaseContainer,
|
||||
@@ -99,35 +134,85 @@ export class ChatHandler {
|
||||
res.end(JSON.stringify(response));
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle streaming chat completion
|
||||
*/
|
||||
private async handleStreamingCompletion(
|
||||
res: http.ServerResponse,
|
||||
container: import('../../containers/base-container.ts').BaseContainer,
|
||||
body: IChatCompletionRequest,
|
||||
): Promise<void> {
|
||||
// Set SSE headers
|
||||
res.writeHead(200, {
|
||||
'Content-Type': 'text/event-stream',
|
||||
'Cache-Control': 'no-cache',
|
||||
'Connection': 'keep-alive',
|
||||
Connection: 'keep-alive',
|
||||
'X-Accel-Buffering': 'no',
|
||||
});
|
||||
|
||||
// Stream chunks to client
|
||||
await container.chatCompletionStream(body, (chunk) => {
|
||||
res.write(`data: ${chunk}\n\n`);
|
||||
res.write(chunk);
|
||||
});
|
||||
|
||||
// Send final done message
|
||||
res.write('data: [DONE]\n\n');
|
||||
res.end();
|
||||
}
|
||||
|
||||
/**
|
||||
* Send error response
|
||||
*/
|
||||
private async proxyChatRequest(
|
||||
req: http.IncomingMessage,
|
||||
res: http.ServerResponse,
|
||||
targetEndpoint: string,
|
||||
body: IChatCompletionRequest,
|
||||
): Promise<void> {
|
||||
const response = await fetch(`${targetEndpoint}/v1/chat/completions`, {
|
||||
method: 'POST',
|
||||
headers: this.buildForwardHeaders(req),
|
||||
body: JSON.stringify(body),
|
||||
});
|
||||
|
||||
if (body.stream) {
|
||||
res.writeHead(response.status, {
|
||||
'Content-Type': response.headers.get('content-type') || 'text/event-stream',
|
||||
'Cache-Control': 'no-cache',
|
||||
Connection: 'keep-alive',
|
||||
});
|
||||
|
||||
const reader = response.body?.getReader();
|
||||
if (!reader) {
|
||||
res.end();
|
||||
return;
|
||||
}
|
||||
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
if (done) {
|
||||
break;
|
||||
}
|
||||
|
||||
res.write(value);
|
||||
}
|
||||
|
||||
res.end();
|
||||
return;
|
||||
}
|
||||
|
||||
const text = await response.text();
|
||||
res.writeHead(response.status, {
|
||||
'Content-Type': response.headers.get('content-type') || 'application/json',
|
||||
});
|
||||
res.end(text);
|
||||
}
|
||||
|
||||
private buildForwardHeaders(req: http.IncomingMessage): Record<string, string> {
|
||||
const headers: Record<string, string> = {
|
||||
'Content-Type': 'application/json',
|
||||
};
|
||||
|
||||
if (typeof req.headers.authorization === 'string') {
|
||||
headers.Authorization = req.headers.authorization;
|
||||
}
|
||||
|
||||
if (typeof req.headers['x-request-id'] === 'string') {
|
||||
headers['X-Request-Id'] = req.headers['x-request-id'];
|
||||
}
|
||||
|
||||
return headers;
|
||||
}
|
||||
|
||||
private sendError(
|
||||
res: http.ServerResponse,
|
||||
statusCode: number,
|
||||
@@ -140,7 +225,6 @@ export class ChatHandler {
|
||||
message,
|
||||
type,
|
||||
param,
|
||||
code: null,
|
||||
},
|
||||
};
|
||||
|
||||
|
||||
@@ -0,0 +1,316 @@
|
||||
import * as http from 'node:http';
|
||||
import type { IApiError } from '../../interfaces/api.ts';
|
||||
import type { IClusterNodeHeartbeat } from '../../interfaces/cluster.ts';
|
||||
import { ClusterCoordinator } from '../../cluster/coordinator.ts';
|
||||
import { CLUSTER } from '../../constants.ts';
|
||||
|
||||
export class ClusterHandler {
|
||||
private clusterCoordinator: ClusterCoordinator;
|
||||
|
||||
constructor(clusterCoordinator: ClusterCoordinator) {
|
||||
this.clusterCoordinator = clusterCoordinator;
|
||||
}
|
||||
|
||||
public async handle(
|
||||
req: http.IncomingMessage,
|
||||
res: http.ServerResponse,
|
||||
path: string,
|
||||
url: URL,
|
||||
): Promise<void> {
|
||||
if (!this.authenticate(req)) {
|
||||
return this.sendError(res, 401, 'Invalid cluster secret', 'authentication_error');
|
||||
}
|
||||
|
||||
if (path === '/_cluster/status' && req.method === 'GET') {
|
||||
return this.sendJson(res, 200, this.clusterCoordinator.getStatus());
|
||||
}
|
||||
|
||||
if (path === '/_cluster/nodes' && req.method === 'GET') {
|
||||
return this.sendJson(res, 200, this.clusterCoordinator.getStatus().nodes);
|
||||
}
|
||||
|
||||
if (path === '/_cluster/desired' && req.method === 'GET') {
|
||||
return this.sendJson(res, 200, this.clusterCoordinator.getDesiredDeployments());
|
||||
}
|
||||
|
||||
if (
|
||||
(path === '/_cluster/nodes/register' || path === '/_cluster/nodes/heartbeat') &&
|
||||
req.method === 'POST'
|
||||
) {
|
||||
const body = await this.parseBody(req) as IClusterNodeHeartbeat | null;
|
||||
if (!body) {
|
||||
return this.sendError(
|
||||
res,
|
||||
400,
|
||||
'Invalid cluster heartbeat payload',
|
||||
'invalid_request_error',
|
||||
);
|
||||
}
|
||||
|
||||
this.clusterCoordinator.acceptHeartbeat(body);
|
||||
return this.sendJson(res, 200, { ok: true });
|
||||
}
|
||||
|
||||
if (path === '/_cluster/models/resolve' && req.method === 'GET') {
|
||||
const model = url.searchParams.get('model');
|
||||
if (!model) {
|
||||
return this.sendError(res, 400, 'Missing model query parameter', 'invalid_request_error');
|
||||
}
|
||||
|
||||
const resolved = await this.clusterCoordinator.resolveModel(model);
|
||||
if (!resolved) {
|
||||
return this.sendError(res, 404, `Model "${model}" not found in cluster`, 'model_not_found');
|
||||
}
|
||||
|
||||
return this.sendJson(res, 200, resolved);
|
||||
}
|
||||
|
||||
if (path === '/_cluster/models/ensure' && req.method === 'POST') {
|
||||
if (!this.clusterCoordinator.canManageClusterState()) {
|
||||
return this.sendError(
|
||||
res,
|
||||
409,
|
||||
'This node is not the control plane',
|
||||
'invalid_request_error',
|
||||
);
|
||||
}
|
||||
|
||||
const body = await this.parseBody(req) as { model?: string } | null;
|
||||
if (!body?.model) {
|
||||
return this.sendError(res, 400, 'Missing model in request body', 'invalid_request_error');
|
||||
}
|
||||
|
||||
const ensured = await this.clusterCoordinator.ensureModel(body.model);
|
||||
if (!ensured) {
|
||||
return this.sendError(res, 503, `Unable to schedule model "${body.model}"`, 'server_error');
|
||||
}
|
||||
|
||||
return this.sendJson(res, 200, ensured);
|
||||
}
|
||||
|
||||
if (path === '/_cluster/models/desired' && req.method === 'POST') {
|
||||
if (!this.clusterCoordinator.canManageClusterState()) {
|
||||
return this.sendError(
|
||||
res,
|
||||
409,
|
||||
'This node is not the control plane',
|
||||
'invalid_request_error',
|
||||
);
|
||||
}
|
||||
|
||||
const body = await this.parseBody(req) as { model?: string; desiredReplicas?: number } | null;
|
||||
if (!body?.model || body.desiredReplicas === undefined) {
|
||||
return this.sendError(
|
||||
res,
|
||||
400,
|
||||
'Missing model or desiredReplicas in request body',
|
||||
'invalid_request_error',
|
||||
);
|
||||
}
|
||||
|
||||
const desiredDeployment = await this.clusterCoordinator.setDesiredReplicas(
|
||||
body.model,
|
||||
body.desiredReplicas,
|
||||
);
|
||||
if (!desiredDeployment) {
|
||||
return this.sendError(res, 404, `Model "${body.model}" not found`, 'model_not_found');
|
||||
}
|
||||
|
||||
return this.sendJson(res, 200, desiredDeployment);
|
||||
}
|
||||
|
||||
if (path === '/_cluster/models/desired/remove' && req.method === 'POST') {
|
||||
if (!this.clusterCoordinator.canManageClusterState()) {
|
||||
return this.sendError(
|
||||
res,
|
||||
409,
|
||||
'This node is not the control plane',
|
||||
'invalid_request_error',
|
||||
);
|
||||
}
|
||||
|
||||
const body = await this.parseBody(req) as { model?: string } | null;
|
||||
if (!body?.model) {
|
||||
return this.sendError(res, 400, 'Missing model in request body', 'invalid_request_error');
|
||||
}
|
||||
|
||||
const removed = await this.clusterCoordinator.clearDesiredDeployment(body.model);
|
||||
return this.sendJson(res, 200, { removed });
|
||||
}
|
||||
|
||||
if (path === '/_cluster/deployments' && req.method === 'POST') {
|
||||
const body = await this.parseBody(req) as { model?: string; replicaOrdinal?: number } | null;
|
||||
if (!body?.model) {
|
||||
return this.sendError(res, 400, 'Missing model in request body', 'invalid_request_error');
|
||||
}
|
||||
|
||||
const deployed = body.replicaOrdinal !== undefined
|
||||
? await this.clusterCoordinator.deployReplicaLocally(body.model, body.replicaOrdinal)
|
||||
: await this.clusterCoordinator.deployModelLocally(body.model);
|
||||
if (!deployed) {
|
||||
return this.sendError(res, 503, `Unable to deploy model "${body.model}"`, 'server_error');
|
||||
}
|
||||
|
||||
return this.sendJson(res, 200, deployed);
|
||||
}
|
||||
|
||||
if (path === '/_cluster/nodes/cordon' && req.method === 'POST') {
|
||||
if (!this.clusterCoordinator.canManageClusterState()) {
|
||||
return this.sendError(
|
||||
res,
|
||||
409,
|
||||
'This node is not the control plane',
|
||||
'invalid_request_error',
|
||||
);
|
||||
}
|
||||
|
||||
const body = await this.parseBody(req) as { nodeName?: string } | null;
|
||||
if (!body?.nodeName) {
|
||||
return this.sendError(
|
||||
res,
|
||||
400,
|
||||
'Missing nodeName in request body',
|
||||
'invalid_request_error',
|
||||
);
|
||||
}
|
||||
|
||||
const schedulerState = this.clusterCoordinator.setNodeSchedulerState(
|
||||
body.nodeName,
|
||||
'cordoned',
|
||||
);
|
||||
return this.sendJson(res, 200, { nodeName: body.nodeName, schedulerState });
|
||||
}
|
||||
|
||||
if (path === '/_cluster/nodes/uncordon' && req.method === 'POST') {
|
||||
if (!this.clusterCoordinator.canManageClusterState()) {
|
||||
return this.sendError(
|
||||
res,
|
||||
409,
|
||||
'This node is not the control plane',
|
||||
'invalid_request_error',
|
||||
);
|
||||
}
|
||||
|
||||
const body = await this.parseBody(req) as { nodeName?: string } | null;
|
||||
if (!body?.nodeName) {
|
||||
return this.sendError(
|
||||
res,
|
||||
400,
|
||||
'Missing nodeName in request body',
|
||||
'invalid_request_error',
|
||||
);
|
||||
}
|
||||
|
||||
const schedulerState = this.clusterCoordinator.setNodeSchedulerState(body.nodeName, 'active');
|
||||
return this.sendJson(res, 200, { nodeName: body.nodeName, schedulerState });
|
||||
}
|
||||
|
||||
if (path === '/_cluster/nodes/drain' && req.method === 'POST') {
|
||||
if (!this.clusterCoordinator.canManageClusterState()) {
|
||||
return this.sendError(
|
||||
res,
|
||||
409,
|
||||
'This node is not the control plane',
|
||||
'invalid_request_error',
|
||||
);
|
||||
}
|
||||
|
||||
const body = await this.parseBody(req) as { nodeName?: string } | null;
|
||||
if (!body?.nodeName) {
|
||||
return this.sendError(
|
||||
res,
|
||||
400,
|
||||
'Missing nodeName in request body',
|
||||
'invalid_request_error',
|
||||
);
|
||||
}
|
||||
|
||||
const schedulerState = this.clusterCoordinator.setNodeSchedulerState(
|
||||
body.nodeName,
|
||||
'draining',
|
||||
);
|
||||
return this.sendJson(res, 200, { nodeName: body.nodeName, schedulerState });
|
||||
}
|
||||
|
||||
if (path === '/_cluster/nodes/activate' && req.method === 'POST') {
|
||||
if (!this.clusterCoordinator.canManageClusterState()) {
|
||||
return this.sendError(
|
||||
res,
|
||||
409,
|
||||
'This node is not the control plane',
|
||||
'invalid_request_error',
|
||||
);
|
||||
}
|
||||
|
||||
const body = await this.parseBody(req) as { nodeName?: string } | null;
|
||||
if (!body?.nodeName) {
|
||||
return this.sendError(
|
||||
res,
|
||||
400,
|
||||
'Missing nodeName in request body',
|
||||
'invalid_request_error',
|
||||
);
|
||||
}
|
||||
|
||||
const schedulerState = this.clusterCoordinator.setNodeSchedulerState(body.nodeName, 'active');
|
||||
return this.sendJson(res, 200, { nodeName: body.nodeName, schedulerState });
|
||||
}
|
||||
|
||||
return this.sendError(res, 404, `Unknown cluster endpoint: ${path}`, 'invalid_request_error');
|
||||
}
|
||||
|
||||
private authenticate(req: http.IncomingMessage): boolean {
|
||||
const sharedSecret = this.clusterCoordinator.getSharedSecret();
|
||||
if (!sharedSecret) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return req.headers[CLUSTER.AUTH_HEADER_NAME] === sharedSecret;
|
||||
}
|
||||
|
||||
private async parseBody(req: http.IncomingMessage): Promise<unknown | null> {
|
||||
return new Promise((resolve) => {
|
||||
let body = '';
|
||||
|
||||
req.on('data', (chunk) => {
|
||||
body += chunk.toString();
|
||||
});
|
||||
|
||||
req.on('end', () => {
|
||||
if (!body) {
|
||||
resolve(null);
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
resolve(JSON.parse(body));
|
||||
} catch {
|
||||
resolve(null);
|
||||
}
|
||||
});
|
||||
|
||||
req.on('error', () => resolve(null));
|
||||
});
|
||||
}
|
||||
|
||||
private sendJson(res: http.ServerResponse, statusCode: number, body: unknown): void {
|
||||
res.writeHead(statusCode, { 'Content-Type': 'application/json' });
|
||||
res.end(JSON.stringify(body));
|
||||
}
|
||||
|
||||
private sendError(
|
||||
res: http.ServerResponse,
|
||||
statusCode: number,
|
||||
message: string,
|
||||
type: string,
|
||||
): void {
|
||||
const error: IApiError = {
|
||||
error: {
|
||||
message,
|
||||
type,
|
||||
},
|
||||
};
|
||||
|
||||
this.sendJson(res, statusCode, error);
|
||||
}
|
||||
}
|
||||
@@ -1,53 +1,96 @@
|
||||
/**
|
||||
* Embeddings Handler
|
||||
*
|
||||
* Handles /v1/embeddings endpoint.
|
||||
* Embeddings handler.
|
||||
*/
|
||||
|
||||
import * as http from 'node:http';
|
||||
import type {
|
||||
IApiError,
|
||||
IEmbeddingData,
|
||||
IEmbeddingsRequest,
|
||||
IEmbeddingsResponse,
|
||||
IEmbeddingData,
|
||||
IApiError,
|
||||
} from '../../interfaces/api.ts';
|
||||
import { logger } from '../../logger.ts';
|
||||
import { ClusterCoordinator } from '../../cluster/coordinator.ts';
|
||||
import { ContainerManager } from '../../containers/container-manager.ts';
|
||||
import { logger } from '../../logger.ts';
|
||||
import { ModelRegistry } from '../../models/registry.ts';
|
||||
|
||||
/**
|
||||
* Handler for embeddings requests
|
||||
*/
|
||||
export class EmbeddingsHandler {
|
||||
private containerManager: ContainerManager;
|
||||
private modelRegistry: ModelRegistry;
|
||||
private clusterCoordinator: ClusterCoordinator;
|
||||
|
||||
constructor(containerManager: ContainerManager) {
|
||||
constructor(
|
||||
containerManager: ContainerManager,
|
||||
modelRegistry: ModelRegistry,
|
||||
clusterCoordinator: ClusterCoordinator,
|
||||
) {
|
||||
this.containerManager = containerManager;
|
||||
this.modelRegistry = modelRegistry;
|
||||
this.clusterCoordinator = clusterCoordinator;
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle POST /v1/embeddings
|
||||
*/
|
||||
public async handleEmbeddings(
|
||||
req: http.IncomingMessage,
|
||||
res: http.ServerResponse,
|
||||
body: IEmbeddingsRequest,
|
||||
): Promise<void> {
|
||||
const modelName = body.model;
|
||||
const canonicalModel = await this.resolveCanonicalModel(body.model);
|
||||
const requestBody: IEmbeddingsRequest = {
|
||||
...body,
|
||||
model: canonicalModel,
|
||||
};
|
||||
|
||||
logger.dim(`Embeddings request for model: ${modelName}`);
|
||||
logger.dim(`Embeddings request for model: ${canonicalModel}`);
|
||||
|
||||
try {
|
||||
// Find container with the embedding model
|
||||
const container = await this.containerManager.findContainerForModel(modelName);
|
||||
if (!container) {
|
||||
this.sendError(res, 404, `Embedding model "${modelName}" not found`, 'model_not_found');
|
||||
const container = await this.containerManager.findContainerForModel(canonicalModel);
|
||||
if (container) {
|
||||
const response = await this.generateEmbeddings(container, requestBody);
|
||||
res.writeHead(200, { 'Content-Type': 'application/json' });
|
||||
res.end(JSON.stringify(response));
|
||||
return;
|
||||
}
|
||||
|
||||
// Generate embeddings
|
||||
const response = await this.generateEmbeddings(container, body);
|
||||
const ensured = await this.clusterCoordinator.ensureModelViaControlPlane(canonicalModel);
|
||||
if (!ensured) {
|
||||
this.sendError(
|
||||
res,
|
||||
404,
|
||||
`Embedding model "${canonicalModel}" not found`,
|
||||
'model_not_found',
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
res.writeHead(200, { 'Content-Type': 'application/json' });
|
||||
res.end(JSON.stringify(response));
|
||||
if (ensured.location.nodeName === this.clusterCoordinator.getLocalNodeName()) {
|
||||
const localContainer = await this.containerManager.findContainerForModel(canonicalModel);
|
||||
if (!localContainer) {
|
||||
this.sendError(
|
||||
res,
|
||||
503,
|
||||
`Embedding model "${canonicalModel}" is not ready`,
|
||||
'server_error',
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
const response = await this.generateEmbeddings(localContainer, requestBody);
|
||||
res.writeHead(200, { 'Content-Type': 'application/json' });
|
||||
res.end(JSON.stringify(response));
|
||||
return;
|
||||
}
|
||||
|
||||
const response = await fetch(`${ensured.location.endpoint}/v1/embeddings`, {
|
||||
method: 'POST',
|
||||
headers: this.buildForwardHeaders(req),
|
||||
body: JSON.stringify(requestBody),
|
||||
});
|
||||
|
||||
const text = await response.text();
|
||||
res.writeHead(response.status, {
|
||||
'Content-Type': response.headers.get('content-type') || 'application/json',
|
||||
});
|
||||
res.end(text);
|
||||
} catch (error) {
|
||||
const message = error instanceof Error ? error.message : String(error);
|
||||
logger.error(`Embeddings error: ${message}`);
|
||||
@@ -55,9 +98,11 @@ export class EmbeddingsHandler {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate embeddings from container
|
||||
*/
|
||||
private async resolveCanonicalModel(modelName: string): Promise<string> {
|
||||
const model = await this.modelRegistry.getModel(modelName);
|
||||
return model?.id || modelName;
|
||||
}
|
||||
|
||||
private async generateEmbeddings(
|
||||
container: import('../../containers/base-container.ts').BaseContainer,
|
||||
request: IEmbeddingsRequest,
|
||||
@@ -66,7 +111,6 @@ export class EmbeddingsHandler {
|
||||
const embeddings: IEmbeddingData[] = [];
|
||||
let totalTokens = 0;
|
||||
|
||||
// Generate embeddings for each input
|
||||
for (let i = 0; i < inputs.length; i++) {
|
||||
const input = inputs[i];
|
||||
const embedding = await this.getEmbeddingFromContainer(container, request.model, input);
|
||||
@@ -91,9 +135,6 @@ export class EmbeddingsHandler {
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Get embedding from container (container-specific implementation)
|
||||
*/
|
||||
private async getEmbeddingFromContainer(
|
||||
container: import('../../containers/base-container.ts').BaseContainer,
|
||||
model: string,
|
||||
@@ -102,54 +143,17 @@ export class EmbeddingsHandler {
|
||||
const endpoint = container.getEndpoint();
|
||||
const containerType = container.type;
|
||||
|
||||
// Route to container-specific embedding endpoint
|
||||
if (containerType === 'ollama') {
|
||||
return this.getOllamaEmbedding(endpoint, model, input);
|
||||
} else if (containerType === 'vllm') {
|
||||
if (containerType === 'vllm') {
|
||||
return this.getVllmEmbedding(endpoint, model, input);
|
||||
} else if (containerType === 'tgi') {
|
||||
}
|
||||
|
||||
if (containerType === 'tgi') {
|
||||
return this.getTgiEmbedding(endpoint, model, input);
|
||||
}
|
||||
|
||||
throw new Error(`Container type ${containerType} does not support embeddings`);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get embedding from Ollama
|
||||
*/
|
||||
private async getOllamaEmbedding(
|
||||
endpoint: string,
|
||||
model: string,
|
||||
input: string,
|
||||
): Promise<{ vector: number[]; tokenCount: number }> {
|
||||
const response = await fetch(`${endpoint}/api/embeddings`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({
|
||||
model,
|
||||
prompt: input,
|
||||
}),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text();
|
||||
throw new Error(`Ollama embedding error: ${errorText}`);
|
||||
}
|
||||
|
||||
const result = await response.json() as { embedding: number[] };
|
||||
|
||||
// Estimate token count (rough approximation: ~4 chars per token)
|
||||
const tokenCount = Math.ceil(input.length / 4);
|
||||
|
||||
return {
|
||||
vector: result.embedding,
|
||||
tokenCount,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Get embedding from vLLM (OpenAI-compatible)
|
||||
*/
|
||||
private async getVllmEmbedding(
|
||||
endpoint: string,
|
||||
model: string,
|
||||
@@ -158,61 +162,58 @@ export class EmbeddingsHandler {
|
||||
const response = await fetch(`${endpoint}/v1/embeddings`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({
|
||||
model,
|
||||
input,
|
||||
}),
|
||||
body: JSON.stringify({ model, input }),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text();
|
||||
throw new Error(`vLLM embedding error: ${errorText}`);
|
||||
throw new Error(`vLLM embedding error: ${await response.text()}`);
|
||||
}
|
||||
|
||||
const result = await response.json() as IEmbeddingsResponse;
|
||||
|
||||
return {
|
||||
vector: result.data[0].embedding,
|
||||
tokenCount: result.usage.total_tokens,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Get embedding from TGI
|
||||
*/
|
||||
private async getTgiEmbedding(
|
||||
endpoint: string,
|
||||
_model: string,
|
||||
input: string,
|
||||
): Promise<{ vector: number[]; tokenCount: number }> {
|
||||
// TGI uses /embed endpoint
|
||||
const response = await fetch(`${endpoint}/embed`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({
|
||||
inputs: input,
|
||||
}),
|
||||
body: JSON.stringify({ inputs: input }),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text();
|
||||
throw new Error(`TGI embedding error: ${errorText}`);
|
||||
throw new Error(`TGI embedding error: ${await response.text()}`);
|
||||
}
|
||||
|
||||
const result = await response.json() as number[][];
|
||||
|
||||
// Estimate token count
|
||||
const tokenCount = Math.ceil(input.length / 4);
|
||||
|
||||
return {
|
||||
vector: result[0],
|
||||
tokenCount,
|
||||
tokenCount: Math.ceil(input.length / 4),
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Send error response
|
||||
*/
|
||||
private buildForwardHeaders(req: http.IncomingMessage): Record<string, string> {
|
||||
const headers: Record<string, string> = {
|
||||
'Content-Type': 'application/json',
|
||||
};
|
||||
|
||||
if (typeof req.headers.authorization === 'string') {
|
||||
headers.Authorization = req.headers.authorization;
|
||||
}
|
||||
|
||||
if (typeof req.headers['x-request-id'] === 'string') {
|
||||
headers['X-Request-Id'] = req.headers['x-request-id'];
|
||||
}
|
||||
|
||||
return headers;
|
||||
}
|
||||
|
||||
private sendError(
|
||||
res: http.ServerResponse,
|
||||
statusCode: number,
|
||||
@@ -225,7 +226,6 @@ export class EmbeddingsHandler {
|
||||
message,
|
||||
type,
|
||||
param,
|
||||
code: null,
|
||||
},
|
||||
};
|
||||
|
||||
|
||||
@@ -5,5 +5,6 @@
|
||||
*/
|
||||
|
||||
export { ChatHandler } from './chat.ts';
|
||||
export { ClusterHandler } from './cluster.ts';
|
||||
export { ModelsHandler } from './models.ts';
|
||||
export { EmbeddingsHandler } from './embeddings.ts';
|
||||
|
||||
+53
-50
@@ -1,34 +1,29 @@
|
||||
/**
|
||||
* Models Handler
|
||||
*
|
||||
* Handles /v1/models endpoints.
|
||||
* Models handler.
|
||||
*/
|
||||
|
||||
import * as http from 'node:http';
|
||||
import type {
|
||||
IModelInfo,
|
||||
IListModelsResponse,
|
||||
IApiError,
|
||||
} from '../../interfaces/api.ts';
|
||||
import { logger } from '../../logger.ts';
|
||||
import type { IApiError, IListModelsResponse, IModelInfo } from '../../interfaces/api.ts';
|
||||
import { ClusterCoordinator } from '../../cluster/coordinator.ts';
|
||||
import { ContainerManager } from '../../containers/container-manager.ts';
|
||||
import { logger } from '../../logger.ts';
|
||||
import { ModelRegistry } from '../../models/registry.ts';
|
||||
|
||||
/**
|
||||
* Handler for model-related requests
|
||||
*/
|
||||
export class ModelsHandler {
|
||||
private containerManager: ContainerManager;
|
||||
private modelRegistry: ModelRegistry;
|
||||
private clusterCoordinator: ClusterCoordinator;
|
||||
|
||||
constructor(containerManager: ContainerManager, modelRegistry: ModelRegistry) {
|
||||
constructor(
|
||||
containerManager: ContainerManager,
|
||||
modelRegistry: ModelRegistry,
|
||||
clusterCoordinator: ClusterCoordinator,
|
||||
) {
|
||||
this.containerManager = containerManager;
|
||||
this.modelRegistry = modelRegistry;
|
||||
this.clusterCoordinator = clusterCoordinator;
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle GET /v1/models
|
||||
*/
|
||||
public async handleListModels(res: http.ServerResponse): Promise<void> {
|
||||
try {
|
||||
const models = await this.getAvailableModels();
|
||||
@@ -47,13 +42,12 @@ export class ModelsHandler {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle GET /v1/models/:model
|
||||
*/
|
||||
public async handleGetModel(res: http.ServerResponse, modelId: string): Promise<void> {
|
||||
try {
|
||||
const models = await this.getAvailableModels();
|
||||
const model = models.find((m) => m.id === modelId);
|
||||
const requested = await this.modelRegistry.getModel(modelId);
|
||||
const canonicalId = requested?.id || modelId;
|
||||
const model = models.find((entry) => entry.id === canonicalId);
|
||||
|
||||
if (!model) {
|
||||
this.sendError(res, 404, `Model "${modelId}" not found`, 'model_not_found');
|
||||
@@ -69,51 +63,61 @@ export class ModelsHandler {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all available models from containers and greenlist
|
||||
*/
|
||||
private async getAvailableModels(): Promise<IModelInfo[]> {
|
||||
const models: IModelInfo[] = [];
|
||||
const seen = new Set<string>();
|
||||
const timestamp = Math.floor(Date.now() / 1000);
|
||||
|
||||
// Get models from running containers
|
||||
const containerModels = await this.containerManager.getAllAvailableModels();
|
||||
for (const [modelId, modelInfo] of containerModels) {
|
||||
if (!seen.has(modelId)) {
|
||||
seen.add(modelId);
|
||||
models.push({
|
||||
id: modelId,
|
||||
object: 'model',
|
||||
created: timestamp,
|
||||
owned_by: `modelgrid-${modelInfo.container}`,
|
||||
});
|
||||
for (const [modelId, endpoints] of containerModels) {
|
||||
if (seen.has(modelId)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const primaryEndpoint = endpoints[0];
|
||||
seen.add(modelId);
|
||||
models.push({
|
||||
id: modelId,
|
||||
object: 'model',
|
||||
created: timestamp,
|
||||
owned_by: `modelgrid-${primaryEndpoint?.type || 'vllm'}`,
|
||||
});
|
||||
}
|
||||
|
||||
// Add greenlit models that aren't loaded yet
|
||||
const greenlitModels = await this.modelRegistry.getAllGreenlitModels();
|
||||
for (const greenlit of greenlitModels) {
|
||||
if (!seen.has(greenlit.name)) {
|
||||
seen.add(greenlit.name);
|
||||
models.push({
|
||||
id: greenlit.name,
|
||||
object: 'model',
|
||||
created: timestamp,
|
||||
owned_by: `modelgrid-${greenlit.container}`,
|
||||
});
|
||||
const clusterStatus = this.clusterCoordinator.getStatus();
|
||||
for (const [modelId, locations] of Object.entries(clusterStatus.models)) {
|
||||
if (seen.has(modelId) || locations.length === 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
seen.add(modelId);
|
||||
models.push({
|
||||
id: modelId,
|
||||
object: 'model',
|
||||
created: timestamp,
|
||||
owned_by: `modelgrid-${locations[0].engine}`,
|
||||
});
|
||||
}
|
||||
|
||||
// Sort alphabetically
|
||||
models.sort((a, b) => a.id.localeCompare(b.id));
|
||||
const catalogModels = await this.modelRegistry.getAllModels();
|
||||
for (const model of catalogModels) {
|
||||
if (seen.has(model.id)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
seen.add(model.id);
|
||||
models.push({
|
||||
id: model.id,
|
||||
object: 'model',
|
||||
created: timestamp,
|
||||
owned_by: `modelgrid-${model.engine}`,
|
||||
});
|
||||
}
|
||||
|
||||
models.sort((left, right) => left.id.localeCompare(right.id));
|
||||
return models;
|
||||
}
|
||||
|
||||
/**
|
||||
* Send error response
|
||||
*/
|
||||
private sendError(
|
||||
res: http.ServerResponse,
|
||||
statusCode: number,
|
||||
@@ -126,7 +130,6 @@ export class ModelsHandler {
|
||||
message,
|
||||
type,
|
||||
param,
|
||||
code: null,
|
||||
},
|
||||
};
|
||||
|
||||
|
||||
+39
-11
@@ -63,7 +63,11 @@ export class SanityMiddleware {
|
||||
if (request.temperature !== undefined) {
|
||||
const temp = request.temperature as number;
|
||||
if (typeof temp !== 'number' || temp < 0 || temp > 2) {
|
||||
return { valid: false, error: '"temperature" must be between 0 and 2', param: 'temperature' };
|
||||
return {
|
||||
valid: false,
|
||||
error: '"temperature" must be between 0 and 2',
|
||||
param: 'temperature',
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
@@ -77,7 +81,11 @@ export class SanityMiddleware {
|
||||
if (request.max_tokens !== undefined) {
|
||||
const maxTokens = request.max_tokens as number;
|
||||
if (typeof maxTokens !== 'number' || maxTokens < 1) {
|
||||
return { valid: false, error: '"max_tokens" must be a positive integer', param: 'max_tokens' };
|
||||
return {
|
||||
valid: false,
|
||||
error: '"max_tokens" must be a positive integer',
|
||||
param: 'max_tokens',
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
@@ -95,14 +103,22 @@ export class SanityMiddleware {
|
||||
if (request.presence_penalty !== undefined) {
|
||||
const pp = request.presence_penalty as number;
|
||||
if (typeof pp !== 'number' || pp < -2 || pp > 2) {
|
||||
return { valid: false, error: '"presence_penalty" must be between -2 and 2', param: 'presence_penalty' };
|
||||
return {
|
||||
valid: false,
|
||||
error: '"presence_penalty" must be between -2 and 2',
|
||||
param: 'presence_penalty',
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
if (request.frequency_penalty !== undefined) {
|
||||
const fp = request.frequency_penalty as number;
|
||||
if (typeof fp !== 'number' || fp < -2 || fp > 2) {
|
||||
return { valid: false, error: '"frequency_penalty" must be between -2 and 2', param: 'frequency_penalty' };
|
||||
return {
|
||||
valid: false,
|
||||
error: '"frequency_penalty" must be between -2 and 2',
|
||||
param: 'frequency_penalty',
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
@@ -114,7 +130,11 @@ export class SanityMiddleware {
|
||||
*/
|
||||
private validateMessage(msg: Record<string, unknown>, index: number): IValidationResult {
|
||||
if (!msg || typeof msg !== 'object') {
|
||||
return { valid: false, error: `Message at index ${index} must be an object`, param: `messages[${index}]` };
|
||||
return {
|
||||
valid: false,
|
||||
error: `Message at index ${index} must be an object`,
|
||||
param: `messages[${index}]`,
|
||||
};
|
||||
}
|
||||
|
||||
// Validate role
|
||||
@@ -178,7 +198,11 @@ export class SanityMiddleware {
|
||||
|
||||
const input = request.input;
|
||||
if (typeof input !== 'string' && !Array.isArray(input)) {
|
||||
return { valid: false, error: '"input" must be a string or array of strings', param: 'input' };
|
||||
return {
|
||||
valid: false,
|
||||
error: '"input" must be a string or array of strings',
|
||||
param: 'input',
|
||||
};
|
||||
}
|
||||
|
||||
if (Array.isArray(input)) {
|
||||
@@ -197,7 +221,11 @@ export class SanityMiddleware {
|
||||
if (request.encoding_format !== undefined) {
|
||||
const format = request.encoding_format as string;
|
||||
if (format !== 'float' && format !== 'base64') {
|
||||
return { valid: false, error: '"encoding_format" must be "float" or "base64"', param: 'encoding_format' };
|
||||
return {
|
||||
valid: false,
|
||||
error: '"encoding_format" must be "float" or "base64"',
|
||||
param: 'encoding_format',
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
@@ -205,14 +233,14 @@ export class SanityMiddleware {
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if model is in greenlist (async validation)
|
||||
* Check if model is in the public registry.
|
||||
*/
|
||||
public async validateModelGreenlist(modelName: string): Promise<IValidationResult> {
|
||||
const isGreenlit = await this.modelRegistry.isModelGreenlit(modelName);
|
||||
if (!isGreenlit) {
|
||||
const isListed = await this.modelRegistry.isModelListed(modelName);
|
||||
if (!isListed) {
|
||||
return {
|
||||
valid: false,
|
||||
error: `Model "${modelName}" is not greenlit. Contact administrator to add it to the greenlist.`,
|
||||
error: `Model "${modelName}" is not listed in the registry.`,
|
||||
param: 'model',
|
||||
};
|
||||
}
|
||||
|
||||
+38
-17
@@ -5,7 +5,8 @@
|
||||
*/
|
||||
|
||||
import * as http from 'node:http';
|
||||
import type { IApiError } from '../interfaces/api.ts';
|
||||
import type { IApiError, IChatCompletionRequest } from '../interfaces/api.ts';
|
||||
import { ClusterCoordinator } from '../cluster/coordinator.ts';
|
||||
import { logger } from '../logger.ts';
|
||||
import { ContainerManager } from '../containers/container-manager.ts';
|
||||
import { ModelRegistry } from '../models/registry.ts';
|
||||
@@ -23,6 +24,7 @@ export class ApiRouter {
|
||||
private containerManager: ContainerManager;
|
||||
private modelRegistry: ModelRegistry;
|
||||
private modelLoader: ModelLoader;
|
||||
private clusterCoordinator: ClusterCoordinator;
|
||||
private chatHandler: ChatHandler;
|
||||
private modelsHandler: ModelsHandler;
|
||||
private embeddingsHandler: EmbeddingsHandler;
|
||||
@@ -33,16 +35,27 @@ export class ApiRouter {
|
||||
containerManager: ContainerManager,
|
||||
modelRegistry: ModelRegistry,
|
||||
modelLoader: ModelLoader,
|
||||
clusterCoordinator: ClusterCoordinator,
|
||||
apiKeys: string[],
|
||||
) {
|
||||
this.containerManager = containerManager;
|
||||
this.modelRegistry = modelRegistry;
|
||||
this.modelLoader = modelLoader;
|
||||
this.clusterCoordinator = clusterCoordinator;
|
||||
|
||||
// Initialize handlers
|
||||
this.chatHandler = new ChatHandler(containerManager, modelLoader);
|
||||
this.modelsHandler = new ModelsHandler(containerManager, modelRegistry);
|
||||
this.embeddingsHandler = new EmbeddingsHandler(containerManager);
|
||||
this.chatHandler = new ChatHandler(
|
||||
containerManager,
|
||||
modelRegistry,
|
||||
modelLoader,
|
||||
clusterCoordinator,
|
||||
);
|
||||
this.modelsHandler = new ModelsHandler(containerManager, modelRegistry, clusterCoordinator);
|
||||
this.embeddingsHandler = new EmbeddingsHandler(
|
||||
containerManager,
|
||||
modelRegistry,
|
||||
clusterCoordinator,
|
||||
);
|
||||
|
||||
// Initialize middleware
|
||||
this.authMiddleware = new AuthMiddleware(apiKeys);
|
||||
@@ -119,8 +132,8 @@ export class ApiRouter {
|
||||
return;
|
||||
}
|
||||
|
||||
// Handle request
|
||||
await this.chatHandler.handleChatCompletion(req, res, body);
|
||||
const requestBody = this.sanityMiddleware.sanitizeChatRequest(body as Record<string, unknown>);
|
||||
await this.chatHandler.handleChatCompletion(req, res, requestBody);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -149,7 +162,7 @@ export class ApiRouter {
|
||||
}
|
||||
|
||||
// Convert to chat format and handle
|
||||
const chatBody = this.convertCompletionToChat(body);
|
||||
const chatBody = this.convertCompletionToChat(body as Record<string, unknown>);
|
||||
await this.chatHandler.handleChatCompletion(req, res, chatBody);
|
||||
}
|
||||
|
||||
@@ -222,7 +235,16 @@ export class ApiRouter {
|
||||
return;
|
||||
}
|
||||
|
||||
await this.embeddingsHandler.handleEmbeddings(res, body);
|
||||
const validation = this.sanityMiddleware.validateEmbeddingsRequest(body);
|
||||
if (!validation.valid) {
|
||||
this.sendError(res, 400, validation.error || 'Invalid request', 'invalid_request_error');
|
||||
return;
|
||||
}
|
||||
|
||||
const requestBody = this.sanityMiddleware.sanitizeEmbeddingsRequest(
|
||||
body as Record<string, unknown>,
|
||||
);
|
||||
await this.embeddingsHandler.handleEmbeddings(req, res, requestBody);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -257,21 +279,21 @@ export class ApiRouter {
|
||||
/**
|
||||
* Convert legacy completion request to chat format
|
||||
*/
|
||||
private convertCompletionToChat(body: Record<string, unknown>): Record<string, unknown> {
|
||||
private convertCompletionToChat(body: Record<string, unknown>): IChatCompletionRequest {
|
||||
const prompt = body.prompt as string | string[];
|
||||
const promptText = Array.isArray(prompt) ? prompt.join('\n') : prompt;
|
||||
|
||||
return {
|
||||
model: body.model,
|
||||
model: body.model as string,
|
||||
messages: [
|
||||
{ role: 'user', content: promptText },
|
||||
],
|
||||
max_tokens: body.max_tokens,
|
||||
temperature: body.temperature,
|
||||
top_p: body.top_p,
|
||||
n: body.n,
|
||||
stream: body.stream,
|
||||
stop: body.stop,
|
||||
max_tokens: body.max_tokens as number | undefined,
|
||||
temperature: body.temperature as number | undefined,
|
||||
top_p: body.top_p as number | undefined,
|
||||
n: body.n as number | undefined,
|
||||
stream: body.stream as boolean | undefined,
|
||||
stop: body.stop as string | string[] | undefined,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -290,7 +312,6 @@ export class ApiRouter {
|
||||
message,
|
||||
type,
|
||||
param,
|
||||
code: null,
|
||||
},
|
||||
};
|
||||
|
||||
|
||||
+17
-4
@@ -7,13 +7,15 @@
|
||||
import * as http from 'node:http';
|
||||
import type { IApiConfig } from '../interfaces/config.ts';
|
||||
import type { IHealthResponse } from '../interfaces/api.ts';
|
||||
import { ClusterCoordinator } from '../cluster/coordinator.ts';
|
||||
import { logger } from '../logger.ts';
|
||||
import { API_SERVER } from '../constants.ts';
|
||||
import { VERSION } from '../constants.ts';
|
||||
import { ApiRouter } from './router.ts';
|
||||
import { ContainerManager } from '../containers/container-manager.ts';
|
||||
import { ModelRegistry } from '../models/registry.ts';
|
||||
import { ModelLoader } from '../models/loader.ts';
|
||||
import { GpuDetector } from '../hardware/gpu-detector.ts';
|
||||
import { ClusterHandler } from './handlers/cluster.ts';
|
||||
|
||||
/**
|
||||
* API Server for ModelGrid
|
||||
@@ -26,22 +28,29 @@ export class ApiServer {
|
||||
private modelRegistry: ModelRegistry;
|
||||
private modelLoader: ModelLoader;
|
||||
private gpuDetector: GpuDetector;
|
||||
private clusterCoordinator: ClusterCoordinator;
|
||||
private clusterHandler: ClusterHandler;
|
||||
private startTime: number = 0;
|
||||
|
||||
constructor(
|
||||
config: IApiConfig,
|
||||
containerManager: ContainerManager,
|
||||
modelRegistry: ModelRegistry,
|
||||
modelLoader: ModelLoader,
|
||||
clusterCoordinator: ClusterCoordinator,
|
||||
) {
|
||||
this.config = config;
|
||||
this.containerManager = containerManager;
|
||||
this.modelRegistry = modelRegistry;
|
||||
this.gpuDetector = new GpuDetector();
|
||||
this.modelLoader = new ModelLoader(modelRegistry, containerManager, true);
|
||||
this.modelLoader = modelLoader;
|
||||
this.clusterCoordinator = clusterCoordinator;
|
||||
this.clusterHandler = new ClusterHandler(clusterCoordinator);
|
||||
this.router = new ApiRouter(
|
||||
containerManager,
|
||||
modelRegistry,
|
||||
this.modelLoader,
|
||||
clusterCoordinator,
|
||||
config.apiKeys,
|
||||
);
|
||||
}
|
||||
@@ -120,6 +129,11 @@ export class ApiServer {
|
||||
const url = new URL(req.url || '/', `http://${req.headers.host || 'localhost'}`);
|
||||
const path = url.pathname;
|
||||
|
||||
if (path.startsWith('/_cluster')) {
|
||||
await this.clusterHandler.handle(req, res, path, url);
|
||||
return;
|
||||
}
|
||||
|
||||
// Health check endpoint (no auth required)
|
||||
if (path === '/health' || path === '/healthz') {
|
||||
await this.handleHealthCheck(res);
|
||||
@@ -194,7 +208,7 @@ export class ApiServer {
|
||||
|
||||
const response: IHealthResponse = {
|
||||
status,
|
||||
version: '1.0.0', // TODO: Get from config
|
||||
version: VERSION,
|
||||
uptime: Math.floor((Date.now() - this.startTime) / 1000),
|
||||
containers: statuses.size,
|
||||
models: models.size,
|
||||
@@ -276,7 +290,6 @@ export class ApiServer {
|
||||
error: {
|
||||
message,
|
||||
type,
|
||||
code: null,
|
||||
},
|
||||
}));
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user