Skip to content

Commit

Permalink
Use sandbox_with.
Browse files Browse the repository at this point in the history
  • Loading branch information
Eric Patey committed Jan 13, 2025
1 parent e8c0d2c commit 2d32bb4
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 69 deletions.
25 changes: 23 additions & 2 deletions src/inspect_ai/tool/_tools/_computer/_common.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import json
from textwrap import dedent
from typing import Literal

from pydantic import BaseModel, Field

from inspect_ai._util.content import ContentText
from inspect_ai._util.error import PrerequisiteError
from inspect_ai.model import ContentImage
from inspect_ai.tool import ToolError, ToolResult
from inspect_ai.util import sandbox
from inspect_ai.util._sandbox.context import sandbox_with
from inspect_ai.util._sandbox.environment import SandboxEnvironment

Action = Literal[
"key",
Expand Down Expand Up @@ -38,7 +41,7 @@ async def _send_cmd(cmdTail: list[str], timeout: int | None = None) -> ToolResul

cmd = ["python3", "-m", "computer_tool.computer_tool", "--action"] + cmdTail

raw_exec_result = await sandbox().exec(cmd, timeout=timeout)
raw_exec_result = await (await computer_sandbox()).exec(cmd, timeout=timeout)

if not raw_exec_result.success:
raise RuntimeError(
Expand Down Expand Up @@ -111,3 +114,21 @@ async def press_key(key: str, timeout: int | None = None) -> ToolResult:

async def type(text: str, timeout: int | None = None) -> ToolResult:
return await _send_cmd(["type", "--text", text], timeout=timeout)


async def computer_sandbox() -> SandboxEnvironment:
sb = await sandbox_with("/opt/computer_tool/computer_tool.py")
if sb:
return sb
else:
raise PrerequisiteError(
dedent("""
The computer tool service was not found in any of the sandboxes for this sample. Please add the computer tool service to your configuration. For example, the following Docker compose file uses the (currently internal) inspect-computer-tool image as its default sandbox:
services:
default:
# Temporary internal image until the official one is available
image: "inspect-computer-tool"
init: true
""").strip()
)
128 changes: 61 additions & 67 deletions src/inspect_ai/tool/_tools/_computer/_computer.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,72 +62,66 @@ async def execute(
Returns:
The output of the command. Many commands will include a screenshot reflecting the result of the command in their output.
"""
try:
if action in ("mouse_move", "left_click_drag"):
if coordinate is None:
raise ToolParsingError(f"coordinate is required for {action}")
if text is not None:
raise ToolParsingError(f"text is not accepted for {action}")
if not isinstance(coordinate, list) or len(coordinate) != 2:
raise ToolParsingError(f"{coordinate} must be a tuple of length 2")
if not all(isinstance(i, int) and i >= 0 for i in coordinate):
raise ToolParsingError(
f"{coordinate} must be a tuple of non-negative ints"
)

if action == "mouse_move":
return await common.mouse_move(
coordinate[0], coordinate[1], timeout=timeout
)
elif action == "left_click_drag":
return await common.left_click_drag(
coordinate[0], coordinate[1], timeout=timeout
)

if action in ("key", "type"):
if text is None:
raise ToolParsingError(f"text is required for {action}")
if coordinate is not None:
raise ToolParsingError(f"coordinate is not accepted for {action}")
if not isinstance(text, str):
raise ToolParsingError(output=f"{text} must be a string")

if action == "key":
return await common.press_key(text, timeout=timeout)
elif action == "type":
return await common.type(text, timeout=timeout)

if action in (
"left_click",
"right_click",
"double_click",
"middle_click",
"screenshot",
"cursor_position",
):
if text is not None:
raise ToolParsingError(f"text is not accepted for {action}")
if coordinate is not None:
raise ToolParsingError(f"coordinate is not accepted for {action}")

if action == "screenshot":
return await common.screenshot(timeout=timeout)
elif action == "cursor_position":
return await common.cursor_position(timeout=timeout)
elif action == "left_click":
return await common.left_click(timeout=timeout)
elif action == "right_click":
return await common.right_click(timeout=timeout)
elif action == "middle_click":
return await common.middle_click(timeout=timeout)
elif action == "double_click":
return await common.double_click(timeout=timeout)

raise ToolParsingError(f"Invalid action: {action}")

except ToolError:
raise
except Exception as e:
raise ToolError(str(e))
if action in ("mouse_move", "left_click_drag"):
if coordinate is None:
raise ToolParsingError(f"coordinate is required for {action}")
if text is not None:
raise ToolParsingError(f"text is not accepted for {action}")
if not isinstance(coordinate, list) or len(coordinate) != 2:
raise ToolParsingError(f"{coordinate} must be a tuple of length 2")
if not all(isinstance(i, int) and i >= 0 for i in coordinate):
raise ToolParsingError(
f"{coordinate} must be a tuple of non-negative ints"
)

if action == "mouse_move":
return await common.mouse_move(
coordinate[0], coordinate[1], timeout=timeout
)
elif action == "left_click_drag":
return await common.left_click_drag(
coordinate[0], coordinate[1], timeout=timeout
)

if action in ("key", "type"):
if text is None:
raise ToolParsingError(f"text is required for {action}")
if coordinate is not None:
raise ToolParsingError(f"coordinate is not accepted for {action}")
if not isinstance(text, str):
raise ToolParsingError(output=f"{text} must be a string")

if action == "key":
return await common.press_key(text, timeout=timeout)
elif action == "type":
return await common.type(text, timeout=timeout)

if action in (
"left_click",
"right_click",
"double_click",
"middle_click",
"screenshot",
"cursor_position",
):
if text is not None:
raise ToolParsingError(f"text is not accepted for {action}")
if coordinate is not None:
raise ToolParsingError(f"coordinate is not accepted for {action}")

if action == "screenshot":
return await common.screenshot(timeout=timeout)
elif action == "cursor_position":
return await common.cursor_position(timeout=timeout)
elif action == "left_click":
return await common.left_click(timeout=timeout)
elif action == "right_click":
return await common.right_click(timeout=timeout)
elif action == "middle_click":
return await common.middle_click(timeout=timeout)
elif action == "double_click":
return await common.double_click(timeout=timeout)

raise ToolParsingError(f"Invalid action: {action}")

return execute

0 comments on commit 2d32bb4

Please sign in to comment.