Files
smartai/ts/smartai.provider.ollama.ts

427 lines
14 KiB
TypeScript

import type {
LanguageModelV3,
LanguageModelV3CallOptions,
LanguageModelV3GenerateResult,
LanguageModelV3StreamResult,
LanguageModelV3StreamPart,
LanguageModelV3Prompt,
LanguageModelV3Content,
LanguageModelV3Usage,
LanguageModelV3FinishReason,
} from '@ai-sdk/provider';
import type { ISmartAiOptions, IOllamaModelOptions } from './smartai.interfaces.js';
interface IOllamaMessage {
role: string;
content: string;
images?: string[];
tool_calls?: Array<{
function: { name: string; arguments: Record<string, unknown> };
}>;
thinking?: string;
}
interface IOllamaTool {
type: 'function';
function: {
name: string;
description: string;
parameters: Record<string, unknown>;
};
}
/**
* Convert AI SDK V3 prompt messages to Ollama's message format.
*/
function convertPromptToOllamaMessages(prompt: LanguageModelV3Prompt): IOllamaMessage[] {
const messages: IOllamaMessage[] = [];
for (const msg of prompt) {
if (msg.role === 'system') {
// System message content is a plain string in V3
messages.push({ role: 'system', content: msg.content });
} else if (msg.role === 'user') {
let text = '';
const images: string[] = [];
for (const part of msg.content) {
if (part.type === 'text') {
text += part.text;
} else if (part.type === 'file' && part.mediaType?.startsWith('image/')) {
// Handle image files — Ollama expects base64 images
if (typeof part.data === 'string') {
images.push(part.data);
} else if (part.data instanceof Uint8Array) {
images.push(Buffer.from(part.data).toString('base64'));
}
}
}
const m: IOllamaMessage = { role: 'user', content: text };
if (images.length > 0) m.images = images;
messages.push(m);
} else if (msg.role === 'assistant') {
let text = '';
let thinking = '';
const toolCalls: IOllamaMessage['tool_calls'] = [];
for (const part of msg.content) {
if (part.type === 'text') {
text += part.text;
} else if (part.type === 'reasoning') {
thinking += part.text;
} else if (part.type === 'tool-call') {
const args = typeof part.input === 'string'
? JSON.parse(part.input as string)
: (part.input as Record<string, unknown>);
toolCalls.push({
function: {
name: part.toolName,
arguments: args,
},
});
}
}
const m: IOllamaMessage = { role: 'assistant', content: text };
if (toolCalls.length > 0) m.tool_calls = toolCalls;
if (thinking) m.thinking = thinking;
messages.push(m);
} else if (msg.role === 'tool') {
for (const part of msg.content) {
if (part.type === 'tool-result') {
let resultContent = '';
if (part.output) {
if (part.output.type === 'text') {
resultContent = part.output.value;
} else if (part.output.type === 'json') {
resultContent = JSON.stringify(part.output.value);
}
}
messages.push({ role: 'tool', content: resultContent });
}
}
}
}
return messages;
}
/**
* Convert AI SDK V3 tools to Ollama's tool format.
*/
function convertToolsToOllamaTools(tools: LanguageModelV3CallOptions['tools']): IOllamaTool[] | undefined {
if (!tools || tools.length === 0) return undefined;
return tools
.filter((t): t is Extract<typeof t, { type: 'function' }> => t.type === 'function')
.map(t => ({
type: 'function' as const,
function: {
name: t.name,
description: t.description ?? '',
parameters: t.inputSchema as Record<string, unknown>,
},
}));
}
function makeUsage(promptTokens?: number, completionTokens?: number): LanguageModelV3Usage {
return {
inputTokens: {
total: promptTokens,
noCache: undefined,
cacheRead: undefined,
cacheWrite: undefined,
},
outputTokens: {
total: completionTokens,
text: completionTokens,
reasoning: undefined,
},
};
}
function makeFinishReason(reason?: string): LanguageModelV3FinishReason {
if (reason === 'tool_calls' || reason === 'tool-calls') {
return { unified: 'tool-calls', raw: reason };
}
return { unified: 'stop', raw: reason ?? 'stop' };
}
let idCounter = 0;
function generateId(): string {
return `ollama-${Date.now()}-${idCounter++}`;
}
/**
* Custom LanguageModelV3 implementation for Ollama.
* Calls Ollama's native /api/chat endpoint directly to support
* think, num_ctx, temperature, and other model options.
*/
export function createOllamaModel(options: ISmartAiOptions): LanguageModelV3 {
const baseUrl = options.baseUrl ?? 'http://localhost:11434';
const modelId = options.model;
const ollamaOpts: IOllamaModelOptions = { ...options.ollamaOptions };
// Apply default temperature of 0.55 for Qwen models
if (modelId.toLowerCase().includes('qwen') && ollamaOpts.temperature === undefined) {
ollamaOpts.temperature = 0.55;
}
const model: LanguageModelV3 = {
specificationVersion: 'v3',
provider: 'ollama',
modelId,
supportedUrls: {},
async doGenerate(callOptions: LanguageModelV3CallOptions): Promise<LanguageModelV3GenerateResult> {
const messages = convertPromptToOllamaMessages(callOptions.prompt);
const tools = convertToolsToOllamaTools(callOptions.tools);
const ollamaModelOptions: Record<string, unknown> = { ...ollamaOpts };
// Override with call-level options if provided
if (callOptions.temperature !== undefined) ollamaModelOptions.temperature = callOptions.temperature;
if (callOptions.topP !== undefined) ollamaModelOptions.top_p = callOptions.topP;
if (callOptions.topK !== undefined) ollamaModelOptions.top_k = callOptions.topK;
if (callOptions.maxOutputTokens !== undefined) ollamaModelOptions.num_predict = callOptions.maxOutputTokens;
if (callOptions.seed !== undefined) ollamaModelOptions.seed = callOptions.seed;
if (callOptions.stopSequences) ollamaModelOptions.stop = callOptions.stopSequences;
// Remove think from options — it goes at the top level
const { think, ...modelOpts } = ollamaModelOptions;
const requestBody: Record<string, unknown> = {
model: modelId,
messages,
stream: false,
options: modelOpts,
};
// Add think parameter at the top level (Ollama API requirement)
if (ollamaOpts.think !== undefined) {
requestBody.think = ollamaOpts.think;
}
if (tools) requestBody.tools = tools;
const response = await fetch(`${baseUrl}/api/chat`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify(requestBody),
signal: callOptions.abortSignal,
});
if (!response.ok) {
const body = await response.text();
throw new Error(`Ollama API error ${response.status}: ${body}`);
}
const result = await response.json() as Record<string, unknown>;
const message = result.message as Record<string, unknown>;
// Build content array
const content: LanguageModelV3Content[] = [];
// Add reasoning if present
if (message.thinking && typeof message.thinking === 'string') {
content.push({ type: 'reasoning', text: message.thinking });
}
// Add text content
if (message.content && typeof message.content === 'string') {
content.push({ type: 'text', text: message.content });
}
// Add tool calls if present
if (Array.isArray(message.tool_calls)) {
for (const tc of message.tool_calls as Array<Record<string, unknown>>) {
const fn = tc.function as Record<string, unknown>;
content.push({
type: 'tool-call',
toolCallId: generateId(),
toolName: fn.name as string,
input: JSON.stringify(fn.arguments),
});
}
}
const finishReason = Array.isArray(message.tool_calls) && (message.tool_calls as unknown[]).length > 0
? makeFinishReason('tool_calls')
: makeFinishReason('stop');
return {
content,
finishReason,
usage: makeUsage(
(result.prompt_eval_count as number) ?? undefined,
(result.eval_count as number) ?? undefined,
),
warnings: [],
request: { body: requestBody },
};
},
async doStream(callOptions: LanguageModelV3CallOptions): Promise<LanguageModelV3StreamResult> {
const messages = convertPromptToOllamaMessages(callOptions.prompt);
const tools = convertToolsToOllamaTools(callOptions.tools);
const ollamaModelOptions: Record<string, unknown> = { ...ollamaOpts };
if (callOptions.temperature !== undefined) ollamaModelOptions.temperature = callOptions.temperature;
if (callOptions.topP !== undefined) ollamaModelOptions.top_p = callOptions.topP;
if (callOptions.topK !== undefined) ollamaModelOptions.top_k = callOptions.topK;
if (callOptions.maxOutputTokens !== undefined) ollamaModelOptions.num_predict = callOptions.maxOutputTokens;
if (callOptions.seed !== undefined) ollamaModelOptions.seed = callOptions.seed;
if (callOptions.stopSequences) ollamaModelOptions.stop = callOptions.stopSequences;
const { think, ...modelOpts } = ollamaModelOptions;
const requestBody: Record<string, unknown> = {
model: modelId,
messages,
stream: true,
options: modelOpts,
};
if (ollamaOpts.think !== undefined) {
requestBody.think = ollamaOpts.think;
}
if (tools) requestBody.tools = tools;
const response = await fetch(`${baseUrl}/api/chat`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify(requestBody),
signal: callOptions.abortSignal,
});
if (!response.ok) {
const body = await response.text();
throw new Error(`Ollama API error ${response.status}: ${body}`);
}
const reader = response.body!.getReader();
const decoder = new TextDecoder();
const textId = generateId();
const reasoningId = generateId();
let textStarted = false;
let reasoningStarted = false;
let hasToolCalls = false;
let closed = false;
const stream = new ReadableStream<LanguageModelV3StreamPart>({
async pull(controller) {
if (closed) return;
const processLine = (line: string) => {
if (!line.trim()) return;
let json: Record<string, unknown>;
try {
json = JSON.parse(line);
} catch {
return;
}
const msg = json.message as Record<string, unknown> | undefined;
// Handle thinking/reasoning content
if (msg?.thinking && typeof msg.thinking === 'string') {
if (!reasoningStarted) {
reasoningStarted = true;
controller.enqueue({ type: 'reasoning-start', id: reasoningId });
}
controller.enqueue({ type: 'reasoning-delta', id: reasoningId, delta: msg.thinking });
}
// Handle text content
if (msg?.content && typeof msg.content === 'string') {
if (reasoningStarted && !textStarted) {
controller.enqueue({ type: 'reasoning-end', id: reasoningId });
}
if (!textStarted) {
textStarted = true;
controller.enqueue({ type: 'text-start', id: textId });
}
controller.enqueue({ type: 'text-delta', id: textId, delta: msg.content });
}
// Handle tool calls
if (Array.isArray(msg?.tool_calls)) {
hasToolCalls = true;
for (const tc of msg!.tool_calls as Array<Record<string, unknown>>) {
const fn = tc.function as Record<string, unknown>;
const callId = generateId();
controller.enqueue({
type: 'tool-call',
toolCallId: callId,
toolName: fn.name as string,
input: JSON.stringify(fn.arguments),
});
}
}
// Handle done
if (json.done) {
if (reasoningStarted && !textStarted) {
controller.enqueue({ type: 'reasoning-end', id: reasoningId });
}
if (textStarted) {
controller.enqueue({ type: 'text-end', id: textId });
}
controller.enqueue({
type: 'finish',
finishReason: hasToolCalls
? makeFinishReason('tool_calls')
: makeFinishReason('stop'),
usage: makeUsage(
(json.prompt_eval_count as number) ?? undefined,
(json.eval_count as number) ?? undefined,
),
});
closed = true;
controller.close();
}
};
try {
let buffer = '';
while (true) {
const { done, value } = await reader.read();
if (done) {
if (buffer.trim()) processLine(buffer);
if (!closed) {
controller.enqueue({
type: 'finish',
finishReason: makeFinishReason('stop'),
usage: makeUsage(undefined, undefined),
});
closed = true;
controller.close();
}
return;
}
buffer += decoder.decode(value, { stream: true });
const lines = buffer.split('\n');
buffer = lines.pop() || '';
for (const line of lines) {
processLine(line);
if (closed) return;
}
}
} catch (error) {
if (!closed) {
controller.error(error);
closed = true;
}
} finally {
reader.releaseLock();
}
},
});
return {
stream,
request: { body: requestBody },
};
},
};
return model;
}