smartai/ts/provider.ollama.ts

170 lines
5.1 KiB
TypeScript
Raw Normal View History

2024-04-25 10:49:07 +02:00
import * as plugins from './plugins.js';
import * as paths from './paths.js';
import { MultiModalModel } from './abstract.classes.multimodal.js';
import type { ChatOptions, ChatResponse, ChatMessage } from './abstract.classes.multimodal.js';
2024-04-25 10:49:07 +02:00
export interface IOllamaProviderOptions {
baseUrl?: string;
model?: string;
}
export class OllamaProvider extends MultiModalModel {
private options: IOllamaProviderOptions;
private baseUrl: string;
private model: string;
constructor(optionsArg: IOllamaProviderOptions = {}) {
super();
this.options = optionsArg;
this.baseUrl = optionsArg.baseUrl || 'http://localhost:11434';
this.model = optionsArg.model || 'llama2';
}
async start() {
// Verify Ollama is running
try {
const response = await fetch(`${this.baseUrl}/api/tags`);
if (!response.ok) {
throw new Error('Failed to connect to Ollama server');
}
} catch (error) {
throw new Error(`Failed to connect to Ollama server at ${this.baseUrl}: ${error.message}`);
}
}
async stop() {}
public async chatStream(input: ReadableStream<Uint8Array>): Promise<ReadableStream<string>> {
// Create a TextDecoder to handle incoming chunks
const decoder = new TextDecoder();
let buffer = '';
let currentMessage: { role: string; content: string; } | null = null;
// Create a TransformStream to process the input
const transform = new TransformStream<Uint8Array, string>({
async transform(chunk, controller) {
buffer += decoder.decode(chunk, { stream: true });
// Try to parse complete JSON messages from the buffer
while (true) {
const newlineIndex = buffer.indexOf('\n');
if (newlineIndex === -1) break;
const line = buffer.slice(0, newlineIndex);
buffer = buffer.slice(newlineIndex + 1);
if (line.trim()) {
try {
const message = JSON.parse(line);
currentMessage = {
role: message.role || 'user',
content: message.content || '',
};
} catch (e) {
console.error('Failed to parse message:', e);
}
}
}
// If we have a complete message, send it to Ollama
if (currentMessage) {
const response = await fetch(`${this.baseUrl}/api/chat`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({
model: this.model,
messages: [{ role: currentMessage.role, content: currentMessage.content }],
stream: true,
}),
});
// Process each chunk from Ollama
const reader = response.body?.getReader();
if (reader) {
try {
while (true) {
const { done, value } = await reader.read();
if (done) break;
const chunk = new TextDecoder().decode(value);
const lines = chunk.split('\n');
for (const line of lines) {
if (line.trim()) {
try {
const parsed = JSON.parse(line);
const content = parsed.message?.content;
if (content) {
controller.enqueue(content);
}
} catch (e) {
console.error('Failed to parse Ollama response:', e);
}
}
}
}
} finally {
reader.releaseLock();
}
}
currentMessage = null;
}
},
flush(controller) {
if (buffer) {
try {
const message = JSON.parse(buffer);
controller.enqueue(message.content || '');
} catch (e) {
console.error('Failed to parse remaining buffer:', e);
}
}
}
});
// Connect the input to our transform stream
return input.pipeThrough(transform);
}
// Implementing the synchronous chat interaction
public async chat(optionsArg: ChatOptions): Promise<ChatResponse> {
// Format messages for Ollama
const messages = [
{ role: 'system', content: optionsArg.systemMessage },
...optionsArg.messageHistory,
{ role: 'user', content: optionsArg.userMessage }
];
// Make API call to Ollama
const response = await fetch(`${this.baseUrl}/api/chat`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({
model: this.model,
messages: messages,
stream: false
}),
});
if (!response.ok) {
throw new Error(`Ollama API error: ${response.statusText}`);
}
const result = await response.json();
return {
role: 'assistant' as const,
message: result.message.content,
};
}
public async audio(optionsArg: { message: string }): Promise<NodeJS.ReadableStream> {
throw new Error('Audio generation is not supported by Ollama.');
}
}