-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add just recipe for server * hhiuuiubiuu hu * ws * asdlk * working controller example * webio stuff * webio stuff * scaffolding * push * it works I guess idk untested * format * EXECUTE ARBITRARY FLOWCHART LETS GO * EXECUTE ARBITRARY FLOWCHART LETS GO * format * format * remove commented * EXECUTE ARBITRARY FLOWCHART LETS GO * print * solid reactive mpl which works, minus param for add (need to use typed dict) * format * checkpoint * progress * use wrappers * nvm go back * works great * comments * delete old code * small update * start frontend * start frontend * UI works but with ZIP problem :( * Fix zip problem * Add optional zipping via hard coded list * UI is not as ugly and it works well now * cleanup * works with flowchart * remove unused * fix default ui input value on start * remove unused * remove unused types * use zustand * small refactor * cleanup and scaffold tests * just use 1 setter * remove test file * remove unused packages * reorganize * a * organize imports * add test recipe * static instead of cls * merge * fix up * smooth * transform * use a button instead * clean deps Signed-off-by: Joey Yu <[email protected]> --------- Signed-off-by: Joey Yu <[email protected]> Co-authored-by: Sasha Aleshchenko <[email protected]> Co-authored-by: Joey Yu <[email protected]>
- Loading branch information
1 parent
2d81dd5
commit eb5e9b1
Showing
39 changed files
with
1,293 additions
and
325 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
def add(x, y): | ||
return x + y |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
def bignum(x): | ||
return x |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
from captain.decorators import ui_input | ||
|
||
|
||
@ui_input | ||
def button(x): | ||
return x |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
def constant(): | ||
return 2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
from captain.decorators import ui_input | ||
|
||
|
||
@ui_input | ||
def gamepad(x): | ||
return x |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
from captain.decorators import ui_input | ||
|
||
|
||
@ui_input | ||
def slider(x): | ||
return x |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
def subtract(x, y): | ||
return x - y |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,200 @@ | ||
import os | ||
from dataclasses import dataclass | ||
from functools import partial | ||
from typing import Any, Callable, Mapping, Tuple | ||
|
||
import reactivex as rx | ||
import reactivex.operators as ops | ||
from reactivex import Observable, Subject | ||
|
||
from captain.logging import logger | ||
from captain.types.events import FlowUIEvent | ||
from captain.types.flowchart import FCBlock, FlowChart | ||
from captain.utils.blocks import import_blocks, is_ui_input | ||
|
||
BLOCKS_DIR = os.path.join("captain", "blocks") | ||
|
||
ZIPPED_BLOCKS = [] # TODO: I (sasha) am anti zip in all cases. | ||
|
||
|
||
@dataclass | ||
class FCBlockIO: | ||
block: FCBlock | ||
i: Subject | ||
o: Observable | ||
|
||
|
||
def find_islands(blocks: dict[str, FCBlock]) -> list[list[FCBlock]]: | ||
visited = set() | ||
|
||
def dfs(block: FCBlock, island: list[FCBlock]): | ||
visited.add(block.id) | ||
island.append(block) | ||
neighbors = [i.source for i in block.ins] + [o.target for o in block.outs] | ||
for connection in neighbors: | ||
if connection in visited: | ||
continue | ||
dfs(blocks[connection], island) | ||
|
||
islands: list[list[FCBlock]] = [] | ||
for block in blocks.values(): | ||
if block.id not in visited: | ||
island: list[FCBlock] = [] | ||
dfs(block, island) | ||
islands.append(island) | ||
|
||
return islands | ||
|
||
|
||
def wire_flowchart( | ||
flowchart: FlowChart, | ||
on_publish, | ||
starter: Observable, | ||
ui_inputs: Mapping[str, Observable], | ||
block_funcs: Mapping[str, Callable], | ||
): | ||
blocks: dict[str, FCBlock] = {b.id: b for b in flowchart.blocks} | ||
islands = find_islands(blocks) | ||
block_ios: dict[str, FCBlockIO] = {} | ||
|
||
for island in islands: | ||
for block in island: | ||
|
||
def run_block(blk: FCBlock, kwargs: dict[str, Any]): | ||
fn = block_funcs[blk.block_type] | ||
logger.debug(f"Running block {blk.id}") | ||
return fn(**kwargs) | ||
|
||
def make_block_fn_props( | ||
blk: FCBlock, inputs: list[Tuple[str, Any]] | ||
) -> dict[str, Any]: | ||
logger.debug(f"Making params for block {blk.id} with {inputs}") | ||
return dict(inputs) | ||
|
||
input_subject = Subject() | ||
input_subject.subscribe( | ||
partial( | ||
lambda blk, x: logger.debug( | ||
f"Input got {x} for {blk.id} regardless of zip" | ||
), | ||
block, | ||
) | ||
) | ||
|
||
output_observable = input_subject.pipe( | ||
ops.map(partial(make_block_fn_props, block)), | ||
ops.map(partial(run_block, block)), | ||
ops.publish(), # Makes it so values are not emitted on each subscribe | ||
) | ||
|
||
output_observable.subscribe( | ||
partial( | ||
lambda blk, x: logger.debug( | ||
f"Got {x} for {blk.id} after zip and transform" | ||
), | ||
block, | ||
) | ||
) | ||
output_observable.subscribe( | ||
partial(lambda blk, x: on_publish(x, blk.id), block), | ||
on_error=lambda e: logger.debug(e), | ||
on_completed=lambda: logger.debug("completed"), | ||
) | ||
|
||
# Start emitting values for outputs | ||
output_observable.connect() | ||
|
||
if block.id in ui_inputs: | ||
logger.debug(f"Connecting {block.id} to ui input {ui_inputs[block.id]}") | ||
ui_inputs[block.id].subscribe( | ||
input_subject.on_next, | ||
input_subject.on_error, | ||
input_subject.on_completed, | ||
) | ||
ui_inputs[block.id].subscribe( | ||
on_next=lambda x: logger.debug(f"Got {x} from the UI input subject") | ||
) | ||
|
||
block_ios[block.id] = FCBlockIO( | ||
block=block, i=input_subject, o=output_observable | ||
) | ||
|
||
visitedBlocks = set() | ||
|
||
def rec_connect_blocks(io: FCBlockIO): | ||
logger.info(f"Recursively connecting {io.block.id} to its inputs") | ||
|
||
if not io.block.ins and io.block.id not in ui_inputs: | ||
logger.info( | ||
f"Connected {io.block.id} to start observable with ui inputs {ui_inputs.keys()}" | ||
) | ||
logger.debug(f"CREATED REACTIVE EDGE {io.block.id} -> {starter}") | ||
starter.subscribe(io.i.on_next, io.i.on_error, io.i.on_completed) | ||
return | ||
|
||
if not io.block.ins and io.block.id in ui_inputs: | ||
return | ||
|
||
logger.debug(f"Connecting {io.block.id}") | ||
|
||
combine_strategy = ( | ||
rx.zip if io.block.block_type in ZIPPED_BLOCKS else rx.combine_latest | ||
) | ||
in_combined = combine_strategy( | ||
*( | ||
block_ios[conn.source].o.pipe( | ||
ops.map(partial(lambda param, v: (param, v), conn.targetParam)) | ||
) | ||
for conn in io.block.ins | ||
) | ||
) | ||
|
||
for conn in io.block.ins: | ||
logger.debug( | ||
f"CREATED REACTIVE EDGE {conn.source} -> {io.block.id} thru {'zip' if io.block.block_type in ZIPPED_BLOCKS else 'combine_latest'} via {conn.targetParam}" | ||
) | ||
|
||
in_combined.subscribe(io.i.on_next, io.i.on_error, io.i.on_completed) | ||
for conn in io.block.ins: | ||
logger.debug(conn) | ||
if conn.source in visitedBlocks: | ||
continue | ||
visitedBlocks.add(conn.source) | ||
rec_connect_blocks(block_ios[conn.source]) | ||
|
||
# Connect the graph backwards starting from the terminal nodes | ||
terminals = filter(lambda b: not b.outs, blocks.values()) | ||
for block in terminals: | ||
visitedBlocks.add(block.id) | ||
rec_connect_blocks(block_ios[block.id]) | ||
|
||
|
||
class Flow: | ||
flowchart: FlowChart | ||
ui_inputs: dict[str, Subject] | ||
|
||
def __init__( | ||
self, flowchart: FlowChart, publish_fn: Callable, start_obs: Observable | ||
) -> None: | ||
self.flowchart = flowchart | ||
self.ui_inputs = {} | ||
funcs = import_blocks(BLOCKS_DIR) | ||
for block in flowchart.blocks: | ||
if is_ui_input(funcs[block.block_type]): | ||
logger.debug(f"Creating UI input for {block.id}") | ||
self.ui_inputs[block.id] = Subject() | ||
wire_flowchart( | ||
flowchart=self.flowchart, | ||
on_publish=publish_fn, | ||
starter=start_obs, | ||
ui_inputs=self.ui_inputs, | ||
block_funcs=funcs, | ||
) | ||
|
||
@staticmethod | ||
def from_json(data: str, publish_fn: Callable, start_obs: Observable): | ||
fc = FlowChart.model_validate_json(data) | ||
return Flow(fc, publish_fn, start_obs) | ||
|
||
def process_ui_event(self, event: FlowUIEvent): | ||
self.ui_inputs[event.ui_input_id].on_next([("x", event.value)]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
from typing import Callable | ||
|
||
|
||
def ui_input(block: Callable) -> Callable: | ||
block.ui_input = True | ||
return block |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
import logging | ||
|
||
logging.basicConfig( | ||
level=logging.INFO, | ||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", | ||
) | ||
|
||
logger = logging.getLogger(__name__) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,133 @@ | ||
from fastapi import APIRouter | ||
import asyncio | ||
from asyncio import Future | ||
from typing import Any | ||
|
||
import reactivex.operators as ops | ||
from fastapi import APIRouter, WebSocket | ||
from pydantic import ValidationError | ||
from reactivex import Subject, create | ||
from reactivex.subject import BehaviorSubject | ||
|
||
from captain.controllers.reactive import Flow | ||
from captain.logging import logger | ||
from captain.types.events import ( | ||
FlowCancelEvent, | ||
FlowSocketMessage, | ||
FlowStartEvent, | ||
FlowStateUpdateEvent, | ||
FlowUIEvent, | ||
) | ||
from captain.types.flowchart import FlowChart | ||
|
||
router = APIRouter(tags=["blocks"], prefix="/blocks") | ||
|
||
|
||
@router.get("/") | ||
async def read_blocks(): | ||
return "Hello blocks!" | ||
|
||
|
||
class IgnoreComplete: | ||
def __call__(self, source): | ||
return create( | ||
lambda observer, scheduler: source.subscribe( | ||
observer.on_next, lambda err: observer.on_error(err) | ||
) | ||
) | ||
|
||
|
||
@router.websocket("/ws") | ||
async def websocket_endpoint(websocket: WebSocket): | ||
all_events = BehaviorSubject[str | int](0) | ||
button_events = BehaviorSubject(0) | ||
joy_events = BehaviorSubject("axis 0 0") | ||
|
||
def destruct(): | ||
print("completed") | ||
# TODO | ||
|
||
def on_next_event(event): | ||
print(event) | ||
if isinstance(str, event) and event.startswith("axis"): | ||
joy_events.on_next(event) | ||
else: | ||
button_events.on_next(event) | ||
|
||
def on_next_joy(event): | ||
x = float(event.split(" ")[-1]) | ||
print(f"Joy got {x}, event: {event}") | ||
|
||
def send_button(x) -> Future[None]: | ||
return asyncio.ensure_future(websocket.send_text(str(x))) | ||
|
||
button_events.pipe(IgnoreComplete()).pipe( | ||
ops.take_with_time(500), ops.flat_map_latest(send_button) | ||
).subscribe(on_next=print, on_error=lambda e: print(e), on_completed=destruct) | ||
joy_events.pipe(IgnoreComplete()).subscribe( | ||
on_next=on_next_joy, on_error=lambda e: print(e), on_completed=destruct | ||
) | ||
|
||
all_events.pipe(IgnoreComplete()).subscribe( | ||
on_next=on_next_event, on_error=lambda e: print(e), on_completed=destruct | ||
) | ||
|
||
await websocket.accept() | ||
while True: | ||
data = await websocket.receive_text() | ||
|
||
all_events.on_next(data) | ||
print(f"Got data: {data}") | ||
if data == "close": | ||
await websocket.close() | ||
break | ||
|
||
|
||
@router.websocket("/flowchart") | ||
async def websocket_flowchart(websocket: WebSocket): | ||
send_msg = send_message_factory(websocket) | ||
|
||
start_obs = Subject() | ||
start_obs.subscribe(on_next=lambda x: logger.info(f"Got start {x}")) | ||
|
||
def publish_fn(x, id): | ||
logger.debug(f"Publishing {x} for {id}") | ||
send_msg(FlowStateUpdateEvent(id=id, data=x).model_dump_json()) | ||
|
||
await websocket.accept() | ||
|
||
flow: Flow | None = None | ||
|
||
while True: | ||
data = await websocket.receive_text() | ||
try: | ||
message = FlowSocketMessage.model_validate_json(data) | ||
except ValidationError as e: | ||
logger.error(str(e)) | ||
continue | ||
|
||
match message.event: | ||
case FlowStartEvent(rf=rf): | ||
if flow is None: | ||
fc = FlowChart.from_react_flow(rf) | ||
logger.info("Creating flow from react flow instance") | ||
flow = Flow(fc, publish_fn, start_obs) | ||
case FlowCancelEvent(): | ||
flow = None | ||
logger.info("Cancelling flow") | ||
case FlowUIEvent(): | ||
if flow is None: | ||
logger.error("Can't process UI event for non existent flow") | ||
else: | ||
logger.debug(f"Got UI event {message.event}") | ||
flow.process_ui_event(message.event) | ||
|
||
|
||
def send_message_factory(websocket): | ||
def send_message(x: Any) -> Future[None]: | ||
""" | ||
USAGE: Flat map to this thingy | ||
""" | ||
logger.debug(f"supposed to send {x}") | ||
return asyncio.ensure_future(websocket.send_text(str(x))) | ||
|
||
return send_message |
Oops, something went wrong.