Files
modelgrid/ts/drivers/driver-manager.ts

268 lines
7.2 KiB
TypeScript
Raw Normal View History

2026-01-30 03:16:57 +00:00
/**
* Driver Manager
*
* Coordinates detection and installation of GPU drivers across all vendors.
*/
import type { IDriverStatus, TGpuVendor } from '../interfaces/gpu.ts';
import { logger } from '../logger.ts';
import { GpuDetector } from '../hardware/gpu-detector.ts';
import { BaseDriver, type IDriverInstallOptions } from './base-driver.ts';
import { NvidiaDriver } from './nvidia.ts';
import { AmdDriver } from './amd.ts';
import { IntelDriver } from './intel.ts';
/**
* Driver Manager - coordinates GPU driver management
*/
export class DriverManager {
private gpuDetector: GpuDetector;
private drivers: Map<TGpuVendor, BaseDriver>;
constructor() {
this.gpuDetector = new GpuDetector();
this.drivers = new Map([
['nvidia', new NvidiaDriver()],
['amd', new AmdDriver()],
['intel', new IntelDriver()],
]);
}
/**
* Get driver manager for a specific vendor
*/
public getDriver(vendor: TGpuVendor): BaseDriver | undefined {
return this.drivers.get(vendor);
}
/**
* Get status of all GPU drivers
*/
public async getAllDriverStatus(): Promise<Map<TGpuVendor, IDriverStatus>> {
const statuses = new Map<TGpuVendor, IDriverStatus>();
// Only check drivers for detected GPUs
const gpus = await this.gpuDetector.detectGpus();
const detectedVendors = new Set(gpus.map((g) => g.vendor));
for (const vendor of detectedVendors) {
if (vendor === 'unknown') continue;
const driver = this.drivers.get(vendor);
if (driver) {
const status = await driver.getStatus();
statuses.set(vendor, status);
}
}
return statuses;
}
/**
* Check drivers for all detected GPUs
*/
public async checkAllDrivers(): Promise<{
allInstalled: boolean;
allContainerReady: boolean;
issues: string[];
}> {
const gpus = await this.gpuDetector.detectGpus();
const issues: string[] = [];
let allInstalled = true;
let allContainerReady = true;
if (gpus.length === 0) {
issues.push('No GPUs detected');
return { allInstalled: false, allContainerReady: false, issues };
}
// Group GPUs by vendor
const vendorCounts = new Map<TGpuVendor, number>();
for (const gpu of gpus) {
vendorCounts.set(gpu.vendor, (vendorCounts.get(gpu.vendor) || 0) + 1);
}
// Check each vendor
for (const [vendor, count] of vendorCounts) {
if (vendor === 'unknown') {
issues.push(`${count} GPU(s) with unknown vendor - cannot manage drivers`);
continue;
}
const driver = this.drivers.get(vendor);
if (!driver) {
issues.push(`No driver manager for ${vendor}`);
continue;
}
const status = await driver.getStatus();
if (!status.installed) {
allInstalled = false;
issues.push(`${driver.displayName} driver not installed for ${count} GPU(s)`);
}
if (!status.containerSupport) {
allContainerReady = false;
issues.push(`${driver.displayName} container support not configured`);
}
// Add specific issues
issues.push(...status.issues);
}
return { allInstalled, allContainerReady, issues };
}
/**
* Install drivers for all detected GPUs
*/
public async installAllDrivers(options: Partial<IDriverInstallOptions> = {}): Promise<boolean> {
const fullOptions: IDriverInstallOptions = {
installToolkit: options.installToolkit ?? true,
installContainerSupport: options.installContainerSupport ?? true,
nonInteractive: options.nonInteractive ?? false,
driverVersion: options.driverVersion,
toolkitVersion: options.toolkitVersion,
};
const gpus = await this.gpuDetector.detectGpus();
const vendors = new Set(gpus.map((g) => g.vendor).filter((v) => v !== 'unknown'));
if (vendors.size === 0) {
logger.error('No supported GPUs detected');
return false;
}
let allSuccess = true;
for (const vendor of vendors) {
const driver = this.drivers.get(vendor);
if (!driver) continue;
logger.info(`Installing ${driver.displayName} drivers...`);
const success = await driver.install(fullOptions);
if (!success) {
allSuccess = false;
logger.error(`Failed to install ${driver.displayName} drivers`);
}
}
return allSuccess;
}
/**
* Install container support for all GPUs
*/
public async installContainerSupport(): Promise<boolean> {
const gpus = await this.gpuDetector.detectGpus();
const vendors = new Set(gpus.map((g) => g.vendor).filter((v) => v !== 'unknown'));
let allSuccess = true;
for (const vendor of vendors) {
const driver = this.drivers.get(vendor);
if (!driver) continue;
const success = await driver.installContainerSupport();
if (!success) {
allSuccess = false;
}
}
return allSuccess;
}
/**
* Print driver status summary
*/
public async printDriverStatus(): Promise<void> {
const gpus = await this.gpuDetector.detectGpus();
if (gpus.length === 0) {
logger.logBox('Driver Status', ['No GPUs detected'], 50, 'warning');
return;
}
// Group by vendor
const vendorGpus = new Map<TGpuVendor, typeof gpus>();
for (const gpu of gpus) {
if (!vendorGpus.has(gpu.vendor)) {
vendorGpus.set(gpu.vendor, []);
}
vendorGpus.get(gpu.vendor)!.push(gpu);
}
// Print status for each vendor
for (const [vendor, gpuList] of vendorGpus) {
if (vendor === 'unknown') {
logger.logBox('Unknown GPUs', [
`${gpuList.length} GPU(s) with unknown vendor`,
'Manual driver installation may be required',
], 50, 'warning');
continue;
}
const driver = this.drivers.get(vendor);
if (driver) {
await driver.logStatus();
}
}
}
/**
* Get Docker run arguments for GPU support
*/
public async getDockerGpuArgs(gpuIds?: string[]): Promise<string[]> {
const gpus = await this.gpuDetector.detectGpus();
const args: string[] = [];
// Filter to specific GPUs if provided
const targetGpus = gpuIds
? gpus.filter((g) => gpuIds.includes(g.id))
: gpus;
if (targetGpus.length === 0) {
return args;
}
// Determine vendor (assume single vendor for simplicity)
const vendor = targetGpus[0].vendor;
switch (vendor) {
case 'nvidia':
// NVIDIA uses nvidia-docker runtime
args.push('--runtime=nvidia');
if (gpuIds && gpuIds.length > 0) {
// Use specific GPU indices
const indices = targetGpus.map((g) => g.index).join(',');
args.push(`--gpus="device=${indices}"`);
} else {
args.push('--gpus=all');
}
break;
case 'amd':
// AMD uses device passthrough
args.push('--device=/dev/kfd');
for (const gpu of targetGpus) {
args.push(`--device=/dev/dri/renderD${128 + gpu.index}`);
}
args.push('--group-add=video');
args.push('--security-opt=seccomp=unconfined');
break;
case 'intel':
// Intel uses device passthrough
for (const gpu of targetGpus) {
args.push(`--device=/dev/dri/renderD${128 + gpu.index}`);
}
args.push('--group-add=render');
break;
}
return args;
}
}