From be8418613de0fc370f9426227152c3e3ed074923 Mon Sep 17 00:00:00 2001 From: Haofei Yu <1125027232@qq.com> Date: Sun, 14 Apr 2024 20:52:36 -0400 Subject: [PATCH] support agentprofile and environmentprofile serialization --- sotopia/database/__init__.py | 16 ++- sotopia/database/serialization.py | 170 +++++++++++++++++++++++++++++- 2 files changed, 184 insertions(+), 2 deletions(-) diff --git a/sotopia/database/__init__.py b/sotopia/database/__init__.py index 591ab1e2e..b64a57525 100644 --- a/sotopia/database/__init__.py +++ b/sotopia/database/__init__.py @@ -8,9 +8,16 @@ RelationshipType, ) from .serialization import ( + agentprofiles_to_csv, + agentprofiles_to_json, + environmentprofiles_to_csv, + environmentprofiles_to_json, episodes_to_csv, episodes_to_json, get_rewards_from_episode, + jsonl_to_agentprofiles, + jsonl_to_environmentprofiles, + jsonl_to_episodes, ) from .session_transaction import MessageTransaction, SessionTransaction from .waiting_room import MatchingInWaitingRoom @@ -28,7 +35,14 @@ "SessionTransaction", "MessageTransaction", "MatchingInWaitingRoom", + "agentprofiles_to_csv", + "agentprofiles_to_json", + "environmentprofiles_to_csv", + "environmentprofiles_to_json", "episodes_to_csv", "episodes_to_json", - "get_rewards_from_episodes", + "jsonl_to_agentprofiles", + "jsonl_to_environmentprofiles", + "jsonl_to_episodes", + "get_rewards_from_episode", ] diff --git a/sotopia/database/serialization.py b/sotopia/database/serialization.py index 939d988a5..60d2528d7 100644 --- a/sotopia/database/serialization.py +++ b/sotopia/database/serialization.py @@ -6,7 +6,6 @@ from .logs import EpisodeLog from .persistent_profile import AgentProfile, EnvironmentProfile - class TwoAgentEpisodeWithScenarioBackgroundGoals(BaseModel): episode_id: str = Field(required=True) scenario: str = Field(required=True) @@ -17,6 +16,34 @@ class TwoAgentEpisodeWithScenarioBackgroundGoals(BaseModel): reasoning: str = Field(required=False) rewards: list[dict[str, float]] = Field(required=False) +class AgentProfileWithPersonalInformation(BaseModel): + agent_id: str = Field(required=True) + first_name: str = Field(required=True) + last_name: str = Field(required=True) + age: int = Field(required=True) + occupation: str = Field(required=True) + gender: str = Field(required=True) + gender_pronoun: str = Field(required=True) + public_info: str = Field(required=True) + big_five: str = Field(required=True) + moral_values: list[str] = Field(required=True) + schwartz_personal_values: list[str] = Field(required=True) + personality_and_values: str = Field(required=True) + decision_making_style: str = Field(required=True) + secret: str = Field(required=True) + mbti: str = Field(required=True) + +class EnvironmentProfileWithTwoAgentRequirements(BaseModel): + scenario_id: str = Field(required=True) + codename: str = Field(required=True) + source: str = Field(required=True) + scenario: str = Field(required=True) + agent_goals: list[str] = Field(required=True) + relationship: str = Field(required=True) + age_constraint: str = Field(required=True) + occupation_constraint: str = Field(required=True) + agent_constraint: str = Field(required=True) + def _map_gender_to_adj(gender: str) -> str: gender_to_adj = { @@ -179,6 +206,107 @@ def episodes_to_json( json.dump(dict(data), f) f.write("\n") +def agentprofiles_to_csv( + agent_profiles: list[AgentProfile], csv_file_path: str = "agent_profiles.csv" +) -> None: + """Save agent profiles to a csv file. + + Args: + agent_profiles (list[AgentProfile]): List of agent profiles. + filepath (str, optional): The file path. Defaults to "agent_profiles.csv". + """ + data = { + "agent_id": [profile.pk for profile in agent_profiles], + "first_name": [profile.first_name for profile in agent_profiles], + "last_name": [profile.last_name for profile in agent_profiles], + "age": [profile.age for profile in agent_profiles], + "occupation": [profile.occupation for profile in agent_profiles], + } + df = pd.DataFrame(data) + df.to_csv(csv_file_path, index=False) + + +def agentprofiles_to_jsonl( + agent_profiles: list[AgentProfile], jsonl_file_path: str = "agent_profiles.jsonl" +) -> None: + """Save agent profiles to a json file. + + Args: + agent_profiles (list[AgentProfile]): List of agent profiles. + filepath (str, optional): The file path. Defaults to "agent_profiles.json". + """ + with open(jsonl_file_path, "w") as f: + for profile in agent_profiles: + data = AgentProfileWithPersonalInformation( + agent_id=profile.pk, + first_name=profile.first_name, + last_name=profile.last_name, + age=profile.age, + occupation=profile.occupation, + gender=profile.gender, + gender_pronoun=profile.gender_pronoun, + public_info=profile.public_info, + big_five=profile.big_five, + moral_values=profile.moral_values, + schwartz_personal_values=profile.schwartz_personal_values, + personality_and_values=profile.personality_and_values, + decision_making_style=profile.decision_making_style, + secret=profile.secret, + mbti=profile.mbti + ) + json.dump(dict(data), f) + f.write("\n") + + +def environmentprofiles_to_csv( + environment_profiles: list[EnvironmentProfile], csv_file_path: str = "environment_profiles.csv" +) -> None: + """Save environment profiles to a csv file. + + Args: + environment_profiles (list[EnvironmentProfile]): List of environment profiles. + filepath (str, optional): The file path. Defaults to "environment_profiles.csv". + """ + data = { + "scenario_id": [profile.pk for profile in environment_profiles], + "codename": [profile.codename for profile in environment_profiles], + "source": [profile.source for profile in environment_profiles], + "scenario": [profile.scenario for profile in environment_profiles], + "agent_goals": [profile.agent_goals for profile in environment_profiles], + "relationship": [profile.relationship for profile in environment_profiles], + "age_constraint": [profile.age_constraint for profile in environment_profiles], + "occupation_constraint": [profile.occupation_constraint for profile in environment_profiles], + "agent_constraint": [profile.agent_constraint for profile in environment_profiles], + } + df = pd.DataFrame(data) + df.to_csv(csv_file_path, index=False) + + +def environmentprofiles_to_jsonl( + environment_profiles: list[EnvironmentProfile], jsonl_file_path: str = "environment_profiles.jsonl" +) -> None: + """Save environment profiles to a json file. + + Args: + environment_profiles (list[EnvironmentProfile]): List of environment profiles. + filepath (str, optional): The file path. Defaults to "environment_profiles.json". + """ + with open(jsonl_file_path, "w") as f: + for profile in environment_profiles: + data = EnvironmentProfileWithTwoAgentRequirements( + scenario_id=profile.pk, + codename=profile.codename, + source=profile.source, + scenario=profile.scenario, + agent_goals=profile.agent_goals, + relationship=profile.relationship, + age_constraint=profile.age_constraint, + occupation_constraint=profile.occupation_constraint, + agent_constraint=profile.agent_constraint if profile.agent_constraint else "nan" + ) + json.dump(dict(data), f) + f.write("\n") + def jsonl_to_episodes( jsonl_file_path: str, @@ -198,3 +326,43 @@ def jsonl_to_episodes( episode = TwoAgentEpisodeWithScenarioBackgroundGoals(**data) episodes.append(episode) return episodes + + +def jsonl_to_agentprofiles( + jsonl_file_path: str, +) -> list[AgentProfileWithPersonalInformation]: + """Load agent profiles from a jsonl file. + + Args: + jsonl_file_path (str): The file path. + + Returns: + list[AgentProfileWithPersonalInformation]: List of agent profiles. + """ + agent_profiles = [] + with open(jsonl_file_path, "r") as f: + for line in f: + data = json.loads(line) + agent_profile = AgentProfileWithPersonalInformation(**data) + agent_profiles.append(agent_profile) + return agent_profiles + + +def jsonl_to_environmentprofiles( + jsonl_file_path: str, +) -> list[EnvironmentProfileWithTwoAgentRequirements]: + """Load environment profiles from a jsonl file. + + Args: + jsonl_file_path (str): The file path. + + Returns: + list[EnvironmentProfileWithTwoAgentSettings]: List of environment profiles. + """ + environment_profiles = [] + with open(jsonl_file_path, "r") as f: + for line in f: + data = json.loads(line) + environment_profile = EnvironmentProfileWithTwoAgentRequirements(**data) + environment_profiles.append(environment_profile) + return environment_profiles