Files
modelgrid/ts/cli/gpu-handler.ts
T

278 lines
7.5 KiB
TypeScript
Raw Normal View History

2026-01-30 03:16:57 +00:00
/**
* GPU Handler
*
* CLI commands for GPU management.
*/
import { logger } from '../logger.ts';
import { theme } from '../colors.ts';
import { GpuDetector } from '../hardware/gpu-detector.ts';
import { SystemInfo } from '../hardware/system-info.ts';
import { DriverManager } from '../drivers/driver-manager.ts';
import type { ITableColumn } from '../logger.ts';
/**
* Handler for GPU-related CLI commands
*/
export class GpuHandler {
private gpuDetector: GpuDetector;
private systemInfo: SystemInfo;
private driverManager: DriverManager;
constructor() {
this.gpuDetector = new GpuDetector();
this.systemInfo = new SystemInfo();
this.driverManager = new DriverManager();
}
/**
* List detected GPUs
*/
public async list(): Promise<void> {
logger.log('');
logger.info('Detecting GPUs...');
logger.log('');
const gpus = await this.gpuDetector.detectGpus();
if (gpus.length === 0) {
logger.logBox(
'No GPUs Detected',
[
'No GPUs were found on this system.',
'',
theme.dim('Possible reasons:'),
' - No discrete GPU installed',
' - GPU drivers not installed',
' - GPU not properly connected',
],
60,
'warning',
);
return;
}
const rows = gpus.map((gpu) => ({
id: gpu.id,
vendor: this.formatVendor(gpu.vendor),
model: gpu.model,
vram: `${Math.round(gpu.vram / 1024)} GB`,
driver: gpu.driverVersion || theme.dim('N/A'),
cuda: gpu.cudaVersion || theme.dim('N/A'),
pci: gpu.pciSlot,
}));
const columns: ITableColumn[] = [
{ header: 'ID', key: 'id', align: 'left' },
{ header: 'Vendor', key: 'vendor', align: 'left' },
{ header: 'Model', key: 'model', align: 'left', color: theme.highlight },
{ header: 'VRAM', key: 'vram', align: 'right', color: theme.info },
{ header: 'Driver', key: 'driver', align: 'left' },
{ header: 'CUDA', key: 'cuda', align: 'left' },
{ header: 'PCI', key: 'pci', align: 'left', color: theme.dim },
];
logger.info(`Found ${gpus.length} GPU(s):`);
logger.log('');
logger.logTable(columns, rows);
logger.log('');
}
/**
* Show GPU status and utilization
*/
public async status(): Promise<void> {
logger.log('');
logger.info('GPU Status');
logger.log('');
const gpuInfo = await this.gpuDetector.detectGpus();
const gpuStatus = await this.gpuDetector.getAllGpuStatus();
2026-01-30 03:16:57 +00:00
if (gpuStatus.size === 0) {
2026-01-30 03:16:57 +00:00
logger.warn('No GPUs detected');
return;
}
for (const [gpuId, status] of gpuStatus) {
const info = gpuInfo.find((gpu) => gpu.id === gpuId);
const utilizationBar = this.createProgressBar(status.utilization, 30);
const memoryBar = this.createProgressBar(status.memoryUsed / status.memoryTotal * 100, 30);
logger.logBoxTitle(`GPU ${status.id}: ${info?.model || 'Unknown GPU'}`, 70, 'info');
logger.logBoxLine(`Utilization: ${utilizationBar} ${status.utilization.toFixed(1)}%`);
logger.logBoxLine(
`Memory: ${memoryBar} ${Math.round(status.memoryUsed)}/${
Math.round(status.memoryTotal)
} MB`,
);
logger.logBoxLine(`Temperature: ${this.formatTemperature(status.temperature)}`);
logger.logBoxLine(
`Power: ${status.powerUsage.toFixed(0)}W / ${status.powerLimit.toFixed(0)}W`,
);
2026-01-30 03:16:57 +00:00
logger.logBoxEnd();
logger.log('');
}
}
/**
* Check and install GPU drivers
*/
public async drivers(): Promise<void> {
logger.log('');
logger.info('GPU Driver Status');
logger.log('');
// Get system info first
const sysInfo = await this.systemInfo.getSystemInfo();
// Detect GPUs
const gpus = await this.gpuDetector.detectGpus();
if (gpus.length === 0) {
logger.warn('No GPUs detected');
return;
}
// Check driver status for each vendor
const vendors = new Set(gpus.map((g) => g.vendor));
for (const vendor of vendors) {
const driver = this.driverManager.getDriver(vendor);
if (!driver) {
logger.warn(`No driver support for ${vendor}`);
continue;
}
const status = await driver.getStatus();
logger.logBoxTitle(
`${this.formatVendor(vendor)} Driver`,
60,
status.installed ? 'success' : 'warning',
);
logger.logBoxLine(
`Installed: ${status.installed ? theme.success('Yes') : theme.error('No')}`,
);
2026-01-30 03:16:57 +00:00
if (status.installed) {
logger.logBoxLine(`Version: ${status.version || 'Unknown'}`);
logger.logBoxLine(`Runtime: ${status.containerRuntimeVersion || 'Unknown'}`);
logger.logBoxLine(
`Container Support: ${
status.containerSupport ? theme.success('Yes') : theme.warning('No')
}`,
);
2026-01-30 03:16:57 +00:00
} else {
logger.logBoxLine('');
logger.logBoxLine(theme.dim('Run `modelgrid gpu install` to install drivers'));
}
logger.logBoxEnd();
logger.log('');
}
}
/**
* Install GPU drivers
*/
public async install(): Promise<void> {
logger.log('');
logger.info('Installing GPU Drivers');
logger.log('');
// Detect GPUs
const gpus = await this.gpuDetector.detectGpus();
if (gpus.length === 0) {
logger.error('No GPUs detected - cannot install drivers');
return;
}
// Install drivers for each vendor
const vendors = new Set(gpus.map((g) => g.vendor));
for (const vendor of vendors) {
const driver = this.driverManager.getDriver(vendor);
if (!driver) {
logger.warn(`No driver installer for ${vendor}`);
continue;
}
logger.info(`Installing ${this.formatVendor(vendor)} drivers...`);
const success = await driver.install({
installToolkit: true,
installContainerSupport: true,
nonInteractive: false,
});
2026-01-30 03:16:57 +00:00
if (success) {
logger.success(`${this.formatVendor(vendor)} drivers installed successfully`);
// Setup container support
logger.info('Setting up container support...');
const containerSuccess = await driver.installContainerSupport();
2026-01-30 03:16:57 +00:00
if (containerSuccess) {
logger.success('Container support configured');
} else {
logger.warn('Container support setup failed - GPU passthrough may not work');
}
} else {
logger.error(`Failed to install ${this.formatVendor(vendor)} drivers`);
}
logger.log('');
}
}
/**
* Format vendor name for display
*/
private formatVendor(vendor: string): string {
switch (vendor) {
case 'nvidia':
return theme.gpuNvidia('NVIDIA');
case 'amd':
return theme.gpuAmd('AMD');
case 'intel':
return theme.gpuIntel('Intel');
default:
return vendor;
}
}
/**
* Create a progress bar
*/
private createProgressBar(percent: number, width: number): string {
const filled = Math.round((percent / 100) * width);
const empty = width - filled;
const bar = '█'.repeat(filled) + '░'.repeat(empty);
if (percent >= 90) {
return theme.error(bar);
} else if (percent >= 70) {
return theme.warning(bar);
} else {
return theme.success(bar);
}
}
/**
* Format temperature with color coding
*/
private formatTemperature(temp: number): string {
const tempStr = `${temp}°C`;
if (temp >= 85) {
return theme.error(tempStr);
} else if (temp >= 70) {
return theme.warning(tempStr);
} else {
return theme.success(tempStr);
}
}
}