initial
This commit is contained in:
417
ts/containers/tgi.ts
Normal file
417
ts/containers/tgi.ts
Normal file
@@ -0,0 +1,417 @@
|
||||
/**
|
||||
* TGI Container (Text Generation Inference)
|
||||
*
|
||||
* Manages HuggingFace Text Generation Inference containers.
|
||||
*/
|
||||
|
||||
import type {
|
||||
IContainerConfig,
|
||||
ILoadedModel,
|
||||
TContainerType,
|
||||
} from '../interfaces/container.ts';
|
||||
import type {
|
||||
IChatCompletionRequest,
|
||||
IChatCompletionResponse,
|
||||
IChatCompletionChoice,
|
||||
IChatMessage,
|
||||
} from '../interfaces/api.ts';
|
||||
import { CONTAINER_IMAGES, CONTAINER_PORTS } from '../constants.ts';
|
||||
import { logger } from '../logger.ts';
|
||||
import { BaseContainer, type TModelPullProgress } from './base-container.ts';
|
||||
|
||||
/**
|
||||
* TGI info response
|
||||
*/
|
||||
interface ITgiInfoResponse {
|
||||
model_id: string;
|
||||
model_sha: string;
|
||||
model_dtype: string;
|
||||
model_device_type: string;
|
||||
max_concurrent_requests: number;
|
||||
max_best_of: number;
|
||||
max_stop_sequences: number;
|
||||
max_input_length: number;
|
||||
max_total_tokens: number;
|
||||
version: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* TGI generate request
|
||||
*/
|
||||
interface ITgiGenerateRequest {
|
||||
inputs: string;
|
||||
parameters?: {
|
||||
temperature?: number;
|
||||
top_p?: number;
|
||||
max_new_tokens?: number;
|
||||
stop?: string[];
|
||||
do_sample?: boolean;
|
||||
return_full_text?: boolean;
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* TGI generate response
|
||||
*/
|
||||
interface ITgiGenerateResponse {
|
||||
generated_text: string;
|
||||
details?: {
|
||||
finish_reason: string;
|
||||
generated_tokens: number;
|
||||
seed?: number;
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* TGI container implementation
|
||||
*
|
||||
* TGI is optimized for:
|
||||
* - Production deployments
|
||||
* - Flash Attention support
|
||||
* - Quantization (bitsandbytes, GPTQ, AWQ)
|
||||
* - Multiple GPU support with tensor parallelism
|
||||
*/
|
||||
export class TgiContainer extends BaseContainer {
|
||||
public readonly type: TContainerType = 'tgi';
|
||||
public readonly displayName = 'TGI';
|
||||
public readonly defaultImage = CONTAINER_IMAGES.TGI;
|
||||
public readonly defaultPort = CONTAINER_PORTS.TGI;
|
||||
|
||||
constructor(config: IContainerConfig) {
|
||||
super(config);
|
||||
|
||||
// Set defaults if not provided
|
||||
if (!config.image) {
|
||||
config.image = this.defaultImage;
|
||||
}
|
||||
if (!config.port) {
|
||||
config.port = this.defaultPort;
|
||||
}
|
||||
|
||||
// Add default volume for model cache
|
||||
if (!config.volumes || config.volumes.length === 0) {
|
||||
config.volumes = [`modelgrid-tgi-${config.id}:/data`];
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Create TGI container configuration
|
||||
*/
|
||||
public static createConfig(
|
||||
id: string,
|
||||
name: string,
|
||||
modelName: string,
|
||||
gpuIds: string[],
|
||||
options: Partial<IContainerConfig> = {},
|
||||
): IContainerConfig {
|
||||
const env: Record<string, string> = {
|
||||
MODEL_ID: modelName,
|
||||
PORT: String(options.port || CONTAINER_PORTS.TGI),
|
||||
HUGGING_FACE_HUB_TOKEN: options.env?.HF_TOKEN || options.env?.HUGGING_FACE_HUB_TOKEN || '',
|
||||
...options.env,
|
||||
};
|
||||
|
||||
// Add GPU configuration
|
||||
if (gpuIds.length > 1) {
|
||||
env.NUM_SHARD = String(gpuIds.length);
|
||||
}
|
||||
|
||||
// Add quantization if specified
|
||||
if (options.env?.QUANTIZE) {
|
||||
env.QUANTIZE = options.env.QUANTIZE;
|
||||
}
|
||||
|
||||
return {
|
||||
id,
|
||||
name,
|
||||
type: 'tgi',
|
||||
image: options.image || CONTAINER_IMAGES.TGI,
|
||||
gpuIds,
|
||||
port: options.port || CONTAINER_PORTS.TGI,
|
||||
externalPort: options.externalPort,
|
||||
models: [modelName],
|
||||
env,
|
||||
volumes: options.volumes || [`modelgrid-tgi-${id}:/data`],
|
||||
autoStart: options.autoStart ?? true,
|
||||
restartPolicy: options.restartPolicy || 'unless-stopped',
|
||||
memoryLimit: options.memoryLimit,
|
||||
cpuLimit: options.cpuLimit,
|
||||
command: options.command,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if TGI is healthy
|
||||
*/
|
||||
public async isHealthy(): Promise<boolean> {
|
||||
try {
|
||||
const response = await this.fetch('/health', { timeout: 5000 });
|
||||
return response.ok;
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* List available models
|
||||
* TGI serves a single model per instance
|
||||
*/
|
||||
public async listModels(): Promise<string[]> {
|
||||
try {
|
||||
const info = await this.fetchJson<ITgiInfoResponse>('/info');
|
||||
return [info.model_id];
|
||||
} catch (error) {
|
||||
logger.warn(`Failed to get TGI info: ${error instanceof Error ? error.message : String(error)}`);
|
||||
return this.config.models || [];
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get loaded models with details
|
||||
*/
|
||||
public async getLoadedModels(): Promise<ILoadedModel[]> {
|
||||
try {
|
||||
const info = await this.fetchJson<ITgiInfoResponse>('/info');
|
||||
return [{
|
||||
name: info.model_id,
|
||||
size: 0, // TGI doesn't expose model size
|
||||
format: info.model_dtype,
|
||||
loaded: true,
|
||||
requestCount: 0,
|
||||
}];
|
||||
} catch {
|
||||
return this.config.models.map((name) => ({
|
||||
name,
|
||||
size: 0,
|
||||
loaded: true,
|
||||
requestCount: 0,
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Pull a model
|
||||
* TGI downloads models automatically at startup
|
||||
*/
|
||||
public async pullModel(modelName: string, onProgress?: TModelPullProgress): Promise<boolean> {
|
||||
logger.info(`TGI downloads models at startup. Model: ${modelName}`);
|
||||
logger.info('To use a different model, create a new TGI container.');
|
||||
|
||||
if (onProgress) {
|
||||
onProgress({
|
||||
model: modelName,
|
||||
status: 'TGI models are loaded at container startup',
|
||||
percent: 100,
|
||||
});
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Remove a model
|
||||
* TGI serves a single model per instance
|
||||
*/
|
||||
public async removeModel(modelName: string): Promise<boolean> {
|
||||
logger.info(`TGI serves a single model per instance.`);
|
||||
logger.info(`To remove model ${modelName}, stop and remove this container.`);
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Send a chat completion request
|
||||
* Convert OpenAI format to TGI format
|
||||
*/
|
||||
public async chatCompletion(request: IChatCompletionRequest): Promise<IChatCompletionResponse> {
|
||||
// Convert messages to TGI prompt format
|
||||
const prompt = this.messagesToPrompt(request.messages);
|
||||
|
||||
const tgiRequest: ITgiGenerateRequest = {
|
||||
inputs: prompt,
|
||||
parameters: {
|
||||
temperature: request.temperature,
|
||||
top_p: request.top_p,
|
||||
max_new_tokens: request.max_tokens || 1024,
|
||||
stop: Array.isArray(request.stop) ? request.stop : request.stop ? [request.stop] : undefined,
|
||||
do_sample: (request.temperature || 0) > 0,
|
||||
return_full_text: false,
|
||||
},
|
||||
};
|
||||
|
||||
const response = await this.fetchJson<ITgiGenerateResponse>('/generate', {
|
||||
method: 'POST',
|
||||
body: tgiRequest,
|
||||
timeout: 300000, // 5 minutes
|
||||
});
|
||||
|
||||
// Convert to OpenAI format
|
||||
const created = Math.floor(Date.now() / 1000);
|
||||
|
||||
const choice: IChatCompletionChoice = {
|
||||
index: 0,
|
||||
message: {
|
||||
role: 'assistant',
|
||||
content: response.generated_text,
|
||||
},
|
||||
finish_reason: response.details?.finish_reason === 'eos_token' ? 'stop' : 'length',
|
||||
};
|
||||
|
||||
return {
|
||||
id: this.generateRequestId(),
|
||||
object: 'chat.completion',
|
||||
created,
|
||||
model: this.config.models[0] || 'unknown',
|
||||
choices: [choice],
|
||||
usage: {
|
||||
prompt_tokens: 0, // TGI doesn't always report this
|
||||
completion_tokens: response.details?.generated_tokens || 0,
|
||||
total_tokens: response.details?.generated_tokens || 0,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Stream a chat completion request
|
||||
*/
|
||||
public async chatCompletionStream(
|
||||
request: IChatCompletionRequest,
|
||||
onChunk: (chunk: string) => void,
|
||||
): Promise<void> {
|
||||
// Convert messages to TGI prompt format
|
||||
const prompt = this.messagesToPrompt(request.messages);
|
||||
|
||||
const response = await this.fetch('/generate_stream', {
|
||||
method: 'POST',
|
||||
body: {
|
||||
inputs: prompt,
|
||||
parameters: {
|
||||
temperature: request.temperature,
|
||||
top_p: request.top_p,
|
||||
max_new_tokens: request.max_tokens || 1024,
|
||||
stop: Array.isArray(request.stop) ? request.stop : request.stop ? [request.stop] : undefined,
|
||||
do_sample: (request.temperature || 0) > 0,
|
||||
},
|
||||
},
|
||||
timeout: 300000,
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const error = await response.text();
|
||||
throw new Error(`HTTP ${response.status}: ${error}`);
|
||||
}
|
||||
|
||||
const reader = response.body?.getReader();
|
||||
if (!reader) {
|
||||
throw new Error('No response body');
|
||||
}
|
||||
|
||||
const decoder = new TextDecoder();
|
||||
const requestId = this.generateRequestId();
|
||||
const created = Math.floor(Date.now() / 1000);
|
||||
const model = this.config.models[0] || 'unknown';
|
||||
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
if (done) break;
|
||||
|
||||
const text = decoder.decode(value);
|
||||
const lines = text.split('\n').filter((l) => l.startsWith('data:'));
|
||||
|
||||
for (const line of lines) {
|
||||
try {
|
||||
const jsonStr = line.substring(5).trim();
|
||||
if (jsonStr === '[DONE]') {
|
||||
onChunk('data: [DONE]\n\n');
|
||||
continue;
|
||||
}
|
||||
|
||||
const data = JSON.parse(jsonStr);
|
||||
|
||||
// Convert to OpenAI streaming format
|
||||
const chunk = {
|
||||
id: requestId,
|
||||
object: 'chat.completion.chunk',
|
||||
created,
|
||||
model,
|
||||
choices: [
|
||||
{
|
||||
index: 0,
|
||||
delta: {
|
||||
content: data.token?.text || '',
|
||||
} as Partial<IChatMessage>,
|
||||
finish_reason: data.details?.finish_reason ? 'stop' : null,
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
onChunk(`data: ${JSON.stringify(chunk)}\n\n`);
|
||||
} catch {
|
||||
// Invalid JSON, skip
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert chat messages to TGI prompt format
|
||||
*/
|
||||
private messagesToPrompt(messages: IChatMessage[]): string {
|
||||
// Use a simple chat template
|
||||
// TGI can use model-specific templates via the Messages API
|
||||
let prompt = '';
|
||||
|
||||
for (const message of messages) {
|
||||
switch (message.role) {
|
||||
case 'system':
|
||||
prompt += `System: ${message.content}\n\n`;
|
||||
break;
|
||||
case 'user':
|
||||
prompt += `User: ${message.content}\n\n`;
|
||||
break;
|
||||
case 'assistant':
|
||||
prompt += `Assistant: ${message.content}\n\n`;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
prompt += 'Assistant:';
|
||||
return prompt;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get TGI server info
|
||||
*/
|
||||
public async getInfo(): Promise<ITgiInfoResponse | null> {
|
||||
try {
|
||||
return await this.fetchJson<ITgiInfoResponse>('/info');
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get TGI metrics
|
||||
*/
|
||||
public async getMetrics(): Promise<Record<string, unknown>> {
|
||||
try {
|
||||
const response = await this.fetch('/metrics', { timeout: 5000 });
|
||||
if (response.ok) {
|
||||
const text = await response.text();
|
||||
// Parse Prometheus metrics
|
||||
const metrics: Record<string, unknown> = {};
|
||||
const lines = text.split('\n');
|
||||
for (const line of lines) {
|
||||
if (line.startsWith('#') || !line.trim()) continue;
|
||||
const match = line.match(/^(\w+)(?:\{[^}]*\})?\s+([\d.e+-]+)/);
|
||||
if (match) {
|
||||
metrics[match[1]] = parseFloat(match[2]);
|
||||
}
|
||||
}
|
||||
return metrics;
|
||||
}
|
||||
} catch {
|
||||
// Metrics endpoint may not be available
|
||||
}
|
||||
return {};
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user