Skip to content

Commit

Permalink
support agentprofile and environmentprofile serialization
Browse files Browse the repository at this point in the history
  • Loading branch information
lwaekfjlk committed Apr 15, 2024
1 parent c50169b commit be84186
Show file tree
Hide file tree
Showing 2 changed files with 184 additions and 2 deletions.
16 changes: 15 additions & 1 deletion sotopia/database/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,16 @@
RelationshipType,
)
from .serialization import (
agentprofiles_to_csv,
agentprofiles_to_json,
environmentprofiles_to_csv,
environmentprofiles_to_json,
episodes_to_csv,
episodes_to_json,
get_rewards_from_episode,
jsonl_to_agentprofiles,
jsonl_to_environmentprofiles,
jsonl_to_episodes,
)
from .session_transaction import MessageTransaction, SessionTransaction
from .waiting_room import MatchingInWaitingRoom
Expand All @@ -28,7 +35,14 @@
"SessionTransaction",
"MessageTransaction",
"MatchingInWaitingRoom",
"agentprofiles_to_csv",
"agentprofiles_to_json",
"environmentprofiles_to_csv",
"environmentprofiles_to_json",
"episodes_to_csv",
"episodes_to_json",
"get_rewards_from_episodes",
"jsonl_to_agentprofiles",
"jsonl_to_environmentprofiles",
"jsonl_to_episodes",
"get_rewards_from_episode",
]
170 changes: 169 additions & 1 deletion sotopia/database/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from .logs import EpisodeLog
from .persistent_profile import AgentProfile, EnvironmentProfile


class TwoAgentEpisodeWithScenarioBackgroundGoals(BaseModel):
episode_id: str = Field(required=True)
scenario: str = Field(required=True)
Expand All @@ -17,6 +16,34 @@ class TwoAgentEpisodeWithScenarioBackgroundGoals(BaseModel):
reasoning: str = Field(required=False)
rewards: list[dict[str, float]] = Field(required=False)

class AgentProfileWithPersonalInformation(BaseModel):
agent_id: str = Field(required=True)
first_name: str = Field(required=True)
last_name: str = Field(required=True)
age: int = Field(required=True)
occupation: str = Field(required=True)
gender: str = Field(required=True)
gender_pronoun: str = Field(required=True)
public_info: str = Field(required=True)
big_five: str = Field(required=True)
moral_values: list[str] = Field(required=True)
schwartz_personal_values: list[str] = Field(required=True)
personality_and_values: str = Field(required=True)
decision_making_style: str = Field(required=True)
secret: str = Field(required=True)
mbti: str = Field(required=True)

class EnvironmentProfileWithTwoAgentRequirements(BaseModel):
scenario_id: str = Field(required=True)
codename: str = Field(required=True)
source: str = Field(required=True)
scenario: str = Field(required=True)
agent_goals: list[str] = Field(required=True)
relationship: str = Field(required=True)
age_constraint: str = Field(required=True)
occupation_constraint: str = Field(required=True)
agent_constraint: str = Field(required=True)


def _map_gender_to_adj(gender: str) -> str:
gender_to_adj = {
Expand Down Expand Up @@ -179,6 +206,107 @@ def episodes_to_json(
json.dump(dict(data), f)
f.write("\n")

def agentprofiles_to_csv(
agent_profiles: list[AgentProfile], csv_file_path: str = "agent_profiles.csv"
) -> None:
"""Save agent profiles to a csv file.
Args:
agent_profiles (list[AgentProfile]): List of agent profiles.
filepath (str, optional): The file path. Defaults to "agent_profiles.csv".
"""
data = {
"agent_id": [profile.pk for profile in agent_profiles],
"first_name": [profile.first_name for profile in agent_profiles],
"last_name": [profile.last_name for profile in agent_profiles],
"age": [profile.age for profile in agent_profiles],
"occupation": [profile.occupation for profile in agent_profiles],
}
df = pd.DataFrame(data)
df.to_csv(csv_file_path, index=False)


def agentprofiles_to_jsonl(
agent_profiles: list[AgentProfile], jsonl_file_path: str = "agent_profiles.jsonl"
) -> None:
"""Save agent profiles to a json file.
Args:
agent_profiles (list[AgentProfile]): List of agent profiles.
filepath (str, optional): The file path. Defaults to "agent_profiles.json".
"""
with open(jsonl_file_path, "w") as f:
for profile in agent_profiles:
data = AgentProfileWithPersonalInformation(
agent_id=profile.pk,
first_name=profile.first_name,
last_name=profile.last_name,
age=profile.age,
occupation=profile.occupation,
gender=profile.gender,
gender_pronoun=profile.gender_pronoun,
public_info=profile.public_info,
big_five=profile.big_five,
moral_values=profile.moral_values,
schwartz_personal_values=profile.schwartz_personal_values,
personality_and_values=profile.personality_and_values,
decision_making_style=profile.decision_making_style,
secret=profile.secret,
mbti=profile.mbti
)
json.dump(dict(data), f)
f.write("\n")


def environmentprofiles_to_csv(
environment_profiles: list[EnvironmentProfile], csv_file_path: str = "environment_profiles.csv"
) -> None:
"""Save environment profiles to a csv file.
Args:
environment_profiles (list[EnvironmentProfile]): List of environment profiles.
filepath (str, optional): The file path. Defaults to "environment_profiles.csv".
"""
data = {
"scenario_id": [profile.pk for profile in environment_profiles],
"codename": [profile.codename for profile in environment_profiles],
"source": [profile.source for profile in environment_profiles],
"scenario": [profile.scenario for profile in environment_profiles],
"agent_goals": [profile.agent_goals for profile in environment_profiles],
"relationship": [profile.relationship for profile in environment_profiles],
"age_constraint": [profile.age_constraint for profile in environment_profiles],
"occupation_constraint": [profile.occupation_constraint for profile in environment_profiles],
"agent_constraint": [profile.agent_constraint for profile in environment_profiles],
}
df = pd.DataFrame(data)
df.to_csv(csv_file_path, index=False)


def environmentprofiles_to_jsonl(
environment_profiles: list[EnvironmentProfile], jsonl_file_path: str = "environment_profiles.jsonl"
) -> None:
"""Save environment profiles to a json file.
Args:
environment_profiles (list[EnvironmentProfile]): List of environment profiles.
filepath (str, optional): The file path. Defaults to "environment_profiles.json".
"""
with open(jsonl_file_path, "w") as f:
for profile in environment_profiles:
data = EnvironmentProfileWithTwoAgentRequirements(
scenario_id=profile.pk,
codename=profile.codename,
source=profile.source,
scenario=profile.scenario,
agent_goals=profile.agent_goals,
relationship=profile.relationship,
age_constraint=profile.age_constraint,
occupation_constraint=profile.occupation_constraint,
agent_constraint=profile.agent_constraint if profile.agent_constraint else "nan"
)
json.dump(dict(data), f)
f.write("\n")


def jsonl_to_episodes(
jsonl_file_path: str,
Expand All @@ -198,3 +326,43 @@ def jsonl_to_episodes(
episode = TwoAgentEpisodeWithScenarioBackgroundGoals(**data)
episodes.append(episode)
return episodes


def jsonl_to_agentprofiles(
jsonl_file_path: str,
) -> list[AgentProfileWithPersonalInformation]:
"""Load agent profiles from a jsonl file.
Args:
jsonl_file_path (str): The file path.
Returns:
list[AgentProfileWithPersonalInformation]: List of agent profiles.
"""
agent_profiles = []
with open(jsonl_file_path, "r") as f:
for line in f:
data = json.loads(line)
agent_profile = AgentProfileWithPersonalInformation(**data)
agent_profiles.append(agent_profile)
return agent_profiles


def jsonl_to_environmentprofiles(
jsonl_file_path: str,
) -> list[EnvironmentProfileWithTwoAgentRequirements]:
"""Load environment profiles from a jsonl file.
Args:
jsonl_file_path (str): The file path.
Returns:
list[EnvironmentProfileWithTwoAgentSettings]: List of environment profiles.
"""
environment_profiles = []
with open(jsonl_file_path, "r") as f:
for line in f:
data = json.loads(line)
environment_profile = EnvironmentProfileWithTwoAgentRequirements(**data)
environment_profiles.append(environment_profile)
return environment_profiles

0 comments on commit be84186

Please sign in to comment.