Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sending images over operators, zmq multipart messages, change message types #24

Merged
merged 3 commits into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 0 additions & 12 deletions backend/core/core/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,6 @@
Protocol,
URILocation,
)
from .messages import (
MESSAGE_SUBJECT_TO_MODEL,
BaseMessage,
DataMessage,
ErrorMessage,
MessageSubject,
PipelineMessage,
URIConnectMessage,
URIConnectResponseMessage,
URIMessage,
URIUpdateMessage,
)
from .pipeline import (
EdgeJSON,
InputJSON,
Expand Down
94 changes: 0 additions & 94 deletions backend/core/core/models/messages.py

This file was deleted.

12 changes: 0 additions & 12 deletions backend/core/core/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
PortJSON,
PortType,
)
from .models.messages import PutPipelineNodeMessage

logger = get_logger("pipeline")

Expand Down Expand Up @@ -189,14 +188,3 @@ def get_edge_model(self, input_id: IdType, output_id: IdType) -> EdgeJSON:
if not edge:
raise ValueError(f"Edge not found between {input_id} and {output_id}")
return EdgeJSON(**edge)

def put_node(self, message: PutPipelineNodeMessage) -> PutPipelineNodeMessage:
node_id = message.node.id
if node_id not in self.nodes:
raise ValueError(f"Node {node_id} not found in the graph.")

logger.warning(f"Updating node {node_id} with message: {message}")
current_node = self.nodes[node_id]
current_node.update(message.node.model_dump())
logger.info(f"Updated node {node_id} with message: {message}")
return message
6 changes: 3 additions & 3 deletions backend/operators/containerfiles/test/operator0.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,17 @@
from uuid import UUID

from core.logger import get_logger
from operators.examples import create_hello_world, receive_hello_world
from operators.examples import recv_image, send_image_every_second

logger = get_logger("operator_main", "DEBUG")


async def async_main(operator_id: str):
# Initialize the operator with the provided ID
if operator_id == "12345678-1234-1234-1234-1234567890ab":
operator = create_hello_world(UUID(operator_id))
operator = send_image_every_second(UUID(operator_id))
else:
operator = receive_hello_world(UUID(operator_id))
operator = recv_image(UUID(operator_id))

print(operator.id)

Expand Down
35 changes: 25 additions & 10 deletions backend/operators/operators/examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,35 +2,50 @@

import numpy as np
from core.logger import get_logger
from core.models.messages import BaseMessage, DataMessage

from .messengers.base import BytesMessage, MessageHeader, MessageSubject
from .operator import operator

logger = get_logger("operators.examples", "DEBUG")


@operator
def create_hello_world(inputs: BaseMessage | None) -> BaseMessage:
return DataMessage(data=b"Hello, World!")
def create_hello_world(inputs: BytesMessage | None) -> BytesMessage:
logger.info("Creating message...")
if inputs:
return BytesMessage(header=inputs.header, data=b"Hello, World!")
else:
header = MessageHeader(subject=MessageSubject.BYTES, meta={})
return BytesMessage(header=header,data=b"Hello, World!")


@operator
def receive_hello_world(inputs: BaseMessage | None) -> BaseMessage:
def receive_hello_world(inputs: BytesMessage | None) -> BytesMessage:
if inputs:
logger.info(f"Received message: {inputs}")
return inputs or DataMessage(data=b"No input provided")
header = MessageHeader(subject=MessageSubject.BYTES, meta={})
return inputs or BytesMessage(header=header, data=b"No input provided")


@operator
def process_hello_world(inputs: BaseMessage | None) -> BaseMessage:
def process_hello_world(inputs: BytesMessage | None) -> BytesMessage:
if inputs:
logger.info(f"Processing message: {inputs}")
return inputs or DataMessage(data=b"No input provided")
header = MessageHeader(subject=MessageSubject.BYTES, meta={})
return inputs or BytesMessage(header=header, data=b"No input provided")


# TODO: This operator does not work, because of send_model() serialization
@operator
def send_image_every_second(inputs: BaseMessage | None) -> BaseMessage:
def send_image_every_second(inputs: BytesMessage | None) -> BytesMessage:
time.sleep(1)
arr = np.random.randint(0, 255, (100, 100), dtype=np.uint8)
return DataMessage(data=arr.tobytes())
header = MessageHeader(subject=MessageSubject.BYTES, meta={})
return BytesMessage(header=header,data=arr.tobytes())

@operator
def recv_image(inputs: BytesMessage | None) -> BytesMessage:
if inputs:
arr = np.frombuffer(inputs.data, dtype=np.uint8).reshape(100, 100)
logger.info(f"Received image: {arr}")
header = MessageHeader(subject=MessageSubject.BYTES, meta={})
return inputs or BytesMessage(header=header, data=b"No input provided")
36 changes: 33 additions & 3 deletions backend/operators/operators/messengers/base.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,39 @@
from abc import ABC, abstractmethod
from enum import Enum
from typing import Any

from core.models.base import OperatorID
from core.models.messages import BaseMessage
from core.models.pipeline import InputJSON, OutputJSON
from nats.js import JetStreamContext
from pydantic import BaseModel


class MessageSubject(str, Enum):
BYTES = "bytes"
SHM = "shm"


class MessageHeader(BaseModel):
subject: MessageSubject
meta: dict[str, Any] = {}

class BaseMessage(BaseModel):
header: MessageHeader


class BytesMessage(BaseMessage):
data: bytes


class ShmMessage(BaseMessage):
shm_meta: dict[str, Any] = {}


MESSAGE_SUBJECT_TO_MODEL: dict[MessageSubject, type[BaseMessage]] = {
MessageSubject.BYTES: BytesMessage,
MessageSubject.SHM: ShmMessage,
}



class BaseMessenger(ABC):
Expand Down Expand Up @@ -32,11 +62,11 @@ def output_ports(self) -> list[OutputJSON]:
pass

@abstractmethod
async def send(self, msg, dst: str):
async def send(self, message: BytesMessage):
pass

@abstractmethod
async def recv(self, src: str) -> BaseMessage | None:
async def recv(self) -> BytesMessage | None:
pass

@abstractmethod
Expand Down
Loading