From c18bc9037d1b07ceabd901a79a59e53a28b7745f Mon Sep 17 00:00:00 2001 From: Aymeric Date: Fri, 20 Dec 2024 16:18:40 +0100 Subject: [PATCH] =?UTF-8?q?Add=20E2B=20code=20interpreter=20=F0=9F=A5=B3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 7 +- Dockerfile | 4 +- e2b.Dockerfile | 5 + e2b.toml | 16 +++ examples/docker_example.py | 4 +- examples/dummytool.py | 2 +- examples/e2b_example.py | 44 +++++++ src/agents/__init__.py | 16 ++- src/agents/agents.py | 65 ++++------- .../{tools => default_tools}/__init__.py | 0 .../base.py} | 69 +---------- src/agents/{tools => default_tools}/search.py | 10 +- src/agents/docker_alternative.py | 2 +- src/agents/docker_python_executor.py | 4 +- src/agents/e2b_executor.py | 103 +++++++++++++++++ src/agents/gradio_ui.py | 6 +- src/agents/local_python_executor.py | 108 ++++++++++++++---- src/agents/prompts.py | 2 +- src/agents/tool_validation.py | 59 +++++++--- src/agents/{tool.py => tools.py} | 83 +++----------- src/agents/types.py | 9 +- src/agents/utils.py | 21 +++- tests/test_agents.py | 2 +- tests/test_tools_common.py | 2 +- 24 files changed, 400 insertions(+), 243 deletions(-) create mode 100644 e2b.Dockerfile create mode 100644 e2b.toml create mode 100644 examples/e2b_example.py rename src/agents/{tools => default_tools}/__init__.py (100%) rename src/agents/{default_tools.py => default_tools/base.py} (73%) rename src/agents/{tools => default_tools}/search.py (86%) create mode 100644 src/agents/e2b_executor.py rename src/agents/{tool.py => tools.py} (94%) diff --git a/.gitignore b/.gitignore index b9ee3baf..33c003c6 100644 --- a/.gitignore +++ b/.gitignore @@ -36,6 +36,7 @@ sdist/ var/ wheels/ share/python-wheels/ +node_modules/ *.egg-info/ .installed.cfg *.egg @@ -156,4 +157,8 @@ dmypy.json cython_debug/ # PyCharm -#.idea/ \ No newline at end of file +#.idea/ + +# Archive +archive/ +savedir/ \ No newline at end of file diff --git a/Dockerfile b/Dockerfile index 67d2b42b..a4ff4b8b 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,5 +1,5 @@ # Base Python image -FROM python:3.9-slim +FROM python:3.12-slim # Set working directory WORKDIR /app @@ -7,8 +7,6 @@ WORKDIR /app # Install build dependencies RUN apt-get update && apt-get install -y \ build-essential \ - gcc \ - g++ \ zlib1g-dev \ libjpeg-dev \ libpng-dev \ diff --git a/e2b.Dockerfile b/e2b.Dockerfile new file mode 100644 index 00000000..b9721dc6 --- /dev/null +++ b/e2b.Dockerfile @@ -0,0 +1,5 @@ +# You can use most Debian-based base images +FROM e2bdev/code-interpreter:latest + +# Install dependencies and customize sandbox +RUN pip install git+https://github.com/huggingface/agents.git \ No newline at end of file diff --git a/e2b.toml b/e2b.toml new file mode 100644 index 00000000..74018dc7 --- /dev/null +++ b/e2b.toml @@ -0,0 +1,16 @@ +# This is a config for E2B sandbox template. +# You can use template ID (qywp2ctmu2q7jzprcf4j) to create a sandbox: + +# Python SDK +# from e2b import Sandbox, AsyncSandbox +# sandbox = Sandbox("qywp2ctmu2q7jzprcf4j") # Sync sandbox +# sandbox = await AsyncSandbox.create("qywp2ctmu2q7jzprcf4j") # Async sandbox + +# JS SDK +# import { Sandbox } from 'e2b' +# const sandbox = await Sandbox.create('qywp2ctmu2q7jzprcf4j') + +team_id = "f8776d3a-df2f-4a1d-af48-68c2e13b3b87" +start_cmd = "/root/.jupyter/start-up.sh" +dockerfile = "e2b.Dockerfile" +template_id = "qywp2ctmu2q7jzprcf4j" diff --git a/examples/docker_example.py b/examples/docker_example.py index 39384095..ddfb50ef 100644 --- a/examples/docker_example.py +++ b/examples/docker_example.py @@ -1,8 +1,8 @@ -from agents.tools.search import DuckDuckGoSearchTool +from agents.default_tools.search import DuckDuckGoSearchTool from agents.docker_alternative import DockerPythonInterpreter -from agents.tool import Tool +from agents.tools import Tool class DummyTool(Tool): name = "echo" diff --git a/examples/dummytool.py b/examples/dummytool.py index f9176299..75edfd92 100644 --- a/examples/dummytool.py +++ b/examples/dummytool.py @@ -1,4 +1,4 @@ -from agents.tool import Tool +from agents.tools import Tool class DummyTool(Tool): diff --git a/examples/e2b_example.py b/examples/e2b_example.py new file mode 100644 index 00000000..74d219fc --- /dev/null +++ b/examples/e2b_example.py @@ -0,0 +1,44 @@ +from agents import Tool, CodeAgent +from agents.default_tools.search import VisitWebpageTool +from dotenv import load_dotenv +load_dotenv() + +LAUNCH_GRADIO = False + +class GetCatImageTool(Tool): + name="get_cat_image" + description = "Get a cat image" + inputs = {} + output_type = "image" + + def __init__(self): + super().__init__() + self.url = "https://em-content.zobj.net/source/twitter/53/robot-face_1f916.png" + + def forward(self): + from PIL import Image + import requests + from io import BytesIO + + response = requests.get(self.url) + + return Image.open(BytesIO(response.content)) + +get_cat_image = GetCatImageTool() + + +agent = CodeAgent( + tools = [get_cat_image, VisitWebpageTool()], + additional_authorized_imports=["Pillow", "requests", "markdownify"], # "duckduckgo-search", + use_e2b_executor=False +) + +if LAUNCH_GRADIO: + from agents.gradio_ui import GradioUI + + GradioUI(agent).launch() +else: + agent.run( + "Return me an image of Lincoln's preferred pet", + additional_context="Here is a webpage about US presidents and pets: https://www.9lives.com/blog/a-history-of-cats-in-the-white-house/" + ) diff --git a/src/agents/__init__.py b/src/agents/__init__.py index 40932d66..ba8f8f1f 100644 --- a/src/agents/__init__.py +++ b/src/agents/__init__.py @@ -24,22 +24,28 @@ if TYPE_CHECKING: from .agents import * - from .default_tools import * + from .default_tools.base import * + from .default_tools.search import * from .gradio_ui import * from .llm_engines import * from .local_python_executor import * from .monitoring import * from .prompts import * - from .tools.search import * - from .tool import * + from .tools import * from .types import * from .utils import * + from .default_tools.search import * else: import sys - _file = globals()["__file__"] + import_structure = define_import_structure(_file) + import_structure[""]= {"__version__": __version__} sys.modules[__name__] = _LazyModule( - __name__, _file, define_import_structure(_file), module_spec=__spec__ + __name__, + _file, + import_structure, + module_spec=__spec__, + extra_objects={"__version__": __version__} ) diff --git a/src/agents/agents.py b/src/agents/agents.py index acfdc013..0ec2a11b 100644 --- a/src/agents/agents.py +++ b/src/agents/agents.py @@ -27,7 +27,7 @@ from .utils import console, parse_code_blob, parse_json_tool_call, truncate_content from .types import AgentAudio, AgentImage -from .default_tools import BASE_PYTHON_TOOLS, FinalAnswerTool +from .default_tools.base import FinalAnswerTool from .llm_engines import HfApiEngine, MessageRole from .monitoring import Monitor from .prompts import ( @@ -42,8 +42,9 @@ SYSTEM_PROMPT_PLAN_UPDATE, SYSTEM_PROMPT_PLAN, ) -from .local_python_executor import LIST_SAFE_MODULES, evaluate_python_code -from .tool import ( +from .local_python_executor import BASE_BUILTIN_MODULES, LocalPythonExecutor +from .e2b_executor import E2BExecutor +from .tools import ( DEFAULT_TOOL_DESCRIPTION_TEMPLATE, Tool, get_tool_description_with_args, @@ -169,17 +170,6 @@ def format_prompt_with_managed_agents_descriptions( else: return prompt_template.replace(agent_descriptions_placeholder, "") - -def format_prompt_with_imports( - prompt_template: str, authorized_imports: List[str] -) -> str: - if "<>" not in prompt_template: - raise AgentError( - "Tag '<>' should be provided in the prompt." - ) - return prompt_template.replace("<>", str(authorized_imports)) - - class BaseAgent: def __init__( self, @@ -264,11 +254,6 @@ def initialize_system_prompt(self): self.system_prompt = format_prompt_with_managed_agents_descriptions( self.system_prompt, self.managed_agents ) - if hasattr(self, "authorized_imports"): - self.system_prompt = format_prompt_with_imports( - self.system_prompt, - list(set(LIST_SAFE_MODULES) | set(getattr(self, "authorized_imports"))), - ) return self.system_prompt @@ -439,9 +424,7 @@ def execute_tool_call(self, tool_name: str, arguments: Dict[str, str]) -> Any: tool_name (`str`): Name of the Tool to execute (should be one from self.toolbox). arguments (Dict[str, str]): Arguments passed to the Tool. """ - available_tools = self.toolbox.tools - if self.managed_agents is not None: - available_tools = {**available_tools, **self.managed_agents} + available_tools = {**self.toolbox.tools, **self.managed_agents} if tool_name not in available_tools: error_msg = f"Error: unknown tool {tool_name}, should be instead one of {list(available_tools.keys())}." console.print(f"[bold red]{error_msg}") @@ -674,8 +657,6 @@ def planning_step(self, task, is_first_step: bool, iteration: int): ), managed_agents_descriptions=( show_agents_descriptions(self.managed_agents) - if self.managed_agents is not None - else "" ), answer_facts=answer_facts, ), @@ -729,8 +710,6 @@ def planning_step(self, task, is_first_step: bool, iteration: int): ), managed_agents_descriptions=( show_agents_descriptions(self.managed_agents) - if self.managed_agents is not None - else "" ), facts_update=facts_update, remaining_steps=(self.max_iterations - iteration), @@ -891,6 +870,7 @@ def __init__( grammar: Optional[Dict[str, str]] = None, additional_authorized_imports: Optional[List[str]] = None, planning_interval: Optional[int] = None, + use_e2b_executor: bool = False, **kwargs, ): if llm_engine is None: @@ -909,17 +889,24 @@ def __init__( **kwargs, ) - self.python_evaluator = evaluate_python_code self.additional_authorized_imports = ( additional_authorized_imports if additional_authorized_imports else [] ) + all_tools = {**self.toolbox.tools, **self.managed_agents} + if use_e2b_executor: + self.python_executor = E2BExecutor(self.additional_authorized_imports, list(all_tools.values())) + else: + self.python_executor = LocalPythonExecutor(self.additional_authorized_imports, all_tools) self.authorized_imports = list( - set(LIST_SAFE_MODULES) | set(self.additional_authorized_imports) + set(BASE_BUILTIN_MODULES) | set(self.additional_authorized_imports) ) + if "{{authorized_imports}}" not in self.system_prompt: + raise AgentError( + "Tag '{{authorized_imports}}' should be provided in the prompt." + ) self.system_prompt = self.system_prompt.replace( "{{authorized_imports}}", str(self.authorized_imports) ) - self.custom_tools = {} def step(self, log_entry: ActionStep) -> Union[None, Any]: """ @@ -991,22 +978,12 @@ def step(self, log_entry: ActionStep) -> Union[None, Any]: ) try: - static_tools = { - **BASE_PYTHON_TOOLS.copy(), - **self.toolbox.tools, - } - if self.managed_agents is not None: - static_tools = {**static_tools, **self.managed_agents} - output = self.python_evaluator( + output, execution_logs = self.python_executor( code_action, - static_tools=static_tools, - custom_tools=self.custom_tools, - state=self.state, - authorized_imports=self.authorized_imports, ) - if len(self.state["print_outputs"]) > 0: - console.print(Group(Text("Print outputs:", style="bold"), Text(self.state["print_outputs"]))) - observation = "Print outputs:\n" + self.state["print_outputs"] + if len(execution_logs) > 0: + console.print(Group(Text("Execution logs:", style="bold"), Text(execution_logs))) + observation = "Execution logs:\n" + execution_logs if output is not None: truncated_output = truncate_content( str(output) @@ -1026,7 +1003,7 @@ def step(self, log_entry: ActionStep) -> Union[None, Any]: console.print(Group(Text("Final answer:", style="bold"), Text(str(output), style="bold green"))) log_entry.action_output = output return output - return None + class ManagedAgent: diff --git a/src/agents/tools/__init__.py b/src/agents/default_tools/__init__.py similarity index 100% rename from src/agents/tools/__init__.py rename to src/agents/default_tools/__init__.py diff --git a/src/agents/default_tools.py b/src/agents/default_tools/base.py similarity index 73% rename from src/agents/default_tools.py rename to src/agents/default_tools/base.py index ca41136d..65d6a9e2 100644 --- a/src/agents/default_tools.py +++ b/src/agents/default_tools/base.py @@ -15,75 +15,14 @@ # See the License for the specific language governing permissions and # limitations under the License. import json -import math from dataclasses import dataclass -from math import sqrt from typing import Dict from huggingface_hub import hf_hub_download, list_spaces from transformers.utils import is_offline_mode -from .local_python_executor import LIST_SAFE_MODULES, evaluate_python_code -from .tool import TOOL_CONFIG_FILE, Tool - - -def custom_print(*args): - return None - - -BASE_PYTHON_TOOLS = { - "print": custom_print, - "isinstance": isinstance, - "range": range, - "float": float, - "int": int, - "bool": bool, - "str": str, - "set": set, - "list": list, - "dict": dict, - "tuple": tuple, - "round": round, - "ceil": math.ceil, - "floor": math.floor, - "log": math.log, - "exp": math.exp, - "sin": math.sin, - "cos": math.cos, - "tan": math.tan, - "asin": math.asin, - "acos": math.acos, - "atan": math.atan, - "atan2": math.atan2, - "degrees": math.degrees, - "radians": math.radians, - "pow": math.pow, - "sqrt": sqrt, - "len": len, - "sum": sum, - "max": max, - "min": min, - "abs": abs, - "enumerate": enumerate, - "zip": zip, - "reversed": reversed, - "sorted": sorted, - "all": all, - "any": any, - "map": map, - "filter": filter, - "ord": ord, - "chr": chr, - "next": next, - "iter": iter, - "divmod": divmod, - "callable": callable, - "getattr": getattr, - "hasattr": hasattr, - "setattr": setattr, - "issubclass": issubclass, - "type": type, -} +from ..local_python_executor import BASE_BUILTIN_MODULES, BASE_PYTHON_TOOLS, evaluate_python_code +from ..tools import TOOL_CONFIG_FILE, Tool @dataclass @@ -136,10 +75,10 @@ class PythonInterpreterTool(Tool): def __init__(self, *args, authorized_imports=None, **kwargs): if authorized_imports is None: - self.authorized_imports = list(set(LIST_SAFE_MODULES)) + self.authorized_imports = list(set(BASE_BUILTIN_MODULES)) else: self.authorized_imports = list( - set(LIST_SAFE_MODULES) | set(authorized_imports) + set(BASE_BUILTIN_MODULES) | set(authorized_imports) ) self.inputs = { "code": { diff --git a/src/agents/tools/search.py b/src/agents/default_tools/search.py similarity index 86% rename from src/agents/tools/search.py rename to src/agents/default_tools/search.py index 01af0cd6..fad2d469 100644 --- a/src/agents/tools/search.py +++ b/src/agents/default_tools/search.py @@ -16,15 +16,11 @@ # limitations under the License. import re -import requests -from requests.exceptions import RequestException - from ..tools import Tool - class DuckDuckGoSearchTool(Tool): name = "web_search" - description = """Perform a web search based on your query (think a Google search) then returns the top search results as a list of dict elements. + description = """Performs a web search based on your query (think a Google search) then returns the top search results as a list of dict elements. Each result has keys 'title', 'href' and 'body'.""" inputs = { "query": {"type": "string", "description": "The search query to perform."} @@ -56,9 +52,11 @@ class VisitWebpageTool(Tool): def forward(self, url: str) -> str: try: from markdownify import markdownify + import requests + from requests.exceptions import RequestException except ImportError: raise ImportError( - "You must install package `markdownify` to run this tool: for instance run `pip install markdownify`." + "You must install packages `markdownify` and `requests` to run this tool: for instance run `pip install markdownify requests`." ) try: # Send a GET request to the URL diff --git a/src/agents/docker_alternative.py b/src/agents/docker_alternative.py index 711d2c8a..b035c7ee 100644 --- a/src/agents/docker_alternative.py +++ b/src/agents/docker_alternative.py @@ -3,7 +3,7 @@ import warnings import socket -from agents.tool import Tool +from agents.tools import Tool class DockerPythonInterpreter: def __init__(self): diff --git a/src/agents/docker_python_executor.py b/src/agents/docker_python_executor.py index 96d4446e..b15b9291 100644 --- a/src/agents/docker_python_executor.py +++ b/src/agents/docker_python_executor.py @@ -343,7 +343,7 @@ def stop(self, remove: bool = False): def execute_locally(code: str, work_dir: Path, tools: Dict[str, Any]) -> Any: - from .local_python_executor import evaluate_python_code, LIST_SAFE_MODULES + from .local_python_executor import evaluate_python_code, BASE_BUILTIN_MODULES """Execute code locally with state transfer.""" state_manager = StateManager(work_dir) @@ -363,7 +363,7 @@ def execute_locally(code: str, work_dir: Path, tools: Dict[str, Any]) -> Any: tools, {}, namespace, - LIST_SAFE_MODULES, + BASE_BUILTIN_MODULES, ) # Save state for Docker diff --git a/src/agents/e2b_executor.py b/src/agents/e2b_executor.py new file mode 100644 index 00000000..e7d199ad --- /dev/null +++ b/src/agents/e2b_executor.py @@ -0,0 +1,103 @@ +#!/usr/bin/env python +# coding=utf-8 + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dotenv import load_dotenv +import textwrap +import base64 +from io import BytesIO +from PIL import Image + +from e2b_code_interpreter import Sandbox +from typing import Dict, List, Callable, Tuple, Any +from .tool_validation import validate_tool_attributes +from .utils import instance_to_source, BASE_BUILTIN_MODULES +from .tools import Tool +from .types import AgentImage + +load_dotenv() + + +class E2BExecutor(): + def __init__(self, additional_imports: List[str], tools: List[Tool]): + self.custom_tools = {} + self.sbx = Sandbox() # "qywp2ctmu2q7jzprcf4j") + # TODO: validate installing agents package or not + # print("Installing agents package on remote executor...") + # self.sbx.commands.run( + # "pip install git+https://github.com/huggingface/agents.git", + # timeout=300 + # ) + # print("Installation of agents package finished.") + if len(additional_imports) > 0: + execution = self.sbx.commands.run("pip install " + " ".join(additional_imports)) + if execution.error: + raise Exception(f"Error installing dependencies: {execution.error}") + else: + print("Installation succeeded!") + + tool_codes = [] + for tool in tools: + validate_tool_attributes(tool.__class__, check_imports=False) + tool_code = instance_to_source(tool, base_cls=Tool) + tool_code = tool_code.replace("from agents.tools import Tool", "") + tool_code += f"\n{tool.name} = {tool.__class__.__name__}()\n" + tool_codes.append(tool_code) + + tool_definition_code = "\n".join([f"import {module}" for module in BASE_BUILTIN_MODULES]) + tool_definition_code += textwrap.dedent(""" +class Tool: + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) + + def forward(self, *args, **kwargs): + pass # to be implemented in child class + """) + tool_definition_code += "\n\n".join(tool_codes) + + tool_definition_execution = self.run_code_raise_errors(tool_definition_code) + print(tool_definition_execution.logs) + + def run_code_raise_errors(self, code: str): + execution = self.sbx.run_code( + code, + ) + if execution.error: + logs = 'Executing code yielded an error:' + logs += execution.error.name + logs += execution.error.value + logs += execution.error.traceback + raise ValueError(logs) + return execution + + def __call__(self, code_action: str) -> Tuple[Any, Any]: + execution = self.run_code_raise_errors(code_action) + execution_logs = "\n".join([str(log) for log in execution.logs.stdout]) + if not execution.results: + return None, execution_logs + else: + for result in execution.results: + if result.is_main_result: + for attribute_name in ['jpeg', 'png']: + if getattr(result, attribute_name) is not None: + image_output = getattr(result, attribute_name) + decoded_bytes = base64.b64decode(image_output.encode('utf-8')) + return Image.open(BytesIO(decoded_bytes)), execution_logs + for attribute_name in ['chart', 'data', 'html', 'javascript', 'json', 'latex', 'markdown', 'pdf', 'svg', 'text']: + if getattr(result, attribute_name) is not None: + return getattr(result, attribute_name), execution_logs + raise ValueError("No main result returned by executor!") + +__all__ = ["E2BExecutor"] \ No newline at end of file diff --git a/src/agents/gradio_ui.py b/src/agents/gradio_ui.py index 176fda48..e3324519 100644 --- a/src/agents/gradio_ui.py +++ b/src/agents/gradio_ui.py @@ -14,7 +14,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from .types import AgentAudio, AgentImage, AgentText +from .types import AgentAudio, AgentImage, AgentText, handle_agent_output_types from .agents import BaseAgent, AgentStep, ActionStep import gradio as gr @@ -58,7 +58,7 @@ def stream_to_gradio( for message in pull_messages_from_step(step_log, test_mode=test_mode): yield message - final_answer = step_log # Last log is the run's final_answer + final_answer = handle_agent_output_types(step_log) # Last log is the run's final_answer if isinstance(final_answer, AgentText): yield gr.ChatMessage( @@ -93,7 +93,7 @@ def interact_with_agent(self, prompt, messages): yield messages yield messages - def run(self): + def launch(self): with gr.Blocks() as demo: stored_message = gr.State([]) chatbot = gr.Chatbot( diff --git a/src/agents/local_python_executor.py b/src/agents/local_python_executor.py index 938e1846..f31b6d0d 100644 --- a/src/agents/local_python_executor.py +++ b/src/agents/local_python_executor.py @@ -19,12 +19,12 @@ import difflib from collections.abc import Mapping from importlib import import_module -from typing import Any, Callable, Dict, List, Optional - +from typing import Any, Callable, Dict, List, Optional, Tuple +import math import numpy as np import pandas as pd -from .utils import truncate_content +from .utils import truncate_content, BASE_BUILTIN_MODULES class InterpreterError(ValueError): @@ -43,24 +43,66 @@ class InterpreterError(ValueError): and issubclass(getattr(builtins, name), BaseException) } - -LIST_SAFE_MODULES = [ - "random", - "collections", - "math", - "time", - "queue", - "itertools", - "re", - "stat", - "statistics", - "unicodedata", -] - PRINT_OUTPUTS, MAX_LEN_OUTPUT = "", 50000 OPERATIONS_COUNT, MAX_OPERATIONS = 0, 10000000 +def custom_print(*args): + return None + +BASE_PYTHON_TOOLS = { + "print": custom_print, + "isinstance": isinstance, + "range": range, + "float": float, + "int": int, + "bool": bool, + "str": str, + "set": set, + "list": list, + "dict": dict, + "tuple": tuple, + "round": round, + "ceil": math.ceil, + "floor": math.floor, + "log": math.log, + "exp": math.exp, + "sin": math.sin, + "cos": math.cos, + "tan": math.tan, + "asin": math.asin, + "acos": math.acos, + "atan": math.atan, + "atan2": math.atan2, + "degrees": math.degrees, + "radians": math.radians, + "pow": math.pow, + "sqrt": math.sqrt, + "len": len, + "sum": sum, + "max": max, + "min": min, + "abs": abs, + "enumerate": enumerate, + "zip": zip, + "reversed": reversed, + "sorted": sorted, + "all": all, + "any": any, + "map": map, + "filter": filter, + "ord": ord, + "chr": chr, + "next": next, + "iter": iter, + "divmod": divmod, + "callable": callable, + "getattr": getattr, + "hasattr": hasattr, + "setattr": setattr, + "issubclass": issubclass, + "type": type, +} class BreakException(Exception): pass @@ -771,7 +813,7 @@ def evaluate_ast( state: Dict[str, Any], static_tools: Dict[str, Callable], custom_tools: Dict[str, Callable], - authorized_imports: List[str] = LIST_SAFE_MODULES, + authorized_imports: List[str] = BASE_BUILTIN_MODULES, ): """ Evaluate an abstract syntax tree using the content of the variables stored in a state and only evaluating a given @@ -949,7 +991,7 @@ def evaluate_python_code( static_tools: Optional[Dict[str, Callable]] = None, custom_tools: Optional[Dict[str, Callable]] = None, state: Optional[Dict[str, Any]] = None, - authorized_imports: List[str] = LIST_SAFE_MODULES, + authorized_imports: List[str] = BASE_BUILTIN_MODULES, ): """ Evaluate a python expression using the content of the variables stored in a state and only evaluating a given set @@ -1001,4 +1043,30 @@ def evaluate_python_code( raise InterpreterError(msg) -__all__ = ["evaluate_python_code"] +class LocalPythonExecutor(): + def __init__(self, additional_authorized_imports: List[str], tools: Dict): + self.custom_tools = {} + self.state = {} + self.additional_authorized_imports = additional_authorized_imports + self.authorized_imports = list( + set(BASE_BUILTIN_MODULES) | set(self.additional_authorized_imports) + ) + # Add base trusted tools to list + self.static_tools = { + **tools, + **BASE_PYTHON_TOOLS.copy(), + } + # TODO: assert self.authorized imports are all installed locally + + def __call__(self, code_action: str) -> Tuple[Any, str]: + output = evaluate_python_code( + code_action, + static_tools=self.static_tools, + custom_tools=self.custom_tools, + state=self.state, + authorized_imports=self.authorized_imports, + ) + logs = self.state["print_outputs"] + return output, logs + +__all__ = ["evaluate_python_code", "LocalPythonExecutor"] diff --git a/src/agents/prompts.py b/src/agents/prompts.py index 1d5a7dbf..05721f24 100644 --- a/src/agents/prompts.py +++ b/src/agents/prompts.py @@ -370,7 +370,7 @@ def download_prompt(prompt_or_repo_id, agent_name, mode="run"): 5. Call a tool only when needed, and never re-do a tool call that you previously did with the exact same parameters. 6. Don't name any new variable with the same name as a tool: for instance don't name a variable 'final_answer'. 7. Never create any notional variables in our code, as having these in your logs might derail you from the true variables. -8. You can use imports in your code, but only from the following list of modules: <> +8. You can use imports in your code, but only from the following list of modules: {{authorized_imports}} 9. The state persists between code executions: so if in one step you've created variables or imported modules, these will all persist. 10. Don't give up! You're in charge of solving the task, not providing directions to solve it. diff --git a/src/agents/tool_validation.py b/src/agents/tool_validation.py index 7e25afe7..cd714754 100644 --- a/src/agents/tool_validation.py +++ b/src/agents/tool_validation.py @@ -5,27 +5,30 @@ from pathlib import Path from typing import List, Set, Dict import textwrap +from .utils import BASE_BUILTIN_MODULES _BUILTIN_NAMES = set(vars(builtins)) -def is_local_import(module_name: str) -> bool: +IMPORTED_PACKAGES = BASE_BUILTIN_MODULES + +def is_installed_package(module_name: str) -> bool: """ - Check if an import is from a local file or a package. - Returns True if it's a local file import. + Check if an import is from an installed package. + Returns False if it's not found or a local file import. """ try: spec = importlib.util.find_spec(module_name) if spec is None: - return True # If we can't find the module, assume it's local + return False # If we can't find the module, assume it's local # If the module is found and has a file path, check if it's in site-packages if spec.origin and 'site-packages' not in spec.origin: # Check if it's a .py file in the current directory or subdirectories - return spec.origin.endswith('.py') - + return not spec.origin.endswith('.py') + return False except ImportError: - return True # If there's an import error, assume it's local + return False # If there's an import error, assume it's local class MethodChecker(ast.NodeVisitor): """ @@ -33,7 +36,7 @@ class MethodChecker(ast.NodeVisitor): - only uses defined names - contains no local imports (e.g. numpy is ok but local_script is not) """ - def __init__(self, class_attributes: Set[str]): + def __init__(self, class_attributes: Set[str], check_imports: bool = True): self.undefined_names = set() self.imports = {} self.from_imports = {} @@ -41,6 +44,7 @@ def __init__(self, class_attributes: Set[str]): self.arg_names = set() self.class_attributes = class_attributes self.errors = [] + self.check_imports = check_imports def visit_arguments(self, node): """Collect function arguments""" @@ -53,16 +57,16 @@ def visit_arguments(self, node): def visit_Import(self, node): for name in node.names: actual_name = name.asname or name.name - if is_local_import(actual_name): - self.errors.append(f"Local import '{actual_name}'") + if not is_installed_package(actual_name) and self.check_imports: + self.errors.append(f"Package not found in importlib, might be a local install: '{actual_name}'") self.imports[actual_name] = name.name def visit_ImportFrom(self, node): module = node.module or "" for name in node.names: actual_name = name.asname or name.name - if is_local_import(module): - self.errors.append(f"Local import '{module}'") + if not is_installed_package(module) and self.check_imports: + self.errors.append(f"Package not found in importlib, might be a local install: '{module}'") self.from_imports[actual_name] = (module, name.name) def visit_Assign(self, node): @@ -71,6 +75,20 @@ def visit_Assign(self, node): self.assigned_names.add(target.id) self.visit(node.value) + def visit_With(self, node): + """Track aliases in 'with' statements (the 'y' in 'with X as y')""" + for item in node.items: + if item.optional_vars: # This is the 'y' in 'with X as y' + if isinstance(item.optional_vars, ast.Name): + self.assigned_names.add(item.optional_vars.id) + self.generic_visit(node) + + def visit_ExceptHandler(self, node): + """Track exception aliases (the 'e' in 'except Exception as e')""" + if node.name: # This is the 'e' in 'except Exception as e' + self.assigned_names.add(node.name) + self.generic_visit(node) + def visit_AnnAssign(self, node): """Track annotated assignments.""" if isinstance(node.target, ast.Name): @@ -97,6 +115,7 @@ def visit_Name(self, node): if isinstance(node.ctx, ast.Load): if not ( node.id in _BUILTIN_NAMES + or node.id in IMPORTED_PACKAGES or node.id in self.arg_names or node.id == "self" or node.id in self.class_attributes @@ -110,17 +129,18 @@ def visit_Call(self, node): if isinstance(node.func, ast.Name): if not ( node.func.id in _BUILTIN_NAMES + or node.func.id in IMPORTED_PACKAGES or node.func.id in self.arg_names or node.func.id == "self" or node.func.id in self.class_attributes or node.func.id in self.imports or node.func.id in self.from_imports or node.func.id in self.assigned_names - ): + ): self.errors.append(f"Name '{node.func.id}' is undefined.") self.generic_visit(node) -def validate_tool_attributes(cls) -> None: +def validate_tool_attributes(cls, check_imports: bool = True) -> None: """ Validates that a Tool class follows the proper patterns: 0. __init__ takes no argument (args chosen at init are not traceable so we cannot rebuild the source code for them, make them class attributes!). @@ -156,8 +176,17 @@ def __init__(self): self.imported_names = set() self.complex_attributes = set() self.class_attributes = set() + self.in_method = False + + def visit_FunctionDef(self, node): + old_context = self.in_method + self.in_method = True + self.generic_visit(node) + self.in_method = old_context def visit_Assign(self, node): + if self.in_method: + return # Track class attributes for target in node.targets: if isinstance(target, ast.Name): @@ -182,7 +211,7 @@ def visit_Assign(self, node): # Run checks on all methods for node in class_node.body: if isinstance(node, ast.FunctionDef): - method_checker = MethodChecker(class_level_checker.class_attributes) + method_checker = MethodChecker(class_level_checker.class_attributes, check_imports=check_imports) method_checker.visit(node) errors += [f"- {node.name}: {error}" for error in method_checker.errors] diff --git a/src/agents/tool.py b/src/agents/tools.py similarity index 94% rename from src/agents/tool.py rename to src/agents/tools.py index 76713946..56404130 100644 --- a/src/agents/tool.py +++ b/src/agents/tools.py @@ -28,6 +28,7 @@ from functools import lru_cache, wraps from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Union, Set +import math from huggingface_hub import ( create_repo, @@ -48,7 +49,7 @@ is_vision_available, ) from transformers.dynamic_module_utils import get_imports -from .types import ImageType, handle_agent_inputs, handle_agent_outputs +from .types import ImageType, handle_agent_input_types, handle_agent_output_types from .utils import instance_to_source from .tool_validation import validate_tool_attributes, MethodChecker @@ -66,7 +67,6 @@ TOOL_CONFIG_FILE = "tool_config.json" - def get_repo_type(repo_id, repo_type=None, **hub_kwargs): if repo_type is not None: return repo_type @@ -197,12 +197,15 @@ def validate_arguments(self, do_validate_forward: bool = True): def forward(self, *args, **kwargs): return NotImplementedError("Write this method in your subclass of `Tool`.") - def __call__(self, *args, **kwargs): + def __call__(self, *args, sanitize_inputs_outputs: bool = False, **kwargs): if not self.is_initialized: self.setup() - args, kwargs = handle_agent_inputs(*args, **kwargs) + if sanitize_inputs_outputs: + args, kwargs = handle_agent_input_types(*args, **kwargs) outputs = self.forward(*args, **kwargs) - return handle_agent_outputs(outputs, self.output_type) + if sanitize_inputs_outputs: + outputs = handle_agent_output_types(outputs, self.output_type) + return outputs def setup(self): """ @@ -266,9 +269,7 @@ def replacement(match): forward_source_code = add_self_argument(forward_source_code) forward_source_code = forward_source_code.replace("@tool", "").strip() tool_code += "\n\n" + textwrap.indent(forward_source_code, " ") - - with open(tool_file, "w", encoding="utf-8") as f: - f.write(tool_code) + else: # If the tool was not created by the @tool decorator, it was made by subclassing Tool if type(self).__name__ in ["SpaceToolWrapper", "LangChainToolWrapper", "GradioToolWrapper"]: raise ValueError( @@ -278,8 +279,9 @@ def replacement(match): validate_tool_attributes(self.__class__) tool_code = instance_to_source(self, base_cls=Tool) - with open(tool_file, "w", encoding="utf-8") as f: - f.write(tool_code) + + with open(tool_file, "w", encoding="utf-8") as f: + f.write(tool_code) # Save app file app_file = os.path.join(output_dir, "app.py") @@ -719,6 +721,9 @@ def launch_gradio_demo(tool: Tool): "number": gr.Textbox, } + def fn(*args, **kwargs): + return tool(*args, **kwargs, sanitize_inputs_outputs=True) + gradio_inputs = [] for input_name, input_details in tool.inputs.items(): input_gradio_component_class = TYPE_TO_COMPONENT_CLASS_MAPPING[ @@ -733,7 +738,7 @@ def launch_gradio_demo(tool: Tool): gradio_output = output_gradio_componentclass(label="Output") gr.Interface( - fn=tool, # This works because `tool` has a __call__ method + fn=fn, inputs=gradio_inputs, outputs=gradio_output, title=tool.name, @@ -823,61 +828,6 @@ def inner(func): return inner -## Will move to the Hub -class EndpointClient: - def __init__(self, endpoint_url: str, token: Optional[str] = None): - self.headers = { - **build_hf_headers(token=token), - "Content-Type": "application/json", - } - self.endpoint_url = endpoint_url - - @staticmethod - def encode_image(image): - _bytes = io.BytesIO() - image.save(_bytes, format="PNG") - b64 = base64.b64encode(_bytes.getvalue()) - return b64.decode("utf-8") - - @staticmethod - def decode_image(raw_image): - if not is_vision_available(): - raise ImportError( - "This tool returned an image but Pillow is not installed. Please install it (`pip install Pillow`)." - ) - - from PIL import Image - - b64 = base64.b64decode(raw_image) - _bytes = io.BytesIO(b64) - return Image.open(_bytes) - - def __call__( - self, - inputs: Optional[Union[str, Dict, List[str], List[List[str]]]] = None, - params: Optional[Dict] = None, - data: Optional[bytes] = None, - output_image: bool = False, - ) -> Any: - # Build payload - payload = {} - if inputs: - payload["inputs"] = inputs - if params: - payload["parameters"] = params - - # Make API call - response = get_session().post( - self.endpoint_url, headers=self.headers, json=payload, data=data - ) - - # By default, parse the response for the user. - if output_image: - return self.decode_image(response.content) - else: - return response.json() - - class ToolCollection: """ Tool collections enable loading all Spaces from a collection in order to be added to the agent's toolbox. @@ -1063,4 +1013,5 @@ def __repr__(self): "load_tool", "launch_gradio_demo", "Toolbox", + "ToolCollection", ] diff --git a/src/agents/types.py b/src/agents/types.py index 33970d74..b61bffcd 100644 --- a/src/agents/types.py +++ b/src/agents/types.py @@ -16,6 +16,7 @@ import pathlib import tempfile import uuid +from io import BytesIO import numpy as np @@ -105,6 +106,8 @@ def __init__(self, value): if isinstance(value, ImageType): self._raw = value + elif isinstance(value, bytes): + self._raw = Image.open(BytesIO(value)) elif isinstance(value, (str, pathlib.Path)): self._path = value elif isinstance(value, torch.Tensor): @@ -241,13 +244,13 @@ def to_string(self): AGENT_TYPE_MAPPING = {"string": AgentText, "image": AgentImage, "audio": AgentAudio} -INSTANCE_TYPE_MAPPING = {str: AgentText, ImageType: AgentImage} +INSTANCE_TYPE_MAPPING = {str: AgentText, ImageType: AgentImage, np.ndarray: AgentAudio} if is_torch_available(): INSTANCE_TYPE_MAPPING[Tensor] = AgentAudio -def handle_agent_inputs(*args, **kwargs): +def handle_agent_input_types(*args, **kwargs): args = [(arg.to_raw() if isinstance(arg, AgentType) else arg) for arg in args] kwargs = { k: (v.to_raw() if isinstance(v, AgentType) else v) for k, v in kwargs.items() @@ -255,7 +258,7 @@ def handle_agent_inputs(*args, **kwargs): return args, kwargs -def handle_agent_outputs(output, output_type=None): +def handle_agent_output_types(output, output_type=None): if output_type in AGENT_TYPE_MAPPING: # If the class has defined outputs, we can map directly according to the class definition decoded_outputs = AGENT_TYPE_MAPPING[output_type](output) diff --git a/src/agents/utils.py b/src/agents/utils.py index a5fc8ca7..cf3d324f 100644 --- a/src/agents/utils.py +++ b/src/agents/utils.py @@ -34,7 +34,18 @@ def is_pygments_available(): console = Console() - +BASE_BUILTIN_MODULES = [ + "random", + "collections", + "math", + "time", + "queue", + "itertools", + "re", + "stat", + "statistics", + "unicodedata", +] def parse_json_blob(json_blob: str) -> Dict[str, str]: try: first_accolade_index = json_blob.find("{") @@ -190,7 +201,10 @@ def instance_to_source(instance, base_cls=None): for name, value in class_attrs.items(): if isinstance(value, str): - class_lines.append(f' {name} = "{value}"') + if "\n" in value: + class_lines.append(f' {name} = """{value}"""') + else: + class_lines.append(f' {name} = "{value}"') else: class_lines.append(f' {name} = {repr(value)}') @@ -230,7 +244,8 @@ def instance_to_source(instance, base_cls=None): final_lines.append(f"from {base_cls.__module__} import {base_cls.__name__}") # Add discovered imports - final_lines.extend(required_imports) + for package in required_imports: + final_lines.append(f"import {package}") if final_lines: # Add empty line after imports final_lines.append("") diff --git a/tests/test_agents.py b/tests/test_agents.py index 93154c05..539f4cf9 100644 --- a/tests/test_agents.py +++ b/tests/test_agents.py @@ -29,7 +29,7 @@ Toolbox, ToolCall, ) -from agents.tool import tool +from agents.tools import tool from agents.default_tools import PythonInterpreterTool from transformers.testing_utils import get_tests_dir diff --git a/tests/test_tools_common.py b/tests/test_tools_common.py index 7d3b7307..4875f9ff 100644 --- a/tests/test_tools_common.py +++ b/tests/test_tools_common.py @@ -26,7 +26,7 @@ AgentImage, AgentText, ) -from agents.tool import Tool, tool, AUTHORIZED_TYPES +from agents.tools import Tool, tool, AUTHORIZED_TYPES from transformers.testing_utils import get_tests_dir