Skip to content

Commit

Permalink
Refactor database utils and test_database.py
Browse files Browse the repository at this point in the history
  • Loading branch information
XuhuiZhou committed Mar 17, 2024
1 parent 903a5eb commit ca40fc4
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 11 deletions.
26 changes: 17 additions & 9 deletions sotopia/database/utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import json

import pandas as pd

from .logs import EpisodeLog
from .persistent_profile import AgentProfile, EnvironmentProfile
from pydantic import ConstrainedList, conlist, root_validator
from redis_om import HashModel, JsonModel
from redis_om.model.model import Field

from .logs import EpisodeLog
from .persistent_profile import AgentProfile, EnvironmentProfile


class TwoAgentEpisodeWithScenarioBackgroundGoals(JsonModel):
episode_id: str = Field(index=True)
scenario: str = Field(index=True)
Expand All @@ -16,6 +17,7 @@ class TwoAgentEpisodeWithScenarioBackgroundGoals(JsonModel):
social_goals: dict[str, str] = Field(index=True)
social_interactions: str = Field(index=True)


def _map_gender_to_adj(gender: str) -> str:
gender_to_adj = {
"Man": "male",
Expand Down Expand Up @@ -113,7 +115,7 @@ def get_social_interactions_from_episode(
# raise ValueError("The starter message is not in the expected format")
else:
overall_social_interaction[0] = starter_msg_list[-1]
return "\n\n".join(overall_social_interaction)
return "\n\n".join(overall_social_interaction)


def episodes_to_csv(
Expand All @@ -137,11 +139,13 @@ def episodes_to_csv(
get_agents_background_from_episode(episode) for episode in episodes
],
"social_goals": [
get_agent_name_to_social_goal_from_episode(episode) for episode in episodes
get_agent_name_to_social_goal_from_episode(episode)
for episode in episodes
],
"social_interactions": [
get_social_interactions_from_episode(episode) for episode in episodes
]
get_social_interactions_from_episode(episode)
for episode in episodes
],
}
df = pd.DataFrame(data)
df.to_csv(filepath, index=False)
Expand All @@ -163,8 +167,12 @@ def episodes_to_json(
scenario=get_scenario_from_episode(episode),
codename=get_codename_from_episode(episode),
agents_background=get_agents_background_from_episode(episode),
social_goals=get_agent_name_to_social_goal_from_episode(episode),
social_interactions=get_social_interactions_from_episode(episode),
social_goals=get_agent_name_to_social_goal_from_episode(
episode
),
social_interactions=get_social_interactions_from_episode(
episode
),
)
json.dump(dict(data), f)
f.write("\n")
4 changes: 2 additions & 2 deletions tests/database/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def _test_create_episode_log_setup_and_tear_down() -> Generator[
AgentProfile.delete("tmppk_agent2")
EpisodeLog.delete("tmppk_episode_log")


def create_dummy_episode_log() -> EpisodeLog:
episode = EpisodeLog(
environment="env",
Expand Down Expand Up @@ -102,7 +103,7 @@ def create_dummy_episode_log() -> EpisodeLog:
reasoning="",
pk="tmppk_episode_log",
rewards_prompt="",
)
)
return episode


Expand Down Expand Up @@ -141,4 +142,3 @@ def test_create_episode_log(
agent_profiles, messages_and_rewards = episode_log.render_for_humans()
assert len(agent_profiles) == 2
assert len(messages_and_rewards) == 4

0 comments on commit ca40fc4

Please sign in to comment.