Skip to content

Commit

Permalink
assign work to machine names
Browse files Browse the repository at this point in the history
this addition allows us to assign work to agents sitting on specific `machine_name`s

uses round-robin for work assignment, based on matching machine name. silly, but works for now.
  • Loading branch information
swelborn committed Sep 12, 2024
1 parent a10b780 commit 97b58b5
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 23 deletions.
21 changes: 15 additions & 6 deletions backend/agent/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,10 @@
import podman
import podman.errors
from core.constants import BUCKET_AGENTS, STREAM_AGENTS, STREAM_OPERATORS
from core.events.pipelines import PipelineRunEvent
from core.logger import get_logger
from core.models.agent import AgentStatus, AgentVal
from core.models.base import IdType
from core.models.pipeline import OperatorJSON, PipelineJSON
from core.models.pipeline import OperatorJSON, PipelineAssignment, PipelineJSON
from core.models.uri import URI, CommBackend, URILocation
from core.pipeline import Pipeline
from nats.aio.client import Client as NATSClient
Expand Down Expand Up @@ -259,13 +258,21 @@ async def server_loop(self):
asyncio.gather(*[msg.ack() for msg in msgs])
for msg in msgs:
try:
event = PipelineRunEvent.model_validate_json(msg.data)
event = PipelineAssignment.model_validate_json(msg.data)
except ValidationError:
logger.error("Invalid message")
return
valid_pipeline = PipelineJSON(id=event.id, **event.data)
logger.error(f"Invalid message: {msg.data}")
continue
try:
assert event.agent_id == self.id
except AssertionError:
logger.error(
f"Agent ID mismatch: {event.agent_id} != {self.id}"
)
continue
valid_pipeline = PipelineJSON.model_validate(event.pipeline)
logger.info(f"Validated pipeline: {valid_pipeline.id}")
self.pipeline = Pipeline.from_pipeline(valid_pipeline)
self.my_operator_ids = event.operators_assigned
break

self.containers = await self.start_operators()
Expand Down Expand Up @@ -306,6 +313,8 @@ async def start_operators(self) -> dict[uuid.UUID, Container]:
futures = []
with PodmanClient(base_url=self._podman_service_uri) as client:
for id, op_info in self.pipeline.operators.items():
if id not in self.my_operator_ids:
continue
logger.info(f"Starting operator {id} with image {op_info.image}")

# TODO: could use async to create multiple at once
Expand Down
9 changes: 6 additions & 3 deletions backend/app/app/tests/api/routes/test.http
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ Authorization: bearer put_the_access_token_here
],
"params": {
"hello": "world"
}
},
"machine_name": "mothership6"
},
{
"id": "12345678-1234-1234-1234-1234567890cd",
Expand All @@ -69,7 +70,8 @@ Authorization: bearer put_the_access_token_here
"outputs": [],
"params": {
"hello": "world"
}
},
"machine_name": "perlmutter"
},
{
"id": "12345678-1234-1234-1234-1234567890ef",
Expand All @@ -81,7 +83,8 @@ Authorization: bearer put_the_access_token_here
"outputs": [],
"params": {
"hello": "world"
}
},
"machine_name": "perlmutter"
}
],
"ports": [
Expand Down
7 changes: 7 additions & 0 deletions backend/core/core/models/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class OperatorJSON(PipelineNodeJSON):
params: dict[str, Any] = {}
inputs: list[PortID] = []
outputs: list[PortID] = []
machine_name: str | None = None


class EdgeJSON(BaseModel):
Expand All @@ -45,3 +46,9 @@ class PipelineJSON(BaseModel):
operators: Sequence[OperatorJSON] = []
ports: Sequence[PortJSON] = []
edges: Sequence[EdgeJSON] = []


class PipelineAssignment(BaseModel):
agent_id: IdType
operators_assigned: list[OperatorID] = []
pipeline: PipelineJSON
111 changes: 97 additions & 14 deletions backend/orchestrator/orchestrator/orchestrator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
from collections.abc import Awaitable, Callable
from itertools import cycle
from uuid import uuid4

import nats
Expand All @@ -15,7 +16,9 @@
from core.events.pipelines import PipelineRunEvent
from core.logger import get_logger
from core.models.agent import AgentVal
from core.pipeline import Pipeline, PipelineJSON
from core.models.base import IdType
from core.models.pipeline import PipelineAssignment, PipelineJSON
from core.pipeline import OperatorJSON, Pipeline
from nats.aio.client import Client as NATSClient
from nats.aio.msg import Msg as NATSMsg
from nats.js import JetStreamContext
Expand All @@ -29,6 +32,88 @@
logger = get_logger("orchestrator", "DEBUG")


async def publish_assignment(js: JetStreamContext, assignment: PipelineAssignment):
await js.publish(
f"{STREAM_AGENTS}.{assignment.agent_id}",
stream=f"{STREAM_AGENTS}",
payload=assignment.model_dump_json().encode(),
)


def assign_pipeline_to_agents(agent_infos: list[AgentVal], pipeline: Pipeline):
# Group agents by machine name
agents_on_machines: dict[str, list[AgentVal]] = {}
agents_without_machines: list[AgentVal] = []
for agent_info in agent_infos:
if agent_info.machine_name:
if agent_info.machine_name not in agents_on_machines:
agents_on_machines[agent_info.machine_name] = []
agents_on_machines[agent_info.machine_name].append(agent_info)
else:
agents_without_machines.append(agent_info)

# Group operators by machine name
operators_on_machines: dict[str, list[OperatorJSON]] = {}
operators_without_machines: list[OperatorJSON] = []
for operator in pipeline.operators.values():
if operator.machine_name:
if operator.machine_name not in operators_on_machines:
operators_on_machines[operator.machine_name] = []
operators_on_machines[operator.machine_name].append(operator)
else:
operators_without_machines.append(operator)

# Assign operators with machine names to agents with corresponding machine names using round-robin
assignments: dict[IdType, list[OperatorJSON]] = {}
for machine_name, operators in operators_on_machines.items():
agents_on_this_machine = agents_on_machines.get(machine_name, None)
if not agents_on_this_machine:
raise Exception(f"No agents available for machine {machine_name}")

agent_cycle = cycle(agents_on_this_machine)
for operator in operators:
agent = next(agent_cycle)
agent_id = agent.uri.id
if agent_id not in assignments:
assignments[agent_id] = []
assignments[agent_id].append(operator)

# Assign operators without machine names to agents without machine names using round-robin
if operators_without_machines:
if not agents_without_machines:
raise Exception("No agents available for operators without machine names")

agent_cycle = cycle(agents_without_machines)
for operator in operators_without_machines:
agent = next(agent_cycle)
agent_id = agent.uri.id
if agent_id not in assignments:
assignments[agent_id] = []
assignments[agent_id].append(operator)

pipeline_assignments: list[PipelineAssignment] = []
for agent_id, operators in assignments.items():
pipeline_assignment = PipelineAssignment(
agent_id=agent_id,
operators_assigned=[op.id for op in operators],
pipeline=pipeline.to_json(),
)
pipeline_assignments.append(pipeline_assignment)

formatted_assignments = "\n".join(
f"Agent {assignment.agent_id}:\n"
+ "\n".join(
f" Operator {op_id} on machine {next(op.machine_name for op in assignment.pipeline.operators if op.id == op_id)}"
for op_id in assignment.operators_assigned
)
for assignment in pipeline_assignments
)

logger.info(f"Final assignments:\n{formatted_assignments}")

return pipeline_assignments


async def handle_run_pipeline(msg: NATSMsg, js: JetStreamContext):
logger.info("Received pipeline run event...")
await msg.ack()
Expand All @@ -40,27 +125,25 @@ async def handle_run_pipeline(msg: NATSMsg, js: JetStreamContext):
return
valid_pipeline = PipelineJSON(id=event.id, **event.data)
logger.info(f"Validated pipeline: {valid_pipeline.id}")
_ = Pipeline.from_pipeline(valid_pipeline)
pipeline = Pipeline.from_pipeline(valid_pipeline)

agents = await get_agents(js)
logger.info(f"There are currently {len(agents)} agent(s) available...")
logger.info(f"Agents: {agents}")
if len(agents) < 1:
if len(agents) == 0:
logger.info("No agents available to run pipeline.")
# TODO: publish event to send message back to API that says this won't work
return

first_agent_info = await get_agent_info(js, agents[0])
if not first_agent_info:
logger.error("Agent info not found.")
agent_infos = await asyncio.gather(*[get_agent_info(js, agent) for agent in agents])
agent_infos = [agent_info for agent_info in agent_infos if agent_info]

try:
assignments = assign_pipeline_to_agents(agent_infos, pipeline)
except Exception as e:
logger.error(f"Failed to assign pipeline to agents: {e}")
return

logger.info(f"Assigning pipeline to agent: {first_agent_info.uri.id}")
logger.info(f"Agent info: {first_agent_info}")
await js.publish(
f"{STREAM_AGENTS}.{first_agent_info.uri.id}",
stream=f"{STREAM_AGENTS}",
payload=msg.data,
await asyncio.gather(
*[publish_assignment(js, assignment) for assignment in assignments]
)

logger.info("Pipeline run event processed.")
Expand Down

0 comments on commit 97b58b5

Please sign in to comment.