Skip to content

Commit

Permalink
support stopping the chat
Browse files Browse the repository at this point in the history
  • Loading branch information
bugsz committed Nov 17, 2024
1 parent f6aeecf commit ead08b4
Show file tree
Hide file tree
Showing 5 changed files with 256 additions and 4 deletions.
15 changes: 14 additions & 1 deletion sotopia/ui/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,19 @@ returns:
| END_SIM | Client → Server | End simulation (payload: not needed) |
| FINISH_SIM | Server → Client | Terminate simulation (payload: not needed) |

**Error Type: TBD**

**Error Type**

| Error Code | Description |
|------------|-------------|
| NOT_AUTHORIZED | Authentication failure - invalid or expired token |
| SIMULATION_ALREADY_STARTED | Attempt to start a simulation that is already active |
| SIMULATION_NOT_STARTED | Operation attempted on an inactive simulation |
| RESOURCE_NOT_FOUND | Required env_id or agent_ids not found |
| SIMULATION_ERROR | Error occurred during simulation execution |
| SIMULATION_INTERRUPTED | The simulation is interruped |
| OTHER | Other unspecified errors |



**Implementation plan**: Currently only support LLM-LLM simulation based on [this function](https://github.com/sotopia-lab/sotopia/blob/19d39e068c3bca9246fc366e5759414f62284f93/sotopia/server.py#L108).
27 changes: 25 additions & 2 deletions sotopia/ui/fastapi_server.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,28 @@
from fastapi import FastAPI, WebSocket
from typing import Literal, cast, Dict
from fastapi.middleware.cors import CORSMiddleware

from sotopia.database import EnvironmentProfile, AgentProfile, EpisodeLog
import json
from websocket_utils import (
from sotopia.ui.websocket_utils import (
WebSocketSotopiaSimulator,
WSMessageType,
ErrorType,
WSMessage,
)
import uvicorn
import asyncio

app = FastAPI()

app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
) # TODO: Whether allowing CORS for all origins


@app.get("/scenarios", response_model=list[EnvironmentProfile])
async def get_scenarios_all() -> list[EnvironmentProfile]:
Expand Down Expand Up @@ -172,7 +183,19 @@ async def websocket_endpoint(websocket: WebSocket, token: str) -> None:
).to_json()
await websocket.send_json(agent_message_to_pass_back)

# TODO There is no mechanism to stop the simulation
try:
data = await asyncio.wait_for(
websocket.receive_text(), timeout=0.01
)
msg = json.loads(data)
if msg.get("type") == WSMessageType.FINISH_SIM.value:
print("----- FINISH -----")
break
except asyncio.TimeoutError:
continue
except Exception as e:
print("Error in receiving message from client", e)

except Exception:
await websocket.send_json(
WSMessage(
Expand Down
215 changes: 215 additions & 0 deletions sotopia/ui/sotopia_demo_with_api.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
<!doctype html>
<html lang="en">
<head>
<meta charset="UTF-8" />
<title>Sotopia Minimal</title>
<style>
body {
padding: 20px;
}
pre {
background: #f5f5f5;
padding: 10px;
}
.error {
color: red;
}
button {
margin-right: 10px;
}
</style>
</head>
<body>
<div>
<h3>Scenario (select one):</h3>
<select id="scenarioSelect">
<option value="">Select scenario...</option>
</select>
</div>

<div>
<h3>Agents (select exactly 2):</h3>
<div id="agentsList"></div>
</div>

<div>
<button id="startBtn" disabled>Start</button>
<button id="stopBtn" disabled>Stop</button>
</div>

<h3>Selected Data:</h3>
<pre id="selectedData"></pre>

<h3>Simulation Messages:</h3>
<pre id="messages"></pre>

<script>
const scenarioSelect = document.getElementById("scenarioSelect");
const agentsList = document.getElementById("agentsList");
const startBtn = document.getElementById("startBtn");
const stopBtn = document.getElementById("stopBtn");
const selectedData = document.getElementById("selectedData");
const messages = document.getElementById("messages");

let scenarios = [];
let agents = [];
let selectedAgents = [];
let activeSocket = null;

function updateSelectedDisplay() {
const selectedScenario = scenarios.find(
(s) => s.pk === scenarioSelect.value,
);
const selectedAgentData = agents.filter((a) =>
selectedAgents.includes(a.pk),
);

selectedData.textContent = JSON.stringify(
{
selected_scenario: selectedScenario,
selected_agents: selectedAgentData,
},
null,
2,
);

startBtn.disabled = !selectedScenario || selectedAgents.length !== 2;
}

async function cleanupWebSocket() {
if (activeSocket) {
activeSocket.close();
activeSocket = null;
await new Promise((resolve) => setTimeout(resolve, 1000));
}
}

async function initialize() {
try {
const [scenariosRes, agentsRes] = await Promise.all([
fetch("http://127.0.0.1:8800/scenarios"),
fetch("http://127.0.0.1:8800/agents"),
]);

scenarios = await scenariosRes.json();
agents = await agentsRes.json();

scenarioSelect.innerHTML = `
<option value="">Select scenario...</option>
${scenarios
.map((s) => `<option value="${s.pk}">${s.pk}</option>`)
.join("")}
`;

agentsList.innerHTML = agents
.map(
(a) => `
<label style="display: block; margin: 5px;">
<input type="checkbox" value="${a.pk}">
${a.pk}
</label>
`,
)
.join("");

agentsList
.querySelectorAll('input[type="checkbox"]')
.forEach((checkbox) => {
checkbox.addEventListener("change", (e) => {
if (e.target.checked) {
if (selectedAgents.length < 2) {
selectedAgents.push(e.target.value);
} else {
e.target.checked = false;
alert("You can only select 2 agents");
}
} else {
selectedAgents = selectedAgents.filter(
(id) => id !== e.target.value,
);
}
updateSelectedDisplay();
});
});
} catch (error) {
selectedData.textContent = "Error loading data: " + error.message;
}
}

scenarioSelect.addEventListener("change", updateSelectedDisplay);

stopBtn.addEventListener("click", async () => {
if (!activeSocket || activeSocket.readyState !== WebSocket.OPEN) return;

stopBtn.disabled = true;
const stopMessage = {
type: "FINISH_SIM",
data: "",
};
activeSocket.send(JSON.stringify(stopMessage));
messages.textContent += `Sent: ${JSON.stringify(
stopMessage,
null,
2,
)}\n\n`;

startBtn.disabled = false;
});

startBtn.addEventListener("click", async () => {
// await cleanupWebSocket();

messages.textContent = "Starting simulation...\n";

activeSocket = new WebSocket(
"ws://127.0.0.1:8800/ws/simulation?token=demo-token",
);
startBtn.disabled = true;

activeSocket.onopen = () => {
stopBtn.disabled = false;
const startMessage = {
type: "START_SIM",
data: {
env_id: scenarioSelect.value,
agent_ids: selectedAgents,
},
};
activeSocket.send(JSON.stringify(startMessage));
messages.textContent += `Sent: ${JSON.stringify(
startMessage,
null,
2,
)}\n\n`;
};

activeSocket.onmessage = (event) => {
const data = JSON.parse(event.data);
messages.textContent += JSON.stringify(data, null, 2) + "\n\n";
messages.scrollTop = messages.scrollHeight;

if (data.type === "END_SIM" || data.type === "ERROR") {
stopBtn.disabled = true;
startBtn.disabled = false;
// cleanupWebSocket();
}
};

activeSocket.onerror = async (error) => {
messages.textContent += "Error: " + error + "\n";
stopBtn.disabled = true;
startBtn.disabled = false;
// await cleanupWebSocket();
};

activeSocket.onclose = async () => {
stopBtn.disabled = true;
startBtn.disabled = false;
// await cleanupWebSocket();
};
});

initialize();
</script>
</body>
</html>
2 changes: 1 addition & 1 deletion sotopia/ui/websocket_test_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import websocket
import json
import rel
import rel # type: ignore
from sotopia.database import EnvironmentProfile, AgentProfile


Expand Down
1 change: 1 addition & 0 deletions sotopia/ui/websocket_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class WSMessageType(str, Enum):
ERROR = "ERROR"
START_SIM = "START_SIM"
END_SIM = "END_SIM"
FINISH_SIM = "FINISH_SIM"


class ErrorType(str, Enum):
Expand Down

0 comments on commit ead08b4

Please sign in to comment.