diff --git a/sotopia/server.py b/sotopia/server.py index bc9c9a035..d023b22be 100644 --- a/sotopia/server.py +++ b/sotopia/server.py @@ -29,7 +29,7 @@ RuleBasedTerminatedEvaluator, unweighted_aggregate_evaluate, ) -from sotopia.generation_utils.generate import LLM_Name, agenerate_script +from sotopia.generation_utils.generate import agenerate_script from sotopia.messages import AgentAction, Message, Observation from sotopia.messages.message_classes import ( ScriptBackground, @@ -46,7 +46,7 @@ @beartype def run_sync_server( - model_name_dict: dict[str, LLM_Name], + model_name_dict: dict[str, str], action_order: Literal["simutaneous", "round-robin", "random"], agents_info: dict[str, dict[str, str]] | None = None, partial_background_file: str | None = None, @@ -124,7 +124,7 @@ def run_sync_server( async def arun_one_episode( env: ParallelSotopiaEnv, agent_list: Sequence[BaseAgent[Observation, AgentAction]], - model_dict: dict[str, LLM_Name], + model_dict: dict[str, str], omniscient: bool = False, script_like: bool = False, json_in_script: bool = False, @@ -257,7 +257,7 @@ async def arun_one_episode( @gin.configurable @beartype async def run_async_server( - model_dict: dict[str, LLM_Name], + model_dict: dict[str, str], sampler: BaseSampler[Observation, AgentAction] = BaseSampler(), action_order: Literal[ "simutaneous", "round-robin", "random" @@ -358,7 +358,7 @@ def get_agent_class( async def arun_one_script( env: ParallelSotopiaEnv, agent_list: Sequence[BaseAgent[Observation, AgentAction]], - model_dict: dict[str, LLM_Name], + model_dict: dict[str, str], omniscient: bool = False, tag: str | None = None, push_to_db: bool = False,