268 lines
7.2 KiB
TypeScript
268 lines
7.2 KiB
TypeScript
|
|
/**
|
||
|
|
* Driver Manager
|
||
|
|
*
|
||
|
|
* Coordinates detection and installation of GPU drivers across all vendors.
|
||
|
|
*/
|
||
|
|
|
||
|
|
import type { IDriverStatus, TGpuVendor } from '../interfaces/gpu.ts';
|
||
|
|
import { logger } from '../logger.ts';
|
||
|
|
import { GpuDetector } from '../hardware/gpu-detector.ts';
|
||
|
|
import { BaseDriver, type IDriverInstallOptions } from './base-driver.ts';
|
||
|
|
import { NvidiaDriver } from './nvidia.ts';
|
||
|
|
import { AmdDriver } from './amd.ts';
|
||
|
|
import { IntelDriver } from './intel.ts';
|
||
|
|
|
||
|
|
/**
|
||
|
|
* Driver Manager - coordinates GPU driver management
|
||
|
|
*/
|
||
|
|
export class DriverManager {
|
||
|
|
private gpuDetector: GpuDetector;
|
||
|
|
private drivers: Map<TGpuVendor, BaseDriver>;
|
||
|
|
|
||
|
|
constructor() {
|
||
|
|
this.gpuDetector = new GpuDetector();
|
||
|
|
this.drivers = new Map([
|
||
|
|
['nvidia', new NvidiaDriver()],
|
||
|
|
['amd', new AmdDriver()],
|
||
|
|
['intel', new IntelDriver()],
|
||
|
|
]);
|
||
|
|
}
|
||
|
|
|
||
|
|
/**
|
||
|
|
* Get driver manager for a specific vendor
|
||
|
|
*/
|
||
|
|
public getDriver(vendor: TGpuVendor): BaseDriver | undefined {
|
||
|
|
return this.drivers.get(vendor);
|
||
|
|
}
|
||
|
|
|
||
|
|
/**
|
||
|
|
* Get status of all GPU drivers
|
||
|
|
*/
|
||
|
|
public async getAllDriverStatus(): Promise<Map<TGpuVendor, IDriverStatus>> {
|
||
|
|
const statuses = new Map<TGpuVendor, IDriverStatus>();
|
||
|
|
|
||
|
|
// Only check drivers for detected GPUs
|
||
|
|
const gpus = await this.gpuDetector.detectGpus();
|
||
|
|
const detectedVendors = new Set(gpus.map((g) => g.vendor));
|
||
|
|
|
||
|
|
for (const vendor of detectedVendors) {
|
||
|
|
if (vendor === 'unknown') continue;
|
||
|
|
|
||
|
|
const driver = this.drivers.get(vendor);
|
||
|
|
if (driver) {
|
||
|
|
const status = await driver.getStatus();
|
||
|
|
statuses.set(vendor, status);
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
return statuses;
|
||
|
|
}
|
||
|
|
|
||
|
|
/**
|
||
|
|
* Check drivers for all detected GPUs
|
||
|
|
*/
|
||
|
|
public async checkAllDrivers(): Promise<{
|
||
|
|
allInstalled: boolean;
|
||
|
|
allContainerReady: boolean;
|
||
|
|
issues: string[];
|
||
|
|
}> {
|
||
|
|
const gpus = await this.gpuDetector.detectGpus();
|
||
|
|
const issues: string[] = [];
|
||
|
|
let allInstalled = true;
|
||
|
|
let allContainerReady = true;
|
||
|
|
|
||
|
|
if (gpus.length === 0) {
|
||
|
|
issues.push('No GPUs detected');
|
||
|
|
return { allInstalled: false, allContainerReady: false, issues };
|
||
|
|
}
|
||
|
|
|
||
|
|
// Group GPUs by vendor
|
||
|
|
const vendorCounts = new Map<TGpuVendor, number>();
|
||
|
|
for (const gpu of gpus) {
|
||
|
|
vendorCounts.set(gpu.vendor, (vendorCounts.get(gpu.vendor) || 0) + 1);
|
||
|
|
}
|
||
|
|
|
||
|
|
// Check each vendor
|
||
|
|
for (const [vendor, count] of vendorCounts) {
|
||
|
|
if (vendor === 'unknown') {
|
||
|
|
issues.push(`${count} GPU(s) with unknown vendor - cannot manage drivers`);
|
||
|
|
continue;
|
||
|
|
}
|
||
|
|
|
||
|
|
const driver = this.drivers.get(vendor);
|
||
|
|
if (!driver) {
|
||
|
|
issues.push(`No driver manager for ${vendor}`);
|
||
|
|
continue;
|
||
|
|
}
|
||
|
|
|
||
|
|
const status = await driver.getStatus();
|
||
|
|
|
||
|
|
if (!status.installed) {
|
||
|
|
allInstalled = false;
|
||
|
|
issues.push(`${driver.displayName} driver not installed for ${count} GPU(s)`);
|
||
|
|
}
|
||
|
|
|
||
|
|
if (!status.containerSupport) {
|
||
|
|
allContainerReady = false;
|
||
|
|
issues.push(`${driver.displayName} container support not configured`);
|
||
|
|
}
|
||
|
|
|
||
|
|
// Add specific issues
|
||
|
|
issues.push(...status.issues);
|
||
|
|
}
|
||
|
|
|
||
|
|
return { allInstalled, allContainerReady, issues };
|
||
|
|
}
|
||
|
|
|
||
|
|
/**
|
||
|
|
* Install drivers for all detected GPUs
|
||
|
|
*/
|
||
|
|
public async installAllDrivers(options: Partial<IDriverInstallOptions> = {}): Promise<boolean> {
|
||
|
|
const fullOptions: IDriverInstallOptions = {
|
||
|
|
installToolkit: options.installToolkit ?? true,
|
||
|
|
installContainerSupport: options.installContainerSupport ?? true,
|
||
|
|
nonInteractive: options.nonInteractive ?? false,
|
||
|
|
driverVersion: options.driverVersion,
|
||
|
|
toolkitVersion: options.toolkitVersion,
|
||
|
|
};
|
||
|
|
|
||
|
|
const gpus = await this.gpuDetector.detectGpus();
|
||
|
|
const vendors = new Set(gpus.map((g) => g.vendor).filter((v) => v !== 'unknown'));
|
||
|
|
|
||
|
|
if (vendors.size === 0) {
|
||
|
|
logger.error('No supported GPUs detected');
|
||
|
|
return false;
|
||
|
|
}
|
||
|
|
|
||
|
|
let allSuccess = true;
|
||
|
|
|
||
|
|
for (const vendor of vendors) {
|
||
|
|
const driver = this.drivers.get(vendor);
|
||
|
|
if (!driver) continue;
|
||
|
|
|
||
|
|
logger.info(`Installing ${driver.displayName} drivers...`);
|
||
|
|
|
||
|
|
const success = await driver.install(fullOptions);
|
||
|
|
if (!success) {
|
||
|
|
allSuccess = false;
|
||
|
|
logger.error(`Failed to install ${driver.displayName} drivers`);
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
return allSuccess;
|
||
|
|
}
|
||
|
|
|
||
|
|
/**
|
||
|
|
* Install container support for all GPUs
|
||
|
|
*/
|
||
|
|
public async installContainerSupport(): Promise<boolean> {
|
||
|
|
const gpus = await this.gpuDetector.detectGpus();
|
||
|
|
const vendors = new Set(gpus.map((g) => g.vendor).filter((v) => v !== 'unknown'));
|
||
|
|
|
||
|
|
let allSuccess = true;
|
||
|
|
|
||
|
|
for (const vendor of vendors) {
|
||
|
|
const driver = this.drivers.get(vendor);
|
||
|
|
if (!driver) continue;
|
||
|
|
|
||
|
|
const success = await driver.installContainerSupport();
|
||
|
|
if (!success) {
|
||
|
|
allSuccess = false;
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
return allSuccess;
|
||
|
|
}
|
||
|
|
|
||
|
|
/**
|
||
|
|
* Print driver status summary
|
||
|
|
*/
|
||
|
|
public async printDriverStatus(): Promise<void> {
|
||
|
|
const gpus = await this.gpuDetector.detectGpus();
|
||
|
|
|
||
|
|
if (gpus.length === 0) {
|
||
|
|
logger.logBox('Driver Status', ['No GPUs detected'], 50, 'warning');
|
||
|
|
return;
|
||
|
|
}
|
||
|
|
|
||
|
|
// Group by vendor
|
||
|
|
const vendorGpus = new Map<TGpuVendor, typeof gpus>();
|
||
|
|
for (const gpu of gpus) {
|
||
|
|
if (!vendorGpus.has(gpu.vendor)) {
|
||
|
|
vendorGpus.set(gpu.vendor, []);
|
||
|
|
}
|
||
|
|
vendorGpus.get(gpu.vendor)!.push(gpu);
|
||
|
|
}
|
||
|
|
|
||
|
|
// Print status for each vendor
|
||
|
|
for (const [vendor, gpuList] of vendorGpus) {
|
||
|
|
if (vendor === 'unknown') {
|
||
|
|
logger.logBox('Unknown GPUs', [
|
||
|
|
`${gpuList.length} GPU(s) with unknown vendor`,
|
||
|
|
'Manual driver installation may be required',
|
||
|
|
], 50, 'warning');
|
||
|
|
continue;
|
||
|
|
}
|
||
|
|
|
||
|
|
const driver = this.drivers.get(vendor);
|
||
|
|
if (driver) {
|
||
|
|
await driver.logStatus();
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
/**
|
||
|
|
* Get Docker run arguments for GPU support
|
||
|
|
*/
|
||
|
|
public async getDockerGpuArgs(gpuIds?: string[]): Promise<string[]> {
|
||
|
|
const gpus = await this.gpuDetector.detectGpus();
|
||
|
|
const args: string[] = [];
|
||
|
|
|
||
|
|
// Filter to specific GPUs if provided
|
||
|
|
const targetGpus = gpuIds
|
||
|
|
? gpus.filter((g) => gpuIds.includes(g.id))
|
||
|
|
: gpus;
|
||
|
|
|
||
|
|
if (targetGpus.length === 0) {
|
||
|
|
return args;
|
||
|
|
}
|
||
|
|
|
||
|
|
// Determine vendor (assume single vendor for simplicity)
|
||
|
|
const vendor = targetGpus[0].vendor;
|
||
|
|
|
||
|
|
switch (vendor) {
|
||
|
|
case 'nvidia':
|
||
|
|
// NVIDIA uses nvidia-docker runtime
|
||
|
|
args.push('--runtime=nvidia');
|
||
|
|
if (gpuIds && gpuIds.length > 0) {
|
||
|
|
// Use specific GPU indices
|
||
|
|
const indices = targetGpus.map((g) => g.index).join(',');
|
||
|
|
args.push(`--gpus="device=${indices}"`);
|
||
|
|
} else {
|
||
|
|
args.push('--gpus=all');
|
||
|
|
}
|
||
|
|
break;
|
||
|
|
|
||
|
|
case 'amd':
|
||
|
|
// AMD uses device passthrough
|
||
|
|
args.push('--device=/dev/kfd');
|
||
|
|
for (const gpu of targetGpus) {
|
||
|
|
args.push(`--device=/dev/dri/renderD${128 + gpu.index}`);
|
||
|
|
}
|
||
|
|
args.push('--group-add=video');
|
||
|
|
args.push('--security-opt=seccomp=unconfined');
|
||
|
|
break;
|
||
|
|
|
||
|
|
case 'intel':
|
||
|
|
// Intel uses device passthrough
|
||
|
|
for (const gpu of targetGpus) {
|
||
|
|
args.push(`--device=/dev/dri/renderD${128 + gpu.index}`);
|
||
|
|
}
|
||
|
|
args.push('--group-add=render');
|
||
|
|
break;
|
||
|
|
}
|
||
|
|
|
||
|
|
return args;
|
||
|
|
}
|
||
|
|
}
|