Skip to content

Commit

Permalink
add evaluation node
Browse files Browse the repository at this point in the history
  • Loading branch information
XuhuiZhou committed Jan 8, 2025
1 parent bbf6061 commit 7558927
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 75 deletions.
9 changes: 9 additions & 0 deletions examples/experimental/sotopia_original_replica/origin.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ node_class = "moderator"
[nodes.node_args]
output_channels = ["moderator:Jane", "moderator:Jack"]
input_channels = ["Jane:moderator", "Jack:moderator"]
evaluator_channels = ["evaluator:moderator"]
agent_mapping = {"moderator:Jane" = "Jane", "moderator:Jack" = "Jack"}
scenario = "Two friends are sitting in a cafe and catching up with each other's lives."
max_turns = 2
Expand Down Expand Up @@ -55,3 +56,11 @@ node_class = "chat_print"

[nodes.node_args]
env_agents = ["Jack", "Jane"]

[[nodes]]
node_name = "evaluator"
node_class = "evaluator"

[nodes.node_args]
input_channels = ["moderator:evaluator"]
output_channels = ["moderator:evaluator"]
42 changes: 22 additions & 20 deletions sotopia/experimental/agents/evaluators.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,27 @@
from abc import ABC, abstractmethod
from aact import NodeFactory, Node
from .logs import EpisodeLog
from .datamodels import AgentAction, Observation


class BaseEvaluator(ABC):
def __init__(self):
pass
@NodeFactory.register("evaluator")
class Evaluator(Node[AgentAction, Observation]):
def __init__(
self,
node_name: str,
input_channels: list[str],
output_channels: list[str],
redis_url: str,
):
super().__init__(
input_channel_types=[
(input_channel, AgentAction) for input_channel in input_channels
],
output_channel_types=[
(output_channel, Observation) for output_channel in output_channels
],
node_name=node_name,
redis_url=redis_url,
)

@abstractmethod
def evaluate(self, epilog: EpisodeLog) -> tuple[float, str]:
"""
evaluate an episode, returns the score and reward prompt
"""
async def aevaluate(self, episode: EpisodeLog) -> AgentAction | None:
raise NotImplementedError


class DummyEvaluator(BaseEvaluator):
def __init__(self):
super().__init__()

def evaluate(self, epilog: EpisodeLog) -> tuple[float, str]:
"""
evaluate an episode, returns the score and reward prompt
"""
return 0.0, "No evaluation implemented"
72 changes: 17 additions & 55 deletions sotopia/experimental/agents/moderator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,15 @@
else:
from typing import Self


from aact import Message, NodeFactory, Node
from aact.messages import DataModel, DataModelFactory

from typing import Literal, Any, AsyncIterator
from pydantic import Field

from .datamodels import AgentAction, Observation
from sotopia.envs.evaluators import Evaluator, unweighted_aggregate_evaluate
from sotopia.messages import ActionType
from .logs import EpisodeLog
import itertools


@DataModelFactory.register("observations")
Expand All @@ -37,6 +34,7 @@ def __init__(
output_channels: list[str],
scenario: str,
agent_mapping: dict[str, str],
evaluator_channels: list[str] = [],
tag: str = "",
redis_url: str = "redis://localhost:6379/0",
action_order: Literal["simultaneous", "round-robin", "random"] = "round-robin",
Expand All @@ -49,10 +47,7 @@ def __init__(
],
max_turns: int = 20,
push_to_db: bool = False,
evaluators: list[Evaluator] = [],
terminal_evaluators: list[Evaluator] = [],
use_pk_value: bool = False,
evaluator: str = "DummyEvaluator",
) -> None:
super().__init__(
input_channel_types=[
Expand Down Expand Up @@ -82,14 +77,10 @@ def __init__(
self.agent_models: dict[str, str] = {}
self.agents_awake: dict[str, bool] = {name: False for name in self.agents}
self.all_agents_awake: asyncio.Event = asyncio.Event()
self.message_history: list[list[tuple[str, str, str]]] = [
[("Environment", "Environment", self.scenario)]
]
self.evaluator_channels: list[str] = evaluator_channels
self.push_to_db: bool = push_to_db
self.evaluators: list[Evaluator] = evaluators
self.terminal_evaluators: list[Evaluator] = terminal_evaluators
self.use_pk_value: bool = use_pk_value
self.evaluator: str = evaluator
self.epilog: EpisodeLog = EpisodeLog(messages=[], rewards=[], rewards_prompt="")

if self.action_order == "round-robin":
pass
Expand Down Expand Up @@ -193,10 +184,10 @@ async def booting(self) -> None:
self.current_agent_index += 1

async def wrap_up_and_stop(self) -> None:
epilog = await self.save()
print("episode saved")
if self.terminal_evaluators:
epilog = await self.eval(epilog)
if self.evaluator_channels:
epilog = await self.aeval(self.epilog)
if self.push_to_db:
epilog.save()
await asyncio.sleep(0.5)
print("result of this episode:\n", epilog)
await self.r.publish(
Expand All @@ -213,48 +204,19 @@ async def episode_log_to_messages(
messages.append((message[0], message[1], message[2]))
return messages

async def eval(self, epilog: EpisodeLog) -> EpisodeLog:
async def aeval(self, epilog: EpisodeLog) -> EpisodeLog:
"""
evaluate the episode
"""
messages = await self.episode_log_to_messages(epilog)
if self.terminal_evaluators:
response = unweighted_aggregate_evaluate(
list(
itertools.chain(
*await asyncio.gather(
*[
evaluator.__acall__(
turn_number=self.turn_number,
messages=messages, # type: ignore
)
for evaluator in self.evaluators
]
)
)
)
)
epilog.rewards = response.p1_rate # type: ignore
epilog.rewards_prompt = response.comments # type: ignore
if self.push_to_db:
epilog.save()
return epilog
for evaluator_channel in self.evaluator_channels:
await self.r.publish(evaluator_channel, epilog.model_dump_json())
print("episode eval started")

async def save(self) -> EpisodeLog:
"""
save the EpisodeLog to redis
"""
epilog = EpisodeLog(
environment=self.scenario,
agents=list(self.agents_pk.values()),
tag=self.tag,
models=list(self.agent_models.values()),
messages=self.message_history,
rewards=[0.0] * len(self.agents),
rewards_prompt="",
)
if self.push_to_db:
epilog.save()
for evaluator_channel in self.evaluator_channels:
await self.observation_queue.get()
print("episode eval finished")
epilog.rewards = [0.0] * len(self.agents) # TODO: get real rewards
epilog.rewards_prompt = "" # TODO: get real rewards_prompt
return epilog

async def astep(self, agent_action: AgentAction) -> Observations | None:
Expand All @@ -267,7 +229,7 @@ async def astep(self, agent_action: AgentAction) -> Observations | None:
return None

# message (sender, receivers (seperated by comma), message content)
self.message_history.append(
self.epilog.messages.append(
[
(
agent_action.agent_name,
Expand Down

0 comments on commit 7558927

Please sign in to comment.