Skip to content

Commit

Permalink
add get_acc_tree to osworldaci (ubuntu); update readme
Browse files Browse the repository at this point in the history
  • Loading branch information
alckasoc committed Jan 14, 2025
1 parent d365b04 commit b44ed02
Show file tree
Hide file tree
Showing 2 changed files with 217 additions and 6 deletions.
23 changes: 17 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -133,14 +133,21 @@ This will show a user query prompt where you can enter your query and interact w
To deploy Agent S on MacOS or Windows:

```
platform = "Darwin" # or "Windows"
import pyautogui
import io
from gui_agents.core.AgentS import GraphSearchAgent
import platform
if platform == "Darwin":
if platform.system() == "Darwin":
from gui_agents.aci.MacOSACI import MacOSACI, UIElement
grounding_agent = MacOSACI()
elif platform == "Windows":
elif platform.system() == "Windows":
from gui_agents.aci.WindowsOSACI import WindowsACI, UIElement
grounding_agent = WindowsACI()
elif platform.system() == "Linux":
from gui_agents.aci.OSWorldACI import OSWorldACI, get_acc_tree
grounding_agent = OSWorldACI()
else:
raise ValueError("Unsupported platform")
Expand All @@ -152,25 +159,29 @@ engine_params = {
agent = GraphSearchAgent(
engine_params,
grounding_agent,
platform=platform,
platform="ubuntu", # "macos", "windows"
action_space="pyautogui",
observation_type="mixed",
search_engine="Perplexica"
)
# Get screenshot.
screenshot = pyautogui.screenshot()
buffered = io.BytesIO()
buffered = io.BytesIO()
screenshot.save(buffered, format="PNG")
screenshot_bytes = buffered.getvalue()
# Get accessibility tree.
acc_tree = UIElement.systemWideElement()
if platform.system() != "Linux":
acc_tree = UIElement.systemWideElement()
elif platform.system() == "Linux":
acc_tree = get_acc_tree()
obs = {
"screenshot": screenshot_bytes,
"accessibility_tree": acc_tree,
}
instruction = "Close VS Code"
info, action = agent.predict(instruction=instruction, observation=obs)
Expand Down
200 changes: 200 additions & 0 deletions gui_agents/aci/OSWorldACI.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,27 @@

from gui_agents.aci.ACI import ACI

import platform

if platform.system() == "Linux":
import pyatspi
from pyatspi import Accessible, StateType, STATE_SHOWING
from pyatspi import Action as ATAction
from pyatspi import Component # , Document
from pyatspi import Text as ATText
from pyatspi import Value as ATValue

from pyatspi import Accessible, StateType
from lxml.etree import _Element
from typing import Optional, Dict, Any, List

import platform
from typing import Any, Optional, Sequence
import lxml.etree
from flask import jsonify
from lxml.etree import _Element
import concurrent.futures

logger = logging.getLogger("desktopenv.agent")

install_tmux_cmd = """import subprocess
Expand Down Expand Up @@ -655,3 +676,182 @@ def done(self):
def fail(self):
"""End the current task with a failure"""
return """FAIL"""


_accessibility_ns_map_ubuntu = {
"st": "https://accessibility.ubuntu.example.org/ns/state",
"attr": "https://accessibility.ubuntu.example.org/ns/attributes",
"cp": "https://accessibility.ubuntu.example.org/ns/component",
"doc": "https://accessibility.ubuntu.example.org/ns/document",
"docattr": "https://accessibility.ubuntu.example.org/ns/document/attributes",
"txt": "https://accessibility.ubuntu.example.org/ns/text",
"val": "https://accessibility.ubuntu.example.org/ns/value",
"act": "https://accessibility.ubuntu.example.org/ns/action",
}

MAX_WIDTH = 1024
MAX_DEPTH = 50

# Ref: https://github.com/xlang-ai/OSWorld/blob/main/desktop_env/server/main.py#L793
def _create_atspi_node(node: Accessible, depth: int = 0, flag: Optional[str] = None) -> _Element:
node_name = node.name
attribute_dict: Dict[str, Any] = {"name": node_name}

# States
states: List[StateType] = node.getState().get_states()
for st in states:
state_name: str = StateType._enum_lookup[st]
state_name: str = state_name.split("_", maxsplit=1)[1].lower()
if len(state_name) == 0:
continue
attribute_dict["{{{:}}}{:}".format(_accessibility_ns_map_ubuntu["st"], state_name)] = "true"

# Attributes
attributes: Dict[str, str] = node.get_attributes()
for attribute_name, attribute_value in attributes.items():
if len(attribute_name) == 0:
continue
attribute_dict["{{{:}}}{:}".format(_accessibility_ns_map_ubuntu["attr"], attribute_name)] = attribute_value

# Component
if attribute_dict.get("{{{:}}}visible".format(_accessibility_ns_map_ubuntu["st"]), "false") == "true" \
and attribute_dict.get("{{{:}}}showing".format(_accessibility_ns_map_ubuntu["st"]), "false") == "true":
try:
component: Component = node.queryComponent()
except NotImplementedError:
pass
else:
bbox: Sequence[int] = component.getExtents(pyatspi.XY_SCREEN)
attribute_dict["{{{:}}}screencoord".format(_accessibility_ns_map_ubuntu["cp"])] = \
str(tuple(bbox[0:2]))
attribute_dict["{{{:}}}size".format(_accessibility_ns_map_ubuntu["cp"])] = str(tuple(bbox[2:]))

text = ""
# Text
try:
text_obj: ATText = node.queryText()
# only text shown on current screen is available
# attribute_dict["txt:text"] = text_obj.getText(0, text_obj.characterCount)
text: str = text_obj.getText(0, text_obj.characterCount)
# if flag=="thunderbird":
# appeared in thunderbird (uFFFC) (not only in thunderbird), "Object
# Replacement Character" in Unicode, "used as placeholder in text for
# an otherwise unspecified object; uFFFD is another "Replacement
# Character", just in case
text = text.replace("\ufffc", "").replace("\ufffd", "")
except NotImplementedError:
pass

# Image, Selection, Value, Action
try:
node.queryImage()
attribute_dict["image"] = "true"
except NotImplementedError:
pass

try:
node.querySelection()
attribute_dict["selection"] = "true"
except NotImplementedError:
pass

try:
value: ATValue = node.queryValue()
value_key = f"{{{_accessibility_ns_map_ubuntu['val']}}}"

for attr_name, attr_func in [
("value", lambda: value.currentValue),
("min", lambda: value.minimumValue),
("max", lambda: value.maximumValue),
("step", lambda: value.minimumIncrement)
]:
try:
attribute_dict[f"{value_key}{attr_name}"] = str(attr_func())
except:
pass
except NotImplementedError:
pass

try:
action: ATAction = node.queryAction()
for i in range(action.nActions):
action_name: str = action.getName(i).replace(" ", "-")
attribute_dict[
"{{{:}}}{:}_desc".format(_accessibility_ns_map_ubuntu["act"], action_name)] = action.getDescription(
i)
attribute_dict[
"{{{:}}}{:}_kb".format(_accessibility_ns_map_ubuntu["act"], action_name)] = action.getKeyBinding(i)
except NotImplementedError:
pass

# Add from here if we need more attributes in the future...

raw_role_name: str = node.getRoleName().strip()
node_role_name = (raw_role_name or "unknown").replace(" ", "-")

if not flag:
if raw_role_name == "document spreadsheet":
flag = "calc"
if raw_role_name == "application" and node.name == "Thunderbird":
flag = "thunderbird"

xml_node = lxml.etree.Element(
node_role_name,
attrib=attribute_dict,
nsmap=_accessibility_ns_map_ubuntu
)

if len(text) > 0:
xml_node.text = text

if depth == MAX_DEPTH:
return xml_node

if flag == "calc" and node_role_name == "table":
# Maximum column: 1024 if ver<=7.3 else 16384
# Maximum row: 104 8576
# Maximun sheet: 1 0000

global libreoffice_version_tuple
MAXIMUN_COLUMN = 1024 if libreoffice_version_tuple < (7, 4) else 16384
MAX_ROW = 104_8576

index_base = 0
first_showing = False
column_base = None
for r in range(MAX_ROW):
for clm in range(column_base or 0, MAXIMUN_COLUMN):
child_node: Accessible = node[index_base + clm]
showing: bool = child_node.getState().contains(STATE_SHOWING)
if showing:
child_node: _Element = _create_atspi_node(child_node, depth + 1, flag)
if not first_showing:
column_base = clm
first_showing = True
xml_node.append(child_node)
elif first_showing and column_base is not None or clm >= 500:
break
if first_showing and clm == column_base or not first_showing and r >= 500:
break
index_base += MAXIMUN_COLUMN
return xml_node
else:
try:
for i, ch in enumerate(node):
if i == MAX_WIDTH:
break
xml_node.append(_create_atspi_node(ch, depth + 1, flag))
except:
pass
return xml_node

def get_acc_tree() -> str:
desktop: Accessible = pyatspi.Registry.getDesktop(0)
xml_node = lxml.etree.Element("desktop-frame", nsmap=_accessibility_ns_map_ubuntu)
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = [executor.submit(_create_atspi_node, app_node, 1) for app_node in desktop]
for future in concurrent.futures.as_completed(futures):
xml_tree = future.result()
xml_node.append(xml_tree)
acc_tree = lxml.etree.tostring(xml_node, encoding="unicode")
return acc_tree

0 comments on commit b44ed02

Please sign in to comment.