Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sotopia Benchmark CLI API #69

Merged
merged 48 commits into from
Jun 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
d0811b6
add benchmark social agents
XuhuiZhou May 16, 2024
51a9132
add benchmark agents
XuhuiZhou May 16, 2024
b576be7
Add sotopia_benchmark cli api
ProKil May 16, 2024
7922e16
fix pre-commit
ProKil May 16, 2024
9c13c02
add evaluator model argument
ProKil May 16, 2024
165c245
Merge branch 'main' into feature/benchmark_agents.py
XuhuiZhou May 23, 2024
27aa565
finish benchmarking
XuhuiZhou May 23, 2024
b625cfc
benchmark done
XuhuiZhou May 24, 2024
7066a2a
chore: Fix formatting issue in redis_stats.ipynb and cli.py
XuhuiZhou May 25, 2024
b710ea6
switch back to LLM_Name
XuhuiZhou May 25, 2024
385a53c
Merge branch 'main' into feature/benchmark_agents.py
XuhuiZhou May 26, 2024
ad9e741
merge main
XuhuiZhou May 26, 2024
ee87f84
add together ai
XuhuiZhou May 26, 2024
41ad0e3
fix naming error
XuhuiZhou May 26, 2024
5ae2b26
roll back to llama2
XuhuiZhou May 26, 2024
af4e2ee
chore: Update langchain-together dependency to version 0.1.2
XuhuiZhou May 26, 2024
5361d0c
use chatopenai for together models
ProKil May 26, 2024
35be2a2
Merge branch 'feature/benchmark_agents.py' of github.com:sotopia-lab/…
ProKil May 26, 2024
7087f40
add logging
ProKil May 26, 2024
01b6291
Merge remote-tracking branch 'origin/main' into feature/benchmark_age…
ProKil May 26, 2024
4be892d
fix pre-commit
ProKil May 26, 2024
e8a59d2
add more logging options
ProKil May 27, 2024
9e15790
probably fix the event loop closed error: following https://github.co…
ProKil May 27, 2024
4906efe
modify cli; fix model position bug
XuhuiZhou May 28, 2024
0026b3e
chore: Update benchmark tag to "benchmark_{model}_final"
XuhuiZhou May 29, 2024
1d873be
Refactor _iterate_all_env_agent_combo_not_in_db function
XuhuiZhou May 29, 2024
cb49079
chore: Update python version to 3.11.2
XuhuiZhou May 29, 2024
f983b44
change to dict comparison
XuhuiZhou Jun 1, 2024
dbc873c
ignore jsonl
XuhuiZhou Jun 4, 2024
6a79d7e
✨ finish benchmarking script
XuhuiZhou Jun 5, 2024
9a38b93
chore: Refactor server.py and redis_stats.ipynb
XuhuiZhou Jun 5, 2024
0996090
add type ignore
XuhuiZhou Jun 5, 2024
f1a626b
push for the eval
XuhuiZhou Jun 5, 2024
b567224
Refactor run_async_benchmark_in_batch function
XuhuiZhou Jun 5, 2024
bfadacd
Refactor run_async_benchmark_in_batch function
XuhuiZhou Jun 5, 2024
9a5dce9
add doc
XuhuiZhou Jun 5, 2024
3dbde75
precommit fix
XuhuiZhou Jun 5, 2024
897389a
pre-commit
XuhuiZhou Jun 5, 2024
1c6b2b9
Merge remote-tracking branch 'origin/main' into feature/benchmark_age…
ProKil Jun 6, 2024
b7fd148
refactor
XuhuiZhou Jun 9, 2024
73c2c92
Merge branch 'feature/benchmark_agents.py' of github.com:sotopia-lab/…
XuhuiZhou Jun 9, 2024
df1bfde
Merge branch 'main' into feature/benchmark_agents.py
XuhuiZhou Jun 14, 2024
8352ea7
update w feedback
XuhuiZhou Jun 14, 2024
f6c86cb
pre commit
XuhuiZhou Jun 14, 2024
6bcb178
chore: Update authors in pyproject.toml and fetch benchmark_agents.js…
XuhuiZhou Jun 14, 2024
f5c269c
hotfix
XuhuiZhou Jun 15, 2024
98fe351
Merge branch 'main' into feature/benchmark_agents.py
XuhuiZhou Jun 15, 2024
7165109
chore: Remove unnecessary type hint in benchmark/cli.py
XuhuiZhou Jun 15, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ data/*
deprecated/*

*.csv

*.jsonl
#backup
backup/*

Expand Down
11 changes: 11 additions & 0 deletions docs/pages/benchmark.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Benchmark your model as a social agent in Sotopia

```
sotopia_benchmark --model=<your_model_name>
```
or

```
python sotopia/benchmark/cli.py --model=<your_model_name>
```
Currently this script would run over 100 simulations on the Sotopia Hard tasks. And the partner model is fixed to be `meta-llama/Llama-3-70b-chat-hf`
99 changes: 95 additions & 4 deletions notebooks/redis_stats.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,8 @@
"metadata": {},
"source": [
"## EnvAgentComboStorage\n",
"\n",
"Identify the combination of environment and agent that is used in the episodes.\n",
"Once we have the combination, we can use it to start the simulation.\n",
"Combo is a combination of Environment and two agents."
]
},
Expand All @@ -204,9 +205,99 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## Episode Log\n",
"## EnvironmentList\n",
"Store a list of special environments (e.g., sotopia hard) that can be used to start certain simulations. Agent index is used to identify the special agent in the simulation."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from sotopia.database.persistent_profile import EnvironmentList\n",
"\n",
"all_list = EnvironmentList.all_pks()\n",
"all_list = list(all_list)\n",
"print(len(all_list))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from sotopia.samplers import ConstraintBasedSampler\n",
"from sotopia.messages import AgentAction, Observation\n",
"from sotopia.agents import LLMAgent\n",
"import json\n",
"# In this example, we will demonstrate using the EnvironmentList class to sample a list of EnvAgentComboStorage and serialize it to a json file that can be used for sharing with others for benchmarking purposes.\n",
"\n",
"\n",
"def _sample_env_agent_combo_and_push_to_db(env_id: str) -> list[EnvAgentComboStorage]:\n",
" combo_list = []\n",
" sampler = ConstraintBasedSampler[Observation, AgentAction](env_candidates=[env_id])\n",
" env_agent_combo_list = list(\n",
" sampler.sample(agent_classes=[LLMAgent] * 2, replacement=False, size=10)\n",
" )\n",
" for env, agent in env_agent_combo_list:\n",
" combo = EnvAgentComboStorage(\n",
" env_id=env.profile.pk,\n",
" agent_ids=[agent[0].profile.pk, agent[1].profile.pk],\n",
" )\n",
" combo_list.append(combo)\n",
" return combo_list\n",
"\n",
"\n",
"# First we will extrat the hard environments from the EnvironmentList\n",
"hard_envs = EnvironmentList.get(\"01HAK34YPB1H1RWXQDASDKHSNS\").environments\n",
"print(len(hard_envs))\n",
"hard_envs_set = set(hard_envs)\n",
"\n",
"# Next we will sample 10 EnvAgentComboStorage from each hard environment\n",
"final_list_for_benchmark_agents = []\n",
"for env in hard_envs_set:\n",
" combo_list = EnvAgentComboStorage.find(EnvAgentComboStorage.env_id == env).all()\n",
" print(len(combo_list))\n",
" final_list_for_benchmark_agents.extend(combo_list)\n",
"\n",
"# Finally we will serialize the list to a json file\n",
"with open(\"../data/benchmark_agents.json\", \"w\") as f:\n",
" json.dump(\n",
" [combo.dict() for combo in final_list_for_benchmark_agents],\n",
" f,\n",
" indent=4,\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"EnvironmentList.get(\"01HAK34YPB1H1RWXQDASDKHSNS\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Episode Log"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# find episode log by tag\n",
"Episodes = EpisodeLog.find(EpisodeLog.tag == \"aug20_gpt4_llama-2-70b-chat_zqi2\").all()\n",
"len(Episodes) ## Episode Log\n",
"\n",
"Episodelog stores the social conversation between two agents in an environment."
"## Episodelog stores the social conversation between two agents in an environment."
]
},
{
Expand Down Expand Up @@ -289,7 +380,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
"version": "3.11.2"
}
},
"nbformat": 4,
Expand Down
Loading
Loading