From c234b03ed60112ca63806a7d876e5795d1e7fefa Mon Sep 17 00:00:00 2001 From: Xuhui Zhou Date: Sat, 6 Apr 2024 11:52:40 -0400 Subject: [PATCH] Serialize episodes into CSV and JSON (#30) * add csv& json * clean format * Refactor test_create_episode_log function and add create_dummy_episode_log helper function * Refactor database utils and test_database.py * Refactor imports and update type annotations in redis_stats.ipynb * Add serialization of data saved in the database to Episodes_to_CSV/JSON in notebooks/redis_stats.ipynb * Add rewards to the logged episodes in jsonl --------- Co-authored-by: Hao --- docs/all_the_issues.md | 4 + notebooks/redis_stats.ipynb | 75 ++++++++++- sotopia/database/__init__.py | 8 ++ sotopia/database/serialization.py | 200 ++++++++++++++++++++++++++++++ tests/database/test_database.py | 53 ++++---- 5 files changed, 314 insertions(+), 26 deletions(-) create mode 100644 sotopia/database/serialization.py diff --git a/docs/all_the_issues.md b/docs/all_the_issues.md index 80ce33d37..56994817b 100644 --- a/docs/all_the_issues.md +++ b/docs/all_the_issues.md @@ -3,6 +3,10 @@ Large batch size may cause some episodes to be skipped. This is due to the fact that the server may not be able to handle the load. Try reducing the batch size. But you can also use the script in `examples/fix_missing_episodes.py` to fix the missing episodes. +## How to serialize the data saved in the database? + +Check out `Episodes_to_CSV/JSON` in the `notebooks/redis_stats.ipynb` notebook. + ## Where I can find the data? For the full data: diff --git a/notebooks/redis_stats.ipynb b/notebooks/redis_stats.ipynb index d7ee5bc4c..057aa88f7 100644 --- a/notebooks/redis_stats.ipynb +++ b/notebooks/redis_stats.ipynb @@ -9,12 +9,13 @@ "import sys\n", "import os\n", "import json\n", + "from typing import get_args\n", "from tqdm.notebook import tqdm\n", "import rich\n", "import logging\n", + "from pydantic import ValidationError\n", "from collections import defaultdict, Counter\n", - "from sotopia.database.persistent_profile import AgentProfile, EnvironmentProfile, RelationshipProfile\n", - "from sotopia.database.logs import EpisodeLog\n", + "from sotopia.database import AgentProfile, EnvironmentProfile, RelationshipProfile, EpisodeLog, episodes_to_csv, episodes_to_json \n", "from sotopia.database.env_agent_combo_storage import EnvAgentComboStorage\n", "from collections import Counter \n", "from redis_om import Migrator\n", @@ -22,6 +23,76 @@ "from rich.terminal_theme import MONOKAI " ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Episodes to CSV/JSON" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "LLM_Name = Literal[\n", + " \"togethercomputer/llama-2-7b-chat\",\n", + " \"togethercomputer/llama-2-70b-chat\",\n", + " \"togethercomputer/mpt-30b-chat\",\n", + " \"gpt-3.5-turbo\",\n", + " \"text-davinci-003\",\n", + " \"gpt-4\",\n", + " \"gpt-4-turbo\",\n", + " \"human\",\n", + " \"redis\",\n", + "]\n", + "def _is_valid_episode_log_pk(pk: str) -> bool:\n", + " try:\n", + " episode = EpisodeLog.get(pk=pk)\n", + " except ValidationError:\n", + " return False\n", + " try:\n", + " tag = episode.tag\n", + " model_1, model_2, version = tag.split(\"_\", maxsplit=2)\n", + " if (\n", + " model_1 in get_args(LLM_Name)\n", + " and model_2 in get_args(LLM_Name)\n", + " and version == \"v0.0.1_clean\"\n", + " ):\n", + " return True\n", + " else:\n", + " return False\n", + " except (ValueError, AttributeError):\n", + " # ValueError: tag has less than 3 parts\n", + " # AttributeError: tag is None\n", + " return False\n", + "\n", + "\n", + "episodes: list[EpisodeLog] = [\n", + " EpisodeLog.get(pk=pk)\n", + " for pk in filter(_is_valid_episode_log_pk, EpisodeLog.all_pks())\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "episodes_to_csv(episodes, \"../data/sotopia_episodes_v1.csv\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "episodes_to_json(episodes, \"../data/sotopia_episodes_v1.json\")" + ] + }, { "cell_type": "markdown", "metadata": {}, diff --git a/sotopia/database/__init__.py b/sotopia/database/__init__.py index 32c7037cc..591ab1e2e 100644 --- a/sotopia/database/__init__.py +++ b/sotopia/database/__init__.py @@ -7,6 +7,11 @@ RelationshipProfile, RelationshipType, ) +from .serialization import ( + episodes_to_csv, + episodes_to_json, + get_rewards_from_episode, +) from .session_transaction import MessageTransaction, SessionTransaction from .waiting_room import MatchingInWaitingRoom @@ -23,4 +28,7 @@ "SessionTransaction", "MessageTransaction", "MatchingInWaitingRoom", + "episodes_to_csv", + "episodes_to_json", + "get_rewards_from_episodes", ] diff --git a/sotopia/database/serialization.py b/sotopia/database/serialization.py new file mode 100644 index 000000000..939d988a5 --- /dev/null +++ b/sotopia/database/serialization.py @@ -0,0 +1,200 @@ +import json + +import pandas as pd +from pydantic import BaseModel, Field + +from .logs import EpisodeLog +from .persistent_profile import AgentProfile, EnvironmentProfile + + +class TwoAgentEpisodeWithScenarioBackgroundGoals(BaseModel): + episode_id: str = Field(required=True) + scenario: str = Field(required=True) + codename: str = Field(required=True) + agents_background: dict[str, str] = Field(required=True) + social_goals: dict[str, str] = Field(required=True) + social_interactions: str = Field(required=True) + reasoning: str = Field(required=False) + rewards: list[dict[str, float]] = Field(required=False) + + +def _map_gender_to_adj(gender: str) -> str: + gender_to_adj = { + "Man": "male", + "Woman": "female", + "Nonbinary": "nonbinary", + } + if gender: + return gender_to_adj[gender] + else: + return "" + + +def get_rewards_from_episode(episode: EpisodeLog) -> list[dict[str, float]]: + assert ( + len(episode.rewards) == 2 + and (not isinstance(episode.rewards[0], float)) + and (not isinstance(episode.rewards[1], float)) + ) + return [episode.rewards[0][1], episode.rewards[1][1]] + + +def get_scenario_from_episode(episode: EpisodeLog) -> str: + """Get the scenario from the episode. + + Args: + episode (EpisodeLog): The episode. + + Returns: + str: The scenario. + """ + return EnvironmentProfile.get(pk=episode.environment).scenario + + +def get_codename_from_episode(episode: EpisodeLog) -> str: + """Get the codename from the episode. + + Args: + episode (EpisodeLog): The episode. + + Returns: + str: The codename. + """ + return EnvironmentProfile.get(pk=episode.environment).codename + + +def get_agents_background_from_episode(episode: EpisodeLog) -> dict[str, str]: + """Get the agents' background from the episode. + + Args: + episode (EpisodeLog): The episode. + + Returns: + list[str]: The agents' background. + """ + agents = [AgentProfile.get(pk=agent) for agent in episode.agents] + + return { + f"{profile.first_name} {profile.last_name}": f"{profile.first_name} {profile.last_name} is a {profile.age}-year-old {_map_gender_to_adj(profile.gender)} {profile.occupation.lower()}. {profile.gender_pronoun} pronouns. {profile.public_info} Personality and values description: {profile.personality_and_values} {profile.first_name}'s secrets: {profile.secret}" + for profile in agents + } + + +def get_agent_name_to_social_goal_from_episode( + episode: EpisodeLog, +) -> dict[str, str]: + agents = [AgentProfile.get(agent) for agent in episode.agents] + agent_names = [ + agent.first_name + " " + agent.last_name for agent in agents + ] + environment = EnvironmentProfile.get(episode.environment) + agent_goals = { + agent_names[0]: environment.agent_goals[0], + agent_names[1]: environment.agent_goals[1], + } + return agent_goals + + +def get_social_interactions_from_episode( + episode: EpisodeLog, +) -> str: + assert isinstance(episode.tag, str) + list_of_social_interactions = episode.render_for_humans()[1] + if len(list_of_social_interactions) < 3: + return "" + if "script" in episode.tag.split("_"): + overall_social_interaction = list_of_social_interactions[1:-3] + else: + overall_social_interaction = list_of_social_interactions[0:-3] + # only get the sentence after "Conversation Starts:\n\n" + starter_msg_list = overall_social_interaction[0].split( + "Conversation Starts:\n\n" + ) + if len(starter_msg_list) < 3: + overall_social_interaction = list_of_social_interactions[1:-3] + # raise ValueError("The starter message is not in the expected format") + else: + overall_social_interaction[0] = starter_msg_list[-1] + return "\n\n".join(overall_social_interaction) + + +def episodes_to_csv( + episodes: list[EpisodeLog], csv_file_path: str = "episodes.csv" +) -> None: + """Save episodes to a csv file. + + Args: + episodes (list[EpisodeLog]): List of episodes. + filepath (str, optional): The file path. Defaults to "episodes.csv". + """ + data = { + "episode_id": [episode.pk for episode in episodes], + "scenario": [ + get_scenario_from_episode(episode) for episode in episodes + ], + "codename": [ + get_codename_from_episode(episode) for episode in episodes + ], + "agents_background": [ + get_agents_background_from_episode(episode) for episode in episodes + ], + "social_goals": [ + get_agent_name_to_social_goal_from_episode(episode) + for episode in episodes + ], + "social_interactions": [ + get_social_interactions_from_episode(episode) + for episode in episodes + ], + } + df = pd.DataFrame(data) + df.to_csv(csv_file_path, index=False) + + +def episodes_to_json( + episodes: list[EpisodeLog], jsonl_file_path: str = "episodes.jsonl" +) -> None: + """Save episodes to a json file. + + Args: + episodes (list[EpisodeLog]): List of episodes. + filepath (str, optional): The file path. Defaults to "episodes.json". + """ + with open(jsonl_file_path, "w") as f: + for episode in episodes: + data = TwoAgentEpisodeWithScenarioBackgroundGoals( + episode_id=episode.pk, + scenario=get_scenario_from_episode(episode), + codename=get_codename_from_episode(episode), + agents_background=get_agents_background_from_episode(episode), + social_goals=get_agent_name_to_social_goal_from_episode( + episode + ), + social_interactions=get_social_interactions_from_episode( + episode + ), + reasoning=episode.reasoning, + rewards=get_rewards_from_episode(episode), + ) + json.dump(dict(data), f) + f.write("\n") + + +def jsonl_to_episodes( + jsonl_file_path: str, +) -> list[TwoAgentEpisodeWithScenarioBackgroundGoals]: + """Load episodes from a jsonl file. + + Args: + jsonl_file_path (str): The file path. + + Returns: + list[TwoAgentEpisodeWithScenarioBackgroundGoals]: List of episodes. + """ + episodes = [] + with open(jsonl_file_path, "r") as f: + for line in f: + data = json.loads(line) + episode = TwoAgentEpisodeWithScenarioBackgroundGoals(**data) + episodes.append(episode) + return episodes diff --git a/tests/database/test_database.py b/tests/database/test_database.py index ad6a0d90d..ac3d0d3c1 100644 --- a/tests/database/test_database.py +++ b/tests/database/test_database.py @@ -54,30 +54,8 @@ def _test_create_episode_log_setup_and_tear_down() -> Generator[ EpisodeLog.delete("tmppk_episode_log") -def test_get_agent_by_name( - _test_create_episode_log_setup_and_tear_down: Any, -) -> None: - agent_profile = AgentProfile.find(AgentProfile.first_name == "John").all() - assert agent_profile[0].pk == "tmppk_agent1" - - -def test_create_episode_log( - _test_create_episode_log_setup_and_tear_down: Any, -) -> None: - try: - _ = EpisodeLog( - environment="", - agents=["", ""], - messages=[], - rewards=[[0, 0, 0]], - reasoning=[""], - rewards_prompt="", - ) - assert False - except Exception as e: - assert isinstance(e, ValidationError) - - episode_log = EpisodeLog( +def create_dummy_episode_log() -> EpisodeLog: + episode = EpisodeLog( environment="env", agents=["tmppk_agent1", "tmppk_agent2"], messages=[ @@ -126,6 +104,33 @@ def test_create_episode_log( pk="tmppk_episode_log", rewards_prompt="", ) + return episode + + +def test_get_agent_by_name( + _test_create_episode_log_setup_and_tear_down: Any, +) -> None: + agent_profile = AgentProfile.find(AgentProfile.first_name == "John").all() + assert agent_profile[0].pk == "tmppk_agent1" + + +def test_create_episode_log( + _test_create_episode_log_setup_and_tear_down: Any, +) -> None: + try: + _ = EpisodeLog( + environment="", + agents=["", ""], + messages=[], + rewards=[[0, 0, 0]], + reasoning=[""], + rewards_prompt="", + ) + assert False + except Exception as e: + assert isinstance(e, ValidationError) + + episode_log = create_dummy_episode_log() episode_log.save() assert episode_log.pk == "tmppk_episode_log" retrieved_episode_log: EpisodeLog = EpisodeLog.get(episode_log.pk)