251 lines
7.8 KiB
TypeScript
251 lines
7.8 KiB
TypeScript
|
|
import type { JSONObject, JSONValue, LanguageModelV3Middleware, LanguageModelV3Prompt } from '@ai-sdk/provider';
|
||
|
|
import type { TSmartAiProviderOptions } from './smartai.interfaces.js';
|
||
|
|
|
||
|
|
export type TSmartAiMessageCacheProvider =
|
||
|
|
| 'anthropic'
|
||
|
|
| 'openrouter'
|
||
|
|
| 'bedrock'
|
||
|
|
| 'openaiCompatible'
|
||
|
|
| 'copilot'
|
||
|
|
| 'alibaba';
|
||
|
|
|
||
|
|
export type TSmartAiCacheRetention = 'ephemeral' | '1h' | 'in_memory' | '24h';
|
||
|
|
|
||
|
|
export interface ISmartAiCacheOptions {
|
||
|
|
/** Provider-specific message cache marker namespace. Usually inferred from the model. */
|
||
|
|
provider?: TSmartAiMessageCacheProvider;
|
||
|
|
/** Stable session/request key for providers that support request-level prompt cache affinity. */
|
||
|
|
key?: string;
|
||
|
|
/** Short retention is the default; longer retention is opt-in. */
|
||
|
|
retention?: TSmartAiCacheRetention;
|
||
|
|
}
|
||
|
|
|
||
|
|
export type TSmartAiCacheSetting = boolean | 'auto' | ISmartAiCacheOptions;
|
||
|
|
|
||
|
|
function isObject(input: unknown): input is Record<string, unknown> {
|
||
|
|
return typeof input === 'object' && input !== null && !Array.isArray(input);
|
||
|
|
}
|
||
|
|
|
||
|
|
function mergeJsonDefaults(defaults: JSONObject, overrides?: JSONObject): JSONObject {
|
||
|
|
const result: JSONObject = { ...defaults };
|
||
|
|
|
||
|
|
if (!overrides) return result;
|
||
|
|
|
||
|
|
for (const [key, value] of Object.entries(overrides)) {
|
||
|
|
const existing = result[key];
|
||
|
|
if (isObject(existing) && isObject(value)) {
|
||
|
|
result[key] = mergeJsonDefaults(existing as JSONObject, value as JSONObject);
|
||
|
|
continue;
|
||
|
|
}
|
||
|
|
result[key] = value as JSONValue;
|
||
|
|
}
|
||
|
|
|
||
|
|
return result;
|
||
|
|
}
|
||
|
|
|
||
|
|
export function mergeSmartAiProviderOptions(
|
||
|
|
defaults?: TSmartAiProviderOptions,
|
||
|
|
overrides?: TSmartAiProviderOptions,
|
||
|
|
): TSmartAiProviderOptions | undefined {
|
||
|
|
if (!defaults) return overrides;
|
||
|
|
if (!overrides) return defaults;
|
||
|
|
return mergeJsonDefaults(defaults as JSONObject, overrides as JSONObject) as TSmartAiProviderOptions;
|
||
|
|
}
|
||
|
|
|
||
|
|
function cacheOptionsFromSetting(cache: TSmartAiCacheSetting | undefined): ISmartAiCacheOptions | undefined {
|
||
|
|
if (cache === false) return undefined;
|
||
|
|
if (cache === undefined || cache === true || cache === 'auto') return {};
|
||
|
|
return cache;
|
||
|
|
}
|
||
|
|
|
||
|
|
export function resolveSmartAiCacheProvider(provider?: string, modelId?: string): TSmartAiMessageCacheProvider | undefined {
|
||
|
|
const providerLower = provider?.toLowerCase() ?? '';
|
||
|
|
const modelLower = modelId?.toLowerCase() ?? '';
|
||
|
|
|
||
|
|
if (providerLower.includes('openrouter')) return 'openrouter';
|
||
|
|
if (providerLower.includes('bedrock')) return 'bedrock';
|
||
|
|
if (providerLower.includes('copilot')) return 'copilot';
|
||
|
|
if (providerLower.includes('alibaba')) return 'alibaba';
|
||
|
|
if (providerLower.includes('openai-compatible') || providerLower.includes('openaicompatible')) {
|
||
|
|
return 'openaiCompatible';
|
||
|
|
}
|
||
|
|
if (providerLower.includes('anthropic')) return 'anthropic';
|
||
|
|
if (modelLower.includes('claude') || modelLower.includes('anthropic')) return 'anthropic';
|
||
|
|
|
||
|
|
return undefined;
|
||
|
|
}
|
||
|
|
|
||
|
|
export function getSmartAiMessageCacheProviderOptions(
|
||
|
|
provider: TSmartAiMessageCacheProvider,
|
||
|
|
options: ISmartAiCacheOptions = {},
|
||
|
|
): TSmartAiProviderOptions {
|
||
|
|
const anthropicCacheControl: JSONObject = {
|
||
|
|
type: 'ephemeral',
|
||
|
|
...(options.retention === '1h' ? { ttl: '1h' } : {}),
|
||
|
|
};
|
||
|
|
|
||
|
|
const providerOptions: Record<TSmartAiMessageCacheProvider, JSONObject> = {
|
||
|
|
anthropic: {
|
||
|
|
anthropic: {
|
||
|
|
cacheControl: anthropicCacheControl,
|
||
|
|
},
|
||
|
|
},
|
||
|
|
openrouter: {
|
||
|
|
openrouter: {
|
||
|
|
cacheControl: { type: 'ephemeral' },
|
||
|
|
},
|
||
|
|
},
|
||
|
|
bedrock: {
|
||
|
|
bedrock: {
|
||
|
|
cachePoint: { type: 'default' },
|
||
|
|
},
|
||
|
|
},
|
||
|
|
openaiCompatible: {
|
||
|
|
openaiCompatible: {
|
||
|
|
cache_control: { type: 'ephemeral' },
|
||
|
|
},
|
||
|
|
},
|
||
|
|
copilot: {
|
||
|
|
copilot: {
|
||
|
|
copilot_cache_control: { type: 'ephemeral' },
|
||
|
|
},
|
||
|
|
},
|
||
|
|
alibaba: {
|
||
|
|
alibaba: {
|
||
|
|
cacheControl: { type: 'ephemeral' },
|
||
|
|
},
|
||
|
|
},
|
||
|
|
};
|
||
|
|
|
||
|
|
return providerOptions[provider] as TSmartAiProviderOptions;
|
||
|
|
}
|
||
|
|
|
||
|
|
function shouldUseMessageLevelOptions(provider: TSmartAiMessageCacheProvider): boolean {
|
||
|
|
return provider === 'anthropic' || provider === 'bedrock';
|
||
|
|
}
|
||
|
|
|
||
|
|
function applyProviderOptionsDefaults<T extends { providerOptions?: TSmartAiProviderOptions }>(
|
||
|
|
item: T,
|
||
|
|
defaults: TSmartAiProviderOptions,
|
||
|
|
): T {
|
||
|
|
return {
|
||
|
|
...item,
|
||
|
|
providerOptions: mergeSmartAiProviderOptions(defaults, item.providerOptions),
|
||
|
|
};
|
||
|
|
}
|
||
|
|
|
||
|
|
function isToolApprovalPart(part: unknown): boolean {
|
||
|
|
if (!isObject(part)) return false;
|
||
|
|
return part.type === 'tool-approval-request' || part.type === 'tool-approval-response';
|
||
|
|
}
|
||
|
|
|
||
|
|
function applyCacheToMessage(
|
||
|
|
message: LanguageModelV3Prompt[number],
|
||
|
|
provider: TSmartAiMessageCacheProvider,
|
||
|
|
options: ISmartAiCacheOptions,
|
||
|
|
): LanguageModelV3Prompt[number] {
|
||
|
|
const providerOptions = getSmartAiMessageCacheProviderOptions(provider, options);
|
||
|
|
const content = message.content;
|
||
|
|
|
||
|
|
if (!shouldUseMessageLevelOptions(provider) && Array.isArray(content) && content.length > 0) {
|
||
|
|
const lastIndex = content.length - 1;
|
||
|
|
const lastPart = content[lastIndex];
|
||
|
|
if (!isToolApprovalPart(lastPart)) {
|
||
|
|
const messageWithArrayContent = message as Extract<LanguageModelV3Prompt[number], { content: unknown[] }>;
|
||
|
|
return {
|
||
|
|
...messageWithArrayContent,
|
||
|
|
content: content.map((part, index) =>
|
||
|
|
index === lastIndex ? applyProviderOptionsDefaults(part, providerOptions) : part,
|
||
|
|
) as typeof messageWithArrayContent.content,
|
||
|
|
} as LanguageModelV3Prompt[number];
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
return applyProviderOptionsDefaults(message, providerOptions);
|
||
|
|
}
|
||
|
|
|
||
|
|
export function applySmartAiPromptCaching(
|
||
|
|
prompt: LanguageModelV3Prompt,
|
||
|
|
options: ISmartAiCacheOptions = {},
|
||
|
|
): LanguageModelV3Prompt {
|
||
|
|
const provider = options.provider ?? 'anthropic';
|
||
|
|
const targetIndexes = new Set<number>();
|
||
|
|
const nonSystemIndexes: number[] = [];
|
||
|
|
let systemCount = 0;
|
||
|
|
|
||
|
|
for (let i = 0; i < prompt.length; i++) {
|
||
|
|
const message = prompt[i];
|
||
|
|
if (message.role === 'system') {
|
||
|
|
if (systemCount < 2) targetIndexes.add(i);
|
||
|
|
systemCount++;
|
||
|
|
continue;
|
||
|
|
}
|
||
|
|
nonSystemIndexes.push(i);
|
||
|
|
}
|
||
|
|
|
||
|
|
for (const index of nonSystemIndexes.slice(-2)) {
|
||
|
|
targetIndexes.add(index);
|
||
|
|
}
|
||
|
|
|
||
|
|
if (targetIndexes.size === 0) return prompt;
|
||
|
|
|
||
|
|
return prompt.map((message, index) =>
|
||
|
|
targetIndexes.has(index) ? applyCacheToMessage(message, provider, options) : message,
|
||
|
|
) as LanguageModelV3Prompt;
|
||
|
|
}
|
||
|
|
|
||
|
|
export function createSmartAiCachingMiddleware(options: ISmartAiCacheOptions = {}): LanguageModelV3Middleware {
|
||
|
|
return {
|
||
|
|
specificationVersion: 'v3',
|
||
|
|
transformParams: async ({ params }) => ({
|
||
|
|
...params,
|
||
|
|
prompt: applySmartAiPromptCaching(params.prompt, options),
|
||
|
|
}),
|
||
|
|
};
|
||
|
|
}
|
||
|
|
|
||
|
|
function isOpenAiProvider(provider?: string): boolean {
|
||
|
|
const providerLower = provider?.toLowerCase() ?? '';
|
||
|
|
return providerLower === 'openai' || providerLower.startsWith('openai.') || providerLower.includes('@ai-sdk/openai');
|
||
|
|
}
|
||
|
|
|
||
|
|
export function getSmartAiCacheProviderOptions(input: {
|
||
|
|
provider?: string;
|
||
|
|
modelId?: string;
|
||
|
|
cache?: TSmartAiCacheSetting;
|
||
|
|
sessionId?: string;
|
||
|
|
}): TSmartAiProviderOptions | undefined {
|
||
|
|
const cacheOptions = cacheOptionsFromSetting(input.cache);
|
||
|
|
if (!cacheOptions) return undefined;
|
||
|
|
|
||
|
|
if (isOpenAiProvider(input.provider)) {
|
||
|
|
const key = cacheOptions.key ?? input.sessionId;
|
||
|
|
return {
|
||
|
|
openai: {
|
||
|
|
store: false,
|
||
|
|
...(key ? { promptCacheKey: key } : {}),
|
||
|
|
...(cacheOptions.retention === '24h' || cacheOptions.retention === 'in_memory'
|
||
|
|
? { promptCacheRetention: cacheOptions.retention }
|
||
|
|
: key
|
||
|
|
? { promptCacheRetention: 'in_memory' }
|
||
|
|
: {}),
|
||
|
|
},
|
||
|
|
};
|
||
|
|
}
|
||
|
|
|
||
|
|
return undefined;
|
||
|
|
}
|
||
|
|
|
||
|
|
export function applySmartAiCacheProviderOptions(input: {
|
||
|
|
provider?: string;
|
||
|
|
modelId?: string;
|
||
|
|
providerOptions?: TSmartAiProviderOptions;
|
||
|
|
cache?: TSmartAiCacheSetting;
|
||
|
|
sessionId?: string;
|
||
|
|
}): TSmartAiProviderOptions | undefined {
|
||
|
|
return mergeSmartAiProviderOptions(
|
||
|
|
getSmartAiCacheProviderOptions(input),
|
||
|
|
input.providerOptions,
|
||
|
|
);
|
||
|
|
}
|