206 lines
6.2 KiB
TypeScript
206 lines
6.2 KiB
TypeScript
/**
|
|
* Model registry backed by list.modelgrid.com.
|
|
*/
|
|
|
|
import * as fs from 'node:fs/promises';
|
|
import type { IModelCatalog, IModelCatalogEntry } from '../interfaces/catalog.ts';
|
|
import { MODEL_REGISTRY, TIMING } from '../constants.ts';
|
|
import { logger } from '../logger.ts';
|
|
|
|
export class ModelRegistry {
|
|
private catalogUrl: string;
|
|
private cachedCatalog: IModelCatalog | null = null;
|
|
private cacheTime: number = 0;
|
|
|
|
constructor(catalogUrl: string = MODEL_REGISTRY.DEFAULT_CATALOG_URL) {
|
|
this.catalogUrl = catalogUrl;
|
|
}
|
|
|
|
public setCatalogUrl(url: string): void {
|
|
this.catalogUrl = url;
|
|
this.cachedCatalog = null;
|
|
this.cacheTime = 0;
|
|
}
|
|
|
|
public async fetchCatalog(forceRefresh: boolean = false): Promise<IModelCatalog> {
|
|
if (
|
|
!forceRefresh &&
|
|
this.cachedCatalog &&
|
|
Date.now() - this.cacheTime < TIMING.GREENLIST_CACHE_DURATION_MS
|
|
) {
|
|
return this.cachedCatalog;
|
|
}
|
|
|
|
try {
|
|
logger.dim(`Fetching model catalog from: ${this.catalogUrl}`);
|
|
const catalog = await this.readCatalogSource(this.catalogUrl);
|
|
|
|
if (!Array.isArray(catalog.models)) {
|
|
throw new Error('Invalid catalog format: missing models array');
|
|
}
|
|
|
|
this.cachedCatalog = catalog;
|
|
this.cacheTime = Date.now();
|
|
|
|
logger.dim(`Loaded ${catalog.models.length} catalog models`);
|
|
return catalog;
|
|
} catch (error) {
|
|
logger.warn(
|
|
`Failed to fetch model catalog: ${error instanceof Error ? error.message : String(error)}`,
|
|
);
|
|
|
|
if (!this.cachedCatalog) {
|
|
logger.dim('Using fallback catalog');
|
|
return this.getFallbackCatalog();
|
|
}
|
|
|
|
return this.cachedCatalog;
|
|
}
|
|
}
|
|
|
|
public async isModelListed(modelName: string): Promise<boolean> {
|
|
return (await this.getModel(modelName)) !== null;
|
|
}
|
|
|
|
public async getModel(modelName: string): Promise<IModelCatalogEntry | null> {
|
|
const catalog = await this.fetchCatalog();
|
|
const normalized = this.normalizeModelName(modelName);
|
|
|
|
return catalog.models.find((model) => {
|
|
const candidates = [model.id, ...(model.aliases || [])];
|
|
return candidates.some((candidate) => this.normalizeModelName(candidate) === normalized);
|
|
}) || null;
|
|
}
|
|
|
|
public async getAllModels(): Promise<IModelCatalogEntry[]> {
|
|
const catalog = await this.fetchCatalog();
|
|
return catalog.models;
|
|
}
|
|
|
|
public async getModelsByEngine(engine: 'vllm'): Promise<IModelCatalogEntry[]> {
|
|
const catalog = await this.fetchCatalog();
|
|
return catalog.models.filter((model) => model.engine === engine);
|
|
}
|
|
|
|
public async getModelsWithinVram(maxVramGb: number): Promise<IModelCatalogEntry[]> {
|
|
const catalog = await this.fetchCatalog();
|
|
return catalog.models.filter((model) => model.requirements.minVramGb <= maxVramGb);
|
|
}
|
|
|
|
public async getRecommendedEngine(modelName: string): Promise<'vllm' | null> {
|
|
const model = await this.getModel(modelName);
|
|
return model ? model.engine : null;
|
|
}
|
|
|
|
public async getMinVram(modelName: string): Promise<number | null> {
|
|
const model = await this.getModel(modelName);
|
|
return model ? model.requirements.minVramGb : null;
|
|
}
|
|
|
|
public async modelFitsInVram(modelName: string, availableVramGb: number): Promise<boolean> {
|
|
const minVram = await this.getMinVram(modelName);
|
|
if (minVram === null) {
|
|
return false;
|
|
}
|
|
|
|
return availableVramGb >= minVram;
|
|
}
|
|
|
|
public async searchModels(pattern: string): Promise<IModelCatalogEntry[]> {
|
|
const catalog = await this.fetchCatalog();
|
|
const normalizedPattern = pattern.toLowerCase();
|
|
|
|
return catalog.models.filter((model) =>
|
|
model.id.toLowerCase().includes(normalizedPattern) ||
|
|
model.aliases?.some((alias) => alias.toLowerCase().includes(normalizedPattern)) ||
|
|
model.metadata?.summary?.toLowerCase().includes(normalizedPattern) ||
|
|
model.metadata?.tags?.some((tag) => tag.toLowerCase().includes(normalizedPattern))
|
|
);
|
|
}
|
|
|
|
public async getModelsByTags(tags: string[]): Promise<IModelCatalogEntry[]> {
|
|
const catalog = await this.fetchCatalog();
|
|
const normalizedTags = tags.map((tag) => tag.toLowerCase());
|
|
|
|
return catalog.models.filter((model) =>
|
|
model.metadata?.tags?.some((tag) => normalizedTags.includes(tag.toLowerCase()))
|
|
);
|
|
}
|
|
|
|
public clearCache(): void {
|
|
this.cachedCatalog = null;
|
|
this.cacheTime = 0;
|
|
}
|
|
|
|
public async printSummary(): Promise<void> {
|
|
const catalog = await this.fetchCatalog();
|
|
|
|
logger.logBoxTitle('Model Catalog', 70, 'info');
|
|
logger.logBoxLine(`Version: ${catalog.version}`);
|
|
logger.logBoxLine(`Generated: ${catalog.generatedAt}`);
|
|
logger.logBoxLine(`Total Models: ${catalog.models.length}`);
|
|
logger.logBoxLine('');
|
|
|
|
for (const model of catalog.models.slice(0, 10)) {
|
|
logger.logBoxLine(
|
|
`- ${model.id} (${model.requirements.minVramGb}GB, ${model.engine})`,
|
|
);
|
|
}
|
|
|
|
if (catalog.models.length > 10) {
|
|
logger.logBoxLine(`... and ${catalog.models.length - 10} more`);
|
|
}
|
|
|
|
logger.logBoxEnd();
|
|
}
|
|
|
|
private async readCatalogSource(source: string): Promise<IModelCatalog> {
|
|
if (source.startsWith('file://')) {
|
|
const filePath = new URL(source);
|
|
const content = await fs.readFile(filePath, 'utf-8');
|
|
return JSON.parse(content) as IModelCatalog;
|
|
}
|
|
|
|
if (source.startsWith('/')) {
|
|
const content = await fs.readFile(source, 'utf-8');
|
|
return JSON.parse(content) as IModelCatalog;
|
|
}
|
|
|
|
const controller = new AbortController();
|
|
const timeout = setTimeout(() => controller.abort(), 30000);
|
|
|
|
try {
|
|
const response = await fetch(source, {
|
|
signal: controller.signal,
|
|
headers: {
|
|
Accept: 'application/json',
|
|
'User-Agent': 'ModelGrid/1.0',
|
|
},
|
|
});
|
|
|
|
if (!response.ok) {
|
|
throw new Error(`HTTP ${response.status}: ${response.statusText}`);
|
|
}
|
|
|
|
return await response.json() as IModelCatalog;
|
|
} finally {
|
|
clearTimeout(timeout);
|
|
}
|
|
}
|
|
|
|
private getFallbackCatalog(): IModelCatalog {
|
|
return {
|
|
version: '1.0',
|
|
generatedAt: new Date().toISOString(),
|
|
models: MODEL_REGISTRY.FALLBACK_CATALOG as unknown as IModelCatalogEntry[],
|
|
};
|
|
}
|
|
|
|
private normalizeModelName(name: string): string {
|
|
return name
|
|
.toLowerCase()
|
|
.replace(/[^a-z0-9:/._-]/g, '')
|
|
.trim();
|
|
}
|
|
}
|