From 5e5a9c0ac7a3520bc3b1a634dbd473e00618f1a5 Mon Sep 17 00:00:00 2001 From: Zhe Su <360307598@qq.com> Date: Wed, 20 Nov 2024 18:31:51 -0500 Subject: [PATCH] update the returned message types --- sotopia/ui/websocket_utils.py | 100 +++++++++++++++++++++++++++------- 1 file changed, 81 insertions(+), 19 deletions(-) diff --git a/sotopia/ui/websocket_utils.py b/sotopia/ui/websocket_utils.py index e59ae2cd..54c59de6 100644 --- a/sotopia/ui/websocket_utils.py +++ b/sotopia/ui/websocket_utils.py @@ -94,6 +94,22 @@ def get_env_agents( return env, agents, environment_messages +def parse_reasoning(reasoning: str, num_agents: int) -> tuple[list[str], str]: + """Parse the reasoning string into a dictionary.""" + sep_token = "SEPSEP" + for i in range(1, num_agents + 1): + reasoning = ( + reasoning.replace(f"Agent {i} comments:\n", sep_token) + .strip(" ") + .strip("\n") + ) + all_chunks = reasoning.split(sep_token) + general_comment = all_chunks[0].strip(" ").strip("\n") + comment_chunks = all_chunks[-num_agents:] + + return comment_chunks, general_comment + + class WebSocketSotopiaSimulator: def __init__(self, env_id: str, agent_ids: list[str]) -> None: self.env, self.agents, self.environment_messages = get_env_agents( @@ -166,7 +182,7 @@ async def run_one_step(self) -> AsyncGenerator[dict[str, Any], None]: ] ) - messages_in_this_turn = [] + messages_in_this_turn: list[dict[str, str]] = [] for sender, receiver, message in self.messages[-2]: if receiver == "Environment": if sender != "Environment": @@ -174,31 +190,77 @@ async def run_one_step(self) -> AsyncGenerator[dict[str, Any], None]: continue else: if "said:" in message: - messages_in_this_turn.append(f"{sender} {message}") + messages_in_this_turn.append( + { + "role": sender, + "type": "said", + "content": message, + } + ) else: - messages_in_this_turn.append(f"{sender}: {message}") + messages_in_this_turn.append( + { + "role": sender, + "type": "action", + "content": message, + } + ) else: - messages_in_this_turn.append(message) - print("\n".join(messages_in_this_turn)) - yield { - "role": "agent", # TODO separate agent 1 and 2 - "type": "action", - "content": messages_in_this_turn[0], - } + messages_in_this_turn.append( + { + "role": "Environment", + "type": "environment", + "content": message, + } + ) + + yield messages_in_this_turn[0] done = all(terminated.values()) - reasoning = info[self.env.agents[0]]["comments"] yield { - "role": "agent", # TODO separate agent 1 and 2 - "type": "comment", - "content": reasoning, + "role": "System", + "type": "divider", + "content": "End Simulation", } + + reasoning = info[self.env.agents[0]]["comments"] rewards = [ info[agent_name]["complete_rating"] for agent_name in self.env.agents ] - yield { - "role": "agent", # TODO separate agent 1 and 2 - "type": "comment", - "content": rewards, - } + reasoning_per_agent, general_comment = parse_reasoning( + reasoning, 2 + ) # TODO: support multiple in the future + + for idx, reasoning in enumerate(reasoning_per_agent): + reasoning_lines = reasoning.split("\n") + new_reasoning = "" + for reasoning_line in reasoning_lines: + dimension = reasoning_line.split(":")[0] + new_reasoning += ( + ( + f"**{dimension}**: {':'.join(reasoning_line.split(':')[1:])}" + + "\n" + ) + if dimension != "" + else reasoning_line + "\n" + ) + yield { + "role": f"Agent {idx + 1}", + "type": "comment", + "content": f"**Agent {idx + 1} reasoning**:\n{new_reasoning}\n\n**Rewards**: {str(rewards[idx])}", + } + + # yield { + # "role": "agent", # TODO separate agent 1 and 2 + # "type": "comment", + # "content": reasoning, + # } + # rewards = [ + # info[agent_name]["complete_rating"] for agent_name in self.env.agents + # ] + # yield { + # "role": "agent", # TODO separate agent 1 and 2 + # "type": "comment", + # "content": rewards, + # }