Skip to content

Commit

Permalink
Refactor test_create_episode_log function and add create_dummy_episod…
Browse files Browse the repository at this point in the history
…e_log helper function
  • Loading branch information
XuhuiZhou committed Mar 17, 2024
1 parent 021fd5d commit 903a5eb
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 94 deletions.
125 changes: 57 additions & 68 deletions sotopia/database/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,17 @@

from .logs import EpisodeLog
from .persistent_profile import AgentProfile, EnvironmentProfile

from pydantic import ConstrainedList, conlist, root_validator
from redis_om import HashModel, JsonModel
from redis_om.model.model import Field

class TwoAgentEpisodeWithScenarioBackgroundGoals(JsonModel):
episode_id: str = Field(index=True)
scenario: str = Field(index=True)
codename: str = Field(index=True)
agents_background: dict[str, str] = Field(index=True)
social_goals: dict[str, str] = Field(index=True)
social_interactions: str = Field(index=True)

def _map_gender_to_adj(gender: str) -> str:
gender_to_adj = {
Expand Down Expand Up @@ -68,63 +78,42 @@ def get_agents_background_from_episode(episode: EpisodeLog) -> dict[str, str]:
}


def get_social_goals_from_episode(
epsidoes: list[EpisodeLog],
) -> list[dict[str, str]]:
"""Obtain social goals from episodes.
Args:
epsidoes (list[EpisodeLog]): List of episodes.
Returns:
list[dict[str, str]]: List of social goals with agent names as the index.
"""
social_goals = []
for episode in epsidoes:
agents = [AgentProfile.get(agent) for agent in episode.agents]
agent_names = [
agent.first_name + " " + agent.last_name for agent in agents
]
environment = EnvironmentProfile.get(episode.environment)
agent_goals = {
agent_names[0]: environment.agent_goals[0],
agent_names[1]: environment.agent_goals[1],
}
social_goals.append(agent_goals)
return social_goals
def get_agent_name_to_social_goal_from_episode(
episode: EpisodeLog,
) -> dict[str, str]:
agents = [AgentProfile.get(agent) for agent in episode.agents]
agent_names = [
agent.first_name + " " + agent.last_name for agent in agents
]
environment = EnvironmentProfile.get(episode.environment)
agent_goals = {
agent_names[0]: environment.agent_goals[0],
agent_names[1]: environment.agent_goals[1],
}
return agent_goals


def get_social_interactions_from_episode(
epsidoes: list[EpisodeLog],
) -> list[str]:
"""Obtain pure social interactions from episodes.
Args:
epsidoes (list[EpisodeLog]): List of episodes.
Returns:
list[str]: List of social interactions.
"""
social_interactions = []
for episode in epsidoes:
assert isinstance(episode.tag, str)
list_of_social_interactions = episode.render_for_humans()[1]
if len(list_of_social_interactions) < 3:
continue
if "script" in episode.tag.split("_"):
episode: EpisodeLog,
) -> str:
assert isinstance(episode.tag, str)
list_of_social_interactions = episode.render_for_humans()[1]
if len(list_of_social_interactions) < 3:
return ""
if "script" in episode.tag.split("_"):
overall_social_interaction = list_of_social_interactions[1:-3]
else:
overall_social_interaction = list_of_social_interactions[0:-3]
# only get the sentence after "Conversation Starts:\n\n"
starter_msg_list = overall_social_interaction[0].split(
"Conversation Starts:\n\n"
)
if len(starter_msg_list) < 3:
overall_social_interaction = list_of_social_interactions[1:-3]
# raise ValueError("The starter message is not in the expected format")
else:
overall_social_interaction = list_of_social_interactions[0:-3]
# only get the sentence after "Conversation Starts:\n\n"
starter_msg_list = overall_social_interaction[0].split(
"Conversation Starts:\n\n"
)
if len(starter_msg_list) < 3:
overall_social_interaction = list_of_social_interactions[1:-3]
# raise ValueError("The starter message is not in the expected format")
else:
overall_social_interaction[0] = starter_msg_list[-1]

social_interactions.append("\n\n".join(overall_social_interaction))
return social_interactions
overall_social_interaction[0] = starter_msg_list[-1]
return "\n\n".join(overall_social_interaction)


def episodes_to_csv(
Expand All @@ -147,8 +136,12 @@ def episodes_to_csv(
"agents_background": [
get_agents_background_from_episode(episode) for episode in episodes
],
"social_goals": get_social_goals_from_episode(episodes),
"social_interactions": get_social_interactions_from_episode(episodes),
"social_goals": [
get_agent_name_to_social_goal_from_episode(episode) for episode in episodes
],
"social_interactions": [
get_social_interactions_from_episode(episode) for episode in episodes
]
}
df = pd.DataFrame(data)
df.to_csv(filepath, index=False)
Expand All @@ -165,17 +158,13 @@ def episodes_to_json(
"""
with open(filepath, "w") as f:
for episode in episodes:
data = {
"episode_id": episode.pk,
"scenario": get_scenario_from_episode(episode),
"codename": get_codename_from_episode(episode),
"agents_background": get_agents_background_from_episode(
episode
),
"social_goals": get_social_goals_from_episode([episode]),
"social_interactions": get_social_interactions_from_episode(
[episode]
),
}
json.dump(data, f)
data = TwoAgentEpisodeWithScenarioBackgroundGoals(
episode_id=episode.pk,
scenario=get_scenario_from_episode(episode),
codename=get_codename_from_episode(episode),
agents_background=get_agents_background_from_episode(episode),
social_goals=get_agent_name_to_social_goal_from_episode(episode),
social_interactions=get_social_interactions_from_episode(episode),
)
json.dump(dict(data), f)
f.write("\n")
57 changes: 31 additions & 26 deletions tests/database/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,31 +53,8 @@ def _test_create_episode_log_setup_and_tear_down() -> Generator[
AgentProfile.delete("tmppk_agent2")
EpisodeLog.delete("tmppk_episode_log")


def test_get_agent_by_name(
_test_create_episode_log_setup_and_tear_down: Any,
) -> None:
agent_profile = AgentProfile.find(AgentProfile.first_name == "John").all()
assert agent_profile[0].pk == "tmppk_agent1"


def test_create_episode_log(
_test_create_episode_log_setup_and_tear_down: Any,
) -> None:
try:
_ = EpisodeLog(
environment="",
agents=["", ""],
messages=[],
rewards=[[0, 0, 0]],
reasoning=[""],
rewards_prompt="",
)
assert False
except Exception as e:
assert isinstance(e, ValidationError)

episode_log = EpisodeLog(
def create_dummy_episode_log() -> EpisodeLog:
episode = EpisodeLog(
environment="env",
agents=["tmppk_agent1", "tmppk_agent2"],
messages=[
Expand Down Expand Up @@ -125,7 +102,34 @@ def test_create_episode_log(
reasoning="",
pk="tmppk_episode_log",
rewards_prompt="",
)
)
return episode


def test_get_agent_by_name(
_test_create_episode_log_setup_and_tear_down: Any,
) -> None:
agent_profile = AgentProfile.find(AgentProfile.first_name == "John").all()
assert agent_profile[0].pk == "tmppk_agent1"


def test_create_episode_log(
_test_create_episode_log_setup_and_tear_down: Any,
) -> None:
try:
_ = EpisodeLog(
environment="",
agents=["", ""],
messages=[],
rewards=[[0, 0, 0]],
reasoning=[""],
rewards_prompt="",
)
assert False
except Exception as e:
assert isinstance(e, ValidationError)

episode_log = create_dummy_episode_log()
episode_log.save()
assert episode_log.pk == "tmppk_episode_log"
retrieved_episode_log: EpisodeLog = EpisodeLog.get(episode_log.pk)
Expand All @@ -137,3 +141,4 @@ def test_create_episode_log(
agent_profiles, messages_and_rewards = episode_log.render_for_humans()
assert len(agent_profiles) == 2
assert len(messages_and_rewards) == 4

0 comments on commit 903a5eb

Please sign in to comment.