diff --git a/examples/hello_computer.py b/examples/hello_computer.py new file mode 100644 index 000000000..670812115 --- /dev/null +++ b/examples/hello_computer.py @@ -0,0 +1,37 @@ +from inspect_ai import Task, task +from inspect_ai.dataset import Sample +from inspect_ai.solver import generate, use_tools +from inspect_ai.tool import tool + + +@tool +def computer(): + async def execute( + action: str, + text: str | None = None, + coordinate: list[int] | None = None, + ) -> str: + """Take an action using a computer. + + Args: + action: Action to take. + text: Text related to the action + coordinate: Coordinate related to the action. + + Returns: + The sound that was passed to check. + """ + return action + + return execute + + +@task +def hello_computer(): + return Task( + dataset=[Sample(input="Call the computer tool with the action 'screenshot'")], + solver=[ + use_tools([computer()]), + generate(), + ], + ) diff --git a/src/inspect_ai/_cli/eval.py b/src/inspect_ai/_cli/eval.py index c1f68c33c..0bf0f4e78 100644 --- a/src/inspect_ai/_cli/eval.py +++ b/src/inspect_ai/_cli/eval.py @@ -364,6 +364,14 @@ def eval_options(func: Callable[..., Any]) -> Callable[..., click.Context]: help="Whether to enable parallel function calling during tool use (defaults to True) OpenAI and Groq only.", envvar="INSPECT_EVAL_PARALLEL_TOOL_CALLS", ) + @click.option( + "--internal-tools/--no-internal-tools", + type=bool, + is_flag=True, + default=True, + help="Whether to automatically map tools to model internal implementations (e.g. 'computer' for anthropic).", + envvar="INSPECT_EVAL_INTERNAL_TOOLS", + ) @click.option( "--max-tool-output", type=int, @@ -438,6 +446,7 @@ def eval_command( logprobs: bool | None, top_logprobs: int | None, parallel_tool_calls: bool | None, + internal_tools: bool | None, max_tool_output: int | None, cache_prompt: str | None, reasoning_effort: str | None, @@ -597,6 +606,7 @@ def eval_set_command( logprobs: bool | None, top_logprobs: int | None, parallel_tool_calls: bool | None, + internal_tools: bool | None, max_tool_output: int | None, cache_prompt: str | None, reasoning_effort: str | None, @@ -835,6 +845,9 @@ def config_from_locals(locals: dict[str, Any]) -> GenerateConfigArgs: if key == "parallel_tool_calls": if value is not False: value = None + if key == "internal_tools": + if value is not False: + value = None config[key] = value # type: ignore return config diff --git a/src/inspect_ai/_view/www/log-schema.json b/src/inspect_ai/_view/www/log-schema.json index cb6c58beb..8c81c74e7 100644 --- a/src/inspect_ai/_view/www/log-schema.json +++ b/src/inspect_ai/_view/www/log-schema.json @@ -1065,6 +1065,7 @@ "logprobs": null, "top_logprobs": null, "parallel_tool_calls": null, + "internal_tools": null, "max_tool_output": null, "cache_prompt": null, "reasoning_effort": null @@ -2118,6 +2119,18 @@ "default": null, "title": "Parallel Tool Calls" }, + "internal_tools": { + "anyOf": [ + { + "type": "boolean" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Internal Tools" + }, "max_tool_output": { "anyOf": [ { @@ -2186,6 +2199,7 @@ "logprobs", "top_logprobs", "parallel_tool_calls", + "internal_tools", "max_tool_output", "cache_prompt", "reasoning_effort" @@ -4123,6 +4137,7 @@ "best_of": null, "cache_prompt": null, "frequency_penalty": null, + "internal_tools": null, "logit_bias": null, "logprobs": null, "max_connections": null, diff --git a/src/inspect_ai/_view/www/src/types/log.d.ts b/src/inspect_ai/_view/www/src/types/log.d.ts index 7a8132c42..225b9c30f 100644 --- a/src/inspect_ai/_view/www/src/types/log.d.ts +++ b/src/inspect_ai/_view/www/src/types/log.d.ts @@ -77,6 +77,7 @@ export type NumChoices = number | null; export type Logprobs = boolean | null; export type TopLogprobs = number | null; export type ParallelToolCalls = boolean | null; +export type InternalTools = boolean | null; export type MaxToolOutput = number | null; export type CachePrompt = "auto" | boolean | null; export type ReasoningEffort = ("low" | "medium" | "high") | null; @@ -531,6 +532,7 @@ export interface GenerateConfig { logprobs: Logprobs; top_logprobs: TopLogprobs; parallel_tool_calls: ParallelToolCalls; + internal_tools: InternalTools; max_tool_output: MaxToolOutput; cache_prompt: CachePrompt; reasoning_effort: ReasoningEffort; @@ -873,6 +875,7 @@ export interface GenerateConfig1 { logprobs: Logprobs; top_logprobs: TopLogprobs; parallel_tool_calls: ParallelToolCalls; + internal_tools: InternalTools; max_tool_output: MaxToolOutput; cache_prompt: CachePrompt; reasoning_effort: ReasoningEffort; diff --git a/src/inspect_ai/model/_generate_config.py b/src/inspect_ai/model/_generate_config.py index a29cc8ce4..3e931afdd 100644 --- a/src/inspect_ai/model/_generate_config.py +++ b/src/inspect_ai/model/_generate_config.py @@ -66,6 +66,9 @@ class GenerateConfigArgs(TypedDict, total=False): parallel_tool_calls: bool | None """Whether to enable parallel function calling during tool use (defaults to True). OpenAI and Groq only.""" + internal_tools: bool | None + """Whether to automatically map tools to model internal implementations (e.g. 'computer' for anthropic).""" + max_tool_output: int | None """Maximum tool output (in bytes). Defaults to 16 * 1024.""" @@ -136,6 +139,9 @@ class GenerateConfig(BaseModel): parallel_tool_calls: bool | None = Field(default=None) """Whether to enable parallel function calling during tool use (defaults to True). OpenAI and Groq only.""" + internal_tools: bool | None = Field(default=None) + """Whether to automatically map tools to model internal implementations (e.g. 'computer' for anthropic).""" + max_tool_output: int | None = Field(default=None) """Maximum tool output (in bytes). Defaults to 16 * 1024.""" diff --git a/src/inspect_ai/model/_providers/anthropic.py b/src/inspect_ai/model/_providers/anthropic.py index 738031b51..5b9d696df 100644 --- a/src/inspect_ai/model/_providers/anthropic.py +++ b/src/inspect_ai/model/_providers/anthropic.py @@ -2,7 +2,7 @@ import os from copy import copy from logging import getLogger -from typing import Any, Literal, Tuple, cast +from typing import Any, Literal, NotRequired, Tuple, TypedDict, cast from anthropic import ( APIConnectionError, @@ -142,7 +142,7 @@ def model_call() -> ModelCall: system_param, tools_param, messages, - cache_prompt, + computer_use, ) = await resolve_chat_input(self.model_name, input, tools, config) # prepare request params (assembed this way so we can log the raw model call) @@ -158,13 +158,11 @@ def model_call() -> ModelCall: # additional options request = request | self.completion_params(config) - # caching header - if cache_prompt: - request["extra_headers"] = { - "anthropic-beta": "prompt-caching-2024-07-31" - } + # computer use beta + if computer_use: + request["extra_headers"] = {"anthropic-beta": "computer-use-2024-10-22"} - # call model + # make request message = await self.client.messages.create(**request, stream=False) # set response for ModelCall @@ -256,6 +254,9 @@ def handle_bad_request(self, ex: BadRequestError) -> ModelOutput | None: elif "content filtering" in error: content = "Sorry, but I am unable to help with that request." stop_reason = "content_filter" + else: + content = error + stop_reason = "unknown" if content and stop_reason: return ModelOutput.from_content( @@ -268,12 +269,26 @@ def handle_bad_request(self, ex: BadRequestError) -> ModelOutput | None: return None +# native anthropic tool definitions for computer use beta +# https://docs.anthropic.com/en/docs/build-with-claude/computer-use +class ComputerUseToolParam(TypedDict): + type: str + name: str + display_width_px: NotRequired[int] + display_height_px: NotRequired[int] + display_number: NotRequired[int] + + +# tools can be either a stock tool param or a special computer use tool param +ToolParamDef = ToolParam | ComputerUseToolParam + + async def resolve_chat_input( model: str, input: list[ChatMessage], tools: list[ToolInfo], config: GenerateConfig, -) -> Tuple[list[TextBlockParam] | None, list[ToolParam], list[MessageParam], bool]: +) -> Tuple[list[TextBlockParam] | None, list[ToolParamDef], list[MessageParam], bool]: # extract system message system_messages, messages = split_system_messages(input, config) @@ -286,14 +301,7 @@ async def resolve_chat_input( ) # tools - tools_params = [ - ToolParam( - name=tool.name, - description=tool.description, - input_schema=tool.parameters.model_dump(exclude_none=True), - ) - for tool in tools - ] + tools_params, computer_use = tool_params_for_tools(tools, config) # system messages if len(system_messages) > 0: @@ -343,10 +351,60 @@ async def resolve_chat_input( add_cache_control(cast(dict[str, Any], content[-1])) # return chat input - return system_param, tools_params, message_params, cache_prompt + return system_param, tools_params, message_params, computer_use + + +def tool_params_for_tools( + tools: list[ToolInfo], config: GenerateConfig +) -> tuple[list[ToolParamDef], bool]: + # tool params and computer_use bit to return + tool_params: list[ToolParamDef] = [] + computer_use = False + + # for each tool, check if it has a native computer use implementation and use that + # when available (noting that we need to set the computer use request header) + for tool in tools: + computer_use_tool = ( + computer_use_tool_param(tool) + if config.internal_tools is not False + else None + ) + if computer_use_tool: + tool_params.append(computer_use_tool) + computer_use = True + else: + tool_params.append( + ToolParam( + name=tool.name, + description=tool.description, + input_schema=tool.parameters.model_dump(exclude_none=True), + ) + ) + + return tool_params, computer_use + + +def computer_use_tool_param(tool: ToolInfo) -> ComputerUseToolParam | None: + # check for compatible 'computer' tool + if tool.name == "computer" and ( + sorted(tool.parameters.properties.keys()) + == sorted(["action", "coordinate", "text"]) + ): + return ComputerUseToolParam( + type="computer_20241022", + name="computer", + display_width_px=1024, + display_height_px=768, + display_number=1, + ) + # not a computer_use tool + else: + return None -def add_cache_control(param: TextBlockParam | ToolParam | dict[str, Any]) -> None: +def add_cache_control( + param: TextBlockParam | ToolParam | ComputerUseToolParam | dict[str, Any], +) -> None: cast(dict[str, Any], param)["cache_control"] = {"type": "ephemeral"} diff --git a/tools/vscode/src/@types/log.d.ts b/tools/vscode/src/@types/log.d.ts index 7a8132c42..225b9c30f 100644 --- a/tools/vscode/src/@types/log.d.ts +++ b/tools/vscode/src/@types/log.d.ts @@ -77,6 +77,7 @@ export type NumChoices = number | null; export type Logprobs = boolean | null; export type TopLogprobs = number | null; export type ParallelToolCalls = boolean | null; +export type InternalTools = boolean | null; export type MaxToolOutput = number | null; export type CachePrompt = "auto" | boolean | null; export type ReasoningEffort = ("low" | "medium" | "high") | null; @@ -531,6 +532,7 @@ export interface GenerateConfig { logprobs: Logprobs; top_logprobs: TopLogprobs; parallel_tool_calls: ParallelToolCalls; + internal_tools: InternalTools; max_tool_output: MaxToolOutput; cache_prompt: CachePrompt; reasoning_effort: ReasoningEffort; @@ -873,6 +875,7 @@ export interface GenerateConfig1 { logprobs: Logprobs; top_logprobs: TopLogprobs; parallel_tool_calls: ParallelToolCalls; + internal_tools: InternalTools; max_tool_output: MaxToolOutput; cache_prompt: CachePrompt; reasoning_effort: ReasoningEffort;