Skip to content

Commit

Permalink
Merge branch 'main' into feature/multiparty
Browse files Browse the repository at this point in the history
  • Loading branch information
XuhuiZhou committed Jan 16, 2025
2 parents 8fc24b5 + 3b9e842 commit a67166e
Show file tree
Hide file tree
Showing 12 changed files with 105 additions and 93 deletions.
4 changes: 2 additions & 2 deletions docs/pages/index.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,8 @@ or manual setup:
tar -xvzf redis-stack-server.tar.gz
export PATH=$(pwd)/redis-stack-server-7.2.0-v10/bin:$PATH
# if you are using Ubunutu 22.04, please do an extra step
wget http://nz2.archive.ubuntu.com/ubuntu/pool/main/o/openssl/libssl1.1_1.1.1f-1ubuntu2.22_amd64.deb
sudo dpkg -i libssl1.1_1.1.1f-1ubuntu2.22_amd64.deb
wget http://nz2.archive.ubuntu.com/ubuntu/pool/main/o/openssl/libssl1.1_1.1.1f-1ubuntu2_amd64.deb
sudo dpkg -i libssl1.1_1.1.1f-1ubuntu2_amd64.deb
```

### Start the server
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from sotopia.generation_utils import agenerate
from sotopia.generation_utils.generate import StrOutputParser

from time import sleep
# Check Python version
if sys.version_info >= (3, 11):
pass
Expand Down Expand Up @@ -56,15 +55,17 @@ def __init__(
self.count_ticks: int = 0
self.message_history: list[Observation] = []
self.goal: str = goal
self.model_name: str = model_name
self.model_name: str = model_name
self.agent_profile_pk: str | None = agent_pk
self.name: str | None = agent_name
self.background: dict[str,Any] | None = background
self.background: dict[str, Any] | None = background
self.awake: bool = False

def set_profile(self, use_pk_value: bool) -> None:
if not use_pk_value:
assert (self.background is not None and self.name is not None), "Background and name must be provided"
assert (
self.background is not None and self.name is not None
), "Background and name must be provided"
if " " in self.name:
first_name, last_name = self.name.split(" ", 1)
else:
Expand All @@ -87,7 +88,7 @@ def _format_message_history(self, message_history: list[Observation]) -> str:

async def aact(self, obs: Observation) -> AgentAction:
if obs.turn_number == -1:
if(self.awake):
if self.awake:
return AgentAction(
agent_name=self.name,
output_channel=self.output_channel,
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "sotopia"
version = "0.1.2"
version = "0.1.3"
description = "A platform for simulating and evaluating social interaction."
authors = [
{ name = "Hao Zhu", email = "[email protected]" },
Expand Down
4 changes: 2 additions & 2 deletions sotopia/api/fastapi_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ async def send_error(
)


async def run_simulation(
async def nonstreaming_simulation(
episode_pk: str,
simulation_request: SimulationRequest,
simulation_status: NonStreamingSimulationStatus,
Expand Down Expand Up @@ -548,7 +548,7 @@ def simulate(simulation_request: SimulationRequest) -> Response:
simulation_status.save()
queue = rq.Queue("default", connection=get_redis_connection())
queue.enqueue(
run_simulation,
nonstreaming_simulation,
episode_pk=episode_pk,
simulation_request=simulation_request,
simulation_status=simulation_status,
Expand Down
10 changes: 10 additions & 0 deletions sotopia/cli/install/install.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,16 @@ def install(
subprocess.run(
"tar -xvzf redis-stack-server.tar.gz", shell=True, check=True
)
subprocess.run(
"wget http://nz2.archive.ubuntu.com/ubuntu/pool/main/o/openssl/libssl1.1_1.1.1f-1ubuntu2_amd64.deb",
shell=True,
check=True,
)
subprocess.run(
"sudo dpkg -i libssl1.1_1.1.1f-1ubuntu2_amd64.deb",
shell=True,
check=True,
)
if load_database:
Path("./redis-stack-server-7.2.0-v10/var/db/redis-stack").mkdir(
parents=True, exist_ok=True
Expand Down
4 changes: 3 additions & 1 deletion sotopia/database/logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ class BaseEpisodeLog(BaseModel):
models: list[str] | None = Field(index=True, default=[])
messages: list[list[tuple[str, str, str]]] # Messages arranged by turn
reasoning: str = Field(default="")
rewards: list[tuple[float, dict[str, float]] | float | dict[str, dict]] # Rewards arranged by turn
rewards: list[
tuple[float, dict[str, float]] | float | dict[str, dict]
] # Rewards arranged by turn
rewards_prompt: str

@model_validator(mode="after")
Expand Down
39 changes: 23 additions & 16 deletions sotopia/experimental/agents/evaluators.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import json

from aact import NodeFactory, Node, Message
from aact import NodeFactory
from .base_agent import BaseAgent
from .logs import EpisodeLog
from .datamodels import AgentAction, Observation
from sotopia.database.persistent_profile import AgentProfile

from typing import AsyncIterator, Generic, TypeVar, Type, Any
from typing import Generic, TypeVar, Type, Any
from pydantic import BaseModel, Field
from asyncio import Event

from sotopia.envs.evaluators import GoalDimension
from sotopia.generation_utils.generate import agenerate
Expand All @@ -26,8 +25,12 @@

T_eval_dim = TypeVar("T_eval_dim", bound=BaseModel)


class EvaluationForMutiAgents(BaseModel, Generic[T_eval_dim]):
agents_evaluation: dict[str, T_eval_dim] = Field(description="the evaluation for each agent, the key is the agent name,be sure to include every agent in the agent list, the value should follow the evaluation dimension format")
agents_evaluation: dict[str, T_eval_dim] = Field(
description="the evaluation for each agent, the key is the agent name,be sure to include every agent in the agent list, the value should follow the evaluation dimension format"
)


@NodeFactory.register("evaluator")
class Evaluator(BaseAgent[Observation, AgentAction]):
Expand Down Expand Up @@ -57,10 +60,14 @@ def __init__(
self.reward_prompt = reward_prompt
self.temperature = temperature
if eval_dim_class == "GoalDimension":
self.response_format_class:Type[BaseModel] = EvaluationForMutiAgents[GoalDimension]
self.response_format_class: Type[BaseModel] = EvaluationForMutiAgents[
GoalDimension
]
else:
raise ValueError(f"the eval_dim_class : {eval_dim_class} is not implemented")
#TODO: need a registry for the evaluation dimension class, so dimension can be initialized with a str
raise ValueError(
f"the eval_dim_class : {eval_dim_class} is not implemented"
)
# TODO: need a registry for the evaluation dimension class, so dimension can be initialized with a str

async def aact(self, content: Observation) -> AgentAction:
epilog = EpisodeLog(**json.loads(content.last_turn))
Expand All @@ -70,24 +77,24 @@ async def aact(self, content: Observation) -> AgentAction:
agent_name="evaluator",
output_channel=f"evaluator:{content.agent_name}",
action_type="speak",
argument=json.dumps({
"reward":json.dumps(result),
"reward_prompt":self.reward_prompt
})
argument=json.dumps(
{"reward": json.dumps(result), "reward_prompt": self.reward_prompt}
),
)


async def aevaluate(self, episode: EpisodeLog) -> dict[str, Any]:
#TODO: below is a temporary implementation, need to replaced by using render_for_humans in EpisodeLog
history = "\n".join(f"{msg[0][0]} said: {msg[0][2]}"for msg in episode.messages)
# TODO: below is a temporary implementation, need to replaced by using render_for_humans in EpisodeLog
history = "\n".join(
f"{msg[0][0]} said: {msg[0][2]}" for msg in episode.messages
)
agent_list = []
for pk in episode.agents:
agent = AgentProfile.get(pk)
name = agent.first_name+" "+agent.last_name
name = agent.first_name + " " + agent.last_name
name = name.strip()
agent_list.append(name)

res:BaseModel = await agenerate(
res: BaseModel = await agenerate(
model_name=self.model_name,
template=self.reward_prompt,
input_values=dict(history=history, agent_list=str(agent_list)),
Expand Down
42 changes: 24 additions & 18 deletions sotopia/experimental/agents/moderator.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ def __init__(
super().__init__(
input_channel_types=[
(input_channel, AgentAction) for input_channel in input_channels
]+[(channel[0], AgentAction) for channel in evaluator_channels],
]
+ [(channel[0], AgentAction) for channel in evaluator_channels],
output_channel_types=[
(output_channel, Observation) for output_channel in output_channels
],
Expand Down Expand Up @@ -84,9 +85,11 @@ def __init__(
self.use_pk_value: bool = use_pk_value

self.evaluate_episode: bool = evaluate_episode
assert (not self.evaluate_episode) or len(evaluator_channels) > 0, "if evaluate_episode is True, evaluator_channels should not be empty"
assert (not self.evaluate_episode) or len(
evaluator_channels
) > 0, "if evaluate_episode is True, evaluator_channels should not be empty"

self.epilog: EpisodeLog | None = None # will be initialized in booting process
self.epilog: EpisodeLog | None = None # will be initialized in booting process

if self.action_order == "round-robin":
pass
Expand Down Expand Up @@ -163,7 +166,7 @@ async def booting(self) -> None:
await asyncio.sleep(0.2)
while not self.observation_queue.empty():
agent_action = await self.observation_queue.get()
if(not self.agents_awake[agent_action.agent_name]):
if not self.agents_awake[agent_action.agent_name]:
self.agents_awake[agent_action.agent_name] = True
args: dict[str, Any] = json.loads(agent_action.argument)
self.agents_pk[agent_action.agent_name] = args["pk"]
Expand All @@ -177,10 +180,8 @@ async def booting(self) -> None:
agents=list(self.agents_pk.values()),
tag=self.tag,
models=list(self.agent_models.values()),
messages=[[
("Environment", "Environment", self.scenario)
]],
rewards=[0.0]*len(self.agents),
messages=[[("Environment", "Environment", self.scenario)]],
rewards=[0.0] * len(self.agents),
rewards_prompt="",
)
if self.action_order == "round-robin":
Expand All @@ -205,7 +206,7 @@ async def wrap_up_and_stop(self) -> None:
try:
await asyncio.sleep(0.1)
print("all agents have left, wrap up and stop")
self.shutdown_event.set() # this will disable the task scheduler
self.shutdown_event.set() # this will disable the task scheduler
if self.evaluate_episode:
epilog = await self.aeval(self.epilog)
if self.push_to_db:
Expand Down Expand Up @@ -237,23 +238,28 @@ async def aeval(self, epilog: EpisodeLog) -> EpisodeLog:

for evaluator_channel in self.evaluator_channels:
print(evaluator_channel[1])
await self.r.publish(evaluator_channel[1], Message[Observation](data=Observation(
agent_name="moderator",
last_turn=epilog.model_dump_json(),
turn_number=self.turn_number,
available_actions=self.available_actions,
)).model_dump_json()
await self.r.publish(
evaluator_channel[1],
Message[Observation](
data=Observation(
agent_name="moderator",
last_turn=epilog.model_dump_json(),
turn_number=self.turn_number,
available_actions=self.available_actions,
)
).model_dump_json(),
)


print("episode eval started")

for _ in range(len(self.evaluator_channels)): # the queue will take in input and output from this channel
for _ in range(
len(self.evaluator_channels)
): # the queue will take in input and output from this channel
raw_res = await self.observation_queue.get()
res = json.loads(raw_res.argument)
epilog.rewards = res["reward"]
epilog.rewards_prompt = res["reward_prompt"]

print("episode eval finished")
return epilog

Expand Down
74 changes: 30 additions & 44 deletions tests/api/test_fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from sotopia.messages import SimpleMessage
from sotopia.api.fastapi_server import app
import pytest
from typing import Generator, Callable
from typing import Generator, Callable, Any

client = TestClient(app)

Expand Down Expand Up @@ -341,46 +341,32 @@ def test_delete_evaluation_dimension(create_mock_data: Callable[[], None]) -> No
assert isinstance(response.json(), str)


# def test_simulate(create_mock_data: Callable[[], None]) -> None:
# response = client.post(
# "/simulate",
# json={
# "env_id": "tmppk_env_profile",
# "agent_ids": ["tmppk_agent1", "tmppk_agent2"],
# "models": [
# # "custom/llama3.2:1b@http://localhost:8000/v1",
# # "custom/llama3.2:1b@http://localhost:8000/v1",
# # "custom/llama3.2:1b@http://localhost:8000/v1"
# "gpt-4o-mini",
# "gpt-4o-mini",
# "gpt-4o-mini",
# ],
# "max_turns": 2,
# "tag": "test_tag",
# },
# )
# assert response.status_code == 200
# assert isinstance(response.json(), str)
# max_retries = 20
# retry_count = 0
# while retry_count < max_retries:
# try:
# status = NonStreamingSimulationStatus.find(
# NonStreamingSimulationStatus.episode_pk == response.json()
# ).all()[0]
# assert isinstance(status, NonStreamingSimulationStatus)
# print(status)
# if status.status == "Error":
# raise Exception("Error running simulation")
# elif status.status == "Completed":
# # EpisodeLog.get(response.json())
# break
# # Status is "Started", keep polling
# time.sleep(1)
# retry_count += 1
# except Exception as e:
# print(f"Error checking simulation status: {e}")
# time.sleep(1)
# retry_count += 1
# else:
# raise TimeoutError("Simulation timed out after 10 retries")
def test_websocket_simulate(create_mock_data: Callable[[], None]) -> None:
LOCAL_MODEL = "custom/llama3.2:1b@http://localhost:8000/v1"
with client.websocket_connect("/ws/simulation?token=test") as websocket:
start_msg = {
"type": "START_SIM",
"data": {
"env_id": "tmppk_env_profile",
"agent_ids": ["tmppk_agent1", "tmppk_agent2"],
"agent_models": [LOCAL_MODEL, LOCAL_MODEL],
"evaluator_model": LOCAL_MODEL,
"evaluation_dimension_list_name": "test_dimension_list",
},
}
websocket.send_json(start_msg)

# check the streaming response, stop when we received 2 messages
messages: list[dict[str, Any]] = []
while len(messages) < 2:
message = websocket.receive_json()
assert (
message["type"] == "SERVER_MSG"
), f"Expected SERVER_MSG, got {message['type']}, full msg: {message}"
messages.append(message)

# send the end message
end_msg = {
"type": "FINISH_SIM",
}
websocket.send_json(end_msg)
2 changes: 1 addition & 1 deletion ui/modal_streamlit_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
)
.pip_install("streamlit~=1.40.2", "uv")
.run_commands(
"rm -rf sotopia && git clone https://github.com/sotopia-lab/sotopia.git && cd sotopia && git checkout demo && uv pip install pyproject.toml --system && pip install -e . && cd ui/streamlit_ui",
"rm -rf sotopia && git clone https://github.com/sotopia-lab/sotopia.git && cd sotopia && git checkout demo && uv pip install pyproject.toml --system && pip install -e . && cd ui",
force_build=True,
)
# .pip_install("pydantic==2.8.2")
Expand Down
4 changes: 2 additions & 2 deletions ui/pages/intro.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@
st.markdown(
"""
For larger scale experiments you may need to use the API instead of the Streamlit UI.
- The API documentation for current set of Sotopia is [here](https://sotopia-lab--sotopia-fastapi-webapi-serve.modal.run/)
- The API documentation for current set of Sotopia is [here](https://sotopia-lab--sotopia-fastapi-webapi-serve.modal.run/docs)
- When you are hosting your own API, find it in `{YOUR_API_BASE}/docs`.
- Also see [Sotopia examples](https://github.com/sotopia-lab/sotopia/example) for more information.
- Also see [Sotopia examples](https://github.com/sotopia-lab/sotopia/tree/main/examples) for more information.
"""
)
st.markdown("Current API Base: " + st.session_state.API_BASE)
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit a67166e

Please sign in to comment.