feat(cluster,api,models,cli): add cluster-aware model catalog deployments and request routing
This commit is contained in:
@@ -3,6 +3,6 @@
|
||||
*/
|
||||
export const commitinfo = {
|
||||
name: '@modelgrid.com/modelgrid',
|
||||
version: '1.0.1',
|
||||
description: 'ModelGrid - GPU infrastructure management daemon for AI model containers with OpenAI-compatible API'
|
||||
version: '1.1.0',
|
||||
description: 'ModelGrid - vLLM deployment manager with an OpenAI-compatible API and OSS model catalog'
|
||||
}
|
||||
|
||||
+149
-65
@@ -1,58 +1,89 @@
|
||||
/**
|
||||
* Chat Completions Handler
|
||||
*
|
||||
* Handles /v1/chat/completions and /v1/completions endpoints.
|
||||
* Chat completions handler.
|
||||
*/
|
||||
|
||||
import * as http from 'node:http';
|
||||
import type {
|
||||
IChatCompletionRequest,
|
||||
IChatCompletionResponse,
|
||||
IApiError,
|
||||
} from '../../interfaces/api.ts';
|
||||
import { logger } from '../../logger.ts';
|
||||
import type { IApiError, IChatCompletionRequest } from '../../interfaces/api.ts';
|
||||
import { ClusterCoordinator } from '../../cluster/coordinator.ts';
|
||||
import { ContainerManager } from '../../containers/container-manager.ts';
|
||||
import { logger } from '../../logger.ts';
|
||||
import { ModelRegistry } from '../../models/registry.ts';
|
||||
import { ModelLoader } from '../../models/loader.ts';
|
||||
|
||||
/**
|
||||
* Handler for chat completion requests
|
||||
*/
|
||||
export class ChatHandler {
|
||||
private containerManager: ContainerManager;
|
||||
private modelRegistry: ModelRegistry;
|
||||
private modelLoader: ModelLoader;
|
||||
private clusterCoordinator: ClusterCoordinator;
|
||||
|
||||
constructor(containerManager: ContainerManager, modelLoader: ModelLoader) {
|
||||
constructor(
|
||||
containerManager: ContainerManager,
|
||||
modelRegistry: ModelRegistry,
|
||||
modelLoader: ModelLoader,
|
||||
clusterCoordinator: ClusterCoordinator,
|
||||
) {
|
||||
this.containerManager = containerManager;
|
||||
this.modelRegistry = modelRegistry;
|
||||
this.modelLoader = modelLoader;
|
||||
this.clusterCoordinator = clusterCoordinator;
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle POST /v1/chat/completions
|
||||
*/
|
||||
public async handleChatCompletion(
|
||||
req: http.IncomingMessage,
|
||||
res: http.ServerResponse,
|
||||
body: IChatCompletionRequest,
|
||||
): Promise<void> {
|
||||
const modelName = body.model;
|
||||
const isStream = body.stream === true;
|
||||
const canonicalModel = await this.resolveCanonicalModel(body.model);
|
||||
const requestBody: IChatCompletionRequest = {
|
||||
...body,
|
||||
model: canonicalModel,
|
||||
};
|
||||
|
||||
logger.dim(`Chat completion request for model: ${modelName}`);
|
||||
logger.dim(`Chat completion request for model: ${canonicalModel}`);
|
||||
|
||||
try {
|
||||
// Find or load the model
|
||||
const container = await this.findOrLoadModel(modelName);
|
||||
if (!container) {
|
||||
this.sendError(res, 404, `Model "${modelName}" not found or could not be loaded`, 'model_not_found');
|
||||
const container = await this.findOrLoadLocalModel(canonicalModel);
|
||||
if (container) {
|
||||
if (requestBody.stream) {
|
||||
await this.handleStreamingCompletion(res, container, requestBody);
|
||||
} else {
|
||||
await this.handleNonStreamingCompletion(res, container, requestBody);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// Route to streaming or non-streaming handler
|
||||
if (isStream) {
|
||||
await this.handleStreamingCompletion(res, container, body);
|
||||
} else {
|
||||
await this.handleNonStreamingCompletion(res, container, body);
|
||||
const ensured = await this.clusterCoordinator.ensureModelViaControlPlane(canonicalModel);
|
||||
if (!ensured) {
|
||||
this.sendError(
|
||||
res,
|
||||
404,
|
||||
`Model "${canonicalModel}" not found or could not be deployed`,
|
||||
'model_not_found',
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
if (ensured.location.nodeName === this.clusterCoordinator.getLocalNodeName()) {
|
||||
const localContainer = await this.findLocalContainer(canonicalModel);
|
||||
if (!localContainer) {
|
||||
this.sendError(
|
||||
res,
|
||||
503,
|
||||
`Model "${canonicalModel}" was scheduled locally but is not ready`,
|
||||
'server_error',
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
if (requestBody.stream) {
|
||||
await this.handleStreamingCompletion(res, localContainer, requestBody);
|
||||
} else {
|
||||
await this.handleNonStreamingCompletion(res, localContainer, requestBody);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
await this.proxyChatRequest(req, res, ensured.location.endpoint, requestBody);
|
||||
} catch (error) {
|
||||
const message = error instanceof Error ? error.message : String(error);
|
||||
logger.error(`Chat completion error: ${message}`);
|
||||
@@ -60,34 +91,38 @@ export class ChatHandler {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Find container with model or attempt to load it
|
||||
*/
|
||||
private async findOrLoadModel(
|
||||
private async resolveCanonicalModel(modelName: string): Promise<string> {
|
||||
const model = await this.modelRegistry.getModel(modelName);
|
||||
return model?.id || modelName;
|
||||
}
|
||||
|
||||
private async findLocalContainer(
|
||||
modelName: string,
|
||||
): Promise<import('../../containers/base-container.ts').BaseContainer | null> {
|
||||
// First, check if model is already loaded
|
||||
const container = await this.containerManager.findContainerForModel(modelName);
|
||||
if (container) {
|
||||
return container;
|
||||
}
|
||||
|
||||
// Try to load the model
|
||||
logger.info(`Model ${modelName} not loaded, attempting to load...`);
|
||||
const loadResult = await this.modelLoader.loadModel(modelName);
|
||||
|
||||
if (!loadResult.success) {
|
||||
logger.error(`Failed to load model: ${loadResult.error}`);
|
||||
return null;
|
||||
}
|
||||
|
||||
// Find the container again after loading
|
||||
return this.containerManager.findContainerForModel(modelName);
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle non-streaming chat completion
|
||||
*/
|
||||
private async findOrLoadLocalModel(
|
||||
modelName: string,
|
||||
): Promise<import('../../containers/base-container.ts').BaseContainer | null> {
|
||||
const existing = await this.findLocalContainer(modelName);
|
||||
if (existing) {
|
||||
return existing;
|
||||
}
|
||||
|
||||
if (!this.clusterCoordinator.shouldDeployLocallyFirst()) {
|
||||
return null;
|
||||
}
|
||||
|
||||
logger.info(`Model ${modelName} not loaded, attempting local deploy...`);
|
||||
const loadResult = await this.modelLoader.loadModel(modelName);
|
||||
if (!loadResult.success) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return this.findLocalContainer(loadResult.model);
|
||||
}
|
||||
|
||||
private async handleNonStreamingCompletion(
|
||||
res: http.ServerResponse,
|
||||
container: import('../../containers/base-container.ts').BaseContainer,
|
||||
@@ -99,35 +134,85 @@ export class ChatHandler {
|
||||
res.end(JSON.stringify(response));
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle streaming chat completion
|
||||
*/
|
||||
private async handleStreamingCompletion(
|
||||
res: http.ServerResponse,
|
||||
container: import('../../containers/base-container.ts').BaseContainer,
|
||||
body: IChatCompletionRequest,
|
||||
): Promise<void> {
|
||||
// Set SSE headers
|
||||
res.writeHead(200, {
|
||||
'Content-Type': 'text/event-stream',
|
||||
'Cache-Control': 'no-cache',
|
||||
'Connection': 'keep-alive',
|
||||
Connection: 'keep-alive',
|
||||
'X-Accel-Buffering': 'no',
|
||||
});
|
||||
|
||||
// Stream chunks to client
|
||||
await container.chatCompletionStream(body, (chunk) => {
|
||||
res.write(`data: ${chunk}\n\n`);
|
||||
res.write(chunk);
|
||||
});
|
||||
|
||||
// Send final done message
|
||||
res.write('data: [DONE]\n\n');
|
||||
res.end();
|
||||
}
|
||||
|
||||
/**
|
||||
* Send error response
|
||||
*/
|
||||
private async proxyChatRequest(
|
||||
req: http.IncomingMessage,
|
||||
res: http.ServerResponse,
|
||||
targetEndpoint: string,
|
||||
body: IChatCompletionRequest,
|
||||
): Promise<void> {
|
||||
const response = await fetch(`${targetEndpoint}/v1/chat/completions`, {
|
||||
method: 'POST',
|
||||
headers: this.buildForwardHeaders(req),
|
||||
body: JSON.stringify(body),
|
||||
});
|
||||
|
||||
if (body.stream) {
|
||||
res.writeHead(response.status, {
|
||||
'Content-Type': response.headers.get('content-type') || 'text/event-stream',
|
||||
'Cache-Control': 'no-cache',
|
||||
Connection: 'keep-alive',
|
||||
});
|
||||
|
||||
const reader = response.body?.getReader();
|
||||
if (!reader) {
|
||||
res.end();
|
||||
return;
|
||||
}
|
||||
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
if (done) {
|
||||
break;
|
||||
}
|
||||
|
||||
res.write(value);
|
||||
}
|
||||
|
||||
res.end();
|
||||
return;
|
||||
}
|
||||
|
||||
const text = await response.text();
|
||||
res.writeHead(response.status, {
|
||||
'Content-Type': response.headers.get('content-type') || 'application/json',
|
||||
});
|
||||
res.end(text);
|
||||
}
|
||||
|
||||
private buildForwardHeaders(req: http.IncomingMessage): Record<string, string> {
|
||||
const headers: Record<string, string> = {
|
||||
'Content-Type': 'application/json',
|
||||
};
|
||||
|
||||
if (typeof req.headers.authorization === 'string') {
|
||||
headers.Authorization = req.headers.authorization;
|
||||
}
|
||||
|
||||
if (typeof req.headers['x-request-id'] === 'string') {
|
||||
headers['X-Request-Id'] = req.headers['x-request-id'];
|
||||
}
|
||||
|
||||
return headers;
|
||||
}
|
||||
|
||||
private sendError(
|
||||
res: http.ServerResponse,
|
||||
statusCode: number,
|
||||
@@ -140,7 +225,6 @@ export class ChatHandler {
|
||||
message,
|
||||
type,
|
||||
param,
|
||||
code: null,
|
||||
},
|
||||
};
|
||||
|
||||
|
||||
@@ -0,0 +1,316 @@
|
||||
import * as http from 'node:http';
|
||||
import type { IApiError } from '../../interfaces/api.ts';
|
||||
import type { IClusterNodeHeartbeat } from '../../interfaces/cluster.ts';
|
||||
import { ClusterCoordinator } from '../../cluster/coordinator.ts';
|
||||
import { CLUSTER } from '../../constants.ts';
|
||||
|
||||
export class ClusterHandler {
|
||||
private clusterCoordinator: ClusterCoordinator;
|
||||
|
||||
constructor(clusterCoordinator: ClusterCoordinator) {
|
||||
this.clusterCoordinator = clusterCoordinator;
|
||||
}
|
||||
|
||||
public async handle(
|
||||
req: http.IncomingMessage,
|
||||
res: http.ServerResponse,
|
||||
path: string,
|
||||
url: URL,
|
||||
): Promise<void> {
|
||||
if (!this.authenticate(req)) {
|
||||
return this.sendError(res, 401, 'Invalid cluster secret', 'authentication_error');
|
||||
}
|
||||
|
||||
if (path === '/_cluster/status' && req.method === 'GET') {
|
||||
return this.sendJson(res, 200, this.clusterCoordinator.getStatus());
|
||||
}
|
||||
|
||||
if (path === '/_cluster/nodes' && req.method === 'GET') {
|
||||
return this.sendJson(res, 200, this.clusterCoordinator.getStatus().nodes);
|
||||
}
|
||||
|
||||
if (path === '/_cluster/desired' && req.method === 'GET') {
|
||||
return this.sendJson(res, 200, this.clusterCoordinator.getDesiredDeployments());
|
||||
}
|
||||
|
||||
if (
|
||||
(path === '/_cluster/nodes/register' || path === '/_cluster/nodes/heartbeat') &&
|
||||
req.method === 'POST'
|
||||
) {
|
||||
const body = await this.parseBody(req) as IClusterNodeHeartbeat | null;
|
||||
if (!body) {
|
||||
return this.sendError(
|
||||
res,
|
||||
400,
|
||||
'Invalid cluster heartbeat payload',
|
||||
'invalid_request_error',
|
||||
);
|
||||
}
|
||||
|
||||
this.clusterCoordinator.acceptHeartbeat(body);
|
||||
return this.sendJson(res, 200, { ok: true });
|
||||
}
|
||||
|
||||
if (path === '/_cluster/models/resolve' && req.method === 'GET') {
|
||||
const model = url.searchParams.get('model');
|
||||
if (!model) {
|
||||
return this.sendError(res, 400, 'Missing model query parameter', 'invalid_request_error');
|
||||
}
|
||||
|
||||
const resolved = await this.clusterCoordinator.resolveModel(model);
|
||||
if (!resolved) {
|
||||
return this.sendError(res, 404, `Model "${model}" not found in cluster`, 'model_not_found');
|
||||
}
|
||||
|
||||
return this.sendJson(res, 200, resolved);
|
||||
}
|
||||
|
||||
if (path === '/_cluster/models/ensure' && req.method === 'POST') {
|
||||
if (!this.clusterCoordinator.canManageClusterState()) {
|
||||
return this.sendError(
|
||||
res,
|
||||
409,
|
||||
'This node is not the control plane',
|
||||
'invalid_request_error',
|
||||
);
|
||||
}
|
||||
|
||||
const body = await this.parseBody(req) as { model?: string } | null;
|
||||
if (!body?.model) {
|
||||
return this.sendError(res, 400, 'Missing model in request body', 'invalid_request_error');
|
||||
}
|
||||
|
||||
const ensured = await this.clusterCoordinator.ensureModel(body.model);
|
||||
if (!ensured) {
|
||||
return this.sendError(res, 503, `Unable to schedule model "${body.model}"`, 'server_error');
|
||||
}
|
||||
|
||||
return this.sendJson(res, 200, ensured);
|
||||
}
|
||||
|
||||
if (path === '/_cluster/models/desired' && req.method === 'POST') {
|
||||
if (!this.clusterCoordinator.canManageClusterState()) {
|
||||
return this.sendError(
|
||||
res,
|
||||
409,
|
||||
'This node is not the control plane',
|
||||
'invalid_request_error',
|
||||
);
|
||||
}
|
||||
|
||||
const body = await this.parseBody(req) as { model?: string; desiredReplicas?: number } | null;
|
||||
if (!body?.model || body.desiredReplicas === undefined) {
|
||||
return this.sendError(
|
||||
res,
|
||||
400,
|
||||
'Missing model or desiredReplicas in request body',
|
||||
'invalid_request_error',
|
||||
);
|
||||
}
|
||||
|
||||
const desiredDeployment = await this.clusterCoordinator.setDesiredReplicas(
|
||||
body.model,
|
||||
body.desiredReplicas,
|
||||
);
|
||||
if (!desiredDeployment) {
|
||||
return this.sendError(res, 404, `Model "${body.model}" not found`, 'model_not_found');
|
||||
}
|
||||
|
||||
return this.sendJson(res, 200, desiredDeployment);
|
||||
}
|
||||
|
||||
if (path === '/_cluster/models/desired/remove' && req.method === 'POST') {
|
||||
if (!this.clusterCoordinator.canManageClusterState()) {
|
||||
return this.sendError(
|
||||
res,
|
||||
409,
|
||||
'This node is not the control plane',
|
||||
'invalid_request_error',
|
||||
);
|
||||
}
|
||||
|
||||
const body = await this.parseBody(req) as { model?: string } | null;
|
||||
if (!body?.model) {
|
||||
return this.sendError(res, 400, 'Missing model in request body', 'invalid_request_error');
|
||||
}
|
||||
|
||||
const removed = await this.clusterCoordinator.clearDesiredDeployment(body.model);
|
||||
return this.sendJson(res, 200, { removed });
|
||||
}
|
||||
|
||||
if (path === '/_cluster/deployments' && req.method === 'POST') {
|
||||
const body = await this.parseBody(req) as { model?: string; replicaOrdinal?: number } | null;
|
||||
if (!body?.model) {
|
||||
return this.sendError(res, 400, 'Missing model in request body', 'invalid_request_error');
|
||||
}
|
||||
|
||||
const deployed = body.replicaOrdinal !== undefined
|
||||
? await this.clusterCoordinator.deployReplicaLocally(body.model, body.replicaOrdinal)
|
||||
: await this.clusterCoordinator.deployModelLocally(body.model);
|
||||
if (!deployed) {
|
||||
return this.sendError(res, 503, `Unable to deploy model "${body.model}"`, 'server_error');
|
||||
}
|
||||
|
||||
return this.sendJson(res, 200, deployed);
|
||||
}
|
||||
|
||||
if (path === '/_cluster/nodes/cordon' && req.method === 'POST') {
|
||||
if (!this.clusterCoordinator.canManageClusterState()) {
|
||||
return this.sendError(
|
||||
res,
|
||||
409,
|
||||
'This node is not the control plane',
|
||||
'invalid_request_error',
|
||||
);
|
||||
}
|
||||
|
||||
const body = await this.parseBody(req) as { nodeName?: string } | null;
|
||||
if (!body?.nodeName) {
|
||||
return this.sendError(
|
||||
res,
|
||||
400,
|
||||
'Missing nodeName in request body',
|
||||
'invalid_request_error',
|
||||
);
|
||||
}
|
||||
|
||||
const schedulerState = this.clusterCoordinator.setNodeSchedulerState(
|
||||
body.nodeName,
|
||||
'cordoned',
|
||||
);
|
||||
return this.sendJson(res, 200, { nodeName: body.nodeName, schedulerState });
|
||||
}
|
||||
|
||||
if (path === '/_cluster/nodes/uncordon' && req.method === 'POST') {
|
||||
if (!this.clusterCoordinator.canManageClusterState()) {
|
||||
return this.sendError(
|
||||
res,
|
||||
409,
|
||||
'This node is not the control plane',
|
||||
'invalid_request_error',
|
||||
);
|
||||
}
|
||||
|
||||
const body = await this.parseBody(req) as { nodeName?: string } | null;
|
||||
if (!body?.nodeName) {
|
||||
return this.sendError(
|
||||
res,
|
||||
400,
|
||||
'Missing nodeName in request body',
|
||||
'invalid_request_error',
|
||||
);
|
||||
}
|
||||
|
||||
const schedulerState = this.clusterCoordinator.setNodeSchedulerState(body.nodeName, 'active');
|
||||
return this.sendJson(res, 200, { nodeName: body.nodeName, schedulerState });
|
||||
}
|
||||
|
||||
if (path === '/_cluster/nodes/drain' && req.method === 'POST') {
|
||||
if (!this.clusterCoordinator.canManageClusterState()) {
|
||||
return this.sendError(
|
||||
res,
|
||||
409,
|
||||
'This node is not the control plane',
|
||||
'invalid_request_error',
|
||||
);
|
||||
}
|
||||
|
||||
const body = await this.parseBody(req) as { nodeName?: string } | null;
|
||||
if (!body?.nodeName) {
|
||||
return this.sendError(
|
||||
res,
|
||||
400,
|
||||
'Missing nodeName in request body',
|
||||
'invalid_request_error',
|
||||
);
|
||||
}
|
||||
|
||||
const schedulerState = this.clusterCoordinator.setNodeSchedulerState(
|
||||
body.nodeName,
|
||||
'draining',
|
||||
);
|
||||
return this.sendJson(res, 200, { nodeName: body.nodeName, schedulerState });
|
||||
}
|
||||
|
||||
if (path === '/_cluster/nodes/activate' && req.method === 'POST') {
|
||||
if (!this.clusterCoordinator.canManageClusterState()) {
|
||||
return this.sendError(
|
||||
res,
|
||||
409,
|
||||
'This node is not the control plane',
|
||||
'invalid_request_error',
|
||||
);
|
||||
}
|
||||
|
||||
const body = await this.parseBody(req) as { nodeName?: string } | null;
|
||||
if (!body?.nodeName) {
|
||||
return this.sendError(
|
||||
res,
|
||||
400,
|
||||
'Missing nodeName in request body',
|
||||
'invalid_request_error',
|
||||
);
|
||||
}
|
||||
|
||||
const schedulerState = this.clusterCoordinator.setNodeSchedulerState(body.nodeName, 'active');
|
||||
return this.sendJson(res, 200, { nodeName: body.nodeName, schedulerState });
|
||||
}
|
||||
|
||||
return this.sendError(res, 404, `Unknown cluster endpoint: ${path}`, 'invalid_request_error');
|
||||
}
|
||||
|
||||
private authenticate(req: http.IncomingMessage): boolean {
|
||||
const sharedSecret = this.clusterCoordinator.getSharedSecret();
|
||||
if (!sharedSecret) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return req.headers[CLUSTER.AUTH_HEADER_NAME] === sharedSecret;
|
||||
}
|
||||
|
||||
private async parseBody(req: http.IncomingMessage): Promise<unknown | null> {
|
||||
return new Promise((resolve) => {
|
||||
let body = '';
|
||||
|
||||
req.on('data', (chunk) => {
|
||||
body += chunk.toString();
|
||||
});
|
||||
|
||||
req.on('end', () => {
|
||||
if (!body) {
|
||||
resolve(null);
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
resolve(JSON.parse(body));
|
||||
} catch {
|
||||
resolve(null);
|
||||
}
|
||||
});
|
||||
|
||||
req.on('error', () => resolve(null));
|
||||
});
|
||||
}
|
||||
|
||||
private sendJson(res: http.ServerResponse, statusCode: number, body: unknown): void {
|
||||
res.writeHead(statusCode, { 'Content-Type': 'application/json' });
|
||||
res.end(JSON.stringify(body));
|
||||
}
|
||||
|
||||
private sendError(
|
||||
res: http.ServerResponse,
|
||||
statusCode: number,
|
||||
message: string,
|
||||
type: string,
|
||||
): void {
|
||||
const error: IApiError = {
|
||||
error: {
|
||||
message,
|
||||
type,
|
||||
},
|
||||
};
|
||||
|
||||
this.sendJson(res, statusCode, error);
|
||||
}
|
||||
}
|
||||
@@ -1,53 +1,96 @@
|
||||
/**
|
||||
* Embeddings Handler
|
||||
*
|
||||
* Handles /v1/embeddings endpoint.
|
||||
* Embeddings handler.
|
||||
*/
|
||||
|
||||
import * as http from 'node:http';
|
||||
import type {
|
||||
IApiError,
|
||||
IEmbeddingData,
|
||||
IEmbeddingsRequest,
|
||||
IEmbeddingsResponse,
|
||||
IEmbeddingData,
|
||||
IApiError,
|
||||
} from '../../interfaces/api.ts';
|
||||
import { logger } from '../../logger.ts';
|
||||
import { ClusterCoordinator } from '../../cluster/coordinator.ts';
|
||||
import { ContainerManager } from '../../containers/container-manager.ts';
|
||||
import { logger } from '../../logger.ts';
|
||||
import { ModelRegistry } from '../../models/registry.ts';
|
||||
|
||||
/**
|
||||
* Handler for embeddings requests
|
||||
*/
|
||||
export class EmbeddingsHandler {
|
||||
private containerManager: ContainerManager;
|
||||
private modelRegistry: ModelRegistry;
|
||||
private clusterCoordinator: ClusterCoordinator;
|
||||
|
||||
constructor(containerManager: ContainerManager) {
|
||||
constructor(
|
||||
containerManager: ContainerManager,
|
||||
modelRegistry: ModelRegistry,
|
||||
clusterCoordinator: ClusterCoordinator,
|
||||
) {
|
||||
this.containerManager = containerManager;
|
||||
this.modelRegistry = modelRegistry;
|
||||
this.clusterCoordinator = clusterCoordinator;
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle POST /v1/embeddings
|
||||
*/
|
||||
public async handleEmbeddings(
|
||||
req: http.IncomingMessage,
|
||||
res: http.ServerResponse,
|
||||
body: IEmbeddingsRequest,
|
||||
): Promise<void> {
|
||||
const modelName = body.model;
|
||||
const canonicalModel = await this.resolveCanonicalModel(body.model);
|
||||
const requestBody: IEmbeddingsRequest = {
|
||||
...body,
|
||||
model: canonicalModel,
|
||||
};
|
||||
|
||||
logger.dim(`Embeddings request for model: ${modelName}`);
|
||||
logger.dim(`Embeddings request for model: ${canonicalModel}`);
|
||||
|
||||
try {
|
||||
// Find container with the embedding model
|
||||
const container = await this.containerManager.findContainerForModel(modelName);
|
||||
if (!container) {
|
||||
this.sendError(res, 404, `Embedding model "${modelName}" not found`, 'model_not_found');
|
||||
const container = await this.containerManager.findContainerForModel(canonicalModel);
|
||||
if (container) {
|
||||
const response = await this.generateEmbeddings(container, requestBody);
|
||||
res.writeHead(200, { 'Content-Type': 'application/json' });
|
||||
res.end(JSON.stringify(response));
|
||||
return;
|
||||
}
|
||||
|
||||
// Generate embeddings
|
||||
const response = await this.generateEmbeddings(container, body);
|
||||
const ensured = await this.clusterCoordinator.ensureModelViaControlPlane(canonicalModel);
|
||||
if (!ensured) {
|
||||
this.sendError(
|
||||
res,
|
||||
404,
|
||||
`Embedding model "${canonicalModel}" not found`,
|
||||
'model_not_found',
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
res.writeHead(200, { 'Content-Type': 'application/json' });
|
||||
res.end(JSON.stringify(response));
|
||||
if (ensured.location.nodeName === this.clusterCoordinator.getLocalNodeName()) {
|
||||
const localContainer = await this.containerManager.findContainerForModel(canonicalModel);
|
||||
if (!localContainer) {
|
||||
this.sendError(
|
||||
res,
|
||||
503,
|
||||
`Embedding model "${canonicalModel}" is not ready`,
|
||||
'server_error',
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
const response = await this.generateEmbeddings(localContainer, requestBody);
|
||||
res.writeHead(200, { 'Content-Type': 'application/json' });
|
||||
res.end(JSON.stringify(response));
|
||||
return;
|
||||
}
|
||||
|
||||
const response = await fetch(`${ensured.location.endpoint}/v1/embeddings`, {
|
||||
method: 'POST',
|
||||
headers: this.buildForwardHeaders(req),
|
||||
body: JSON.stringify(requestBody),
|
||||
});
|
||||
|
||||
const text = await response.text();
|
||||
res.writeHead(response.status, {
|
||||
'Content-Type': response.headers.get('content-type') || 'application/json',
|
||||
});
|
||||
res.end(text);
|
||||
} catch (error) {
|
||||
const message = error instanceof Error ? error.message : String(error);
|
||||
logger.error(`Embeddings error: ${message}`);
|
||||
@@ -55,9 +98,11 @@ export class EmbeddingsHandler {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate embeddings from container
|
||||
*/
|
||||
private async resolveCanonicalModel(modelName: string): Promise<string> {
|
||||
const model = await this.modelRegistry.getModel(modelName);
|
||||
return model?.id || modelName;
|
||||
}
|
||||
|
||||
private async generateEmbeddings(
|
||||
container: import('../../containers/base-container.ts').BaseContainer,
|
||||
request: IEmbeddingsRequest,
|
||||
@@ -66,7 +111,6 @@ export class EmbeddingsHandler {
|
||||
const embeddings: IEmbeddingData[] = [];
|
||||
let totalTokens = 0;
|
||||
|
||||
// Generate embeddings for each input
|
||||
for (let i = 0; i < inputs.length; i++) {
|
||||
const input = inputs[i];
|
||||
const embedding = await this.getEmbeddingFromContainer(container, request.model, input);
|
||||
@@ -91,9 +135,6 @@ export class EmbeddingsHandler {
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Get embedding from container (container-specific implementation)
|
||||
*/
|
||||
private async getEmbeddingFromContainer(
|
||||
container: import('../../containers/base-container.ts').BaseContainer,
|
||||
model: string,
|
||||
@@ -102,54 +143,17 @@ export class EmbeddingsHandler {
|
||||
const endpoint = container.getEndpoint();
|
||||
const containerType = container.type;
|
||||
|
||||
// Route to container-specific embedding endpoint
|
||||
if (containerType === 'ollama') {
|
||||
return this.getOllamaEmbedding(endpoint, model, input);
|
||||
} else if (containerType === 'vllm') {
|
||||
if (containerType === 'vllm') {
|
||||
return this.getVllmEmbedding(endpoint, model, input);
|
||||
} else if (containerType === 'tgi') {
|
||||
}
|
||||
|
||||
if (containerType === 'tgi') {
|
||||
return this.getTgiEmbedding(endpoint, model, input);
|
||||
}
|
||||
|
||||
throw new Error(`Container type ${containerType} does not support embeddings`);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get embedding from Ollama
|
||||
*/
|
||||
private async getOllamaEmbedding(
|
||||
endpoint: string,
|
||||
model: string,
|
||||
input: string,
|
||||
): Promise<{ vector: number[]; tokenCount: number }> {
|
||||
const response = await fetch(`${endpoint}/api/embeddings`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({
|
||||
model,
|
||||
prompt: input,
|
||||
}),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text();
|
||||
throw new Error(`Ollama embedding error: ${errorText}`);
|
||||
}
|
||||
|
||||
const result = await response.json() as { embedding: number[] };
|
||||
|
||||
// Estimate token count (rough approximation: ~4 chars per token)
|
||||
const tokenCount = Math.ceil(input.length / 4);
|
||||
|
||||
return {
|
||||
vector: result.embedding,
|
||||
tokenCount,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Get embedding from vLLM (OpenAI-compatible)
|
||||
*/
|
||||
private async getVllmEmbedding(
|
||||
endpoint: string,
|
||||
model: string,
|
||||
@@ -158,61 +162,58 @@ export class EmbeddingsHandler {
|
||||
const response = await fetch(`${endpoint}/v1/embeddings`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({
|
||||
model,
|
||||
input,
|
||||
}),
|
||||
body: JSON.stringify({ model, input }),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text();
|
||||
throw new Error(`vLLM embedding error: ${errorText}`);
|
||||
throw new Error(`vLLM embedding error: ${await response.text()}`);
|
||||
}
|
||||
|
||||
const result = await response.json() as IEmbeddingsResponse;
|
||||
|
||||
return {
|
||||
vector: result.data[0].embedding,
|
||||
tokenCount: result.usage.total_tokens,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Get embedding from TGI
|
||||
*/
|
||||
private async getTgiEmbedding(
|
||||
endpoint: string,
|
||||
_model: string,
|
||||
input: string,
|
||||
): Promise<{ vector: number[]; tokenCount: number }> {
|
||||
// TGI uses /embed endpoint
|
||||
const response = await fetch(`${endpoint}/embed`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({
|
||||
inputs: input,
|
||||
}),
|
||||
body: JSON.stringify({ inputs: input }),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text();
|
||||
throw new Error(`TGI embedding error: ${errorText}`);
|
||||
throw new Error(`TGI embedding error: ${await response.text()}`);
|
||||
}
|
||||
|
||||
const result = await response.json() as number[][];
|
||||
|
||||
// Estimate token count
|
||||
const tokenCount = Math.ceil(input.length / 4);
|
||||
|
||||
return {
|
||||
vector: result[0],
|
||||
tokenCount,
|
||||
tokenCount: Math.ceil(input.length / 4),
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Send error response
|
||||
*/
|
||||
private buildForwardHeaders(req: http.IncomingMessage): Record<string, string> {
|
||||
const headers: Record<string, string> = {
|
||||
'Content-Type': 'application/json',
|
||||
};
|
||||
|
||||
if (typeof req.headers.authorization === 'string') {
|
||||
headers.Authorization = req.headers.authorization;
|
||||
}
|
||||
|
||||
if (typeof req.headers['x-request-id'] === 'string') {
|
||||
headers['X-Request-Id'] = req.headers['x-request-id'];
|
||||
}
|
||||
|
||||
return headers;
|
||||
}
|
||||
|
||||
private sendError(
|
||||
res: http.ServerResponse,
|
||||
statusCode: number,
|
||||
@@ -225,7 +226,6 @@ export class EmbeddingsHandler {
|
||||
message,
|
||||
type,
|
||||
param,
|
||||
code: null,
|
||||
},
|
||||
};
|
||||
|
||||
|
||||
@@ -5,5 +5,6 @@
|
||||
*/
|
||||
|
||||
export { ChatHandler } from './chat.ts';
|
||||
export { ClusterHandler } from './cluster.ts';
|
||||
export { ModelsHandler } from './models.ts';
|
||||
export { EmbeddingsHandler } from './embeddings.ts';
|
||||
|
||||
+53
-50
@@ -1,34 +1,29 @@
|
||||
/**
|
||||
* Models Handler
|
||||
*
|
||||
* Handles /v1/models endpoints.
|
||||
* Models handler.
|
||||
*/
|
||||
|
||||
import * as http from 'node:http';
|
||||
import type {
|
||||
IModelInfo,
|
||||
IListModelsResponse,
|
||||
IApiError,
|
||||
} from '../../interfaces/api.ts';
|
||||
import { logger } from '../../logger.ts';
|
||||
import type { IApiError, IListModelsResponse, IModelInfo } from '../../interfaces/api.ts';
|
||||
import { ClusterCoordinator } from '../../cluster/coordinator.ts';
|
||||
import { ContainerManager } from '../../containers/container-manager.ts';
|
||||
import { logger } from '../../logger.ts';
|
||||
import { ModelRegistry } from '../../models/registry.ts';
|
||||
|
||||
/**
|
||||
* Handler for model-related requests
|
||||
*/
|
||||
export class ModelsHandler {
|
||||
private containerManager: ContainerManager;
|
||||
private modelRegistry: ModelRegistry;
|
||||
private clusterCoordinator: ClusterCoordinator;
|
||||
|
||||
constructor(containerManager: ContainerManager, modelRegistry: ModelRegistry) {
|
||||
constructor(
|
||||
containerManager: ContainerManager,
|
||||
modelRegistry: ModelRegistry,
|
||||
clusterCoordinator: ClusterCoordinator,
|
||||
) {
|
||||
this.containerManager = containerManager;
|
||||
this.modelRegistry = modelRegistry;
|
||||
this.clusterCoordinator = clusterCoordinator;
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle GET /v1/models
|
||||
*/
|
||||
public async handleListModels(res: http.ServerResponse): Promise<void> {
|
||||
try {
|
||||
const models = await this.getAvailableModels();
|
||||
@@ -47,13 +42,12 @@ export class ModelsHandler {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle GET /v1/models/:model
|
||||
*/
|
||||
public async handleGetModel(res: http.ServerResponse, modelId: string): Promise<void> {
|
||||
try {
|
||||
const models = await this.getAvailableModels();
|
||||
const model = models.find((m) => m.id === modelId);
|
||||
const requested = await this.modelRegistry.getModel(modelId);
|
||||
const canonicalId = requested?.id || modelId;
|
||||
const model = models.find((entry) => entry.id === canonicalId);
|
||||
|
||||
if (!model) {
|
||||
this.sendError(res, 404, `Model "${modelId}" not found`, 'model_not_found');
|
||||
@@ -69,51 +63,61 @@ export class ModelsHandler {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all available models from containers and greenlist
|
||||
*/
|
||||
private async getAvailableModels(): Promise<IModelInfo[]> {
|
||||
const models: IModelInfo[] = [];
|
||||
const seen = new Set<string>();
|
||||
const timestamp = Math.floor(Date.now() / 1000);
|
||||
|
||||
// Get models from running containers
|
||||
const containerModels = await this.containerManager.getAllAvailableModels();
|
||||
for (const [modelId, modelInfo] of containerModels) {
|
||||
if (!seen.has(modelId)) {
|
||||
seen.add(modelId);
|
||||
models.push({
|
||||
id: modelId,
|
||||
object: 'model',
|
||||
created: timestamp,
|
||||
owned_by: `modelgrid-${modelInfo.container}`,
|
||||
});
|
||||
for (const [modelId, endpoints] of containerModels) {
|
||||
if (seen.has(modelId)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const primaryEndpoint = endpoints[0];
|
||||
seen.add(modelId);
|
||||
models.push({
|
||||
id: modelId,
|
||||
object: 'model',
|
||||
created: timestamp,
|
||||
owned_by: `modelgrid-${primaryEndpoint?.type || 'vllm'}`,
|
||||
});
|
||||
}
|
||||
|
||||
// Add greenlit models that aren't loaded yet
|
||||
const greenlitModels = await this.modelRegistry.getAllGreenlitModels();
|
||||
for (const greenlit of greenlitModels) {
|
||||
if (!seen.has(greenlit.name)) {
|
||||
seen.add(greenlit.name);
|
||||
models.push({
|
||||
id: greenlit.name,
|
||||
object: 'model',
|
||||
created: timestamp,
|
||||
owned_by: `modelgrid-${greenlit.container}`,
|
||||
});
|
||||
const clusterStatus = this.clusterCoordinator.getStatus();
|
||||
for (const [modelId, locations] of Object.entries(clusterStatus.models)) {
|
||||
if (seen.has(modelId) || locations.length === 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
seen.add(modelId);
|
||||
models.push({
|
||||
id: modelId,
|
||||
object: 'model',
|
||||
created: timestamp,
|
||||
owned_by: `modelgrid-${locations[0].engine}`,
|
||||
});
|
||||
}
|
||||
|
||||
// Sort alphabetically
|
||||
models.sort((a, b) => a.id.localeCompare(b.id));
|
||||
const catalogModels = await this.modelRegistry.getAllModels();
|
||||
for (const model of catalogModels) {
|
||||
if (seen.has(model.id)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
seen.add(model.id);
|
||||
models.push({
|
||||
id: model.id,
|
||||
object: 'model',
|
||||
created: timestamp,
|
||||
owned_by: `modelgrid-${model.engine}`,
|
||||
});
|
||||
}
|
||||
|
||||
models.sort((left, right) => left.id.localeCompare(right.id));
|
||||
return models;
|
||||
}
|
||||
|
||||
/**
|
||||
* Send error response
|
||||
*/
|
||||
private sendError(
|
||||
res: http.ServerResponse,
|
||||
statusCode: number,
|
||||
@@ -126,7 +130,6 @@ export class ModelsHandler {
|
||||
message,
|
||||
type,
|
||||
param,
|
||||
code: null,
|
||||
},
|
||||
};
|
||||
|
||||
|
||||
+39
-11
@@ -63,7 +63,11 @@ export class SanityMiddleware {
|
||||
if (request.temperature !== undefined) {
|
||||
const temp = request.temperature as number;
|
||||
if (typeof temp !== 'number' || temp < 0 || temp > 2) {
|
||||
return { valid: false, error: '"temperature" must be between 0 and 2', param: 'temperature' };
|
||||
return {
|
||||
valid: false,
|
||||
error: '"temperature" must be between 0 and 2',
|
||||
param: 'temperature',
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
@@ -77,7 +81,11 @@ export class SanityMiddleware {
|
||||
if (request.max_tokens !== undefined) {
|
||||
const maxTokens = request.max_tokens as number;
|
||||
if (typeof maxTokens !== 'number' || maxTokens < 1) {
|
||||
return { valid: false, error: '"max_tokens" must be a positive integer', param: 'max_tokens' };
|
||||
return {
|
||||
valid: false,
|
||||
error: '"max_tokens" must be a positive integer',
|
||||
param: 'max_tokens',
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
@@ -95,14 +103,22 @@ export class SanityMiddleware {
|
||||
if (request.presence_penalty !== undefined) {
|
||||
const pp = request.presence_penalty as number;
|
||||
if (typeof pp !== 'number' || pp < -2 || pp > 2) {
|
||||
return { valid: false, error: '"presence_penalty" must be between -2 and 2', param: 'presence_penalty' };
|
||||
return {
|
||||
valid: false,
|
||||
error: '"presence_penalty" must be between -2 and 2',
|
||||
param: 'presence_penalty',
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
if (request.frequency_penalty !== undefined) {
|
||||
const fp = request.frequency_penalty as number;
|
||||
if (typeof fp !== 'number' || fp < -2 || fp > 2) {
|
||||
return { valid: false, error: '"frequency_penalty" must be between -2 and 2', param: 'frequency_penalty' };
|
||||
return {
|
||||
valid: false,
|
||||
error: '"frequency_penalty" must be between -2 and 2',
|
||||
param: 'frequency_penalty',
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
@@ -114,7 +130,11 @@ export class SanityMiddleware {
|
||||
*/
|
||||
private validateMessage(msg: Record<string, unknown>, index: number): IValidationResult {
|
||||
if (!msg || typeof msg !== 'object') {
|
||||
return { valid: false, error: `Message at index ${index} must be an object`, param: `messages[${index}]` };
|
||||
return {
|
||||
valid: false,
|
||||
error: `Message at index ${index} must be an object`,
|
||||
param: `messages[${index}]`,
|
||||
};
|
||||
}
|
||||
|
||||
// Validate role
|
||||
@@ -178,7 +198,11 @@ export class SanityMiddleware {
|
||||
|
||||
const input = request.input;
|
||||
if (typeof input !== 'string' && !Array.isArray(input)) {
|
||||
return { valid: false, error: '"input" must be a string or array of strings', param: 'input' };
|
||||
return {
|
||||
valid: false,
|
||||
error: '"input" must be a string or array of strings',
|
||||
param: 'input',
|
||||
};
|
||||
}
|
||||
|
||||
if (Array.isArray(input)) {
|
||||
@@ -197,7 +221,11 @@ export class SanityMiddleware {
|
||||
if (request.encoding_format !== undefined) {
|
||||
const format = request.encoding_format as string;
|
||||
if (format !== 'float' && format !== 'base64') {
|
||||
return { valid: false, error: '"encoding_format" must be "float" or "base64"', param: 'encoding_format' };
|
||||
return {
|
||||
valid: false,
|
||||
error: '"encoding_format" must be "float" or "base64"',
|
||||
param: 'encoding_format',
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
@@ -205,14 +233,14 @@ export class SanityMiddleware {
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if model is in greenlist (async validation)
|
||||
* Check if model is in the public registry.
|
||||
*/
|
||||
public async validateModelGreenlist(modelName: string): Promise<IValidationResult> {
|
||||
const isGreenlit = await this.modelRegistry.isModelGreenlit(modelName);
|
||||
if (!isGreenlit) {
|
||||
const isListed = await this.modelRegistry.isModelListed(modelName);
|
||||
if (!isListed) {
|
||||
return {
|
||||
valid: false,
|
||||
error: `Model "${modelName}" is not greenlit. Contact administrator to add it to the greenlist.`,
|
||||
error: `Model "${modelName}" is not listed in the registry.`,
|
||||
param: 'model',
|
||||
};
|
||||
}
|
||||
|
||||
+38
-17
@@ -5,7 +5,8 @@
|
||||
*/
|
||||
|
||||
import * as http from 'node:http';
|
||||
import type { IApiError } from '../interfaces/api.ts';
|
||||
import type { IApiError, IChatCompletionRequest } from '../interfaces/api.ts';
|
||||
import { ClusterCoordinator } from '../cluster/coordinator.ts';
|
||||
import { logger } from '../logger.ts';
|
||||
import { ContainerManager } from '../containers/container-manager.ts';
|
||||
import { ModelRegistry } from '../models/registry.ts';
|
||||
@@ -23,6 +24,7 @@ export class ApiRouter {
|
||||
private containerManager: ContainerManager;
|
||||
private modelRegistry: ModelRegistry;
|
||||
private modelLoader: ModelLoader;
|
||||
private clusterCoordinator: ClusterCoordinator;
|
||||
private chatHandler: ChatHandler;
|
||||
private modelsHandler: ModelsHandler;
|
||||
private embeddingsHandler: EmbeddingsHandler;
|
||||
@@ -33,16 +35,27 @@ export class ApiRouter {
|
||||
containerManager: ContainerManager,
|
||||
modelRegistry: ModelRegistry,
|
||||
modelLoader: ModelLoader,
|
||||
clusterCoordinator: ClusterCoordinator,
|
||||
apiKeys: string[],
|
||||
) {
|
||||
this.containerManager = containerManager;
|
||||
this.modelRegistry = modelRegistry;
|
||||
this.modelLoader = modelLoader;
|
||||
this.clusterCoordinator = clusterCoordinator;
|
||||
|
||||
// Initialize handlers
|
||||
this.chatHandler = new ChatHandler(containerManager, modelLoader);
|
||||
this.modelsHandler = new ModelsHandler(containerManager, modelRegistry);
|
||||
this.embeddingsHandler = new EmbeddingsHandler(containerManager);
|
||||
this.chatHandler = new ChatHandler(
|
||||
containerManager,
|
||||
modelRegistry,
|
||||
modelLoader,
|
||||
clusterCoordinator,
|
||||
);
|
||||
this.modelsHandler = new ModelsHandler(containerManager, modelRegistry, clusterCoordinator);
|
||||
this.embeddingsHandler = new EmbeddingsHandler(
|
||||
containerManager,
|
||||
modelRegistry,
|
||||
clusterCoordinator,
|
||||
);
|
||||
|
||||
// Initialize middleware
|
||||
this.authMiddleware = new AuthMiddleware(apiKeys);
|
||||
@@ -119,8 +132,8 @@ export class ApiRouter {
|
||||
return;
|
||||
}
|
||||
|
||||
// Handle request
|
||||
await this.chatHandler.handleChatCompletion(req, res, body);
|
||||
const requestBody = this.sanityMiddleware.sanitizeChatRequest(body as Record<string, unknown>);
|
||||
await this.chatHandler.handleChatCompletion(req, res, requestBody);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -149,7 +162,7 @@ export class ApiRouter {
|
||||
}
|
||||
|
||||
// Convert to chat format and handle
|
||||
const chatBody = this.convertCompletionToChat(body);
|
||||
const chatBody = this.convertCompletionToChat(body as Record<string, unknown>);
|
||||
await this.chatHandler.handleChatCompletion(req, res, chatBody);
|
||||
}
|
||||
|
||||
@@ -222,7 +235,16 @@ export class ApiRouter {
|
||||
return;
|
||||
}
|
||||
|
||||
await this.embeddingsHandler.handleEmbeddings(res, body);
|
||||
const validation = this.sanityMiddleware.validateEmbeddingsRequest(body);
|
||||
if (!validation.valid) {
|
||||
this.sendError(res, 400, validation.error || 'Invalid request', 'invalid_request_error');
|
||||
return;
|
||||
}
|
||||
|
||||
const requestBody = this.sanityMiddleware.sanitizeEmbeddingsRequest(
|
||||
body as Record<string, unknown>,
|
||||
);
|
||||
await this.embeddingsHandler.handleEmbeddings(req, res, requestBody);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -257,21 +279,21 @@ export class ApiRouter {
|
||||
/**
|
||||
* Convert legacy completion request to chat format
|
||||
*/
|
||||
private convertCompletionToChat(body: Record<string, unknown>): Record<string, unknown> {
|
||||
private convertCompletionToChat(body: Record<string, unknown>): IChatCompletionRequest {
|
||||
const prompt = body.prompt as string | string[];
|
||||
const promptText = Array.isArray(prompt) ? prompt.join('\n') : prompt;
|
||||
|
||||
return {
|
||||
model: body.model,
|
||||
model: body.model as string,
|
||||
messages: [
|
||||
{ role: 'user', content: promptText },
|
||||
],
|
||||
max_tokens: body.max_tokens,
|
||||
temperature: body.temperature,
|
||||
top_p: body.top_p,
|
||||
n: body.n,
|
||||
stream: body.stream,
|
||||
stop: body.stop,
|
||||
max_tokens: body.max_tokens as number | undefined,
|
||||
temperature: body.temperature as number | undefined,
|
||||
top_p: body.top_p as number | undefined,
|
||||
n: body.n as number | undefined,
|
||||
stream: body.stream as boolean | undefined,
|
||||
stop: body.stop as string | string[] | undefined,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -290,7 +312,6 @@ export class ApiRouter {
|
||||
message,
|
||||
type,
|
||||
param,
|
||||
code: null,
|
||||
},
|
||||
};
|
||||
|
||||
|
||||
+17
-4
@@ -7,13 +7,15 @@
|
||||
import * as http from 'node:http';
|
||||
import type { IApiConfig } from '../interfaces/config.ts';
|
||||
import type { IHealthResponse } from '../interfaces/api.ts';
|
||||
import { ClusterCoordinator } from '../cluster/coordinator.ts';
|
||||
import { logger } from '../logger.ts';
|
||||
import { API_SERVER } from '../constants.ts';
|
||||
import { VERSION } from '../constants.ts';
|
||||
import { ApiRouter } from './router.ts';
|
||||
import { ContainerManager } from '../containers/container-manager.ts';
|
||||
import { ModelRegistry } from '../models/registry.ts';
|
||||
import { ModelLoader } from '../models/loader.ts';
|
||||
import { GpuDetector } from '../hardware/gpu-detector.ts';
|
||||
import { ClusterHandler } from './handlers/cluster.ts';
|
||||
|
||||
/**
|
||||
* API Server for ModelGrid
|
||||
@@ -26,22 +28,29 @@ export class ApiServer {
|
||||
private modelRegistry: ModelRegistry;
|
||||
private modelLoader: ModelLoader;
|
||||
private gpuDetector: GpuDetector;
|
||||
private clusterCoordinator: ClusterCoordinator;
|
||||
private clusterHandler: ClusterHandler;
|
||||
private startTime: number = 0;
|
||||
|
||||
constructor(
|
||||
config: IApiConfig,
|
||||
containerManager: ContainerManager,
|
||||
modelRegistry: ModelRegistry,
|
||||
modelLoader: ModelLoader,
|
||||
clusterCoordinator: ClusterCoordinator,
|
||||
) {
|
||||
this.config = config;
|
||||
this.containerManager = containerManager;
|
||||
this.modelRegistry = modelRegistry;
|
||||
this.gpuDetector = new GpuDetector();
|
||||
this.modelLoader = new ModelLoader(modelRegistry, containerManager, true);
|
||||
this.modelLoader = modelLoader;
|
||||
this.clusterCoordinator = clusterCoordinator;
|
||||
this.clusterHandler = new ClusterHandler(clusterCoordinator);
|
||||
this.router = new ApiRouter(
|
||||
containerManager,
|
||||
modelRegistry,
|
||||
this.modelLoader,
|
||||
clusterCoordinator,
|
||||
config.apiKeys,
|
||||
);
|
||||
}
|
||||
@@ -120,6 +129,11 @@ export class ApiServer {
|
||||
const url = new URL(req.url || '/', `http://${req.headers.host || 'localhost'}`);
|
||||
const path = url.pathname;
|
||||
|
||||
if (path.startsWith('/_cluster')) {
|
||||
await this.clusterHandler.handle(req, res, path, url);
|
||||
return;
|
||||
}
|
||||
|
||||
// Health check endpoint (no auth required)
|
||||
if (path === '/health' || path === '/healthz') {
|
||||
await this.handleHealthCheck(res);
|
||||
@@ -194,7 +208,7 @@ export class ApiServer {
|
||||
|
||||
const response: IHealthResponse = {
|
||||
status,
|
||||
version: '1.0.0', // TODO: Get from config
|
||||
version: VERSION,
|
||||
uptime: Math.floor((Date.now() - this.startTime) / 1000),
|
||||
containers: statuses.size,
|
||||
models: models.size,
|
||||
@@ -276,7 +290,6 @@ export class ApiServer {
|
||||
error: {
|
||||
message,
|
||||
type,
|
||||
code: null,
|
||||
},
|
||||
}));
|
||||
}
|
||||
|
||||
@@ -58,6 +58,7 @@ export class ModelGridCli {
|
||||
const serviceHandler = this.modelgrid.getServiceHandler();
|
||||
const gpuHandler = this.modelgrid.getGpuHandler();
|
||||
const containerHandler = this.modelgrid.getContainerHandler();
|
||||
const clusterHandler = this.modelgrid.getClusterHandler();
|
||||
const modelHandler = this.modelgrid.getModelHandler();
|
||||
const configHandler = this.modelgrid.getConfigHandler();
|
||||
|
||||
@@ -99,6 +100,51 @@ export class ModelGridCli {
|
||||
return;
|
||||
}
|
||||
|
||||
if (command === 'cluster') {
|
||||
const subcommand = commandArgs[0] || 'status';
|
||||
const subcommandArgs = commandArgs.slice(1);
|
||||
|
||||
switch (subcommand) {
|
||||
case 'status':
|
||||
await clusterHandler.status();
|
||||
break;
|
||||
case 'nodes':
|
||||
await clusterHandler.nodes();
|
||||
break;
|
||||
case 'models':
|
||||
await clusterHandler.models();
|
||||
break;
|
||||
case 'desired':
|
||||
await clusterHandler.desired();
|
||||
break;
|
||||
case 'ensure':
|
||||
await clusterHandler.ensure(subcommandArgs[0]);
|
||||
break;
|
||||
case 'scale':
|
||||
await clusterHandler.scale(subcommandArgs[0], parseInt(subcommandArgs[1] || '', 10));
|
||||
break;
|
||||
case 'clear':
|
||||
await clusterHandler.clear(subcommandArgs[0]);
|
||||
break;
|
||||
case 'cordon':
|
||||
await clusterHandler.cordon(subcommandArgs[0]);
|
||||
break;
|
||||
case 'uncordon':
|
||||
await clusterHandler.uncordon(subcommandArgs[0]);
|
||||
break;
|
||||
case 'drain':
|
||||
await clusterHandler.drain(subcommandArgs[0]);
|
||||
break;
|
||||
case 'activate':
|
||||
await clusterHandler.activate(subcommandArgs[0]);
|
||||
break;
|
||||
default:
|
||||
this.showClusterHelp();
|
||||
break;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// GPU commands
|
||||
if (command === 'gpu') {
|
||||
const subcommand = commandArgs[0] || 'list';
|
||||
@@ -226,6 +272,12 @@ export class ModelGridCli {
|
||||
|
||||
// Top-level commands
|
||||
switch (command) {
|
||||
case 'run':
|
||||
await modelHandler.pull(commandArgs[0]);
|
||||
break;
|
||||
case 'ps':
|
||||
await containerHandler.list();
|
||||
break;
|
||||
case 'update':
|
||||
await serviceHandler.update();
|
||||
break;
|
||||
@@ -267,10 +319,13 @@ export class ModelGridCli {
|
||||
console.log('');
|
||||
|
||||
logger.log(theme.info('Commands:'));
|
||||
this.printCommand('run <model>', 'Deploy a vLLM model');
|
||||
this.printCommand('ps', 'List active deployments');
|
||||
this.printCommand('service <subcommand>', 'Manage systemd service');
|
||||
this.printCommand('gpu <subcommand>', 'Manage GPU hardware');
|
||||
this.printCommand('container <subcommand>', 'Manage AI containers');
|
||||
this.printCommand('model <subcommand>', 'Manage AI models');
|
||||
this.printCommand('container <subcommand>', 'Manage deployments directly');
|
||||
this.printCommand('model <subcommand>', 'Browse and deploy catalog models');
|
||||
this.printCommand('cluster <subcommand>', 'Inspect cluster control plane');
|
||||
this.printCommand('config <subcommand>', 'Manage configuration');
|
||||
this.printCommand('update', 'Update ModelGrid', theme.dim('(requires root)'));
|
||||
this.printCommand('uninstall', 'Remove ModelGrid', theme.dim('(requires root)'));
|
||||
@@ -280,9 +335,9 @@ export class ModelGridCli {
|
||||
|
||||
logger.log(theme.info('Quick Start:'));
|
||||
logger.dim(' modelgrid gpu list # Detect GPUs');
|
||||
logger.dim(' modelgrid container add # Add an Ollama/vLLM container');
|
||||
logger.dim(' modelgrid container start # Start containers');
|
||||
logger.dim(' modelgrid model pull llama3 # Pull a model');
|
||||
logger.dim(' modelgrid model list # Browse catalog');
|
||||
logger.dim(' modelgrid run <model> # Deploy a vLLM model');
|
||||
logger.dim(' modelgrid ps # List active deployments');
|
||||
logger.dim(' modelgrid service enable # Install as service');
|
||||
console.log('');
|
||||
|
||||
@@ -290,7 +345,9 @@ export class ModelGridCli {
|
||||
logger.dim(' curl -X POST http://localhost:8080/v1/chat/completions \\');
|
||||
logger.dim(' -H "Authorization: Bearer YOUR_API_KEY" \\');
|
||||
logger.dim(' -H "Content-Type: application/json" \\');
|
||||
logger.dim(' -d \'{"model": "llama3", "messages": [{"role": "user", "content": "Hello"}]}\'');
|
||||
logger.dim(
|
||||
' -d \'{"model": "llama3", "messages": [{"role": "user", "content": "Hello"}]}\'',
|
||||
);
|
||||
console.log('');
|
||||
}
|
||||
|
||||
@@ -360,17 +417,17 @@ Usage:
|
||||
modelgrid container <subcommand> [arguments]
|
||||
|
||||
Subcommands:
|
||||
list List all configured containers
|
||||
add Add a new container interactively
|
||||
remove <id> Remove a container by ID
|
||||
start [id] Start a container (or all if no ID)
|
||||
stop [id] Stop a container (or all if no ID)
|
||||
logs <id> Show container logs
|
||||
list List all configured deployments
|
||||
add Add a vLLM deployment interactively
|
||||
remove <id> Remove a deployment by ID
|
||||
start [id] Start a deployment (or all if no ID)
|
||||
stop [id] Stop a deployment (or all if no ID)
|
||||
logs <id> Show deployment logs
|
||||
|
||||
Examples:
|
||||
modelgrid container add # Add new container
|
||||
modelgrid container start ollama # Start specific container
|
||||
modelgrid container logs ollama # View container logs
|
||||
modelgrid container add # Add new deployment
|
||||
modelgrid container start qwen2 # Start specific deployment
|
||||
modelgrid container logs qwen2 # View deployment logs
|
||||
`);
|
||||
}
|
||||
|
||||
@@ -385,16 +442,43 @@ Usage:
|
||||
modelgrid model <subcommand> [arguments]
|
||||
|
||||
Subcommands:
|
||||
list List all available models
|
||||
pull <name> Pull a model (must be greenlit)
|
||||
remove <name> Remove a model
|
||||
status Show model loading recommendations
|
||||
refresh Refresh greenlist cache
|
||||
list List all catalog models
|
||||
pull <name> Deploy a model from the registry
|
||||
remove <name> Remove a deployed model
|
||||
status Show deployment recommendations
|
||||
refresh Refresh the model catalog cache
|
||||
|
||||
Examples:
|
||||
modelgrid model list # Show all models
|
||||
modelgrid model pull llama3:8b # Pull a model
|
||||
modelgrid model status # Show VRAM recommendations
|
||||
modelgrid model list # Show all models
|
||||
modelgrid model pull meta-llama/Llama-3.1-8B-Instruct
|
||||
modelgrid model status # Show GPU-fit recommendations
|
||||
`);
|
||||
}
|
||||
|
||||
private showClusterHelp(): void {
|
||||
logger.log(`
|
||||
ModelGrid - Cluster Commands
|
||||
|
||||
Usage:
|
||||
modelgrid cluster <subcommand> [arguments]
|
||||
|
||||
Subcommands:
|
||||
status Show cluster status
|
||||
nodes List registered nodes
|
||||
models List clustered model locations
|
||||
desired Show desired deployment targets
|
||||
ensure <name> Ask the control plane to schedule a model
|
||||
scale <name> <replicas> Set desired replica count
|
||||
clear <name> Remove desired deployment target
|
||||
cordon <node> Prevent new placements on a node
|
||||
uncordon <node> Re-enable placements on a node
|
||||
drain <node> Mark a node for evacuation
|
||||
activate <node> Mark a node active again
|
||||
|
||||
Examples:
|
||||
modelgrid cluster status
|
||||
modelgrid cluster ensure meta-llama/Llama-3.1-8B-Instruct
|
||||
modelgrid cluster cordon worker-a
|
||||
`);
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,192 @@
|
||||
import * as fs from 'node:fs/promises';
|
||||
import { CLUSTER, PATHS } from '../constants.ts';
|
||||
import type { IModelGridConfig } from '../interfaces/config.ts';
|
||||
import { logger } from '../logger.ts';
|
||||
|
||||
export class ClusterHandler {
|
||||
public async status(): Promise<void> {
|
||||
const response = await this.request('/_cluster/status');
|
||||
if (!response) {
|
||||
return;
|
||||
}
|
||||
|
||||
logger.log(JSON.stringify(response, null, 2));
|
||||
}
|
||||
|
||||
public async nodes(): Promise<void> {
|
||||
const response = await this.request('/_cluster/nodes');
|
||||
if (!response) {
|
||||
return;
|
||||
}
|
||||
|
||||
logger.log(JSON.stringify(response, null, 2));
|
||||
}
|
||||
|
||||
public async models(): Promise<void> {
|
||||
const response = await this.request('/_cluster/status');
|
||||
if (!response || typeof response !== 'object' || !('models' in response)) {
|
||||
return;
|
||||
}
|
||||
|
||||
logger.log(JSON.stringify((response as { models: unknown }).models, null, 2));
|
||||
}
|
||||
|
||||
public async desired(): Promise<void> {
|
||||
const response = await this.request('/_cluster/desired');
|
||||
if (!response) {
|
||||
return;
|
||||
}
|
||||
|
||||
logger.log(JSON.stringify(response, null, 2));
|
||||
}
|
||||
|
||||
public async ensure(model: string): Promise<void> {
|
||||
if (!model) {
|
||||
logger.error('Model ID is required');
|
||||
return;
|
||||
}
|
||||
|
||||
const response = await this.request('/_cluster/models/ensure', {
|
||||
method: 'POST',
|
||||
body: { model },
|
||||
});
|
||||
if (!response) {
|
||||
return;
|
||||
}
|
||||
|
||||
logger.log(JSON.stringify(response, null, 2));
|
||||
}
|
||||
|
||||
public async scale(model: string, desiredReplicas: number): Promise<void> {
|
||||
if (!model || Number.isNaN(desiredReplicas)) {
|
||||
logger.error('Model ID and desired replica count are required');
|
||||
return;
|
||||
}
|
||||
|
||||
const response = await this.request('/_cluster/models/desired', {
|
||||
method: 'POST',
|
||||
body: { model, desiredReplicas },
|
||||
});
|
||||
if (!response) {
|
||||
return;
|
||||
}
|
||||
|
||||
logger.log(JSON.stringify(response, null, 2));
|
||||
}
|
||||
|
||||
public async clear(model: string): Promise<void> {
|
||||
if (!model) {
|
||||
logger.error('Model ID is required');
|
||||
return;
|
||||
}
|
||||
|
||||
const response = await this.request('/_cluster/models/desired/remove', {
|
||||
method: 'POST',
|
||||
body: { model },
|
||||
});
|
||||
if (!response) {
|
||||
return;
|
||||
}
|
||||
|
||||
logger.log(JSON.stringify(response, null, 2));
|
||||
}
|
||||
|
||||
public async cordon(nodeName: string): Promise<void> {
|
||||
await this.updateNodeState('/_cluster/nodes/cordon', nodeName);
|
||||
}
|
||||
|
||||
public async uncordon(nodeName: string): Promise<void> {
|
||||
await this.updateNodeState('/_cluster/nodes/uncordon', nodeName);
|
||||
}
|
||||
|
||||
public async drain(nodeName: string): Promise<void> {
|
||||
await this.updateNodeState('/_cluster/nodes/drain', nodeName);
|
||||
}
|
||||
|
||||
public async activate(nodeName: string): Promise<void> {
|
||||
await this.updateNodeState('/_cluster/nodes/activate', nodeName);
|
||||
}
|
||||
|
||||
private async request(
|
||||
path: string,
|
||||
options: {
|
||||
method?: 'GET' | 'POST';
|
||||
body?: unknown;
|
||||
} = {},
|
||||
): Promise<unknown | null> {
|
||||
const config = await this.readConfig();
|
||||
if (!config) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const endpoint = this.resolveEndpoint(config);
|
||||
const headers: Record<string, string> = {
|
||||
'Content-Type': 'application/json',
|
||||
};
|
||||
|
||||
if (config.cluster.sharedSecret) {
|
||||
headers[CLUSTER.AUTH_HEADER_NAME] = config.cluster.sharedSecret;
|
||||
}
|
||||
|
||||
try {
|
||||
const response = await fetch(`${endpoint}${path}`, {
|
||||
method: options.method || 'GET',
|
||||
headers,
|
||||
body: options.body ? JSON.stringify(options.body) : undefined,
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
logger.error(`Cluster request failed: ${response.status} ${await response.text()}`);
|
||||
return null;
|
||||
}
|
||||
|
||||
return await response.json();
|
||||
} catch (error) {
|
||||
logger.error(
|
||||
`Cluster request failed: ${error instanceof Error ? error.message : String(error)}`,
|
||||
);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
private async readConfig(): Promise<IModelGridConfig | null> {
|
||||
try {
|
||||
return JSON.parse(await fs.readFile(PATHS.CONFIG_FILE, 'utf-8')) as IModelGridConfig;
|
||||
} catch (error) {
|
||||
logger.error(
|
||||
`Failed to read config: ${error instanceof Error ? error.message : String(error)}`,
|
||||
);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
private resolveEndpoint(config: IModelGridConfig): string {
|
||||
if (config.cluster.controlPlaneUrl) {
|
||||
return config.cluster.controlPlaneUrl;
|
||||
}
|
||||
|
||||
if (config.cluster.advertiseUrl) {
|
||||
return config.cluster.advertiseUrl;
|
||||
}
|
||||
|
||||
const host = config.api.host === '0.0.0.0' ? '127.0.0.1' : config.api.host;
|
||||
return `http://${host}:${config.api.port}`;
|
||||
}
|
||||
|
||||
private async updateNodeState(path: string, nodeName: string): Promise<void> {
|
||||
if (!nodeName) {
|
||||
logger.error('Node name is required');
|
||||
return;
|
||||
}
|
||||
|
||||
const response = await this.request(path, {
|
||||
method: 'POST',
|
||||
body: { nodeName },
|
||||
});
|
||||
if (!response) {
|
||||
return;
|
||||
}
|
||||
|
||||
logger.log(JSON.stringify(response, null, 2));
|
||||
}
|
||||
}
|
||||
+62
-12
@@ -25,6 +25,26 @@ export class ConfigHandler {
|
||||
const configPath = PATHS.CONFIG_FILE;
|
||||
const configContent = await fs.readFile(configPath, 'utf-8');
|
||||
const config = JSON.parse(configContent) as IModelGridConfig;
|
||||
const modelConfig = {
|
||||
registryUrl: config.models.registryUrl ||
|
||||
(config.models as { greenlistUrl?: string }).greenlistUrl ||
|
||||
'https://list.modelgrid.com/catalog/models.json',
|
||||
autoDeploy: config.models.autoDeploy ??
|
||||
(config.models as { autoPull?: boolean }).autoPull ?? true,
|
||||
defaultEngine: config.models.defaultEngine || 'vllm',
|
||||
autoLoad: config.models.autoLoad || [],
|
||||
};
|
||||
const clusterConfig = config.cluster || {
|
||||
enabled: false,
|
||||
nodeName: 'modelgrid-local',
|
||||
role: 'standalone',
|
||||
bindHost: '0.0.0.0',
|
||||
gossipPort: 7946,
|
||||
sharedSecret: undefined,
|
||||
advertiseUrl: undefined,
|
||||
controlPlaneUrl: undefined,
|
||||
heartbeatIntervalMs: 5000,
|
||||
};
|
||||
|
||||
// Overview
|
||||
logger.logBox(
|
||||
@@ -48,9 +68,7 @@ export class ConfigHandler {
|
||||
`Host: ${theme.info(config.api.host)}`,
|
||||
`Port: ${theme.highlight(String(config.api.port))}`,
|
||||
`API Keys: ${config.api.apiKeys.length} configured`,
|
||||
...(config.api.rateLimit
|
||||
? [`Rate Limit: ${config.api.rateLimit} req/min`]
|
||||
: []),
|
||||
...(config.api.rateLimit ? [`Rate Limit: ${config.api.rateLimit} req/min`] : []),
|
||||
'',
|
||||
theme.dim('Endpoint:'),
|
||||
` http://${config.api.host}:${config.api.port}/v1/chat/completions`,
|
||||
@@ -88,12 +106,33 @@ export class ConfigHandler {
|
||||
logger.logBox(
|
||||
'Models',
|
||||
[
|
||||
`Auto Pull: ${config.models.autoPull ? theme.success('Enabled') : theme.dim('Disabled')}`,
|
||||
`Default Container: ${config.models.defaultContainer}`,
|
||||
`Auto Load: ${config.models.autoLoad.length} model(s)`,
|
||||
`Auto Deploy: ${
|
||||
modelConfig.autoDeploy ? theme.success('Enabled') : theme.dim('Disabled')
|
||||
}`,
|
||||
`Default Engine: ${modelConfig.defaultEngine}`,
|
||||
`Auto Load: ${modelConfig.autoLoad.length} model(s)`,
|
||||
'',
|
||||
theme.dim('Greenlist URL:'),
|
||||
` ${config.models.greenlistUrl}`,
|
||||
theme.dim('Registry URL:'),
|
||||
` ${modelConfig.registryUrl}`,
|
||||
],
|
||||
70,
|
||||
'default',
|
||||
);
|
||||
|
||||
logger.log('');
|
||||
logger.logBox(
|
||||
'Cluster',
|
||||
[
|
||||
`Enabled: ${clusterConfig.enabled ? theme.success('Yes') : theme.dim('No')}`,
|
||||
`Node: ${clusterConfig.nodeName}`,
|
||||
`Role: ${clusterConfig.role}`,
|
||||
`Bind Host: ${clusterConfig.bindHost}:${clusterConfig.gossipPort}`,
|
||||
`Shared Secret: ${
|
||||
clusterConfig.sharedSecret ? theme.success('Configured') : theme.dim('Not set')
|
||||
}`,
|
||||
`Advertise URL: ${clusterConfig.advertiseUrl || theme.dim('Default loopback')}`,
|
||||
`Control Plane: ${clusterConfig.controlPlaneUrl || theme.dim('Not configured')}`,
|
||||
`Heartbeat: ${clusterConfig.heartbeatIntervalMs}ms`,
|
||||
],
|
||||
70,
|
||||
'default',
|
||||
@@ -110,7 +149,7 @@ export class ConfigHandler {
|
||||
name: c.name,
|
||||
type: c.type,
|
||||
image: c.image.length > 40 ? c.image.substring(0, 37) + '...' : c.image,
|
||||
port: c.port,
|
||||
port: String(c.port),
|
||||
gpus: c.gpuIds.length > 0 ? c.gpuIds.join(',') : theme.dim('None'),
|
||||
}));
|
||||
|
||||
@@ -189,11 +228,22 @@ export class ConfigHandler {
|
||||
},
|
||||
containers: [],
|
||||
models: {
|
||||
greenlistUrl: 'https://code.foss.global/modelgrid.com/model_lists/raw/branch/main/greenlit.json',
|
||||
autoPull: true,
|
||||
defaultContainer: 'ollama',
|
||||
registryUrl: 'https://list.modelgrid.com/catalog/models.json',
|
||||
autoDeploy: true,
|
||||
defaultEngine: 'vllm',
|
||||
autoLoad: [],
|
||||
},
|
||||
cluster: {
|
||||
enabled: false,
|
||||
nodeName: 'modelgrid-local',
|
||||
role: 'standalone',
|
||||
bindHost: '0.0.0.0',
|
||||
gossipPort: 7946,
|
||||
sharedSecret: '',
|
||||
advertiseUrl: 'http://127.0.0.1:8080',
|
||||
heartbeatIntervalMs: 5000,
|
||||
seedNodes: [],
|
||||
},
|
||||
checkInterval: 30000,
|
||||
};
|
||||
|
||||
|
||||
+62
-138
@@ -1,47 +1,36 @@
|
||||
/**
|
||||
* Container Handler
|
||||
*
|
||||
* CLI commands for container management.
|
||||
* Deployment handler for container-backed runtimes.
|
||||
*/
|
||||
|
||||
import { logger } from '../logger.ts';
|
||||
import { theme } from '../colors.ts';
|
||||
import { ContainerManager } from '../containers/container-manager.ts';
|
||||
import { DockerManager } from '../docker/docker-manager.ts';
|
||||
import type { IContainerConfig } from '../interfaces/container.ts';
|
||||
import { VllmContainer } from '../containers/vllm.ts';
|
||||
import type { ITableColumn } from '../logger.ts';
|
||||
import * as helpers from '../helpers/index.ts';
|
||||
|
||||
/**
|
||||
* Handler for container-related CLI commands
|
||||
*/
|
||||
export class ContainerHandler {
|
||||
private containerManager: ContainerManager;
|
||||
private dockerManager: DockerManager;
|
||||
|
||||
constructor(containerManager: ContainerManager) {
|
||||
this.containerManager = containerManager;
|
||||
this.dockerManager = new DockerManager();
|
||||
}
|
||||
|
||||
/**
|
||||
* List all configured containers
|
||||
*/
|
||||
public async list(): Promise<void> {
|
||||
logger.log('');
|
||||
logger.info('Containers');
|
||||
logger.info('Deployments');
|
||||
logger.log('');
|
||||
|
||||
const containers = this.containerManager.getAllContainers();
|
||||
|
||||
if (containers.length === 0) {
|
||||
logger.logBox(
|
||||
'No Containers',
|
||||
'No Deployments',
|
||||
[
|
||||
'No containers are configured.',
|
||||
'No vLLM deployments are configured.',
|
||||
'',
|
||||
theme.dim('Add a container with:'),
|
||||
` ${theme.command('modelgrid container add')}`,
|
||||
theme.dim('Create one with:'),
|
||||
` ${theme.command('modelgrid run <model-id>')}`,
|
||||
],
|
||||
60,
|
||||
'warning',
|
||||
@@ -49,7 +38,7 @@ export class ContainerHandler {
|
||||
return;
|
||||
}
|
||||
|
||||
const rows = [];
|
||||
const rows: Record<string, string | number>[] = [];
|
||||
|
||||
for (const container of containers) {
|
||||
const status = await container.getStatus();
|
||||
@@ -57,28 +46,22 @@ export class ContainerHandler {
|
||||
|
||||
rows.push({
|
||||
id: config.id,
|
||||
name: config.name,
|
||||
type: this.formatContainerType(container.type),
|
||||
status: status.running
|
||||
? theme.success('Running')
|
||||
: theme.dim('Stopped'),
|
||||
health: status.running
|
||||
? this.formatHealth(status.health)
|
||||
: theme.dim('N/A'),
|
||||
port: config.externalPort || config.port,
|
||||
models: status.loadedModels.length,
|
||||
model: config.models[0] || theme.dim('N/A'),
|
||||
engine: this.formatContainerType(container.type),
|
||||
status: status.running ? theme.success('Running') : theme.dim('Stopped'),
|
||||
health: status.running ? this.formatHealth(status.health) : theme.dim('N/A'),
|
||||
port: String(config.externalPort || config.port),
|
||||
gpus: config.gpuIds.length > 0 ? config.gpuIds.join(',') : theme.dim('None'),
|
||||
});
|
||||
}
|
||||
|
||||
const columns: ITableColumn[] = [
|
||||
{ header: 'ID', key: 'id', align: 'left' },
|
||||
{ header: 'Name', key: 'name', align: 'left', color: theme.highlight },
|
||||
{ header: 'Type', key: 'type', align: 'left' },
|
||||
{ header: 'Model', key: 'model', align: 'left', color: theme.highlight },
|
||||
{ header: 'Engine', key: 'engine', align: 'left' },
|
||||
{ header: 'Status', key: 'status', align: 'left' },
|
||||
{ header: 'Health', key: 'health', align: 'left' },
|
||||
{ header: 'Port', key: 'port', align: 'right', color: theme.info },
|
||||
{ header: 'Models', key: 'models', align: 'right' },
|
||||
{ header: 'GPUs', key: 'gpus', align: 'left' },
|
||||
];
|
||||
|
||||
@@ -86,94 +69,70 @@ export class ContainerHandler {
|
||||
logger.log('');
|
||||
}
|
||||
|
||||
/**
|
||||
* Add a new container interactively
|
||||
*/
|
||||
public async add(): Promise<void> {
|
||||
const { prompt, close, select } = await helpers.createPrompt();
|
||||
const { prompt, close } = await helpers.createPrompt();
|
||||
|
||||
try {
|
||||
logger.log('');
|
||||
logger.highlight('Add Container');
|
||||
logger.dim('Configure a new AI model container');
|
||||
logger.highlight('Create vLLM Deployment');
|
||||
logger.dim('Provision a single-model vLLM runtime');
|
||||
logger.log('');
|
||||
|
||||
// Select container type
|
||||
const typeIndex = await select('Select container type:', [
|
||||
'Ollama - Easy to use, good for local models',
|
||||
'vLLM - High performance, OpenAI compatible',
|
||||
'TGI - HuggingFace Text Generation Inference',
|
||||
]);
|
||||
|
||||
const types = ['ollama', 'vllm', 'tgi'] as const;
|
||||
const containerType = types[typeIndex];
|
||||
|
||||
// Container name
|
||||
const name = await prompt('Container name: ');
|
||||
if (!name.trim()) {
|
||||
logger.error('Container name is required');
|
||||
const modelName = await prompt('Model ID or Hugging Face repo: ');
|
||||
if (!modelName.trim()) {
|
||||
logger.error('Model ID is required');
|
||||
return;
|
||||
}
|
||||
|
||||
// Generate ID from name
|
||||
const id = name.toLowerCase().replace(/[^a-z0-9-]/g, '-');
|
||||
const name = await prompt(
|
||||
`Deployment name [${modelName.split('/').pop() || 'deployment'}]: `,
|
||||
);
|
||||
const deploymentName = name.trim() || modelName.split('/').pop() || 'deployment';
|
||||
const deploymentId = deploymentName.toLowerCase().replace(/[^a-z0-9-]/g, '-');
|
||||
|
||||
// Port
|
||||
const defaultPorts = { ollama: 11434, vllm: 8000, tgi: 8080 };
|
||||
const portStr = await prompt(`Port [${defaultPorts[containerType]}]: `);
|
||||
const port = portStr ? parseInt(portStr, 10) : defaultPorts[containerType];
|
||||
const portStr = await prompt('Port [8000]: ');
|
||||
const port = portStr ? parseInt(portStr, 10) : 8000;
|
||||
|
||||
// GPU assignment
|
||||
const gpuStr = await prompt('GPU IDs (comma-separated, or "all", or empty for none): ');
|
||||
const gpuStr = await prompt('GPU IDs (comma-separated, or "all"): ');
|
||||
let gpuIds: string[] = [];
|
||||
|
||||
if (gpuStr.trim().toLowerCase() === 'all') {
|
||||
const { GpuDetector } = await import('../hardware/gpu-detector.ts');
|
||||
const detector = new GpuDetector();
|
||||
const gpus = await detector.detectGpus();
|
||||
gpuIds = gpus.map((g) => g.id);
|
||||
gpuIds = gpus.map((gpu) => gpu.id);
|
||||
} else if (gpuStr.trim()) {
|
||||
gpuIds = gpuStr.split(',').map((s) => s.trim());
|
||||
gpuIds = gpuStr.split(',').map((value) => value.trim());
|
||||
}
|
||||
|
||||
// Build config
|
||||
const config: IContainerConfig = {
|
||||
id,
|
||||
type: containerType,
|
||||
name,
|
||||
image: this.getDefaultImage(containerType),
|
||||
const config = VllmContainer.createConfig(deploymentId, deploymentName, modelName, gpuIds, {
|
||||
port,
|
||||
gpuIds,
|
||||
models: [],
|
||||
};
|
||||
});
|
||||
config.models = [modelName];
|
||||
|
||||
// Add container
|
||||
await this.containerManager.addContainer(config);
|
||||
this.containerManager.addContainer(config);
|
||||
|
||||
logger.log('');
|
||||
logger.success(`Container "${name}" added successfully`);
|
||||
logger.success(`Deployment "${deploymentName}" added successfully`);
|
||||
logger.log('');
|
||||
logger.dim('Start the container with:');
|
||||
logger.log(` ${theme.command(`modelgrid container start ${id}`)}`);
|
||||
logger.dim('Start it with:');
|
||||
logger.log(` ${theme.command(`modelgrid container start ${deploymentId}`)}`);
|
||||
logger.log('');
|
||||
} finally {
|
||||
close();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Remove a container
|
||||
*/
|
||||
public async remove(containerId: string): Promise<void> {
|
||||
if (!containerId) {
|
||||
logger.error('Container ID is required');
|
||||
logger.error('Deployment ID is required');
|
||||
return;
|
||||
}
|
||||
|
||||
const { prompt, close } = await helpers.createPrompt();
|
||||
|
||||
try {
|
||||
const confirm = await prompt(`Remove container "${containerId}"? (y/N): `);
|
||||
const confirm = await prompt(`Remove deployment "${containerId}"? (y/N): `);
|
||||
|
||||
if (confirm.toLowerCase() !== 'y') {
|
||||
logger.log('Aborted');
|
||||
@@ -183,83 +142,72 @@ export class ContainerHandler {
|
||||
const success = await this.containerManager.removeContainer(containerId);
|
||||
|
||||
if (success) {
|
||||
logger.success(`Container "${containerId}" removed`);
|
||||
logger.success(`Deployment "${containerId}" removed`);
|
||||
} else {
|
||||
logger.error(`Failed to remove container "${containerId}"`);
|
||||
logger.error(`Failed to remove deployment "${containerId}"`);
|
||||
}
|
||||
} finally {
|
||||
close();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Start a container
|
||||
*/
|
||||
public async start(containerId?: string): Promise<void> {
|
||||
if (containerId) {
|
||||
// Start specific container
|
||||
const container = this.containerManager.getContainer(containerId);
|
||||
if (!container) {
|
||||
logger.error(`Container "${containerId}" not found`);
|
||||
logger.error(`Deployment "${containerId}" not found`);
|
||||
return;
|
||||
}
|
||||
|
||||
logger.info(`Starting container "${containerId}"...`);
|
||||
logger.info(`Starting deployment "${containerId}"...`);
|
||||
const success = await container.start();
|
||||
|
||||
if (success) {
|
||||
logger.success(`Container "${containerId}" started`);
|
||||
logger.success(`Deployment "${containerId}" started`);
|
||||
} else {
|
||||
logger.error(`Failed to start container "${containerId}"`);
|
||||
logger.error(`Failed to start deployment "${containerId}"`);
|
||||
}
|
||||
} else {
|
||||
// Start all containers
|
||||
logger.info('Starting all containers...');
|
||||
await this.containerManager.startAll();
|
||||
logger.success('All containers started');
|
||||
return;
|
||||
}
|
||||
|
||||
logger.info('Starting all deployments...');
|
||||
await this.containerManager.startAll();
|
||||
logger.success('All deployments started');
|
||||
}
|
||||
|
||||
/**
|
||||
* Stop a container
|
||||
*/
|
||||
public async stop(containerId?: string): Promise<void> {
|
||||
if (containerId) {
|
||||
// Stop specific container
|
||||
const container = this.containerManager.getContainer(containerId);
|
||||
if (!container) {
|
||||
logger.error(`Container "${containerId}" not found`);
|
||||
logger.error(`Deployment "${containerId}" not found`);
|
||||
return;
|
||||
}
|
||||
|
||||
logger.info(`Stopping container "${containerId}"...`);
|
||||
logger.info(`Stopping deployment "${containerId}"...`);
|
||||
const success = await container.stop();
|
||||
|
||||
if (success) {
|
||||
logger.success(`Container "${containerId}" stopped`);
|
||||
logger.success(`Deployment "${containerId}" stopped`);
|
||||
} else {
|
||||
logger.error(`Failed to stop container "${containerId}"`);
|
||||
logger.error(`Failed to stop deployment "${containerId}"`);
|
||||
}
|
||||
} else {
|
||||
// Stop all containers
|
||||
logger.info('Stopping all containers...');
|
||||
await this.containerManager.stopAll();
|
||||
logger.success('All containers stopped');
|
||||
return;
|
||||
}
|
||||
|
||||
logger.info('Stopping all deployments...');
|
||||
await this.containerManager.stopAll();
|
||||
logger.success('All deployments stopped');
|
||||
}
|
||||
|
||||
/**
|
||||
* Show container logs
|
||||
*/
|
||||
public async logs(containerId: string, lines: number = 100): Promise<void> {
|
||||
if (!containerId) {
|
||||
logger.error('Container ID is required');
|
||||
logger.error('Deployment ID is required');
|
||||
return;
|
||||
}
|
||||
|
||||
const container = this.containerManager.getContainer(containerId);
|
||||
if (!container) {
|
||||
logger.error(`Container "${containerId}" not found`);
|
||||
logger.error(`Deployment "${containerId}" not found`);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -267,13 +215,8 @@ export class ContainerHandler {
|
||||
console.log(logs);
|
||||
}
|
||||
|
||||
/**
|
||||
* Format container type for display
|
||||
*/
|
||||
private formatContainerType(type: string): string {
|
||||
switch (type) {
|
||||
case 'ollama':
|
||||
return theme.containerOllama('Ollama');
|
||||
case 'vllm':
|
||||
return theme.containerVllm('vLLM');
|
||||
case 'tgi':
|
||||
@@ -283,9 +226,6 @@ export class ContainerHandler {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Format health status
|
||||
*/
|
||||
private formatHealth(health: string): string {
|
||||
switch (health) {
|
||||
case 'healthy':
|
||||
@@ -298,20 +238,4 @@ export class ContainerHandler {
|
||||
return theme.dim(health);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get default image for container type
|
||||
*/
|
||||
private getDefaultImage(type: string): string {
|
||||
switch (type) {
|
||||
case 'ollama':
|
||||
return 'ollama/ollama:latest';
|
||||
case 'vllm':
|
||||
return 'vllm/vllm-openai:latest';
|
||||
case 'tgi':
|
||||
return 'ghcr.io/huggingface/text-generation-inference:latest';
|
||||
default:
|
||||
return '';
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
+38
-16
@@ -86,22 +86,30 @@ export class GpuHandler {
|
||||
logger.info('GPU Status');
|
||||
logger.log('');
|
||||
|
||||
const gpuStatus = await this.gpuDetector.getGpuStatus();
|
||||
const gpuInfo = await this.gpuDetector.detectGpus();
|
||||
const gpuStatus = await this.gpuDetector.getAllGpuStatus();
|
||||
|
||||
if (gpuStatus.length === 0) {
|
||||
if (gpuStatus.size === 0) {
|
||||
logger.warn('No GPUs detected');
|
||||
return;
|
||||
}
|
||||
|
||||
for (const gpu of gpuStatus) {
|
||||
const utilizationBar = this.createProgressBar(gpu.utilization, 30);
|
||||
const memoryBar = this.createProgressBar(gpu.memoryUsed / gpu.memoryTotal * 100, 30);
|
||||
for (const [gpuId, status] of gpuStatus) {
|
||||
const info = gpuInfo.find((gpu) => gpu.id === gpuId);
|
||||
const utilizationBar = this.createProgressBar(status.utilization, 30);
|
||||
const memoryBar = this.createProgressBar(status.memoryUsed / status.memoryTotal * 100, 30);
|
||||
|
||||
logger.logBoxTitle(`GPU ${gpu.id}: ${gpu.name}`, 70, 'info');
|
||||
logger.logBoxLine(`Utilization: ${utilizationBar} ${gpu.utilization.toFixed(1)}%`);
|
||||
logger.logBoxLine(`Memory: ${memoryBar} ${Math.round(gpu.memoryUsed)}/${Math.round(gpu.memoryTotal)} MB`);
|
||||
logger.logBoxLine(`Temperature: ${this.formatTemperature(gpu.temperature)}`);
|
||||
logger.logBoxLine(`Power: ${gpu.powerDraw.toFixed(0)}W / ${gpu.powerLimit.toFixed(0)}W`);
|
||||
logger.logBoxTitle(`GPU ${status.id}: ${info?.model || 'Unknown GPU'}`, 70, 'info');
|
||||
logger.logBoxLine(`Utilization: ${utilizationBar} ${status.utilization.toFixed(1)}%`);
|
||||
logger.logBoxLine(
|
||||
`Memory: ${memoryBar} ${Math.round(status.memoryUsed)}/${
|
||||
Math.round(status.memoryTotal)
|
||||
} MB`,
|
||||
);
|
||||
logger.logBoxLine(`Temperature: ${this.formatTemperature(status.temperature)}`);
|
||||
logger.logBoxLine(
|
||||
`Power: ${status.powerUsage.toFixed(0)}W / ${status.powerLimit.toFixed(0)}W`,
|
||||
);
|
||||
logger.logBoxEnd();
|
||||
logger.log('');
|
||||
}
|
||||
@@ -138,13 +146,23 @@ export class GpuHandler {
|
||||
|
||||
const status = await driver.getStatus();
|
||||
|
||||
logger.logBoxTitle(`${this.formatVendor(vendor)} Driver`, 60, status.installed ? 'success' : 'warning');
|
||||
logger.logBoxLine(`Installed: ${status.installed ? theme.success('Yes') : theme.error('No')}`);
|
||||
logger.logBoxTitle(
|
||||
`${this.formatVendor(vendor)} Driver`,
|
||||
60,
|
||||
status.installed ? 'success' : 'warning',
|
||||
);
|
||||
logger.logBoxLine(
|
||||
`Installed: ${status.installed ? theme.success('Yes') : theme.error('No')}`,
|
||||
);
|
||||
|
||||
if (status.installed) {
|
||||
logger.logBoxLine(`Version: ${status.version || 'Unknown'}`);
|
||||
logger.logBoxLine(`Runtime: ${status.runtimeVersion || 'Unknown'}`);
|
||||
logger.logBoxLine(`Container Support: ${status.containerSupport ? theme.success('Yes') : theme.warning('No')}`);
|
||||
logger.logBoxLine(`Runtime: ${status.containerRuntimeVersion || 'Unknown'}`);
|
||||
logger.logBoxLine(
|
||||
`Container Support: ${
|
||||
status.containerSupport ? theme.success('Yes') : theme.warning('No')
|
||||
}`,
|
||||
);
|
||||
} else {
|
||||
logger.logBoxLine('');
|
||||
logger.logBoxLine(theme.dim('Run `modelgrid gpu install` to install drivers'));
|
||||
@@ -183,14 +201,18 @@ export class GpuHandler {
|
||||
|
||||
logger.info(`Installing ${this.formatVendor(vendor)} drivers...`);
|
||||
|
||||
const success = await driver.install();
|
||||
const success = await driver.install({
|
||||
installToolkit: true,
|
||||
installContainerSupport: true,
|
||||
nonInteractive: false,
|
||||
});
|
||||
|
||||
if (success) {
|
||||
logger.success(`${this.formatVendor(vendor)} drivers installed successfully`);
|
||||
|
||||
// Setup container support
|
||||
logger.info('Setting up container support...');
|
||||
const containerSuccess = await driver.setupContainer();
|
||||
const containerSuccess = await driver.installContainerSupport();
|
||||
|
||||
if (containerSuccess) {
|
||||
logger.success('Container support configured');
|
||||
|
||||
+61
-89
@@ -1,55 +1,48 @@
|
||||
/**
|
||||
* Model Handler
|
||||
*
|
||||
* CLI commands for model management.
|
||||
* Model handler for catalog-backed vLLM deployments.
|
||||
*/
|
||||
|
||||
import { logger } from '../logger.ts';
|
||||
import { theme } from '../colors.ts';
|
||||
import { ClusterCoordinator } from '../cluster/coordinator.ts';
|
||||
import { ContainerManager } from '../containers/container-manager.ts';
|
||||
import { ModelRegistry } from '../models/registry.ts';
|
||||
import { ModelLoader } from '../models/loader.ts';
|
||||
import type { ITableColumn } from '../logger.ts';
|
||||
|
||||
/**
|
||||
* Handler for model-related CLI commands
|
||||
*/
|
||||
export class ModelHandler {
|
||||
private containerManager: ContainerManager;
|
||||
private clusterCoordinator: ClusterCoordinator;
|
||||
private modelRegistry: ModelRegistry;
|
||||
private modelLoader: ModelLoader;
|
||||
|
||||
constructor(
|
||||
containerManager: ContainerManager,
|
||||
clusterCoordinator: ClusterCoordinator,
|
||||
modelRegistry: ModelRegistry,
|
||||
) {
|
||||
this.containerManager = containerManager;
|
||||
this.clusterCoordinator = clusterCoordinator;
|
||||
this.modelRegistry = modelRegistry;
|
||||
this.modelLoader = new ModelLoader(modelRegistry, containerManager);
|
||||
}
|
||||
|
||||
/**
|
||||
* List all available models
|
||||
*/
|
||||
public async list(): Promise<void> {
|
||||
logger.log('');
|
||||
logger.info('Models');
|
||||
logger.info('Model Catalog');
|
||||
logger.log('');
|
||||
|
||||
// Get loaded models from containers
|
||||
const loadedModels = await this.containerManager.getAllAvailableModels();
|
||||
const catalogModels = await this.modelRegistry.getAllModels();
|
||||
|
||||
// Get greenlit models
|
||||
const greenlitModels = await this.modelRegistry.getAllGreenlitModels();
|
||||
|
||||
if (loadedModels.size === 0 && greenlitModels.length === 0) {
|
||||
if (loadedModels.size === 0 && catalogModels.length === 0) {
|
||||
logger.logBox(
|
||||
'No Models',
|
||||
[
|
||||
'No models are loaded or greenlit.',
|
||||
'The local registry cache is empty.',
|
||||
'',
|
||||
theme.dim('Pull a model with:'),
|
||||
` ${theme.command('modelgrid model pull <name>')}`,
|
||||
theme.dim('Refresh with:'),
|
||||
` ${theme.command('modelgrid model refresh')}`,
|
||||
],
|
||||
60,
|
||||
'warning',
|
||||
@@ -57,56 +50,51 @@ export class ModelHandler {
|
||||
return;
|
||||
}
|
||||
|
||||
// Show loaded models
|
||||
if (loadedModels.size > 0) {
|
||||
logger.info(`Loaded Models (${loadedModels.size}):`);
|
||||
logger.info(`Running Deployments (${loadedModels.size}):`);
|
||||
logger.log('');
|
||||
|
||||
const rows = [];
|
||||
for (const [name, info] of loadedModels) {
|
||||
const rows: Record<string, string | number>[] = [];
|
||||
for (const [name, endpoints] of loadedModels) {
|
||||
const primaryEndpoint = endpoints[0];
|
||||
rows.push({
|
||||
name,
|
||||
container: info.container,
|
||||
size: info.size ? this.formatSize(info.size) : theme.dim('N/A'),
|
||||
format: info.format || theme.dim('N/A'),
|
||||
modified: info.modifiedAt
|
||||
? new Date(info.modifiedAt).toLocaleDateString()
|
||||
: theme.dim('N/A'),
|
||||
model: name,
|
||||
engine: primaryEndpoint?.type || 'vllm',
|
||||
replicas: String(endpoints.length),
|
||||
endpoint: primaryEndpoint?.url || theme.dim('N/A'),
|
||||
});
|
||||
}
|
||||
|
||||
const columns: ITableColumn[] = [
|
||||
{ header: 'Name', key: 'name', align: 'left', color: theme.highlight },
|
||||
{ header: 'Container', key: 'container', align: 'left' },
|
||||
{ header: 'Size', key: 'size', align: 'right', color: theme.info },
|
||||
{ header: 'Format', key: 'format', align: 'left' },
|
||||
{ header: 'Modified', key: 'modified', align: 'left', color: theme.dim },
|
||||
{ header: 'Model', key: 'model', align: 'left', color: theme.highlight },
|
||||
{ header: 'Engine', key: 'engine', align: 'left' },
|
||||
{ header: 'Replicas', key: 'replicas', align: 'right', color: theme.info },
|
||||
{ header: 'Endpoint', key: 'endpoint', align: 'left', color: theme.dim },
|
||||
];
|
||||
|
||||
logger.logTable(columns, rows);
|
||||
logger.log('');
|
||||
}
|
||||
|
||||
// Show greenlit models (not yet loaded)
|
||||
const loadedNames = new Set(loadedModels.keys());
|
||||
const unloadedGreenlit = greenlitModels.filter((m) => !loadedNames.has(m.name));
|
||||
const available = catalogModels.filter((model) => !loadedNames.has(model.id));
|
||||
|
||||
if (unloadedGreenlit.length > 0) {
|
||||
logger.info(`Available to Pull (${unloadedGreenlit.length}):`);
|
||||
if (available.length > 0) {
|
||||
logger.info(`Available To Deploy (${available.length}):`);
|
||||
logger.log('');
|
||||
|
||||
const rows = unloadedGreenlit.map((m) => ({
|
||||
name: m.name,
|
||||
container: m.container,
|
||||
vram: `${m.minVram} GB`,
|
||||
tags: m.tags?.join(', ') || theme.dim('None'),
|
||||
const rows: Record<string, string | number>[] = available.map((model) => ({
|
||||
model: model.id,
|
||||
family: model.metadata?.family || theme.dim('N/A'),
|
||||
vram: `${model.requirements.minVramGb} GB`,
|
||||
capabilities: this.formatCapabilities(model.capabilities),
|
||||
}));
|
||||
|
||||
const columns: ITableColumn[] = [
|
||||
{ header: 'Name', key: 'name', align: 'left' },
|
||||
{ header: 'Container', key: 'container', align: 'left' },
|
||||
{ header: 'Model', key: 'model', align: 'left' },
|
||||
{ header: 'Family', key: 'family', align: 'left' },
|
||||
{ header: 'Min VRAM', key: 'vram', align: 'right', color: theme.info },
|
||||
{ header: 'Tags', key: 'tags', align: 'left', color: theme.dim },
|
||||
{ header: 'Capabilities', key: 'capabilities', align: 'left', color: theme.dim },
|
||||
];
|
||||
|
||||
logger.logTable(columns, rows);
|
||||
@@ -114,47 +102,42 @@ export class ModelHandler {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Pull a model
|
||||
*/
|
||||
public async pull(modelName: string): Promise<void> {
|
||||
if (!modelName) {
|
||||
logger.error('Model name is required');
|
||||
logger.error('Model ID is required');
|
||||
return;
|
||||
}
|
||||
|
||||
logger.log('');
|
||||
logger.info(`Pulling model: ${modelName}`);
|
||||
logger.info(`Deploying model: ${modelName}`);
|
||||
logger.log('');
|
||||
|
||||
const result = await this.modelLoader.loadModel(modelName);
|
||||
const result = await this.clusterCoordinator.ensureModelViaControlPlane(modelName);
|
||||
|
||||
if (result.success) {
|
||||
if (result.alreadyLoaded) {
|
||||
logger.success(`Model "${modelName}" is already loaded`);
|
||||
if (result) {
|
||||
if (result.created) {
|
||||
logger.success(`Model "${result.model}" deployed successfully`);
|
||||
} else {
|
||||
logger.success(`Model "${modelName}" pulled successfully`);
|
||||
}
|
||||
if (result.container) {
|
||||
logger.dim(`Container: ${result.container}`);
|
||||
logger.success(`Model "${result.model}" is already available`);
|
||||
}
|
||||
logger.dim(`Node: ${result.location.nodeName}`);
|
||||
logger.dim(`Endpoint: ${result.location.endpoint}`);
|
||||
} else {
|
||||
logger.error(`Failed to pull model: ${result.error}`);
|
||||
logger.error(`Failed to deploy model: could not schedule ${modelName}`);
|
||||
}
|
||||
|
||||
logger.log('');
|
||||
}
|
||||
|
||||
/**
|
||||
* Remove a model
|
||||
*/
|
||||
public async remove(modelName: string): Promise<void> {
|
||||
if (!modelName) {
|
||||
logger.error('Model name is required');
|
||||
logger.error('Model ID is required');
|
||||
return;
|
||||
}
|
||||
|
||||
logger.info(`Removing model: ${modelName}`);
|
||||
logger.info(`Removing deployment for model: ${modelName}`);
|
||||
|
||||
await this.clusterCoordinator.clearDesiredDeployment(modelName);
|
||||
|
||||
const success = await this.modelLoader.unloadModel(modelName);
|
||||
|
||||
@@ -165,38 +148,27 @@ export class ModelHandler {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Show model loading status and recommendations
|
||||
*/
|
||||
public async status(): Promise<void> {
|
||||
logger.log('');
|
||||
await this.modelLoader.printStatus();
|
||||
}
|
||||
|
||||
/**
|
||||
* Refresh greenlist cache
|
||||
*/
|
||||
public async refresh(): Promise<void> {
|
||||
logger.info('Refreshing greenlist...');
|
||||
|
||||
await this.modelRegistry.refreshGreenlist();
|
||||
|
||||
logger.success('Greenlist refreshed');
|
||||
logger.info('Refreshing model catalog...');
|
||||
await this.modelRegistry.fetchCatalog(true);
|
||||
logger.success('Model catalog refreshed');
|
||||
}
|
||||
|
||||
/**
|
||||
* Format file size
|
||||
*/
|
||||
private formatSize(bytes: number): string {
|
||||
const units = ['B', 'KB', 'MB', 'GB', 'TB'];
|
||||
let size = bytes;
|
||||
let unitIndex = 0;
|
||||
private formatCapabilities(capabilities: {
|
||||
chat?: boolean;
|
||||
completions?: boolean;
|
||||
embeddings?: boolean;
|
||||
tools?: boolean;
|
||||
}): string {
|
||||
const enabled = Object.entries(capabilities)
|
||||
.filter(([, value]) => value)
|
||||
.map(([key]) => key);
|
||||
|
||||
while (size >= 1024 && unitIndex < units.length - 1) {
|
||||
size /= 1024;
|
||||
unitIndex++;
|
||||
}
|
||||
|
||||
return `${size.toFixed(1)} ${units[unitIndex]}`;
|
||||
return enabled.length > 0 ? enabled.join(', ') : theme.dim('none');
|
||||
}
|
||||
}
|
||||
|
||||
@@ -27,7 +27,9 @@ export class ServiceHandler {
|
||||
public async enable(): Promise<void> {
|
||||
this.checkRootAccess('This command must be run as root.');
|
||||
await this.modelgrid.getSystemd().install();
|
||||
logger.log('ModelGrid service has been installed. Use "modelgrid service start" to start the service.');
|
||||
logger.log(
|
||||
'ModelGrid service has been installed. Use "modelgrid service start" to start the service.',
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -41,7 +43,9 @@ export class ServiceHandler {
|
||||
}
|
||||
await this.modelgrid.getDaemon().start();
|
||||
} catch (error) {
|
||||
logger.error(`Daemon start failed: ${error instanceof Error ? error.message : String(error)}`);
|
||||
logger.error(
|
||||
`Daemon start failed: ${error instanceof Error ? error.message : String(error)}`,
|
||||
);
|
||||
process.exit(1);
|
||||
}
|
||||
}
|
||||
@@ -127,13 +131,18 @@ export class ServiceHandler {
|
||||
|
||||
try {
|
||||
const currentVersion = this.modelgrid.getVersion();
|
||||
const apiUrl = 'https://code.foss.global/api/v1/repos/modelgrid.com/modelgrid/releases/latest';
|
||||
const apiUrl =
|
||||
'https://code.foss.global/api/v1/repos/modelgrid.com/modelgrid/releases/latest';
|
||||
const response = execSync(`curl -sSL ${apiUrl}`).toString();
|
||||
const release = JSON.parse(response);
|
||||
const latestVersion = release.tag_name;
|
||||
|
||||
const normalizedCurrent = currentVersion.startsWith('v') ? currentVersion : `v${currentVersion}`;
|
||||
const normalizedLatest = latestVersion.startsWith('v') ? latestVersion : `v${latestVersion}`;
|
||||
const normalizedCurrent = currentVersion.startsWith('v')
|
||||
? currentVersion
|
||||
: `v${currentVersion}`;
|
||||
const normalizedLatest = latestVersion.startsWith('v')
|
||||
? latestVersion
|
||||
: `v${latestVersion}`;
|
||||
|
||||
logger.dim(`Current version: ${normalizedCurrent}`);
|
||||
logger.dim(`Latest version: ${normalizedLatest}`);
|
||||
@@ -149,7 +158,8 @@ export class ServiceHandler {
|
||||
logger.dim('Downloading and installing...');
|
||||
console.log('');
|
||||
|
||||
const installUrl = 'https://code.foss.global/modelgrid.com/modelgrid/raw/branch/main/install.sh';
|
||||
const installUrl =
|
||||
'https://code.foss.global/modelgrid.com/modelgrid/raw/branch/main/install.sh';
|
||||
|
||||
execSync(`curl -sSL ${installUrl} | bash`, {
|
||||
stdio: 'inherit',
|
||||
|
||||
@@ -0,0 +1,456 @@
|
||||
import os from 'node:os';
|
||||
import * as fs from 'node:fs/promises';
|
||||
import type { IModelCatalogEntry } from '../interfaces/catalog.ts';
|
||||
import type {
|
||||
IClusterConfig,
|
||||
IClusterDesiredDeployment,
|
||||
IClusterGpuTopologyGroup,
|
||||
IClusterModelLocation,
|
||||
IClusterNodeHeartbeat,
|
||||
IClusterNodeStatus,
|
||||
IClusterStatusResponse,
|
||||
TClusterNodeSchedulerState,
|
||||
} from '../interfaces/cluster.ts';
|
||||
import { CLUSTER, PATHS } from '../constants.ts';
|
||||
|
||||
export class ClusterManager {
|
||||
private config: IClusterConfig = {
|
||||
enabled: false,
|
||||
nodeName: os.hostname(),
|
||||
role: 'standalone',
|
||||
bindHost: CLUSTER.DEFAULT_BIND_HOST,
|
||||
gossipPort: CLUSTER.DEFAULT_GOSSIP_PORT,
|
||||
heartbeatIntervalMs: CLUSTER.DEFAULT_HEARTBEAT_INTERVAL_MS,
|
||||
seedNodes: [],
|
||||
};
|
||||
private localNode: IClusterNodeHeartbeat | null = null;
|
||||
private knownNodes = new Map<string, IClusterNodeHeartbeat>();
|
||||
private desiredDeployments = new Map<string, IClusterDesiredDeployment>();
|
||||
private nodeSchedulerStates = new Map<string, TClusterNodeSchedulerState>();
|
||||
private persistQueued = false;
|
||||
private controlPersistQueued = false;
|
||||
|
||||
public async initialize(): Promise<void> {
|
||||
try {
|
||||
const stateContent = await fs.readFile(this.getStateFilePath(), 'utf-8');
|
||||
const data = JSON.parse(stateContent) as { nodes?: IClusterNodeHeartbeat[] };
|
||||
|
||||
for (const node of data.nodes || []) {
|
||||
this.knownNodes.set(node.nodeName, node);
|
||||
if (node.nodeName === this.config.nodeName) {
|
||||
this.localNode = node;
|
||||
}
|
||||
}
|
||||
|
||||
this.pruneStaleNodes();
|
||||
} catch {
|
||||
// No persisted cluster state yet.
|
||||
}
|
||||
|
||||
try {
|
||||
const controlStateContent = await fs.readFile(this.getControlStateFilePath(), 'utf-8');
|
||||
const data = JSON.parse(controlStateContent) as {
|
||||
desiredDeployments?: IClusterDesiredDeployment[];
|
||||
nodeSchedulerStates?: Record<string, TClusterNodeSchedulerState>;
|
||||
};
|
||||
|
||||
for (const deployment of data.desiredDeployments || []) {
|
||||
this.desiredDeployments.set(deployment.modelId, deployment);
|
||||
}
|
||||
|
||||
for (const [nodeName, schedulerState] of Object.entries(data.nodeSchedulerStates || {})) {
|
||||
this.nodeSchedulerStates.set(nodeName, schedulerState);
|
||||
}
|
||||
} catch {
|
||||
// No persisted control state yet.
|
||||
}
|
||||
}
|
||||
|
||||
public configure(config: IClusterConfig): void {
|
||||
this.config = {
|
||||
...config,
|
||||
heartbeatIntervalMs: config.heartbeatIntervalMs || CLUSTER.DEFAULT_HEARTBEAT_INTERVAL_MS,
|
||||
seedNodes: config.seedNodes || [],
|
||||
};
|
||||
}
|
||||
|
||||
public getConfig(): IClusterConfig {
|
||||
return this.config;
|
||||
}
|
||||
|
||||
public isEnabled(): boolean {
|
||||
return this.config.enabled;
|
||||
}
|
||||
|
||||
public isControlPlane(): boolean {
|
||||
return this.config.enabled && this.config.role === 'control-plane';
|
||||
}
|
||||
|
||||
public isWorker(): boolean {
|
||||
return this.config.enabled && this.config.role === 'worker';
|
||||
}
|
||||
|
||||
public getModeLabel(): string {
|
||||
if (!this.config.enabled) {
|
||||
return 'standalone';
|
||||
}
|
||||
|
||||
return this.config.role;
|
||||
}
|
||||
|
||||
public getHeartbeatIntervalMs(): number {
|
||||
return this.config.heartbeatIntervalMs || CLUSTER.DEFAULT_HEARTBEAT_INTERVAL_MS;
|
||||
}
|
||||
|
||||
public getAdvertisedEndpoint(): string | undefined {
|
||||
return this.localNode?.endpoint || this.config.advertiseUrl;
|
||||
}
|
||||
|
||||
public getControlPlaneUrl(): string | undefined {
|
||||
return this.config.controlPlaneUrl;
|
||||
}
|
||||
|
||||
public getSharedSecret(): string | undefined {
|
||||
return this.config.sharedSecret || undefined;
|
||||
}
|
||||
|
||||
public updateLocalNode(heartbeat: IClusterNodeHeartbeat): void {
|
||||
this.localNode = heartbeat;
|
||||
this.knownNodes.set(heartbeat.nodeName, heartbeat);
|
||||
this.schedulePersist();
|
||||
}
|
||||
|
||||
public upsertNode(heartbeat: IClusterNodeHeartbeat): void {
|
||||
this.knownNodes.set(heartbeat.nodeName, heartbeat);
|
||||
this.schedulePersist();
|
||||
}
|
||||
|
||||
public getLocalNodeStatus(): IClusterNodeStatus {
|
||||
return {
|
||||
nodeName: this.config.nodeName,
|
||||
role: this.config.role,
|
||||
endpoint: this.getAdvertisedEndpoint(),
|
||||
healthy: true,
|
||||
schedulerState: this.getNodeSchedulerState(this.config.nodeName),
|
||||
};
|
||||
}
|
||||
|
||||
public getLocalNode(): IClusterNodeHeartbeat | null {
|
||||
return this.localNode;
|
||||
}
|
||||
|
||||
public getNode(nodeName: string): IClusterNodeHeartbeat | null {
|
||||
const node = this.knownNodes.get(nodeName);
|
||||
if (!node) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return this.decorateNode(node);
|
||||
}
|
||||
|
||||
public pruneStaleNodes(): void {
|
||||
const now = Date.now();
|
||||
for (const [nodeName, node] of this.knownNodes) {
|
||||
if (nodeName === this.config.nodeName) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (now - node.lastSeenAt > CLUSTER.NODE_STALE_AFTER_MS) {
|
||||
this.knownNodes.delete(nodeName);
|
||||
this.schedulePersist();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public getAllNodes(): IClusterNodeHeartbeat[] {
|
||||
this.pruneStaleNodes();
|
||||
return Array.from(this.knownNodes.values()).map((node) => this.decorateNode(node)).sort(
|
||||
(left, right) => {
|
||||
if (left.nodeName === this.config.nodeName) {
|
||||
return -1;
|
||||
}
|
||||
if (right.nodeName === this.config.nodeName) {
|
||||
return 1;
|
||||
}
|
||||
return left.nodeName.localeCompare(right.nodeName);
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
public getHealthyNodes(): IClusterNodeHeartbeat[] {
|
||||
return this.getAllNodes().filter((node) => node.healthy);
|
||||
}
|
||||
|
||||
public getNodeSchedulerState(nodeName: string): TClusterNodeSchedulerState {
|
||||
return this.nodeSchedulerStates.get(nodeName) || 'active';
|
||||
}
|
||||
|
||||
public setNodeSchedulerState(
|
||||
nodeName: string,
|
||||
schedulerState: TClusterNodeSchedulerState,
|
||||
): TClusterNodeSchedulerState {
|
||||
this.nodeSchedulerStates.set(nodeName, schedulerState);
|
||||
this.scheduleControlPersist();
|
||||
return schedulerState;
|
||||
}
|
||||
|
||||
public getDesiredDeployments(): IClusterDesiredDeployment[] {
|
||||
return Array.from(this.desiredDeployments.values()).sort((left, right) =>
|
||||
left.modelId.localeCompare(right.modelId)
|
||||
);
|
||||
}
|
||||
|
||||
public getDesiredDeployment(modelId: string): IClusterDesiredDeployment | null {
|
||||
return this.desiredDeployments.get(modelId) || null;
|
||||
}
|
||||
|
||||
public upsertDesiredDeployment(
|
||||
modelId: string,
|
||||
desiredReplicas: number,
|
||||
): IClusterDesiredDeployment {
|
||||
const deployment: IClusterDesiredDeployment = {
|
||||
modelId,
|
||||
desiredReplicas,
|
||||
updatedAt: Date.now(),
|
||||
};
|
||||
this.desiredDeployments.set(modelId, deployment);
|
||||
this.scheduleControlPersist();
|
||||
return deployment;
|
||||
}
|
||||
|
||||
public removeDesiredDeployment(modelId: string): boolean {
|
||||
const removed = this.desiredDeployments.delete(modelId);
|
||||
if (removed) {
|
||||
this.scheduleControlPersist();
|
||||
}
|
||||
return removed;
|
||||
}
|
||||
|
||||
public getModelLocations(modelId: string): IClusterModelLocation[] {
|
||||
const locations: IClusterModelLocation[] = [];
|
||||
|
||||
for (const node of this.getHealthyNodes()) {
|
||||
for (const deployment of node.deployments) {
|
||||
if (deployment.modelId !== modelId || !deployment.healthy) {
|
||||
continue;
|
||||
}
|
||||
|
||||
locations.push({
|
||||
modelId,
|
||||
nodeName: node.nodeName,
|
||||
endpoint: deployment.endpoint,
|
||||
healthy: deployment.healthy,
|
||||
engine: deployment.engine,
|
||||
containerId: deployment.containerId,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
return locations;
|
||||
}
|
||||
|
||||
public getActiveModelLocations(modelId: string): IClusterModelLocation[] {
|
||||
return this.getModelLocations(modelId).filter((location) =>
|
||||
this.getNodeSchedulerState(location.nodeName) === 'active'
|
||||
);
|
||||
}
|
||||
|
||||
public resolveModel(modelId: string): IClusterModelLocation | null {
|
||||
const locations = this.getModelLocations(modelId);
|
||||
if (locations.length === 0) {
|
||||
return null;
|
||||
}
|
||||
|
||||
locations.sort((left, right) => {
|
||||
const schedulerPreference = this.compareSchedulerState(
|
||||
this.getNodeSchedulerState(left.nodeName),
|
||||
this.getNodeSchedulerState(right.nodeName),
|
||||
);
|
||||
if (schedulerPreference !== 0) {
|
||||
return schedulerPreference;
|
||||
}
|
||||
|
||||
if (left.nodeName === this.config.nodeName) {
|
||||
return -1;
|
||||
}
|
||||
if (right.nodeName === this.config.nodeName) {
|
||||
return 1;
|
||||
}
|
||||
return left.nodeName.localeCompare(right.nodeName);
|
||||
});
|
||||
|
||||
return locations[0];
|
||||
}
|
||||
|
||||
public pickNodeForModel(
|
||||
model: IModelCatalogEntry,
|
||||
excludedNodeNames: string[] = [],
|
||||
): IClusterNodeHeartbeat | null {
|
||||
const requiredVram = model.requirements.minVramGb;
|
||||
const minGpuCount = model.requirements.minGpuCount || 1;
|
||||
const preferredTensorParallel = model.launchDefaults?.tensorParallelSize || minGpuCount;
|
||||
|
||||
const eligible = this.getHealthyNodes().filter((node) => {
|
||||
if (excludedNodeNames.includes(node.nodeName)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (node.role === 'standalone' && node.nodeName !== this.config.nodeName) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (node.schedulerState && node.schedulerState !== 'active') {
|
||||
return false;
|
||||
}
|
||||
|
||||
return node.resources.availableVramGb >= requiredVram &&
|
||||
this.hasEligibleTopologyGroup(node.resources.topologyGroups, requiredVram, minGpuCount);
|
||||
});
|
||||
|
||||
if (eligible.length === 0) {
|
||||
return null;
|
||||
}
|
||||
|
||||
eligible.sort((left, right) => {
|
||||
if (left.nodeName === this.config.nodeName) {
|
||||
return -1;
|
||||
}
|
||||
if (right.nodeName === this.config.nodeName) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (right.resources.availableVramGb !== left.resources.availableVramGb) {
|
||||
return right.resources.availableVramGb - left.resources.availableVramGb;
|
||||
}
|
||||
|
||||
const leftTopologyDelta = Math.abs(
|
||||
left.resources.largestGpuGroupCount - preferredTensorParallel,
|
||||
);
|
||||
const rightTopologyDelta = Math.abs(
|
||||
right.resources.largestGpuGroupCount - preferredTensorParallel,
|
||||
);
|
||||
if (leftTopologyDelta !== rightTopologyDelta) {
|
||||
return leftTopologyDelta - rightTopologyDelta;
|
||||
}
|
||||
|
||||
return left.resources.deploymentCount - right.resources.deploymentCount;
|
||||
});
|
||||
|
||||
return eligible[0];
|
||||
}
|
||||
|
||||
public getStatus(): IClusterStatusResponse {
|
||||
const models: Record<string, IClusterModelLocation[]> = {};
|
||||
for (const node of this.getHealthyNodes()) {
|
||||
for (const deployment of node.deployments) {
|
||||
if (!models[deployment.modelId]) {
|
||||
models[deployment.modelId] = [];
|
||||
}
|
||||
|
||||
models[deployment.modelId].push({
|
||||
modelId: deployment.modelId,
|
||||
nodeName: node.nodeName,
|
||||
endpoint: deployment.endpoint,
|
||||
healthy: deployment.healthy,
|
||||
engine: deployment.engine,
|
||||
containerId: deployment.containerId,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
localNode: this.localNode ? this.decorateNode(this.localNode) : null,
|
||||
nodes: this.getAllNodes(),
|
||||
models,
|
||||
desiredDeployments: this.getDesiredDeployments(),
|
||||
};
|
||||
}
|
||||
|
||||
private hasEligibleTopologyGroup(
|
||||
groups: IClusterGpuTopologyGroup[],
|
||||
requiredVramGb: number,
|
||||
minGpuCount: number,
|
||||
): boolean {
|
||||
return groups.some((group) =>
|
||||
group.gpuCount >= minGpuCount && group.totalVramGb >= requiredVramGb
|
||||
);
|
||||
}
|
||||
|
||||
private getStateFilePath(): string {
|
||||
return `${PATHS.DATA_DIR}/cluster-state.json`;
|
||||
}
|
||||
|
||||
private getControlStateFilePath(): string {
|
||||
return `${PATHS.DATA_DIR}/cluster-control-state.json`;
|
||||
}
|
||||
|
||||
private schedulePersist(): void {
|
||||
if (this.persistQueued) {
|
||||
return;
|
||||
}
|
||||
|
||||
this.persistQueued = true;
|
||||
queueMicrotask(() => {
|
||||
this.persistQueued = false;
|
||||
void this.persistState();
|
||||
});
|
||||
}
|
||||
|
||||
private scheduleControlPersist(): void {
|
||||
if (this.controlPersistQueued) {
|
||||
return;
|
||||
}
|
||||
|
||||
this.controlPersistQueued = true;
|
||||
queueMicrotask(() => {
|
||||
this.controlPersistQueued = false;
|
||||
void this.persistControlState();
|
||||
});
|
||||
}
|
||||
|
||||
private async persistState(): Promise<void> {
|
||||
try {
|
||||
await fs.mkdir(PATHS.DATA_DIR, { recursive: true });
|
||||
await fs.writeFile(
|
||||
this.getStateFilePath(),
|
||||
JSON.stringify({ nodes: Array.from(this.knownNodes.values()) }, null, 2),
|
||||
);
|
||||
} catch {
|
||||
// Persistence failure should not break the control plane.
|
||||
}
|
||||
}
|
||||
|
||||
private async persistControlState(): Promise<void> {
|
||||
try {
|
||||
await fs.mkdir(PATHS.DATA_DIR, { recursive: true });
|
||||
await fs.writeFile(
|
||||
this.getControlStateFilePath(),
|
||||
JSON.stringify(
|
||||
{
|
||||
desiredDeployments: this.getDesiredDeployments(),
|
||||
nodeSchedulerStates: Object.fromEntries(this.nodeSchedulerStates.entries()),
|
||||
},
|
||||
null,
|
||||
2,
|
||||
),
|
||||
);
|
||||
} catch {
|
||||
// Persistence failure should not break the control plane.
|
||||
}
|
||||
}
|
||||
|
||||
private decorateNode(node: IClusterNodeHeartbeat): IClusterNodeHeartbeat {
|
||||
return {
|
||||
...node,
|
||||
schedulerState: this.getNodeSchedulerState(node.nodeName),
|
||||
};
|
||||
}
|
||||
|
||||
private compareSchedulerState(
|
||||
left: TClusterNodeSchedulerState,
|
||||
right: TClusterNodeSchedulerState,
|
||||
): number {
|
||||
const order: TClusterNodeSchedulerState[] = ['active', 'cordoned', 'draining'];
|
||||
return order.indexOf(left) - order.indexOf(right);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,438 @@
|
||||
import type {
|
||||
IClusterDesiredDeployment,
|
||||
IClusterEnsureResponse,
|
||||
IClusterNodeHeartbeat,
|
||||
IClusterNodeResources,
|
||||
IClusterStatusResponse,
|
||||
TClusterNodeSchedulerState,
|
||||
} from '../interfaces/cluster.ts';
|
||||
import { ContainerManager } from '../containers/container-manager.ts';
|
||||
import { GpuDetector } from '../hardware/gpu-detector.ts';
|
||||
import { logger } from '../logger.ts';
|
||||
import { ModelRegistry } from '../models/registry.ts';
|
||||
import { ModelLoader } from '../models/loader.ts';
|
||||
import { CLUSTER } from '../constants.ts';
|
||||
import { filterOutUsedGpus, summarizeGpuTopologyGroups } from './placement.ts';
|
||||
import { ClusterManager } from './cluster-manager.ts';
|
||||
|
||||
export class ClusterCoordinator {
|
||||
private clusterManager: ClusterManager;
|
||||
private containerManager: ContainerManager;
|
||||
private modelRegistry: ModelRegistry;
|
||||
private modelLoader: ModelLoader;
|
||||
private gpuDetector: GpuDetector;
|
||||
|
||||
constructor(
|
||||
clusterManager: ClusterManager,
|
||||
containerManager: ContainerManager,
|
||||
modelRegistry: ModelRegistry,
|
||||
modelLoader: ModelLoader,
|
||||
) {
|
||||
this.clusterManager = clusterManager;
|
||||
this.containerManager = containerManager;
|
||||
this.modelRegistry = modelRegistry;
|
||||
this.modelLoader = modelLoader;
|
||||
this.gpuDetector = new GpuDetector();
|
||||
}
|
||||
|
||||
public async buildLocalHeartbeat(endpoint: string): Promise<IClusterNodeHeartbeat> {
|
||||
const [gpus, statuses, models] = await Promise.all([
|
||||
this.gpuDetector.detectGpus(),
|
||||
this.containerManager.getAllStatus(),
|
||||
this.containerManager.getAllAvailableModels(),
|
||||
]);
|
||||
|
||||
const deploymentCount = Array.from(statuses.values()).filter((status) => status.running).length;
|
||||
const runningContainers = this.containerManager.getAllContainers().filter((container) => {
|
||||
const status = statuses.get(container.getConfig().id);
|
||||
return status?.running === true;
|
||||
});
|
||||
const resources = await this.buildResourceSummary(
|
||||
gpus,
|
||||
deploymentCount,
|
||||
models,
|
||||
runningContainers,
|
||||
);
|
||||
|
||||
return {
|
||||
nodeName: this.clusterManager.getConfig().nodeName,
|
||||
role: this.clusterManager.getConfig().role,
|
||||
endpoint,
|
||||
healthy: true,
|
||||
resources,
|
||||
deployments: Array.from(models.entries()).map(([modelId, endpoints]) => ({
|
||||
modelId,
|
||||
engine: 'vllm' as const,
|
||||
endpoint,
|
||||
healthy: endpoints.some((entry) => entry.healthy),
|
||||
containerId: endpoints[0]?.containerId,
|
||||
})),
|
||||
lastSeenAt: Date.now(),
|
||||
};
|
||||
}
|
||||
|
||||
public async syncLocalState(endpoint: string): Promise<IClusterNodeHeartbeat> {
|
||||
const heartbeat = await this.buildLocalHeartbeat(endpoint);
|
||||
this.clusterManager.updateLocalNode(heartbeat);
|
||||
return heartbeat;
|
||||
}
|
||||
|
||||
public async sendHeartbeat(): Promise<void> {
|
||||
if (!this.clusterManager.isEnabled()) {
|
||||
return;
|
||||
}
|
||||
|
||||
const endpoint = this.clusterManager.getAdvertisedEndpoint();
|
||||
const controlPlaneUrl = this.clusterManager.getControlPlaneUrl();
|
||||
if (!endpoint || !controlPlaneUrl) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (controlPlaneUrl === endpoint) {
|
||||
return;
|
||||
}
|
||||
|
||||
const heartbeat = await this.syncLocalState(endpoint);
|
||||
|
||||
try {
|
||||
await fetch(`${controlPlaneUrl}/_cluster/nodes/heartbeat`, {
|
||||
method: 'POST',
|
||||
headers: this.buildClusterHeaders(),
|
||||
body: JSON.stringify(heartbeat),
|
||||
});
|
||||
} catch (error) {
|
||||
logger.warn(
|
||||
`Cluster heartbeat failed: ${error instanceof Error ? error.message : String(error)}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
public acceptHeartbeat(heartbeat: IClusterNodeHeartbeat): void {
|
||||
this.clusterManager.upsertNode(heartbeat);
|
||||
}
|
||||
|
||||
public getStatus(): IClusterStatusResponse {
|
||||
return this.clusterManager.getStatus();
|
||||
}
|
||||
|
||||
public getDesiredDeployments(): IClusterDesiredDeployment[] {
|
||||
return this.clusterManager.getDesiredDeployments();
|
||||
}
|
||||
|
||||
public getLocalNodeName(): string {
|
||||
return this.clusterManager.getConfig().nodeName;
|
||||
}
|
||||
|
||||
public getSharedSecret(): string | undefined {
|
||||
return this.clusterManager.getSharedSecret();
|
||||
}
|
||||
|
||||
public setNodeSchedulerState(
|
||||
nodeName: string,
|
||||
schedulerState: TClusterNodeSchedulerState,
|
||||
): TClusterNodeSchedulerState {
|
||||
return this.clusterManager.setNodeSchedulerState(nodeName, schedulerState);
|
||||
}
|
||||
|
||||
public async setDesiredReplicas(
|
||||
modelName: string,
|
||||
desiredReplicas: number,
|
||||
): Promise<IClusterDesiredDeployment | null> {
|
||||
const model = await this.modelRegistry.getModel(modelName);
|
||||
if (!model) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (desiredReplicas <= 0) {
|
||||
this.clusterManager.removeDesiredDeployment(model.id);
|
||||
return {
|
||||
modelId: model.id,
|
||||
desiredReplicas: 0,
|
||||
updatedAt: Date.now(),
|
||||
};
|
||||
}
|
||||
|
||||
return this.clusterManager.upsertDesiredDeployment(model.id, Math.max(desiredReplicas, 0));
|
||||
}
|
||||
|
||||
public async clearDesiredDeployment(modelName: string): Promise<boolean> {
|
||||
const model = await this.modelRegistry.getModel(modelName);
|
||||
if (!model) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return this.clusterManager.removeDesiredDeployment(model.id);
|
||||
}
|
||||
|
||||
public shouldDeployLocallyFirst(): boolean {
|
||||
if (!this.clusterManager.isEnabled()) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return this.clusterManager.isControlPlane() || !this.clusterManager.getControlPlaneUrl();
|
||||
}
|
||||
|
||||
public canManageClusterState(): boolean {
|
||||
return !this.clusterManager.isEnabled() || this.clusterManager.isControlPlane();
|
||||
}
|
||||
|
||||
public async resolveModel(modelName: string): Promise<IClusterEnsureResponse | null> {
|
||||
const model = await this.modelRegistry.getModel(modelName);
|
||||
if (!model) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const location = this.clusterManager.resolveModel(model.id);
|
||||
if (!location) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return {
|
||||
model: model.id,
|
||||
location,
|
||||
created: false,
|
||||
};
|
||||
}
|
||||
|
||||
public async ensureModel(modelName: string): Promise<IClusterEnsureResponse | null> {
|
||||
const model = await this.modelRegistry.getModel(modelName);
|
||||
if (!model) {
|
||||
return null;
|
||||
}
|
||||
|
||||
this.rememberDesiredDeployment(model.id, model.launchDefaults?.replicas || 1);
|
||||
|
||||
const existing = this.clusterManager.getActiveModelLocations(model.id)[0];
|
||||
if (existing) {
|
||||
return {
|
||||
model: model.id,
|
||||
location: existing,
|
||||
created: false,
|
||||
};
|
||||
}
|
||||
|
||||
if (!this.clusterManager.isEnabled() || !this.clusterManager.isControlPlane()) {
|
||||
const local = await this.deployModelLocally(model.id);
|
||||
if (!local) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return local;
|
||||
}
|
||||
|
||||
const targetNode = this.clusterManager.pickNodeForModel(model);
|
||||
if (!targetNode) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (targetNode.nodeName === this.clusterManager.getConfig().nodeName) {
|
||||
return this.deployModelLocally(model.id);
|
||||
}
|
||||
|
||||
return this.requestRemoteDeployment(targetNode.endpoint, model.id);
|
||||
}
|
||||
|
||||
public async ensureModelViaControlPlane(
|
||||
modelName: string,
|
||||
): Promise<IClusterEnsureResponse | null> {
|
||||
const controlPlaneUrl = this.clusterManager.getControlPlaneUrl();
|
||||
const localEndpoint = this.clusterManager.getAdvertisedEndpoint();
|
||||
|
||||
if (!controlPlaneUrl || controlPlaneUrl === localEndpoint) {
|
||||
return this.ensureModel(modelName);
|
||||
}
|
||||
|
||||
try {
|
||||
const response = await fetch(`${controlPlaneUrl}/_cluster/models/ensure`, {
|
||||
method: 'POST',
|
||||
headers: this.buildClusterHeaders(),
|
||||
body: JSON.stringify({ model: modelName }),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return await response.json() as IClusterEnsureResponse;
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
public async deployModelLocally(modelName: string): Promise<IClusterEnsureResponse | null> {
|
||||
const model = await this.modelRegistry.getModel(modelName);
|
||||
if (model) {
|
||||
this.rememberDesiredDeployment(model.id, model.launchDefaults?.replicas || 1);
|
||||
}
|
||||
|
||||
const result = await this.modelLoader.loadModel(modelName);
|
||||
if (!result.success) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const endpoint = this.clusterManager.getAdvertisedEndpoint();
|
||||
if (endpoint) {
|
||||
await this.syncLocalState(endpoint);
|
||||
}
|
||||
|
||||
const resolved = await this.resolveModel(result.model);
|
||||
if (!resolved) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return {
|
||||
...resolved,
|
||||
created: !result.alreadyLoaded,
|
||||
};
|
||||
}
|
||||
|
||||
public async reconcileDesiredReplicas(): Promise<void> {
|
||||
if (this.clusterManager.isEnabled() && !this.clusterManager.isControlPlane()) {
|
||||
return;
|
||||
}
|
||||
|
||||
const desiredDeployments = this.clusterManager.getDesiredDeployments();
|
||||
for (const desiredDeployment of desiredDeployments) {
|
||||
if (desiredDeployment.desiredReplicas <= 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const model = await this.modelRegistry.getModel(desiredDeployment.modelId);
|
||||
if (!model) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const existingLocations = this.clusterManager.getActiveModelLocations(model.id);
|
||||
const missingReplicas = desiredDeployment.desiredReplicas - existingLocations.length;
|
||||
if (missingReplicas <= 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
for (let index = 0; index < missingReplicas; index++) {
|
||||
const targetNode = this.clusterManager.pickNodeForModel(model);
|
||||
if (!targetNode) {
|
||||
break;
|
||||
}
|
||||
|
||||
const replicaOrdinal = existingLocations.length + index;
|
||||
const result = targetNode.nodeName === this.clusterManager.getConfig().nodeName
|
||||
? await this.deployReplicaLocally(model.id, replicaOrdinal)
|
||||
: await this.requestRemoteDeployment(targetNode.endpoint, model.id, replicaOrdinal);
|
||||
|
||||
if (!result) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public async deployReplicaLocally(
|
||||
modelName: string,
|
||||
replicaOrdinal?: number,
|
||||
): Promise<IClusterEnsureResponse | null> {
|
||||
const model = await this.modelRegistry.getModel(modelName);
|
||||
if (model) {
|
||||
this.rememberDesiredDeployment(
|
||||
model.id,
|
||||
Math.max((replicaOrdinal ?? 0) + 1, model.launchDefaults?.replicas || 1),
|
||||
);
|
||||
}
|
||||
|
||||
const result = await this.modelLoader.deployReplica(modelName, replicaOrdinal);
|
||||
if (!result.success) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const endpoint = this.clusterManager.getAdvertisedEndpoint();
|
||||
if (endpoint) {
|
||||
await this.syncLocalState(endpoint);
|
||||
}
|
||||
|
||||
const resolved = await this.resolveModel(result.model);
|
||||
if (!resolved) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return {
|
||||
...resolved,
|
||||
created: !result.alreadyLoaded,
|
||||
};
|
||||
}
|
||||
|
||||
private async requestRemoteDeployment(
|
||||
nodeEndpoint: string,
|
||||
modelName: string,
|
||||
replicaOrdinal?: number,
|
||||
): Promise<IClusterEnsureResponse | null> {
|
||||
try {
|
||||
const response = await fetch(`${nodeEndpoint}/_cluster/deployments`, {
|
||||
method: 'POST',
|
||||
headers: this.buildClusterHeaders(),
|
||||
body: JSON.stringify({ model: modelName, replicaOrdinal }),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return await response.json() as IClusterEnsureResponse;
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
private async buildResourceSummary(
|
||||
gpus: Awaited<ReturnType<GpuDetector['detectGpus']>>,
|
||||
deploymentCount: number,
|
||||
_models: Awaited<ReturnType<ContainerManager['getAllAvailableModels']>>,
|
||||
runningContainers: ReturnType<ContainerManager['getAllContainers']>,
|
||||
): Promise<IClusterNodeResources> {
|
||||
const totalVramGb = Math.round(gpus.reduce((sum, gpu) => sum + gpu.vram, 0) / 1024);
|
||||
const usedGpuIds = runningContainers.flatMap((container) => container.getConfig().gpuIds);
|
||||
const availableGpus = filterOutUsedGpus(gpus, usedGpuIds);
|
||||
const topologyGroups = summarizeGpuTopologyGroups(availableGpus);
|
||||
const availableVramGb = Math.round(
|
||||
availableGpus.reduce((sum, gpu) => sum + gpu.vram, 0) / 1024,
|
||||
);
|
||||
|
||||
const maxSingleGpuVramGb = availableGpus.length > 0
|
||||
? Math.max(...availableGpus.map((gpu) => Math.round(gpu.vram / 1024)))
|
||||
: 0;
|
||||
const largestGpuGroupCount = topologyGroups.length > 0
|
||||
? Math.max(...topologyGroups.map((group) => group.gpuCount))
|
||||
: 0;
|
||||
const largestGpuGroupVramGb = topologyGroups.length > 0
|
||||
? Math.max(...topologyGroups.map((group) => group.totalVramGb))
|
||||
: 0;
|
||||
|
||||
return {
|
||||
gpuCount: gpus.length,
|
||||
totalVramGb,
|
||||
availableVramGb,
|
||||
maxSingleGpuVramGb,
|
||||
largestGpuGroupCount,
|
||||
largestGpuGroupVramGb,
|
||||
deploymentCount,
|
||||
topologyGroups,
|
||||
};
|
||||
}
|
||||
|
||||
private buildClusterHeaders(): Record<string, string> {
|
||||
const headers: Record<string, string> = {
|
||||
'Content-Type': 'application/json',
|
||||
};
|
||||
|
||||
const sharedSecret = this.clusterManager.getSharedSecret();
|
||||
if (sharedSecret) {
|
||||
headers[CLUSTER.AUTH_HEADER_NAME] = sharedSecret;
|
||||
}
|
||||
|
||||
return headers;
|
||||
}
|
||||
|
||||
private rememberDesiredDeployment(modelId: string, minimumReplicas: number): void {
|
||||
const existing = this.clusterManager.getDesiredDeployment(modelId);
|
||||
const desiredReplicas = Math.max(existing?.desiredReplicas || 0, minimumReplicas, 1);
|
||||
this.clusterManager.upsertDesiredDeployment(modelId, desiredReplicas);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,2 @@
|
||||
export { ClusterManager } from './cluster-manager.ts';
|
||||
export { ClusterCoordinator } from './coordinator.ts';
|
||||
@@ -0,0 +1,114 @@
|
||||
import type { IModelCatalogEntry } from '../interfaces/catalog.ts';
|
||||
import type { IGpuInfo, TGpuVendor } from '../interfaces/gpu.ts';
|
||||
import type { IClusterGpuTopologyGroup } from '../interfaces/cluster.ts';
|
||||
|
||||
function parsePciBusNumber(gpu: IGpuInfo): number {
|
||||
const source = gpu.pciBusId || gpu.pciSlot;
|
||||
const match = source.match(/(?:[0-9a-f]{4}:)?([0-9a-f]{2}):/i);
|
||||
if (!match) {
|
||||
return gpu.index;
|
||||
}
|
||||
|
||||
return parseInt(match[1], 16);
|
||||
}
|
||||
|
||||
export function buildGpuTopologyGroups(gpus: IGpuInfo[]): IClusterGpuTopologyGroup[] {
|
||||
const sorted = [...gpus].sort((left, right) => {
|
||||
if (left.vendor !== right.vendor) {
|
||||
return left.vendor.localeCompare(right.vendor);
|
||||
}
|
||||
|
||||
return parsePciBusNumber(left) - parsePciBusNumber(right);
|
||||
});
|
||||
|
||||
const groups: IClusterGpuTopologyGroup[] = [];
|
||||
|
||||
for (const gpu of sorted) {
|
||||
const busNumber = parsePciBusNumber(gpu);
|
||||
const previousGroup = groups[groups.length - 1];
|
||||
const previousBus = previousGroup?.busNumbers[previousGroup.busNumbers.length - 1];
|
||||
|
||||
const belongsToPreviousGroup = previousGroup &&
|
||||
previousGroup.vendor === gpu.vendor &&
|
||||
previousBus !== undefined &&
|
||||
busNumber - previousBus <= 1;
|
||||
|
||||
if (belongsToPreviousGroup) {
|
||||
previousGroup.gpuIds.push(gpu.id);
|
||||
previousGroup.busNumbers.push(busNumber);
|
||||
previousGroup.totalVramGb += Math.round(gpu.vram / 1024);
|
||||
previousGroup.maxSingleGpuVramGb = Math.max(
|
||||
previousGroup.maxSingleGpuVramGb,
|
||||
Math.round(gpu.vram / 1024),
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
groups.push({
|
||||
id: `${gpu.vendor}-${groups.length + 1}`,
|
||||
vendor: gpu.vendor,
|
||||
gpuIds: [gpu.id],
|
||||
gpuCount: 1,
|
||||
totalVramGb: Math.round(gpu.vram / 1024),
|
||||
maxSingleGpuVramGb: Math.round(gpu.vram / 1024),
|
||||
busNumbers: [busNumber],
|
||||
});
|
||||
}
|
||||
|
||||
for (const group of groups) {
|
||||
group.gpuCount = group.gpuIds.length;
|
||||
}
|
||||
|
||||
return groups;
|
||||
}
|
||||
|
||||
export function summarizeGpuTopologyGroups(gpus: IGpuInfo[]): IClusterGpuTopologyGroup[] {
|
||||
return buildGpuTopologyGroups(gpus);
|
||||
}
|
||||
|
||||
export function selectPlacementForModel(
|
||||
model: IModelCatalogEntry,
|
||||
gpus: IGpuInfo[],
|
||||
): { gpuIds: string[]; tensorParallelSize: number; topologyGroupId: string } | null {
|
||||
const minGpuCount = model.requirements.minGpuCount || 1;
|
||||
const preferredTensorParallel = model.launchDefaults?.tensorParallelSize || minGpuCount;
|
||||
const topologyGroups = buildGpuTopologyGroups(gpus);
|
||||
|
||||
const eligibleGroups = topologyGroups.filter((group) =>
|
||||
group.gpuCount >= minGpuCount && group.totalVramGb >= model.requirements.minVramGb
|
||||
);
|
||||
|
||||
if (eligibleGroups.length === 0) {
|
||||
return null;
|
||||
}
|
||||
|
||||
eligibleGroups.sort((left, right) => {
|
||||
const leftCountDelta = Math.abs(left.gpuCount - preferredTensorParallel);
|
||||
const rightCountDelta = Math.abs(right.gpuCount - preferredTensorParallel);
|
||||
if (leftCountDelta !== rightCountDelta) {
|
||||
return leftCountDelta - rightCountDelta;
|
||||
}
|
||||
|
||||
const leftVramDelta = left.totalVramGb - model.requirements.minVramGb;
|
||||
const rightVramDelta = right.totalVramGb - model.requirements.minVramGb;
|
||||
if (leftVramDelta !== rightVramDelta) {
|
||||
return leftVramDelta - rightVramDelta;
|
||||
}
|
||||
|
||||
return left.id.localeCompare(right.id);
|
||||
});
|
||||
|
||||
const selectedGroup = eligibleGroups[0];
|
||||
const tensorParallelSize = Math.min(preferredTensorParallel, selectedGroup.gpuCount);
|
||||
|
||||
return {
|
||||
gpuIds: selectedGroup.gpuIds.slice(0, tensorParallelSize),
|
||||
tensorParallelSize,
|
||||
topologyGroupId: selectedGroup.id,
|
||||
};
|
||||
}
|
||||
|
||||
export function filterOutUsedGpus(gpus: IGpuInfo[], usedGpuIds: string[]): IGpuInfo[] {
|
||||
const usedSet = new Set(usedGpuIds);
|
||||
return gpus.filter((gpu) => !usedSet.has(gpu.id));
|
||||
}
|
||||
+8
-3
@@ -37,6 +37,13 @@ export const theme = {
|
||||
containerStopped: colors.red,
|
||||
containerStarting: colors.yellow,
|
||||
|
||||
// Named vendor/container helpers
|
||||
gpuNvidia: colors.green,
|
||||
gpuAmd: colors.red,
|
||||
gpuIntel: colors.blue,
|
||||
containerVllm: colors.cyan,
|
||||
containerTgi: colors.magenta,
|
||||
|
||||
// Box borders
|
||||
borderSuccess: colors.green,
|
||||
borderError: colors.red,
|
||||
@@ -127,10 +134,8 @@ export function formatContainerStatus(
|
||||
/**
|
||||
* Format container type with color
|
||||
*/
|
||||
export function formatContainerType(type: 'ollama' | 'vllm' | 'tgi' | 'custom'): string {
|
||||
export function formatContainerType(type: 'vllm' | 'tgi' | 'custom'): string {
|
||||
switch (type) {
|
||||
case 'ollama':
|
||||
return colors.green('Ollama');
|
||||
case 'vllm':
|
||||
return colors.cyan('vLLM');
|
||||
case 'tgi':
|
||||
|
||||
+88
-13
@@ -5,6 +5,8 @@
|
||||
* This makes configuration easier and code more self-documenting.
|
||||
*/
|
||||
|
||||
export const VERSION = '1.0.1';
|
||||
|
||||
/**
|
||||
* Default timing values in milliseconds
|
||||
*/
|
||||
@@ -106,9 +108,6 @@ export const CONTAINER_PORTS = {
|
||||
* Container image defaults
|
||||
*/
|
||||
export const CONTAINER_IMAGES = {
|
||||
/** Ollama official image */
|
||||
OLLAMA: 'ollama/ollama:latest',
|
||||
|
||||
/** vLLM official image */
|
||||
VLLM: 'vllm/vllm-openai:latest',
|
||||
|
||||
@@ -120,20 +119,96 @@ export const CONTAINER_IMAGES = {
|
||||
* Model registry constants
|
||||
*/
|
||||
export const MODEL_REGISTRY = {
|
||||
/** Default greenlit models URL */
|
||||
DEFAULT_GREENLIST_URL:
|
||||
'https://code.foss.global/modelgrid.com/model_lists/raw/branch/main/greenlit.json',
|
||||
/** Default public catalog URL */
|
||||
DEFAULT_CATALOG_URL: 'https://list.modelgrid.com/catalog/models.json',
|
||||
|
||||
/** Fallback greenlist if remote fetch fails */
|
||||
FALLBACK_GREENLIST: [
|
||||
{ name: 'llama3.2:1b', container: 'ollama', minVram: 4 },
|
||||
{ name: 'llama3.2:3b', container: 'ollama', minVram: 6 },
|
||||
{ name: 'llama3:8b', container: 'ollama', minVram: 8 },
|
||||
{ name: 'mistral:7b', container: 'ollama', minVram: 8 },
|
||||
{ name: 'codellama:7b', container: 'ollama', minVram: 8 },
|
||||
/** Fallback catalog if remote fetch fails */
|
||||
FALLBACK_CATALOG: [
|
||||
{
|
||||
id: 'Qwen/Qwen2.5-7B-Instruct',
|
||||
aliases: ['qwen2.5-7b-instruct'],
|
||||
engine: 'vllm',
|
||||
source: {
|
||||
repo: 'Qwen/Qwen2.5-7B-Instruct',
|
||||
license: 'apache-2.0',
|
||||
},
|
||||
capabilities: {
|
||||
chat: true,
|
||||
completions: true,
|
||||
tools: true,
|
||||
},
|
||||
requirements: {
|
||||
minVramGb: 16,
|
||||
recommendedVramGb: 24,
|
||||
minGpuCount: 1,
|
||||
},
|
||||
metadata: {
|
||||
family: 'Qwen2.5',
|
||||
parameterCount: '7B',
|
||||
contextWindow: 131072,
|
||||
summary: 'General purpose instruct model for chat and tool use.',
|
||||
tags: ['chat', 'tool-use', 'instruct'],
|
||||
},
|
||||
},
|
||||
{
|
||||
id: 'meta-llama/Llama-3.1-8B-Instruct',
|
||||
aliases: ['llama-3.1-8b-instruct'],
|
||||
engine: 'vllm',
|
||||
source: {
|
||||
repo: 'meta-llama/Llama-3.1-8B-Instruct',
|
||||
license: 'llama3.1',
|
||||
},
|
||||
capabilities: {
|
||||
chat: true,
|
||||
completions: true,
|
||||
tools: true,
|
||||
},
|
||||
requirements: {
|
||||
minVramGb: 18,
|
||||
recommendedVramGb: 24,
|
||||
minGpuCount: 1,
|
||||
},
|
||||
metadata: {
|
||||
family: 'Llama 3.1',
|
||||
parameterCount: '8B',
|
||||
contextWindow: 131072,
|
||||
summary: 'High quality instruct model with good ecosystem support.',
|
||||
tags: ['chat', 'tool-use', 'instruct'],
|
||||
},
|
||||
},
|
||||
{
|
||||
id: 'BAAI/bge-m3',
|
||||
aliases: ['bge-m3'],
|
||||
engine: 'vllm',
|
||||
source: {
|
||||
repo: 'BAAI/bge-m3',
|
||||
license: 'mit',
|
||||
},
|
||||
capabilities: {
|
||||
embeddings: true,
|
||||
},
|
||||
requirements: {
|
||||
minVramGb: 8,
|
||||
recommendedVramGb: 12,
|
||||
minGpuCount: 1,
|
||||
},
|
||||
metadata: {
|
||||
family: 'BGE',
|
||||
summary: 'Multilingual embedding model for retrieval workloads.',
|
||||
tags: ['embeddings', 'retrieval', 'multilingual'],
|
||||
},
|
||||
},
|
||||
],
|
||||
} as const;
|
||||
|
||||
export const CLUSTER = {
|
||||
DEFAULT_BIND_HOST: '0.0.0.0',
|
||||
DEFAULT_GOSSIP_PORT: 7946,
|
||||
DEFAULT_HEARTBEAT_INTERVAL_MS: 5000,
|
||||
NODE_STALE_AFTER_MS: 20000,
|
||||
AUTH_HEADER_NAME: 'x-modelgrid-cluster-secret',
|
||||
} as const;
|
||||
|
||||
/**
|
||||
* Configuration paths
|
||||
*/
|
||||
|
||||
@@ -6,14 +6,13 @@
|
||||
|
||||
import type {
|
||||
IContainerConfig,
|
||||
IContainerStatus,
|
||||
IContainerEndpoint,
|
||||
IContainerStatus,
|
||||
TContainerType,
|
||||
} from '../interfaces/container.ts';
|
||||
import { logger } from '../logger.ts';
|
||||
import { DockerManager } from '../docker/docker-manager.ts';
|
||||
import { BaseContainer } from './base-container.ts';
|
||||
import { OllamaContainer } from './ollama.ts';
|
||||
import { VllmContainer } from './vllm.ts';
|
||||
import { TgiContainer } from './tgi.ts';
|
||||
|
||||
@@ -47,8 +46,6 @@ export class ContainerManager {
|
||||
*/
|
||||
private createContainerInstance(config: IContainerConfig): BaseContainer {
|
||||
switch (config.type) {
|
||||
case 'ollama':
|
||||
return new OllamaContainer(config);
|
||||
case 'vllm':
|
||||
return new VllmContainer(config);
|
||||
case 'tgi':
|
||||
@@ -108,7 +105,11 @@ export class ContainerManager {
|
||||
try {
|
||||
this.addContainer(config);
|
||||
} catch (error) {
|
||||
logger.warn(`Failed to load container ${config.id}: ${error instanceof Error ? error.message : String(error)}`);
|
||||
logger.warn(
|
||||
`Failed to load container ${config.id}: ${
|
||||
error instanceof Error ? error.message : String(error)
|
||||
}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -128,7 +129,11 @@ export class ContainerManager {
|
||||
const success = await container.start();
|
||||
results.set(id, success);
|
||||
} catch (error) {
|
||||
logger.error(`Failed to start container ${id}: ${error instanceof Error ? error.message : String(error)}`);
|
||||
logger.error(
|
||||
`Failed to start container ${id}: ${
|
||||
error instanceof Error ? error.message : String(error)
|
||||
}`,
|
||||
);
|
||||
results.set(id, false);
|
||||
}
|
||||
}
|
||||
@@ -147,7 +152,11 @@ export class ContainerManager {
|
||||
const success = await container.stop();
|
||||
results.set(id, success);
|
||||
} catch (error) {
|
||||
logger.error(`Failed to stop container ${id}: ${error instanceof Error ? error.message : String(error)}`);
|
||||
logger.error(
|
||||
`Failed to stop container ${id}: ${
|
||||
error instanceof Error ? error.message : String(error)
|
||||
}`,
|
||||
);
|
||||
results.set(id, false);
|
||||
}
|
||||
}
|
||||
@@ -166,7 +175,11 @@ export class ContainerManager {
|
||||
const status = await container.getStatus();
|
||||
statuses.set(id, status);
|
||||
} catch (error) {
|
||||
logger.warn(`Failed to get status for container ${id}: ${error instanceof Error ? error.message : String(error)}`);
|
||||
logger.warn(
|
||||
`Failed to get status for container ${id}: ${
|
||||
error instanceof Error ? error.message : String(error)
|
||||
}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -266,7 +279,7 @@ export class ContainerManager {
|
||||
*/
|
||||
public async pullModel(
|
||||
modelName: string,
|
||||
containerType: TContainerType = 'ollama',
|
||||
containerType: TContainerType = 'vllm',
|
||||
containerId?: string,
|
||||
): Promise<boolean> {
|
||||
// Find or create appropriate container
|
||||
@@ -313,6 +326,16 @@ export class ContainerManager {
|
||||
return results;
|
||||
}
|
||||
|
||||
public async checkAllHealth(): Promise<boolean> {
|
||||
const results = await this.healthCheck();
|
||||
|
||||
if (results.size === 0) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return Array.from(results.values()).every((healthy) => healthy);
|
||||
}
|
||||
|
||||
/**
|
||||
* Print container status summary
|
||||
*/
|
||||
@@ -329,9 +352,7 @@ export class ContainerManager {
|
||||
for (const [id, status] of statuses) {
|
||||
const runningStr = status.running ? 'Running' : 'Stopped';
|
||||
const healthStr = status.health;
|
||||
const modelsStr = status.loadedModels.length > 0
|
||||
? status.loadedModels.join(', ')
|
||||
: 'None';
|
||||
const modelsStr = status.loadedModels.length > 0 ? status.loadedModels.join(', ') : 'None';
|
||||
|
||||
logger.logBoxLine(`${status.name} (${id})`);
|
||||
logger.logBoxLine(` Type: ${status.type} | Status: ${runningStr} | Health: ${healthStr}`);
|
||||
@@ -339,7 +360,9 @@ export class ContainerManager {
|
||||
logger.logBoxLine(` Endpoint: ${status.endpoint}`);
|
||||
|
||||
if (status.gpuUtilization !== undefined) {
|
||||
logger.logBoxLine(` GPU: ${status.gpuUtilization}% | Memory: ${status.memoryUsage || 0}MB`);
|
||||
logger.logBoxLine(
|
||||
` GPU: ${status.gpuUtilization}% | Memory: ${status.memoryUsage || 0}MB`,
|
||||
);
|
||||
}
|
||||
logger.logBoxLine('');
|
||||
}
|
||||
|
||||
@@ -5,7 +5,6 @@
|
||||
*/
|
||||
|
||||
export { BaseContainer } from './base-container.ts';
|
||||
export { OllamaContainer } from './ollama.ts';
|
||||
export { VllmContainer } from './vllm.ts';
|
||||
export { TgiContainer } from './tgi.ts';
|
||||
export { ContainerManager } from './container-manager.ts';
|
||||
|
||||
@@ -1,387 +0,0 @@
|
||||
/**
|
||||
* Ollama Container
|
||||
*
|
||||
* Manages Ollama containers for running local LLMs.
|
||||
*/
|
||||
|
||||
import type {
|
||||
IContainerConfig,
|
||||
ILoadedModel,
|
||||
TContainerType,
|
||||
} from '../interfaces/container.ts';
|
||||
import type {
|
||||
IChatCompletionRequest,
|
||||
IChatCompletionResponse,
|
||||
IChatCompletionChoice,
|
||||
IChatMessage,
|
||||
} from '../interfaces/api.ts';
|
||||
import { CONTAINER_IMAGES, CONTAINER_PORTS } from '../constants.ts';
|
||||
import { logger } from '../logger.ts';
|
||||
import { BaseContainer, type TModelPullProgress } from './base-container.ts';
|
||||
|
||||
/**
|
||||
* Ollama API response types
|
||||
*/
|
||||
interface IOllamaTagsResponse {
|
||||
models: Array<{
|
||||
name: string;
|
||||
size: number;
|
||||
digest: string;
|
||||
modified_at: string;
|
||||
}>;
|
||||
}
|
||||
|
||||
interface IOllamaChatRequest {
|
||||
model: string;
|
||||
messages: Array<{
|
||||
role: string;
|
||||
content: string;
|
||||
}>;
|
||||
stream?: boolean;
|
||||
options?: {
|
||||
temperature?: number;
|
||||
top_p?: number;
|
||||
num_predict?: number;
|
||||
stop?: string[];
|
||||
};
|
||||
}
|
||||
|
||||
interface IOllamaChatResponse {
|
||||
model: string;
|
||||
created_at: string;
|
||||
message: {
|
||||
role: string;
|
||||
content: string;
|
||||
};
|
||||
done: boolean;
|
||||
total_duration?: number;
|
||||
load_duration?: number;
|
||||
prompt_eval_count?: number;
|
||||
eval_count?: number;
|
||||
}
|
||||
|
||||
interface IOllamaPullResponse {
|
||||
status: string;
|
||||
digest?: string;
|
||||
total?: number;
|
||||
completed?: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* Ollama container implementation
|
||||
*/
|
||||
export class OllamaContainer extends BaseContainer {
|
||||
public readonly type: TContainerType = 'ollama';
|
||||
public readonly displayName = 'Ollama';
|
||||
public readonly defaultImage = CONTAINER_IMAGES.OLLAMA;
|
||||
public readonly defaultPort = CONTAINER_PORTS.OLLAMA;
|
||||
|
||||
constructor(config: IContainerConfig) {
|
||||
super(config);
|
||||
|
||||
// Set defaults if not provided
|
||||
if (!config.image) {
|
||||
config.image = this.defaultImage;
|
||||
}
|
||||
if (!config.port) {
|
||||
config.port = this.defaultPort;
|
||||
}
|
||||
|
||||
// Add default volume for model storage
|
||||
if (!config.volumes || config.volumes.length === 0) {
|
||||
config.volumes = [`modelgrid-ollama-${config.id}:/root/.ollama`];
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Create Ollama container configuration
|
||||
*/
|
||||
public static createConfig(
|
||||
id: string,
|
||||
name: string,
|
||||
gpuIds: string[],
|
||||
options: Partial<IContainerConfig> = {},
|
||||
): IContainerConfig {
|
||||
return {
|
||||
id,
|
||||
name,
|
||||
type: 'ollama',
|
||||
image: options.image || CONTAINER_IMAGES.OLLAMA,
|
||||
gpuIds,
|
||||
port: options.port || CONTAINER_PORTS.OLLAMA,
|
||||
externalPort: options.externalPort,
|
||||
models: options.models || [],
|
||||
env: options.env,
|
||||
volumes: options.volumes || [`modelgrid-ollama-${id}:/root/.ollama`],
|
||||
autoStart: options.autoStart ?? true,
|
||||
restartPolicy: options.restartPolicy || 'unless-stopped',
|
||||
memoryLimit: options.memoryLimit,
|
||||
cpuLimit: options.cpuLimit,
|
||||
command: options.command,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if Ollama is healthy
|
||||
*/
|
||||
public async isHealthy(): Promise<boolean> {
|
||||
try {
|
||||
const response = await this.fetch('/api/tags', { timeout: 5000 });
|
||||
return response.ok;
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* List available models
|
||||
*/
|
||||
public async listModels(): Promise<string[]> {
|
||||
try {
|
||||
const data = await this.fetchJson<IOllamaTagsResponse>('/api/tags');
|
||||
return (data.models || []).map((m) => m.name);
|
||||
} catch (error) {
|
||||
logger.warn(`Failed to list Ollama models: ${error instanceof Error ? error.message : String(error)}`);
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get loaded models with details
|
||||
*/
|
||||
public async getLoadedModels(): Promise<ILoadedModel[]> {
|
||||
try {
|
||||
const data = await this.fetchJson<IOllamaTagsResponse>('/api/tags');
|
||||
return (data.models || []).map((m) => ({
|
||||
name: m.name,
|
||||
size: m.size,
|
||||
format: m.digest.substring(0, 12),
|
||||
loaded: true, // Ollama doesn't distinguish loaded vs available
|
||||
requestCount: 0,
|
||||
}));
|
||||
} catch {
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Pull a model
|
||||
*/
|
||||
public async pullModel(modelName: string, onProgress?: TModelPullProgress): Promise<boolean> {
|
||||
try {
|
||||
logger.info(`Pulling model: ${modelName}`);
|
||||
|
||||
const response = await this.fetch('/api/pull', {
|
||||
method: 'POST',
|
||||
body: { name: modelName },
|
||||
timeout: 3600000, // 1 hour for large models
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`HTTP ${response.status}`);
|
||||
}
|
||||
|
||||
// Read streaming response
|
||||
const reader = response.body?.getReader();
|
||||
if (!reader) {
|
||||
throw new Error('No response body');
|
||||
}
|
||||
|
||||
const decoder = new TextDecoder();
|
||||
let lastStatus = '';
|
||||
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
if (done) break;
|
||||
|
||||
const text = decoder.decode(value);
|
||||
const lines = text.split('\n').filter((l) => l.trim());
|
||||
|
||||
for (const line of lines) {
|
||||
try {
|
||||
const data = JSON.parse(line) as IOllamaPullResponse;
|
||||
const status = data.status;
|
||||
|
||||
if (status !== lastStatus) {
|
||||
lastStatus = status;
|
||||
let percent: number | undefined;
|
||||
|
||||
if (data.total && data.completed) {
|
||||
percent = Math.round((data.completed / data.total) * 100);
|
||||
}
|
||||
|
||||
if (onProgress) {
|
||||
onProgress({ model: modelName, status, percent });
|
||||
} else {
|
||||
const progressStr = percent !== undefined ? ` (${percent}%)` : '';
|
||||
logger.dim(` ${status}${progressStr}`);
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
// Invalid JSON line, skip
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
logger.success(`Model ${modelName} pulled successfully`);
|
||||
return true;
|
||||
} catch (error) {
|
||||
logger.error(`Failed to pull model ${modelName}: ${error instanceof Error ? error.message : String(error)}`);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Remove a model
|
||||
*/
|
||||
public async removeModel(modelName: string): Promise<boolean> {
|
||||
try {
|
||||
const response = await this.fetch('/api/delete', {
|
||||
method: 'DELETE',
|
||||
body: { name: modelName },
|
||||
});
|
||||
|
||||
if (response.ok) {
|
||||
logger.success(`Model ${modelName} removed`);
|
||||
return true;
|
||||
}
|
||||
|
||||
throw new Error(`HTTP ${response.status}`);
|
||||
} catch (error) {
|
||||
logger.error(`Failed to remove model ${modelName}: ${error instanceof Error ? error.message : String(error)}`);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Send a chat completion request
|
||||
*/
|
||||
public async chatCompletion(request: IChatCompletionRequest): Promise<IChatCompletionResponse> {
|
||||
const ollamaRequest: IOllamaChatRequest = {
|
||||
model: request.model,
|
||||
messages: request.messages.map((m) => ({
|
||||
role: m.role,
|
||||
content: m.content,
|
||||
})),
|
||||
stream: false,
|
||||
options: {
|
||||
temperature: request.temperature,
|
||||
top_p: request.top_p,
|
||||
num_predict: request.max_tokens,
|
||||
stop: Array.isArray(request.stop) ? request.stop : request.stop ? [request.stop] : undefined,
|
||||
},
|
||||
};
|
||||
|
||||
const response = await this.fetchJson<IOllamaChatResponse>('/api/chat', {
|
||||
method: 'POST',
|
||||
body: ollamaRequest,
|
||||
timeout: 300000, // 5 minutes
|
||||
});
|
||||
|
||||
// Convert to OpenAI format
|
||||
const created = Math.floor(Date.now() / 1000);
|
||||
|
||||
const choice: IChatCompletionChoice = {
|
||||
index: 0,
|
||||
message: {
|
||||
role: 'assistant',
|
||||
content: response.message.content,
|
||||
},
|
||||
finish_reason: response.done ? 'stop' : null,
|
||||
};
|
||||
|
||||
return {
|
||||
id: this.generateRequestId(),
|
||||
object: 'chat.completion',
|
||||
created,
|
||||
model: request.model,
|
||||
choices: [choice],
|
||||
usage: {
|
||||
prompt_tokens: response.prompt_eval_count || 0,
|
||||
completion_tokens: response.eval_count || 0,
|
||||
total_tokens: (response.prompt_eval_count || 0) + (response.eval_count || 0),
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Stream a chat completion request
|
||||
*/
|
||||
public async chatCompletionStream(
|
||||
request: IChatCompletionRequest,
|
||||
onChunk: (chunk: string) => void,
|
||||
): Promise<void> {
|
||||
const ollamaRequest: IOllamaChatRequest = {
|
||||
model: request.model,
|
||||
messages: request.messages.map((m) => ({
|
||||
role: m.role,
|
||||
content: m.content,
|
||||
})),
|
||||
stream: true,
|
||||
options: {
|
||||
temperature: request.temperature,
|
||||
top_p: request.top_p,
|
||||
num_predict: request.max_tokens,
|
||||
stop: Array.isArray(request.stop) ? request.stop : request.stop ? [request.stop] : undefined,
|
||||
},
|
||||
};
|
||||
|
||||
const response = await this.fetch('/api/chat', {
|
||||
method: 'POST',
|
||||
body: ollamaRequest,
|
||||
timeout: 300000,
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`HTTP ${response.status}`);
|
||||
}
|
||||
|
||||
const reader = response.body?.getReader();
|
||||
if (!reader) {
|
||||
throw new Error('No response body');
|
||||
}
|
||||
|
||||
const decoder = new TextDecoder();
|
||||
const requestId = this.generateRequestId();
|
||||
const created = Math.floor(Date.now() / 1000);
|
||||
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
if (done) break;
|
||||
|
||||
const text = decoder.decode(value);
|
||||
const lines = text.split('\n').filter((l) => l.trim());
|
||||
|
||||
for (const line of lines) {
|
||||
try {
|
||||
const data = JSON.parse(line) as IOllamaChatResponse;
|
||||
|
||||
// Convert to OpenAI streaming format
|
||||
const chunk = {
|
||||
id: requestId,
|
||||
object: 'chat.completion.chunk',
|
||||
created,
|
||||
model: request.model,
|
||||
choices: [
|
||||
{
|
||||
index: 0,
|
||||
delta: {
|
||||
content: data.message.content,
|
||||
} as Partial<IChatMessage>,
|
||||
finish_reason: data.done ? 'stop' : null,
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
onChunk(`data: ${JSON.stringify(chunk)}\n\n`);
|
||||
|
||||
if (data.done) {
|
||||
onChunk('data: [DONE]\n\n');
|
||||
}
|
||||
} catch {
|
||||
// Invalid JSON, skip
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
+15
-9
@@ -4,15 +4,11 @@
|
||||
* Manages HuggingFace Text Generation Inference containers.
|
||||
*/
|
||||
|
||||
import type { IContainerConfig, ILoadedModel, TContainerType } from '../interfaces/container.ts';
|
||||
import type {
|
||||
IContainerConfig,
|
||||
ILoadedModel,
|
||||
TContainerType,
|
||||
} from '../interfaces/container.ts';
|
||||
import type {
|
||||
IChatCompletionChoice,
|
||||
IChatCompletionRequest,
|
||||
IChatCompletionResponse,
|
||||
IChatCompletionChoice,
|
||||
IChatMessage,
|
||||
} from '../interfaces/api.ts';
|
||||
import { CONTAINER_IMAGES, CONTAINER_PORTS } from '../constants.ts';
|
||||
@@ -161,7 +157,9 @@ export class TgiContainer extends BaseContainer {
|
||||
const info = await this.fetchJson<ITgiInfoResponse>('/info');
|
||||
return [info.model_id];
|
||||
} catch (error) {
|
||||
logger.warn(`Failed to get TGI info: ${error instanceof Error ? error.message : String(error)}`);
|
||||
logger.warn(
|
||||
`Failed to get TGI info: ${error instanceof Error ? error.message : String(error)}`,
|
||||
);
|
||||
return this.config.models || [];
|
||||
}
|
||||
}
|
||||
@@ -232,7 +230,11 @@ export class TgiContainer extends BaseContainer {
|
||||
temperature: request.temperature,
|
||||
top_p: request.top_p,
|
||||
max_new_tokens: request.max_tokens || 1024,
|
||||
stop: Array.isArray(request.stop) ? request.stop : request.stop ? [request.stop] : undefined,
|
||||
stop: Array.isArray(request.stop)
|
||||
? request.stop
|
||||
: request.stop
|
||||
? [request.stop]
|
||||
: undefined,
|
||||
do_sample: (request.temperature || 0) > 0,
|
||||
return_full_text: false,
|
||||
},
|
||||
@@ -288,7 +290,11 @@ export class TgiContainer extends BaseContainer {
|
||||
temperature: request.temperature,
|
||||
top_p: request.top_p,
|
||||
max_new_tokens: request.max_tokens || 1024,
|
||||
stop: Array.isArray(request.stop) ? request.stop : request.stop ? [request.stop] : undefined,
|
||||
stop: Array.isArray(request.stop)
|
||||
? request.stop
|
||||
: request.stop
|
||||
? [request.stop]
|
||||
: undefined,
|
||||
do_sample: (request.temperature || 0) > 0,
|
||||
},
|
||||
},
|
||||
|
||||
+30
-13
@@ -4,11 +4,7 @@
|
||||
* Manages vLLM containers for high-performance LLM inference.
|
||||
*/
|
||||
|
||||
import type {
|
||||
IContainerConfig,
|
||||
ILoadedModel,
|
||||
TContainerType,
|
||||
} from '../interfaces/container.ts';
|
||||
import type { IContainerConfig, ILoadedModel, TContainerType } from '../interfaces/container.ts';
|
||||
import type {
|
||||
IChatCompletionRequest,
|
||||
IChatCompletionResponse,
|
||||
@@ -72,20 +68,26 @@ export class VllmContainer extends BaseContainer {
|
||||
gpuIds: string[],
|
||||
options: Partial<IContainerConfig> = {},
|
||||
): IContainerConfig {
|
||||
// vLLM requires model to be specified at startup
|
||||
const command = [
|
||||
'--model', modelName,
|
||||
'--host', '0.0.0.0',
|
||||
'--port', String(options.port || CONTAINER_PORTS.VLLM),
|
||||
const command = options.command ? [...options.command] : [
|
||||
'--model',
|
||||
modelName,
|
||||
];
|
||||
|
||||
if (!command.includes('--host')) {
|
||||
command.push('--host', '0.0.0.0');
|
||||
}
|
||||
|
||||
if (!command.includes('--port')) {
|
||||
command.push('--port', String(options.port || CONTAINER_PORTS.VLLM));
|
||||
}
|
||||
|
||||
// Add tensor parallelism if multiple GPUs
|
||||
if (gpuIds.length > 1) {
|
||||
if (gpuIds.length > 1 && !command.includes('--tensor-parallel-size')) {
|
||||
command.push('--tensor-parallel-size', String(gpuIds.length));
|
||||
}
|
||||
|
||||
// Add additional options
|
||||
if (options.env?.VLLM_MAX_MODEL_LEN) {
|
||||
if (options.env?.VLLM_MAX_MODEL_LEN && !command.includes('--max-model-len')) {
|
||||
command.push('--max-model-len', options.env.VLLM_MAX_MODEL_LEN);
|
||||
}
|
||||
|
||||
@@ -128,11 +130,17 @@ export class VllmContainer extends BaseContainer {
|
||||
* vLLM serves a single model per instance
|
||||
*/
|
||||
public async listModels(): Promise<string[]> {
|
||||
if (this.config.models.length > 0) {
|
||||
return this.config.models;
|
||||
}
|
||||
|
||||
try {
|
||||
const data = await this.fetchJson<IVllmModelsResponse>('/v1/models');
|
||||
return (data.data || []).map((m) => m.id);
|
||||
} catch (error) {
|
||||
logger.warn(`Failed to list vLLM models: ${error instanceof Error ? error.message : String(error)}`);
|
||||
logger.warn(
|
||||
`Failed to list vLLM models: ${error instanceof Error ? error.message : String(error)}`,
|
||||
);
|
||||
return this.config.models || [];
|
||||
}
|
||||
}
|
||||
@@ -141,6 +149,15 @@ export class VllmContainer extends BaseContainer {
|
||||
* Get loaded models with details
|
||||
*/
|
||||
public async getLoadedModels(): Promise<ILoadedModel[]> {
|
||||
if (this.config.models.length > 0) {
|
||||
return this.config.models.map((name) => ({
|
||||
name,
|
||||
size: 0,
|
||||
loaded: true,
|
||||
requestCount: 0,
|
||||
}));
|
||||
}
|
||||
|
||||
try {
|
||||
const data = await this.fetchJson<IVllmModelsResponse>('/v1/models');
|
||||
return (data.data || []).map((m) => ({
|
||||
|
||||
+40
-5
@@ -54,6 +54,10 @@ export class Daemon {
|
||||
// Preload models if configured
|
||||
await this.preloadModels(config);
|
||||
|
||||
await this.syncClusterState(config);
|
||||
await this.modelgrid.getClusterCoordinator().reconcileDesiredReplicas();
|
||||
await this.syncClusterState(config);
|
||||
|
||||
// Setup signal handlers
|
||||
this.setupSignalHandlers();
|
||||
|
||||
@@ -63,7 +67,9 @@ export class Daemon {
|
||||
await this.monitor();
|
||||
} catch (error) {
|
||||
this.isRunning = false;
|
||||
logger.error(`Daemon failed to start: ${error instanceof Error ? error.message : String(error)}`);
|
||||
logger.error(
|
||||
`Daemon failed to start: ${error instanceof Error ? error.message : String(error)}`,
|
||||
);
|
||||
process.exit(1);
|
||||
}
|
||||
}
|
||||
@@ -101,6 +107,8 @@ export class Daemon {
|
||||
config.api,
|
||||
this.modelgrid.getContainerManager(),
|
||||
this.modelgrid.getModelRegistry(),
|
||||
this.modelgrid.getModelLoader(),
|
||||
this.modelgrid.getClusterCoordinator(),
|
||||
);
|
||||
|
||||
await this.apiServer.start();
|
||||
@@ -151,8 +159,16 @@ export class Daemon {
|
||||
|
||||
logger.info(`Preloading ${config.models.autoLoad.length} model(s)...`);
|
||||
|
||||
const modelLoader = this.modelgrid.getModelLoader();
|
||||
const results = await modelLoader.preloadModels(config.models.autoLoad);
|
||||
const clusterCoordinator = this.modelgrid.getClusterCoordinator();
|
||||
const results = new Map<string, { success: boolean; error?: string }>();
|
||||
|
||||
for (const modelName of config.models.autoLoad) {
|
||||
const ensured = await clusterCoordinator.ensureModel(modelName);
|
||||
results.set(modelName, {
|
||||
success: !!ensured,
|
||||
error: ensured ? undefined : 'Failed to schedule preload',
|
||||
});
|
||||
}
|
||||
|
||||
let loaded = 0;
|
||||
let failed = 0;
|
||||
@@ -203,6 +219,10 @@ export class Daemon {
|
||||
// Check container health
|
||||
await this.checkContainerHealth();
|
||||
|
||||
await this.syncClusterState();
|
||||
await this.modelgrid.getClusterCoordinator().reconcileDesiredReplicas();
|
||||
await this.syncClusterState();
|
||||
|
||||
// Log periodic status
|
||||
this.logPeriodicStatus();
|
||||
|
||||
@@ -245,6 +265,19 @@ export class Daemon {
|
||||
}
|
||||
}
|
||||
|
||||
private async syncClusterState(config?: IModelGridConfig): Promise<void> {
|
||||
const effectiveConfig = config || this.modelgrid.getConfig();
|
||||
if (!effectiveConfig) {
|
||||
return;
|
||||
}
|
||||
|
||||
const advertiseUrl = effectiveConfig.cluster.advertiseUrl ||
|
||||
`http://127.0.0.1:${effectiveConfig.api.port}`;
|
||||
const coordinator = this.modelgrid.getClusterCoordinator();
|
||||
await coordinator.syncLocalState(advertiseUrl);
|
||||
await coordinator.sendHeartbeat();
|
||||
}
|
||||
|
||||
/**
|
||||
* Log configuration loaded message
|
||||
*/
|
||||
@@ -252,8 +285,10 @@ export class Daemon {
|
||||
logger.log('');
|
||||
logger.logBoxTitle('Configuration Loaded', 60, 'success');
|
||||
logger.logBoxLine(`API Port: ${config.api.port}`);
|
||||
logger.logBoxLine(`Containers: ${config.containers.length}`);
|
||||
logger.logBoxLine(`Auto-pull: ${config.models.autoPull ? 'Enabled' : 'Disabled'}`);
|
||||
logger.logBoxLine(`Deployments: ${config.containers.length}`);
|
||||
logger.logBoxLine(`Auto-deploy: ${config.models.autoDeploy ? 'Enabled' : 'Disabled'}`);
|
||||
logger.logBoxLine(`Registry: ${config.models.registryUrl}`);
|
||||
logger.logBoxLine(`Cluster Mode: ${config.cluster.role}`);
|
||||
logger.logBoxLine(`Check Interval: ${config.checkInterval / 1000}s`);
|
||||
logger.logBoxEnd();
|
||||
logger.log('');
|
||||
|
||||
@@ -71,7 +71,11 @@ export class ContainerRuntime {
|
||||
logger.success(`Started existing container: ${containerName}`);
|
||||
return true;
|
||||
} catch (error) {
|
||||
logger.error(`Failed to start existing container: ${error instanceof Error ? error.message : String(error)}`);
|
||||
logger.error(
|
||||
`Failed to start existing container: ${
|
||||
error instanceof Error ? error.message : String(error)
|
||||
}`,
|
||||
);
|
||||
// Try to remove and recreate
|
||||
await this.removeContainer(config.id);
|
||||
}
|
||||
@@ -93,7 +97,9 @@ export class ContainerRuntime {
|
||||
|
||||
return true;
|
||||
} catch (error) {
|
||||
logger.error(`Failed to start container: ${error instanceof Error ? error.message : String(error)}`);
|
||||
logger.error(
|
||||
`Failed to start container: ${error instanceof Error ? error.message : String(error)}`,
|
||||
);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -118,7 +124,9 @@ export class ContainerRuntime {
|
||||
logger.success(`Container ${containerName} stopped`);
|
||||
return true;
|
||||
} catch (error) {
|
||||
logger.error(`Failed to stop container: ${error instanceof Error ? error.message : String(error)}`);
|
||||
logger.error(
|
||||
`Failed to stop container: ${error instanceof Error ? error.message : String(error)}`,
|
||||
);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -140,7 +148,9 @@ export class ContainerRuntime {
|
||||
logger.success(`Container ${containerName} removed`);
|
||||
return true;
|
||||
} catch (error) {
|
||||
logger.error(`Failed to remove container: ${error instanceof Error ? error.message : String(error)}`);
|
||||
logger.error(
|
||||
`Failed to remove container: ${error instanceof Error ? error.message : String(error)}`,
|
||||
);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -164,7 +174,9 @@ export class ContainerRuntime {
|
||||
logger.success(`Container ${containerName} restarted`);
|
||||
return true;
|
||||
} catch (error) {
|
||||
logger.error(`Failed to restart container: ${error instanceof Error ? error.message : String(error)}`);
|
||||
logger.error(
|
||||
`Failed to restart container: ${error instanceof Error ? error.message : String(error)}`,
|
||||
);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -248,7 +260,9 @@ export class ContainerRuntime {
|
||||
status.cpuUsage = stats.cpuUsage;
|
||||
}
|
||||
} catch (error) {
|
||||
logger.dim(`Error getting container status: ${error instanceof Error ? error.message : String(error)}`);
|
||||
logger.dim(
|
||||
`Error getting container status: ${error instanceof Error ? error.message : String(error)}`,
|
||||
);
|
||||
}
|
||||
|
||||
return status;
|
||||
@@ -295,16 +309,6 @@ export class ContainerRuntime {
|
||||
|
||||
try {
|
||||
switch (config.type) {
|
||||
case 'ollama': {
|
||||
// Query Ollama API for loaded models
|
||||
const { stdout } = await execAsync(
|
||||
`docker exec ${containerName} curl -s http://localhost:11434/api/tags`,
|
||||
{ timeout: 5000 },
|
||||
);
|
||||
const data = JSON.parse(stdout);
|
||||
return (data.models || []).map((m: { name: string }) => m.name);
|
||||
}
|
||||
|
||||
case 'vllm':
|
||||
case 'tgi': {
|
||||
// These typically serve a single model
|
||||
|
||||
@@ -296,7 +296,9 @@ export class DockerManager {
|
||||
await execAsync('systemctl enable docker');
|
||||
logger.success('Docker service started and enabled');
|
||||
} catch (error) {
|
||||
logger.warn(`Could not start Docker service: ${error instanceof Error ? error.message : String(error)}`);
|
||||
logger.warn(
|
||||
`Could not start Docker service: ${error instanceof Error ? error.message : String(error)}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -308,7 +310,9 @@ export class DockerManager {
|
||||
await execAsync('systemctl stop docker');
|
||||
logger.success('Docker service stopped');
|
||||
} catch (error) {
|
||||
logger.warn(`Could not stop Docker service: ${error instanceof Error ? error.message : String(error)}`);
|
||||
logger.warn(
|
||||
`Could not stop Docker service: ${error instanceof Error ? error.message : String(error)}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -320,7 +324,11 @@ export class DockerManager {
|
||||
await execAsync('systemctl restart docker');
|
||||
logger.success('Docker service restarted');
|
||||
} catch (error) {
|
||||
logger.warn(`Could not restart Docker service: ${error instanceof Error ? error.message : String(error)}`);
|
||||
logger.warn(
|
||||
`Could not restart Docker service: ${
|
||||
error instanceof Error ? error.message : String(error)
|
||||
}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -352,7 +360,9 @@ export class DockerManager {
|
||||
logger.success(`Created Docker network '${this.networkName}'`);
|
||||
return true;
|
||||
} catch (error) {
|
||||
logger.error(`Failed to create network: ${error instanceof Error ? error.message : String(error)}`);
|
||||
logger.error(
|
||||
`Failed to create network: ${error instanceof Error ? error.message : String(error)}`,
|
||||
);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -372,7 +382,9 @@ export class DockerManager {
|
||||
logger.success(`Removed Docker network '${this.networkName}'`);
|
||||
return true;
|
||||
} catch (error) {
|
||||
logger.error(`Failed to remove network: ${error instanceof Error ? error.message : String(error)}`);
|
||||
logger.error(
|
||||
`Failed to remove network: ${error instanceof Error ? error.message : String(error)}`,
|
||||
);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -389,7 +401,9 @@ export class DockerManager {
|
||||
logger.success(`Pulled image: ${image}`);
|
||||
return true;
|
||||
} catch (error) {
|
||||
logger.error(`Failed to pull image: ${error instanceof Error ? error.message : String(error)}`);
|
||||
logger.error(
|
||||
`Failed to pull image: ${error instanceof Error ? error.message : String(error)}`,
|
||||
);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -454,7 +468,11 @@ export class DockerManager {
|
||||
logger.info('Log out and log back in for the change to take effect');
|
||||
return true;
|
||||
} catch (error) {
|
||||
logger.error(`Failed to add user to docker group: ${error instanceof Error ? error.message : String(error)}`);
|
||||
logger.error(
|
||||
`Failed to add user to docker group: ${
|
||||
error instanceof Error ? error.message : String(error)
|
||||
}`,
|
||||
);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
+27
-9
@@ -20,10 +20,13 @@ export class AmdDriver extends BaseDriver {
|
||||
*/
|
||||
public async isInstalled(): Promise<boolean> {
|
||||
try {
|
||||
const { stdout } = await this.execCommand('rocm-smi --showdriverversion 2>/dev/null | head -1', {
|
||||
timeout: 5000,
|
||||
ignoreErrors: true,
|
||||
});
|
||||
const { stdout } = await this.execCommand(
|
||||
'rocm-smi --showdriverversion 2>/dev/null | head -1',
|
||||
{
|
||||
timeout: 5000,
|
||||
ignoreErrors: true,
|
||||
},
|
||||
);
|
||||
return stdout.includes('Driver');
|
||||
} catch {
|
||||
return false;
|
||||
@@ -114,7 +117,10 @@ export class AmdDriver extends BaseDriver {
|
||||
try {
|
||||
if (distro.id === 'ubuntu') {
|
||||
return await this.installOnUbuntu(options);
|
||||
} else if (distro.id === 'rhel' || distro.id === 'centos' || distro.id === 'rocky' || distro.id === 'almalinux') {
|
||||
} else if (
|
||||
distro.id === 'rhel' || distro.id === 'centos' || distro.id === 'rocky' ||
|
||||
distro.id === 'almalinux'
|
||||
) {
|
||||
return await this.installOnRhel(options);
|
||||
} else {
|
||||
logger.error(`Unsupported distribution: ${distro.id}`);
|
||||
@@ -122,7 +128,11 @@ export class AmdDriver extends BaseDriver {
|
||||
return false;
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error(`Failed to install AMD ROCm drivers: ${error instanceof Error ? error.message : String(error)}`);
|
||||
logger.error(
|
||||
`Failed to install AMD ROCm drivers: ${
|
||||
error instanceof Error ? error.message : String(error)
|
||||
}`,
|
||||
);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -152,7 +162,9 @@ export class AmdDriver extends BaseDriver {
|
||||
|
||||
// Add AMDGPU repository
|
||||
await this.execCommand(
|
||||
`echo "deb [arch=amd64] https://repo.radeon.com/amdgpu/${rocmVersion}/ubuntu ${ubuntuVersion === '2204' ? 'jammy' : 'focal'} main" > /etc/apt/sources.list.d/amdgpu.list`,
|
||||
`echo "deb [arch=amd64] https://repo.radeon.com/amdgpu/${rocmVersion}/ubuntu ${
|
||||
ubuntuVersion === '2204' ? 'jammy' : 'focal'
|
||||
} main" > /etc/apt/sources.list.d/amdgpu.list`,
|
||||
);
|
||||
|
||||
await this.aptUpdate();
|
||||
@@ -250,7 +262,9 @@ EOF`,
|
||||
// No special runtime needed, just need to pass --device flags
|
||||
|
||||
// Verify device files exist
|
||||
const { stdout: devices } = await this.execCommand('ls -la /dev/kfd /dev/dri/render* 2>/dev/null || true');
|
||||
const { stdout: devices } = await this.execCommand(
|
||||
'ls -la /dev/kfd /dev/dri/render* 2>/dev/null || true',
|
||||
);
|
||||
|
||||
if (!devices.includes('/dev/kfd')) {
|
||||
logger.warn('/dev/kfd not found. ROCm driver may not be properly loaded.');
|
||||
@@ -266,7 +280,11 @@ EOF`,
|
||||
logger.info(' --device=/dev/kfd --device=/dev/dri --group-add video');
|
||||
return true;
|
||||
} catch (error) {
|
||||
logger.error(`Failed to configure ROCm container support: ${error instanceof Error ? error.message : String(error)}`);
|
||||
logger.error(
|
||||
`Failed to configure ROCm container support: ${
|
||||
error instanceof Error ? error.message : String(error)
|
||||
}`,
|
||||
);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -177,7 +177,9 @@ export abstract class BaseDriver {
|
||||
protected async addAptRepository(repo: string, keyUrl?: string): Promise<void> {
|
||||
if (keyUrl) {
|
||||
// Add GPG key
|
||||
await this.execCommand(`curl -fsSL ${keyUrl} | gpg --dearmor -o /usr/share/keyrings/$(basename ${keyUrl}).gpg`);
|
||||
await this.execCommand(
|
||||
`curl -fsSL ${keyUrl} | gpg --dearmor -o /usr/share/keyrings/$(basename ${keyUrl}).gpg`,
|
||||
);
|
||||
}
|
||||
await this.execCommand(`add-apt-repository -y "${repo}"`);
|
||||
}
|
||||
@@ -188,7 +190,11 @@ export abstract class BaseDriver {
|
||||
public async logStatus(): Promise<void> {
|
||||
const status = await this.getStatus();
|
||||
|
||||
logger.logBoxTitle(`${this.displayName} Driver Status`, 60, status.installed ? 'success' : 'warning');
|
||||
logger.logBoxTitle(
|
||||
`${this.displayName} Driver Status`,
|
||||
60,
|
||||
status.installed ? 'success' : 'warning',
|
||||
);
|
||||
logger.logBoxLine(`Installed: ${status.installed ? 'Yes' : 'No'}`);
|
||||
|
||||
if (status.installed) {
|
||||
|
||||
@@ -21,7 +21,7 @@ export class DriverManager {
|
||||
|
||||
constructor() {
|
||||
this.gpuDetector = new GpuDetector();
|
||||
this.drivers = new Map([
|
||||
this.drivers = new Map<TGpuVendor, BaseDriver>([
|
||||
['nvidia', new NvidiaDriver()],
|
||||
['amd', new AmdDriver()],
|
||||
['intel', new IntelDriver()],
|
||||
@@ -197,10 +197,15 @@ export class DriverManager {
|
||||
// Print status for each vendor
|
||||
for (const [vendor, gpuList] of vendorGpus) {
|
||||
if (vendor === 'unknown') {
|
||||
logger.logBox('Unknown GPUs', [
|
||||
`${gpuList.length} GPU(s) with unknown vendor`,
|
||||
'Manual driver installation may be required',
|
||||
], 50, 'warning');
|
||||
logger.logBox(
|
||||
'Unknown GPUs',
|
||||
[
|
||||
`${gpuList.length} GPU(s) with unknown vendor`,
|
||||
'Manual driver installation may be required',
|
||||
],
|
||||
50,
|
||||
'warning',
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -219,9 +224,7 @@ export class DriverManager {
|
||||
const args: string[] = [];
|
||||
|
||||
// Filter to specific GPUs if provided
|
||||
const targetGpus = gpuIds
|
||||
? gpus.filter((g) => gpuIds.includes(g.id))
|
||||
: gpus;
|
||||
const targetGpus = gpuIds ? gpus.filter((g) => gpuIds.includes(g.id)) : gpus;
|
||||
|
||||
if (targetGpus.length === 0) {
|
||||
return args;
|
||||
|
||||
+18
-4
@@ -138,7 +138,11 @@ export class IntelDriver extends BaseDriver {
|
||||
return false;
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error(`Failed to install Intel drivers: ${error instanceof Error ? error.message : String(error)}`);
|
||||
logger.error(
|
||||
`Failed to install Intel drivers: ${
|
||||
error instanceof Error ? error.message : String(error)
|
||||
}`,
|
||||
);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -159,7 +163,11 @@ export class IntelDriver extends BaseDriver {
|
||||
);
|
||||
|
||||
const distro = await this.getLinuxDistro();
|
||||
const ubuntuCodename = distro.version === '22.04' ? 'jammy' : distro.version === '24.04' ? 'noble' : 'jammy';
|
||||
const ubuntuCodename = distro.version === '22.04'
|
||||
? 'jammy'
|
||||
: distro.version === '24.04'
|
||||
? 'noble'
|
||||
: 'jammy';
|
||||
|
||||
await this.execCommand(
|
||||
`echo "deb [arch=amd64 signed-by=/usr/share/keyrings/intel-graphics.gpg] https://repositories.intel.com/graphics/ubuntu ${ubuntuCodename} arc" > /etc/apt/sources.list.d/intel-graphics.list`,
|
||||
@@ -308,7 +316,9 @@ EOF`,
|
||||
try {
|
||||
// Intel GPUs work by passing through device files
|
||||
// Verify render devices exist
|
||||
const { stdout: devices } = await this.execCommand('ls -la /dev/dri/renderD* 2>/dev/null || true');
|
||||
const { stdout: devices } = await this.execCommand(
|
||||
'ls -la /dev/dri/renderD* 2>/dev/null || true',
|
||||
);
|
||||
|
||||
if (!devices.includes('renderD')) {
|
||||
logger.warn('/dev/dri/renderD* not found. Intel GPU driver may not be properly loaded.');
|
||||
@@ -323,7 +333,11 @@ EOF`,
|
||||
logger.info(' --device=/dev/dri --group-add render');
|
||||
return true;
|
||||
} catch (error) {
|
||||
logger.error(`Failed to configure Intel container support: ${error instanceof Error ? error.message : String(error)}`);
|
||||
logger.error(
|
||||
`Failed to configure Intel container support: ${
|
||||
error instanceof Error ? error.message : String(error)
|
||||
}`,
|
||||
);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
+37
-14
@@ -20,10 +20,13 @@ export class NvidiaDriver extends BaseDriver {
|
||||
*/
|
||||
public async isInstalled(): Promise<boolean> {
|
||||
try {
|
||||
const { stdout } = await this.execCommand('nvidia-smi --query-gpu=driver_version --format=csv,noheader', {
|
||||
timeout: 5000,
|
||||
ignoreErrors: true,
|
||||
});
|
||||
const { stdout } = await this.execCommand(
|
||||
'nvidia-smi --query-gpu=driver_version --format=csv,noheader',
|
||||
{
|
||||
timeout: 5000,
|
||||
ignoreErrors: true,
|
||||
},
|
||||
);
|
||||
return stdout.trim().length > 0;
|
||||
} catch {
|
||||
return false;
|
||||
@@ -115,7 +118,10 @@ export class NvidiaDriver extends BaseDriver {
|
||||
try {
|
||||
if (distro.id === 'ubuntu' || distro.id === 'debian') {
|
||||
return await this.installOnDebian(options);
|
||||
} else if (distro.id === 'fedora' || distro.id === 'rhel' || distro.id === 'centos' || distro.id === 'rocky' || distro.id === 'almalinux') {
|
||||
} else if (
|
||||
distro.id === 'fedora' || distro.id === 'rhel' || distro.id === 'centos' ||
|
||||
distro.id === 'rocky' || distro.id === 'almalinux'
|
||||
) {
|
||||
return await this.installOnRhel(options);
|
||||
} else {
|
||||
logger.error(`Unsupported distribution: ${distro.id}`);
|
||||
@@ -123,7 +129,11 @@ export class NvidiaDriver extends BaseDriver {
|
||||
return false;
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error(`Failed to install NVIDIA drivers: ${error instanceof Error ? error.message : String(error)}`);
|
||||
logger.error(
|
||||
`Failed to install NVIDIA drivers: ${
|
||||
error instanceof Error ? error.message : String(error)
|
||||
}`,
|
||||
);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -181,7 +191,9 @@ export class NvidiaDriver extends BaseDriver {
|
||||
|
||||
// Add NVIDIA CUDA repository
|
||||
const distro = await this.getLinuxDistro();
|
||||
const repoUrl = `https://developer.download.nvidia.com/compute/cuda/repos/rhel${distro.version.split('.')[0]}/x86_64/cuda-rhel${distro.version.split('.')[0]}.repo`;
|
||||
const repoUrl = `https://developer.download.nvidia.com/compute/cuda/repos/rhel${
|
||||
distro.version.split('.')[0]
|
||||
}/x86_64/cuda-rhel${distro.version.split('.')[0]}.repo`;
|
||||
|
||||
await this.execCommand(`dnf config-manager --add-repo ${repoUrl}`);
|
||||
|
||||
@@ -213,8 +225,11 @@ export class NvidiaDriver extends BaseDriver {
|
||||
|
||||
if (distro.id === 'ubuntu' || distro.id === 'debian') {
|
||||
// Add CUDA repository
|
||||
const cudaKeyUrl = 'https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb';
|
||||
await this.execCommand(`wget -q ${cudaKeyUrl} -O /tmp/cuda-keyring.deb && dpkg -i /tmp/cuda-keyring.deb`);
|
||||
const cudaKeyUrl =
|
||||
'https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb';
|
||||
await this.execCommand(
|
||||
`wget -q ${cudaKeyUrl} -O /tmp/cuda-keyring.deb && dpkg -i /tmp/cuda-keyring.deb`,
|
||||
);
|
||||
await this.aptUpdate();
|
||||
|
||||
const cudaPackage = options.toolkitVersion
|
||||
@@ -247,8 +262,8 @@ export class NvidiaDriver extends BaseDriver {
|
||||
const distribution = `${distro.id}${distro.version}`;
|
||||
await this.execCommand(
|
||||
`curl -s -L https://nvidia.github.io/libnvidia-container/${distribution}/libnvidia-container.list | ` +
|
||||
'sed "s#deb https://#deb [signed-by=/usr/share/keyrings/nvidia-container-toolkit-keyring.gpg] https://#g" | ' +
|
||||
'tee /etc/apt/sources.list.d/nvidia-container-toolkit.list',
|
||||
'sed "s#deb https://#deb [signed-by=/usr/share/keyrings/nvidia-container-toolkit-keyring.gpg] https://#g" | ' +
|
||||
'tee /etc/apt/sources.list.d/nvidia-container-toolkit.list',
|
||||
);
|
||||
|
||||
await this.aptUpdate();
|
||||
@@ -257,7 +272,7 @@ export class NvidiaDriver extends BaseDriver {
|
||||
// RHEL/Fedora
|
||||
await this.execCommand(
|
||||
'curl -s -L https://nvidia.github.io/libnvidia-container/stable/rpm/nvidia-container-toolkit.repo | ' +
|
||||
'tee /etc/yum.repos.d/nvidia-container-toolkit.repo',
|
||||
'tee /etc/yum.repos.d/nvidia-container-toolkit.repo',
|
||||
);
|
||||
await this.dnfInstall('nvidia-container-toolkit');
|
||||
}
|
||||
@@ -268,7 +283,11 @@ export class NvidiaDriver extends BaseDriver {
|
||||
logger.success('NVIDIA Container Toolkit installed successfully');
|
||||
return true;
|
||||
} catch (error) {
|
||||
logger.error(`Failed to install NVIDIA Container Toolkit: ${error instanceof Error ? error.message : String(error)}`);
|
||||
logger.error(
|
||||
`Failed to install NVIDIA Container Toolkit: ${
|
||||
error instanceof Error ? error.message : String(error)
|
||||
}`,
|
||||
);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -288,7 +307,11 @@ export class NvidiaDriver extends BaseDriver {
|
||||
|
||||
logger.success('Docker configured to use NVIDIA runtime');
|
||||
} catch (error) {
|
||||
logger.warn(`Could not configure Docker runtime automatically: ${error instanceof Error ? error.message : String(error)}`);
|
||||
logger.warn(
|
||||
`Could not configure Docker runtime automatically: ${
|
||||
error instanceof Error ? error.message : String(error)
|
||||
}`,
|
||||
);
|
||||
logger.info('Please run: nvidia-ctk runtime configure --runtime=docker');
|
||||
}
|
||||
}
|
||||
|
||||
@@ -96,9 +96,12 @@ export class GpuDetector {
|
||||
// Get CUDA version separately
|
||||
if (gpus.length > 0) {
|
||||
try {
|
||||
const { stdout: cudaOut } = await execAsync('nvidia-smi --query-gpu=driver_version --format=csv,noheader | head -1 && nvcc --version 2>/dev/null | grep "release" | sed "s/.*release \\([0-9.]*\\).*/\\1/"', {
|
||||
timeout: 5000,
|
||||
});
|
||||
const { stdout: cudaOut } = await execAsync(
|
||||
'nvidia-smi --query-gpu=driver_version --format=csv,noheader | head -1 && nvcc --version 2>/dev/null | grep "release" | sed "s/.*release \\([0-9.]*\\).*/\\1/"',
|
||||
{
|
||||
timeout: 5000,
|
||||
},
|
||||
);
|
||||
const cudaMatch = cudaOut.match(/(\d+\.\d+)/);
|
||||
if (cudaMatch) {
|
||||
for (const gpu of gpus) {
|
||||
@@ -142,7 +145,9 @@ export class GpuDetector {
|
||||
id: `amd-${index}`,
|
||||
vendor: 'amd',
|
||||
model: String(cardData['Card series'] || cardData['card_series'] || 'AMD GPU'),
|
||||
vram: this.parseMemory(String(cardData['VRAM Total Memory (B)'] || cardData['vram_total'] || '0')),
|
||||
vram: this.parseMemory(
|
||||
String(cardData['VRAM Total Memory (B)'] || cardData['vram_total'] || '0'),
|
||||
),
|
||||
driverVersion: String(cardData['Driver version'] || cardData['driver_version'] || ''),
|
||||
rocmVersion: await this.getRocmVersion(),
|
||||
pciSlot: String(cardData['PCI Bus'] || cardData['pci_bus'] || ''),
|
||||
@@ -371,14 +376,17 @@ export class GpuDetector {
|
||||
);
|
||||
|
||||
const parts = stdout.trim().split(',').map((p: string) => p.trim());
|
||||
const [utilization, memUsed, memTotal, temp, power, powerLimit, fan, gpuClock, memClock] = parts;
|
||||
const [utilization, memUsed, memTotal, temp, power, powerLimit, fan, gpuClock, memClock] =
|
||||
parts;
|
||||
|
||||
return {
|
||||
id: gpu.id,
|
||||
utilization: parseInt(utilization, 10) || 0,
|
||||
memoryUsed: parseInt(memUsed, 10) || 0,
|
||||
memoryTotal: parseInt(memTotal, 10) || gpu.vram,
|
||||
memoryPercent: memTotal ? Math.round((parseInt(memUsed, 10) / parseInt(memTotal, 10)) * 100) : 0,
|
||||
memoryPercent: memTotal
|
||||
? Math.round((parseInt(memUsed, 10) / parseInt(memTotal, 10)) * 100)
|
||||
: 0,
|
||||
temperature: parseInt(temp, 10) || 0,
|
||||
powerUsage: parseFloat(power) || 0,
|
||||
powerLimit: parseFloat(powerLimit) || 0,
|
||||
@@ -513,7 +521,7 @@ export class GpuDetector {
|
||||
bytes /= 1024;
|
||||
break;
|
||||
case 'B':
|
||||
bytes /= (1024 * 1024);
|
||||
bytes /= 1024 * 1024;
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -542,7 +550,9 @@ export class GpuDetector {
|
||||
*/
|
||||
private async getRocmVersion(): Promise<string | undefined> {
|
||||
try {
|
||||
const { stdout } = await execAsync('cat /opt/rocm/.info/version 2>/dev/null || rocminfo 2>/dev/null | grep "ROCm" | head -1');
|
||||
const { stdout } = await execAsync(
|
||||
'cat /opt/rocm/.info/version 2>/dev/null || rocminfo 2>/dev/null | grep "ROCm" | head -1',
|
||||
);
|
||||
const match = stdout.match(/(\d+\.\d+(?:\.\d+)?)/);
|
||||
return match ? match[1] : undefined;
|
||||
} catch {
|
||||
@@ -555,7 +565,9 @@ export class GpuDetector {
|
||||
*/
|
||||
private async getOneApiVersion(): Promise<string | undefined> {
|
||||
try {
|
||||
const { stdout } = await execAsync('source /opt/intel/oneapi/setvars.sh 2>/dev/null && echo $ONEAPI_ROOT 2>/dev/null || cat /opt/intel/oneapi/compiler/latest/env/vars.sh 2>/dev/null | grep VERSION');
|
||||
const { stdout } = await execAsync(
|
||||
'source /opt/intel/oneapi/setvars.sh 2>/dev/null && echo $ONEAPI_ROOT 2>/dev/null || cat /opt/intel/oneapi/compiler/latest/env/vars.sh 2>/dev/null | grep VERSION',
|
||||
);
|
||||
const match = stdout.match(/(\d+\.\d+(?:\.\d+)?)/);
|
||||
return match ? match[1] : undefined;
|
||||
} catch {
|
||||
|
||||
@@ -105,7 +105,9 @@ export class SystemInfo {
|
||||
*/
|
||||
private async getNvidiaContainerVersion(): Promise<string | undefined> {
|
||||
try {
|
||||
const { stdout } = await execAsync('nvidia-container-cli --version 2>&1 | head -1', { timeout: 5000 });
|
||||
const { stdout } = await execAsync('nvidia-container-cli --version 2>&1 | head -1', {
|
||||
timeout: 5000,
|
||||
});
|
||||
const match = stdout.match(/version (\d+\.\d+\.\d+)/);
|
||||
return match ? match[1] : undefined;
|
||||
} catch {
|
||||
@@ -156,7 +158,9 @@ export class SystemInfo {
|
||||
*/
|
||||
public async getAvailableDiskSpace(path: string = '/var/lib'): Promise<number> {
|
||||
try {
|
||||
const { stdout } = await execAsync(`df -m "${path}" | tail -1 | awk '{print $4}'`, { timeout: 5000 });
|
||||
const { stdout } = await execAsync(`df -m "${path}" | tail -1 | awk '{print $4}'`, {
|
||||
timeout: 5000,
|
||||
});
|
||||
return parseInt(stdout.trim(), 10) || 0;
|
||||
} catch {
|
||||
return 0;
|
||||
@@ -198,7 +202,11 @@ export class SystemInfo {
|
||||
logger.logBoxLine(`OS: ${info.os}`);
|
||||
logger.logBoxLine(`Kernel: ${info.kernelVersion}`);
|
||||
logger.logBoxLine(`CPU: ${info.cpuModel} (${info.cpuCores} cores)`);
|
||||
logger.logBoxLine(`RAM: ${Math.round(info.ramTotal / 1024)} GB total, ${Math.round(info.ramAvailable / 1024)} GB available`);
|
||||
logger.logBoxLine(
|
||||
`RAM: ${Math.round(info.ramTotal / 1024)} GB total, ${
|
||||
Math.round(info.ramAvailable / 1024)
|
||||
} GB available`,
|
||||
);
|
||||
logger.logBoxLine('');
|
||||
|
||||
if (info.dockerVersion) {
|
||||
|
||||
@@ -36,5 +36,6 @@ export * from './hardware/index.ts';
|
||||
export * from './drivers/index.ts';
|
||||
export * from './docker/index.ts';
|
||||
export * from './containers/index.ts';
|
||||
export * from './cluster/index.ts';
|
||||
export * from './models/index.ts';
|
||||
export * from './api/index.ts';
|
||||
|
||||
@@ -0,0 +1,56 @@
|
||||
/**
|
||||
* Model catalog interfaces for list.modelgrid.com.
|
||||
*/
|
||||
|
||||
export interface IModelCapabilitySet {
|
||||
chat?: boolean;
|
||||
completions?: boolean;
|
||||
embeddings?: boolean;
|
||||
tools?: boolean;
|
||||
}
|
||||
|
||||
export interface IVllmLaunchProfile {
|
||||
replicas?: number;
|
||||
tensorParallelSize?: number;
|
||||
pipelineParallelSize?: number;
|
||||
maxModelLen?: number;
|
||||
gpuMemoryUtilization?: number;
|
||||
quantization?: string;
|
||||
dtype?: string;
|
||||
generationConfig?: 'auto' | 'vllm';
|
||||
extraArgs?: string[];
|
||||
env?: Record<string, string>;
|
||||
}
|
||||
|
||||
export interface IModelCatalogEntry {
|
||||
id: string;
|
||||
aliases?: string[];
|
||||
engine: 'vllm';
|
||||
source: {
|
||||
repo: string;
|
||||
revision?: string;
|
||||
tokenizer?: string;
|
||||
license?: string;
|
||||
homepage?: string;
|
||||
};
|
||||
capabilities: IModelCapabilitySet;
|
||||
requirements: {
|
||||
minVramGb: number;
|
||||
recommendedVramGb?: number;
|
||||
minGpuCount?: number;
|
||||
};
|
||||
launchDefaults?: IVllmLaunchProfile;
|
||||
metadata?: {
|
||||
family?: string;
|
||||
parameterCount?: string;
|
||||
contextWindow?: number;
|
||||
summary?: string;
|
||||
tags?: string[];
|
||||
};
|
||||
}
|
||||
|
||||
export interface IModelCatalog {
|
||||
version: string;
|
||||
generatedAt: string;
|
||||
models: IModelCatalogEntry[];
|
||||
}
|
||||
@@ -0,0 +1,91 @@
|
||||
/**
|
||||
* Cluster and deployment interfaces.
|
||||
*/
|
||||
|
||||
export type TClusterRole = 'standalone' | 'control-plane' | 'worker';
|
||||
export type TClusterNodeSchedulerState = 'active' | 'cordoned' | 'draining';
|
||||
|
||||
export interface IClusterConfig {
|
||||
enabled: boolean;
|
||||
nodeName: string;
|
||||
role: TClusterRole;
|
||||
bindHost: string;
|
||||
gossipPort: number;
|
||||
sharedSecret?: string;
|
||||
advertiseUrl?: string;
|
||||
controlPlaneUrl?: string;
|
||||
heartbeatIntervalMs?: number;
|
||||
seedNodes?: string[];
|
||||
}
|
||||
|
||||
export interface IClusterNodeStatus {
|
||||
nodeName: string;
|
||||
role: TClusterRole;
|
||||
endpoint?: string;
|
||||
healthy: boolean;
|
||||
schedulerState?: TClusterNodeSchedulerState;
|
||||
}
|
||||
|
||||
export interface IClusterNodeResources {
|
||||
gpuCount: number;
|
||||
totalVramGb: number;
|
||||
availableVramGb: number;
|
||||
maxSingleGpuVramGb: number;
|
||||
largestGpuGroupCount: number;
|
||||
largestGpuGroupVramGb: number;
|
||||
deploymentCount: number;
|
||||
topologyGroups: IClusterGpuTopologyGroup[];
|
||||
}
|
||||
|
||||
export interface IClusterGpuTopologyGroup {
|
||||
id: string;
|
||||
vendor: 'nvidia' | 'amd' | 'intel' | 'unknown';
|
||||
gpuIds: string[];
|
||||
gpuCount: number;
|
||||
totalVramGb: number;
|
||||
maxSingleGpuVramGb: number;
|
||||
busNumbers: number[];
|
||||
}
|
||||
|
||||
export interface IClusterDeploymentAdvertisement {
|
||||
modelId: string;
|
||||
engine: 'vllm';
|
||||
endpoint: string;
|
||||
healthy: boolean;
|
||||
containerId?: string;
|
||||
}
|
||||
|
||||
export interface IClusterNodeHeartbeat extends IClusterNodeStatus {
|
||||
endpoint: string;
|
||||
resources: IClusterNodeResources;
|
||||
deployments: IClusterDeploymentAdvertisement[];
|
||||
lastSeenAt: number;
|
||||
}
|
||||
|
||||
export interface IClusterModelLocation {
|
||||
modelId: string;
|
||||
nodeName: string;
|
||||
endpoint: string;
|
||||
healthy: boolean;
|
||||
engine: 'vllm';
|
||||
containerId?: string;
|
||||
}
|
||||
|
||||
export interface IClusterEnsureResponse {
|
||||
model: string;
|
||||
location: IClusterModelLocation;
|
||||
created: boolean;
|
||||
}
|
||||
|
||||
export interface IClusterDesiredDeployment {
|
||||
modelId: string;
|
||||
desiredReplicas: number;
|
||||
updatedAt: number;
|
||||
}
|
||||
|
||||
export interface IClusterStatusResponse {
|
||||
localNode: IClusterNodeHeartbeat | null;
|
||||
nodes: IClusterNodeHeartbeat[];
|
||||
models: Record<string, IClusterModelLocation[]>;
|
||||
desiredDeployments: IClusterDesiredDeployment[];
|
||||
}
|
||||
+13
-36
@@ -1,9 +1,9 @@
|
||||
/**
|
||||
* ModelGrid Configuration Interfaces
|
||||
*
|
||||
* Defines the configuration structure for the ModelGrid daemon.
|
||||
* ModelGrid configuration interfaces.
|
||||
*/
|
||||
|
||||
import type { IModelCatalog, IModelCatalogEntry } from './catalog.ts';
|
||||
import type { IClusterConfig } from './cluster.ts';
|
||||
import type { IContainerConfig } from './container.ts';
|
||||
|
||||
/**
|
||||
@@ -50,12 +50,12 @@ export interface IGpuAssignmentConfig {
|
||||
* Model management configuration
|
||||
*/
|
||||
export interface IModelConfig {
|
||||
/** URL to fetch greenlit models list */
|
||||
greenlistUrl: string;
|
||||
/** Whether to auto-pull models when requested */
|
||||
autoPull: boolean;
|
||||
/** Default container type for new models */
|
||||
defaultContainer: 'ollama' | 'vllm' | 'tgi';
|
||||
/** URL to fetch the public catalog */
|
||||
registryUrl: string;
|
||||
/** Whether to auto-start a deployment when requested */
|
||||
autoDeploy: boolean;
|
||||
/** Default engine for new deployments */
|
||||
defaultEngine: 'vllm';
|
||||
/** Models to auto-load on startup */
|
||||
autoLoad: string[];
|
||||
}
|
||||
@@ -76,37 +76,14 @@ export interface IModelGridConfig {
|
||||
containers: IContainerConfig[];
|
||||
/** Model management configuration */
|
||||
models: IModelConfig;
|
||||
/** Cluster configuration */
|
||||
cluster: IClusterConfig;
|
||||
/** Health check interval in milliseconds */
|
||||
checkInterval: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* Greenlit model entry from remote list
|
||||
*/
|
||||
export interface IGreenlitModel {
|
||||
/** Model name (e.g., "llama3:8b") */
|
||||
name: string;
|
||||
/** Preferred container type */
|
||||
container: 'ollama' | 'vllm' | 'tgi';
|
||||
/** Minimum VRAM required in GB */
|
||||
minVram: number;
|
||||
/** Optional tags for categorization */
|
||||
tags?: string[];
|
||||
/** Optional description */
|
||||
description?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Greenlit models list structure
|
||||
*/
|
||||
export interface IGreenlitModelsList {
|
||||
/** List version */
|
||||
version: string;
|
||||
/** Last updated timestamp */
|
||||
lastUpdated: string;
|
||||
/** List of greenlit models */
|
||||
models: IGreenlitModel[];
|
||||
}
|
||||
export type IRegistryModel = IModelCatalogEntry;
|
||||
export type IRegistryCatalog = IModelCatalog;
|
||||
|
||||
/**
|
||||
* Update status information
|
||||
|
||||
@@ -1,13 +1,11 @@
|
||||
/**
|
||||
* ModelGrid Container Interfaces
|
||||
*
|
||||
* Defines types for container management (Ollama, vLLM, TGI).
|
||||
* ModelGrid container interfaces.
|
||||
*/
|
||||
|
||||
/**
|
||||
* Container type
|
||||
*/
|
||||
export type TContainerType = 'ollama' | 'vllm' | 'tgi' | 'custom';
|
||||
export type TContainerType = 'vllm' | 'tgi' | 'custom';
|
||||
|
||||
/**
|
||||
* Container health status
|
||||
|
||||
@@ -5,6 +5,8 @@
|
||||
*/
|
||||
|
||||
export * from './config.ts';
|
||||
export * from './catalog.ts';
|
||||
export * from './cluster.ts';
|
||||
export * from './gpu.ts';
|
||||
export * from './container.ts';
|
||||
export * from './api.ts';
|
||||
|
||||
+5
-1
@@ -276,7 +276,11 @@ export class Logger {
|
||||
* @param rows Array of data objects
|
||||
* @param title Optional table title
|
||||
*/
|
||||
public logTable(columns: ITableColumn[], rows: Record<string, string>[], title?: string): void {
|
||||
public logTable(
|
||||
columns: ITableColumn[],
|
||||
rows: Record<string, string | number>[],
|
||||
title?: string,
|
||||
): void {
|
||||
if (rows.length === 0) {
|
||||
this.dim('No data to display');
|
||||
return;
|
||||
|
||||
+107
-4
@@ -14,10 +14,13 @@ import { SystemInfo } from './hardware/system-info.ts';
|
||||
import { DriverManager } from './drivers/driver-manager.ts';
|
||||
import { DockerManager } from './docker/docker-manager.ts';
|
||||
import { ContainerManager } from './containers/container-manager.ts';
|
||||
import { ClusterCoordinator } from './cluster/coordinator.ts';
|
||||
import { ClusterManager } from './cluster/cluster-manager.ts';
|
||||
import { ModelRegistry } from './models/registry.ts';
|
||||
import { ModelLoader } from './models/loader.ts';
|
||||
import { GpuHandler } from './cli/gpu-handler.ts';
|
||||
import { ContainerHandler } from './cli/container-handler.ts';
|
||||
import { ClusterHandler } from './cli/cluster-handler.ts';
|
||||
import { ModelHandler } from './cli/model-handler.ts';
|
||||
import { ConfigHandler } from './cli/config-handler.ts';
|
||||
import { ServiceHandler } from './cli/service-handler.ts';
|
||||
@@ -35,12 +38,15 @@ export class ModelGrid {
|
||||
private driverManager: DriverManager;
|
||||
private dockerManager: DockerManager;
|
||||
private containerManager: ContainerManager;
|
||||
private clusterManager: ClusterManager;
|
||||
private clusterCoordinator?: ClusterCoordinator;
|
||||
private modelRegistry: ModelRegistry;
|
||||
private modelLoader?: ModelLoader;
|
||||
|
||||
// CLI Handlers
|
||||
private gpuHandler: GpuHandler;
|
||||
private containerHandler: ContainerHandler;
|
||||
private clusterHandler: ClusterHandler;
|
||||
private modelHandler: ModelHandler;
|
||||
private configHandler: ConfigHandler;
|
||||
private serviceHandler: ServiceHandler;
|
||||
@@ -52,6 +58,7 @@ export class ModelGrid {
|
||||
this.driverManager = new DriverManager();
|
||||
this.dockerManager = new DockerManager();
|
||||
this.containerManager = new ContainerManager();
|
||||
this.clusterManager = new ClusterManager();
|
||||
this.modelRegistry = new ModelRegistry();
|
||||
this.systemd = new Systemd();
|
||||
this.daemon = new Daemon(this);
|
||||
@@ -59,7 +66,12 @@ export class ModelGrid {
|
||||
// Initialize CLI handlers
|
||||
this.gpuHandler = new GpuHandler();
|
||||
this.containerHandler = new ContainerHandler(this.containerManager);
|
||||
this.modelHandler = new ModelHandler(this.containerManager, this.modelRegistry);
|
||||
this.clusterHandler = new ClusterHandler();
|
||||
this.modelHandler = new ModelHandler(
|
||||
this.containerManager,
|
||||
this.getClusterCoordinator(),
|
||||
this.modelRegistry,
|
||||
);
|
||||
this.configHandler = new ConfigHandler();
|
||||
this.serviceHandler = new ServiceHandler(this);
|
||||
}
|
||||
@@ -70,7 +82,14 @@ export class ModelGrid {
|
||||
public async loadConfig(): Promise<void> {
|
||||
try {
|
||||
const configContent = await fs.readFile(PATHS.CONFIG_FILE, 'utf-8');
|
||||
this.config = JSON.parse(configContent) as IModelGridConfig;
|
||||
this.config = this.normalizeConfig(
|
||||
JSON.parse(configContent) as Partial<IModelGridConfig> & {
|
||||
models?: {
|
||||
greenlistUrl?: string;
|
||||
autoPull?: boolean;
|
||||
} & Partial<IModelGridConfig['models']>;
|
||||
},
|
||||
);
|
||||
logger.dim(`Configuration loaded from ${PATHS.CONFIG_FILE}`);
|
||||
} catch (error) {
|
||||
if ((error as NodeJS.ErrnoException).code === 'ENOENT') {
|
||||
@@ -163,6 +182,23 @@ export class ModelGrid {
|
||||
return this.containerManager;
|
||||
}
|
||||
|
||||
public getClusterManager(): ClusterManager {
|
||||
return this.clusterManager;
|
||||
}
|
||||
|
||||
public getClusterCoordinator(): ClusterCoordinator {
|
||||
if (!this.clusterCoordinator) {
|
||||
this.clusterCoordinator = new ClusterCoordinator(
|
||||
this.clusterManager,
|
||||
this.containerManager,
|
||||
this.modelRegistry,
|
||||
this.getModelLoader(),
|
||||
);
|
||||
}
|
||||
|
||||
return this.clusterCoordinator;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get Model Registry instance
|
||||
*/
|
||||
@@ -203,6 +239,10 @@ export class ModelGrid {
|
||||
return this.modelHandler;
|
||||
}
|
||||
|
||||
public getClusterHandler(): ClusterHandler {
|
||||
return this.clusterHandler;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get Config Handler
|
||||
*/
|
||||
@@ -234,18 +274,81 @@ export class ModelGrid {
|
||||
}
|
||||
|
||||
// Initialize model registry
|
||||
this.modelRegistry.setGreenlistUrl(this.config.models.greenlistUrl);
|
||||
this.modelRegistry.setCatalogUrl(this.config.models.registryUrl);
|
||||
this.clusterManager.configure(this.config.cluster);
|
||||
await this.clusterManager.initialize();
|
||||
|
||||
// Create model loader
|
||||
this.modelLoader = new ModelLoader(
|
||||
this.modelRegistry,
|
||||
this.containerManager,
|
||||
this.config.models.autoPull,
|
||||
this.config.models.autoDeploy,
|
||||
);
|
||||
this.clusterCoordinator = new ClusterCoordinator(
|
||||
this.clusterManager,
|
||||
this.containerManager,
|
||||
this.modelRegistry,
|
||||
this.modelLoader,
|
||||
);
|
||||
|
||||
logger.success('ModelGrid initialized');
|
||||
}
|
||||
|
||||
private normalizeConfig(
|
||||
config: Partial<IModelGridConfig> & {
|
||||
models?: {
|
||||
greenlistUrl?: string;
|
||||
autoPull?: boolean;
|
||||
} & Partial<IModelGridConfig['models']>;
|
||||
},
|
||||
): IModelGridConfig {
|
||||
const filteredContainers = (config.containers || []).filter(
|
||||
(container) => (container as { type?: string }).type !== 'ollama',
|
||||
);
|
||||
|
||||
return {
|
||||
version: config.version || VERSION,
|
||||
api: {
|
||||
port: config.api?.port || 8080,
|
||||
host: config.api?.host || '0.0.0.0',
|
||||
apiKeys: config.api?.apiKeys || [],
|
||||
rateLimit: config.api?.rateLimit,
|
||||
cors: config.api?.cors ?? true,
|
||||
corsOrigins: config.api?.corsOrigins || ['*'],
|
||||
},
|
||||
docker: {
|
||||
networkName: config.docker?.networkName || 'modelgrid',
|
||||
runtime: config.docker?.runtime || 'docker',
|
||||
socketPath: config.docker?.socketPath,
|
||||
},
|
||||
gpus: {
|
||||
autoDetect: config.gpus?.autoDetect ?? true,
|
||||
assignments: config.gpus?.assignments || {},
|
||||
},
|
||||
containers: filteredContainers,
|
||||
models: {
|
||||
registryUrl: config.models?.registryUrl || config.models?.greenlistUrl ||
|
||||
'https://list.modelgrid.com/catalog/models.json',
|
||||
autoDeploy: config.models?.autoDeploy ?? config.models?.autoPull ?? true,
|
||||
defaultEngine: 'vllm',
|
||||
autoLoad: config.models?.autoLoad || [],
|
||||
},
|
||||
cluster: {
|
||||
enabled: config.cluster?.enabled ?? false,
|
||||
nodeName: config.cluster?.nodeName || 'modelgrid-local',
|
||||
role: config.cluster?.role || 'standalone',
|
||||
bindHost: config.cluster?.bindHost || '0.0.0.0',
|
||||
gossipPort: config.cluster?.gossipPort || 7946,
|
||||
sharedSecret: config.cluster?.sharedSecret,
|
||||
advertiseUrl: config.cluster?.advertiseUrl,
|
||||
controlPlaneUrl: config.cluster?.controlPlaneUrl,
|
||||
heartbeatIntervalMs: config.cluster?.heartbeatIntervalMs || 5000,
|
||||
seedNodes: config.cluster?.seedNodes || [],
|
||||
},
|
||||
checkInterval: config.checkInterval || 30000,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Shutdown the ModelGrid system
|
||||
*/
|
||||
|
||||
+195
-148
@@ -1,18 +1,16 @@
|
||||
/**
|
||||
* Model Loader
|
||||
*
|
||||
* Handles automatic model loading with greenlist validation.
|
||||
* Model loader for vLLM deployments.
|
||||
*/
|
||||
|
||||
import type { TContainerType } from '../interfaces/container.ts';
|
||||
import type { IModelCatalogEntry } from '../interfaces/catalog.ts';
|
||||
import type { IGpuInfo } from '../interfaces/gpu.ts';
|
||||
import { filterOutUsedGpus, selectPlacementForModel } from '../cluster/placement.ts';
|
||||
import { VllmContainer } from '../containers/vllm.ts';
|
||||
import { logger } from '../logger.ts';
|
||||
import { ModelRegistry } from './registry.ts';
|
||||
import { ContainerManager } from '../containers/container-manager.ts';
|
||||
import { GpuDetector } from '../hardware/gpu-detector.ts';
|
||||
|
||||
/**
|
||||
* Model load result
|
||||
*/
|
||||
export interface IModelLoadResult {
|
||||
success: boolean;
|
||||
model: string;
|
||||
@@ -21,161 +19,112 @@ export interface IModelLoadResult {
|
||||
alreadyLoaded?: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
* Model loader with greenlist validation
|
||||
*/
|
||||
export interface IModelLoadOptions {
|
||||
forceNewReplica?: boolean;
|
||||
replicaOrdinal?: number;
|
||||
}
|
||||
|
||||
export class ModelLoader {
|
||||
private registry: ModelRegistry;
|
||||
private containerManager: ContainerManager;
|
||||
private gpuDetector: GpuDetector;
|
||||
private autoPull: boolean;
|
||||
private autoDeploy: boolean;
|
||||
|
||||
constructor(
|
||||
registry: ModelRegistry,
|
||||
containerManager: ContainerManager,
|
||||
autoPull: boolean = true,
|
||||
autoDeploy: boolean = true,
|
||||
) {
|
||||
this.registry = registry;
|
||||
this.containerManager = containerManager;
|
||||
this.gpuDetector = new GpuDetector();
|
||||
this.autoPull = autoPull;
|
||||
this.autoDeploy = autoDeploy;
|
||||
}
|
||||
|
||||
/**
|
||||
* Load a model with greenlist validation
|
||||
*/
|
||||
public async loadModel(modelName: string): Promise<IModelLoadResult> {
|
||||
public async loadModel(
|
||||
modelName: string,
|
||||
options: IModelLoadOptions = {},
|
||||
): Promise<IModelLoadResult> {
|
||||
logger.info(`Loading model: ${modelName}`);
|
||||
|
||||
// Step 1: Check if model is already loaded in any container
|
||||
const container = await this.containerManager.findContainerForModel(modelName);
|
||||
if (container) {
|
||||
logger.dim(`Model ${modelName} is already available in container ${container.getConfig().id}`);
|
||||
const modelInfo = await this.registry.getModel(modelName);
|
||||
const resolvedModelName = modelInfo?.id || modelName;
|
||||
|
||||
const existing = await this.containerManager.findContainerForModel(resolvedModelName);
|
||||
if (existing && !options.forceNewReplica) {
|
||||
return {
|
||||
success: true,
|
||||
model: modelName,
|
||||
container: container.getConfig().id,
|
||||
model: resolvedModelName,
|
||||
container: existing.getConfig().id,
|
||||
alreadyLoaded: true,
|
||||
};
|
||||
}
|
||||
|
||||
// Step 2: Check if model is greenlit
|
||||
const isGreenlit = await this.registry.isModelGreenlit(modelName);
|
||||
if (!isGreenlit) {
|
||||
logger.error(`Model ${modelName} is not in the greenlit list`);
|
||||
logger.info('Only greenlit models can be auto-pulled for security reasons.');
|
||||
logger.info('Contact your administrator to add this model to the greenlist.');
|
||||
return {
|
||||
success: false,
|
||||
model: modelName,
|
||||
error: `Model "${modelName}" is not greenlit. Request via admin or add to greenlist.`,
|
||||
};
|
||||
}
|
||||
|
||||
// Step 3: Get model info from greenlist
|
||||
const modelInfo = await this.registry.getGreenlitModel(modelName);
|
||||
if (!modelInfo) {
|
||||
return {
|
||||
success: false,
|
||||
model: modelName,
|
||||
error: 'Failed to get model info from greenlist',
|
||||
model: resolvedModelName,
|
||||
error: `Model "${modelName}" is not listed in the registry`,
|
||||
};
|
||||
}
|
||||
|
||||
// Step 4: Check VRAM requirements
|
||||
const gpus = await this.gpuDetector.detectGpus();
|
||||
const totalVram = gpus.reduce((sum, gpu) => sum + gpu.vram, 0);
|
||||
const totalVramGb = Math.round(totalVram / 1024);
|
||||
|
||||
if (modelInfo.minVram > totalVramGb) {
|
||||
logger.error(`Insufficient VRAM for model ${modelName}`);
|
||||
logger.info(`Required: ${modelInfo.minVram}GB, Available: ${totalVramGb}GB`);
|
||||
const placement = this.planPlacement(modelInfo, await this.gpuDetector.detectGpus());
|
||||
if (!placement) {
|
||||
return {
|
||||
success: false,
|
||||
model: modelName,
|
||||
error: `Insufficient VRAM. Required: ${modelInfo.minVram}GB, Available: ${totalVramGb}GB`,
|
||||
model: resolvedModelName,
|
||||
error: 'Insufficient GPU capacity for deployment',
|
||||
};
|
||||
}
|
||||
|
||||
// Step 5: Find or create appropriate container
|
||||
const containerType = modelInfo.container;
|
||||
let targetContainer = await this.findAvailableContainer(containerType);
|
||||
|
||||
if (!targetContainer) {
|
||||
logger.warn(`No ${containerType} container available`);
|
||||
|
||||
// Could auto-create container here if desired
|
||||
if (!this.autoDeploy) {
|
||||
return {
|
||||
success: false,
|
||||
model: modelName,
|
||||
error: `No ${containerType} container available to load model`,
|
||||
model: resolvedModelName,
|
||||
error: 'Automatic deployments are disabled',
|
||||
};
|
||||
}
|
||||
|
||||
// Step 6: Pull the model if auto-pull is enabled
|
||||
if (this.autoPull) {
|
||||
logger.info(`Pulling model ${modelName} to ${containerType} container...`);
|
||||
const deploymentId = this.createDeploymentId(
|
||||
modelInfo.id,
|
||||
options.replicaOrdinal ?? this.getExistingReplicaCount(modelInfo.id),
|
||||
);
|
||||
const deploymentName = this.createDeploymentName(
|
||||
modelInfo.id,
|
||||
options.replicaOrdinal ?? this.getExistingReplicaCount(modelInfo.id),
|
||||
);
|
||||
const config = VllmContainer.createConfig(
|
||||
deploymentId,
|
||||
deploymentName,
|
||||
modelInfo.source.repo,
|
||||
placement.gpuIds,
|
||||
{
|
||||
env: {
|
||||
...(modelInfo.launchDefaults?.env || {}),
|
||||
},
|
||||
command: this.buildVllmCommand(modelInfo, placement.tensorParallelSize),
|
||||
},
|
||||
);
|
||||
config.models = [modelInfo.id];
|
||||
|
||||
const pullSuccess = await targetContainer.pullModel(modelName, (progress) => {
|
||||
const percent = progress.percent !== undefined ? ` (${progress.percent}%)` : '';
|
||||
logger.dim(` ${progress.status}${percent}`);
|
||||
});
|
||||
|
||||
if (!pullSuccess) {
|
||||
return {
|
||||
success: false,
|
||||
model: modelName,
|
||||
error: 'Failed to pull model',
|
||||
};
|
||||
}
|
||||
const container = this.containerManager.addContainer(config);
|
||||
const started = await container.start();
|
||||
if (!started) {
|
||||
await this.containerManager.removeContainer(config.id);
|
||||
return {
|
||||
success: false,
|
||||
model: resolvedModelName,
|
||||
error: 'Failed to start vLLM deployment',
|
||||
};
|
||||
}
|
||||
|
||||
logger.success(`Model ${modelName} loaded successfully`);
|
||||
return {
|
||||
success: true,
|
||||
model: modelName,
|
||||
container: targetContainer.getConfig().id,
|
||||
model: modelInfo.id,
|
||||
container: config.id,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Find an available container of the specified type
|
||||
*/
|
||||
private async findAvailableContainer(
|
||||
containerType: TContainerType,
|
||||
): Promise<import('../containers/base-container.ts').BaseContainer | null> {
|
||||
const containers = this.containerManager.getAllContainers();
|
||||
|
||||
for (const container of containers) {
|
||||
if (container.type !== containerType) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const status = await container.getStatus();
|
||||
if (status.running) {
|
||||
return container;
|
||||
}
|
||||
}
|
||||
|
||||
// No running container found, try to start one
|
||||
for (const container of containers) {
|
||||
if (container.type !== containerType) {
|
||||
continue;
|
||||
}
|
||||
|
||||
logger.info(`Starting ${containerType} container: ${container.getConfig().name}`);
|
||||
const started = await container.start();
|
||||
if (started) {
|
||||
return container;
|
||||
}
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Preload a list of models
|
||||
*/
|
||||
public async preloadModels(modelNames: string[]): Promise<Map<string, IModelLoadResult>> {
|
||||
const results = new Map<string, IModelLoadResult>();
|
||||
|
||||
@@ -191,36 +140,45 @@ export class ModelLoader {
|
||||
return results;
|
||||
}
|
||||
|
||||
/**
|
||||
* Unload a model from a container
|
||||
*/
|
||||
public async unloadModel(modelName: string): Promise<boolean> {
|
||||
const container = await this.containerManager.findContainerForModel(modelName);
|
||||
if (!container) {
|
||||
const modelInfo = await this.registry.getModel(modelName);
|
||||
const canonicalModel = modelInfo?.id || modelName;
|
||||
const containers = this.containerManager.getAllContainers().filter((container) =>
|
||||
container.getConfig().models.includes(canonicalModel)
|
||||
);
|
||||
|
||||
if (containers.length === 0) {
|
||||
logger.warn(`Model ${modelName} not found in any container`);
|
||||
return false;
|
||||
}
|
||||
|
||||
return container.removeModel(modelName);
|
||||
let allRemoved = true;
|
||||
for (const container of containers) {
|
||||
const removed = await this.containerManager.removeContainer(container.getConfig().id);
|
||||
allRemoved = allRemoved && removed;
|
||||
}
|
||||
|
||||
return allRemoved;
|
||||
}
|
||||
|
||||
public async deployReplica(
|
||||
modelName: string,
|
||||
replicaOrdinal?: number,
|
||||
): Promise<IModelLoadResult> {
|
||||
return this.loadModel(modelName, {
|
||||
forceNewReplica: true,
|
||||
replicaOrdinal,
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if auto-pull is enabled
|
||||
*/
|
||||
public isAutoPullEnabled(): boolean {
|
||||
return this.autoPull;
|
||||
return this.autoDeploy;
|
||||
}
|
||||
|
||||
/**
|
||||
* Enable or disable auto-pull
|
||||
*/
|
||||
public setAutoPull(enabled: boolean): void {
|
||||
this.autoPull = enabled;
|
||||
this.autoDeploy = enabled;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get loading recommendations for available VRAM
|
||||
*/
|
||||
public async getRecommendations(): Promise<{
|
||||
canLoad: string[];
|
||||
cannotLoad: string[];
|
||||
@@ -229,7 +187,7 @@ export class ModelLoader {
|
||||
const gpus = await this.gpuDetector.detectGpus();
|
||||
const totalVramGb = Math.round(gpus.reduce((sum, gpu) => sum + gpu.vram, 0) / 1024);
|
||||
|
||||
const allModels = await this.registry.getAllGreenlitModels();
|
||||
const allModels = await this.registry.getAllModels();
|
||||
const availableModels = await this.containerManager.getAllAvailableModels();
|
||||
const loadedNames = new Set(availableModels.keys());
|
||||
|
||||
@@ -238,27 +196,24 @@ export class ModelLoader {
|
||||
const loaded: string[] = [];
|
||||
|
||||
for (const model of allModels) {
|
||||
if (loadedNames.has(model.name)) {
|
||||
loaded.push(model.name);
|
||||
} else if (model.minVram <= totalVramGb) {
|
||||
canLoad.push(model.name);
|
||||
if (loadedNames.has(model.id)) {
|
||||
loaded.push(model.id);
|
||||
} else if (model.requirements.minVramGb <= totalVramGb) {
|
||||
canLoad.push(model.id);
|
||||
} else {
|
||||
cannotLoad.push(model.name);
|
||||
cannotLoad.push(model.id);
|
||||
}
|
||||
}
|
||||
|
||||
return { canLoad, cannotLoad, loaded };
|
||||
}
|
||||
|
||||
/**
|
||||
* Print loading status
|
||||
*/
|
||||
public async printStatus(): Promise<void> {
|
||||
const recommendations = await this.getRecommendations();
|
||||
|
||||
logger.logBoxTitle('Model Loading Status', 60, 'info');
|
||||
logger.logBoxTitle('Model Deployment Status', 70, 'info');
|
||||
|
||||
logger.logBoxLine(`Loaded Models (${recommendations.loaded.length}):`);
|
||||
logger.logBoxLine(`Running Deployments (${recommendations.loaded.length}):`);
|
||||
if (recommendations.loaded.length > 0) {
|
||||
for (const model of recommendations.loaded) {
|
||||
logger.logBoxLine(` - ${model}`);
|
||||
@@ -268,7 +223,7 @@ export class ModelLoader {
|
||||
}
|
||||
|
||||
logger.logBoxLine('');
|
||||
logger.logBoxLine(`Available to Load (${recommendations.canLoad.length}):`);
|
||||
logger.logBoxLine(`Ready To Deploy (${recommendations.canLoad.length}):`);
|
||||
for (const model of recommendations.canLoad.slice(0, 5)) {
|
||||
logger.logBoxLine(` - ${model}`);
|
||||
}
|
||||
@@ -277,10 +232,10 @@ export class ModelLoader {
|
||||
}
|
||||
|
||||
logger.logBoxLine('');
|
||||
logger.logBoxLine(`Insufficient VRAM (${recommendations.cannotLoad.length}):`);
|
||||
logger.logBoxLine(`Needs Larger GPUs (${recommendations.cannotLoad.length}):`);
|
||||
for (const model of recommendations.cannotLoad.slice(0, 3)) {
|
||||
const info = await this.registry.getGreenlitModel(model);
|
||||
logger.logBoxLine(` - ${model} (needs ${info?.minVram || '?'}GB)`);
|
||||
const info = await this.registry.getModel(model);
|
||||
logger.logBoxLine(` - ${model} (needs ${info?.requirements.minVramGb || '?'}GB)`);
|
||||
}
|
||||
if (recommendations.cannotLoad.length > 3) {
|
||||
logger.logBoxLine(` ... and ${recommendations.cannotLoad.length - 3} more`);
|
||||
@@ -288,4 +243,96 @@ export class ModelLoader {
|
||||
|
||||
logger.logBoxEnd();
|
||||
}
|
||||
|
||||
private planPlacement(
|
||||
modelInfo: IModelCatalogEntry,
|
||||
gpus: IGpuInfo[],
|
||||
): { gpuIds: string[]; tensorParallelSize: number } | null {
|
||||
const usedGpuIds = this.containerManager.getAllContainers().flatMap((container) =>
|
||||
container.getConfig().gpuIds
|
||||
);
|
||||
const freeGpus = filterOutUsedGpus(gpus, usedGpuIds);
|
||||
|
||||
const preferredPlacement = selectPlacementForModel(modelInfo, freeGpus);
|
||||
if (preferredPlacement) {
|
||||
return {
|
||||
gpuIds: preferredPlacement.gpuIds,
|
||||
tensorParallelSize: preferredPlacement.tensorParallelSize,
|
||||
};
|
||||
}
|
||||
|
||||
const fallbackPlacement = selectPlacementForModel(modelInfo, gpus);
|
||||
if (!fallbackPlacement) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return {
|
||||
gpuIds: fallbackPlacement.gpuIds,
|
||||
tensorParallelSize: fallbackPlacement.tensorParallelSize,
|
||||
};
|
||||
}
|
||||
|
||||
private buildVllmCommand(
|
||||
modelInfo: IModelCatalogEntry,
|
||||
tensorParallelSize: number,
|
||||
): string[] {
|
||||
const command = ['--model', modelInfo.source.repo];
|
||||
|
||||
if (tensorParallelSize > 1) {
|
||||
command.push('--tensor-parallel-size', String(tensorParallelSize));
|
||||
}
|
||||
|
||||
if (modelInfo.launchDefaults?.maxModelLen) {
|
||||
command.push('--max-model-len', String(modelInfo.launchDefaults.maxModelLen));
|
||||
}
|
||||
|
||||
if (modelInfo.launchDefaults?.gpuMemoryUtilization) {
|
||||
command.push(
|
||||
'--gpu-memory-utilization',
|
||||
String(modelInfo.launchDefaults.gpuMemoryUtilization),
|
||||
);
|
||||
}
|
||||
|
||||
if (modelInfo.launchDefaults?.quantization) {
|
||||
command.push('--quantization', modelInfo.launchDefaults.quantization);
|
||||
}
|
||||
|
||||
if (modelInfo.launchDefaults?.dtype) {
|
||||
command.push('--dtype', modelInfo.launchDefaults.dtype);
|
||||
}
|
||||
|
||||
if (modelInfo.launchDefaults?.generationConfig) {
|
||||
command.push('--generation-config', modelInfo.launchDefaults.generationConfig);
|
||||
}
|
||||
|
||||
if (modelInfo.launchDefaults?.extraArgs) {
|
||||
command.push(...modelInfo.launchDefaults.extraArgs);
|
||||
}
|
||||
|
||||
return command;
|
||||
}
|
||||
|
||||
private getExistingReplicaCount(modelId: string): number {
|
||||
return this.containerManager.getAllContainers().filter((container) =>
|
||||
container.getConfig().models.includes(modelId)
|
||||
).length;
|
||||
}
|
||||
|
||||
private createDeploymentId(modelId: string, replicaOrdinal: number): string {
|
||||
const baseId = modelId.toLowerCase().replace(/[^a-z0-9]+/g, '-').replace(/^-+|-+$/g, '').slice(
|
||||
0,
|
||||
32,
|
||||
);
|
||||
const suffix = replicaOrdinal > 0 ? `-r${replicaOrdinal + 1}` : '';
|
||||
return `vllm-${baseId}${suffix}`;
|
||||
}
|
||||
|
||||
private createDeploymentName(modelId: string, replicaOrdinal: number): string {
|
||||
const baseName = modelId.split('/').pop() || modelId;
|
||||
if (replicaOrdinal === 0) {
|
||||
return baseName;
|
||||
}
|
||||
|
||||
return `${baseName} replica ${replicaOrdinal + 1}`;
|
||||
}
|
||||
}
|
||||
|
||||
+158
-205
@@ -1,252 +1,205 @@
|
||||
/**
|
||||
* Model Registry
|
||||
*
|
||||
* Manages the greenlit model list and model availability.
|
||||
* Model registry backed by list.modelgrid.com.
|
||||
*/
|
||||
|
||||
import type { IGreenlitModel, IGreenlitModelsList } from '../interfaces/config.ts';
|
||||
import type { TContainerType } from '../interfaces/container.ts';
|
||||
import * as fs from 'node:fs/promises';
|
||||
import type { IModelCatalog, IModelCatalogEntry } from '../interfaces/catalog.ts';
|
||||
import { MODEL_REGISTRY, TIMING } from '../constants.ts';
|
||||
import { logger } from '../logger.ts';
|
||||
|
||||
/**
|
||||
* Model registry for managing greenlit models
|
||||
*/
|
||||
export class ModelRegistry {
|
||||
private greenlistUrl: string;
|
||||
private cachedGreenlist: IGreenlitModelsList | null = null;
|
||||
private catalogUrl: string;
|
||||
private cachedCatalog: IModelCatalog | null = null;
|
||||
private cacheTime: number = 0;
|
||||
|
||||
constructor(greenlistUrl: string = MODEL_REGISTRY.DEFAULT_GREENLIST_URL) {
|
||||
this.greenlistUrl = greenlistUrl;
|
||||
constructor(catalogUrl: string = MODEL_REGISTRY.DEFAULT_CATALOG_URL) {
|
||||
this.catalogUrl = catalogUrl;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the greenlist URL
|
||||
*/
|
||||
public setGreenlistUrl(url: string): void {
|
||||
this.greenlistUrl = url;
|
||||
this.cachedGreenlist = null;
|
||||
public setCatalogUrl(url: string): void {
|
||||
this.catalogUrl = url;
|
||||
this.cachedCatalog = null;
|
||||
this.cacheTime = 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* Fetch the greenlit model list from remote URL
|
||||
*/
|
||||
public async fetchGreenlist(forceRefresh: boolean = false): Promise<IGreenlitModelsList> {
|
||||
// Return cached data if still valid
|
||||
public async fetchCatalog(forceRefresh: boolean = false): Promise<IModelCatalog> {
|
||||
if (
|
||||
!forceRefresh &&
|
||||
this.cachedGreenlist &&
|
||||
this.cachedCatalog &&
|
||||
Date.now() - this.cacheTime < TIMING.GREENLIST_CACHE_DURATION_MS
|
||||
) {
|
||||
return this.cachedGreenlist;
|
||||
return this.cachedCatalog;
|
||||
}
|
||||
|
||||
try {
|
||||
logger.dim(`Fetching greenlit models from: ${this.greenlistUrl}`);
|
||||
logger.dim(`Fetching model catalog from: ${this.catalogUrl}`);
|
||||
const catalog = await this.readCatalogSource(this.catalogUrl);
|
||||
|
||||
const controller = new AbortController();
|
||||
const timeout = setTimeout(() => controller.abort(), 30000);
|
||||
if (!Array.isArray(catalog.models)) {
|
||||
throw new Error('Invalid catalog format: missing models array');
|
||||
}
|
||||
|
||||
const response = await fetch(this.greenlistUrl, {
|
||||
this.cachedCatalog = catalog;
|
||||
this.cacheTime = Date.now();
|
||||
|
||||
logger.dim(`Loaded ${catalog.models.length} catalog models`);
|
||||
return catalog;
|
||||
} catch (error) {
|
||||
logger.warn(
|
||||
`Failed to fetch model catalog: ${error instanceof Error ? error.message : String(error)}`,
|
||||
);
|
||||
|
||||
if (!this.cachedCatalog) {
|
||||
logger.dim('Using fallback catalog');
|
||||
return this.getFallbackCatalog();
|
||||
}
|
||||
|
||||
return this.cachedCatalog;
|
||||
}
|
||||
}
|
||||
|
||||
public async isModelListed(modelName: string): Promise<boolean> {
|
||||
return (await this.getModel(modelName)) !== null;
|
||||
}
|
||||
|
||||
public async getModel(modelName: string): Promise<IModelCatalogEntry | null> {
|
||||
const catalog = await this.fetchCatalog();
|
||||
const normalized = this.normalizeModelName(modelName);
|
||||
|
||||
return catalog.models.find((model) => {
|
||||
const candidates = [model.id, ...(model.aliases || [])];
|
||||
return candidates.some((candidate) => this.normalizeModelName(candidate) === normalized);
|
||||
}) || null;
|
||||
}
|
||||
|
||||
public async getAllModels(): Promise<IModelCatalogEntry[]> {
|
||||
const catalog = await this.fetchCatalog();
|
||||
return catalog.models;
|
||||
}
|
||||
|
||||
public async getModelsByEngine(engine: 'vllm'): Promise<IModelCatalogEntry[]> {
|
||||
const catalog = await this.fetchCatalog();
|
||||
return catalog.models.filter((model) => model.engine === engine);
|
||||
}
|
||||
|
||||
public async getModelsWithinVram(maxVramGb: number): Promise<IModelCatalogEntry[]> {
|
||||
const catalog = await this.fetchCatalog();
|
||||
return catalog.models.filter((model) => model.requirements.minVramGb <= maxVramGb);
|
||||
}
|
||||
|
||||
public async getRecommendedEngine(modelName: string): Promise<'vllm' | null> {
|
||||
const model = await this.getModel(modelName);
|
||||
return model ? model.engine : null;
|
||||
}
|
||||
|
||||
public async getMinVram(modelName: string): Promise<number | null> {
|
||||
const model = await this.getModel(modelName);
|
||||
return model ? model.requirements.minVramGb : null;
|
||||
}
|
||||
|
||||
public async modelFitsInVram(modelName: string, availableVramGb: number): Promise<boolean> {
|
||||
const minVram = await this.getMinVram(modelName);
|
||||
if (minVram === null) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return availableVramGb >= minVram;
|
||||
}
|
||||
|
||||
public async searchModels(pattern: string): Promise<IModelCatalogEntry[]> {
|
||||
const catalog = await this.fetchCatalog();
|
||||
const normalizedPattern = pattern.toLowerCase();
|
||||
|
||||
return catalog.models.filter((model) =>
|
||||
model.id.toLowerCase().includes(normalizedPattern) ||
|
||||
model.aliases?.some((alias) => alias.toLowerCase().includes(normalizedPattern)) ||
|
||||
model.metadata?.summary?.toLowerCase().includes(normalizedPattern) ||
|
||||
model.metadata?.tags?.some((tag) => tag.toLowerCase().includes(normalizedPattern))
|
||||
);
|
||||
}
|
||||
|
||||
public async getModelsByTags(tags: string[]): Promise<IModelCatalogEntry[]> {
|
||||
const catalog = await this.fetchCatalog();
|
||||
const normalizedTags = tags.map((tag) => tag.toLowerCase());
|
||||
|
||||
return catalog.models.filter((model) =>
|
||||
model.metadata?.tags?.some((tag) => normalizedTags.includes(tag.toLowerCase()))
|
||||
);
|
||||
}
|
||||
|
||||
public clearCache(): void {
|
||||
this.cachedCatalog = null;
|
||||
this.cacheTime = 0;
|
||||
}
|
||||
|
||||
public async printSummary(): Promise<void> {
|
||||
const catalog = await this.fetchCatalog();
|
||||
|
||||
logger.logBoxTitle('Model Catalog', 70, 'info');
|
||||
logger.logBoxLine(`Version: ${catalog.version}`);
|
||||
logger.logBoxLine(`Generated: ${catalog.generatedAt}`);
|
||||
logger.logBoxLine(`Total Models: ${catalog.models.length}`);
|
||||
logger.logBoxLine('');
|
||||
|
||||
for (const model of catalog.models.slice(0, 10)) {
|
||||
logger.logBoxLine(
|
||||
`- ${model.id} (${model.requirements.minVramGb}GB, ${model.engine})`,
|
||||
);
|
||||
}
|
||||
|
||||
if (catalog.models.length > 10) {
|
||||
logger.logBoxLine(`... and ${catalog.models.length - 10} more`);
|
||||
}
|
||||
|
||||
logger.logBoxEnd();
|
||||
}
|
||||
|
||||
private async readCatalogSource(source: string): Promise<IModelCatalog> {
|
||||
if (source.startsWith('file://')) {
|
||||
const filePath = new URL(source);
|
||||
const content = await fs.readFile(filePath, 'utf-8');
|
||||
return JSON.parse(content) as IModelCatalog;
|
||||
}
|
||||
|
||||
if (source.startsWith('/')) {
|
||||
const content = await fs.readFile(source, 'utf-8');
|
||||
return JSON.parse(content) as IModelCatalog;
|
||||
}
|
||||
|
||||
const controller = new AbortController();
|
||||
const timeout = setTimeout(() => controller.abort(), 30000);
|
||||
|
||||
try {
|
||||
const response = await fetch(source, {
|
||||
signal: controller.signal,
|
||||
headers: {
|
||||
'Accept': 'application/json',
|
||||
Accept: 'application/json',
|
||||
'User-Agent': 'ModelGrid/1.0',
|
||||
},
|
||||
});
|
||||
|
||||
clearTimeout(timeout);
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`HTTP ${response.status}: ${response.statusText}`);
|
||||
}
|
||||
|
||||
const greenlist = await response.json() as IGreenlitModelsList;
|
||||
|
||||
// Validate structure
|
||||
if (!greenlist.models || !Array.isArray(greenlist.models)) {
|
||||
throw new Error('Invalid greenlist format: missing models array');
|
||||
}
|
||||
|
||||
// Cache the result
|
||||
this.cachedGreenlist = greenlist;
|
||||
this.cacheTime = Date.now();
|
||||
|
||||
logger.dim(`Loaded ${greenlist.models.length} greenlit models`);
|
||||
return greenlist;
|
||||
} catch (error) {
|
||||
logger.warn(`Failed to fetch greenlist: ${error instanceof Error ? error.message : String(error)}`);
|
||||
|
||||
// Return fallback if we have no cache
|
||||
if (!this.cachedGreenlist) {
|
||||
logger.dim('Using fallback greenlist');
|
||||
return this.getFallbackGreenlist();
|
||||
}
|
||||
|
||||
// Return stale cache
|
||||
return this.cachedGreenlist;
|
||||
return await response.json() as IModelCatalog;
|
||||
} finally {
|
||||
clearTimeout(timeout);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get fallback greenlist
|
||||
*/
|
||||
private getFallbackGreenlist(): IGreenlitModelsList {
|
||||
private getFallbackCatalog(): IModelCatalog {
|
||||
return {
|
||||
version: '1.0',
|
||||
lastUpdated: new Date().toISOString(),
|
||||
models: MODEL_REGISTRY.FALLBACK_GREENLIST as unknown as IGreenlitModel[],
|
||||
generatedAt: new Date().toISOString(),
|
||||
models: MODEL_REGISTRY.FALLBACK_CATALOG as unknown as IModelCatalogEntry[],
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a model is greenlit
|
||||
*/
|
||||
public async isModelGreenlit(modelName: string): Promise<boolean> {
|
||||
const greenlist = await this.fetchGreenlist();
|
||||
return greenlist.models.some((m) => this.normalizeModelName(m.name) === this.normalizeModelName(modelName));
|
||||
}
|
||||
|
||||
/**
|
||||
* Get greenlit model info
|
||||
*/
|
||||
public async getGreenlitModel(modelName: string): Promise<IGreenlitModel | null> {
|
||||
const greenlist = await this.fetchGreenlist();
|
||||
const normalized = this.normalizeModelName(modelName);
|
||||
return greenlist.models.find((m) => this.normalizeModelName(m.name) === normalized) || null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all greenlit models
|
||||
*/
|
||||
public async getAllGreenlitModels(): Promise<IGreenlitModel[]> {
|
||||
const greenlist = await this.fetchGreenlist();
|
||||
return greenlist.models;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get greenlit models by container type
|
||||
*/
|
||||
public async getModelsByContainer(containerType: TContainerType): Promise<IGreenlitModel[]> {
|
||||
const greenlist = await this.fetchGreenlist();
|
||||
return greenlist.models.filter((m) => m.container === containerType);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get greenlit models that fit within VRAM limit
|
||||
*/
|
||||
public async getModelsWithinVram(maxVramGb: number): Promise<IGreenlitModel[]> {
|
||||
const greenlist = await this.fetchGreenlist();
|
||||
return greenlist.models.filter((m) => m.minVram <= maxVramGb);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get recommended container type for a model
|
||||
*/
|
||||
public async getRecommendedContainer(modelName: string): Promise<TContainerType | null> {
|
||||
const model = await this.getGreenlitModel(modelName);
|
||||
return model ? model.container : null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get minimum VRAM required for a model
|
||||
*/
|
||||
public async getMinVram(modelName: string): Promise<number | null> {
|
||||
const model = await this.getGreenlitModel(modelName);
|
||||
return model ? model.minVram : null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if model fits in available VRAM
|
||||
*/
|
||||
public async modelFitsInVram(modelName: string, availableVramGb: number): Promise<boolean> {
|
||||
const minVram = await this.getMinVram(modelName);
|
||||
if (minVram === null) {
|
||||
// Model not in greenlist, assume it might fit
|
||||
return true;
|
||||
}
|
||||
return availableVramGb >= minVram;
|
||||
}
|
||||
|
||||
/**
|
||||
* Normalize model name for comparison
|
||||
* Handles variations like "llama3:8b" vs "llama3:8B" vs "meta-llama/llama-3-8b"
|
||||
*/
|
||||
private normalizeModelName(name: string): string {
|
||||
return name
|
||||
.toLowerCase()
|
||||
.replace(/[^a-z0-9:.-]/g, '')
|
||||
.replace(/[^a-z0-9:/._-]/g, '')
|
||||
.trim();
|
||||
}
|
||||
|
||||
/**
|
||||
* Search models by name pattern
|
||||
*/
|
||||
public async searchModels(pattern: string): Promise<IGreenlitModel[]> {
|
||||
const greenlist = await this.fetchGreenlist();
|
||||
const normalizedPattern = pattern.toLowerCase();
|
||||
|
||||
return greenlist.models.filter((m) =>
|
||||
m.name.toLowerCase().includes(normalizedPattern) ||
|
||||
m.description?.toLowerCase().includes(normalizedPattern) ||
|
||||
m.tags?.some((t) => t.toLowerCase().includes(normalizedPattern))
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get models by tags
|
||||
*/
|
||||
public async getModelsByTags(tags: string[]): Promise<IGreenlitModel[]> {
|
||||
const greenlist = await this.fetchGreenlist();
|
||||
const normalizedTags = tags.map((t) => t.toLowerCase());
|
||||
|
||||
return greenlist.models.filter((m) =>
|
||||
m.tags?.some((t) => normalizedTags.includes(t.toLowerCase()))
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Clear the cached greenlist
|
||||
*/
|
||||
public clearCache(): void {
|
||||
this.cachedGreenlist = null;
|
||||
this.cacheTime = 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* Print greenlist summary
|
||||
*/
|
||||
public async printSummary(): Promise<void> {
|
||||
const greenlist = await this.fetchGreenlist();
|
||||
|
||||
// Group by container type
|
||||
const byContainer = new Map<string, IGreenlitModel[]>();
|
||||
for (const model of greenlist.models) {
|
||||
if (!byContainer.has(model.container)) {
|
||||
byContainer.set(model.container, []);
|
||||
}
|
||||
byContainer.get(model.container)!.push(model);
|
||||
}
|
||||
|
||||
logger.logBoxTitle('Greenlit Models', 60, 'info');
|
||||
logger.logBoxLine(`Version: ${greenlist.version}`);
|
||||
logger.logBoxLine(`Last Updated: ${greenlist.lastUpdated}`);
|
||||
logger.logBoxLine(`Total Models: ${greenlist.models.length}`);
|
||||
logger.logBoxLine('');
|
||||
|
||||
for (const [container, models] of byContainer) {
|
||||
logger.logBoxLine(`${container.toUpperCase()} (${models.length}):`);
|
||||
for (const model of models.slice(0, 5)) {
|
||||
logger.logBoxLine(` - ${model.name} (${model.minVram}GB VRAM)`);
|
||||
}
|
||||
if (models.length > 5) {
|
||||
logger.logBoxLine(` ... and ${models.length - 5} more`);
|
||||
}
|
||||
logger.logBoxLine('');
|
||||
}
|
||||
|
||||
logger.logBoxEnd();
|
||||
}
|
||||
}
|
||||
|
||||
+40
-12
@@ -8,7 +8,7 @@ import process from 'node:process';
|
||||
import { promises as fs } from 'node:fs';
|
||||
import { execSync } from 'node:child_process';
|
||||
import { logger } from './logger.ts';
|
||||
import { theme, symbols } from './colors.ts';
|
||||
import { symbols, theme } from './colors.ts';
|
||||
import { PATHS, VERSION } from './constants.ts';
|
||||
|
||||
/**
|
||||
@@ -122,7 +122,9 @@ WantedBy=multi-user.target
|
||||
// Display GPU status
|
||||
await this.displayGpuStatus();
|
||||
} catch (error) {
|
||||
logger.error(`Failed to get status: ${error instanceof Error ? error.message : String(error)}`);
|
||||
logger.error(
|
||||
`Failed to get status: ${error instanceof Error ? error.message : String(error)}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -152,9 +154,15 @@ WantedBy=multi-user.target
|
||||
|
||||
logger.log('');
|
||||
if (isActive) {
|
||||
logger.log(`${symbols.running} ${theme.success('Service:')} ${theme.statusActive('active (running)')}`);
|
||||
logger.log(
|
||||
`${symbols.running} ${theme.success('Service:')} ${
|
||||
theme.statusActive('active (running)')
|
||||
}`,
|
||||
);
|
||||
} else {
|
||||
logger.log(`${symbols.stopped} ${theme.dim('Service:')} ${theme.statusInactive('inactive')}`);
|
||||
logger.log(
|
||||
`${symbols.stopped} ${theme.dim('Service:')} ${theme.statusInactive('inactive')}`,
|
||||
);
|
||||
}
|
||||
|
||||
if (pid || memory) {
|
||||
@@ -166,7 +174,9 @@ WantedBy=multi-user.target
|
||||
logger.log('');
|
||||
} catch (_error) {
|
||||
logger.log('');
|
||||
logger.log(`${symbols.stopped} ${theme.dim('Service:')} ${theme.statusInactive('not installed')}`);
|
||||
logger.log(
|
||||
`${symbols.stopped} ${theme.dim('Service:')} ${theme.statusInactive('not installed')}`,
|
||||
);
|
||||
logger.log('');
|
||||
}
|
||||
}
|
||||
@@ -177,8 +187,11 @@ WantedBy=multi-user.target
|
||||
private async displayContainerStatus(): Promise<void> {
|
||||
try {
|
||||
// Try to get container info from docker
|
||||
const output = execSync('docker ps --filter "name=modelgrid" --format "{{.Names}}\\t{{.Status}}"', { encoding: 'utf-8' });
|
||||
const lines = output.trim().split('\n').filter(l => l.trim());
|
||||
const output = execSync(
|
||||
'docker ps --filter "name=modelgrid" --format "{{.Names}}\\t{{.Status}}"',
|
||||
{ encoding: 'utf-8' },
|
||||
);
|
||||
const lines = output.trim().split('\n').filter((l) => l.trim());
|
||||
|
||||
if (lines.length === 0) {
|
||||
logger.info('Containers: None running');
|
||||
@@ -191,7 +204,11 @@ WantedBy=multi-user.target
|
||||
const [name, status] = line.split('\t');
|
||||
const isUp = status?.toLowerCase().includes('up');
|
||||
|
||||
logger.log(` ${isUp ? symbols.running : symbols.stopped} ${theme.highlight(name)} - ${isUp ? theme.success(status) : theme.dim(status)}`);
|
||||
logger.log(
|
||||
` ${isUp ? symbols.running : symbols.stopped} ${theme.highlight(name)} - ${
|
||||
isUp ? theme.success(status) : theme.dim(status)
|
||||
}`,
|
||||
);
|
||||
}
|
||||
logger.log('');
|
||||
} catch (_error) {
|
||||
@@ -205,7 +222,10 @@ WantedBy=multi-user.target
|
||||
private async displayGpuStatus(): Promise<void> {
|
||||
try {
|
||||
// Try nvidia-smi
|
||||
const output = execSync('nvidia-smi --query-gpu=name,utilization.gpu,memory.used,memory.total --format=csv,noheader,nounits', { encoding: 'utf-8' });
|
||||
const output = execSync(
|
||||
'nvidia-smi --query-gpu=name,utilization.gpu,memory.used,memory.total --format=csv,noheader,nounits',
|
||||
{ encoding: 'utf-8' },
|
||||
);
|
||||
const lines = output.trim().split('\n');
|
||||
|
||||
if (lines.length === 0) {
|
||||
@@ -215,11 +235,15 @@ WantedBy=multi-user.target
|
||||
logger.info(`GPUs (${lines.length}):`);
|
||||
|
||||
for (const line of lines) {
|
||||
const [name, util, memUsed, memTotal] = line.split(',').map(s => s.trim());
|
||||
const [name, util, memUsed, memTotal] = line.split(',').map((s) => s.trim());
|
||||
const memPercent = Math.round((parseInt(memUsed) / parseInt(memTotal)) * 100);
|
||||
|
||||
logger.log(` ${symbols.info} ${theme.gpuNvidia(name)}`);
|
||||
logger.log(` Utilization: ${theme.highlight(util + '%')} Memory: ${theme.info(memUsed)}/${memTotal} MB (${memPercent}%)`);
|
||||
logger.log(
|
||||
` Utilization: ${theme.highlight(util + '%')} Memory: ${
|
||||
theme.info(memUsed)
|
||||
}/${memTotal} MB (${memPercent}%)`,
|
||||
);
|
||||
}
|
||||
logger.log('');
|
||||
} catch (_error) {
|
||||
@@ -275,7 +299,11 @@ WantedBy=multi-user.target
|
||||
logger.log('');
|
||||
logger.error('No configuration found');
|
||||
logger.log(` ${theme.dim('Config file:')} ${PATHS.CONFIG_FILE}`);
|
||||
logger.log(` ${theme.dim('Run')} ${theme.command('modelgrid config init')} ${theme.dim('to create one')}`);
|
||||
logger.log(
|
||||
` ${theme.dim('Run')} ${theme.command('modelgrid config init')} ${
|
||||
theme.dim('to create one')
|
||||
}`,
|
||||
);
|
||||
logger.log('');
|
||||
throw new Error('Configuration not found');
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user