Skip to content

Commit

Permalink
chore: Update benchmark tag to "benchmark_{model}_final"
Browse files Browse the repository at this point in the history
  • Loading branch information
XuhuiZhou committed May 29, 2024
1 parent 4906efe commit 0026b3e
Showing 1 changed file with 14 additions and 5 deletions.
19 changes: 14 additions & 5 deletions sotopia/benchmark/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ def check_existing_episodes(
assert isinstance(episode, EpisodeLog), "episode should be an EpisodeLog"
if episode.agents == agent_ids and episode.models == list(models.values()):
return True
breakpoint()
return False
else:
return False
Expand Down Expand Up @@ -98,14 +97,19 @@ def _iterate_all_env_agent_combo_not_in_db(
assert isinstance(agent_index, list), "agent_index should be a list"
envs_index_mapping: dict[str, list[str]] = {env_id: [] for env_id in set(hard_envs)}
# Repeat 10 times to match the number of combos
model_names_switched = {
"env": model_names["env"],
"agent1": model_names["agent2"],
"agent2": model_names["agent1"],
}
for _ in range(10):
for index, env_id in zip(agent_index, hard_envs):
envs_index_mapping[env_id].append(index)

for env_agent_combo_storage in env_agent_combo_storage_list:
agent_ids = env_agent_combo_storage.agent_ids
env_id = env_agent_combo_storage.env_id
if check_existing_episodes(env_id, agent_ids, model_names, tag):
if check_existing_episodes(env_id, agent_ids, model_names, tag) or check_existing_episodes(env_id=env_id, agent_ids=agent_ids, models=model_names_switched, tag=tag):
logging.info(
f"Episode for {env_id} with agents {agent_ids} using {list(model_names.values())} already exists"
)
Expand Down Expand Up @@ -157,7 +161,6 @@ def run_async_benchmark_in_batch(
model_names=model_names, tag=tag, env_agent_combo_storage_list=benchmark_combo
)
env_agent_combo_iter_length = sum(1 for _ in env_agent_combo_iter)
breakpoint()
env_agent_combo_iter = _iterate_all_env_agent_combo_not_in_db(
model_names=model_names, tag=tag, env_agent_combo_storage_list=benchmark_combo
)
Expand Down Expand Up @@ -210,13 +213,19 @@ def run_async_benchmark_in_batch(
pk = episode.pk
assert isinstance(pk, str)
EpisodeLog.delete(pk)
breakpoint()

env_agent_combo_iter = _iterate_all_env_agent_combo_not_in_db(
model_names=model_names,
tag=tag,
env_agent_combo_storage_list=benchmark_combo,
)
env_agent_combo_iter_length = sum(1 for _ in env_agent_combo_iter)
env_agent_combo_iter = _iterate_all_env_agent_combo_not_in_db(
model_names=model_names,
tag=tag,
env_agent_combo_storage_list=benchmark_combo,
)
env_agent_combo_batch = []
number_of_fix_turns += 1
if env_agent_combo_iter_length == 0 or number_of_fix_turns >= 5:
Expand Down Expand Up @@ -271,7 +280,7 @@ def cli(
model = cast(LLM_Name, model)
partner_model = cast(LLM_Name, partner_model)
evaluator_model = cast(LLM_Name, evaluator_model)
tag = f"benchmark_{model}_q"
tag = f"benchmark_{model}_final"
run_async_benchmark_in_batch(
batch_size=batch_size,
model_names={"env": evaluator_model, "agent1": model, "agent2": partner_model},
Expand Down

0 comments on commit 0026b3e

Please sign in to comment.