427 lines
14 KiB
TypeScript
427 lines
14 KiB
TypeScript
|
|
import type {
|
||
|
|
LanguageModelV3,
|
||
|
|
LanguageModelV3CallOptions,
|
||
|
|
LanguageModelV3GenerateResult,
|
||
|
|
LanguageModelV3StreamResult,
|
||
|
|
LanguageModelV3StreamPart,
|
||
|
|
LanguageModelV3Prompt,
|
||
|
|
LanguageModelV3Content,
|
||
|
|
LanguageModelV3Usage,
|
||
|
|
LanguageModelV3FinishReason,
|
||
|
|
} from '@ai-sdk/provider';
|
||
|
|
import type { ISmartAiOptions, IOllamaModelOptions } from './smartai.interfaces.js';
|
||
|
|
|
||
|
|
interface IOllamaMessage {
|
||
|
|
role: string;
|
||
|
|
content: string;
|
||
|
|
images?: string[];
|
||
|
|
tool_calls?: Array<{
|
||
|
|
function: { name: string; arguments: Record<string, unknown> };
|
||
|
|
}>;
|
||
|
|
thinking?: string;
|
||
|
|
}
|
||
|
|
|
||
|
|
interface IOllamaTool {
|
||
|
|
type: 'function';
|
||
|
|
function: {
|
||
|
|
name: string;
|
||
|
|
description: string;
|
||
|
|
parameters: Record<string, unknown>;
|
||
|
|
};
|
||
|
|
}
|
||
|
|
|
||
|
|
/**
|
||
|
|
* Convert AI SDK V3 prompt messages to Ollama's message format.
|
||
|
|
*/
|
||
|
|
function convertPromptToOllamaMessages(prompt: LanguageModelV3Prompt): IOllamaMessage[] {
|
||
|
|
const messages: IOllamaMessage[] = [];
|
||
|
|
|
||
|
|
for (const msg of prompt) {
|
||
|
|
if (msg.role === 'system') {
|
||
|
|
// System message content is a plain string in V3
|
||
|
|
messages.push({ role: 'system', content: msg.content });
|
||
|
|
} else if (msg.role === 'user') {
|
||
|
|
let text = '';
|
||
|
|
const images: string[] = [];
|
||
|
|
for (const part of msg.content) {
|
||
|
|
if (part.type === 'text') {
|
||
|
|
text += part.text;
|
||
|
|
} else if (part.type === 'file' && part.mediaType?.startsWith('image/')) {
|
||
|
|
// Handle image files — Ollama expects base64 images
|
||
|
|
if (typeof part.data === 'string') {
|
||
|
|
images.push(part.data);
|
||
|
|
} else if (part.data instanceof Uint8Array) {
|
||
|
|
images.push(Buffer.from(part.data).toString('base64'));
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
const m: IOllamaMessage = { role: 'user', content: text };
|
||
|
|
if (images.length > 0) m.images = images;
|
||
|
|
messages.push(m);
|
||
|
|
} else if (msg.role === 'assistant') {
|
||
|
|
let text = '';
|
||
|
|
let thinking = '';
|
||
|
|
const toolCalls: IOllamaMessage['tool_calls'] = [];
|
||
|
|
for (const part of msg.content) {
|
||
|
|
if (part.type === 'text') {
|
||
|
|
text += part.text;
|
||
|
|
} else if (part.type === 'reasoning') {
|
||
|
|
thinking += part.text;
|
||
|
|
} else if (part.type === 'tool-call') {
|
||
|
|
const args = typeof part.input === 'string'
|
||
|
|
? JSON.parse(part.input as string)
|
||
|
|
: (part.input as Record<string, unknown>);
|
||
|
|
toolCalls.push({
|
||
|
|
function: {
|
||
|
|
name: part.toolName,
|
||
|
|
arguments: args,
|
||
|
|
},
|
||
|
|
});
|
||
|
|
}
|
||
|
|
}
|
||
|
|
const m: IOllamaMessage = { role: 'assistant', content: text };
|
||
|
|
if (toolCalls.length > 0) m.tool_calls = toolCalls;
|
||
|
|
if (thinking) m.thinking = thinking;
|
||
|
|
messages.push(m);
|
||
|
|
} else if (msg.role === 'tool') {
|
||
|
|
for (const part of msg.content) {
|
||
|
|
if (part.type === 'tool-result') {
|
||
|
|
let resultContent = '';
|
||
|
|
if (part.output) {
|
||
|
|
if (part.output.type === 'text') {
|
||
|
|
resultContent = part.output.value;
|
||
|
|
} else if (part.output.type === 'json') {
|
||
|
|
resultContent = JSON.stringify(part.output.value);
|
||
|
|
}
|
||
|
|
}
|
||
|
|
messages.push({ role: 'tool', content: resultContent });
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
return messages;
|
||
|
|
}
|
||
|
|
|
||
|
|
/**
|
||
|
|
* Convert AI SDK V3 tools to Ollama's tool format.
|
||
|
|
*/
|
||
|
|
function convertToolsToOllamaTools(tools: LanguageModelV3CallOptions['tools']): IOllamaTool[] | undefined {
|
||
|
|
if (!tools || tools.length === 0) return undefined;
|
||
|
|
|
||
|
|
return tools
|
||
|
|
.filter((t): t is Extract<typeof t, { type: 'function' }> => t.type === 'function')
|
||
|
|
.map(t => ({
|
||
|
|
type: 'function' as const,
|
||
|
|
function: {
|
||
|
|
name: t.name,
|
||
|
|
description: t.description ?? '',
|
||
|
|
parameters: t.inputSchema as Record<string, unknown>,
|
||
|
|
},
|
||
|
|
}));
|
||
|
|
}
|
||
|
|
|
||
|
|
function makeUsage(promptTokens?: number, completionTokens?: number): LanguageModelV3Usage {
|
||
|
|
return {
|
||
|
|
inputTokens: {
|
||
|
|
total: promptTokens,
|
||
|
|
noCache: undefined,
|
||
|
|
cacheRead: undefined,
|
||
|
|
cacheWrite: undefined,
|
||
|
|
},
|
||
|
|
outputTokens: {
|
||
|
|
total: completionTokens,
|
||
|
|
text: completionTokens,
|
||
|
|
reasoning: undefined,
|
||
|
|
},
|
||
|
|
};
|
||
|
|
}
|
||
|
|
|
||
|
|
function makeFinishReason(reason?: string): LanguageModelV3FinishReason {
|
||
|
|
if (reason === 'tool_calls' || reason === 'tool-calls') {
|
||
|
|
return { unified: 'tool-calls', raw: reason };
|
||
|
|
}
|
||
|
|
return { unified: 'stop', raw: reason ?? 'stop' };
|
||
|
|
}
|
||
|
|
|
||
|
|
let idCounter = 0;
|
||
|
|
function generateId(): string {
|
||
|
|
return `ollama-${Date.now()}-${idCounter++}`;
|
||
|
|
}
|
||
|
|
|
||
|
|
/**
|
||
|
|
* Custom LanguageModelV3 implementation for Ollama.
|
||
|
|
* Calls Ollama's native /api/chat endpoint directly to support
|
||
|
|
* think, num_ctx, temperature, and other model options.
|
||
|
|
*/
|
||
|
|
export function createOllamaModel(options: ISmartAiOptions): LanguageModelV3 {
|
||
|
|
const baseUrl = options.baseUrl ?? 'http://localhost:11434';
|
||
|
|
const modelId = options.model;
|
||
|
|
const ollamaOpts: IOllamaModelOptions = { ...options.ollamaOptions };
|
||
|
|
|
||
|
|
// Apply default temperature of 0.55 for Qwen models
|
||
|
|
if (modelId.toLowerCase().includes('qwen') && ollamaOpts.temperature === undefined) {
|
||
|
|
ollamaOpts.temperature = 0.55;
|
||
|
|
}
|
||
|
|
|
||
|
|
const model: LanguageModelV3 = {
|
||
|
|
specificationVersion: 'v3',
|
||
|
|
provider: 'ollama',
|
||
|
|
modelId,
|
||
|
|
supportedUrls: {},
|
||
|
|
|
||
|
|
async doGenerate(callOptions: LanguageModelV3CallOptions): Promise<LanguageModelV3GenerateResult> {
|
||
|
|
const messages = convertPromptToOllamaMessages(callOptions.prompt);
|
||
|
|
const tools = convertToolsToOllamaTools(callOptions.tools);
|
||
|
|
|
||
|
|
const ollamaModelOptions: Record<string, unknown> = { ...ollamaOpts };
|
||
|
|
// Override with call-level options if provided
|
||
|
|
if (callOptions.temperature !== undefined) ollamaModelOptions.temperature = callOptions.temperature;
|
||
|
|
if (callOptions.topP !== undefined) ollamaModelOptions.top_p = callOptions.topP;
|
||
|
|
if (callOptions.topK !== undefined) ollamaModelOptions.top_k = callOptions.topK;
|
||
|
|
if (callOptions.maxOutputTokens !== undefined) ollamaModelOptions.num_predict = callOptions.maxOutputTokens;
|
||
|
|
if (callOptions.seed !== undefined) ollamaModelOptions.seed = callOptions.seed;
|
||
|
|
if (callOptions.stopSequences) ollamaModelOptions.stop = callOptions.stopSequences;
|
||
|
|
// Remove think from options — it goes at the top level
|
||
|
|
const { think, ...modelOpts } = ollamaModelOptions;
|
||
|
|
|
||
|
|
const requestBody: Record<string, unknown> = {
|
||
|
|
model: modelId,
|
||
|
|
messages,
|
||
|
|
stream: false,
|
||
|
|
options: modelOpts,
|
||
|
|
};
|
||
|
|
|
||
|
|
// Add think parameter at the top level (Ollama API requirement)
|
||
|
|
if (ollamaOpts.think !== undefined) {
|
||
|
|
requestBody.think = ollamaOpts.think;
|
||
|
|
}
|
||
|
|
|
||
|
|
if (tools) requestBody.tools = tools;
|
||
|
|
|
||
|
|
const response = await fetch(`${baseUrl}/api/chat`, {
|
||
|
|
method: 'POST',
|
||
|
|
headers: { 'Content-Type': 'application/json' },
|
||
|
|
body: JSON.stringify(requestBody),
|
||
|
|
signal: callOptions.abortSignal,
|
||
|
|
});
|
||
|
|
|
||
|
|
if (!response.ok) {
|
||
|
|
const body = await response.text();
|
||
|
|
throw new Error(`Ollama API error ${response.status}: ${body}`);
|
||
|
|
}
|
||
|
|
|
||
|
|
const result = await response.json() as Record<string, unknown>;
|
||
|
|
const message = result.message as Record<string, unknown>;
|
||
|
|
|
||
|
|
// Build content array
|
||
|
|
const content: LanguageModelV3Content[] = [];
|
||
|
|
|
||
|
|
// Add reasoning if present
|
||
|
|
if (message.thinking && typeof message.thinking === 'string') {
|
||
|
|
content.push({ type: 'reasoning', text: message.thinking });
|
||
|
|
}
|
||
|
|
|
||
|
|
// Add text content
|
||
|
|
if (message.content && typeof message.content === 'string') {
|
||
|
|
content.push({ type: 'text', text: message.content });
|
||
|
|
}
|
||
|
|
|
||
|
|
// Add tool calls if present
|
||
|
|
if (Array.isArray(message.tool_calls)) {
|
||
|
|
for (const tc of message.tool_calls as Array<Record<string, unknown>>) {
|
||
|
|
const fn = tc.function as Record<string, unknown>;
|
||
|
|
content.push({
|
||
|
|
type: 'tool-call',
|
||
|
|
toolCallId: generateId(),
|
||
|
|
toolName: fn.name as string,
|
||
|
|
input: JSON.stringify(fn.arguments),
|
||
|
|
});
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
const finishReason = Array.isArray(message.tool_calls) && (message.tool_calls as unknown[]).length > 0
|
||
|
|
? makeFinishReason('tool_calls')
|
||
|
|
: makeFinishReason('stop');
|
||
|
|
|
||
|
|
return {
|
||
|
|
content,
|
||
|
|
finishReason,
|
||
|
|
usage: makeUsage(
|
||
|
|
(result.prompt_eval_count as number) ?? undefined,
|
||
|
|
(result.eval_count as number) ?? undefined,
|
||
|
|
),
|
||
|
|
warnings: [],
|
||
|
|
request: { body: requestBody },
|
||
|
|
};
|
||
|
|
},
|
||
|
|
|
||
|
|
async doStream(callOptions: LanguageModelV3CallOptions): Promise<LanguageModelV3StreamResult> {
|
||
|
|
const messages = convertPromptToOllamaMessages(callOptions.prompt);
|
||
|
|
const tools = convertToolsToOllamaTools(callOptions.tools);
|
||
|
|
|
||
|
|
const ollamaModelOptions: Record<string, unknown> = { ...ollamaOpts };
|
||
|
|
if (callOptions.temperature !== undefined) ollamaModelOptions.temperature = callOptions.temperature;
|
||
|
|
if (callOptions.topP !== undefined) ollamaModelOptions.top_p = callOptions.topP;
|
||
|
|
if (callOptions.topK !== undefined) ollamaModelOptions.top_k = callOptions.topK;
|
||
|
|
if (callOptions.maxOutputTokens !== undefined) ollamaModelOptions.num_predict = callOptions.maxOutputTokens;
|
||
|
|
if (callOptions.seed !== undefined) ollamaModelOptions.seed = callOptions.seed;
|
||
|
|
if (callOptions.stopSequences) ollamaModelOptions.stop = callOptions.stopSequences;
|
||
|
|
const { think, ...modelOpts } = ollamaModelOptions;
|
||
|
|
|
||
|
|
const requestBody: Record<string, unknown> = {
|
||
|
|
model: modelId,
|
||
|
|
messages,
|
||
|
|
stream: true,
|
||
|
|
options: modelOpts,
|
||
|
|
};
|
||
|
|
|
||
|
|
if (ollamaOpts.think !== undefined) {
|
||
|
|
requestBody.think = ollamaOpts.think;
|
||
|
|
}
|
||
|
|
|
||
|
|
if (tools) requestBody.tools = tools;
|
||
|
|
|
||
|
|
const response = await fetch(`${baseUrl}/api/chat`, {
|
||
|
|
method: 'POST',
|
||
|
|
headers: { 'Content-Type': 'application/json' },
|
||
|
|
body: JSON.stringify(requestBody),
|
||
|
|
signal: callOptions.abortSignal,
|
||
|
|
});
|
||
|
|
|
||
|
|
if (!response.ok) {
|
||
|
|
const body = await response.text();
|
||
|
|
throw new Error(`Ollama API error ${response.status}: ${body}`);
|
||
|
|
}
|
||
|
|
|
||
|
|
const reader = response.body!.getReader();
|
||
|
|
const decoder = new TextDecoder();
|
||
|
|
|
||
|
|
const textId = generateId();
|
||
|
|
const reasoningId = generateId();
|
||
|
|
let textStarted = false;
|
||
|
|
let reasoningStarted = false;
|
||
|
|
let hasToolCalls = false;
|
||
|
|
let closed = false;
|
||
|
|
|
||
|
|
const stream = new ReadableStream<LanguageModelV3StreamPart>({
|
||
|
|
async pull(controller) {
|
||
|
|
if (closed) return;
|
||
|
|
|
||
|
|
const processLine = (line: string) => {
|
||
|
|
if (!line.trim()) return;
|
||
|
|
let json: Record<string, unknown>;
|
||
|
|
try {
|
||
|
|
json = JSON.parse(line);
|
||
|
|
} catch {
|
||
|
|
return;
|
||
|
|
}
|
||
|
|
|
||
|
|
const msg = json.message as Record<string, unknown> | undefined;
|
||
|
|
|
||
|
|
// Handle thinking/reasoning content
|
||
|
|
if (msg?.thinking && typeof msg.thinking === 'string') {
|
||
|
|
if (!reasoningStarted) {
|
||
|
|
reasoningStarted = true;
|
||
|
|
controller.enqueue({ type: 'reasoning-start', id: reasoningId });
|
||
|
|
}
|
||
|
|
controller.enqueue({ type: 'reasoning-delta', id: reasoningId, delta: msg.thinking });
|
||
|
|
}
|
||
|
|
|
||
|
|
// Handle text content
|
||
|
|
if (msg?.content && typeof msg.content === 'string') {
|
||
|
|
if (reasoningStarted && !textStarted) {
|
||
|
|
controller.enqueue({ type: 'reasoning-end', id: reasoningId });
|
||
|
|
}
|
||
|
|
if (!textStarted) {
|
||
|
|
textStarted = true;
|
||
|
|
controller.enqueue({ type: 'text-start', id: textId });
|
||
|
|
}
|
||
|
|
controller.enqueue({ type: 'text-delta', id: textId, delta: msg.content });
|
||
|
|
}
|
||
|
|
|
||
|
|
// Handle tool calls
|
||
|
|
if (Array.isArray(msg?.tool_calls)) {
|
||
|
|
hasToolCalls = true;
|
||
|
|
for (const tc of msg!.tool_calls as Array<Record<string, unknown>>) {
|
||
|
|
const fn = tc.function as Record<string, unknown>;
|
||
|
|
const callId = generateId();
|
||
|
|
controller.enqueue({
|
||
|
|
type: 'tool-call',
|
||
|
|
toolCallId: callId,
|
||
|
|
toolName: fn.name as string,
|
||
|
|
input: JSON.stringify(fn.arguments),
|
||
|
|
});
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// Handle done
|
||
|
|
if (json.done) {
|
||
|
|
if (reasoningStarted && !textStarted) {
|
||
|
|
controller.enqueue({ type: 'reasoning-end', id: reasoningId });
|
||
|
|
}
|
||
|
|
if (textStarted) {
|
||
|
|
controller.enqueue({ type: 'text-end', id: textId });
|
||
|
|
}
|
||
|
|
controller.enqueue({
|
||
|
|
type: 'finish',
|
||
|
|
finishReason: hasToolCalls
|
||
|
|
? makeFinishReason('tool_calls')
|
||
|
|
: makeFinishReason('stop'),
|
||
|
|
usage: makeUsage(
|
||
|
|
(json.prompt_eval_count as number) ?? undefined,
|
||
|
|
(json.eval_count as number) ?? undefined,
|
||
|
|
),
|
||
|
|
});
|
||
|
|
closed = true;
|
||
|
|
controller.close();
|
||
|
|
}
|
||
|
|
};
|
||
|
|
|
||
|
|
try {
|
||
|
|
let buffer = '';
|
||
|
|
while (true) {
|
||
|
|
const { done, value } = await reader.read();
|
||
|
|
if (done) {
|
||
|
|
if (buffer.trim()) processLine(buffer);
|
||
|
|
if (!closed) {
|
||
|
|
controller.enqueue({
|
||
|
|
type: 'finish',
|
||
|
|
finishReason: makeFinishReason('stop'),
|
||
|
|
usage: makeUsage(undefined, undefined),
|
||
|
|
});
|
||
|
|
closed = true;
|
||
|
|
controller.close();
|
||
|
|
}
|
||
|
|
return;
|
||
|
|
}
|
||
|
|
|
||
|
|
buffer += decoder.decode(value, { stream: true });
|
||
|
|
const lines = buffer.split('\n');
|
||
|
|
buffer = lines.pop() || '';
|
||
|
|
for (const line of lines) {
|
||
|
|
processLine(line);
|
||
|
|
if (closed) return;
|
||
|
|
}
|
||
|
|
}
|
||
|
|
} catch (error) {
|
||
|
|
if (!closed) {
|
||
|
|
controller.error(error);
|
||
|
|
closed = true;
|
||
|
|
}
|
||
|
|
} finally {
|
||
|
|
reader.releaseLock();
|
||
|
|
}
|
||
|
|
},
|
||
|
|
});
|
||
|
|
|
||
|
|
return {
|
||
|
|
stream,
|
||
|
|
request: { body: requestBody },
|
||
|
|
};
|
||
|
|
},
|
||
|
|
};
|
||
|
|
|
||
|
|
return model;
|
||
|
|
}
|