From 1797a53ad9a681b5daadca58acba8ca044085efd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luk=C3=A1=C5=A1=20Jane=C4=8Dek?= Date: Tue, 17 Dec 2024 12:39:31 +0100 Subject: [PATCH] feat(tools): add llm and calculator tools (#132) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Lukáš Janeček Co-authored-by: Lukáš Janeček --- pnpm-lock.yaml | 23 +++++++------ src/runs/execution/tools/helpers.ts | 33 ++++++++++++++++++- .../entities/tool-calls/system-call.entity.ts | 4 ++- src/tools/tools.service.ts | 30 +++++++++++++++++ 4 files changed, 76 insertions(+), 14 deletions(-) diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index c64470f..1a595c9 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -46,10 +46,10 @@ importers: version: 6.2.9 '@mikro-orm/migrations-mongodb': specifier: 6.2.9 - version: 6.2.9(@mikro-orm/core@6.2.9)(@types/node@20.16.14)(gcp-metadata@6.1.0(encoding@0.1.13))(socks@2.8.3) + version: 6.2.9(@mikro-orm/core@6.2.9)(@types/node@20.16.14)(socks@2.8.3) '@mikro-orm/mongodb': specifier: 6.2.9 - version: 6.2.9(@mikro-orm/core@6.2.9)(gcp-metadata@6.1.0(encoding@0.1.13))(socks@2.8.3) + version: 6.2.9(@mikro-orm/core@6.2.9)(socks@2.8.3) '@mikro-orm/reflection': specifier: 6.2.9 version: 6.2.9(@mikro-orm/core@6.2.9) @@ -6193,12 +6193,12 @@ snapshots: - supports-color - tedious - '@mikro-orm/migrations-mongodb@6.2.9(@mikro-orm/core@6.2.9)(@types/node@20.16.14)(gcp-metadata@6.1.0(encoding@0.1.13))(socks@2.8.3)': + '@mikro-orm/migrations-mongodb@6.2.9(@mikro-orm/core@6.2.9)(@types/node@20.16.14)(socks@2.8.3)': dependencies: '@mikro-orm/core': 6.2.9 - '@mikro-orm/mongodb': 6.2.9(@mikro-orm/core@6.2.9)(gcp-metadata@6.1.0(encoding@0.1.13))(socks@2.8.3) + '@mikro-orm/mongodb': 6.2.9(@mikro-orm/core@6.2.9)(socks@2.8.3) fs-extra: 11.2.0 - mongodb: 6.7.0(gcp-metadata@6.1.0(encoding@0.1.13))(socks@2.8.3) + mongodb: 6.7.0(socks@2.8.3) umzug: 3.8.0(@types/node@20.16.14) transitivePeerDependencies: - '@aws-sdk/credential-providers' @@ -6210,11 +6210,11 @@ snapshots: - snappy - socks - '@mikro-orm/mongodb@6.2.9(@mikro-orm/core@6.2.9)(gcp-metadata@6.1.0(encoding@0.1.13))(socks@2.8.3)': + '@mikro-orm/mongodb@6.2.9(@mikro-orm/core@6.2.9)(socks@2.8.3)': dependencies: '@mikro-orm/core': 6.2.9 bson: 6.8.0 - mongodb: 6.7.0(gcp-metadata@6.1.0(encoding@0.1.13))(socks@2.8.3) + mongodb: 6.7.0(socks@2.8.3) transitivePeerDependencies: - '@aws-sdk/credential-providers' - '@mongodb-js/zstd' @@ -8692,7 +8692,7 @@ snapshots: debug: 4.3.6 enhanced-resolve: 5.17.1 eslint: 8.57.0 - eslint-module-utils: 2.8.1(@typescript-eslint/parser@7.18.0(eslint@8.57.0)(typescript@5.5.4))(eslint-import-resolver-node@0.3.9)(eslint-import-resolver-typescript@3.6.1(@typescript-eslint/parser@7.18.0(eslint@8.57.0)(typescript@5.5.4))(eslint-plugin-import@2.29.1)(eslint@8.57.0))(eslint@8.57.0) + eslint-module-utils: 2.8.1(@typescript-eslint/parser@7.18.0(eslint@8.57.0)(typescript@5.5.4))(eslint-import-resolver-node@0.3.9)(eslint-import-resolver-typescript@3.6.1)(eslint@8.57.0) eslint-plugin-import: 2.29.1(@typescript-eslint/parser@7.18.0(eslint@8.57.0)(typescript@5.5.4))(eslint-import-resolver-typescript@3.6.1)(eslint@8.57.0) fast-glob: 3.3.2 get-tsconfig: 4.7.6 @@ -8704,7 +8704,7 @@ snapshots: - eslint-import-resolver-webpack - supports-color - eslint-module-utils@2.8.1(@typescript-eslint/parser@7.18.0(eslint@8.57.0)(typescript@5.5.4))(eslint-import-resolver-node@0.3.9)(eslint-import-resolver-typescript@3.6.1(@typescript-eslint/parser@7.18.0(eslint@8.57.0)(typescript@5.5.4))(eslint-plugin-import@2.29.1)(eslint@8.57.0))(eslint@8.57.0): + eslint-module-utils@2.8.1(@typescript-eslint/parser@7.18.0(eslint@8.57.0)(typescript@5.5.4))(eslint-import-resolver-node@0.3.9)(eslint-import-resolver-typescript@3.6.1)(eslint@8.57.0): dependencies: debug: 3.2.7 optionalDependencies: @@ -8725,7 +8725,7 @@ snapshots: doctrine: 2.1.0 eslint: 8.57.0 eslint-import-resolver-node: 0.3.9 - eslint-module-utils: 2.8.1(@typescript-eslint/parser@7.18.0(eslint@8.57.0)(typescript@5.5.4))(eslint-import-resolver-node@0.3.9)(eslint-import-resolver-typescript@3.6.1(@typescript-eslint/parser@7.18.0(eslint@8.57.0)(typescript@5.5.4))(eslint-plugin-import@2.29.1)(eslint@8.57.0))(eslint@8.57.0) + eslint-module-utils: 2.8.1(@typescript-eslint/parser@7.18.0(eslint@8.57.0)(typescript@5.5.4))(eslint-import-resolver-node@0.3.9)(eslint-import-resolver-typescript@3.6.1)(eslint@8.57.0) hasown: 2.0.2 is-core-module: 2.15.0 is-glob: 4.0.3 @@ -10046,13 +10046,12 @@ snapshots: '@types/whatwg-url': 11.0.5 whatwg-url: 13.0.0 - mongodb@6.7.0(gcp-metadata@6.1.0(encoding@0.1.13))(socks@2.8.3): + mongodb@6.7.0(socks@2.8.3): dependencies: '@mongodb-js/saslprep': 1.1.8 bson: 6.8.0 mongodb-connection-string-url: 3.0.1 optionalDependencies: - gcp-metadata: 6.1.0(encoding@0.1.13) socks: 2.8.3 ms@2.1.2: {} diff --git a/src/runs/execution/tools/helpers.ts b/src/runs/execution/tools/helpers.ts index 7bce44f..5573779 100644 --- a/src/runs/execution/tools/helpers.ts +++ b/src/runs/execution/tools/helpers.ts @@ -33,13 +33,16 @@ import { LLMChatTemplates } from 'bee-agent-framework/adapters/shared/llmChatTem import { DuckDuckGoSearchTool } from 'bee-agent-framework/tools/search/duckDuckGoSearch'; import { SearchToolOptions, SearchToolOutput } from 'bee-agent-framework/tools/search/base'; import { PromptTemplate } from 'bee-agent-framework/template'; +import { CalculatorTool } from 'bee-agent-framework/tools/calculator'; +import { LLMTool } from 'bee-agent-framework/tools/llm'; import { AgentContext } from '../execute.js'; import { getRunVectorStores } from '../helpers.js'; import { CodeInterpreterTool as CodeInterpreterUserTool } from '../../../tools/entities/tool/code-interpreter-tool.entity.js'; import { ApiTool as ApiCallUserTool } from '../../../tools/entities/tool/api-tool.entity.js'; -import { createCodeLLM } from '../factory.js'; +import { createChatLLM, createCodeLLM } from '../factory.js'; import { RedisCache } from '../cache.js'; +import { getDefaultModel } from '../constants.js'; import { createPythonStorage } from './python-tool-storage.js'; import { FunctionTool, FunctionToolOutput } from './function.js'; @@ -128,6 +131,22 @@ export async function getTools(run: LoadedRun, context: AgentContext): Promise tool.type === ToolType.SYSTEM && tool.toolId === SystemTools.LLM + ); + if (llmUsage) + tools.push( + new LLMTool({ + llm: createChatLLM({ model: getDefaultModel() }) + }) + ); + + const calculatorUsage = run.tools.find( + (tool): tool is SystemUsage => + tool.type === ToolType.SYSTEM && tool.toolId === SystemTools.CALCULATOR + ); + if (calculatorUsage) tools.push(new CalculatorTool()); + const weatherUsage = run.tools.find( (tool): tool is SystemUsage => tool.type === ToolType.SYSTEM && tool.toolId === SystemTools.WEATHER @@ -342,6 +361,16 @@ export async function createToolCall( toolId: SystemTools.READ_FILE, input: await tool.parse(input) }); + } else if (tool instanceof LLMTool) { + return new SystemCall({ + toolId: SystemTools.LLM, + input: await tool.parse(input) + }); + } else if (tool instanceof CalculatorTool) { + return new SystemCall({ + toolId: SystemTools.CALCULATOR, + input: await tool.parse(input) + }); } else if (tool instanceof FunctionTool) { return new FunctionCall({ name: tool.name, arguments: JSON.stringify(input) }); } else if (tool instanceof CustomTool || tool instanceof ApiCallTool) { @@ -387,6 +416,8 @@ export async function finalizeToolCall( toolCall.output = result; break; } + case SystemTools.CALCULATOR: + case SystemTools.LLM: case SystemTools.READ_FILE: { if (!(result instanceof StringToolOutput)) throw new TypeError(); toolCall.output = result.result; diff --git a/src/tools/entities/tool-calls/system-call.entity.ts b/src/tools/entities/tool-calls/system-call.entity.ts index 877495d..16b3749 100644 --- a/src/tools/entities/tool-calls/system-call.entity.ts +++ b/src/tools/entities/tool-calls/system-call.entity.ts @@ -25,7 +25,9 @@ export enum SystemTools { WIKIPEDIA = 'wikipedia', WEATHER = 'weather', ARXIV = 'arxiv', - READ_FILE = 'read_file' + READ_FILE = 'read_file', + LLM = 'llm', + CALCULATOR = 'calculator' } @Embeddable({ discriminatorValue: ToolType.SYSTEM }) diff --git a/src/tools/tools.service.ts b/src/tools/tools.service.ts index ab69517..8c250f7 100644 --- a/src/tools/tools.service.ts +++ b/src/tools/tools.service.ts @@ -19,6 +19,8 @@ import { CustomTool, CustomToolCreateError } from 'bee-agent-framework/tools/cus import dayjs from 'dayjs'; import mime from 'mime/lite'; import { WikipediaTool } from 'bee-agent-framework/tools/search/wikipedia'; +import { LLMTool } from 'bee-agent-framework/tools/llm'; +import { CalculatorTool } from 'bee-agent-framework/tools/calculator'; import { Tool as FrameworkTool } from 'bee-agent-framework/tools/base'; import { ZodTypeAny } from 'zod'; import { zodToJsonSchema } from 'zod-to-json-schema'; @@ -72,6 +74,8 @@ import { createCodeInterpreterConnectionOptions } from '@/runs/execution/tools/h import { ReadFileTool } from '@/runs/execution/tools/read-file-tool.js'; import { snakeToCamel } from '@/utils/strings.js'; import { createSearchTool } from '@/runs/execution/tools/search-tool'; +import { createChatLLM } from '@/runs/execution/factory.js'; +import { getDefaultModel } from '@/runs/execution/constants.js'; type SystemTool = Pick & { type: ToolType; @@ -462,6 +466,10 @@ function getSystemTools() { }); const fileSearch = new FileSearchTool({ vectorStores: [], maxNumResults: 0 }); const readFile = new ReadFileTool({ files: [], fileSize: 0 }); + const llmTool = new LLMTool({ + llm: createChatLLM({ model: getDefaultModel() }) + }); + const calculatorTool = new CalculatorTool(); const systemTools = new Map(); @@ -569,6 +577,26 @@ function getSystemTools() { userDescription: 'Execute Python code for various tasks, including data analysis, file processing, and visualizations. Supports the installation of any library such as NumPy, Pandas, SciPy, and Matplotlib. Users can create new files or convert existing files, which are then made available for download.' }); + systemTools.set(SystemTools.LLM, { + type: ToolType.SYSTEM, + id: SystemTools.LLM, + createdAt: new Date('2024-12-12'), + ...llmTool, + inputSchema: llmTool.inputSchema.bind(llmTool), + isExternal: false, + userDescription: + 'Uses expert LLM to work with data in the existing conversation (classification, entity extraction, summarization, ...)' + }); + systemTools.set(SystemTools.CALCULATOR, { + type: ToolType.SYSTEM, + id: SystemTools.CALCULATOR, + createdAt: new Date('2024-12-12'), + ...calculatorTool, + inputSchema: calculatorTool.inputSchema.bind(calculatorTool), + isExternal: false, + userDescription: + 'A calculator tool that performs basic arithmetic operations like addition, subtraction, multiplication, and division. Only use the calculator tool if you need to perform a calculation.' + }); return systemTools; } @@ -591,6 +619,8 @@ export async function listTools({ allSystemTools.get(SystemTools.WIKIPEDIA), allSystemTools.get(SystemTools.WEATHER), allSystemTools.get(SystemTools.ARXIV), + allSystemTools.get(SystemTools.LLM), + allSystemTools.get(SystemTools.CALCULATOR), allSystemTools.get('read_file') ] : [];