diff --git a/sotopia/benchmark/cli.py b/sotopia/benchmark/cli.py index f275982c6..e4ecdd3c5 100644 --- a/sotopia/benchmark/cli.py +++ b/sotopia/benchmark/cli.py @@ -52,7 +52,13 @@ def check_existing_episodes( if existing_episode: for episode in existing_episode: assert isinstance(episode, EpisodeLog), "episode should be an EpisodeLog" - if episode.agents == agent_ids and episode.models == list(models.values()): + assert isinstance(episode.models, list), "episode.models should be a list" + episode_models_dict = { + 'env': episode.models[0], + 'agent1': episode.models[1], + 'agent2': episode.models[2], + } + if episode.agents == agent_ids and episode_models_dict == models: return True return False else: @@ -90,7 +96,7 @@ def _iterate_all_env_agent_combo_not_in_db( model_names: dict[str, LLM_Name], env_agent_combo_storage_list: list[EnvAgentComboStorage], tag: str | None = None, -) -> Generator[EnvAgentCombo[Observation, AgentAction], None, None]: +) -> list[EnvAgentCombo[Observation, AgentAction]]: """We iterate over each environment and return the **first** env-agent combo that is not in the database.""" hard_envs = EnvironmentList.get("01HAK34YPB1H1RWXQDASDKHSNS").environments agent_index = EnvironmentList.get("01HAK34YPB1H1RWXQDASDKHSNS").agent_index