diff --git a/sotopia/ui/__init__.py b/sotopia/ui/__init__.py new file mode 100644 index 000000000..3543d7423 --- /dev/null +++ b/sotopia/ui/__init__.py @@ -0,0 +1,17 @@ +from .fastapi_server import ( + get_scenarios_all, + get_scenarios, + get_agents_all, + get_agents, + get_episodes_all, + get_episodes, +) + +__all__ = [ + "get_scenarios_all", + "get_scenarios", + "get_agents_all", + "get_agents", + "get_episodes_all", + "get_episodes", +] diff --git a/sotopia/ui/fastapi_server.py b/sotopia/ui/fastapi_server.py index 92261175c..a41622b7c 100644 --- a/sotopia/ui/fastapi_server.py +++ b/sotopia/ui/fastapi_server.py @@ -1,8 +1,6 @@ -from fastapi import FastAPI, WebSocket +from fastapi import FastAPI from typing import Literal, cast, Dict from sotopia.database import EnvironmentProfile, AgentProfile, EpisodeLog -import json -from ws_utils import WebSocketSotopiaSimulator, WSMessageType, ErrorType, WSMessage import uvicorn app = FastAPI() @@ -117,76 +115,5 @@ async def delete_scenario(scenario_id: str) -> str: ] = {} # TODO check whether this is the correct way to store the active simulations -@app.websocket("/ws/simulation") -async def websocket_endpoint(websocket: WebSocket, token: str): - if not token: - await websocket.close(code=1008, reason="Missing token") - return - - # TODO check the validity of the token - - await websocket.accept() - simulation_started = False - - while True: - raw_message = await websocket.receive_text() - client_start_msg = json.loads(raw_message) - msg_type = client_start_msg.get("type") - - if msg_type == WSMessageType.START_SIM.value: - if simulation_started: - await websocket.send_json( - WSMessage( - type=WSMessageType.ERROR, - data={"type": ErrorType.SIMULATION_ALREADY_STARTED}, - ).to_json() - ) - continue - - simulation_started = True - active_simulations[token] = True - - try: - simulator = WebSocketSotopiaSimulator( - env_id=client_start_msg["data"]["env_id"], - agent_ids=client_start_msg["data"]["agent_ids"], - ) - except Exception: - await websocket.send_json( - WSMessage( - type=WSMessageType.ERROR, data={"type": ErrorType.OTHER} - ).to_json() - ) - break - - try: - async for message in simulator.run_one_step(): - agent_message_to_pass_back = WSMessage( - type=WSMessageType.SERVER_MSG, - data=message, - ).to_json() - await websocket.send_json(agent_message_to_pass_back) - - # TODO There is no mechanism to stop the simulation - except Exception: - await websocket.send_json( - WSMessage( - type=WSMessageType.ERROR, - data={"type": ErrorType.SIMULATION_ISSUE}, - ).to_json() - ) - - end_simulation_message = WSMessage( - type=WSMessageType.END_SIM, data={} - ).to_json() - await websocket.send_json(end_simulation_message) - simulation_started = False - active_simulations[token] = False - - else: - # TODO if that is not a start message, check other possibilities - pass - - if __name__ == "__main__": uvicorn.run(app, host="127.0.0.1", port=8800) diff --git a/sotopia/ui/websocket_test_client.py b/sotopia/ui/websocket_test_client.py deleted file mode 100644 index bb8ea7128..000000000 --- a/sotopia/ui/websocket_test_client.py +++ /dev/null @@ -1,45 +0,0 @@ -import websocket -import json -import rel -from sotopia.database import EnvironmentProfile, AgentProfile - - -def on_message(ws, message): - msg = json.loads(message) - print("\nReceived message:", json.dumps(msg, indent=2)) - - -def on_error(ws, error): - print("Error:", error) - - -def on_close(ws, close_status_code, close_msg): - print("Connection closed") - - -def on_open(ws): - agent_ids = [agent.pk for agent in AgentProfile.find().all()[:2]] - env_id = EnvironmentProfile.find().all()[0].pk - - print("Connection established, sending START_SIM message...") - start_message = { - "type": "START_SIM", - "data": {"env_id": env_id, "agent_ids": agent_ids}, - } - ws.send(json.dumps(start_message)) - - -if __name__ == "__main__": - websocket.enableTrace(True) - - ws = websocket.WebSocketApp( - "ws://localhost:8800/ws/simulation?token=test_token", - on_open=on_open, - on_message=on_message, - on_error=on_error, - on_close=on_close, - ) - - ws.run_forever(dispatcher=rel) - rel.signal(2, rel.abort) # Ctrl+C to abort - rel.dispatch() diff --git a/sotopia/ui/ws_utils.py b/sotopia/ui/ws_utils.py deleted file mode 100644 index 4c88d4e8a..000000000 --- a/sotopia/ui/ws_utils.py +++ /dev/null @@ -1,199 +0,0 @@ -from sotopia.envs.evaluators import ( - EvaluationForTwoAgents, - ReachGoalLLMEvaluator, - RuleBasedTerminatedEvaluator, - SotopiaDimensions, -) -from sotopia.agents import Agents, LLMAgent -from sotopia.messages import Observation, AgentAction -from sotopia.envs import ParallelSotopiaEnv -from sotopia.database import EnvironmentProfile, AgentProfile - -from enum import Enum -from typing import TypedDict -from pydantic import BaseModel -import asyncio -from typing import AsyncGenerator - - -class WSMessageType(str, Enum): - SERVER_MSG = "SERVER_MSG" - CLIENT_MSG = "CLIENT_MSG" - ERROR = "ERROR" - START_SIM = "START_SIM" - END_SIM = "END_SIM" - - -class ErrorType(str, Enum): - NOT_AUTHORIZED = "NOT_AUTHORIZED" - SIMULATION_ALREADY_STARTED = "SIMULATION_ALREADY_STARTED" - SIMULATION_NOT_STARTED = "SIMULATION_NOT_STARTED" - SIMULATION_ISSUE = "SIMULATION_ISSUE" - OTHER = "OTHER" - - -class MessageForRendering(TypedDict): - role: str - type: str - content: str - - -class WSMessage(BaseModel): - type: WSMessageType - data: dict - - model_config = {"arbitrary_types_allowed": True, "protected_namespaces": ()} - - def to_json(self) -> dict: - return { - "type": self.type.value, # TODO check whether we want to use the enum value or the enum itself - "data": self.data, - } - - -def get_env_agents( - env_id: str, agent_ids: list[str] -) -> tuple[ParallelSotopiaEnv, Agents, dict[str, Observation]]: - environment_profile = EnvironmentProfile.find().all()[0] - agent_profiles = AgentProfile.find().all()[:2] - - agent_list = [ - LLMAgent( - agent_profile=agent_profile, - model_name="gpt-4o-mini", - ) - for agent_idx, agent_profile in enumerate(agent_profiles) - ] - for idx, goal in enumerate(environment_profile.agent_goals): - agent_list[idx].goal = goal - - agents = Agents({agent.agent_name: agent for agent in agent_list}) - env = ParallelSotopiaEnv( - action_order="round-robin", - model_name="gpt-4o-mini", - evaluators=[ - RuleBasedTerminatedEvaluator(max_turn_number=20, max_stale_turn=2), - ], - terminal_evaluators=[ - ReachGoalLLMEvaluator( - "gpt-4o", - EvaluationForTwoAgents[SotopiaDimensions], - ), - ], - env_profile=environment_profile, - ) - - environment_messages = env.reset(agents=agents, omniscient=False) - agents.reset() - - return env, agents, environment_messages - - -class WebSocketSotopiaSimulator: - def __init__(self, env_id, agent_ids) -> None: - self.env, self.agents, self.environment_messages = get_env_agents( - env_id, agent_ids - ) - self.messages: list[list[tuple[str, str, str]]] = [] - self.messages.append( - [ - ( - "Environment", - agent_name, - self.environment_messages[agent_name].to_natural_language(), - ) - for agent_name in self.env.agents - ] - ) - for index, agent_name in enumerate(self.env.agents): - self.agents[agent_name].goal = self.env.profile.agent_goals[index] - - async def run_one_step(self) -> AsyncGenerator[dict[str, any], None]: - done = False - - turn = self.messages[-1] - messages_for_rendering = [ - {"role": "Background Info", "type": "info", "content": turn[0][2]}, - {"role": "Background Info", "type": "info", "content": turn[1][2]}, - {"role": "System", "type": "divider", "content": "Start Simulation"}, - ] - for msg in messages_for_rendering: - yield msg - - while not done: - # gather agent messages - agent_messages: dict[str, AgentAction] = dict() - actions = await asyncio.gather( - *[ - self.agents[agent_name].aact(self.environment_messages[agent_name]) - for agent_name in self.env.agents - ] - ) - - for idx, agent_name in enumerate(self.env.agents): - agent_messages[agent_name] = actions[idx] - - self.messages[-1].append( - ( - agent_name, - "Environment", - agent_messages[agent_name].to_natural_language(), - ) - ) - - # send agent messages to environment - ( - self.environment_messages, - rewards_in_turn, - terminated, - ___, - info, - ) = await self.env.astep(agent_messages) - - self.messages.append( - [ - ( - "Environment", - agent_name, - self.environment_messages[agent_name].to_natural_language(), - ) - for agent_name in self.env.agents - ] - ) - - messages_in_this_turn = [] - for sender, receiver, message in self.messages[-2]: - if receiver == "Environment": - if sender != "Environment": - if "did nothing" in message: - continue - else: - if "said:" in message: - messages_in_this_turn.append(f"{sender} {message}") - else: - messages_in_this_turn.append(f"{sender}: {message}") - else: - messages_in_this_turn.append(message) - print("\n".join(messages_in_this_turn)) - yield { - "role": "agent", # TODO separate agent 1 and 2 - "type": "action", - "content": messages_in_this_turn[0], - } - - done = all(terminated.values()) - - reasoning = info[self.env.agents[0]]["comments"] - yield { - "role": "agent", # TODO separate agent 1 and 2 - "type": "comment", - "content": reasoning, - } - rewards = [ - info[agent_name]["complete_rating"] for agent_name in self.env.agents - ] - yield { - "role": "agent", # TODO separate agent 1 and 2 - "type": "comment", - "content": rewards, - } diff --git a/tests/ui/test_fastapi.py b/tests/ui/test_fastapi.py new file mode 100644 index 000000000..f09a5a125 --- /dev/null +++ b/tests/ui/test_fastapi.py @@ -0,0 +1,11 @@ +from fastapi import FastAPI +from fastapi.testclient import TestClient + +app = FastAPI() +client = TestClient(app) + + +def test_get_scenarios_all() -> None: + response = client.get("/scenarios/") + assert response.status_code == 200 + assert response.json() == {"msg": "Hello World"}