diff --git a/backend/agent/agent/agent.py b/backend/agent/agent/agent.py index d51ffc7..08de9ed 100644 --- a/backend/agent/agent/agent.py +++ b/backend/agent/agent/agent.py @@ -11,11 +11,10 @@ import podman import podman.errors from core.constants import BUCKET_AGENTS, STREAM_AGENTS, STREAM_OPERATORS -from core.events.pipelines import PipelineRunEvent from core.logger import get_logger from core.models.agent import AgentStatus, AgentVal from core.models.base import IdType -from core.models.pipeline import OperatorJSON, PipelineJSON +from core.models.pipeline import OperatorJSON, PipelineAssignment, PipelineJSON from core.models.uri import URI, CommBackend, URILocation from core.pipeline import Pipeline from nats.aio.client import Client as NATSClient @@ -259,13 +258,21 @@ async def server_loop(self): asyncio.gather(*[msg.ack() for msg in msgs]) for msg in msgs: try: - event = PipelineRunEvent.model_validate_json(msg.data) + event = PipelineAssignment.model_validate_json(msg.data) except ValidationError: - logger.error("Invalid message") - return - valid_pipeline = PipelineJSON(id=event.id, **event.data) + logger.error(f"Invalid message: {msg.data}") + continue + try: + assert event.agent_id == self.id + except AssertionError: + logger.error( + f"Agent ID mismatch: {event.agent_id} != {self.id}" + ) + continue + valid_pipeline = PipelineJSON.model_validate(event.pipeline) logger.info(f"Validated pipeline: {valid_pipeline.id}") self.pipeline = Pipeline.from_pipeline(valid_pipeline) + self.my_operator_ids = event.operators_assigned break self.containers = await self.start_operators() @@ -306,6 +313,8 @@ async def start_operators(self) -> dict[uuid.UUID, Container]: futures = [] with PodmanClient(base_url=self._podman_service_uri) as client: for id, op_info in self.pipeline.operators.items(): + if id not in self.my_operator_ids: + continue logger.info(f"Starting operator {id} with image {op_info.image}") # TODO: could use async to create multiple at once diff --git a/backend/app/app/tests/api/routes/test.http b/backend/app/app/tests/api/routes/test.http index ed8af85..2e1b787 100644 --- a/backend/app/app/tests/api/routes/test.http +++ b/backend/app/app/tests/api/routes/test.http @@ -57,7 +57,8 @@ Authorization: bearer put_the_access_token_here ], "params": { "hello": "world" - } + }, + "machine_name": "mothership6" }, { "id": "12345678-1234-1234-1234-1234567890cd", @@ -69,7 +70,8 @@ Authorization: bearer put_the_access_token_here "outputs": [], "params": { "hello": "world" - } + }, + "machine_name": "perlmutter" }, { "id": "12345678-1234-1234-1234-1234567890ef", @@ -81,7 +83,8 @@ Authorization: bearer put_the_access_token_here "outputs": [], "params": { "hello": "world" - } + }, + "machine_name": "perlmutter" } ], "ports": [ diff --git a/backend/core/core/models/pipeline.py b/backend/core/core/models/pipeline.py index 16ac3d9..6955459 100644 --- a/backend/core/core/models/pipeline.py +++ b/backend/core/core/models/pipeline.py @@ -32,6 +32,7 @@ class OperatorJSON(PipelineNodeJSON): params: dict[str, Any] = {} inputs: list[PortID] = [] outputs: list[PortID] = [] + machine_name: str | None = None class EdgeJSON(BaseModel): @@ -45,3 +46,9 @@ class PipelineJSON(BaseModel): operators: Sequence[OperatorJSON] = [] ports: Sequence[PortJSON] = [] edges: Sequence[EdgeJSON] = [] + + +class PipelineAssignment(BaseModel): + agent_id: IdType + operators_assigned: list[OperatorID] = [] + pipeline: PipelineJSON diff --git a/backend/orchestrator/orchestrator/orchestrator.py b/backend/orchestrator/orchestrator/orchestrator.py index ff3ac19..8408907 100644 --- a/backend/orchestrator/orchestrator/orchestrator.py +++ b/backend/orchestrator/orchestrator/orchestrator.py @@ -1,5 +1,6 @@ import asyncio from collections.abc import Awaitable, Callable +from itertools import cycle from uuid import uuid4 import nats @@ -15,7 +16,9 @@ from core.events.pipelines import PipelineRunEvent from core.logger import get_logger from core.models.agent import AgentVal -from core.pipeline import Pipeline, PipelineJSON +from core.models.base import IdType +from core.models.pipeline import PipelineAssignment, PipelineJSON +from core.pipeline import OperatorJSON, Pipeline from nats.aio.client import Client as NATSClient from nats.aio.msg import Msg as NATSMsg from nats.js import JetStreamContext @@ -29,6 +32,88 @@ logger = get_logger("orchestrator", "DEBUG") +async def publish_assignment(js: JetStreamContext, assignment: PipelineAssignment): + await js.publish( + f"{STREAM_AGENTS}.{assignment.agent_id}", + stream=f"{STREAM_AGENTS}", + payload=assignment.model_dump_json().encode(), + ) + + +def assign_pipeline_to_agents(agent_infos: list[AgentVal], pipeline: Pipeline): + # Group agents by machine name + agents_on_machines: dict[str, list[AgentVal]] = {} + agents_without_machines: list[AgentVal] = [] + for agent_info in agent_infos: + if agent_info.machine_name: + if agent_info.machine_name not in agents_on_machines: + agents_on_machines[agent_info.machine_name] = [] + agents_on_machines[agent_info.machine_name].append(agent_info) + else: + agents_without_machines.append(agent_info) + + # Group operators by machine name + operators_on_machines: dict[str, list[OperatorJSON]] = {} + operators_without_machines: list[OperatorJSON] = [] + for operator in pipeline.operators.values(): + if operator.machine_name: + if operator.machine_name not in operators_on_machines: + operators_on_machines[operator.machine_name] = [] + operators_on_machines[operator.machine_name].append(operator) + else: + operators_without_machines.append(operator) + + # Assign operators with machine names to agents with corresponding machine names using round-robin + assignments: dict[IdType, list[OperatorJSON]] = {} + for machine_name, operators in operators_on_machines.items(): + agents_on_this_machine = agents_on_machines.get(machine_name, None) + if not agents_on_this_machine: + raise Exception(f"No agents available for machine {machine_name}") + + agent_cycle = cycle(agents_on_this_machine) + for operator in operators: + agent = next(agent_cycle) + agent_id = agent.uri.id + if agent_id not in assignments: + assignments[agent_id] = [] + assignments[agent_id].append(operator) + + # Assign operators without machine names to agents without machine names using round-robin + if operators_without_machines: + if not agents_without_machines: + raise Exception("No agents available for operators without machine names") + + agent_cycle = cycle(agents_without_machines) + for operator in operators_without_machines: + agent = next(agent_cycle) + agent_id = agent.uri.id + if agent_id not in assignments: + assignments[agent_id] = [] + assignments[agent_id].append(operator) + + pipeline_assignments: list[PipelineAssignment] = [] + for agent_id, operators in assignments.items(): + pipeline_assignment = PipelineAssignment( + agent_id=agent_id, + operators_assigned=[op.id for op in operators], + pipeline=pipeline.to_json(), + ) + pipeline_assignments.append(pipeline_assignment) + + formatted_assignments = "\n".join( + f"Agent {assignment.agent_id}:\n" + + "\n".join( + f" Operator {op_id} on machine {next(op.machine_name for op in assignment.pipeline.operators if op.id == op_id)}" + for op_id in assignment.operators_assigned + ) + for assignment in pipeline_assignments + ) + + logger.info(f"Final assignments:\n{formatted_assignments}") + + return pipeline_assignments + + async def handle_run_pipeline(msg: NATSMsg, js: JetStreamContext): logger.info("Received pipeline run event...") await msg.ack() @@ -40,27 +125,25 @@ async def handle_run_pipeline(msg: NATSMsg, js: JetStreamContext): return valid_pipeline = PipelineJSON(id=event.id, **event.data) logger.info(f"Validated pipeline: {valid_pipeline.id}") - _ = Pipeline.from_pipeline(valid_pipeline) + pipeline = Pipeline.from_pipeline(valid_pipeline) agents = await get_agents(js) - logger.info(f"There are currently {len(agents)} agent(s) available...") - logger.info(f"Agents: {agents}") - if len(agents) < 1: + if len(agents) == 0: logger.info("No agents available to run pipeline.") # TODO: publish event to send message back to API that says this won't work return - first_agent_info = await get_agent_info(js, agents[0]) - if not first_agent_info: - logger.error("Agent info not found.") + agent_infos = await asyncio.gather(*[get_agent_info(js, agent) for agent in agents]) + agent_infos = [agent_info for agent_info in agent_infos if agent_info] + + try: + assignments = assign_pipeline_to_agents(agent_infos, pipeline) + except Exception as e: + logger.error(f"Failed to assign pipeline to agents: {e}") return - logger.info(f"Assigning pipeline to agent: {first_agent_info.uri.id}") - logger.info(f"Agent info: {first_agent_info}") - await js.publish( - f"{STREAM_AGENTS}.{first_agent_info.uri.id}", - stream=f"{STREAM_AGENTS}", - payload=msg.data, + await asyncio.gather( + *[publish_assignment(js, assignment) for assignment in assignments] ) logger.info("Pipeline run event processed.")