From ead08b405ef0ffe33c37d5a5d26570a1eb7c7fe3 Mon Sep 17 00:00:00 2001 From: Zhe Su <360307598@qq.com> Date: Sun, 17 Nov 2024 13:04:25 -0500 Subject: [PATCH] support stopping the chat --- sotopia/ui/README.md | 15 +- sotopia/ui/fastapi_server.py | 27 +++- sotopia/ui/sotopia_demo_with_api.html | 215 ++++++++++++++++++++++++++ sotopia/ui/websocket_test_client.py | 2 +- sotopia/ui/websocket_utils.py | 1 + 5 files changed, 256 insertions(+), 4 deletions(-) create mode 100644 sotopia/ui/sotopia_demo_with_api.html diff --git a/sotopia/ui/README.md b/sotopia/ui/README.md index 6f71c461..c5c9c84a 100644 --- a/sotopia/ui/README.md +++ b/sotopia/ui/README.md @@ -158,6 +158,19 @@ returns: | END_SIM | Client → Server | End simulation (payload: not needed) | | FINISH_SIM | Server → Client | Terminate simulation (payload: not needed) | -**Error Type: TBD** + +**Error Type** + +| Error Code | Description | +|------------|-------------| +| NOT_AUTHORIZED | Authentication failure - invalid or expired token | +| SIMULATION_ALREADY_STARTED | Attempt to start a simulation that is already active | +| SIMULATION_NOT_STARTED | Operation attempted on an inactive simulation | +| RESOURCE_NOT_FOUND | Required env_id or agent_ids not found | +| SIMULATION_ERROR | Error occurred during simulation execution | +| SIMULATION_INTERRUPTED | The simulation is interruped | +| OTHER | Other unspecified errors | + + **Implementation plan**: Currently only support LLM-LLM simulation based on [this function](https://github.com/sotopia-lab/sotopia/blob/19d39e068c3bca9246fc366e5759414f62284f93/sotopia/server.py#L108). diff --git a/sotopia/ui/fastapi_server.py b/sotopia/ui/fastapi_server.py index 9fdbfc98..e9320db0 100644 --- a/sotopia/ui/fastapi_server.py +++ b/sotopia/ui/fastapi_server.py @@ -1,17 +1,28 @@ from fastapi import FastAPI, WebSocket from typing import Literal, cast, Dict +from fastapi.middleware.cors import CORSMiddleware + from sotopia.database import EnvironmentProfile, AgentProfile, EpisodeLog import json -from websocket_utils import ( +from sotopia.ui.websocket_utils import ( WebSocketSotopiaSimulator, WSMessageType, ErrorType, WSMessage, ) import uvicorn +import asyncio app = FastAPI() +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) # TODO: Whether allowing CORS for all origins + @app.get("/scenarios", response_model=list[EnvironmentProfile]) async def get_scenarios_all() -> list[EnvironmentProfile]: @@ -172,7 +183,19 @@ async def websocket_endpoint(websocket: WebSocket, token: str) -> None: ).to_json() await websocket.send_json(agent_message_to_pass_back) - # TODO There is no mechanism to stop the simulation + try: + data = await asyncio.wait_for( + websocket.receive_text(), timeout=0.01 + ) + msg = json.loads(data) + if msg.get("type") == WSMessageType.FINISH_SIM.value: + print("----- FINISH -----") + break + except asyncio.TimeoutError: + continue + except Exception as e: + print("Error in receiving message from client", e) + except Exception: await websocket.send_json( WSMessage( diff --git a/sotopia/ui/sotopia_demo_with_api.html b/sotopia/ui/sotopia_demo_with_api.html new file mode 100644 index 00000000..4ee6887a --- /dev/null +++ b/sotopia/ui/sotopia_demo_with_api.html @@ -0,0 +1,215 @@ + + +
+ +