/** * Driver Manager * * Coordinates detection and installation of GPU drivers across all vendors. */ import type { IDriverStatus, TGpuVendor } from '../interfaces/gpu.ts'; import { logger } from '../logger.ts'; import { GpuDetector } from '../hardware/gpu-detector.ts'; import { BaseDriver, type IDriverInstallOptions } from './base-driver.ts'; import { NvidiaDriver } from './nvidia.ts'; import { AmdDriver } from './amd.ts'; import { IntelDriver } from './intel.ts'; /** * Driver Manager - coordinates GPU driver management */ export class DriverManager { private gpuDetector: GpuDetector; private drivers: Map; constructor() { this.gpuDetector = new GpuDetector(); this.drivers = new Map([ ['nvidia', new NvidiaDriver()], ['amd', new AmdDriver()], ['intel', new IntelDriver()], ]); } /** * Get driver manager for a specific vendor */ public getDriver(vendor: TGpuVendor): BaseDriver | undefined { return this.drivers.get(vendor); } /** * Get status of all GPU drivers */ public async getAllDriverStatus(): Promise> { const statuses = new Map(); // Only check drivers for detected GPUs const gpus = await this.gpuDetector.detectGpus(); const detectedVendors = new Set(gpus.map((g) => g.vendor)); for (const vendor of detectedVendors) { if (vendor === 'unknown') continue; const driver = this.drivers.get(vendor); if (driver) { const status = await driver.getStatus(); statuses.set(vendor, status); } } return statuses; } /** * Check drivers for all detected GPUs */ public async checkAllDrivers(): Promise<{ allInstalled: boolean; allContainerReady: boolean; issues: string[]; }> { const gpus = await this.gpuDetector.detectGpus(); const issues: string[] = []; let allInstalled = true; let allContainerReady = true; if (gpus.length === 0) { issues.push('No GPUs detected'); return { allInstalled: false, allContainerReady: false, issues }; } // Group GPUs by vendor const vendorCounts = new Map(); for (const gpu of gpus) { vendorCounts.set(gpu.vendor, (vendorCounts.get(gpu.vendor) || 0) + 1); } // Check each vendor for (const [vendor, count] of vendorCounts) { if (vendor === 'unknown') { issues.push(`${count} GPU(s) with unknown vendor - cannot manage drivers`); continue; } const driver = this.drivers.get(vendor); if (!driver) { issues.push(`No driver manager for ${vendor}`); continue; } const status = await driver.getStatus(); if (!status.installed) { allInstalled = false; issues.push(`${driver.displayName} driver not installed for ${count} GPU(s)`); } if (!status.containerSupport) { allContainerReady = false; issues.push(`${driver.displayName} container support not configured`); } // Add specific issues issues.push(...status.issues); } return { allInstalled, allContainerReady, issues }; } /** * Install drivers for all detected GPUs */ public async installAllDrivers(options: Partial = {}): Promise { const fullOptions: IDriverInstallOptions = { installToolkit: options.installToolkit ?? true, installContainerSupport: options.installContainerSupport ?? true, nonInteractive: options.nonInteractive ?? false, driverVersion: options.driverVersion, toolkitVersion: options.toolkitVersion, }; const gpus = await this.gpuDetector.detectGpus(); const vendors = new Set(gpus.map((g) => g.vendor).filter((v) => v !== 'unknown')); if (vendors.size === 0) { logger.error('No supported GPUs detected'); return false; } let allSuccess = true; for (const vendor of vendors) { const driver = this.drivers.get(vendor); if (!driver) continue; logger.info(`Installing ${driver.displayName} drivers...`); const success = await driver.install(fullOptions); if (!success) { allSuccess = false; logger.error(`Failed to install ${driver.displayName} drivers`); } } return allSuccess; } /** * Install container support for all GPUs */ public async installContainerSupport(): Promise { const gpus = await this.gpuDetector.detectGpus(); const vendors = new Set(gpus.map((g) => g.vendor).filter((v) => v !== 'unknown')); let allSuccess = true; for (const vendor of vendors) { const driver = this.drivers.get(vendor); if (!driver) continue; const success = await driver.installContainerSupport(); if (!success) { allSuccess = false; } } return allSuccess; } /** * Print driver status summary */ public async printDriverStatus(): Promise { const gpus = await this.gpuDetector.detectGpus(); if (gpus.length === 0) { logger.logBox('Driver Status', ['No GPUs detected'], 50, 'warning'); return; } // Group by vendor const vendorGpus = new Map(); for (const gpu of gpus) { if (!vendorGpus.has(gpu.vendor)) { vendorGpus.set(gpu.vendor, []); } vendorGpus.get(gpu.vendor)!.push(gpu); } // Print status for each vendor for (const [vendor, gpuList] of vendorGpus) { if (vendor === 'unknown') { logger.logBox('Unknown GPUs', [ `${gpuList.length} GPU(s) with unknown vendor`, 'Manual driver installation may be required', ], 50, 'warning'); continue; } const driver = this.drivers.get(vendor); if (driver) { await driver.logStatus(); } } } /** * Get Docker run arguments for GPU support */ public async getDockerGpuArgs(gpuIds?: string[]): Promise { const gpus = await this.gpuDetector.detectGpus(); const args: string[] = []; // Filter to specific GPUs if provided const targetGpus = gpuIds ? gpus.filter((g) => gpuIds.includes(g.id)) : gpus; if (targetGpus.length === 0) { return args; } // Determine vendor (assume single vendor for simplicity) const vendor = targetGpus[0].vendor; switch (vendor) { case 'nvidia': // NVIDIA uses nvidia-docker runtime args.push('--runtime=nvidia'); if (gpuIds && gpuIds.length > 0) { // Use specific GPU indices const indices = targetGpus.map((g) => g.index).join(','); args.push(`--gpus="device=${indices}"`); } else { args.push('--gpus=all'); } break; case 'amd': // AMD uses device passthrough args.push('--device=/dev/kfd'); for (const gpu of targetGpus) { args.push(`--device=/dev/dri/renderD${128 + gpu.index}`); } args.push('--group-add=video'); args.push('--security-opt=seccomp=unconfined'); break; case 'intel': // Intel uses device passthrough for (const gpu of targetGpus) { args.push(`--device=/dev/dri/renderD${128 + gpu.index}`); } args.push('--group-add=render'); break; } return args; } }