Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into feature/ruff-uv
Browse files Browse the repository at this point in the history
  • Loading branch information
ProKil committed Apr 7, 2024
2 parents b209d8a + c234b03 commit aa15c1b
Show file tree
Hide file tree
Showing 5 changed files with 324 additions and 24 deletions.
4 changes: 4 additions & 0 deletions docs/all_the_issues.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@

Large batch size may cause some episodes to be skipped. This is due to the fact that the server may not be able to handle the load. Try reducing the batch size. But you can also use the script in `examples/fix_missing_episodes.py` to fix the missing episodes.

## How to serialize the data saved in the database?

Check out `Episodes_to_CSV/JSON` in the `notebooks/redis_stats.ipynb` notebook.

## Where I can find the data?

For the full data:
Expand Down
84 changes: 84 additions & 0 deletions notebooks/redis_stats.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
"source": [
"import os\n",
"import json\n",
"from typing import get_args\n",
"from tqdm.notebook import tqdm\n",
"import rich\n",
"import logging\n",
Expand All @@ -21,6 +22,89 @@
"from sotopia.database.env_agent_combo_storage import EnvAgentComboStorage"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Episodes to CSV/JSON"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from typing import Literal\n",
"\n",
"from pydantic import ValidationError\n",
"\n",
"\n",
"LLM_Name = Literal[\n",
" \"togethercomputer/llama-2-7b-chat\",\n",
" \"togethercomputer/llama-2-70b-chat\",\n",
" \"togethercomputer/mpt-30b-chat\",\n",
" \"gpt-3.5-turbo\",\n",
" \"text-davinci-003\",\n",
" \"gpt-4\",\n",
" \"gpt-4-turbo\",\n",
" \"human\",\n",
" \"redis\",\n",
"]\n",
"\n",
"\n",
"def _is_valid_episode_log_pk(pk: str) -> bool:\n",
" try:\n",
" episode = EpisodeLog.get(pk=pk)\n",
" except ValidationError:\n",
" return False\n",
" try:\n",
" tag = episode.tag\n",
" model_1, model_2, version = tag.split(\"_\", maxsplit=2)\n",
" if (\n",
" model_1 in get_args(LLM_Name)\n",
" and model_2 in get_args(LLM_Name)\n",
" and version == \"v0.0.1_clean\"\n",
" ):\n",
" return True\n",
" else:\n",
" return False\n",
" except (ValueError, AttributeError):\n",
" # ValueError: tag has less than 3 parts\n",
" # AttributeError: tag is None\n",
" return False\n",
"\n",
"\n",
"episodes: list[EpisodeLog] = [\n",
" EpisodeLog.get(pk=pk)\n",
" for pk in filter(_is_valid_episode_log_pk, EpisodeLog.all_pks())\n",
"]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from sotopia.database.serialization import episodes_to_csv\n",
"\n",
"\n",
"episodes_to_csv(episodes, \"../data/sotopia_episodes_v1.csv\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from sotopia.database.serialization import episodes_to_json\n",
"\n",
"\n",
"episodes_to_json(episodes, \"../data/sotopia_episodes_v1.json\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down
7 changes: 7 additions & 0 deletions sotopia/database/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
RelationshipProfile,
RelationshipType,
)
from .serialization import (
episodes_to_csv,
episodes_to_json,
)
from .session_transaction import MessageTransaction, SessionTransaction
from .waiting_room import MatchingInWaitingRoom

Expand All @@ -23,4 +27,7 @@
"SessionTransaction",
"MessageTransaction",
"MatchingInWaitingRoom",
"episodes_to_csv",
"episodes_to_json",
"get_rewards_from_episodes",
]
200 changes: 200 additions & 0 deletions sotopia/database/serialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
import json

import pandas as pd
from pydantic import BaseModel, Field

from .logs import EpisodeLog
from .persistent_profile import AgentProfile, EnvironmentProfile


class TwoAgentEpisodeWithScenarioBackgroundGoals(BaseModel):
episode_id: str = Field(required=True)
scenario: str = Field(required=True)
codename: str = Field(required=True)
agents_background: dict[str, str] = Field(required=True)
social_goals: dict[str, str] = Field(required=True)
social_interactions: str = Field(required=True)
reasoning: str = Field(required=False)
rewards: list[dict[str, float]] = Field(required=False)


def _map_gender_to_adj(gender: str) -> str:
gender_to_adj = {
"Man": "male",
"Woman": "female",
"Nonbinary": "nonbinary",
}
if gender:
return gender_to_adj[gender]
else:
return ""


def get_rewards_from_episode(episode: EpisodeLog) -> list[dict[str, float]]:
assert (
len(episode.rewards) == 2
and (not isinstance(episode.rewards[0], float))
and (not isinstance(episode.rewards[1], float))
)
return [episode.rewards[0][1], episode.rewards[1][1]]


def get_scenario_from_episode(episode: EpisodeLog) -> str:
"""Get the scenario from the episode.
Args:
episode (EpisodeLog): The episode.
Returns:
str: The scenario.
"""
return EnvironmentProfile.get(pk=episode.environment).scenario


def get_codename_from_episode(episode: EpisodeLog) -> str:
"""Get the codename from the episode.
Args:
episode (EpisodeLog): The episode.
Returns:
str: The codename.
"""
return EnvironmentProfile.get(pk=episode.environment).codename


def get_agents_background_from_episode(episode: EpisodeLog) -> dict[str, str]:
"""Get the agents' background from the episode.
Args:
episode (EpisodeLog): The episode.
Returns:
list[str]: The agents' background.
"""
agents = [AgentProfile.get(pk=agent) for agent in episode.agents]

return {
f"{profile.first_name} {profile.last_name}": f"{profile.first_name} {profile.last_name} is a {profile.age}-year-old {_map_gender_to_adj(profile.gender)} {profile.occupation.lower()}. {profile.gender_pronoun} pronouns. {profile.public_info} Personality and values description: {profile.personality_and_values} {profile.first_name}'s secrets: {profile.secret}"
for profile in agents
}


def get_agent_name_to_social_goal_from_episode(
episode: EpisodeLog,
) -> dict[str, str]:
agents = [AgentProfile.get(agent) for agent in episode.agents]
agent_names = [
agent.first_name + " " + agent.last_name for agent in agents
]
environment = EnvironmentProfile.get(episode.environment)
agent_goals = {
agent_names[0]: environment.agent_goals[0],
agent_names[1]: environment.agent_goals[1],
}
return agent_goals


def get_social_interactions_from_episode(
episode: EpisodeLog,
) -> str:
assert isinstance(episode.tag, str)
list_of_social_interactions = episode.render_for_humans()[1]
if len(list_of_social_interactions) < 3:
return ""
if "script" in episode.tag.split("_"):
overall_social_interaction = list_of_social_interactions[1:-3]
else:
overall_social_interaction = list_of_social_interactions[0:-3]
# only get the sentence after "Conversation Starts:\n\n"
starter_msg_list = overall_social_interaction[0].split(
"Conversation Starts:\n\n"
)
if len(starter_msg_list) < 3:
overall_social_interaction = list_of_social_interactions[1:-3]
# raise ValueError("The starter message is not in the expected format")
else:
overall_social_interaction[0] = starter_msg_list[-1]
return "\n\n".join(overall_social_interaction)


def episodes_to_csv(
episodes: list[EpisodeLog], csv_file_path: str = "episodes.csv"
) -> None:
"""Save episodes to a csv file.
Args:
episodes (list[EpisodeLog]): List of episodes.
filepath (str, optional): The file path. Defaults to "episodes.csv".
"""
data = {
"episode_id": [episode.pk for episode in episodes],
"scenario": [
get_scenario_from_episode(episode) for episode in episodes
],
"codename": [
get_codename_from_episode(episode) for episode in episodes
],
"agents_background": [
get_agents_background_from_episode(episode) for episode in episodes
],
"social_goals": [
get_agent_name_to_social_goal_from_episode(episode)
for episode in episodes
],
"social_interactions": [
get_social_interactions_from_episode(episode)
for episode in episodes
],
}
df = pd.DataFrame(data)
df.to_csv(csv_file_path, index=False)


def episodes_to_json(
episodes: list[EpisodeLog], jsonl_file_path: str = "episodes.jsonl"
) -> None:
"""Save episodes to a json file.
Args:
episodes (list[EpisodeLog]): List of episodes.
filepath (str, optional): The file path. Defaults to "episodes.json".
"""
with open(jsonl_file_path, "w") as f:
for episode in episodes:
data = TwoAgentEpisodeWithScenarioBackgroundGoals(
episode_id=episode.pk,
scenario=get_scenario_from_episode(episode),
codename=get_codename_from_episode(episode),
agents_background=get_agents_background_from_episode(episode),
social_goals=get_agent_name_to_social_goal_from_episode(
episode
),
social_interactions=get_social_interactions_from_episode(
episode
),
reasoning=episode.reasoning,
rewards=get_rewards_from_episode(episode),
)
json.dump(dict(data), f)
f.write("\n")


def jsonl_to_episodes(
jsonl_file_path: str,
) -> list[TwoAgentEpisodeWithScenarioBackgroundGoals]:
"""Load episodes from a jsonl file.
Args:
jsonl_file_path (str): The file path.
Returns:
list[TwoAgentEpisodeWithScenarioBackgroundGoals]: List of episodes.
"""
episodes = []
with open(jsonl_file_path, "r") as f:
for line in f:
data = json.loads(line)
episode = TwoAgentEpisodeWithScenarioBackgroundGoals(**data)
episodes.append(episode)
return episodes
53 changes: 29 additions & 24 deletions tests/database/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,30 +52,8 @@ def _test_create_episode_log_setup_and_tear_down() -> Generator[None, None, None
EpisodeLog.delete("tmppk_episode_log")


def test_get_agent_by_name(
_test_create_episode_log_setup_and_tear_down: Any,
) -> None:
agent_profile = AgentProfile.find(AgentProfile.first_name == "John").all()
assert agent_profile[0].pk == "tmppk_agent1"


def test_create_episode_log(
_test_create_episode_log_setup_and_tear_down: Any,
) -> None:
try:
_ = EpisodeLog(
environment="",
agents=["", ""],
messages=[],
rewards=[[0, 0, 0]],
reasoning=[""],
rewards_prompt="",
)
assert False
except Exception as e:
assert isinstance(e, ValidationError)

episode_log = EpisodeLog(
def create_dummy_episode_log() -> EpisodeLog:
episode = EpisodeLog(
environment="env",
agents=["tmppk_agent1", "tmppk_agent2"],
messages=[
Expand Down Expand Up @@ -124,6 +102,33 @@ def test_create_episode_log(
pk="tmppk_episode_log",
rewards_prompt="",
)
return episode


def test_get_agent_by_name(
_test_create_episode_log_setup_and_tear_down: Any,
) -> None:
agent_profile = AgentProfile.find(AgentProfile.first_name == "John").all()
assert agent_profile[0].pk == "tmppk_agent1"


def test_create_episode_log(
_test_create_episode_log_setup_and_tear_down: Any,
) -> None:
try:
_ = EpisodeLog(
environment="",
agents=["", ""],
messages=[],
rewards=[[0, 0, 0]],
reasoning=[""],
rewards_prompt="",
)
assert False
except Exception as e:
assert isinstance(e, ValidationError)

episode_log = create_dummy_episode_log()
episode_log.save()
assert episode_log.pk == "tmppk_episode_log"
retrieved_episode_log: EpisodeLog = EpisodeLog.get(episode_log.pk)
Expand Down

0 comments on commit aa15c1b

Please sign in to comment.