Skip to content

Commit

Permalink
Add rewards to the logged episodes in jsonl
Browse files Browse the repository at this point in the history
  • Loading branch information
ProKil committed Apr 5, 2024
1 parent ca40fc4 commit 4d59569
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 16 deletions.
4 changes: 2 additions & 2 deletions sotopia/database/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
RelationshipProfile,
RelationshipType,
)
from .session_transaction import MessageTransaction, SessionTransaction
from .utils import (
from .serialization import (
episodes_to_csv,
episodes_to_json,
get_rewards_from_episode,
)
from .session_transaction import MessageTransaction, SessionTransaction
from .waiting_room import MatchingInWaitingRoom

__all__ = [
Expand Down
50 changes: 36 additions & 14 deletions sotopia/database/utils.py → sotopia/database/serialization.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
import json

import pandas as pd
from pydantic import ConstrainedList, conlist, root_validator
from redis_om import HashModel, JsonModel
from redis_om.model.model import Field
from pydantic import BaseModel, Field

from .logs import EpisodeLog
from .persistent_profile import AgentProfile, EnvironmentProfile


class TwoAgentEpisodeWithScenarioBackgroundGoals(JsonModel):
episode_id: str = Field(index=True)
scenario: str = Field(index=True)
codename: str = Field(index=True)
agents_background: dict[str, str] = Field(index=True)
social_goals: dict[str, str] = Field(index=True)
social_interactions: str = Field(index=True)
class TwoAgentEpisodeWithScenarioBackgroundGoals(BaseModel):
episode_id: str = Field(required=True)
scenario: str = Field(required=True)
codename: str = Field(required=True)
agents_background: dict[str, str] = Field(required=True)
social_goals: dict[str, str] = Field(required=True)
social_interactions: str = Field(required=True)
reasoning: str = Field(required=False)
rewards: list[dict[str, float]] = Field(required=False)


def _map_gender_to_adj(gender: str) -> str:
Expand Down Expand Up @@ -119,7 +119,7 @@ def get_social_interactions_from_episode(


def episodes_to_csv(
episodes: list[EpisodeLog], filepath: str = "episodes.csv"
episodes: list[EpisodeLog], csv_file_path: str = "episodes.csv"
) -> None:
"""Save episodes to a csv file.
Expand Down Expand Up @@ -148,19 +148,19 @@ def episodes_to_csv(
],
}
df = pd.DataFrame(data)
df.to_csv(filepath, index=False)
df.to_csv(csv_file_path, index=False)


def episodes_to_json(
episodes: list[EpisodeLog], filepath: str = "episodes.json"
episodes: list[EpisodeLog], jsonl_file_path: str = "episodes.jsonl"
) -> None:
"""Save episodes to a json file.
Args:
episodes (list[EpisodeLog]): List of episodes.
filepath (str, optional): The file path. Defaults to "episodes.json".
"""
with open(filepath, "w") as f:
with open(jsonl_file_path, "w") as f:
for episode in episodes:
data = TwoAgentEpisodeWithScenarioBackgroundGoals(
episode_id=episode.pk,
Expand All @@ -173,6 +173,28 @@ def episodes_to_json(
social_interactions=get_social_interactions_from_episode(
episode
),
reasoning=episode.reasoning,
rewards=get_rewards_from_episode(episode),
)
json.dump(dict(data), f)
f.write("\n")


def jsonl_to_episodes(
jsonl_file_path: str,
) -> list[TwoAgentEpisodeWithScenarioBackgroundGoals]:
"""Load episodes from a jsonl file.
Args:
jsonl_file_path (str): The file path.
Returns:
list[TwoAgentEpisodeWithScenarioBackgroundGoals]: List of episodes.
"""
episodes = []
with open(jsonl_file_path, "r") as f:
for line in f:
data = json.loads(line)
episode = TwoAgentEpisodeWithScenarioBackgroundGoals(**data)
episodes.append(episode)
return episodes

0 comments on commit 4d59569

Please sign in to comment.