Skip to content

Commit

Permalink
add initial test
Browse files Browse the repository at this point in the history
  • Loading branch information
XuhuiZhou committed Nov 20, 2024
1 parent 7d2e08f commit 43260a1
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 318 deletions.
17 changes: 17 additions & 0 deletions sotopia/ui/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from .fastapi_server import (
get_scenarios_all,
get_scenarios,
get_agents_all,
get_agents,
get_episodes_all,
get_episodes,
)

__all__ = [
"get_scenarios_all",
"get_scenarios",
"get_agents_all",
"get_agents",
"get_episodes_all",
"get_episodes",
]
75 changes: 1 addition & 74 deletions sotopia/ui/fastapi_server.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from fastapi import FastAPI, WebSocket
from fastapi import FastAPI
from typing import Literal, cast, Dict
from sotopia.database import EnvironmentProfile, AgentProfile, EpisodeLog
import json
from ws_utils import WebSocketSotopiaSimulator, WSMessageType, ErrorType, WSMessage
import uvicorn

app = FastAPI()
Expand Down Expand Up @@ -117,76 +115,5 @@ async def delete_scenario(scenario_id: str) -> str:
] = {} # TODO check whether this is the correct way to store the active simulations


@app.websocket("/ws/simulation")
async def websocket_endpoint(websocket: WebSocket, token: str):
if not token:
await websocket.close(code=1008, reason="Missing token")
return

# TODO check the validity of the token

await websocket.accept()
simulation_started = False

while True:
raw_message = await websocket.receive_text()
client_start_msg = json.loads(raw_message)
msg_type = client_start_msg.get("type")

if msg_type == WSMessageType.START_SIM.value:
if simulation_started:
await websocket.send_json(
WSMessage(
type=WSMessageType.ERROR,
data={"type": ErrorType.SIMULATION_ALREADY_STARTED},
).to_json()
)
continue

simulation_started = True
active_simulations[token] = True

try:
simulator = WebSocketSotopiaSimulator(
env_id=client_start_msg["data"]["env_id"],
agent_ids=client_start_msg["data"]["agent_ids"],
)
except Exception:
await websocket.send_json(
WSMessage(
type=WSMessageType.ERROR, data={"type": ErrorType.OTHER}
).to_json()
)
break

try:
async for message in simulator.run_one_step():
agent_message_to_pass_back = WSMessage(
type=WSMessageType.SERVER_MSG,
data=message,
).to_json()
await websocket.send_json(agent_message_to_pass_back)

# TODO There is no mechanism to stop the simulation
except Exception:
await websocket.send_json(
WSMessage(
type=WSMessageType.ERROR,
data={"type": ErrorType.SIMULATION_ISSUE},
).to_json()
)

end_simulation_message = WSMessage(
type=WSMessageType.END_SIM, data={}
).to_json()
await websocket.send_json(end_simulation_message)
simulation_started = False
active_simulations[token] = False

else:
# TODO if that is not a start message, check other possibilities
pass


if __name__ == "__main__":
uvicorn.run(app, host="127.0.0.1", port=8800)
45 changes: 0 additions & 45 deletions sotopia/ui/websocket_test_client.py

This file was deleted.

199 changes: 0 additions & 199 deletions sotopia/ui/ws_utils.py

This file was deleted.

11 changes: 11 additions & 0 deletions tests/ui/test_fastapi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from fastapi import FastAPI
from fastapi.testclient import TestClient

app = FastAPI()
client = TestClient(app)


def test_get_scenarios_all() -> None:
response = client.get("/scenarios/")
assert response.status_code == 200
assert response.json() == {"msg": "Hello World"}

0 comments on commit 43260a1

Please sign in to comment.