diff --git a/seeders/DatabaseSeeder.ts b/seeders/DatabaseSeeder.ts index 60c8172..56dcd7a 100644 --- a/seeders/DatabaseSeeder.ts +++ b/seeders/DatabaseSeeder.ts @@ -92,7 +92,7 @@ export class DatabaseSeeder extends Seeder { redactedValue: redactProjectKeyValue(PROJECT_API_KEY) }); const beeAssistant = new Assistant({ - model: defaultAIProvider.createChatLLM().modelId, + model: defaultAIProvider.createAssistantBackend().modelId, agent: Agent.BEE, tools: [ { @@ -121,7 +121,7 @@ export class DatabaseSeeder extends Seeder { } }); const streamlitAssistant = new Assistant({ - model: defaultAIProvider.createChatLLM().modelId, + model: defaultAIProvider.createAssistantBackend().modelId, agent: Agent.STREAMLIT, tools: [], name: 'Builder Assistant', diff --git a/src/assistants/assistants.service.ts b/src/assistants/assistants.service.ts index aaadecf..47613f7 100644 --- a/src/assistants/assistants.service.ts +++ b/src/assistants/assistants.service.ts @@ -114,7 +114,7 @@ export async function createAssistant({ description: description ?? undefined, metadata, topP: top_p ?? undefined, - model: model ?? defaultAIProvider.createChatLLM().modelId, + model: model ?? defaultAIProvider.createAssistantBackend().modelId, agent, temperature: temperature ?? undefined, systemPromptOverwrite: system_prompt_overwrite ?? undefined diff --git a/src/chat/chat.service.ts b/src/chat/chat.service.ts index c13bcaf..0f326b9 100644 --- a/src/chat/chat.service.ts +++ b/src/chat/chat.service.ts @@ -55,7 +55,7 @@ export async function createChatCompletion({ messages, response_format }: ChatCompletionCreateBody): Promise { - const llm = defaultAIProvider.createChatLLM({ model }); + const llm = defaultAIProvider.createChatBackend({ model }); const chat = new Chat({ model: llm.modelId, messages, responseFormat: response_format }); await ORM.em.persistAndFlush(chat); try { diff --git a/src/runs/execution/execute.ts b/src/runs/execution/execute.ts index 2e68ea5..114992e 100644 --- a/src/runs/execution/execute.ts +++ b/src/runs/execution/execute.ts @@ -122,7 +122,7 @@ export async function executeRun(run: LoadedRun) { const context = { run, publish } as AgentContext; const tools = await getTools(run, context); - const llm = defaultAIProvider.createChatLLM(run); + const llm = defaultAIProvider.createAssistantBackend(run); const memory = new TokenMemory({ llm }); await memory.addMany(messages); diff --git a/src/runs/execution/provider.ts b/src/runs/execution/provider.ts index fd0ff51..f56aecb 100644 --- a/src/runs/execution/provider.ts +++ b/src/runs/execution/provider.ts @@ -50,9 +50,10 @@ interface AIProvider< ChatLLMType extends ChatLLM, LLMType extends LLM = any > { - createChatLLM: (params?: ChatLLMParams) => ChatLLMType; - createCodeLLM: (params?: { model?: string }) => LLMType | void; - createEmbeddingModel?: (params?: { model?: string }) => EmbeddingModel; + createChatBackend: (params?: ChatLLMParams) => ChatLLMType; + createAssistantBackend: (params?: ChatLLMParams) => ChatLLMType; + createCodeBackend: (params?: { model?: string }) => LLMType | void; + createEmbeddingBackend?: (params?: { model?: string }) => EmbeddingModel; } export class BamAIProvider implements AIProvider { @@ -62,7 +63,10 @@ export class BamAIProvider implements AIProvider { BamAIProvider.client ??= new BAMClient({ apiKey: BAM_API_KEY ?? undefined }); } - createChatLLM({ model = 'meta-llama/llama-3-1-70b-instruct', ...params }: ChatLLMParams = {}) { + createChatBackend({ + model = 'meta-llama/llama-3-1-70b-instruct', + ...params + }: ChatLLMParams = {}) { return BAMChatLLM.fromPreset(model as BAMChatLLMPresetModel, { client: BamAIProvider.client, parameters: (parameters) => ({ @@ -74,7 +78,11 @@ export class BamAIProvider implements AIProvider { }); } - createCodeLLM({ model = 'meta-llama/llama-3-1-70b-instruct' } = {}) { + createAssistantBackend(params?: ChatLLMParams) { + return this.createChatBackend(params); + } + + createCodeBackend({ model = 'meta-llama/llama-3-1-70b-instruct' } = {}) { return new BAMLLM({ client: BamAIProvider.client, modelId: model, @@ -86,7 +94,7 @@ export class BamAIProvider implements AIProvider { }); } - createEmbeddingModel({ model = 'baai/bge-large-en-v1.5' } = {}) { + createEmbeddingBackend({ model = 'baai/bge-large-en-v1.5' } = {}) { return new BAMLLM({ client: BamAIProvider.client, modelId: model }); } } @@ -98,7 +106,7 @@ export class OllamaAIProvider implements AIProvider { OllamaAIProvider.client ??= new Ollama({ host: OLLAMA_URL ?? undefined }); } - createChatLLM({ model: modelId = 'llama3.1', ...params }: ChatLLMParams = {}) { + createChatBackend({ model: modelId = 'llama3.1', ...params }: ChatLLMParams = {}) { return new OllamaChatLLM({ client: OllamaAIProvider.client, modelId, @@ -109,9 +117,13 @@ export class OllamaAIProvider implements AIProvider { } }); } - createCodeLLM() {} + createAssistantBackend(params?: ChatLLMParams) { + return this.createChatBackend(params); + } - createEmbeddingModel({ model: modelId = 'nomic-embed-text' } = {}) { + createCodeBackend() {} + + createEmbeddingBackend({ model: modelId = 'nomic-embed-text' } = {}) { return new OllamaLLM({ client: OllamaAIProvider.client, modelId }); } } @@ -122,7 +134,7 @@ export class OpenAIProvider implements AIProvider { OpenAIProvider.client ??= new OpenAI({ apiKey: OPENAI_API_KEY ?? undefined }); } - createChatLLM({ model = 'gpt-4o', ...params }: ChatLLMParams = {}) { + createChatBackend({ model = 'gpt-4o', ...params }: ChatLLMParams = {}) { return new OpenAIChatLLM({ client: OpenAIProvider.client, modelId: model as OpenAI.ChatModel, @@ -134,9 +146,13 @@ export class OpenAIProvider implements AIProvider { }); } - createCodeLLM() {} + createAssistantBackend(params?: ChatLLMParams) { + return this.createChatBackend(params); + } + + createCodeBackend() {} - createEmbeddingModel({ model = 'text-embedding-3-large' } = {}) { + createEmbeddingBackend({ model = 'text-embedding-3-large' } = {}) { return { chatLLM: new OpenAIChatLLM({ client: OpenAIProvider.client, @@ -170,7 +186,10 @@ export class IBMvLLMAIProvider implements AIProvider { }); } - createChatLLM({ model = IBMVllmModel.LLAMA_3_1_70B_INSTRUCT, ...params }: ChatLLMParams = {}) { + createChatBackend({ + model = IBMVllmModel.LLAMA_3_1_70B_INSTRUCT, + ...params + }: ChatLLMParams = {}) { return IBMVllmChatLLM.fromPreset(model as IBMVllmChatLLMPresetModel, { client: IBMvLLMAIProvider.client, parameters: (parameters) => ({ @@ -188,7 +207,11 @@ export class IBMvLLMAIProvider implements AIProvider { }); } - createCodeLLM({ model: modelId = 'meta-llama/llama-3-1-70b-instruct' } = {}) { + createAssistantBackend(params?: ChatLLMParams) { + return this.createChatBackend(params); + } + + createCodeBackend({ model: modelId = 'meta-llama/llama-3-1-70b-instruct' } = {}) { return new IBMvLLM({ client: IBMvLLMAIProvider.client, modelId, @@ -199,7 +222,7 @@ export class IBMvLLMAIProvider implements AIProvider { }); } - createEmbeddingModel({ model: modelId = 'baai/bge-large-en-v1.5' } = {}) { + createEmbeddingBackend({ model: modelId = 'baai/bge-large-en-v1.5' } = {}) { return new IBMvLLM({ client: IBMvLLMAIProvider.client, modelId }); } } @@ -215,7 +238,10 @@ export class WatsonxAIProvider implements AIProvider }; } - createChatLLM({ model = 'meta-llama/llama-3-1-70b-instruct', ...params }: ChatLLMParams = {}) { + createChatBackend({ + model = 'meta-llama/llama-3-1-70b-instruct', + ...params + }: ChatLLMParams = {}) { return WatsonXChatLLM.fromPreset(model as WatsonXChatLLMPresetModel, { ...this.credentials, parameters: (parameters) => ({ @@ -227,7 +253,11 @@ export class WatsonxAIProvider implements AIProvider }); } - createCodeLLM({ model: modelId = 'meta-llama/llama-3-1-70b-instruct' } = {}) { + createAssistantBackend(params?: ChatLLMParams) { + return this.createChatBackend(params); + } + + createCodeBackend({ model: modelId = 'meta-llama/llama-3-1-70b-instruct' } = {}) { return new WatsonXLLM({ ...this.credentials, modelId, @@ -239,7 +269,7 @@ export class WatsonxAIProvider implements AIProvider }); } - createEmbeddingModel({ model: modelId = 'ibm/slate-30m-english-rtrvr-v2' } = {}) { + createEmbeddingBackend({ model: modelId = 'ibm/slate-30m-english-rtrvr-v2' } = {}) { return new WatsonXLLM({ ...this.credentials, modelId, region: WATSONX_REGION ?? undefined }); } } diff --git a/src/runs/execution/tools/file-search-tool.ts b/src/runs/execution/tools/file-search-tool.ts index 968988f..ae80d53 100644 --- a/src/runs/execution/tools/file-search-tool.ts +++ b/src/runs/execution/tools/file-search-tool.ts @@ -94,9 +94,11 @@ export class FileSearchTool extends Tool { const vectorStoreClient = getVectorStoreClient(); - const embeddingModel = defaultAIProvider.createEmbeddingModel(); + const embeddingModel = defaultAIProvider.createEmbeddingBackend(); - const [embedding] = (await embeddingModel.embed([query], { signal: run.signal })).embeddings; + const { + embeddings: [embedding] + } = await embeddingModel.embed([query], { signal: run.signal }); if (this.vectorStores.some((vectorStore) => vectorStore.expired)) { throw new Error('Some of the vector stores are expired'); diff --git a/src/runs/execution/tools/helpers.ts b/src/runs/execution/tools/helpers.ts index fcaa87f..f7756d1 100644 --- a/src/runs/execution/tools/helpers.ts +++ b/src/runs/execution/tools/helpers.ts @@ -134,7 +134,7 @@ export async function getTools(run: LoadedRun, context: AgentContext): Promise tool.type === ToolType.SYSTEM && tool.toolId === SystemTools.LLM ); if (llmUsage) { - tools.push(new LLMTool({ llm: defaultAIProvider.createChatLLM() })); + tools.push(new LLMTool({ llm: defaultAIProvider.createChatBackend() })); } const calculatorUsage = run.tools.find( @@ -186,7 +186,7 @@ export async function getTools(run: LoadedRun, context: AgentContext): Promise container.file.$); if (codeInterpreterUsage) { - const codeLLM = defaultAIProvider.createCodeLLM(); + const codeLLM = defaultAIProvider.createCodeBackend(); tools.push( new PythonTool({ codeInterpreter, diff --git a/src/runs/execution/tools/wikipedia-tool.ts b/src/runs/execution/tools/wikipedia-tool.ts index 2af4132..b9862dd 100644 --- a/src/runs/execution/tools/wikipedia-tool.ts +++ b/src/runs/execution/tools/wikipedia-tool.ts @@ -36,7 +36,7 @@ export function wikipediaTool( maxResults = 5 ): AnyTool { // LLM to perform text embedding - const embeddingLLM = defaultAIProvider.createEmbeddingModel(); + const embeddingLLM = defaultAIProvider.createEmbeddingBackend(); // Similarity tool to calculate the similarity between a query and a set of wikipedia passages const similarity = new SimilarityTool({ diff --git a/src/tools/tools.service.ts b/src/tools/tools.service.ts index c8d3ab6..03905f3 100644 --- a/src/tools/tools.service.ts +++ b/src/tools/tools.service.ts @@ -465,7 +465,7 @@ function getSystemTools() { }); const fileSearch = new FileSearchTool({ vectorStores: [], maxNumResults: 0 }); const readFile = new ReadFileTool({ files: [], fileSize: 0 }); - const llmTool = new LLMTool({ llm: defaultAIProvider.createChatLLM() }); + const llmTool = new LLMTool({ llm: defaultAIProvider.createChatBackend() }); const calculatorTool = new CalculatorTool(); const systemTools = new Map(); diff --git a/src/users/users.service.ts b/src/users/users.service.ts index 983c32c..7dd0013 100644 --- a/src/users/users.service.ts +++ b/src/users/users.service.ts @@ -131,7 +131,7 @@ export async function createUser({ user.defaultProject = ORM.em.getRepository(Project).getReference(project.id, { wrapped: true }); const assistant = new Assistant({ - model: defaultAIProvider.createChatLLM().modelId, + model: defaultAIProvider.createAssistantBackend().modelId, agent: Agent.BEE, tools: [ new SystemUsage({ toolId: SystemTools.WEB_SEARCH }), diff --git a/src/vector-store-files/execution/client.ts b/src/vector-store-files/execution/client.ts index c6ae646..242e3a7 100644 --- a/src/vector-store-files/execution/client.ts +++ b/src/vector-store-files/execution/client.ts @@ -28,7 +28,7 @@ export type DocumentType = z.ZodType; export function getVectorStoreClient(): VectorStoreClient { return new MilvusVectorStore({ - modelName: defaultAIProvider.createEmbeddingModel().modelId, + modelName: defaultAIProvider.createEmbeddingBackend().modelId, documentSchema: DocumentSchema }); } diff --git a/src/vector-store-files/execution/process-file.ts b/src/vector-store-files/execution/process-file.ts index e7d6617..4ca2e91 100644 --- a/src/vector-store-files/execution/process-file.ts +++ b/src/vector-store-files/execution/process-file.ts @@ -64,7 +64,7 @@ export async function processVectorStoreFile(vectorStoreFile: Loaded) { - const embeddingAdapter = defaultAIProvider.createEmbeddingModel(); + const embeddingAdapter = defaultAIProvider.createEmbeddingBackend(); for await (const items of source) { const output = await embeddingAdapter.embed(items, { signal: controller.signal }); yield output.embeddings.map((embedding, idx) => ({