diff --git a/test/api-server_test.ts b/test/api-server_test.ts index 55588f1..2a7fd2f 100644 --- a/test/api-server_test.ts +++ b/test/api-server_test.ts @@ -47,15 +47,14 @@ Deno.test('ApiServer serves health metrics and authenticated model listings', as }; }, } as never, - ); - - (server as unknown as { - gpuDetector: { detectGpus: () => Promise }; - }).gpuDetector = { - async detectGpus() { - return [{ id: 'nvidia-0' }]; + { + gpuDetector: { + async detectGpus() { + return [{ id: 'nvidia-0' }]; + }, + } as never, }, - }; + ); await server.start(); @@ -142,15 +141,14 @@ Deno.test('ApiServer metrics expose 5xx counts for failing endpoints', async () }; }, } as never, - ); - - (server as unknown as { - gpuDetector: { detectGpus: () => Promise }; - }).gpuDetector = { - async detectGpus() { - return []; + { + gpuDetector: { + async detectGpus() { + return []; + }, + } as never, }, - }; + ); await server.start(); @@ -209,15 +207,14 @@ Deno.test('ApiServer enforces api rate limits while exempting health and metrics }; }, } as never, - ); - - (server as unknown as { - gpuDetector: { detectGpus: () => Promise }; - }).gpuDetector = { - async detectGpus() { - return []; + { + gpuDetector: { + async detectGpus() { + return []; + }, + } as never, }, - }; + ); await server.start(); diff --git a/ts/api/server.ts b/ts/api/server.ts index d26aecc..eefdfeb 100644 --- a/ts/api/server.ts +++ b/ts/api/server.ts @@ -17,6 +17,12 @@ import { ModelLoader } from '../models/loader.ts'; import { GpuDetector } from '../hardware/gpu-detector.ts'; import { ClusterHandler } from './handlers/cluster.ts'; +interface IApiServerOptions { + gpuDetector?: GpuDetector; + router?: ApiRouter; + clusterHandler?: ClusterHandler; +} + /** * API Server for ModelGrid */ @@ -42,15 +48,16 @@ export class ApiServer { modelRegistry: ModelRegistry, modelLoader: ModelLoader, clusterCoordinator: ClusterCoordinator, + options: IApiServerOptions = {}, ) { this.config = config; this.containerManager = containerManager; this.modelRegistry = modelRegistry; - this.gpuDetector = new GpuDetector(); + this.gpuDetector = options.gpuDetector || new GpuDetector(); this.modelLoader = modelLoader; this.clusterCoordinator = clusterCoordinator; - this.clusterHandler = new ClusterHandler(clusterCoordinator); - this.router = new ApiRouter( + this.clusterHandler = options.clusterHandler || new ClusterHandler(clusterCoordinator); + this.router = options.router || new ApiRouter( containerManager, modelRegistry, this.modelLoader,