/** * TGI Container (Text Generation Inference) * * Manages HuggingFace Text Generation Inference containers. */ import type { IContainerConfig, ILoadedModel, TContainerType, } from '../interfaces/container.ts'; import type { IChatCompletionRequest, IChatCompletionResponse, IChatCompletionChoice, IChatMessage, } from '../interfaces/api.ts'; import { CONTAINER_IMAGES, CONTAINER_PORTS } from '../constants.ts'; import { logger } from '../logger.ts'; import { BaseContainer, type TModelPullProgress } from './base-container.ts'; /** * TGI info response */ interface ITgiInfoResponse { model_id: string; model_sha: string; model_dtype: string; model_device_type: string; max_concurrent_requests: number; max_best_of: number; max_stop_sequences: number; max_input_length: number; max_total_tokens: number; version: string; } /** * TGI generate request */ interface ITgiGenerateRequest { inputs: string; parameters?: { temperature?: number; top_p?: number; max_new_tokens?: number; stop?: string[]; do_sample?: boolean; return_full_text?: boolean; }; } /** * TGI generate response */ interface ITgiGenerateResponse { generated_text: string; details?: { finish_reason: string; generated_tokens: number; seed?: number; }; } /** * TGI container implementation * * TGI is optimized for: * - Production deployments * - Flash Attention support * - Quantization (bitsandbytes, GPTQ, AWQ) * - Multiple GPU support with tensor parallelism */ export class TgiContainer extends BaseContainer { public readonly type: TContainerType = 'tgi'; public readonly displayName = 'TGI'; public readonly defaultImage = CONTAINER_IMAGES.TGI; public readonly defaultPort = CONTAINER_PORTS.TGI; constructor(config: IContainerConfig) { super(config); // Set defaults if not provided if (!config.image) { config.image = this.defaultImage; } if (!config.port) { config.port = this.defaultPort; } // Add default volume for model cache if (!config.volumes || config.volumes.length === 0) { config.volumes = [`modelgrid-tgi-${config.id}:/data`]; } } /** * Create TGI container configuration */ public static createConfig( id: string, name: string, modelName: string, gpuIds: string[], options: Partial = {}, ): IContainerConfig { const env: Record = { MODEL_ID: modelName, PORT: String(options.port || CONTAINER_PORTS.TGI), HUGGING_FACE_HUB_TOKEN: options.env?.HF_TOKEN || options.env?.HUGGING_FACE_HUB_TOKEN || '', ...options.env, }; // Add GPU configuration if (gpuIds.length > 1) { env.NUM_SHARD = String(gpuIds.length); } // Add quantization if specified if (options.env?.QUANTIZE) { env.QUANTIZE = options.env.QUANTIZE; } return { id, name, type: 'tgi', image: options.image || CONTAINER_IMAGES.TGI, gpuIds, port: options.port || CONTAINER_PORTS.TGI, externalPort: options.externalPort, models: [modelName], env, volumes: options.volumes || [`modelgrid-tgi-${id}:/data`], autoStart: options.autoStart ?? true, restartPolicy: options.restartPolicy || 'unless-stopped', memoryLimit: options.memoryLimit, cpuLimit: options.cpuLimit, command: options.command, }; } /** * Check if TGI is healthy */ public async isHealthy(): Promise { try { const response = await this.fetch('/health', { timeout: 5000 }); return response.ok; } catch { return false; } } /** * List available models * TGI serves a single model per instance */ public async listModels(): Promise { try { const info = await this.fetchJson('/info'); return [info.model_id]; } catch (error) { logger.warn(`Failed to get TGI info: ${error instanceof Error ? error.message : String(error)}`); return this.config.models || []; } } /** * Get loaded models with details */ public async getLoadedModels(): Promise { try { const info = await this.fetchJson('/info'); return [{ name: info.model_id, size: 0, // TGI doesn't expose model size format: info.model_dtype, loaded: true, requestCount: 0, }]; } catch { return this.config.models.map((name) => ({ name, size: 0, loaded: true, requestCount: 0, })); } } /** * Pull a model * TGI downloads models automatically at startup */ public async pullModel(modelName: string, onProgress?: TModelPullProgress): Promise { logger.info(`TGI downloads models at startup. Model: ${modelName}`); logger.info('To use a different model, create a new TGI container.'); if (onProgress) { onProgress({ model: modelName, status: 'TGI models are loaded at container startup', percent: 100, }); } return true; } /** * Remove a model * TGI serves a single model per instance */ public async removeModel(modelName: string): Promise { logger.info(`TGI serves a single model per instance.`); logger.info(`To remove model ${modelName}, stop and remove this container.`); return true; } /** * Send a chat completion request * Convert OpenAI format to TGI format */ public async chatCompletion(request: IChatCompletionRequest): Promise { // Convert messages to TGI prompt format const prompt = this.messagesToPrompt(request.messages); const tgiRequest: ITgiGenerateRequest = { inputs: prompt, parameters: { temperature: request.temperature, top_p: request.top_p, max_new_tokens: request.max_tokens || 1024, stop: Array.isArray(request.stop) ? request.stop : request.stop ? [request.stop] : undefined, do_sample: (request.temperature || 0) > 0, return_full_text: false, }, }; const response = await this.fetchJson('/generate', { method: 'POST', body: tgiRequest, timeout: 300000, // 5 minutes }); // Convert to OpenAI format const created = Math.floor(Date.now() / 1000); const choice: IChatCompletionChoice = { index: 0, message: { role: 'assistant', content: response.generated_text, }, finish_reason: response.details?.finish_reason === 'eos_token' ? 'stop' : 'length', }; return { id: this.generateRequestId(), object: 'chat.completion', created, model: this.config.models[0] || 'unknown', choices: [choice], usage: { prompt_tokens: 0, // TGI doesn't always report this completion_tokens: response.details?.generated_tokens || 0, total_tokens: response.details?.generated_tokens || 0, }, }; } /** * Stream a chat completion request */ public async chatCompletionStream( request: IChatCompletionRequest, onChunk: (chunk: string) => void, ): Promise { // Convert messages to TGI prompt format const prompt = this.messagesToPrompt(request.messages); const response = await this.fetch('/generate_stream', { method: 'POST', body: { inputs: prompt, parameters: { temperature: request.temperature, top_p: request.top_p, max_new_tokens: request.max_tokens || 1024, stop: Array.isArray(request.stop) ? request.stop : request.stop ? [request.stop] : undefined, do_sample: (request.temperature || 0) > 0, }, }, timeout: 300000, }); if (!response.ok) { const error = await response.text(); throw new Error(`HTTP ${response.status}: ${error}`); } const reader = response.body?.getReader(); if (!reader) { throw new Error('No response body'); } const decoder = new TextDecoder(); const requestId = this.generateRequestId(); const created = Math.floor(Date.now() / 1000); const model = this.config.models[0] || 'unknown'; while (true) { const { done, value } = await reader.read(); if (done) break; const text = decoder.decode(value); const lines = text.split('\n').filter((l) => l.startsWith('data:')); for (const line of lines) { try { const jsonStr = line.substring(5).trim(); if (jsonStr === '[DONE]') { onChunk('data: [DONE]\n\n'); continue; } const data = JSON.parse(jsonStr); // Convert to OpenAI streaming format const chunk = { id: requestId, object: 'chat.completion.chunk', created, model, choices: [ { index: 0, delta: { content: data.token?.text || '', } as Partial, finish_reason: data.details?.finish_reason ? 'stop' : null, }, ], }; onChunk(`data: ${JSON.stringify(chunk)}\n\n`); } catch { // Invalid JSON, skip } } } } /** * Convert chat messages to TGI prompt format */ private messagesToPrompt(messages: IChatMessage[]): string { // Use a simple chat template // TGI can use model-specific templates via the Messages API let prompt = ''; for (const message of messages) { switch (message.role) { case 'system': prompt += `System: ${message.content}\n\n`; break; case 'user': prompt += `User: ${message.content}\n\n`; break; case 'assistant': prompt += `Assistant: ${message.content}\n\n`; break; } } prompt += 'Assistant:'; return prompt; } /** * Get TGI server info */ public async getInfo(): Promise { try { return await this.fetchJson('/info'); } catch { return null; } } /** * Get TGI metrics */ public async getMetrics(): Promise> { try { const response = await this.fetch('/metrics', { timeout: 5000 }); if (response.ok) { const text = await response.text(); // Parse Prometheus metrics const metrics: Record = {}; const lines = text.split('\n'); for (const line of lines) { if (line.startsWith('#') || !line.trim()) continue; const match = line.match(/^(\w+)(?:\{[^}]*\})?\s+([\d.e+-]+)/); if (match) { metrics[match[1]] = parseFloat(match[2]); } } return metrics; } } catch { // Metrics endpoint may not be available } return {}; } }