Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: enabled saving and evaluation for moderator #271

Merged
merged 8 commits into from
Jan 7, 2025
Merged
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import logging
import sys
import json
from rich.logging import RichHandler

from aact import NodeFactory

from sotopia.experimental.agents.base_agent import BaseAgent
from sotopia.experimental.agents.datamodels import Observation, AgentAction
from sotopia.database.persistent_profile import AgentProfile

from sotopia.generation_utils import agenerate
from sotopia.generation_utils.generate import StrOutputParser
Expand Down Expand Up @@ -33,11 +35,13 @@ def __init__(
input_channels: list[str],
output_channel: str,
query_interval: int,
agent_name: str,
node_name: str,
goal: str,
model_name: str,
redis_url: str,
goal: str,
agent_name: str | None = None,
background: dict | None = None,
agent_pk: str | None = None,
redis_url: str = "redis://localhost:6379/0",
):
super().__init__(
[(input_channel, Observation) for input_channel in input_channels],
Expand All @@ -47,23 +51,48 @@ def __init__(
)
self.output_channel = output_channel
self.query_interval = query_interval
self.count_ticks = 0
self.count_ticks: int = 0
self.message_history: list[Observation] = []
self.name = agent_name
self.model_name = model_name
self.goal = goal
self.goal: str = goal
self.model_name: str = model_name
self.agent_profile_pk: str = agent_pk
self.name: str = agent_name
self.background: dict = background

def set_profile(self, use_pk_value: bool):
profile: AgentProfile = None
if not use_pk_value:
if " " in self.name:
first_name, last_name = self.name.split(" ", 1)
else:
first_name = self.name
last_name = ""
profile = AgentProfile(
first_name=first_name, last_name=last_name, **self.background
)
profile.save()
else:
profile = AgentProfile.get(pk=self.agent_profile_pk)

self.agent_profile_pk = profile.pk
self.name = " ".join([profile.first_name, profile.last_name]).strip()
self.background = profile.model_dump()

def _format_message_history(self, message_history: list[Observation]) -> str:
## TODO: akhatua Fix the mapping of action to be gramatically correct
return "\n".join(message.to_natural_language() for message in message_history)

async def aact(self, obs: Observation) -> AgentAction:
if obs.turn_number == -1:
args = json.loads(obs.last_turn)
self.set_profile(args["use_pk_value"])
return AgentAction(
agent_name=self.name,
output_channel=self.output_channel,
action_type="none",
argument=self.model_name,
argument=json.dumps(
{"pk": self.agent_profile_pk, "model_name": self.model_name}
),
)

self.message_history.append(obs)
Expand Down
7 changes: 6 additions & 1 deletion examples/experimental/sotopia_original_replica/origin.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@ node_class = "moderator"
[nodes.node_args]
output_channels = ["moderator:Jane", "moderator:Jack"]
input_channels = ["Jane:moderator", "Jack:moderator"]
agent_backgrounds = {"Jane" = "", "Jack" = ""}
agent_mapping = {"moderator:Jane" = "Jane", "moderator:Jack" = "Jack"}
scenario = "Two friends are sitting in a cafe and catching up with each other's lives."
max_turns = 2
push_to_db = false
will_eval = true
use_pk_value = false

[[nodes]]
node_name = "Jack"
JXZhou0224 marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -26,6 +27,8 @@ output_channel = "Jack:moderator"
goal = "Your goal is to borrow 5000 dollars from Jane."
model_name = "gpt-4o-mini"
agent_name = "Jack"
background = {"occupation" = "construction worker"}
agent_pk = ""


[[nodes]]
Expand All @@ -39,6 +42,8 @@ input_channels = ["moderator:Jane"]
goal = "Your goal is to help Jack however, you are in a finicial crisis yourself and can only afford to give him 500 dollars."
model_name = "gpt-4o-mini"
agent_name = "Jane"
background = {"occupation" = "gardener"}
agent_pk = ""

[[nodes]]
node_name = "chat_print"
Expand Down
2 changes: 1 addition & 1 deletion sotopia/database/logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class BaseEpisodeLog(BaseModel):
tag: str | None = Field(index=True, default="")
models: list[str] | None = Field(index=True, default=[])
messages: list[list[tuple[str, str, str]]] # Messages arranged by turn
reasoning: str
reasoning: str = Field(default="")
rewards: list[tuple[float, dict[str, float]] | float] # Rewards arranged by turn
rewards_prompt: str

Expand Down
25 changes: 25 additions & 0 deletions sotopia/experimental/agents/evaluators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from abc import ABC, abstractmethod
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not using the existing evaluators

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The existing evaluators base class can only take in message history. However, there might be cases where we need other information in EpisodeLog to generate an evaluation. So I made a new class of evaluator for future use.

from .logs import EpisodeLog


class BaseEvaluator(ABC):
def __init__(self):
pass

@abstractmethod
def evaluate(self, epilog: EpisodeLog) -> tuple[float, str]:
"""
evaluate an episode, returns the score and reward prompt
"""
raise NotImplementedError


class DummyEvaluator(BaseEvaluator):
def __init__(self):
super().__init__()

def evaluate(self, epilog: EpisodeLog) -> tuple[float, str]:
"""
evaluate an episode, returns the score and reward prompt
"""
return 0.0, "No evaluation implemented"
8 changes: 8 additions & 0 deletions sotopia/experimental/agents/logs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from redis_om import JsonModel
from sotopia.database.logs import BaseEpisodeLog
from sotopia.database.persistent_profile import AgentProfile


class EpisodeLog(BaseEpisodeLog, JsonModel):
def render_for_humans(self) -> tuple[list[AgentProfile], list[str]]:
raise NotImplementedError
93 changes: 54 additions & 39 deletions sotopia/experimental/agents/moderator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
import sys

import json

if sys.version_info < (3, 11):
from typing_extensions import Self
Expand All @@ -14,9 +14,10 @@
from typing import Literal, Any, AsyncIterator
from pydantic import Field

from sotopia.database import EpisodeLog
from .datamodels import AgentAction, Observation
from .evaluators import DummyEvaluator
from sotopia.messages import ActionType
from .logs import EpisodeLog


@DataModelFactory.register("observations")
Expand All @@ -30,12 +31,12 @@ class Observations(DataModel):
class Moderator(Node[AgentAction, Observation]):
def __init__(
self,
node_name,
input_channels: list[str],
output_channels: list[str],
scenario: str,
agent_mapping: dict[str, str],
node_name: str,
agent_backgrounds: dict[str, str],
tag: str = "",
redis_url: str = "redis://localhost:6379/0",
action_order: Literal["simultaneous", "round-robin", "random"] = "round-robin",
available_actions: list[ActionType] = [
Expand All @@ -47,6 +48,9 @@ def __init__(
],
max_turns: int = 20,
push_to_db: bool = False,
will_eval: bool = False,
use_pk_value: bool = False,
evaluator: str = "DummyEvaluator",
):
super().__init__(
input_channel_types=[
Expand All @@ -62,6 +66,7 @@ def __init__(
self.task_scheduler: asyncio.Task[None] | None = None
self.shutdown_event: asyncio.Event = asyncio.Event()
self.agent_mapping: dict[str, str] = agent_mapping
self.tag: str = tag
self.action_order: Literal["simultaneous", "round-robin", "random"] = (
action_order
)
Expand All @@ -71,14 +76,17 @@ def __init__(
self.current_agent_index: int = 0
self.scenario: str = scenario
self.agents: list[str] = list(agent_mapping.values())
self.agents_pk: dict[str, str] = {}
self.agent_models: dict[str, str] = {}
self.agents_awake: dict[str, bool] = {name: False for name in self.agents}
self.all_agents_awake: asyncio.Event = asyncio.Event()
self.message_history: list[list[tuple[str, str, str]]] = [
[("Environment", "Environment", self.scenario)]
]
self.push_to_db = push_to_db
self.agent_backgrounds = agent_backgrounds
self.push_to_db: bool = push_to_db
self.will_eval: bool = will_eval
self.use_pk_value: bool = use_pk_value
self.evaluator: str = evaluator

if self.action_order == "round-robin":
pass
Expand Down Expand Up @@ -131,7 +139,7 @@ async def booting(self) -> None:
"""
1. send checking message to agents for every 0.1 seconds, until all agents are awake
- this message has turn_number of -1 for identification, agents should not record this into actual message_history
- if the agent booted succesfully, he is expected to return its model name for record.
- if the agent booted succesfully, he is expected to return its agent_profile's pk for record.
2. (under round-robin action order)after all agents are awake, send agent[0] a message to allow the agent to start speaking
"""
while not self.all_agents_awake.is_set():
Expand All @@ -140,7 +148,11 @@ async def booting(self) -> None:
observations_map={
output_channel: Observation(
agent_name="moderator",
last_turn=self.scenario,
last_turn=json.dumps(
{
"use_pk_value": self.use_pk_value,
}
),
turn_number=-1,
available_actions=["none"],
)
Expand All @@ -152,17 +164,20 @@ async def booting(self) -> None:
while not self.observation_queue.empty():
agent_action = await self.observation_queue.get()
self.agents_awake[agent_action.agent_name] = True
self.agent_models[agent_action.agent_name] = agent_action.argument
args: dict = json.loads(agent_action.argument)
self.agents_pk[agent_action.agent_name] = args["pk"]
self.agent_models[agent_action.agent_name] = args["model_name"]
if False not in self.agents_awake.values():
self.all_agents_awake.set()
print("all agents are awake")

if self.action_order == "round-robin":
await self.send(
Observations(
observations_map={
output_channel: Observation(
agent_name="moderator",
last_turn=self.agent_backgrounds[agent_name],
last_turn="conversation start",
turn_number=0,
available_actions=self.available_actions
if agent_name == self.agents[0]
Expand All @@ -175,36 +190,45 @@ async def booting(self) -> None:
self.current_agent_index += 1

async def wrap_up_and_stop(self) -> None:
if self.push_to_db:
await self.save()
epilog = await self.save()
print("episode saved")
if self.will_eval:
epilog = await self.eval(epilog)
await asyncio.sleep(0.5)
print("stopping all agents")
print("result of this episode:\n", epilog)
await self.r.publish(
f"shutdown:{self.node_name}",
"shutdown:moderator",
"shutdown",
)

async def eval(self, epilog: EpisodeLog) -> EpisodeLog:
"""
evaluate the episode
"""
if self.evaluator == "DummyEvaluator":
evaluator = DummyEvaluator()
reward, reward_prompt = evaluator.evaluate(epilog)
epilog.rewards = [reward]
epilog.rewards_prompt = reward_prompt
if self.push_to_db:
epilog.save()
return epilog

async def save(self) -> EpisodeLog:
"""
save the EpisodeLog to redis, without evaluating
TODO: specify what to be added inside tag
TODO: update the code so that EpisodeLog.render_for_humans() can work
-currently it cannot work because no AgentProfile has been uploaded to redis
-such a process should be done back in the agents' end
-also the current agentslist is consist of names, but not uuid's of agents
save the EpisodeLog to redis
"""
epilog = EpisodeLog(
environment=self.scenario,
agents=self.agents,
tag=None,
agents=list(self.agents_pk.values()),
tag=self.tag,
models=list(self.agent_models.values()),
messages=self.message_history,
reasoning="",
rewards=[0] * len(self.agents),
rewards=[0.0] * len(self.agents),
rewards_prompt="",
)
epilog.save()
# print(epilog.render_for_humans())
if self.push_to_db:
epilog.save()
return epilog

async def aact(self, agent_action: AgentAction) -> Observations | None:
Expand All @@ -216,24 +240,15 @@ async def aact(self, agent_action: AgentAction) -> Observations | None:
if agent_action.action_type == "none":
return None

if len(self.message_history) == 1:
self.message_history[0].append(
self.message_history.append(
[
(
agent_action.agent_name,
"Environment",
agent_action.to_natural_language(),
)
)
else:
self.message_history.append(
[
(
agent_action.agent_name,
"Environment",
agent_action.to_natural_language(),
)
]
)
]
)

if self.turn_number < self.max_turns:
self.turn_number += 1
Expand Down
Loading