Files
modelgrid/ts/cli/gpu-handler.ts
Juergen Kunz daaf6559e3
Some checks failed
CI / Type Check & Lint (push) Failing after 5s
CI / Build Test (Current Platform) (push) Failing after 5s
CI / Build All Platforms (push) Successful in 49s
initial
2026-01-30 03:16:57 +00:00

256 lines
7.1 KiB
TypeScript

/**
* GPU Handler
*
* CLI commands for GPU management.
*/
import { logger } from '../logger.ts';
import { theme } from '../colors.ts';
import { GpuDetector } from '../hardware/gpu-detector.ts';
import { SystemInfo } from '../hardware/system-info.ts';
import { DriverManager } from '../drivers/driver-manager.ts';
import type { ITableColumn } from '../logger.ts';
/**
* Handler for GPU-related CLI commands
*/
export class GpuHandler {
private gpuDetector: GpuDetector;
private systemInfo: SystemInfo;
private driverManager: DriverManager;
constructor() {
this.gpuDetector = new GpuDetector();
this.systemInfo = new SystemInfo();
this.driverManager = new DriverManager();
}
/**
* List detected GPUs
*/
public async list(): Promise<void> {
logger.log('');
logger.info('Detecting GPUs...');
logger.log('');
const gpus = await this.gpuDetector.detectGpus();
if (gpus.length === 0) {
logger.logBox(
'No GPUs Detected',
[
'No GPUs were found on this system.',
'',
theme.dim('Possible reasons:'),
' - No discrete GPU installed',
' - GPU drivers not installed',
' - GPU not properly connected',
],
60,
'warning',
);
return;
}
const rows = gpus.map((gpu) => ({
id: gpu.id,
vendor: this.formatVendor(gpu.vendor),
model: gpu.model,
vram: `${Math.round(gpu.vram / 1024)} GB`,
driver: gpu.driverVersion || theme.dim('N/A'),
cuda: gpu.cudaVersion || theme.dim('N/A'),
pci: gpu.pciSlot,
}));
const columns: ITableColumn[] = [
{ header: 'ID', key: 'id', align: 'left' },
{ header: 'Vendor', key: 'vendor', align: 'left' },
{ header: 'Model', key: 'model', align: 'left', color: theme.highlight },
{ header: 'VRAM', key: 'vram', align: 'right', color: theme.info },
{ header: 'Driver', key: 'driver', align: 'left' },
{ header: 'CUDA', key: 'cuda', align: 'left' },
{ header: 'PCI', key: 'pci', align: 'left', color: theme.dim },
];
logger.info(`Found ${gpus.length} GPU(s):`);
logger.log('');
logger.logTable(columns, rows);
logger.log('');
}
/**
* Show GPU status and utilization
*/
public async status(): Promise<void> {
logger.log('');
logger.info('GPU Status');
logger.log('');
const gpuStatus = await this.gpuDetector.getGpuStatus();
if (gpuStatus.length === 0) {
logger.warn('No GPUs detected');
return;
}
for (const gpu of gpuStatus) {
const utilizationBar = this.createProgressBar(gpu.utilization, 30);
const memoryBar = this.createProgressBar(gpu.memoryUsed / gpu.memoryTotal * 100, 30);
logger.logBoxTitle(`GPU ${gpu.id}: ${gpu.name}`, 70, 'info');
logger.logBoxLine(`Utilization: ${utilizationBar} ${gpu.utilization.toFixed(1)}%`);
logger.logBoxLine(`Memory: ${memoryBar} ${Math.round(gpu.memoryUsed)}/${Math.round(gpu.memoryTotal)} MB`);
logger.logBoxLine(`Temperature: ${this.formatTemperature(gpu.temperature)}`);
logger.logBoxLine(`Power: ${gpu.powerDraw.toFixed(0)}W / ${gpu.powerLimit.toFixed(0)}W`);
logger.logBoxEnd();
logger.log('');
}
}
/**
* Check and install GPU drivers
*/
public async drivers(): Promise<void> {
logger.log('');
logger.info('GPU Driver Status');
logger.log('');
// Get system info first
const sysInfo = await this.systemInfo.getSystemInfo();
// Detect GPUs
const gpus = await this.gpuDetector.detectGpus();
if (gpus.length === 0) {
logger.warn('No GPUs detected');
return;
}
// Check driver status for each vendor
const vendors = new Set(gpus.map((g) => g.vendor));
for (const vendor of vendors) {
const driver = this.driverManager.getDriver(vendor);
if (!driver) {
logger.warn(`No driver support for ${vendor}`);
continue;
}
const status = await driver.getStatus();
logger.logBoxTitle(`${this.formatVendor(vendor)} Driver`, 60, status.installed ? 'success' : 'warning');
logger.logBoxLine(`Installed: ${status.installed ? theme.success('Yes') : theme.error('No')}`);
if (status.installed) {
logger.logBoxLine(`Version: ${status.version || 'Unknown'}`);
logger.logBoxLine(`Runtime: ${status.runtimeVersion || 'Unknown'}`);
logger.logBoxLine(`Container Support: ${status.containerSupport ? theme.success('Yes') : theme.warning('No')}`);
} else {
logger.logBoxLine('');
logger.logBoxLine(theme.dim('Run `modelgrid gpu install` to install drivers'));
}
logger.logBoxEnd();
logger.log('');
}
}
/**
* Install GPU drivers
*/
public async install(): Promise<void> {
logger.log('');
logger.info('Installing GPU Drivers');
logger.log('');
// Detect GPUs
const gpus = await this.gpuDetector.detectGpus();
if (gpus.length === 0) {
logger.error('No GPUs detected - cannot install drivers');
return;
}
// Install drivers for each vendor
const vendors = new Set(gpus.map((g) => g.vendor));
for (const vendor of vendors) {
const driver = this.driverManager.getDriver(vendor);
if (!driver) {
logger.warn(`No driver installer for ${vendor}`);
continue;
}
logger.info(`Installing ${this.formatVendor(vendor)} drivers...`);
const success = await driver.install();
if (success) {
logger.success(`${this.formatVendor(vendor)} drivers installed successfully`);
// Setup container support
logger.info('Setting up container support...');
const containerSuccess = await driver.setupContainer();
if (containerSuccess) {
logger.success('Container support configured');
} else {
logger.warn('Container support setup failed - GPU passthrough may not work');
}
} else {
logger.error(`Failed to install ${this.formatVendor(vendor)} drivers`);
}
logger.log('');
}
}
/**
* Format vendor name for display
*/
private formatVendor(vendor: string): string {
switch (vendor) {
case 'nvidia':
return theme.gpuNvidia('NVIDIA');
case 'amd':
return theme.gpuAmd('AMD');
case 'intel':
return theme.gpuIntel('Intel');
default:
return vendor;
}
}
/**
* Create a progress bar
*/
private createProgressBar(percent: number, width: number): string {
const filled = Math.round((percent / 100) * width);
const empty = width - filled;
const bar = '█'.repeat(filled) + '░'.repeat(empty);
if (percent >= 90) {
return theme.error(bar);
} else if (percent >= 70) {
return theme.warning(bar);
} else {
return theme.success(bar);
}
}
/**
* Format temperature with color coding
*/
private formatTemperature(temp: number): string {
const tempStr = `${temp}°C`;
if (temp >= 85) {
return theme.error(tempStr);
} else if (temp >= 70) {
return theme.warning(tempStr);
} else {
return theme.success(tempStr);
}
}
}