Skip to content

Commit

Permalink
remove more LLM_NAME references
Browse files Browse the repository at this point in the history
  • Loading branch information
clementou committed Apr 10, 2024
1 parent 96024a1 commit 490f4d5
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 27 deletions.
15 changes: 6 additions & 9 deletions examples/experiment_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,12 @@
import logging
import os
import subprocess
import sys
from datetime import datetime
from logging import FileHandler
from typing import Any, Callable, Generator, Literal, Sequence, cast
from typing import Any, Generator, cast

import gin
from absl import app, flags
from rich import print
from absl import flags
from rich.logging import RichHandler
from tqdm import tqdm

Expand All @@ -25,8 +23,7 @@
RuleBasedTerminatedEvaluator,
)
from sotopia.envs.parallel import ParallelSotopiaEnv
from sotopia.generation_utils.generate import LLM_Name
from sotopia.messages import AgentAction, Message, Observation
from sotopia.messages import AgentAction, Observation
from sotopia.samplers import (
BaseSampler,
ConstraintBasedSampler,
Expand Down Expand Up @@ -71,7 +68,7 @@
def check_existing_episodes(
env_id: str,
agent_ids: list[str],
models: dict[str, LLM_Name],
models: dict[str, str],
tag: str | None = None,
) -> bool:
if tag:
Expand Down Expand Up @@ -112,7 +109,7 @@ def _sample_env_agent_combo_and_push_to_db(env_id: str) -> None:

@gin.configurable
def _iterate_env_agent_combo_not_in_db(
model_names: dict[str, LLM_Name],
model_names: dict[str, str],
env_ids: list[str] = [],
tag: str | None = None,
) -> Generator[EnvAgentCombo[Observation, AgentAction], None, None]:
Expand Down Expand Up @@ -181,7 +178,7 @@ def _iterate_env_agent_combo_not_in_db(
def run_async_server_in_batch(
*,
batch_size: int = 1,
model_names: dict[str, LLM_Name] = {
model_names: dict[str, str] = {
"env": "gpt-4",
"agent1": "gpt-3.5-turbo",
"agent2": "gpt-3.5-turbo",
Expand Down
20 changes: 6 additions & 14 deletions examples/generate_script.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import asyncio
import logging
import os
import subprocess
from datetime import datetime
from logging import FileHandler
from typing import Any

Expand All @@ -14,14 +12,8 @@
from tqdm import tqdm
from tqdm.asyncio import tqdm_asyncio

from sotopia.envs.parallel import ParallelSotopiaEnv
from sotopia.generation_utils.generate import LLM_Name, agenerate_script
from sotopia.messages.message_classes import (
AgentAction,
Observation,
ScriptBackground,
)
from sotopia.samplers import EnvAgentCombo, UniformSampler
from sotopia.messages.message_classes import AgentAction, Observation
from sotopia.samplers import EnvAgentCombo
from sotopia.server import arun_one_script, run_async_server
from sotopia_conf.gin_utils import parse_gin_flags, run

Expand All @@ -45,7 +37,7 @@

@gin.configurable
def single_step(
model_names: dict[str, LLM_Name],
model_names: dict[str, str],
tag: str | None = None,
batch_size: int = 5,
push_to_db: bool = True,
Expand Down Expand Up @@ -111,7 +103,7 @@ def single_step(

@gin.configurable
def full_freeform(
model_names: dict[str, LLM_Name],
model_names: dict[str, str],
tag: str | None = None,
batch_size: int = 5,
push_to_db: bool = True,
Expand Down Expand Up @@ -186,14 +178,14 @@ def full_freeform(
def run_async_server_in_batch_script(
*,
batch_size: int = 10,
model: LLM_Name = "gpt-3.5-turbo",
model: str = "gpt-3.5-turbo",
tag: str | None = None,
push_to_db: bool = True,
json_in_script: bool = False,
generate_in_full: bool = False,
verbose: bool = False,
) -> None:
model_names: dict[str, LLM_Name] = {
model_names: dict[str, str] = {
"env": model,
"agent1": model,
"agent2": model,
Expand Down
8 changes: 4 additions & 4 deletions sotopia/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@

@beartype
def run_sync_server(
model_name_dict: dict[str, LLM_Name],
model_name_dict: dict[str, str],
action_order: Literal["simutaneous", "round-robin", "random"],
agents_info: dict[str, dict[str, str]] | None = None,
partial_background_file: str | None = None,
Expand Down Expand Up @@ -124,7 +124,7 @@ def run_sync_server(
async def arun_one_episode(
env: ParallelSotopiaEnv,
agent_list: Sequence[BaseAgent[Observation, AgentAction]],
model_dict: dict[str, LLM_Name],
model_dict: dict[str, str],
omniscient: bool = False,
script_like: bool = False,
json_in_script: bool = False,
Expand Down Expand Up @@ -257,7 +257,7 @@ async def arun_one_episode(
@gin.configurable
@beartype
async def run_async_server(
model_dict: dict[str, LLM_Name],
model_dict: dict[str, str],
sampler: BaseSampler[Observation, AgentAction] = BaseSampler(),
action_order: Literal[
"simutaneous", "round-robin", "random"
Expand Down Expand Up @@ -358,7 +358,7 @@ def get_agent_class(
async def arun_one_script(
env: ParallelSotopiaEnv,
agent_list: Sequence[BaseAgent[Observation, AgentAction]],
model_dict: dict[str, LLM_Name],
model_dict: dict[str, str],
omniscient: bool = False,
tag: str | None = None,
push_to_db: bool = False,
Expand Down

0 comments on commit 490f4d5

Please sign in to comment.