diff --git a/test/api-router_test.ts b/test/api-router_test.ts index 83b9cf3..2977fe3 100644 --- a/test/api-router_test.ts +++ b/test/api-router_test.ts @@ -1,5 +1,6 @@ import { assertEquals } from 'jsr:@std/assert@^1.0.0'; import { EventEmitter } from 'node:events'; +import { AuthMiddleware } from '../ts/api/middleware/auth.ts'; import { ApiRouter } from '../ts/api/router.ts'; class TestResponse { @@ -49,6 +50,38 @@ function createRouter(): ApiRouter { {} as never, {} as never, ['valid-key'], + { + authMiddleware: new AuthMiddleware(['valid-key']), + sanityMiddleware: { + validateChatRequest() { + return { valid: true }; + }, + sanitizeChatRequest(body: Record) { + return body; + }, + validateEmbeddingsRequest() { + return { valid: true }; + }, + sanitizeEmbeddingsRequest(body: Record) { + return body; + }, + } as never, + chatHandler: { + async handleChatCompletion() { + throw new Error('chat handler should not run in this test'); + }, + } as never, + modelsHandler: { + async handleListModels() { + throw new Error('models handler should not run in this test'); + }, + } as never, + embeddingsHandler: { + async handleEmbeddings() { + throw new Error('embeddings handler should not run in this test'); + }, + } as never, + }, ); } diff --git a/ts/api/router.ts b/ts/api/router.ts index 501ea60..fbaee3a 100644 --- a/ts/api/router.ts +++ b/ts/api/router.ts @@ -22,6 +22,14 @@ interface IParsedRequestBody { body?: unknown; } +interface IApiRouterOptions { + chatHandler?: ChatHandler; + modelsHandler?: ModelsHandler; + embeddingsHandler?: EmbeddingsHandler; + authMiddleware?: AuthMiddleware; + sanityMiddleware?: SanityMiddleware; +} + /** * API Router - routes requests to handlers */ @@ -42,6 +50,7 @@ export class ApiRouter { modelLoader: ModelLoader, clusterCoordinator: ClusterCoordinator, apiKeys: string[], + options: IApiRouterOptions = {}, ) { this.containerManager = containerManager; this.modelRegistry = modelRegistry; @@ -49,22 +58,23 @@ export class ApiRouter { this.clusterCoordinator = clusterCoordinator; // Initialize handlers - this.chatHandler = new ChatHandler( + this.chatHandler = options.chatHandler || new ChatHandler( containerManager, modelRegistry, modelLoader, clusterCoordinator, ); - this.modelsHandler = new ModelsHandler(containerManager, modelRegistry, clusterCoordinator); - this.embeddingsHandler = new EmbeddingsHandler( + this.modelsHandler = + options.modelsHandler || new ModelsHandler(containerManager, modelRegistry, clusterCoordinator); + this.embeddingsHandler = options.embeddingsHandler || new EmbeddingsHandler( containerManager, modelRegistry, clusterCoordinator, ); // Initialize middleware - this.authMiddleware = new AuthMiddleware(apiKeys); - this.sanityMiddleware = new SanityMiddleware(modelRegistry); + this.authMiddleware = options.authMiddleware || new AuthMiddleware(apiKeys); + this.sanityMiddleware = options.sanityMiddleware || new SanityMiddleware(modelRegistry); } /**