From 0026b3e55326d2160353015edc1ed74432dedd31 Mon Sep 17 00:00:00 2001 From: XuhuiZhou Date: Wed, 29 May 2024 08:43:48 -0700 Subject: [PATCH] chore: Update benchmark tag to "benchmark_{model}_final" --- sotopia/benchmark/cli.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/sotopia/benchmark/cli.py b/sotopia/benchmark/cli.py index fc768e2d1..3c8e9c144 100644 --- a/sotopia/benchmark/cli.py +++ b/sotopia/benchmark/cli.py @@ -54,7 +54,6 @@ def check_existing_episodes( assert isinstance(episode, EpisodeLog), "episode should be an EpisodeLog" if episode.agents == agent_ids and episode.models == list(models.values()): return True - breakpoint() return False else: return False @@ -98,14 +97,19 @@ def _iterate_all_env_agent_combo_not_in_db( assert isinstance(agent_index, list), "agent_index should be a list" envs_index_mapping: dict[str, list[str]] = {env_id: [] for env_id in set(hard_envs)} # Repeat 10 times to match the number of combos + model_names_switched = { + "env": model_names["env"], + "agent1": model_names["agent2"], + "agent2": model_names["agent1"], + } for _ in range(10): for index, env_id in zip(agent_index, hard_envs): envs_index_mapping[env_id].append(index) - + for env_agent_combo_storage in env_agent_combo_storage_list: agent_ids = env_agent_combo_storage.agent_ids env_id = env_agent_combo_storage.env_id - if check_existing_episodes(env_id, agent_ids, model_names, tag): + if check_existing_episodes(env_id, agent_ids, model_names, tag) or check_existing_episodes(env_id=env_id, agent_ids=agent_ids, models=model_names_switched, tag=tag): logging.info( f"Episode for {env_id} with agents {agent_ids} using {list(model_names.values())} already exists" ) @@ -157,7 +161,6 @@ def run_async_benchmark_in_batch( model_names=model_names, tag=tag, env_agent_combo_storage_list=benchmark_combo ) env_agent_combo_iter_length = sum(1 for _ in env_agent_combo_iter) - breakpoint() env_agent_combo_iter = _iterate_all_env_agent_combo_not_in_db( model_names=model_names, tag=tag, env_agent_combo_storage_list=benchmark_combo ) @@ -210,6 +213,7 @@ def run_async_benchmark_in_batch( pk = episode.pk assert isinstance(pk, str) EpisodeLog.delete(pk) + breakpoint() env_agent_combo_iter = _iterate_all_env_agent_combo_not_in_db( model_names=model_names, @@ -217,6 +221,11 @@ def run_async_benchmark_in_batch( env_agent_combo_storage_list=benchmark_combo, ) env_agent_combo_iter_length = sum(1 for _ in env_agent_combo_iter) + env_agent_combo_iter = _iterate_all_env_agent_combo_not_in_db( + model_names=model_names, + tag=tag, + env_agent_combo_storage_list=benchmark_combo, + ) env_agent_combo_batch = [] number_of_fix_turns += 1 if env_agent_combo_iter_length == 0 or number_of_fix_turns >= 5: @@ -271,7 +280,7 @@ def cli( model = cast(LLM_Name, model) partner_model = cast(LLM_Name, partner_model) evaluator_model = cast(LLM_Name, evaluator_model) - tag = f"benchmark_{model}_q" + tag = f"benchmark_{model}_final" run_async_benchmark_in_batch( batch_size=batch_size, model_names={"env": evaluator_model, "agent1": model, "agent2": partner_model},