From cc3445c1a85363c03a57597a9417e66fbbd584ee Mon Sep 17 00:00:00 2001 From: Tomas Pilar Date: Mon, 25 Nov 2024 08:41:27 +0100 Subject: [PATCH] feat(agents): add streamlit agent (#87) Signed-off-by: Tomas Pilar --- migrations/Migration20241121182852.ts | 30 +++++ package.json | 2 +- pnpm-lock.yaml | 16 +-- seeders/DatabaseSeeder.ts | 19 ++- src/assistants/assistant.entity.ts | 13 +- src/assistants/assistants.service.ts | 36 +++++ src/assistants/dtos/assistant-create.ts | 6 + src/assistants/dtos/assistant.ts | 8 +- src/assistants/dtos/assistants-list.ts | 5 + src/runs/execution/constants.ts | 6 + .../execution/event-handlers/streaming.ts | 126 +++++++++++++++++- src/runs/execution/execute.ts | 42 +++--- src/runs/execution/factory.ts | 83 +++++++++++- src/runs/execution/helpers.ts | 49 ------- 14 files changed, 347 insertions(+), 94 deletions(-) create mode 100644 migrations/Migration20241121182852.ts diff --git a/migrations/Migration20241121182852.ts b/migrations/Migration20241121182852.ts new file mode 100644 index 0000000..5427540 --- /dev/null +++ b/migrations/Migration20241121182852.ts @@ -0,0 +1,30 @@ +/** + * Copyright 2024 IBM Corp. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { Migration } from '@mikro-orm/migrations-mongodb'; + +import { Assistant } from '@/assistants/assistant.entity'; +import { Agent } from '@/runs/execution/constants'; + +export class Migration20241121182852 extends Migration { + async up(): Promise { + await this.getCollection(Assistant).updateMany( + {}, + { $set: { agent: Agent.BEE } }, + { session: this.ctx } + ); + } +} diff --git a/package.json b/package.json index 578c64f..6f5d41d 100644 --- a/package.json +++ b/package.json @@ -48,7 +48,7 @@ "@zilliz/milvus2-sdk-node": "^2.4.4", "ajv": "^8.17.1", "axios": "^1.7.7", - "bee-agent-framework": "0.0.41", + "bee-agent-framework": "0.0.42", "bee-observe-connector": "0.0.5", "bullmq": "5.8.1", "cache-manager": "^5.7.6", diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index ed6722a..d41d2bc 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -84,11 +84,11 @@ importers: specifier: ^1.7.7 version: 1.7.7 bee-agent-framework: - specifier: 0.0.41 - version: 0.0.41(@bufbuild/protobuf@1.10.0)(@googleapis/customsearch@3.2.0)(@grpc/grpc-js@1.12.2)(@grpc/proto-loader@0.7.13)(@ibm-generative-ai/node-sdk@3.2.4)(ollama@0.5.9)(openai-chat-tokens@0.2.8)(openai@4.67.3(zod@3.23.8)) + specifier: 0.0.42 + version: 0.0.42(@bufbuild/protobuf@1.10.0)(@googleapis/customsearch@3.2.0)(@grpc/grpc-js@1.12.2)(@grpc/proto-loader@0.7.13)(@ibm-generative-ai/node-sdk@3.2.4)(ollama@0.5.9)(openai-chat-tokens@0.2.8)(openai@4.67.3(zod@3.23.8)) bee-observe-connector: specifier: 0.0.5 - version: 0.0.5(bee-agent-framework@0.0.41(@bufbuild/protobuf@1.10.0)(@googleapis/customsearch@3.2.0)(@grpc/grpc-js@1.12.2)(@grpc/proto-loader@0.7.13)(@ibm-generative-ai/node-sdk@3.2.4)(ollama@0.5.9)(openai-chat-tokens@0.2.8)(openai@4.67.3(zod@3.23.8))) + version: 0.0.5(bee-agent-framework@0.0.42(@bufbuild/protobuf@1.10.0)(@googleapis/customsearch@3.2.0)(@grpc/grpc-js@1.12.2)(@grpc/proto-loader@0.7.13)(@ibm-generative-ai/node-sdk@3.2.4)(ollama@0.5.9)(openai-chat-tokens@0.2.8)(openai@4.67.3(zod@3.23.8))) bullmq: specifier: 5.8.1 version: 5.8.1 @@ -1971,8 +1971,8 @@ packages: resolution: {integrity: sha512-4Bcg1P8xhUuqcii/S0Z9wiHIrQVPMermM1any+MX5GeGD7faD3/msQUDGLol9wOcz4/jbg/WJnGqoJF6LiBdtg==} engines: {node: '>=10.0.0'} - bee-agent-framework@0.0.41: - resolution: {integrity: sha512-fooL1KlyhT06pKo9Cp1Hq4wubbolndqmp/ENLk+bd6vDE1pKylxIo9y6upoT8o+TrkwVRhWVG0zac+ySGTuy5Q==} + bee-agent-framework@0.0.42: + resolution: {integrity: sha512-W6onNf9Zaj6GNpRTKJOMChfvSZccrBsmafuSh8I7Er1eSND9viu5RVybG/FiAI6omjdJGJnEZ+lTgY46rl+DdA==} peerDependencies: '@aws-sdk/client-bedrock-runtime': ^3.687.0 '@elastic/elasticsearch': ^8.0.0 @@ -7260,7 +7260,7 @@ snapshots: basic-ftp@5.0.5: {} - bee-agent-framework@0.0.41(@bufbuild/protobuf@1.10.0)(@googleapis/customsearch@3.2.0)(@grpc/grpc-js@1.12.2)(@grpc/proto-loader@0.7.13)(@ibm-generative-ai/node-sdk@3.2.4)(ollama@0.5.9)(openai-chat-tokens@0.2.8)(openai@4.67.3(zod@3.23.8)): + bee-agent-framework@0.0.42(@bufbuild/protobuf@1.10.0)(@googleapis/customsearch@3.2.0)(@grpc/grpc-js@1.12.2)(@grpc/proto-loader@0.7.13)(@ibm-generative-ai/node-sdk@3.2.4)(ollama@0.5.9)(openai-chat-tokens@0.2.8)(openai@4.67.3(zod@3.23.8)): dependencies: '@ai-zen/node-fetch-event-source': 2.1.4 '@connectrpc/connect': 1.6.1(@bufbuild/protobuf@1.10.0) @@ -7305,9 +7305,9 @@ snapshots: - debug - encoding - bee-observe-connector@0.0.5(bee-agent-framework@0.0.41(@bufbuild/protobuf@1.10.0)(@googleapis/customsearch@3.2.0)(@grpc/grpc-js@1.12.2)(@grpc/proto-loader@0.7.13)(@ibm-generative-ai/node-sdk@3.2.4)(ollama@0.5.9)(openai-chat-tokens@0.2.8)(openai@4.67.3(zod@3.23.8))): + bee-observe-connector@0.0.5(bee-agent-framework@0.0.42(@bufbuild/protobuf@1.10.0)(@googleapis/customsearch@3.2.0)(@grpc/grpc-js@1.12.2)(@grpc/proto-loader@0.7.13)(@ibm-generative-ai/node-sdk@3.2.4)(ollama@0.5.9)(openai-chat-tokens@0.2.8)(openai@4.67.3(zod@3.23.8))): dependencies: - bee-agent-framework: 0.0.41(@bufbuild/protobuf@1.10.0)(@googleapis/customsearch@3.2.0)(@grpc/grpc-js@1.12.2)(@grpc/proto-loader@0.7.13)(@ibm-generative-ai/node-sdk@3.2.4)(ollama@0.5.9)(openai-chat-tokens@0.2.8)(openai@4.67.3(zod@3.23.8)) + bee-agent-framework: 0.0.42(@bufbuild/protobuf@1.10.0)(@googleapis/customsearch@3.2.0)(@grpc/grpc-js@1.12.2)(@grpc/proto-loader@0.7.13)(@ibm-generative-ai/node-sdk@3.2.4)(ollama@0.5.9)(openai-chat-tokens@0.2.8)(openai@4.67.3(zod@3.23.8)) openapi-fetch: 0.11.3 remeda: 2.16.0 diff --git a/seeders/DatabaseSeeder.ts b/seeders/DatabaseSeeder.ts index 207e026..11fdf1f 100644 --- a/seeders/DatabaseSeeder.ts +++ b/seeders/DatabaseSeeder.ts @@ -36,6 +36,7 @@ import { PROJECT_ID_DEFAULT } from '@/config'; import { redactProjectKeyValue } from '@/administration/helpers'; +import { Agent } from '@/runs/execution/constants'; const USER_EXTERNAL_ID = 'test'; const PROJECT_API_KEY = `${API_KEY_PREFIX}testkey`; @@ -89,8 +90,9 @@ export class DatabaseSeeder extends Seeder { project: ref(project), redactedValue: redactProjectKeyValue(PROJECT_API_KEY) }); - const assistant = new Assistant({ + const beeAssistant = new Assistant({ model: getDefaultModel(), + agent: Agent.BEE, tools: [ { type: 'system', @@ -117,7 +119,20 @@ export class DatabaseSeeder extends Seeder { $ui_icon: 'Bee' } }); - em.persist([assistant, projectApiKey]); + const streamlitAssistant = new Assistant({ + model: getDefaultModel(), + agent: Agent.STREAMLIT, + tools: [], + name: 'Builder Assistant', + project: ref(project), + createdBy: ref(projectUser), + description: 'An example streamlit agent, tailored for building Streamlit applications.', + metadata: { + $ui_color: 'white', + $ui_icon: 'Bee' + } + }); + em.persist([beeAssistant, streamlitAssistant, projectApiKey]); await em.flush(); process.env.IN_SEEDER = undefined; } diff --git a/src/assistants/assistant.entity.ts b/src/assistants/assistant.entity.ts index a4cf855..de1f18c 100644 --- a/src/assistants/assistant.entity.ts +++ b/src/assistants/assistant.entity.ts @@ -14,7 +14,7 @@ * limitations under the License. */ -import { Embedded, Entity, Property } from '@mikro-orm/core'; +import { Embedded, Entity, Enum, Property } from '@mikro-orm/core'; import { ProjectScopedEntity, ProjectScopedEntityInput } from '@/common/project-scoped.entity'; import { AnyToolUsage } from '@/tools/entities/tool-usages/tool-usage.entity.js'; @@ -28,6 +28,7 @@ import { AnyToolResource } from '@/tools/entities/tool-resources/tool-resource.e import { FileSearchResource } from '@/tools/entities/tool-resources/file-search-resources.entity.js'; import { UserResource } from '@/tools/entities/tool-resources/user-resource.entity'; import { SystemResource } from '@/tools/entities/tool-resources/system-resource.entity'; +import { Agent } from '@/runs/execution/constants'; @Entity() export class Assistant extends ProjectScopedEntity { @@ -35,6 +36,9 @@ export class Assistant extends ProjectScopedEntity { return 'asst'; } + @Enum(() => Agent) + agent: Agent; + // Union must be defined in alphabetical order, otherwise Mikro-ORM won't discovered the auto-created virtual polymorphic entity @Embedded({ object: true }) tools!: (CodeInterpreterUsage | FileSearchUsage | FunctionUsage | SystemUsage | UserUsage)[]; @@ -71,6 +75,7 @@ export class Assistant extends ProjectScopedEntity { tools, toolResources, model, + agent, topP, temperature, systemPromptOverwrite, @@ -84,6 +89,7 @@ export class Assistant extends ProjectScopedEntity { this.tools = tools; this.toolResources = toolResources; this.model = model; + this.agent = agent; this.topP = topP; this.temperature = temperature; this.systemPromptOverwrite = systemPromptOverwrite; @@ -93,5 +99,8 @@ export class Assistant extends ProjectScopedEntity { export type AssistantInput = ProjectScopedEntityInput & { tools: AnyToolUsage[]; toolResources?: AnyToolResource[]; -} & Pick & +} & Pick< + Assistant, + 'instructions' | 'name' | 'description' | 'model' | 'agent' | 'systemPromptOverwrite' + > & Partial>; diff --git a/src/assistants/assistants.service.ts b/src/assistants/assistants.service.ts index 4da7780..e6f9ee0 100644 --- a/src/assistants/assistants.service.ts +++ b/src/assistants/assistants.service.ts @@ -42,6 +42,7 @@ import { Tool, ToolType } from '@/tools/entities/tool/tool.entity.js'; import { getUpdatedValue } from '@/utils/update.js'; import { createDeleteResponse } from '@/utils/delete.js'; import { getDefaultModel } from '@/runs/execution/factory'; +import { Agent } from '@/runs/execution/constants.js'; export function toDto(assistant: Loaded): AssistantDto { return { @@ -55,6 +56,7 @@ export function toDto(assistant: Loaded): AssistantDto { metadata: assistant.metadata ?? {}, created_at: dayjs(assistant.createdAt).unix(), model: assistant.model, + agent: assistant.agent, top_p: assistant.topP, temperature: assistant.temperature, system_prompt: assistant.systemPromptOverwrite @@ -70,9 +72,23 @@ export async function createAssistant({ metadata, top_p, model, + agent, temperature, system_prompt_overwrite }: AssistantCreateBody): Promise { + if (agent === Agent.STREAMLIT) { + if (toolsParam) + throw new APIError({ + code: APIErrorCode.INVALID_INPUT, + message: 'Tools are currently not supported by Streamlit agent' + }); + if (tool_resources) + throw new APIError({ + code: APIErrorCode.INVALID_INPUT, + message: 'Tool resouces are currently not supported by Streamlit agent' + }); + } + const customToolIds = toolsParam.flatMap((toolUsage) => toolUsage.type === ToolType.USER ? toolUsage.user.tool.id : [] ); @@ -99,6 +115,7 @@ export async function createAssistant({ metadata, topP: top_p ?? undefined, model: model ?? getDefaultModel(), + agent, temperature: temperature ?? undefined, systemPromptOverwrite: system_prompt_overwrite ?? undefined }); @@ -131,6 +148,20 @@ export async function updateAssistant({ const assistant = await ORM.em.getRepository(Assistant).findOneOrFail({ id: assistant_id }); + + if (assistant.agent === Agent.STREAMLIT) { + if (tools) + throw new APIError({ + code: APIErrorCode.INVALID_INPUT, + message: 'Tools are currently not supported by Streamlit agent' + }); + if (tool_resources) + throw new APIError({ + code: APIErrorCode.INVALID_INPUT, + message: 'Tool resouces are currently not supported by Streamlit agent' + }); + } + assistant.name = getUpdatedValue(name, assistant.name); assistant.description = getUpdatedValue(description, assistant.description); assistant.instructions = getUpdatedValue(instructions, assistant.instructions); @@ -157,10 +188,15 @@ export async function listAssistants({ before, order, order_by, + agent, search }: AssistantsListQuery): Promise { const where: FilterQuery = {}; + if (agent) { + where.agent = agent; + } + if (search) { const regexp = new RegExp(search, 'i'); where.$or = [{ description: regexp }, { name: regexp }]; diff --git a/src/assistants/dtos/assistant-create.ts b/src/assistants/dtos/assistant-create.ts index 7c14e2a..4100dbe 100644 --- a/src/assistants/dtos/assistant-create.ts +++ b/src/assistants/dtos/assistant-create.ts @@ -21,6 +21,7 @@ import { assistantSchema } from './assistant.js'; import { toolUsageSchema } from '@/tools/dtos/tools-usage.js'; import { metadataSchema } from '@/schema.js'; import { toolResourcesSchema } from '@/tools/dtos/tool-resources.js'; +import { Agent } from '@/runs/execution/constants.js'; export const assistantCreateBodySchema = { type: 'object', @@ -55,6 +56,11 @@ export const assistantCreateBodySchema = { model: { type: 'string' }, + agent: { + type: 'string', + enum: Object.values(Agent), + default: Agent.BEE + }, top_p: { type: 'number', nullable: true, diff --git a/src/assistants/dtos/assistant.ts b/src/assistants/dtos/assistant.ts index a1dc0e4..d9bc9fc 100644 --- a/src/assistants/dtos/assistant.ts +++ b/src/assistants/dtos/assistant.ts @@ -19,6 +19,7 @@ import { FromSchema, JSONSchema } from 'json-schema-to-ts'; import { toolUsageSchema } from '@/tools/dtos/tools-usage.js'; import { metadataSchema } from '@/schema.js'; import { toolResourcesSchema } from '@/tools/dtos/tool-resources.js'; +import { Agent } from '@/runs/execution/constants'; export const assistantSchema = { type: 'object', @@ -31,7 +32,8 @@ export const assistantSchema = { 'description', 'metadata', 'created_at', - 'model' + 'model', + 'agent' ], properties: { id: { type: 'string' }, @@ -62,6 +64,10 @@ export const assistantSchema = { model: { type: 'string' }, + agent: { + type: 'string', + enum: Object.values(Agent) + }, top_p: { type: 'number', nullable: true diff --git a/src/assistants/dtos/assistants-list.ts b/src/assistants/dtos/assistants-list.ts index bb6dab8..ae9e09e 100644 --- a/src/assistants/dtos/assistants-list.ts +++ b/src/assistants/dtos/assistants-list.ts @@ -19,6 +19,7 @@ import { FromSchema, JSONSchema } from 'json-schema-to-ts'; import { assistantSchema } from './assistant.js'; import { createPaginationQuerySchema, withPagination } from '@/schema.js'; +import { Agent } from '@/runs/execution/constants.js'; export const assistantsListQuerySchema = { type: 'object', @@ -29,6 +30,10 @@ export const assistantsListQuerySchema = { type: 'boolean', default: false }, + agent: { + type: 'string', + enum: Object.values(Agent) + }, search: { type: 'string', nullable: true diff --git a/src/runs/execution/constants.ts b/src/runs/execution/constants.ts index fdf080c..356cdf7 100644 --- a/src/runs/execution/constants.ts +++ b/src/runs/execution/constants.ts @@ -17,6 +17,12 @@ export const RUN_EXPIRATION_MILLISECONDS = 10 * 60 * 1000; export const STATUS_POLL_INTERVAL = 5 * 1000; +export const Agent = { + BEE: 'bee', + STREAMLIT: 'streamlit' +} as const; +export type Agent = (typeof Agent)[keyof typeof Agent]; + export const LLMBackend = { OLLAMA: 'ollama', IBM_VLLM: 'ibm-vllm', diff --git a/src/runs/execution/event-handlers/streaming.ts b/src/runs/execution/event-handlers/streaming.ts index acd0a0f..54fe1c8 100644 --- a/src/runs/execution/event-handlers/streaming.ts +++ b/src/runs/execution/event-handlers/streaming.ts @@ -15,12 +15,18 @@ */ import { FrameworkError, Version } from 'bee-agent-framework'; -import { EventMeta, Emitter } from 'bee-agent-framework/emitter/emitter'; +import { EventMeta, Emitter, Callback } from 'bee-agent-framework/emitter/emitter'; import { ref } from '@mikro-orm/core'; import { Role } from 'bee-agent-framework/llms/primitives/message'; import { BeeCallbacks } from 'bee-agent-framework/agents/bee/types'; import { Summary } from 'prom-client'; import { ToolError } from 'bee-agent-framework/tools/base'; +import { + StreamlitEvents as StreamlitEventsFramework, + StreamlitRunOutput +} from 'bee-agent-framework/agents/experimental/streamlit/agent'; + +import { Agent } from '../constants'; import { AgentContext } from '@/runs/execution/execute.js'; import { getLogger } from '@/logger.js'; @@ -50,9 +56,9 @@ const agentToolExecutionTime = new Summary({ registers: [jobRegistry] }); -export function createStreamingHandler(ctx: AgentContext) { +export function createBeeStreamingHandler(ctx: AgentContext) { return (emitter: Emitter) => { - const logger = getLogger().child({ runId: ctx.run.id }); + const logger = getLogger().child({ runId: ctx.run.id, agent: Agent.BEE }); let toolExecutionEnd: (() => number) | null = null; @@ -185,7 +191,10 @@ export function createStreamingHandler(ctx: AgentContext) { ctx.message.content = data.final_answer ?? ctx.message.content; ctx.message.status = MessageStatus.COMPLETED; await ORM.em.flush(); - await ctx.publish({ event: 'thread.message.completed', data: toMessageDto(ctx.message) }); + await ctx.publish({ + event: 'thread.message.completed', + data: toMessageDto(ctx.message) + }); ctx.message = undefined; } if ( @@ -210,7 +219,10 @@ export function createStreamingHandler(ctx: AgentContext) { } else { ctx.runStep.status = RunStepStatus.FAILED; await ORM.em.flush(); - await ctx.publish({ event: 'thread.run.step.failed', data: toRunStepDto(ctx.runStep) }); + await ctx.publish({ + event: 'thread.run.step.failed', + data: toRunStepDto(ctx.runStep) + }); } ctx.runStep = undefined; ctx.toolCall = undefined; @@ -252,7 +264,10 @@ export function createStreamingHandler(ctx: AgentContext) { event: 'thread.run.step.in_progress', data: toRunStepDto(ctx.runStep) }); - await ctx.publish({ event: 'thread.message.created', data: toMessageDto(ctx.message) }); + await ctx.publish({ + event: 'thread.message.created', + data: toMessageDto(ctx.message) + }); await ctx.publish({ event: 'thread.message.in_progress', data: toMessageDto(ctx.message) @@ -317,6 +332,98 @@ export function createStreamingHandler(ctx: AgentContext) { }; } +export function createStreamlitStreamingHandler(ctx: AgentContext) { + const logger = getLogger().child({ runId: ctx.run.id, agent: Agent.STREAMLIT }); + return (emitter: Emitter) => { + emitter.on('newToken', async ({ delta }) => { + if (!ctx.message) { + ctx.message = new Message({ + project: ctx.run.project, + role: Role.ASSISTANT, + content: '', + thread: ctx.run.thread, + run: ref(ctx.run), + createdBy: ctx.run.createdBy, + status: MessageStatus.IN_PROGRESS + }); + if (ctx.runStep) + logger.warn( + { step: ctx.runStep.id }, + 'Message creation has started while previous run step has not finished' + ); + ctx.runStep = new RunStep({ + project: ctx.run.project, + run: ref(ctx.run), + thread: ctx.run.thread, + assistant: ctx.run.assistant, + createdBy: ctx.run.createdBy, + details: new RunStepMessageCreation({ message: ref(ctx.message) }) + }); + await ORM.em.persistAndFlush([ctx.message, ctx.runStep]); + await ctx.publish({ + event: 'thread.run.step.created', + data: toRunStepDto(ctx.runStep) + }); + await ctx.publish({ + event: 'thread.run.step.in_progress', + data: toRunStepDto(ctx.runStep) + }); + await ctx.publish({ event: 'thread.message.created', data: toMessageDto(ctx.message) }); + await ctx.publish({ + event: 'thread.message.in_progress', + data: toMessageDto(ctx.message) + }); + } + ctx.message.content += delta; + await ctx.publish({ + event: 'thread.message.delta', + data: { + id: ctx.message.id, + object: 'thread.message.delta', + delta: { + role: Role.ASSISTANT, + content: [ + { + index: 0, + type: 'text', + text: { + value: delta + } + } + ] + } + } + }); + }); + emitter.on('success', async () => { + if (!ctx.message) { + logger.warn('Agent success with missing message'); + return; + } + if (!ctx.runStep) { + logger.warn('Agent success with missing run step'); + return; + } + + ctx.message.status = MessageStatus.COMPLETED; + // TODO add artifact to message + await ORM.em.flush(); + await ctx.publish({ + event: 'thread.message.completed', + data: toMessageDto(ctx.message) + }); + + ctx.runStep.status = RunStepStatus.COMPLETED; + await ORM.em.flush(); + await ctx.publish({ + event: 'thread.run.step.completed', + data: toRunStepDto(ctx.runStep) + }); + }); + emitter.on('error', ({ error }) => onError(ctx)(error)); + }; +} + const onError = (ctx: AgentContext) => async (error: Error) => { if (ctx.message) { ctx.message.status = MessageStatus.INCOMPLETE; @@ -358,3 +465,10 @@ const createEventFromMeta = (meta: EventMeta) => { }) : undefined; }; + +type StreamlitEvents = StreamlitEventsFramework & { + error?: Callback<{ + error: Error; + }>; + success?: Callback; +}; diff --git a/src/runs/execution/execute.ts b/src/runs/execution/execute.ts index e36d44c..e097df3 100644 --- a/src/runs/execution/execute.ts +++ b/src/runs/execution/execute.ts @@ -21,6 +21,8 @@ import { Summary } from 'prom-client'; import dayjs from 'dayjs'; import { isTruthy } from 'remeda'; import { createObserveConnector } from 'bee-observe-connector'; +import { BeeAgent } from 'bee-agent-framework/agents/bee/agent'; +import { TokenMemory } from 'bee-agent-framework/memory/tokenMemory'; import { Run } from '../entities/run.entity.js'; import { toRunDto } from '../runs.service.js'; @@ -28,10 +30,10 @@ import { toRunDto } from '../runs.service.js'; import { addFileToToolResource, checkFileExistsOnToolResource, - createAgent, createToolResource } from './helpers.js'; import { getTools } from './tools/helpers.js'; +import { createAgentRun, createChatLLM } from './factory.js'; import { ORM } from '@/database.js'; import { getLogger } from '@/logger.js'; @@ -44,7 +46,6 @@ import { Trace } from '@/observe/entities/trace.entity.js'; import { RunStep } from '@/run-steps/entities/run-step.entity.js'; import { Message } from '@/messages/message.entity.js'; import { AnyToolCall } from '@/tools/entities/tool-calls/tool-call.entity.js'; -import { createStreamingHandler } from '@/runs/execution/event-handlers/streaming.js'; import { LoadedRun } from '@/runs/execution/types.js'; import { UserResource } from '@/tools/entities/tool-resources/user-resource.entity.js'; import { SystemResource } from '@/tools/entities/tool-resources/system-resource.entity.js'; @@ -59,7 +60,7 @@ const agentExecutionTime = new Summary({ }); export type AgentContext = { - run: Loaded; + run: Loaded; publish: ReturnType; runStep?: Loaded; message?: Loaded; @@ -126,8 +127,9 @@ export async function executeRun(run: LoadedRun) { const context = { run, publish } as AgentContext; const tools = await getTools(run, context); - const agent = createAgent(run, tools); - await agent.memory.addMany(messages); + const llm = createChatLLM(run); + const memory = new TokenMemory({ llm }); + await memory.addMany(messages); const cancellationController = new AbortController(); const unsub = watchForCancellation(Run, run, () => cancellationController.abort()); @@ -136,25 +138,18 @@ export async function executeRun(run: LoadedRun) { try { const endAgentExecutionTimer = agentExecutionTime.labels({ framework: Version }).startTimer(); - const agentRunPromise = agent - .run( - { prompt: null }, // messages have been loaded to agent's memory - { - // eslint-disable-next-line @typescript-eslint/ban-ts-comment - // @ts-ignore - signal: AbortSignal.any([cancellationController.signal, expirationSignal]), - execution: { - totalMaxRetries: 10, - maxRetriesPerStep: 3, - maxIterations: 10 - } - } - ) - .observe(createStreamingHandler(context)); + const [agentRun, agent] = createAgentRun( + run, + { llm, tools, memory }, + { + signal: AbortSignal.any([cancellationController.signal, expirationSignal]), + ctx: context + } + ); // apply observe middleware only when the observe API is enabled - if (BEE_OBSERVE_API_URL && BEE_OBSERVE_API_AUTH_KEY) { - agentRunPromise.middleware( + if (BEE_OBSERVE_API_URL && BEE_OBSERVE_API_AUTH_KEY && agent instanceof BeeAgent) { + (agentRun as ReturnType).middleware( createObserveConnector({ api: { baseUrl: BEE_OBSERVE_API_URL, @@ -180,7 +175,7 @@ export async function executeRun(run: LoadedRun) { ); } - await agentRunPromise; + await agentRun; endAgentExecutionTimer(); run.complete(); @@ -208,6 +203,5 @@ export async function executeRun(run: LoadedRun) { } finally { await publish({ event: 'done', data: '[DONE]' }); unsub(); - agent.destroy(); } } diff --git a/src/runs/execution/factory.ts b/src/runs/execution/factory.ts index 0381cb4..2a1d270 100644 --- a/src/runs/execution/factory.ts +++ b/src/runs/execution/factory.ts @@ -33,10 +33,26 @@ import { WatsonXChatLLMPresetModel } from 'bee-agent-framework/adapters/watsonx/ import { BAMLLM } from 'bee-agent-framework/adapters/bam/llm'; import { IBMvLLM } from 'bee-agent-framework/adapters/ibm-vllm/llm'; import { WatsonXLLM } from 'bee-agent-framework/adapters/watsonx/llm'; +import { ZodType } from 'zod'; +import { PromptTemplate } from 'bee-agent-framework'; +import { AnyTool } from 'bee-agent-framework/tools/base'; +import { GraniteBeeAgent } from 'bee-agent-framework/agents/granite/agent'; +import { StreamlitAgent } from 'bee-agent-framework/agents/experimental/streamlit/agent'; +import { GraniteBeeSystemPrompt } from 'bee-agent-framework/agents/granite/prompts'; +import { BeeAgent } from 'bee-agent-framework/agents/bee/agent'; +import { BeeSystemPrompt } from 'bee-agent-framework/agents/bee/prompts'; +import { ChatLLM, ChatLLMOutput } from 'bee-agent-framework/llms/chat'; +import { BaseMemory } from 'bee-agent-framework/memory/base'; +import { StreamlitAgentSystemPrompt } from 'bee-agent-framework/agents/experimental/streamlit/prompts'; import { Run } from '../entities/run.entity'; -import { LLMBackend } from './constants'; +import { Agent, LLMBackend } from './constants'; +import { + createBeeStreamingHandler, + createStreamlitStreamingHandler +} from './event-handlers/streaming'; +import { AgentContext } from './execute'; import { BAM_API_KEY, @@ -215,3 +231,68 @@ export function createCodeLLM(backend: LLMBackend = LLM_BACKEND) { return undefined; } } + +export function createAgentRun( + run: Loaded, + { llm, tools, memory }: { llm: ChatLLM; tools: AnyTool[]; memory: BaseMemory }, + { signal, ctx }: { signal: AbortSignal; ctx: AgentContext } +) { + const runArgs = [ + { prompt: null }, // messages have been loaded to agent's memory + { + // eslint-disable-next-line @typescript-eslint/ban-ts-comment + // @ts-ignore + signal, + execution: { + totalMaxRetries: 10, + maxRetriesPerStep: 3, + maxIterations: 10 + } + } + ] as const; + switch (run.assistant.$.agent) { + case Agent.BEE: { + const agent = run.model.includes('granite') + ? new GraniteBeeAgent({ + llm, + memory, + tools, + templates: { system: getPromptTemplate(run, GraniteBeeSystemPrompt) } + }) + : new BeeAgent({ + llm, + memory, + tools, + templates: { system: getPromptTemplate(run, BeeSystemPrompt) } + }); + return [agent.run(...runArgs).observe(createBeeStreamingHandler(ctx)), agent]; + } + case Agent.STREAMLIT: { + const agent = new StreamlitAgent({ + llm, + memory, + templates: { system: getPromptTemplate(run, StreamlitAgentSystemPrompt) } + }); + return [agent.run(...runArgs).observe(createStreamlitStreamingHandler(ctx)), agent]; + } + } +} + +function getPromptTemplate( + run: Loaded, + promptTemplate: PromptTemplate +): PromptTemplate { + const instructions = run.additionalInstructions + ? `${run.instructions} ${run.additionalInstructions}` + : run.instructions; + return promptTemplate.fork((input) => ({ + ...input, + ...(run.assistant.$.systemPromptOverwrite + ? { template: run.assistant.$.systemPromptOverwrite } + : {}), + defaults: { + ...input.defaults, + ...(instructions ? { instructions } : {}) + } + })); +} diff --git a/src/runs/execution/helpers.ts b/src/runs/execution/helpers.ts index 5e68282..cc24e27 100644 --- a/src/runs/execution/helpers.ts +++ b/src/runs/execution/helpers.ts @@ -15,17 +15,7 @@ */ import { Loaded, ref, Ref } from '@mikro-orm/core'; -import { BeeSystemPrompt } from 'bee-agent-framework/agents/bee/prompts'; import { unique } from 'remeda'; -import { PromptTemplate } from 'bee-agent-framework/template'; -import { GraniteBeeSystemPrompt } from 'bee-agent-framework/agents/granite/prompts'; -import { ZodType } from 'zod'; -import { GraniteBeeAgent } from 'bee-agent-framework/agents/granite/agent'; -import { TokenMemory } from 'bee-agent-framework/memory/tokenMemory'; -import { BeeAgent } from 'bee-agent-framework/agents/bee/agent'; -import { AnyTool } from 'bee-agent-framework/tools/base'; - -import { Run } from '../entities/run.entity.js'; import { ORM } from '@/database.js'; import { File } from '@/files/entities/file.entity.js'; @@ -40,7 +30,6 @@ import { FileContainer } from '@/files/entities/files-container.entity.js'; import { VectorStore } from '@/vector-stores/entities/vector-store.entity.js'; import { UserResource } from '@/tools/entities/tool-resources/user-resource.entity.js'; import { SystemResource } from '@/tools/entities/tool-resources/system-resource.entity.js'; -import { createChatLLM } from '@/runs/execution/factory'; export function getRunVectorStores( assistant: Loaded, @@ -55,44 +44,6 @@ export function getRunVectorStores( return unique(vectorStores); // filter out duplicates } -export function createAgent(run: Loaded, tools: AnyTool[]) { - const llm = createChatLLM(run); - if (run.model.includes('granite')) { - return new GraniteBeeAgent({ - llm, - memory: new TokenMemory({ llm }), - tools, - templates: { system: getPromptTemplate(run, GraniteBeeSystemPrompt) } - }); - } else { - return new BeeAgent({ - llm, - memory: new TokenMemory({ llm }), - tools, - templates: { system: getPromptTemplate(run, BeeSystemPrompt) } - }); - } -} - -function getPromptTemplate( - run: Loaded, - promptTemplate: PromptTemplate -): PromptTemplate { - const instructions = run.additionalInstructions - ? `${run.instructions} ${run.additionalInstructions}` - : run.instructions; - return promptTemplate.fork((input) => ({ - ...input, - ...(run.assistant.$.systemPromptOverwrite - ? { template: run.assistant.$.systemPromptOverwrite } - : {}), - defaults: { - ...input.defaults, - ...(instructions ? { instructions } : {}) - } - })); -} - export async function checkFileExistsOnToolResource( toolResource: AnyToolResource, file: Ref