diff --git a/examples/experimental/nodes/initial_message_node.py b/examples/experimental/nodes/initial_message_node.py index 9cb7f63ca..9ff4c3bdf 100644 --- a/examples/experimental/nodes/initial_message_node.py +++ b/examples/experimental/nodes/initial_message_node.py @@ -18,6 +18,7 @@ def __init__( input_tick_channel: str, output_channels: list[str], env_scenario: str, + node_name: str, redis_url: str = "redis://localhost:6379/0", ): super().__init__( @@ -26,6 +27,7 @@ def __init__( (output_channel, Text) for output_channel in output_channels ], redis_url=redis_url, + node_name=node_name, ) self.env_scenario = env_scenario self.output_channels = output_channels diff --git a/examples/experimental/sotopia_original_replica/llm_agent_sotopia.py b/examples/experimental/sotopia_original_replica/llm_agent_sotopia.py new file mode 100644 index 000000000..abe959294 --- /dev/null +++ b/examples/experimental/sotopia_original_replica/llm_agent_sotopia.py @@ -0,0 +1,113 @@ +import logging +import sys +from rich.logging import RichHandler + +from aact import NodeFactory + +from sotopia.experimental.agents.base_agent import BaseAgent +from sotopia.experimental.agents.datamodels import Observation, AgentAction + +from sotopia.generation_utils import agenerate +from sotopia.generation_utils.generate import StrOutputParser + +# Check Python version +if sys.version_info >= (3, 11): + pass +else: + pass + +# Configure logging +FORMAT = "%(asctime)s - %(levelname)s - %(name)s - %(message)s" +logging.basicConfig( + level=logging.WARNING, + format=FORMAT, + datefmt="[%X]", + handlers=[RichHandler()], +) + + +@NodeFactory.register("llm_agent") +class LLMAgent(BaseAgent[Observation, AgentAction]): + def __init__( + self, + input_channels: list[str], + output_channel: str, + query_interval: int, + agent_name: str, + node_name: str, + goal: str, + model_name: str, + redis_url: str, + ): + super().__init__( + [(input_channel, Observation) for input_channel in input_channels], + [(output_channel, AgentAction)], + redis_url, + node_name, + ) + self.output_channel = output_channel + self.query_interval = query_interval + self.count_ticks = 0 + self.message_history: list[Observation] = [] + self.name = agent_name + self.model_name = model_name + self.goal = goal + + def _format_message_history(self, message_history: list[Observation]) -> str: + ## TODO: akhatua Fix the mapping of action to be gramatically correct + return "\n".join(message.to_natural_language() for message in message_history) + + async def aact(self, obs: Observation) -> AgentAction: + if obs.turn_number == -1: + return AgentAction( + agent_name=self.name, + output_channel=self.output_channel, + action_type="none", + argument=self.model_name, + ) + + self.message_history.append(obs) + + if len(obs.available_actions) == 1 and "none" in obs.available_actions: + return AgentAction( + agent_name=self.name, + output_channel=self.output_channel, + action_type="none", + argument="", + ) + elif len(obs.available_actions) == 1 and "leave" in obs.available_actions: + self.shutdown_event.set() + return AgentAction( + agent_name=self.name, + output_channel=self.output_channel, + action_type="leave", + argument="", + ) + else: + history = self._format_message_history(self.message_history) + action: str = await agenerate( + model_name=self.model_name, + template="Imagine that you are a friend of the other persons. Here is the " + "conversation between you and them.\n" + "You are {agent_name} in the conversation.\n" + "{message_history}\n" + "and you plan to {goal}.\n" + "You can choose to interrupt the other person " + "by saying something or not to interrupt by outputting notiong. What would you say? " + "Please only output a sentence or not outputting anything." + "{format_instructions}", + input_values={ + "message_history": history, + "goal": self.goal, + "agent_name": self.name, + }, + temperature=0.7, + output_parser=StrOutputParser(), + ) + + return AgentAction( + agent_name=self.name, + output_channel=self.output_channel, + action_type="speak", + argument=action, + ) diff --git a/examples/experimental/sotopia_original_replica/origin.svg b/examples/experimental/sotopia_original_replica/origin.svg new file mode 100644 index 000000000..78717b14a --- /dev/null +++ b/examples/experimental/sotopia_original_replica/origin.svg @@ -0,0 +1 @@ +

examples/experimental/sotopia_original_replica/origin.toml

Jane:moderator

Jack:moderator

moderator:Jane

moderator:Jack

Jane:Jack

Jack:Jane

Agent:Runtime

'Jane'

'moderator'

'Jack'

'chat_print'

diff --git a/examples/experimental/sotopia_original_replica/origin.toml b/examples/experimental/sotopia_original_replica/origin.toml new file mode 100644 index 000000000..7bf225273 --- /dev/null +++ b/examples/experimental/sotopia_original_replica/origin.toml @@ -0,0 +1,52 @@ +redis_url = "redis://localhost:6379/0" +extra_modules = ["examples.experimental.sotopia_original_replica.llm_agent_sotopia", "examples.experimental.nodes.chat_print_node", "sotopia.experimental.agents.moderator"] + + +[[nodes]] +node_name = "moderator" +node_class = "moderator" + +[nodes.node_args] +output_channels = ["moderator:Jane", "moderator:Jack"] +input_channels = ["Jane:moderator", "Jack:moderator"] +agent_backgrounds = {"Jane" = "", "Jack" = ""} +agent_mapping = {"moderator:Jane" = "Jane", "moderator:Jack" = "Jack"} +scenario = "Two friends are sitting in a cafe and catching up with each other's lives." +max_turns = 2 +push_to_db = false + +[[nodes]] +node_name = "Jack" +node_class = "llm_agent" + +[nodes.node_args] +query_interval = 5 +input_channels = ["moderator:Jack"] +output_channel = "Jack:moderator" +goal = "Your goal is to borrow 5000 dollars from Jane." +model_name = "gpt-4o-mini" +agent_name = "Jack" + + +[[nodes]] +node_name = "Jane" +node_class = "llm_agent" + +[nodes.node_args] +query_interval = 7 +output_channel = "Jane:moderator" +input_channels = ["moderator:Jane"] +goal = "Your goal is to help Jack however, you are in a finicial crisis yourself and can only afford to give him 500 dollars." +model_name = "gpt-4o-mini" +agent_name = "Jane" + +[[nodes]] +node_name = "chat_print" +node_class = "chat_print" + +[nodes.node_args.print_channel_types] +"Jane:moderator" = "agent_action" +"Jack:moderator" = "agent_action" + +[nodes.node_args] +env_agents = ["Jack", "Jane"] diff --git a/examples/experimental/sotopia_original_replica/readme.md b/examples/experimental/sotopia_original_replica/readme.md new file mode 100644 index 000000000..cb3931dc7 --- /dev/null +++ b/examples/experimental/sotopia_original_replica/readme.md @@ -0,0 +1,13 @@ +To run this example, please use aact to launch. + +```bash +aact run-dataflow examples/experimental/sotopia_original_replica/origin.toml +``` + +To view the flow of the information, please run: + +```bash +aact draw-dataflow examples/experimental/sotopia_original_replica/origin.toml --svg-path examples/experimental/sotopia_original_replica/origin.svg +``` + +![Alt text](./origin.svg) diff --git a/pyproject.toml b/pyproject.toml index 57af6cc3a..b9edcc942 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,6 +69,9 @@ plugins = [ module = "transformers.*" ignore_missing_imports = true +[tool.uv.sources] +aact = { git = "https://github.com/ProKil/aact" , branch = "feature/node-manager" } + [tool.pytest.ini_options] testpaths = ["tests"] python_files = "test_*.py" diff --git a/sotopia/experimental/agents/base_agent.py b/sotopia/experimental/agents/base_agent.py index a7bbafae6..6d9466bbc 100644 --- a/sotopia/experimental/agents/base_agent.py +++ b/sotopia/experimental/agents/base_agent.py @@ -22,11 +22,13 @@ def __init__( input_channel_types: list[tuple[str, type[T_agent_observation]]], output_channel_types: list[tuple[str, type[T_agent_action]]], redis_url: str = "redis://localhost:6379/0", + node_name: str = "base_agent", ): super().__init__( input_channel_types=input_channel_types, output_channel_types=output_channel_types, redis_url=redis_url, + node_name=node_name, ) self.observation_queue: asyncio.Queue[T_agent_observation] = asyncio.Queue() diff --git a/sotopia/experimental/agents/datamodels.py b/sotopia/experimental/agents/datamodels.py new file mode 100644 index 000000000..a243a52a3 --- /dev/null +++ b/sotopia/experimental/agents/datamodels.py @@ -0,0 +1,42 @@ +from sotopia.messages import ActionType +from aact.messages import DataModel, DataModelFactory +from pydantic import Field + + +@DataModelFactory.register("observation") +class Observation(DataModel): + agent_name: str = Field(description="the name of the agent") + last_turn: str = Field(description="the last turn of the conversation") + turn_number: int = Field(description="the turn number of the conversation") + available_actions: list[ActionType] = Field(description="the available actions") + + def to_natural_language(self) -> str: + if self.turn_number == 0: + return f"\n{self.last_turn}\nConversation Starts:\n" + else: + return f"Turn #{self.turn_number-1}: {self.last_turn}\n" + + +@DataModelFactory.register("agent_action") +class AgentAction(DataModel): + agent_name: str = Field(description="the name of the agent") + output_channel: str = Field(description="the name of the output channel") + action_type: ActionType = Field( + description="whether to speak at this turn or choose to not do anything" + ) + argument: str = Field( + description="the utterance if choose to speak, the expression or gesture if choose non-verbal communication, or the physical action if choose action" + ) + + def to_natural_language(self) -> str: + match self.action_type: + case "none": + return "did nothing" + case "speak": + return f'said: "{self.argument}"' + case "non-verbal communication": + return f"[{self.action_type}] {self.argument}" + case "action": + return f"[{self.action_type}] {self.argument}" + case "leave": + return "left the conversation" diff --git a/sotopia/experimental/agents/moderator.py b/sotopia/experimental/agents/moderator.py new file mode 100644 index 000000000..ce57fb38b --- /dev/null +++ b/sotopia/experimental/agents/moderator.py @@ -0,0 +1,270 @@ +import asyncio +import sys + + +if sys.version_info < (3, 11): + from typing_extensions import Self +else: + from typing import Self + + +from aact import Message, NodeFactory, Node +from aact.messages import DataModel, DataModelFactory + +from typing import Literal, Any, AsyncIterator +from pydantic import Field + +from sotopia.database import EpisodeLog +from .datamodels import AgentAction, Observation +from sotopia.messages import ActionType + + +@DataModelFactory.register("observations") +class Observations(DataModel): + observations_map: dict[str, Observation] = Field( + description="the observations of the agents" + ) + + +@NodeFactory.register("moderator") +class Moderator(Node[AgentAction, Observation]): + def __init__( + self, + input_channels: list[str], + output_channels: list[str], + scenario: str, + agent_mapping: dict[str, str], + node_name: str, + agent_backgrounds: dict[str, str], + redis_url: str = "redis://localhost:6379/0", + action_order: Literal["simultaneous", "round-robin", "random"] = "round-robin", + available_actions: list[ActionType] = [ + "none", + "speak", + "non-verbal communication", + "action", + "leave", + ], + max_turns: int = 20, + push_to_db: bool = False, + ): + super().__init__( + input_channel_types=[ + (input_channel, AgentAction) for input_channel in input_channels + ], + output_channel_types=[ + (output_channel, Observation) for output_channel in output_channels + ], + redis_url=redis_url, + node_name=node_name, + ) + self.observation_queue: asyncio.Queue[AgentAction] = asyncio.Queue() + self.task_scheduler: asyncio.Task[None] | None = None + self.shutdown_event: asyncio.Event = asyncio.Event() + self.agent_mapping: dict[str, str] = agent_mapping + self.action_order: Literal["simultaneous", "round-robin", "random"] = ( + action_order + ) + self.available_actions: list[ActionType] = available_actions + self.turn_number: int = 0 + self.max_turns: int = max_turns + self.current_agent_index: int = 0 + self.scenario: str = scenario + self.agents: list[str] = list(agent_mapping.values()) + self.agent_models: dict[str, str] = {} + self.agents_awake: dict[str, bool] = {name: False for name in self.agents} + self.all_agents_awake: asyncio.Event = asyncio.Event() + self.message_history: list[list[tuple[str, str, str]]] = [ + [("Environment", "Environment", self.scenario)] + ] + self.push_to_db = push_to_db + self.agent_backgrounds = agent_backgrounds + + if self.action_order == "round-robin": + pass + else: + raise NotImplementedError( + "the selected action order is currently not implemented" + ) + + async def __aenter__(self) -> Self: + print(self.scenario) + asyncio.create_task(self.booting()) + self.task_scheduler = asyncio.create_task(self._task_scheduler()) + return await super().__aenter__() + + async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: + self.shutdown_event.set() + if self.task_scheduler is not None: + self.task_scheduler.cancel() + return await super().__aexit__(exc_type, exc_value, traceback) + + async def send(self, observations: Observations) -> None: + for output_channel, output_channel_type in self.output_channel_types.items(): + if output_channel in observations.observations_map: + await self.r.publish( + output_channel, + Message[output_channel_type]( # type:ignore[valid-type] + data=observations.observations_map[output_channel] + ).model_dump_json(), + ) + + async def event_handler( + self, channel: str, message: Message[AgentAction] + ) -> AsyncIterator[tuple[str, Message[Observation]]]: + if channel in self.input_channel_types: + await self.observation_queue.put(message.data) + else: + raise ValueError(f"Invalid channel: {channel}") + yield "", self.output_type() + + async def _task_scheduler(self) -> None: + await self.all_agents_awake.wait() + while not self.shutdown_event.is_set(): + observation = await self.observation_queue.get() + action_or_none = await self.aact(observation) + if action_or_none is not None: + await self.send(action_or_none) + self.observation_queue.task_done() + + async def booting(self) -> None: + """ + 1. send checking message to agents for every 0.1 seconds, until all agents are awake + - this message has turn_number of -1 for identification, agents should not record this into actual message_history + - if the agent booted succesfully, he is expected to return its model name for record. + 2. (under round-robin action order)after all agents are awake, send agent[0] a message to allow the agent to start speaking + """ + while not self.all_agents_awake.is_set(): + await self.send( + Observations( + observations_map={ + output_channel: Observation( + agent_name="moderator", + last_turn=self.scenario, + turn_number=-1, + available_actions=["none"], + ) + for output_channel, agent_name in self.agent_mapping.items() + } + ) + ) + await asyncio.sleep(0.1) + while not self.observation_queue.empty(): + agent_action = await self.observation_queue.get() + self.agents_awake[agent_action.agent_name] = True + self.agent_models[agent_action.agent_name] = agent_action.argument + if False not in self.agents_awake.values(): + self.all_agents_awake.set() + + if self.action_order == "round-robin": + await self.send( + Observations( + observations_map={ + output_channel: Observation( + agent_name="moderator", + last_turn=self.agent_backgrounds[agent_name], + turn_number=0, + available_actions=self.available_actions + if agent_name == self.agents[0] + else ["none"], + ) + for output_channel, agent_name in self.agent_mapping.items() + } + ) + ) + self.current_agent_index += 1 + + async def wrap_up_and_stop(self) -> None: + if self.push_to_db: + await self.save() + await asyncio.sleep(0.5) + print("stopping all agents") + await self.r.publish( + f"shutdown:{self.node_name}", + "shutdown", + ) + + async def save(self) -> EpisodeLog: + """ + save the EpisodeLog to redis, without evaluating + TODO: specify what to be added inside tag + TODO: update the code so that EpisodeLog.render_for_humans() can work + -currently it cannot work because no AgentProfile has been uploaded to redis + -such a process should be done back in the agents' end + -also the current agentslist is consist of names, but not uuid's of agents + """ + epilog = EpisodeLog( + environment=self.scenario, + agents=self.agents, + tag=None, + models=list(self.agent_models.values()), + messages=self.message_history, + reasoning="", + rewards=[0] * len(self.agents), + rewards_prompt="", + ) + epilog.save() + # print(epilog.render_for_humans()) + return epilog + + async def aact(self, agent_action: AgentAction) -> Observations | None: + if agent_action.action_type == "leave": + self.agents_awake[agent_action.agent_name] = False + if True not in self.agents_awake.values(): + await self.wrap_up_and_stop() + return None + if agent_action.action_type == "none": + return None + + if len(self.message_history) == 1: + self.message_history[0].append( + ( + agent_action.agent_name, + "Environment", + agent_action.to_natural_language(), + ) + ) + else: + self.message_history.append( + [ + ( + agent_action.agent_name, + "Environment", + agent_action.to_natural_language(), + ) + ] + ) + + if self.turn_number < self.max_turns: + self.turn_number += 1 + else: + return Observations( + observations_map={ + output_channel: Observation( + agent_name="moderator", + last_turn=self.scenario, + turn_number=self.turn_number + 1, + available_actions=["leave"], + ) + for output_channel, agent_name in self.agent_mapping.items() + } + ) + + observations_map: dict[str, Observation] = {} + for output_channel, output_channel_type in self.output_channel_types.items(): + agent_name = self.agent_mapping[output_channel] + available_actions: list[ActionType] = ["none"] + if self.action_order == "round-robin": + if agent_name == self.agents[self.current_agent_index]: + available_actions = self.available_actions + + observation = Observation( + agent_name=agent_name, + last_turn=agent_action.to_natural_language(), + turn_number=self.turn_number, + available_actions=available_actions, + ) + observations_map[output_channel] = observation + self.current_agent_index = (self.current_agent_index + 1) % len(self.agents) + + return Observations(observations_map=observations_map) diff --git a/tests/experimental/test_agent.py b/tests/experimental/test_agent.py index 020c2131b..834c4286c 100644 --- a/tests/experimental/test_agent.py +++ b/tests/experimental/test_agent.py @@ -19,11 +19,13 @@ async def aact(self, observation: Tick) -> Tick: @pytest.mark.asyncio async def test_base_agent() -> None: async with ReturnPlusOneAgent( + node_name="test_base_agent", input_channel_types=[("input", Tick)], output_channel_types=[("output", Tick)], redis_url="redis://localhost:6379/0", ) as agent1: async with ReturnPlusOneAgent( + node_name="test_base_agent_2", input_channel_types=[("output", Tick)], output_channel_types=[("final", Tick)], redis_url="redis://localhost:6379/0", diff --git a/uv.lock b/uv.lock index 5017e0e00..217200dd2 100644 --- a/uv.lock +++ b/uv.lock @@ -10,9 +10,10 @@ resolution-markers = [ [[package]] name = "aact" version = "0.0.10" -source = { registry = "https://pypi.org/simple" } +source = { git = "https://github.com/ProKil/aact?branch=feature%2Fnode-manager#56cd2a2aad8a0e806e4f3a170e848cb1e1ad0720" } dependencies = [ { name = "aiofiles" }, + { name = "aiohttp" }, { name = "aiostream" }, { name = "numpy" }, { name = "pydantic" }, @@ -22,10 +23,6 @@ dependencies = [ { name = "tomlkit", marker = "python_full_version < '3.11'" }, { name = "typer" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/6e/9f/2b32aca3e2fe614df4e04a074870b6b27ef037af62f639b0e4d0b33abb31/aact-0.0.10.tar.gz", hash = "sha256:0cde5360d27bab002a43e9895c4006bfa541f6c2db798412f4aad1fdb685632e", size = 113329 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/31/18/32beed32416f8c9618ed4fc42e33eef94d7c181caf59c6909b3841047006/aact-0.0.10-py3-none-any.whl", hash = "sha256:2c1959666270acc681aafc1452aa089cb26a24a0871b01faa7761fa300b2fc9a", size = 29102 }, -] [[package]] name = "absl-py" @@ -3144,7 +3141,7 @@ dev = [ [package.metadata] requires-dist = [ - { name = "aact" }, + { name = "aact", git = "https://github.com/ProKil/aact?branch=feature%2Fnode-manager" }, { name = "absl-py", specifier = ">=2.0.0,<3.0.0" }, { name = "anthropic", marker = "extra == 'anthropic'" }, { name = "beartype", specifier = ">=0.14.0,<0.20.0" },