278 lines
7.5 KiB
TypeScript
278 lines
7.5 KiB
TypeScript
/**
|
|
* GPU Handler
|
|
*
|
|
* CLI commands for GPU management.
|
|
*/
|
|
|
|
import { logger } from '../logger.ts';
|
|
import { theme } from '../colors.ts';
|
|
import { GpuDetector } from '../hardware/gpu-detector.ts';
|
|
import { SystemInfo } from '../hardware/system-info.ts';
|
|
import { DriverManager } from '../drivers/driver-manager.ts';
|
|
import type { ITableColumn } from '../logger.ts';
|
|
|
|
/**
|
|
* Handler for GPU-related CLI commands
|
|
*/
|
|
export class GpuHandler {
|
|
private gpuDetector: GpuDetector;
|
|
private systemInfo: SystemInfo;
|
|
private driverManager: DriverManager;
|
|
|
|
constructor() {
|
|
this.gpuDetector = new GpuDetector();
|
|
this.systemInfo = new SystemInfo();
|
|
this.driverManager = new DriverManager();
|
|
}
|
|
|
|
/**
|
|
* List detected GPUs
|
|
*/
|
|
public async list(): Promise<void> {
|
|
logger.log('');
|
|
logger.info('Detecting GPUs...');
|
|
logger.log('');
|
|
|
|
const gpus = await this.gpuDetector.detectGpus();
|
|
|
|
if (gpus.length === 0) {
|
|
logger.logBox(
|
|
'No GPUs Detected',
|
|
[
|
|
'No GPUs were found on this system.',
|
|
'',
|
|
theme.dim('Possible reasons:'),
|
|
' - No discrete GPU installed',
|
|
' - GPU drivers not installed',
|
|
' - GPU not properly connected',
|
|
],
|
|
60,
|
|
'warning',
|
|
);
|
|
return;
|
|
}
|
|
|
|
const rows = gpus.map((gpu) => ({
|
|
id: gpu.id,
|
|
vendor: this.formatVendor(gpu.vendor),
|
|
model: gpu.model,
|
|
vram: `${Math.round(gpu.vram / 1024)} GB`,
|
|
driver: gpu.driverVersion || theme.dim('N/A'),
|
|
cuda: gpu.cudaVersion || theme.dim('N/A'),
|
|
pci: gpu.pciSlot,
|
|
}));
|
|
|
|
const columns: ITableColumn[] = [
|
|
{ header: 'ID', key: 'id', align: 'left' },
|
|
{ header: 'Vendor', key: 'vendor', align: 'left' },
|
|
{ header: 'Model', key: 'model', align: 'left', color: theme.highlight },
|
|
{ header: 'VRAM', key: 'vram', align: 'right', color: theme.info },
|
|
{ header: 'Driver', key: 'driver', align: 'left' },
|
|
{ header: 'CUDA', key: 'cuda', align: 'left' },
|
|
{ header: 'PCI', key: 'pci', align: 'left', color: theme.dim },
|
|
];
|
|
|
|
logger.info(`Found ${gpus.length} GPU(s):`);
|
|
logger.log('');
|
|
logger.logTable(columns, rows);
|
|
logger.log('');
|
|
}
|
|
|
|
/**
|
|
* Show GPU status and utilization
|
|
*/
|
|
public async status(): Promise<void> {
|
|
logger.log('');
|
|
logger.info('GPU Status');
|
|
logger.log('');
|
|
|
|
const gpuInfo = await this.gpuDetector.detectGpus();
|
|
const gpuStatus = await this.gpuDetector.getAllGpuStatus();
|
|
|
|
if (gpuStatus.size === 0) {
|
|
logger.warn('No GPUs detected');
|
|
return;
|
|
}
|
|
|
|
for (const [gpuId, status] of gpuStatus) {
|
|
const info = gpuInfo.find((gpu) => gpu.id === gpuId);
|
|
const utilizationBar = this.createProgressBar(status.utilization, 30);
|
|
const memoryBar = this.createProgressBar(status.memoryUsed / status.memoryTotal * 100, 30);
|
|
|
|
logger.logBoxTitle(`GPU ${status.id}: ${info?.model || 'Unknown GPU'}`, 70, 'info');
|
|
logger.logBoxLine(`Utilization: ${utilizationBar} ${status.utilization.toFixed(1)}%`);
|
|
logger.logBoxLine(
|
|
`Memory: ${memoryBar} ${Math.round(status.memoryUsed)}/${
|
|
Math.round(status.memoryTotal)
|
|
} MB`,
|
|
);
|
|
logger.logBoxLine(`Temperature: ${this.formatTemperature(status.temperature)}`);
|
|
logger.logBoxLine(
|
|
`Power: ${status.powerUsage.toFixed(0)}W / ${status.powerLimit.toFixed(0)}W`,
|
|
);
|
|
logger.logBoxEnd();
|
|
logger.log('');
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Check and install GPU drivers
|
|
*/
|
|
public async drivers(): Promise<void> {
|
|
logger.log('');
|
|
logger.info('GPU Driver Status');
|
|
logger.log('');
|
|
|
|
// Get system info first
|
|
const sysInfo = await this.systemInfo.getSystemInfo();
|
|
|
|
// Detect GPUs
|
|
const gpus = await this.gpuDetector.detectGpus();
|
|
|
|
if (gpus.length === 0) {
|
|
logger.warn('No GPUs detected');
|
|
return;
|
|
}
|
|
|
|
// Check driver status for each vendor
|
|
const vendors = new Set(gpus.map((g) => g.vendor));
|
|
|
|
for (const vendor of vendors) {
|
|
const driver = this.driverManager.getDriver(vendor);
|
|
if (!driver) {
|
|
logger.warn(`No driver support for ${vendor}`);
|
|
continue;
|
|
}
|
|
|
|
const status = await driver.getStatus();
|
|
|
|
logger.logBoxTitle(
|
|
`${this.formatVendor(vendor)} Driver`,
|
|
60,
|
|
status.installed ? 'success' : 'warning',
|
|
);
|
|
logger.logBoxLine(
|
|
`Installed: ${status.installed ? theme.success('Yes') : theme.error('No')}`,
|
|
);
|
|
|
|
if (status.installed) {
|
|
logger.logBoxLine(`Version: ${status.version || 'Unknown'}`);
|
|
logger.logBoxLine(`Runtime: ${status.containerRuntimeVersion || 'Unknown'}`);
|
|
logger.logBoxLine(
|
|
`Container Support: ${
|
|
status.containerSupport ? theme.success('Yes') : theme.warning('No')
|
|
}`,
|
|
);
|
|
} else {
|
|
logger.logBoxLine('');
|
|
logger.logBoxLine(theme.dim('Run `modelgrid gpu install` to install drivers'));
|
|
}
|
|
|
|
logger.logBoxEnd();
|
|
logger.log('');
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Install GPU drivers
|
|
*/
|
|
public async install(): Promise<void> {
|
|
logger.log('');
|
|
logger.info('Installing GPU Drivers');
|
|
logger.log('');
|
|
|
|
// Detect GPUs
|
|
const gpus = await this.gpuDetector.detectGpus();
|
|
|
|
if (gpus.length === 0) {
|
|
logger.error('No GPUs detected - cannot install drivers');
|
|
return;
|
|
}
|
|
|
|
// Install drivers for each vendor
|
|
const vendors = new Set(gpus.map((g) => g.vendor));
|
|
|
|
for (const vendor of vendors) {
|
|
const driver = this.driverManager.getDriver(vendor);
|
|
if (!driver) {
|
|
logger.warn(`No driver installer for ${vendor}`);
|
|
continue;
|
|
}
|
|
|
|
logger.info(`Installing ${this.formatVendor(vendor)} drivers...`);
|
|
|
|
const success = await driver.install({
|
|
installToolkit: true,
|
|
installContainerSupport: true,
|
|
nonInteractive: false,
|
|
});
|
|
|
|
if (success) {
|
|
logger.success(`${this.formatVendor(vendor)} drivers installed successfully`);
|
|
|
|
// Setup container support
|
|
logger.info('Setting up container support...');
|
|
const containerSuccess = await driver.installContainerSupport();
|
|
|
|
if (containerSuccess) {
|
|
logger.success('Container support configured');
|
|
} else {
|
|
logger.warn('Container support setup failed - GPU passthrough may not work');
|
|
}
|
|
} else {
|
|
logger.error(`Failed to install ${this.formatVendor(vendor)} drivers`);
|
|
}
|
|
|
|
logger.log('');
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Format vendor name for display
|
|
*/
|
|
private formatVendor(vendor: string): string {
|
|
switch (vendor) {
|
|
case 'nvidia':
|
|
return theme.gpuNvidia('NVIDIA');
|
|
case 'amd':
|
|
return theme.gpuAmd('AMD');
|
|
case 'intel':
|
|
return theme.gpuIntel('Intel');
|
|
default:
|
|
return vendor;
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Create a progress bar
|
|
*/
|
|
private createProgressBar(percent: number, width: number): string {
|
|
const filled = Math.round((percent / 100) * width);
|
|
const empty = width - filled;
|
|
const bar = '█'.repeat(filled) + '░'.repeat(empty);
|
|
|
|
if (percent >= 90) {
|
|
return theme.error(bar);
|
|
} else if (percent >= 70) {
|
|
return theme.warning(bar);
|
|
} else {
|
|
return theme.success(bar);
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Format temperature with color coding
|
|
*/
|
|
private formatTemperature(temp: number): string {
|
|
const tempStr = `${temp}°C`;
|
|
|
|
if (temp >= 85) {
|
|
return theme.error(tempStr);
|
|
} else if (temp >= 70) {
|
|
return theme.warning(tempStr);
|
|
} else {
|
|
return theme.success(tempStr);
|
|
}
|
|
}
|
|
}
|