256 lines
7.1 KiB
TypeScript
256 lines
7.1 KiB
TypeScript
/**
|
|
* GPU Handler
|
|
*
|
|
* CLI commands for GPU management.
|
|
*/
|
|
|
|
import { logger } from '../logger.ts';
|
|
import { theme } from '../colors.ts';
|
|
import { GpuDetector } from '../hardware/gpu-detector.ts';
|
|
import { SystemInfo } from '../hardware/system-info.ts';
|
|
import { DriverManager } from '../drivers/driver-manager.ts';
|
|
import type { ITableColumn } from '../logger.ts';
|
|
|
|
/**
|
|
* Handler for GPU-related CLI commands
|
|
*/
|
|
export class GpuHandler {
|
|
private gpuDetector: GpuDetector;
|
|
private systemInfo: SystemInfo;
|
|
private driverManager: DriverManager;
|
|
|
|
constructor() {
|
|
this.gpuDetector = new GpuDetector();
|
|
this.systemInfo = new SystemInfo();
|
|
this.driverManager = new DriverManager();
|
|
}
|
|
|
|
/**
|
|
* List detected GPUs
|
|
*/
|
|
public async list(): Promise<void> {
|
|
logger.log('');
|
|
logger.info('Detecting GPUs...');
|
|
logger.log('');
|
|
|
|
const gpus = await this.gpuDetector.detectGpus();
|
|
|
|
if (gpus.length === 0) {
|
|
logger.logBox(
|
|
'No GPUs Detected',
|
|
[
|
|
'No GPUs were found on this system.',
|
|
'',
|
|
theme.dim('Possible reasons:'),
|
|
' - No discrete GPU installed',
|
|
' - GPU drivers not installed',
|
|
' - GPU not properly connected',
|
|
],
|
|
60,
|
|
'warning',
|
|
);
|
|
return;
|
|
}
|
|
|
|
const rows = gpus.map((gpu) => ({
|
|
id: gpu.id,
|
|
vendor: this.formatVendor(gpu.vendor),
|
|
model: gpu.model,
|
|
vram: `${Math.round(gpu.vram / 1024)} GB`,
|
|
driver: gpu.driverVersion || theme.dim('N/A'),
|
|
cuda: gpu.cudaVersion || theme.dim('N/A'),
|
|
pci: gpu.pciSlot,
|
|
}));
|
|
|
|
const columns: ITableColumn[] = [
|
|
{ header: 'ID', key: 'id', align: 'left' },
|
|
{ header: 'Vendor', key: 'vendor', align: 'left' },
|
|
{ header: 'Model', key: 'model', align: 'left', color: theme.highlight },
|
|
{ header: 'VRAM', key: 'vram', align: 'right', color: theme.info },
|
|
{ header: 'Driver', key: 'driver', align: 'left' },
|
|
{ header: 'CUDA', key: 'cuda', align: 'left' },
|
|
{ header: 'PCI', key: 'pci', align: 'left', color: theme.dim },
|
|
];
|
|
|
|
logger.info(`Found ${gpus.length} GPU(s):`);
|
|
logger.log('');
|
|
logger.logTable(columns, rows);
|
|
logger.log('');
|
|
}
|
|
|
|
/**
|
|
* Show GPU status and utilization
|
|
*/
|
|
public async status(): Promise<void> {
|
|
logger.log('');
|
|
logger.info('GPU Status');
|
|
logger.log('');
|
|
|
|
const gpuStatus = await this.gpuDetector.getGpuStatus();
|
|
|
|
if (gpuStatus.length === 0) {
|
|
logger.warn('No GPUs detected');
|
|
return;
|
|
}
|
|
|
|
for (const gpu of gpuStatus) {
|
|
const utilizationBar = this.createProgressBar(gpu.utilization, 30);
|
|
const memoryBar = this.createProgressBar(gpu.memoryUsed / gpu.memoryTotal * 100, 30);
|
|
|
|
logger.logBoxTitle(`GPU ${gpu.id}: ${gpu.name}`, 70, 'info');
|
|
logger.logBoxLine(`Utilization: ${utilizationBar} ${gpu.utilization.toFixed(1)}%`);
|
|
logger.logBoxLine(`Memory: ${memoryBar} ${Math.round(gpu.memoryUsed)}/${Math.round(gpu.memoryTotal)} MB`);
|
|
logger.logBoxLine(`Temperature: ${this.formatTemperature(gpu.temperature)}`);
|
|
logger.logBoxLine(`Power: ${gpu.powerDraw.toFixed(0)}W / ${gpu.powerLimit.toFixed(0)}W`);
|
|
logger.logBoxEnd();
|
|
logger.log('');
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Check and install GPU drivers
|
|
*/
|
|
public async drivers(): Promise<void> {
|
|
logger.log('');
|
|
logger.info('GPU Driver Status');
|
|
logger.log('');
|
|
|
|
// Get system info first
|
|
const sysInfo = await this.systemInfo.getSystemInfo();
|
|
|
|
// Detect GPUs
|
|
const gpus = await this.gpuDetector.detectGpus();
|
|
|
|
if (gpus.length === 0) {
|
|
logger.warn('No GPUs detected');
|
|
return;
|
|
}
|
|
|
|
// Check driver status for each vendor
|
|
const vendors = new Set(gpus.map((g) => g.vendor));
|
|
|
|
for (const vendor of vendors) {
|
|
const driver = this.driverManager.getDriver(vendor);
|
|
if (!driver) {
|
|
logger.warn(`No driver support for ${vendor}`);
|
|
continue;
|
|
}
|
|
|
|
const status = await driver.getStatus();
|
|
|
|
logger.logBoxTitle(`${this.formatVendor(vendor)} Driver`, 60, status.installed ? 'success' : 'warning');
|
|
logger.logBoxLine(`Installed: ${status.installed ? theme.success('Yes') : theme.error('No')}`);
|
|
|
|
if (status.installed) {
|
|
logger.logBoxLine(`Version: ${status.version || 'Unknown'}`);
|
|
logger.logBoxLine(`Runtime: ${status.runtimeVersion || 'Unknown'}`);
|
|
logger.logBoxLine(`Container Support: ${status.containerSupport ? theme.success('Yes') : theme.warning('No')}`);
|
|
} else {
|
|
logger.logBoxLine('');
|
|
logger.logBoxLine(theme.dim('Run `modelgrid gpu install` to install drivers'));
|
|
}
|
|
|
|
logger.logBoxEnd();
|
|
logger.log('');
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Install GPU drivers
|
|
*/
|
|
public async install(): Promise<void> {
|
|
logger.log('');
|
|
logger.info('Installing GPU Drivers');
|
|
logger.log('');
|
|
|
|
// Detect GPUs
|
|
const gpus = await this.gpuDetector.detectGpus();
|
|
|
|
if (gpus.length === 0) {
|
|
logger.error('No GPUs detected - cannot install drivers');
|
|
return;
|
|
}
|
|
|
|
// Install drivers for each vendor
|
|
const vendors = new Set(gpus.map((g) => g.vendor));
|
|
|
|
for (const vendor of vendors) {
|
|
const driver = this.driverManager.getDriver(vendor);
|
|
if (!driver) {
|
|
logger.warn(`No driver installer for ${vendor}`);
|
|
continue;
|
|
}
|
|
|
|
logger.info(`Installing ${this.formatVendor(vendor)} drivers...`);
|
|
|
|
const success = await driver.install();
|
|
|
|
if (success) {
|
|
logger.success(`${this.formatVendor(vendor)} drivers installed successfully`);
|
|
|
|
// Setup container support
|
|
logger.info('Setting up container support...');
|
|
const containerSuccess = await driver.setupContainer();
|
|
|
|
if (containerSuccess) {
|
|
logger.success('Container support configured');
|
|
} else {
|
|
logger.warn('Container support setup failed - GPU passthrough may not work');
|
|
}
|
|
} else {
|
|
logger.error(`Failed to install ${this.formatVendor(vendor)} drivers`);
|
|
}
|
|
|
|
logger.log('');
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Format vendor name for display
|
|
*/
|
|
private formatVendor(vendor: string): string {
|
|
switch (vendor) {
|
|
case 'nvidia':
|
|
return theme.gpuNvidia('NVIDIA');
|
|
case 'amd':
|
|
return theme.gpuAmd('AMD');
|
|
case 'intel':
|
|
return theme.gpuIntel('Intel');
|
|
default:
|
|
return vendor;
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Create a progress bar
|
|
*/
|
|
private createProgressBar(percent: number, width: number): string {
|
|
const filled = Math.round((percent / 100) * width);
|
|
const empty = width - filled;
|
|
const bar = '█'.repeat(filled) + '░'.repeat(empty);
|
|
|
|
if (percent >= 90) {
|
|
return theme.error(bar);
|
|
} else if (percent >= 70) {
|
|
return theme.warning(bar);
|
|
} else {
|
|
return theme.success(bar);
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Format temperature with color coding
|
|
*/
|
|
private formatTemperature(temp: number): string {
|
|
const tempStr = `${temp}°C`;
|
|
|
|
if (temp >= 85) {
|
|
return theme.error(tempStr);
|
|
} else if (temp >= 70) {
|
|
return theme.warning(tempStr);
|
|
} else {
|
|
return theme.success(tempStr);
|
|
}
|
|
}
|
|
}
|