From ead08b405ef0ffe33c37d5a5d26570a1eb7c7fe3 Mon Sep 17 00:00:00 2001 From: Zhe Su <360307598@qq.com> Date: Sun, 17 Nov 2024 13:04:25 -0500 Subject: [PATCH] support stopping the chat --- sotopia/ui/README.md | 15 +- sotopia/ui/fastapi_server.py | 27 +++- sotopia/ui/sotopia_demo_with_api.html | 215 ++++++++++++++++++++++++++ sotopia/ui/websocket_test_client.py | 2 +- sotopia/ui/websocket_utils.py | 1 + 5 files changed, 256 insertions(+), 4 deletions(-) create mode 100644 sotopia/ui/sotopia_demo_with_api.html diff --git a/sotopia/ui/README.md b/sotopia/ui/README.md index 6f71c461..c5c9c84a 100644 --- a/sotopia/ui/README.md +++ b/sotopia/ui/README.md @@ -158,6 +158,19 @@ returns: | END_SIM | Client → Server | End simulation (payload: not needed) | | FINISH_SIM | Server → Client | Terminate simulation (payload: not needed) | -**Error Type: TBD** + +**Error Type** + +| Error Code | Description | +|------------|-------------| +| NOT_AUTHORIZED | Authentication failure - invalid or expired token | +| SIMULATION_ALREADY_STARTED | Attempt to start a simulation that is already active | +| SIMULATION_NOT_STARTED | Operation attempted on an inactive simulation | +| RESOURCE_NOT_FOUND | Required env_id or agent_ids not found | +| SIMULATION_ERROR | Error occurred during simulation execution | +| SIMULATION_INTERRUPTED | The simulation is interruped | +| OTHER | Other unspecified errors | + + **Implementation plan**: Currently only support LLM-LLM simulation based on [this function](https://github.com/sotopia-lab/sotopia/blob/19d39e068c3bca9246fc366e5759414f62284f93/sotopia/server.py#L108). diff --git a/sotopia/ui/fastapi_server.py b/sotopia/ui/fastapi_server.py index 9fdbfc98..e9320db0 100644 --- a/sotopia/ui/fastapi_server.py +++ b/sotopia/ui/fastapi_server.py @@ -1,17 +1,28 @@ from fastapi import FastAPI, WebSocket from typing import Literal, cast, Dict +from fastapi.middleware.cors import CORSMiddleware + from sotopia.database import EnvironmentProfile, AgentProfile, EpisodeLog import json -from websocket_utils import ( +from sotopia.ui.websocket_utils import ( WebSocketSotopiaSimulator, WSMessageType, ErrorType, WSMessage, ) import uvicorn +import asyncio app = FastAPI() +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) # TODO: Whether allowing CORS for all origins + @app.get("/scenarios", response_model=list[EnvironmentProfile]) async def get_scenarios_all() -> list[EnvironmentProfile]: @@ -172,7 +183,19 @@ async def websocket_endpoint(websocket: WebSocket, token: str) -> None: ).to_json() await websocket.send_json(agent_message_to_pass_back) - # TODO There is no mechanism to stop the simulation + try: + data = await asyncio.wait_for( + websocket.receive_text(), timeout=0.01 + ) + msg = json.loads(data) + if msg.get("type") == WSMessageType.FINISH_SIM.value: + print("----- FINISH -----") + break + except asyncio.TimeoutError: + continue + except Exception as e: + print("Error in receiving message from client", e) + except Exception: await websocket.send_json( WSMessage( diff --git a/sotopia/ui/sotopia_demo_with_api.html b/sotopia/ui/sotopia_demo_with_api.html new file mode 100644 index 00000000..4ee6887a --- /dev/null +++ b/sotopia/ui/sotopia_demo_with_api.html @@ -0,0 +1,215 @@ + + + + + Sotopia Minimal + + + +
+

Scenario (select one):

+ +
+ +
+

Agents (select exactly 2):

+
+
+ +
+ + +
+ +

Selected Data:

+

+
+    

Simulation Messages:

+

+
+    
+  
+
diff --git a/sotopia/ui/websocket_test_client.py b/sotopia/ui/websocket_test_client.py
index 8405a48c..26f64332 100644
--- a/sotopia/ui/websocket_test_client.py
+++ b/sotopia/ui/websocket_test_client.py
@@ -1,6 +1,6 @@
 import websocket
 import json
-import rel
+import rel  # type: ignore
 from sotopia.database import EnvironmentProfile, AgentProfile
 
 
diff --git a/sotopia/ui/websocket_utils.py b/sotopia/ui/websocket_utils.py
index 587f9029..e59ae2cd 100644
--- a/sotopia/ui/websocket_utils.py
+++ b/sotopia/ui/websocket_utils.py
@@ -22,6 +22,7 @@ class WSMessageType(str, Enum):
     ERROR = "ERROR"
     START_SIM = "START_SIM"
     END_SIM = "END_SIM"
+    FINISH_SIM = "FINISH_SIM"
 
 
 class ErrorType(str, Enum):