Skip to content

Commit

Permalink
using ruff to format
Browse files Browse the repository at this point in the history
  • Loading branch information
ProKil committed Apr 4, 2024
1 parent d91f5fb commit 7a0926f
Show file tree
Hide file tree
Showing 60 changed files with 540 additions and 1,085 deletions.
19 changes: 10 additions & 9 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,17 @@ repos:
hooks:
- id: prettier
types_or: [html]
- repo: https://github.com/psf/black
rev: 22.12.0
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.3.5
hooks:
- id: black
args: [--line-length=79]
- repo: https://github.com/pycqa/isort
rev: 5.12.0
hooks:
- id: isort
args: ["--profile", "black", --line-length=72]
# Run the linter.
- id: ruff
types_or: [ python, pyi, jupyter ]
args: [ --fix ]
# Run the formatter.
- id: ruff-format
types_or: [ python, pyi, jupyter ]
- repo: https://github.com/kynan/nbstripout
rev: 0.6.0
hooks:
Expand Down
36 changes: 6 additions & 30 deletions examples/evaluate_existing_episode.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,29 +5,14 @@
from datetime import datetime
from logging import FileHandler

import gin
import typer
from experiment_eval import _iterate_env_agent_combo_not_in_db
from rich import print
from rich.logging import RichHandler
from tqdm import tqdm
from tqdm.asyncio import tqdm_asyncio

from sotopia.agents.llm_agent import Agents
from sotopia.database.logs import AnnotationForEpisode, EpisodeLog
from sotopia.database.persistent_profile import EnvironmentProfile
from sotopia.generation_utils.generate import LLM_Name, agenerate_script
from sotopia.messages.message_classes import (
AgentAction,
Observation,
ScriptBackground,
)
from sotopia.samplers import (
BaseSampler,
ConstraintBasedSampler,
EnvAgentCombo,
)
from sotopia.server import aevaluate_one_episode, arun_one_script
from sotopia.generation_utils.generate import LLM_Name
from sotopia.server import aevaluate_one_episode

# date and message only
FORMAT = "%(asctime)s - %(levelname)s - %(name)s - %(message)s"
Expand Down Expand Up @@ -61,7 +46,6 @@ def run_async_server_in_batch_aevaluate(
push_to_db: bool = False,
verbose: bool = False,
) -> None:

if not verbose:
logger = logging.getLogger()
logger.setLevel(logging.CRITICAL)
Expand All @@ -75,9 +59,7 @@ def run_async_server_in_batch_aevaluate(
episode = EpisodeLog.get(env_pk)
episode_batch.append(episode)
if len(episode_batch) == batch_size:
logging.info(
f"Running batch of {batch_size} episodes: {episode_batch}"
)
logging.info(f"Running batch of {batch_size} episodes: {episode_batch}")
episode_futures = [
aevaluate_one_episode(
episode=episode,
Expand All @@ -88,17 +70,13 @@ def run_async_server_in_batch_aevaluate(
for episode in episode_batch
]
asyncio.run(
tqdm_asyncio.gather(
*episode_futures, desc="Running one batch"
)
tqdm_asyncio.gather(*episode_futures, desc="Running one batch")
)

episode_batch = []
else:
if episode_batch:
logging.info(
f"Running batch of {batch_size} episodes: {episode_batch}"
)
logging.info(f"Running batch of {batch_size} episodes: {episode_batch}")
episode_futures = [
aevaluate_one_episode(
episode=episode,
Expand All @@ -109,9 +87,7 @@ def run_async_server_in_batch_aevaluate(
for episode in episode_batch
]
asyncio.run(
tqdm_asyncio.gather(
*episode_futures, desc="Running one batch"
)
tqdm_asyncio.gather(*episode_futures, desc="Running one batch")
)
return

Expand Down
48 changes: 13 additions & 35 deletions examples/experiment_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,12 @@
import logging
import os
import subprocess
import sys
from datetime import datetime
from logging import FileHandler
from typing import Any, Callable, Generator, Literal, Sequence, cast
from typing import Any, Generator, cast

import gin
from absl import app, flags
from rich import print
from absl import flags
from rich.logging import RichHandler
from tqdm import tqdm

Expand All @@ -26,7 +24,7 @@
)
from sotopia.envs.parallel import ParallelSotopiaEnv
from sotopia.generation_utils.generate import LLM_Name
from sotopia.messages import AgentAction, Message, Observation
from sotopia.messages import AgentAction, Observation
from sotopia.samplers import (
BaseSampler,
ConstraintBasedSampler,
Expand Down Expand Up @@ -79,27 +77,19 @@ def check_existing_episodes(
(EpisodeLog.environment == env_id) & (EpisodeLog.tag == tag)
).all()
else:
existing_episode = EpisodeLog.find(
EpisodeLog.environment == env_id
).all()
existing_episode = EpisodeLog.find(EpisodeLog.environment == env_id).all()
if existing_episode:
for episode in existing_episode:
assert isinstance(
episode, EpisodeLog
), "episode should be an EpisodeLog"
if episode.agents == agent_ids and episode.models == list(
models.values()
):
assert isinstance(episode, EpisodeLog), "episode should be an EpisodeLog"
if episode.agents == agent_ids and episode.models == list(models.values()):
return True
return False
else:
return False


def _sample_env_agent_combo_and_push_to_db(env_id: str) -> None:
sampler = ConstraintBasedSampler[Observation, AgentAction](
env_candidates=[env_id]
)
sampler = ConstraintBasedSampler[Observation, AgentAction](env_candidates=[env_id])
env_agent_combo_list = list(
sampler.sample(agent_classes=[LLMAgent] * 2, replacement=False)
)
Expand All @@ -122,21 +112,15 @@ def _iterate_env_agent_combo_not_in_db(
for env_id in env_ids:
assert env_id is not None, "env_id should not be None"
env_agent_combo_storage_list = list(
EnvAgentComboStorage.find(
EnvAgentComboStorage.env_id == env_id
).all()
EnvAgentComboStorage.find(EnvAgentComboStorage.env_id == env_id).all()
)
if not env_agent_combo_storage_list:
_sample_env_agent_combo_and_push_to_db(env_id)
env_agent_combo_storage_list = list(
EnvAgentComboStorage.find(
EnvAgentComboStorage.env_id == env_id
).all()
EnvAgentComboStorage.find(EnvAgentComboStorage.env_id == env_id).all()
)
assert env_agent_combo_storage_list
first_env_agent_combo_storage_to_run: EnvAgentComboStorage | None = (
None
)
first_env_agent_combo_storage_to_run: EnvAgentComboStorage | None = None
for env_agent_combo_storage in env_agent_combo_storage_list:
env_agent_combo_storage = cast(
EnvAgentComboStorage, env_agent_combo_storage
Expand All @@ -156,9 +140,7 @@ def _iterate_env_agent_combo_not_in_db(
model_name=model_names["env"],
action_order="round-robin",
evaluators=[
RuleBasedTerminatedEvaluator(
max_turn_number=20, max_stale_turn=2
),
RuleBasedTerminatedEvaluator(max_turn_number=20, max_stale_turn=2),
],
terminal_evaluators=[
ReachGoalLLMEvaluator(model_names["env"]),
Expand Down Expand Up @@ -196,14 +178,10 @@ def run_async_server_in_batch(
logger.removeHandler(rich_handler)

# we cannot get the exact length of the generator, we just give an estimate of the length
env_agent_combo_iter = _iterate_env_agent_combo_not_in_db(
model_names=model_names
)
env_agent_combo_iter = _iterate_env_agent_combo_not_in_db(model_names=model_names)
env_agent_combo_iter_length = sum(1 for _ in env_agent_combo_iter)

env_agent_combo_iter = _iterate_env_agent_combo_not_in_db(
model_names=model_names
)
env_agent_combo_iter = _iterate_env_agent_combo_not_in_db(model_names=model_names)
env_agent_combo_batch: list[EnvAgentCombo[Observation, AgentAction]] = []

while True:
Expand Down
61 changes: 18 additions & 43 deletions examples/fix_missing_episodes.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,12 @@
import asyncio
import logging
from collections import Counter, defaultdict
from typing import (
Any,
Dict,
Generator,
List,
Literal,
Optional,
Set,
cast,
)
from typing import Any, Dict, Generator, List, Set, cast

import gin
from absl import flags
from absl.flags import FLAGS
from rich.logging import RichHandler
from rich.terminal_theme import MONOKAI
from tqdm import tqdm

from sotopia.agents.llm_agent import LLMAgent
Expand Down Expand Up @@ -51,14 +41,15 @@
],
)


# get all episode logs
def get_all_episodes() -> List[EpisodeLog]:
episode_pks: List[str] = list(EpisodeLog.all_pks())
all_episodes = []
for pk in tqdm(episode_pks):
try:
curr_ep = EpisodeLog.get(pk)
except:
except Exception as _:
continue
all_episodes.append(curr_ep)
print(f"all episodes loaded {len(all_episodes)}")
Expand All @@ -74,14 +65,10 @@ def get_all_env_agent_combos(

for env_pk in experiment_env_pks:
env_agent_combo_storage_list = list(
EnvAgentComboStorage.find(
EnvAgentComboStorage.env_id == env_pk
).all()
EnvAgentComboStorage.find(EnvAgentComboStorage.env_id == env_pk).all()
)[start_combo_idx:end_combo_idx]
for combo in env_agent_combo_storage_list:
all_combos_map[cast(str, combo.pk)] = cast(
EnvAgentComboStorage, combo
)
all_combos_map[cast(str, combo.pk)] = cast(EnvAgentComboStorage, combo)
print(f"all combos loaded {len(all_combos_map)}")
return all_combos_map

Expand All @@ -104,9 +91,9 @@ def get_combo_model_map(
all_episodes: List[EpisodeLog],
all_combos_map: Dict[str, EnvAgentComboStorage],
) -> Dict[str, Counter[tuple[LLM_Name, LLM_Name, LLM_Name]]]:
combo_model_map: Dict[
str, Counter[tuple[LLM_Name, LLM_Name, LLM_Name]]
] = defaultdict(Counter)
combo_model_map: Dict[str, Counter[tuple[LLM_Name, LLM_Name, LLM_Name]]] = (
defaultdict(Counter)
)
bad_combos = []
valid_count = 0
invalid_count = 0
Expand Down Expand Up @@ -164,7 +151,7 @@ def get_combo_model_map(


def get_all_model_pairs(
combo_model_map: Dict[str, Counter[tuple[LLM_Name, LLM_Name, LLM_Name]]]
combo_model_map: Dict[str, Counter[tuple[LLM_Name, LLM_Name, LLM_Name]]],
) -> Set[tuple[LLM_Name, LLM_Name, LLM_Name]]:
all_model_pairs = set()
for key in combo_model_map:
Expand All @@ -184,19 +171,17 @@ def get_all_missing_model_pairs(
all_model_pairs: Set[tuple[LLM_Name, LLM_Name, LLM_Name]],
num_required: int,
) -> Dict[str, Counter[tuple[LLM_Name, LLM_Name, LLM_Name]]]:
combo_missing_model_map: Dict[
str, Counter[tuple[LLM_Name, LLM_Name, LLM_Name]]
] = defaultdict(Counter)
combo_missing_model_map: Dict[str, Counter[tuple[LLM_Name, LLM_Name, LLM_Name]]] = (
defaultdict(Counter)
)
missing_count = 0
for key in combo_model_map:
for model_pair in all_model_pairs:
if combo_model_map[key][model_pair] < num_required:
combo_missing_model_map[key][model_pair] += (
num_required - combo_model_map[key][model_pair]
)
missing_count += (
num_required - combo_model_map[key][model_pair]
)
missing_count += num_required - combo_model_map[key][model_pair]
print("-" * 20 + f"Missing {missing_count} Model Pairs" + "-" * 20)
print()
return combo_missing_model_map
Expand All @@ -205,9 +190,7 @@ def get_all_missing_model_pairs(
# temporally used for making sure unique (env, agents, models) setting; need to change
# according to the Counter in the case needing to run multiple experiments for one setting
def get_missing_model_combo_map(
combo_missing_model_map: Dict[
str, Counter[tuple[LLM_Name, LLM_Name, LLM_Name]]
],
combo_missing_model_map: Dict[str, Counter[tuple[LLM_Name, LLM_Name, LLM_Name]]],
all_combos_map: Dict[str, EnvAgentComboStorage],
) -> Dict[tuple[LLM_Name, LLM_Name], List[tuple[str, str, str]]]:
missing_model_combo_map = defaultdict(list)
Expand Down Expand Up @@ -241,17 +224,13 @@ def yield_env_agent_combo(
model_name=model_names["env"],
action_order="round-robin",
evaluators=[
RuleBasedTerminatedEvaluator(
max_turn_number=20, max_stale_turn=2
),
RuleBasedTerminatedEvaluator(max_turn_number=20, max_stale_turn=2),
],
terminal_evaluators=[
ReachGoalLLMEvaluator(model_names["env"]),
],
)
agent_profiles = [
AgentProfile.get(id) for id in (agent_id1, agent_id2)
]
agent_profiles = [AgentProfile.get(id) for id in (agent_id1, agent_id2)]

agents = [
LLMAgent(agent_profile=agent_profile, model_name=agent_model)
Expand All @@ -265,9 +244,7 @@ def yield_env_agent_combo(

@gin.configurable
def re_run_missing_episodes(
combo_with_models: dict[
tuple[LLM_Name, LLM_Name], list[tuple[str, str, str]]
],
combo_with_models: dict[tuple[LLM_Name, LLM_Name], list[tuple[str, str, str]]],
model_names: dict[str, LLM_Name] = {
"env": "gpt-4",
"agent1": "gpt-3.5-turbo",
Expand All @@ -289,9 +266,7 @@ def re_run_missing_episodes(
env_agent_combo_iter_length = len(combo_and_models_to_run)
print(f"total missing episodes: {env_agent_combo_iter_length}")

env_agent_combo_iter = yield_env_agent_combo(
combo_and_models_to_run, model_names
)
env_agent_combo_iter = yield_env_agent_combo(combo_and_models_to_run, model_names)
env_agent_combo_batch: list[EnvAgentCombo[Observation, AgentAction]] = []
while True:
for env_agent_combo in tqdm(
Expand Down
Loading

0 comments on commit 7a0926f

Please sign in to comment.