diff --git a/crates/arroyo-api/src/pipelines.rs b/crates/arroyo-api/src/pipelines.rs index 348fe7a30..02e7102e3 100644 --- a/crates/arroyo-api/src/pipelines.rs +++ b/crates/arroyo-api/src/pipelines.rs @@ -6,12 +6,12 @@ use axum::{debug_handler, Json}; use axum_extra::extract::WithRejection; use http::StatusCode; +use petgraph::visit::NodeRef; use petgraph::{Direction, EdgeDirection}; use std::collections::HashMap; use std::num::ParseIntError; use std::str::FromStr; use std::string::ParseError; -use petgraph::visit::NodeRef; use std::time::{Duration, SystemTime}; use crate::{compiler_service, connection_profiles, jobs, types}; @@ -274,9 +274,9 @@ async fn register_schemas(compiled_sql: &mut CompiledSql) -> anyhow::Result<()> if node.operator_name == OperatorName::ConnectorSink { let mut op = ConnectorOp::decode(&node.operator_config[..]).map_err(|_| { anyhow!( - "failed to decode configuration for connector node {:?}", - node - ) + "failed to decode configuration for connector node {:?}", + node + ) })?; try_register_confluent_schema(&mut op, &schema).await?; @@ -329,11 +329,13 @@ pub(crate) async fn create_pipeline_int<'a>( for idx in g.node_indices() { let should_replace = { let node = &g.node_weight(idx).unwrap().operator_chain; - node.is_sink() && node.iter().next().unwrap().0.operator_config != default_sink().encode_to_vec() + node.is_sink() + && node.iter().next().unwrap().0.operator_config + != default_sink().encode_to_vec() }; if should_replace { if enable_sinks { - todo!("enable sinks") + todo!("enable sinks") // let new_idx = g.add_node(LogicalNode { // operator_id: format!("{}_1", g.node_weight(idx).unwrap().operator_id), // description: "Preview sink".to_string(), @@ -349,8 +351,14 @@ pub(crate) async fn create_pipeline_int<'a>( // g.add_edge(source, new_idx, weight); // } } else { - g.node_weight_mut(idx).unwrap().operator_chain.iter_mut().next().unwrap().0.operator_config = - default_sink().encode_to_vec(); + g.node_weight_mut(idx) + .unwrap() + .operator_chain + .iter_mut() + .next() + .unwrap() + .0 + .operator_config = default_sink().encode_to_vec(); } } } diff --git a/crates/arroyo-connectors/src/impulse/operator.rs b/crates/arroyo-connectors/src/impulse/operator.rs index b78484de0..622d09ebc 100644 --- a/crates/arroyo-connectors/src/impulse/operator.rs +++ b/crates/arroyo-connectors/src/impulse/operator.rs @@ -92,7 +92,11 @@ impl ImpulseSourceFunc { } } - async fn run(&mut self, ctx: &mut SourceContext, collector: &mut SourceCollector) -> SourceFinishType { + async fn run( + &mut self, + ctx: &mut SourceContext, + collector: &mut SourceCollector, + ) -> SourceFinishType { let delay = self.delay(ctx); info!( "Starting impulse source with start {} delay {:?} and limit {}", @@ -138,18 +142,19 @@ impl ImpulseSourceFunc { let counter_column = counter_builder.finish(); let task_index_column = task_index_scalar.to_array_of_size(items).unwrap(); let timestamp_column = timestamp_builder.finish(); - collector.collect( - RecordBatch::try_new( - schema.clone(), - vec![ - Arc::new(counter_column), - Arc::new(task_index_column), - Arc::new(timestamp_column), - ], + collector + .collect( + RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(counter_column), + Arc::new(task_index_column), + Arc::new(timestamp_column), + ], + ) + .unwrap(), ) - .unwrap(), - ) - .await; + .await; items = 0; } @@ -163,18 +168,19 @@ impl ImpulseSourceFunc { let counter_column = counter_builder.finish(); let task_index_column = task_index_scalar.to_array_of_size(items).unwrap(); let timestamp_column = timestamp_builder.finish(); - collector.collect( - RecordBatch::try_new( - schema.clone(), - vec![ - Arc::new(counter_column), - Arc::new(task_index_column), - Arc::new(timestamp_column), - ], + collector + .collect( + RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(counter_column), + Arc::new(task_index_column), + Arc::new(timestamp_column), + ], + ) + .unwrap(), ) - .unwrap(), - ) - .await; + .await; items = 0; } ctx.table_manager @@ -222,18 +228,19 @@ impl ImpulseSourceFunc { let counter_column = counter_builder.finish(); let task_index_column = task_index_scalar.to_array_of_size(items).unwrap(); let timestamp_column = timestamp_builder.finish(); - collector.collect( - RecordBatch::try_new( - schema.clone(), - vec![ - Arc::new(counter_column), - Arc::new(task_index_column), - Arc::new(timestamp_column), - ], + collector + .collect( + RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(counter_column), + Arc::new(task_index_column), + Arc::new(timestamp_column), + ], + ) + .unwrap(), ) - .unwrap(), - ) - .await; + .await; } SourceFinishType::Final @@ -262,7 +269,11 @@ impl SourceOperator for ImpulseSourceFunc { } } - async fn run(&mut self, ctx: &mut SourceContext, collector: &mut SourceCollector) -> SourceFinishType { + async fn run( + &mut self, + ctx: &mut SourceContext, + collector: &mut SourceCollector, + ) -> SourceFinishType { self.run(ctx, collector).await } } diff --git a/crates/arroyo-connectors/src/kafka/mod.rs b/crates/arroyo-connectors/src/kafka/mod.rs index 2d1018751..8b3cefb18 100644 --- a/crates/arroyo-connectors/src/kafka/mod.rs +++ b/crates/arroyo-connectors/src/kafka/mod.rs @@ -399,48 +399,52 @@ impl Connector for KafkaConnector { None }; - Ok(ConstructedOperator::from_source(Box::new(KafkaSourceFunc { - topic: table.topic, - bootstrap_servers: profile.bootstrap_servers.to_string(), - group_id: group_id.clone(), - group_id_prefix: group_id_prefix.clone(), - offset_mode: *offset, - format: config.format.expect("Format must be set for Kafka source"), - framing: config.framing, - schema_resolver, - bad_data: config.bad_data, - client_configs, - context: Context::new(Some(profile.clone())), - messages_per_second: NonZeroU32::new( - config - .rate_limit - .map(|l| l.messages_per_second) - .unwrap_or(u32::MAX), - ) - .unwrap(), - metadata_fields: config.metadata_fields, - }))) + Ok(ConstructedOperator::from_source(Box::new( + KafkaSourceFunc { + topic: table.topic, + bootstrap_servers: profile.bootstrap_servers.to_string(), + group_id: group_id.clone(), + group_id_prefix: group_id_prefix.clone(), + offset_mode: *offset, + format: config.format.expect("Format must be set for Kafka source"), + framing: config.framing, + schema_resolver, + bad_data: config.bad_data, + client_configs, + context: Context::new(Some(profile.clone())), + messages_per_second: NonZeroU32::new( + config + .rate_limit + .map(|l| l.messages_per_second) + .unwrap_or(u32::MAX), + ) + .unwrap(), + metadata_fields: config.metadata_fields, + }, + ))) } TableType::Sink { commit_mode, key_field, timestamp_field, - } => Ok(ConstructedOperator::from_operator(Box::new(KafkaSinkFunc { - bootstrap_servers: profile.bootstrap_servers.to_string(), - producer: None, - consistency_mode: (*commit_mode).into(), - timestamp_field: timestamp_field.clone(), - timestamp_col: None, - key_field: key_field.clone(), - key_col: None, - write_futures: vec![], - client_config: client_configs(&profile, Some(table.clone()))?, - context: Context::new(Some(profile.clone())), - topic: table.topic, - serializer: ArrowSerializer::new( - config.format.expect("Format must be defined for KafkaSink"), - ), - }))), + } => Ok(ConstructedOperator::from_operator(Box::new( + KafkaSinkFunc { + bootstrap_servers: profile.bootstrap_servers.to_string(), + producer: None, + consistency_mode: (*commit_mode).into(), + timestamp_field: timestamp_field.clone(), + timestamp_col: None, + key_field: key_field.clone(), + key_col: None, + write_futures: vec![], + client_config: client_configs(&profile, Some(table.clone()))?, + context: Context::new(Some(profile.clone())), + topic: table.topic, + serializer: ArrowSerializer::new( + config.format.expect("Format must be defined for KafkaSink"), + ), + }, + ))), } } } diff --git a/crates/arroyo-connectors/src/kafka/sink/mod.rs b/crates/arroyo-connectors/src/kafka/sink/mod.rs index 94e7f1906..b5c47c52e 100644 --- a/crates/arroyo-connectors/src/kafka/sink/mod.rs +++ b/crates/arroyo-connectors/src/kafka/sink/mod.rs @@ -279,7 +279,12 @@ impl ArrowOperator for KafkaSinkFunc { .expect("Producer creation failed"); } - async fn process_batch(&mut self, batch: RecordBatch, ctx: &mut OperatorContext, _: &mut dyn Collector) { + async fn process_batch( + &mut self, + batch: RecordBatch, + ctx: &mut OperatorContext, + _: &mut dyn Collector, + ) { let values = self.serializer.serialize(&batch); let timestamps = batch .column( @@ -306,7 +311,12 @@ impl ArrowOperator for KafkaSinkFunc { } } - async fn handle_checkpoint(&mut self, _: CheckpointBarrier, ctx: &mut OperatorContext, _: &mut dyn Collector) { + async fn handle_checkpoint( + &mut self, + _: CheckpointBarrier, + ctx: &mut OperatorContext, + _: &mut dyn Collector, + ) { self.flush(ctx).await; if let ConsistencyMode::ExactlyOnce { next_transaction_index, @@ -372,7 +382,12 @@ impl ArrowOperator for KafkaSinkFunc { .expect("sent commit event"); } - async fn on_close(&mut self, final_message: &Option, ctx: &mut OperatorContext, _: &mut dyn Collector) { + async fn on_close( + &mut self, + final_message: &Option, + ctx: &mut OperatorContext, + _: &mut dyn Collector, + ) { self.flush(ctx).await; if !self.is_committing() { return; diff --git a/crates/arroyo-connectors/src/kafka/source/mod.rs b/crates/arroyo-connectors/src/kafka/source/mod.rs index 7f27d3e4f..38c4a1cd2 100644 --- a/crates/arroyo-connectors/src/kafka/source/mod.rs +++ b/crates/arroyo-connectors/src/kafka/source/mod.rs @@ -113,7 +113,9 @@ impl KafkaSourceFunc { partitions .iter() .enumerate() - .filter(|(i, _)| i % ctx.task_info.parallelism as usize == ctx.task_info.task_index as usize) + .filter(|(i, _)| { + i % ctx.task_info.parallelism as usize == ctx.task_info.task_index as usize + }) .map(|(_, p)| { let offset = state .get(&p.id()) @@ -145,7 +147,11 @@ impl KafkaSourceFunc { Ok(consumer) } - async fn run_int(&mut self, ctx: &mut SourceContext, collector: &mut SourceCollector) -> Result { + async fn run_int( + &mut self, + ctx: &mut SourceContext, + collector: &mut SourceCollector, + ) -> Result { let consumer = self .get_consumer(ctx) .await @@ -157,10 +163,9 @@ impl KafkaSourceFunc { if consumer.assignment().unwrap().count() == 0 { warn!("Kafka Consumer {}-{} is subscribed to no partitions, as there are more subtasks than partitions... setting idle", ctx.task_info.operator_id, ctx.task_info.task_index); - collector.broadcast(ArrowMessage::Signal(SignalMessage::Watermark( - Watermark::Idle, - ))) - .await; + collector + .broadcast(SignalMessage::Watermark(Watermark::Idle)) + .await; } if let Some(schema_resolver) = &self.schema_resolver { @@ -285,7 +290,11 @@ impl KafkaSourceFunc { #[async_trait] impl SourceOperator for KafkaSourceFunc { - async fn run(&mut self, ctx: &mut SourceContext, collector: &mut SourceCollector) -> SourceFinishType { + async fn run( + &mut self, + ctx: &mut SourceContext, + collector: &mut SourceCollector, + ) -> SourceFinishType { match self.run_int(ctx, collector).await { Ok(r) => r, Err(e) => { diff --git a/crates/arroyo-connectors/src/kafka/source/test.rs b/crates/arroyo-connectors/src/kafka/source/test.rs index 1dc0b709c..c2d4320c9 100644 --- a/crates/arroyo-connectors/src/kafka/source/test.rs +++ b/crates/arroyo-connectors/src/kafka/source/test.rs @@ -13,7 +13,7 @@ use std::sync::Arc; use std::time::{Duration, SystemTime}; use crate::kafka::SourceOffset; -use arroyo_operator::context::{batch_bounded, OperatorContext, BatchReceiver}; +use arroyo_operator::context::{batch_bounded, BatchReceiver, OperatorContext}; use arroyo_operator::operator::SourceOperator; use arroyo_rpc::df::ArroyoSchema; use arroyo_rpc::formats::{Format, RawStringFormat}; diff --git a/crates/arroyo-connectors/src/sse/operator.rs b/crates/arroyo-connectors/src/sse/operator.rs index bf0a7539d..54ec8106d 100644 --- a/crates/arroyo-connectors/src/sse/operator.rs +++ b/crates/arroyo-connectors/src/sse/operator.rs @@ -70,7 +70,11 @@ impl SourceOperator for SSESourceFunc { arroyo_state::global_table_config("e", "sse source state") } - async fn run(&mut self, ctx: &mut SourceContext, collector: &mut SourceCollector) -> SourceFinishType { + async fn run( + &mut self, + ctx: &mut SourceContext, + collector: &mut SourceCollector, + ) -> SourceFinishType { let s: &mut GlobalKeyedView<(), SSESourceState> = ctx .table_manager .get_global_keyed_state("e") @@ -136,7 +140,11 @@ impl SSESourceFunc { None } - async fn run_int(&mut self, ctx: &mut SourceContext, collector: &mut SourceCollector) -> Result { + async fn run_int( + &mut self, + ctx: &mut SourceContext, + collector: &mut SourceCollector, + ) -> Result { collector.initialize_deserializer( self.format.clone(), self.framing.clone(), @@ -226,10 +234,9 @@ impl SSESourceFunc { } } else { // otherwise set idle and just process control messages - collector.broadcast(ArrowMessage::Signal(SignalMessage::Watermark( - Watermark::Idle, - ))) - .await; + collector + .broadcast(SignalMessage::Watermark(Watermark::Idle)) + .await; loop { let msg = ctx.control_rx.recv().await; diff --git a/crates/arroyo-controller/src/job_controller/mod.rs b/crates/arroyo-controller/src/job_controller/mod.rs index d43646690..48d709134 100644 --- a/crates/arroyo-controller/src/job_controller/mod.rs +++ b/crates/arroyo-controller/src/job_controller/mod.rs @@ -188,8 +188,7 @@ impl RunningJobModel { CheckpointingOrCommittingState::Committing(committing_state) => { if matches!(c.event_type(), TaskCheckpointEventType::FinishedCommit) { - committing_state - .subtask_committed(c.node_id, c.subtask_index); + committing_state.subtask_committed(c.node_id, c.subtask_index); self.compact_state().await?; } else { warn!("unexpected checkpoint event type {:?}", c.event_type()) @@ -389,7 +388,8 @@ impl RunningJobModel { self.job_id.clone(), &op.operator_id, self.epoch, - ).await?; + ) + .await?; if compacted_tables.is_empty() { continue; @@ -405,7 +405,6 @@ impl RunningJobModel { }) .await?; } - } } @@ -656,11 +655,13 @@ impl JobController { .map(|(id, w)| (*id, w.connect.clone())) .collect(); let program = self.model.program.clone(); - let operator_indices: Arc> = Arc::new(program.graph - .node_indices() - .map(|idx| (program.graph[idx].node_id, idx.index() as u32)) - .collect()); - + let operator_indices: Arc> = Arc::new( + program + .graph + .node_indices() + .map(|idx| (program.graph[idx].node_id, idx.index() as u32)) + .collect(), + ); self.model.metric_update_task = Some(tokio::spawn(async move { let mut metrics: HashMap<(u32, u32), HashMap> = HashMap::new(); @@ -691,8 +692,8 @@ impl JobController { values.into_iter().filter_map(move |m| { let subtask_idx = u32::from_str(find_label(&m.label, "subtask_idx")?).ok()?; - let operator_idx = - *operator_indices.get(&u32::from_str(find_label(&m.label, "node_id")?).ok()?)?; + let operator_idx = *operator_indices + .get(&u32::from_str(find_label(&m.label, "node_id")?).ok()?)?; let value = m .counter .map(|c| c.value) diff --git a/crates/arroyo-controller/src/lib.rs b/crates/arroyo-controller/src/lib.rs index f132d1b70..7754621c0 100644 --- a/crates/arroyo-controller/src/lib.rs +++ b/crates/arroyo-controller/src/lib.rs @@ -570,7 +570,9 @@ impl ControllerServer { .as_object() .unwrap() .into_iter() - .filter_map(|(k, v)| Some((u32::from_str(k).ok()?, v.as_u64()? as usize))) + .filter_map(|(k, v)| { + Some((u32::from_str(k).ok()?, v.as_u64()? as usize)) + }) .collect(), restart_nonce: p.config_restart_nonce, restart_mode: p.restart_mode, diff --git a/crates/arroyo-controller/src/states/scheduling.rs b/crates/arroyo-controller/src/states/scheduling.rs index 75da1c7a5..d61641198 100644 --- a/crates/arroyo-controller/src/states/scheduling.rs +++ b/crates/arroyo-controller/src/states/scheduling.rs @@ -431,7 +431,7 @@ impl State for Scheduling { // } // } // } - // + // // committing_state = Some(CommittingState::new(id, commit_subtasks, committing_data)); // } todo!("committing") diff --git a/crates/arroyo-datastream/src/logical.rs b/crates/arroyo-datastream/src/logical.rs index a2f8bf4d3..0634d55b4 100644 --- a/crates/arroyo-datastream/src/logical.rs +++ b/crates/arroyo-datastream/src/logical.rs @@ -1,14 +1,13 @@ use datafusion_proto::protobuf::ArrowType; use itertools::Itertools; +use crate::optimizers::Optimizer; use anyhow::anyhow; use arrow_schema::DataType; use arroyo_rpc::api_types::pipelines::{PipelineEdge, PipelineGraph, PipelineNode}; use arroyo_rpc::df::ArroyoSchema; use arroyo_rpc::grpc::api; -use arroyo_rpc::grpc::api::{ - ArrowProgram, ArrowProgramConfig, ConnectorOp, EdgeType, -}; +use arroyo_rpc::grpc::api::{ArrowProgram, ArrowProgramConfig, ConnectorOp, EdgeType}; use petgraph::dot::Dot; use petgraph::graph::DiGraph; use petgraph::prelude::EdgeRef; @@ -24,7 +23,6 @@ use std::hash::Hasher; use std::str::FromStr; use std::sync::Arc; use strum::{Display, EnumString}; -use crate::optimizers::Optimizer; #[derive(Clone, Copy, Debug, Eq, PartialEq, EnumString, Display)] pub enum OperatorName { @@ -139,21 +137,12 @@ pub struct LogicalEdge { } impl LogicalEdge { - pub fn new( - edge_type: LogicalEdgeType, - schema: ArroyoSchema, - ) -> Self { - LogicalEdge { - edge_type, - schema, - } + pub fn new(edge_type: LogicalEdgeType, schema: ArroyoSchema) -> Self { + LogicalEdge { edge_type, schema } } pub fn project_all(edge_type: LogicalEdgeType, schema: ArroyoSchema) -> Self { - LogicalEdge { - edge_type, - schema, - } + LogicalEdge { edge_type, schema } } } @@ -179,18 +168,28 @@ impl OperatorChain { .map(|(l, r)| (l.unwrap(), r)) } - pub fn iter_mut(&mut self) -> impl Iterator)> { + pub fn iter_mut( + &mut self, + ) -> impl Iterator)> { self.operators .iter_mut() .zip_longest(self.edges.iter_mut()) .map(|e| e.left_and_right()) .map(|(l, r)| (l.unwrap(), r)) } - + + pub fn first(&self) -> &ChainedLogicalOperator { + &self.operators[0] + } + + pub fn len(&self) -> usize { + self.operators.len() + } + pub fn is_source(&self) -> bool { self.operators[0].operator_name == OperatorName::ConnectorSource } - + pub fn is_sink(&self) -> bool { self.operators[0].operator_name == OperatorName::ConnectorSink } @@ -288,7 +287,7 @@ impl LogicalProgram { program_config, } } - + pub fn optimize(&mut self, optimizer: &dyn Optimizer) { optimizer.optimize(&mut self.graph); } @@ -338,7 +337,7 @@ impl LogicalProgram { let mut tasks_per_operator = HashMap::new(); for node in self.graph.node_weights() { for op in &node.operator_chain.operators { - tasks_per_operator.insert(op.operator_id.clone(), node.parallelism); + tasks_per_operator.insert(op.operator_id.clone(), node.parallelism); } } tasks_per_operator @@ -351,7 +350,7 @@ impl LogicalProgram { } tasks_per_node } - + pub fn features(&self) -> HashSet { let mut s = HashSet::new(); diff --git a/crates/arroyo-datastream/src/optimizers.rs b/crates/arroyo-datastream/src/optimizers.rs index ebaba645c..cbe85cbab 100644 --- a/crates/arroyo-datastream/src/optimizers.rs +++ b/crates/arroyo-datastream/src/optimizers.rs @@ -1,123 +1,107 @@ -use std::collections::HashSet; -use petgraph::prelude::*; use crate::logical::{LogicalEdgeType, LogicalGraph}; +use petgraph::data::DataMapMut; +use petgraph::prelude::*; +use petgraph::visit::NodeRef; +use std::collections::HashSet; +use std::mem; pub trait Optimizer { - fn optimize(&self, plan: &mut LogicalGraph); -} + fn optimize_once(&self, plan: &mut LogicalGraph) -> bool; -pub struct ChainingOptimizer { + fn optimize(&self, plan: &mut LogicalGraph) { + loop { + if !self.optimize_once(plan) { + break; + } + } + } } +pub struct ChainingOptimizer {} + +fn remove_in_place(graph: &mut DiGraph, node: NodeIndex) { + let incoming = graph.edges_directed(node, Incoming).next().unwrap(); + + let parent = incoming.source().id(); + let incoming = incoming.id(); + graph.remove_edge(incoming); + + let outgoing: Vec<_> = graph + .edges_directed(node, Outgoing) + .map(|e| (e.id(), e.target().id())) + .collect(); + + for (edge, target) in outgoing { + let weight = graph.remove_edge(edge).unwrap(); + graph.add_edge(parent, target, weight); + } + + graph.remove_node(node); +} impl Optimizer for ChainingOptimizer { - fn optimize(&self, plan: &mut LogicalGraph) { + fn optimize_once(&self, plan: &mut LogicalGraph) -> bool { let node_indices: Vec = plan.node_indices().collect(); - let mut removed_nodes = HashSet::new(); for &node_idx in &node_indices { - if removed_nodes.contains(&node_idx) { + let cur = plan.node_weight(node_idx).unwrap(); + + // sources can't be chained + if cur.operator_chain.is_source() { continue; } - let mut current_node = match plan.node_weight(node_idx) { - Some(node) => node.clone(), - None => continue, - }; + let mut successors = plan.edges_directed(node_idx, Outgoing).collect::>(); - // sources and sinks can't be chained - if current_node.operator_chain.is_source() || current_node.operator_chain.is_sink() { + if successors.len() != 1 { continue; } - let mut chain = vec![node_idx]; - let mut next_node_idx = node_idx; - - loop { - let mut successors = plan - .edges_directed(next_node_idx, Outgoing) - .collect::>(); - - if successors.len() != 1 { - break; - } - - let edge = successors.remove(0); - let edge_type = edge.weight().edge_type; - - if edge_type != LogicalEdgeType::Forward { - break; - } - - let successor_idx = edge.target(); - - if removed_nodes.contains(&successor_idx) { - break; - } - - let successor_node = match plan.node_weight(successor_idx) { - Some(node) => node.clone(), - None => break, - }; - - // skip if parallelism doesn't match or successor is a sink - if current_node.parallelism != successor_node.parallelism || successor_node.operator_chain.is_sink() - { - break; - } - - // skip successors with multiple predecessors - if plan.edges_directed(successor_idx, Incoming).count() > 1 { - break; - } - - chain.push(successor_idx); - next_node_idx = successor_idx; + let edge = successors.remove(0); + let edge_type = edge.weight().edge_type; + + if edge_type != LogicalEdgeType::Forward { + continue; } - if chain.len() > 1 { - for i in 1..chain.len() { - let node_to_merge_idx = chain[i]; - let node_to_merge = plan.node_weight(node_to_merge_idx).unwrap().clone(); - - current_node.description = format!( - "{} -> {}", - current_node.description, node_to_merge.description - ); - - current_node - .operator_chain - .operators - .extend(node_to_merge.operator_chain.operators.clone()); - - if let Some(edge_idx) = plan.find_edge(chain[i - 1], node_to_merge_idx) { - let edge = plan.edge_weight(edge_idx).unwrap(); - current_node - .operator_chain - .edges - .push(edge.schema.clone()); - } - - removed_nodes.insert(node_to_merge_idx); - } - - plan[node_idx] = current_node; - - let last_node_idx = *chain.last().unwrap(); - let outgoing_edges: Vec<_> = plan - .edges_directed(last_node_idx, petgraph::Outgoing) - .map(|e| (e.id(), e.target(), e.weight().clone())) - .collect(); - - for (edge_id, target_idx, edge_weight) in outgoing_edges { - plan.remove_edge(edge_id); - plan.add_edge(node_idx, target_idx, edge_weight); - } + let successor_idx = edge.target(); + + let successor_node = plan.node_weight(successor_idx).unwrap(); + + // skip if parallelism doesn't match or successor is a sink + if cur.parallelism != successor_node.parallelism + || successor_node.operator_chain.is_sink() + { + continue; } - } - for node_idx in removed_nodes { - plan.remove_node(node_idx); + // skip successors with multiple predecessors + if plan.edges_directed(successor_idx, Incoming).count() > 1 { + continue; + } + + // construct the new node + let mut new_cur = cur.clone(); + + new_cur.description = format!("{} -> {}", cur.description, successor_node.description); + + new_cur + .operator_chain + .operators + .extend(successor_node.operator_chain.operators.clone()); + + new_cur + .operator_chain + .edges + .push(edge.weight().schema.clone()); + + mem::swap(&mut new_cur, plan.node_weight_mut(node_idx).unwrap()); + + // remove the old successor + remove_in_place(plan, successor_idx); + return true; } + + false } -} \ No newline at end of file +} diff --git a/crates/arroyo-operator/src/context.rs b/crates/arroyo-operator/src/context.rs index d0fd7d57e..23f23535f 100644 --- a/crates/arroyo-operator/src/context.rs +++ b/crates/arroyo-operator/src/context.rs @@ -16,8 +16,8 @@ use arroyo_rpc::{get_hasher, CompactionResult, ControlMessage, ControlResp}; use arroyo_state::tables::table_manager::TableManager; use arroyo_state::{BackingStore, StateBackend}; use arroyo_types::{ - from_micros, ArrowMessage, ChainInfo, CheckpointBarrier, SourceError, TaskInfo, UserError, - Watermark, + from_micros, ArrowMessage, ChainInfo, CheckpointBarrier, SignalMessage, SourceError, TaskInfo, + UserError, Watermark, }; use async_trait::async_trait; use datafusion::common::hash_utils; @@ -257,12 +257,11 @@ pub struct SourceContext { pub watermarks: WatermarkHolder, } - impl SourceContext { pub fn from_operator( ctx: OperatorContext, chain_info: Arc, - control_rx: Receiver, + control_rx: Receiver, ) -> Self { Self { out_schema: ctx.out_schema.expect("sources must have downstream nodes"), @@ -271,7 +270,7 @@ impl SourceContext { task_info: ctx.task_info.clone(), }, control_tx: ctx.control_tx, - control_rx, + control_rx, chain_info, task_info: ctx.task_info, table_manager: ctx.table_manager, @@ -279,7 +278,6 @@ impl SourceContext { } } - pub async fn load_compacted(&mut self, compaction: CompactionResult) { //TODO: support compaction in the table manager self.table_manager @@ -304,7 +302,6 @@ impl SourceContext { .await .unwrap(); } - } pub struct SourceCollector { @@ -328,7 +325,7 @@ impl SourceCollector { task_info: &Arc, ) -> Self { Self { - buffer: ContextBuffer::new(out_schema.schema.clone()), + buffer: ContextBuffer::new(out_schema.schema.clone()), out_schema, collector, control_tx, @@ -359,7 +356,7 @@ impl SourceCollector { schema_resolver, )); } - + pub fn initialize_deserializer( &mut self, format: Format, @@ -381,12 +378,12 @@ impl SourceCollector { pub fn should_flush(&self) -> bool { self.buffer.should_flush() || self - .deserializer - .as_ref() - .map(|d| d.should_flush()) - .unwrap_or(false) + .deserializer + .as_ref() + .map(|d| d.should_flush()) + .unwrap_or(false) } - + pub async fn deserialize_slice( &mut self, msg: &[u8], @@ -405,7 +402,6 @@ impl SourceCollector { Ok(()) } - /// Handling errors and rate limiting error reporting. /// Considers the `bad_data` option to determine whether to drop or fail on bad data. @@ -447,7 +443,7 @@ impl SourceCollector { } Ok(()) - } + } pub async fn flush_buffer(&mut self) -> Result<(), UserError> { if self.buffer.size() > 0 { @@ -474,15 +470,13 @@ impl SourceCollector { Ok(()) } - - pub async fn broadcast(&mut self, message: ArrowMessage) { + + pub async fn broadcast(&mut self, message: SignalMessage) { if let Err(e) = self.flush_buffer().await { self.buffered_error.replace(e); } self.collector.broadcast(message).await; } - - } pub async fn send_checkpoint_event( @@ -496,7 +490,7 @@ pub async fn send_checkpoint_event( tx.send(ControlResp::CheckpointEvent(arroyo_rpc::CheckpointEvent { checkpoint_epoch: barrier.epoch, node_id: info.node_id, - operator_id: info.operator_id.clone(), + operator_id: info.operator_id.clone(), subtask_index: info.task_index, time: SystemTime::now(), event_type, @@ -539,36 +533,7 @@ impl ErrorReporter { #[async_trait] pub trait Collector: Send { async fn collect(&mut self, batch: RecordBatch); - async fn broadcast(&mut self, message: ArrowMessage); -} - -pub struct ChainCollector { - messages: Vec, -} - -impl ChainCollector { - pub fn new() -> Self { - Self { messages: vec![] } - } - - pub fn iter(&self) -> impl Iterator { - self.messages.iter() - } - - pub fn clear(&mut self) { - self.messages.clear(); - } -} - -#[async_trait] -impl Collector for ChainCollector { - async fn collect(&mut self, batch: RecordBatch) { - self.messages.push(batch); - } - - async fn broadcast(&mut self, message: ArrowMessage) { - todo!() - } + async fn broadcast(&mut self, message: SignalMessage); } #[derive(Clone)] @@ -684,15 +649,17 @@ impl Collector for ArrowCollector { } } - async fn broadcast(&mut self, message: ArrowMessage) { + async fn broadcast(&mut self, message: SignalMessage) { for out_node in &self.out_qs { for q in out_node { - q.send(message.clone()).await.unwrap_or_else(|e| { - panic!( - "failed to broadcast message <{:?}> for operator {}: {}", - message, self.chain_info, e - ) - }); + q.send(ArrowMessage::Signal(message.clone())) + .await + .unwrap_or_else(|e| { + panic!( + "failed to broadcast message <{:?}> for operator {}: {}", + message, self.chain_info, e + ) + }); } } } @@ -773,7 +740,7 @@ impl OperatorContext { error_reporter: ErrorReporter { tx: control_tx, task_info, - } + }, } } diff --git a/crates/arroyo-operator/src/lib.rs b/crates/arroyo-operator/src/lib.rs index cfbd5ae58..8ac0eacab 100644 --- a/crates/arroyo-operator/src/lib.rs +++ b/crates/arroyo-operator/src/lib.rs @@ -9,7 +9,7 @@ use crate::inq_reader::InQReader; use arrow::array::types::{TimestampNanosecondType, UInt64Type}; use arrow::array::{Array, PrimitiveArray, RecordBatch, UInt64Array}; use arrow::compute::kernels::numeric::{div, rem}; -use arroyo_types::{ArrowMessage, CheckpointBarrier, Data, SignalMessage, TaskInfoRef}; +use arroyo_types::{ArrowMessage, CheckpointBarrier, Data, SignalMessage, TaskInfo}; use bincode::{Decode, Encode}; use crate::context::OperatorContext; @@ -113,7 +113,7 @@ impl CheckpointCounter { #[allow(unused)] pub struct RunContext + Send + Sync> { - pub task_info: TaskInfoRef, + pub task_info: Arc, pub name: String, pub counter: CheckpointCounter, pub closed: HashSet, diff --git a/crates/arroyo-operator/src/operator.rs b/crates/arroyo-operator/src/operator.rs index 620620942..888ec8c52 100644 --- a/crates/arroyo-operator/src/operator.rs +++ b/crates/arroyo-operator/src/operator.rs @@ -1,4 +1,7 @@ -use crate::context::{send_checkpoint_event, ArrowCollector, BatchReceiver, BatchSender, ChainCollector, Collector, OperatorContext, SourceCollector, SourceContext, WatermarkHolder}; +use crate::context::{ + send_checkpoint_event, ArrowCollector, BatchReceiver, BatchSender, Collector, OperatorContext, + SourceCollector, SourceContext, WatermarkHolder, +}; use crate::inq_reader::InQReader; use crate::udfs::{ArroyoUdaf, UdafArg}; use crate::{CheckpointCounter, ControlOutcome, SourceFinishType}; @@ -31,6 +34,7 @@ use datafusion::physical_plan::{displayable, ExecutionPlan}; use dlopen2::wrapper::Container; use futures::future::OptionFuture; use futures::stream::FuturesUnordered; +use futures::FutureExt; use std::any::Any; use std::borrow::Cow; use std::collections::{HashMap, HashSet}; @@ -40,6 +44,7 @@ use std::io::ErrorKind; use std::path::Path; use std::pin::Pin; use std::sync::Arc; +use std::task::{Context, Poll}; use std::time::{Duration, SystemTime}; use tokio::fs::OpenOptions; use tokio::io::AsyncWriteExt; @@ -75,6 +80,23 @@ impl ConstructedOperator { pub fn from_operator(operator: Box) -> Self { Self::Operator(operator) } + + pub fn name(&self) -> String { + match self { + Self::Source(s) => s.name(), + Self::Operator(s) => s.name(), + } + } + + pub fn display(&self) -> DisplayableOperator { + match self { + Self::Source(_) => DisplayableOperator { + name: self.name().into(), + fields: vec![], + }, + Self::Operator(op) => op.display(), + } + } } pub enum OperatorNode { @@ -140,14 +162,17 @@ impl OperatorNode { ) { match self { OperatorNode::Source(mut s) => { - let mut source_context = SourceContext::from_operator(s.context, chain_info.clone(), control_rx); - - let mut collector = SourceCollector::new(source_context.out_schema.clone(), - collector, - control_tx.clone(), - &source_context.chain_info, - &source_context.task_info); - + let mut source_context = + SourceContext::from_operator(s.context, chain_info.clone(), control_rx); + + let mut collector = SourceCollector::new( + source_context.out_schema.clone(), + collector, + control_tx.clone(), + &source_context.chain_info, + &source_context.task_info, + ); + s.operator.on_start(&mut source_context).await; ready.wait().await; @@ -168,20 +193,26 @@ impl OperatorNode { let result = s.operator.run(&mut source_context, &mut collector).await; - s.operator.on_close(&mut source_context, &mut collector).await; - + s.operator + .on_close(&mut source_context, &mut collector) + .await; + if let Some(final_message) = result.into() { - collector - .broadcast(ArrowMessage::Signal(final_message)) - .await; + collector.broadcast(final_message).await; } } OperatorNode::Chained(mut o) => { - let result = operator_run_behavior(&mut o, in_qs, control_tx, control_rx, &mut collector, ready).await; + let result = operator_run_behavior( + &mut o, + in_qs, + control_tx, + control_rx, + &mut collector, + ready, + ) + .await; if let Some(final_message) = result.into() { - collector - .broadcast(ArrowMessage::Signal(final_message)) - .await; + collector.broadcast(final_message).await; } } } @@ -226,22 +257,19 @@ impl OperatorNode { let collector = ArrowCollector::new(chain_info.clone(), out_schema, out_qs); - self - .run_behavior( - &chain_info, - control_tx.clone(), - control_rx, - &mut in_qs, - ready, - collector, - ) - .await; - + self.run_behavior( + &chain_info, + control_tx.clone(), + control_rx, + &mut in_qs, + ready, + collector, + ) + .await; + info!( "Task finished {}-{} ({})", - chain_info.node_id, - chain_info.task_index, - chain_info.description + chain_info.node_id, chain_info.task_index, chain_info.description ); control_tx @@ -275,9 +303,7 @@ async fn run_checkpoint( .await; collector - .broadcast(ArrowMessage::Signal(SignalMessage::Barrier( - checkpoint_barrier, - ))) + .broadcast(SignalMessage::Barrier(checkpoint_barrier)) .await; checkpoint_barrier.then_stop @@ -294,7 +320,11 @@ pub trait SourceOperator: Send + 'static { #[allow(unused_variables)] async fn on_start(&mut self, ctx: &mut SourceContext) {} - async fn run(&mut self, ctx: &mut SourceContext, collector: &mut SourceCollector) -> SourceFinishType; + async fn run( + &mut self, + ctx: &mut SourceContext, + collector: &mut SourceCollector, + ) -> SourceFinishType; #[allow(unused_variables)] async fn on_close(&mut self, ctx: &mut SourceContext, collector: &mut SourceCollector) {} @@ -325,6 +355,34 @@ pub trait SourceOperator: Send + 'static { } } +macro_rules! call_and_recurse { + ($self:expr, $final_collector:expr, $name:ident, $arg:expr) => { + match &mut $self.next { + Some(next) => { + let mut collector = ChainedCollector { + cur: next, + index: 0, + in_partitions: 1, + final_collector: $final_collector, + }; + + $self + .operator + .$name($arg, &mut $self.context, &mut collector) + .await; + + Box::pin(next.$name($arg, $final_collector)).await; + } + None => { + $self + .operator + .$name($arg, &mut $self.context, $final_collector) + .await; + } + } + }; +} + pub struct ChainedCollector<'a, 'b> { cur: &'a mut ChainedOperator, final_collector: &'b mut dyn Collector, @@ -372,8 +430,17 @@ where }; } - async fn broadcast(&mut self, message: ArrowMessage) { - todo!() + async fn broadcast(&mut self, message: SignalMessage) { + match message { + SignalMessage::Watermark(w) => { + self.cur + .handle_watermark(w, self.index, self.final_collector) + .await; + } + m => { + todo!("Unsupported signal message: {:?}", m); + } + } } } @@ -393,7 +460,7 @@ impl ChainedOperator { } } -struct ChainIteratorMut<'a> { +pub struct ChainIteratorMut<'a> { current: Option<&'a mut ChainedOperator>, } @@ -411,7 +478,7 @@ impl<'a> Iterator for ChainIteratorMut<'a> { } } -struct ChainIterator<'a> { +pub struct ChainIterator<'a> { current: Option<&'a ChainedOperator>, } @@ -429,6 +496,22 @@ impl<'a> Iterator for ChainIterator<'a> { } } +struct IndexedFuture { + f: Pin> + Send>>, + i: usize, +} + +impl Future for IndexedFuture { + type Output = (usize, Box); + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match self.f.as_mut().poll(cx) { + Poll::Ready(r) => Poll::Ready((self.i, r)), + Poll::Pending => Poll::Pending, + } + } +} + impl ChainedOperator { async fn handle_controller_message(&mut self, control_message: &ControlMessage) { for (op, ctx) in self.iter_mut() { @@ -459,10 +542,6 @@ impl ChainedOperator { self.iter().filter_map(|(op, _)| op.tick_interval()).min() } - fn display(&self) -> DisplayableOperator { - todo!() - } - async fn on_start(&mut self) { for (op, ctx) in self.iter_mut() { op.on_start(ctx).await; @@ -490,15 +569,22 @@ impl ChainedOperator { fn future_to_poll( &mut self, - ) -> Option> + Send>>> { + ) -> Option)> + Send>>> { let futures = self .iter_mut() - .filter_map(|(op, _)| op.future_to_poll()) + .enumerate() + .filter_map(|(i, (op, _))| { + Some(IndexedFuture { + f: op.future_to_poll()?, + i, + }) + }) .collect::>(); + let task = self.context.task_info.clone(); match futures.len() { 0 => None, - 1 => futures.into_iter().next(), + 1 => Some(Box::pin(futures.into_iter().next().unwrap())), _ => { Some(Box::pin(async move { let mut futures = FuturesUnordered::from_iter(futures.into_iter()); @@ -550,7 +636,7 @@ impl ChainedOperator { ) .await; - self.handle_checkpoint(*t).await; + self.handle_checkpoint(*t, collector).await; send_checkpoint_event( control_tx, @@ -655,7 +741,7 @@ impl ChainedOperator { .await; if let Some(watermark) = watermark { final_collector - .broadcast(ArrowMessage::Signal(SignalMessage::Watermark(watermark))) + .broadcast(SignalMessage::Watermark(watermark)) .await; } } @@ -664,14 +750,44 @@ impl ChainedOperator { async fn handle_future_result( &mut self, + op_index: usize, result: Box, - ctx: &mut OperatorContext, + final_collector: &mut dyn Collector, ) { - todo!() + let mut op = self; + for _ in 0..op_index { + op = op + .next + .as_mut() + .expect("Future produced from operator index larger than chain size"); + } + + match &mut op.next { + None => { + op.operator + .handle_future_result(result, &mut op.context, final_collector) + .await; + } + Some(next) => { + let mut collector = ChainedCollector { + cur: next, + final_collector, + index: 0, + in_partitions: 1, + }; + op.operator + .handle_future_result(result, &mut op.context, &mut collector) + .await; + } + } } - async fn handle_checkpoint(&mut self, b: CheckpointBarrier) { - todo!() + async fn handle_checkpoint( + &mut self, + b: CheckpointBarrier, + final_collector: &mut dyn Collector, + ) { + call_and_recurse!(self, final_collector, handle_checkpoint, b) } async fn handle_commit( @@ -683,8 +799,8 @@ impl ChainedOperator { todo!() } - async fn handle_tick(&mut self, tick: u64) { - todo!() + async fn handle_tick(&mut self, tick: u64, final_collector: &mut dyn Collector) { + call_and_recurse!(self, final_collector, handle_tick, tick) } async fn on_close( @@ -748,7 +864,7 @@ async fn operator_run_behavior( for (i, q) in in_qs.iter_mut().enumerate() { let stream = async_stream::stream! { while let Some(item) = q.recv().await { - yield(i,item); + yield(i, item); } }; sel.push(Box::pin(stream)); @@ -759,6 +875,7 @@ async fn operator_run_behavior( let mut ticks = 0u64; let mut interval = tokio::time::interval(this.tick_interval().unwrap_or(Duration::from_secs(60))); + interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); loop { @@ -827,12 +944,11 @@ async fn operator_run_behavior( } } Some(val) = operator_future => { - todo!() + this.handle_future_result(val.0, val.1, collector).await; } _ = interval.tick() => { - //todo!() - // this.handle_tick(ticks, ctx).await; - // ticks += 1; + this.handle_tick(ticks, collector).await; + ticks += 1; } } } diff --git a/crates/arroyo-planner/src/lib.rs b/crates/arroyo-planner/src/lib.rs index 897c12c12..895b07e6b 100644 --- a/crates/arroyo-planner/src/lib.rs +++ b/crates/arroyo-planner/src/lib.rs @@ -58,7 +58,9 @@ use crate::rewriters::{SourceMetadataVisitor, TimeWindowUdfChecker, UnnestRewrit use crate::udafs::EmptyUdaf; use arrow::compute::kernels::cast_utils::parse_interval_day_time; use arroyo_datastream::logical::LogicalProgram; +use arroyo_datastream::optimizers::ChainingOptimizer; use arroyo_operator::connector::Connection; +use arroyo_rpc::config::config; use arroyo_rpc::df::ArroyoSchema; use arroyo_rpc::TIMESTAMP_FIELD; use arroyo_udf_host::parse::{inner_type, UdfDef}; @@ -76,8 +78,6 @@ use std::{collections::HashMap, sync::Arc}; use syn::Item; use tracing::{debug, info, warn}; use unicase::UniCase; -use arroyo_datastream::optimizers::ChainingOptimizer; -use arroyo_rpc::config::config; const DEFAULT_IDLE_TIME: Option = Some(Duration::from_secs(5 * 60)); pub const ASYNC_RESULT_FIELD: &str = "__async_result"; @@ -813,9 +813,9 @@ pub async fn parse_and_get_arrow_program( python_udfs: schema_provider.python_udfs.clone(), }, ); - + if arroyo_rpc::config::config().pipeline.enable_chaining { - program.optimize(&ChainingOptimizer{}); + program.optimize(&ChainingOptimizer {}); } Ok(CompiledSql { diff --git a/crates/arroyo-rpc/src/config.rs b/crates/arroyo-rpc/src/config.rs index c7b7621a6..9f2b38932 100644 --- a/crates/arroyo-rpc/src/config.rs +++ b/crates/arroyo-rpc/src/config.rs @@ -426,7 +426,7 @@ pub struct PipelineConfig { /// Default sink, for when none is specified #[serde(default)] pub default_sink: DefaultSink, - + /// Whether to enable operator chaining pub enable_chaining: bool, diff --git a/crates/arroyo-rpc/src/lib.rs b/crates/arroyo-rpc/src/lib.rs index ac9f4780a..5bdb01bf2 100644 --- a/crates/arroyo-rpc/src/lib.rs +++ b/crates/arroyo-rpc/src/lib.rs @@ -92,7 +92,7 @@ impl From for CompactionResult { pub struct CheckpointCompleted { pub checkpoint_epoch: u32, pub node_id: u32, - pub operator_id: String, + pub operator_id: String, pub subtask_metadata: SubtaskCheckpointMetadata, } diff --git a/crates/arroyo-state/src/checkpoint_state.rs b/crates/arroyo-state/src/checkpoint_state.rs index ebdc710c3..e8f6dfed6 100644 --- a/crates/arroyo-state/src/checkpoint_state.rs +++ b/crates/arroyo-state/src/checkpoint_state.rs @@ -316,8 +316,7 @@ impl CheckpointState { } } { for i in 0..operator_state.subtasks_checkpointed { - self.subtasks_to_commit - .insert((c.node_id, i as u32)); + self.subtasks_to_commit.insert((c.node_id, i as u32)); } self.commit_data .entry(c.node_id) diff --git a/crates/arroyo-state/src/committing_state.rs b/crates/arroyo-state/src/committing_state.rs index 830543d3c..7ff0065bf 100644 --- a/crates/arroyo-state/src/committing_state.rs +++ b/crates/arroyo-state/src/committing_state.rs @@ -26,8 +26,7 @@ impl CommittingState { } pub fn subtask_committed(&mut self, node_id: u32, subtask_index: u32) { - self.subtasks_to_commit - .remove(&(node_id, subtask_index)); + self.subtasks_to_commit.remove(&(node_id, subtask_index)); } pub fn done(&self) -> bool { diff --git a/crates/arroyo-state/src/tables/expiring_time_key_map.rs b/crates/arroyo-state/src/tables/expiring_time_key_map.rs index e4b171d19..cee6c169c 100644 --- a/crates/arroyo-state/src/tables/expiring_time_key_map.rs +++ b/crates/arroyo-state/src/tables/expiring_time_key_map.rs @@ -24,7 +24,7 @@ use arroyo_rpc::{ }; use arroyo_storage::StorageProviderRef; use arroyo_types::{ - from_micros, from_nanos, print_time, server_for_hash, to_micros, to_nanos, TaskInfoRef, + from_micros, from_nanos, print_time, server_for_hash, to_micros, to_nanos, TaskInfo, }; use datafusion::parquet::arrow::async_reader::ParquetObjectReader; @@ -50,7 +50,7 @@ use super::{table_checkpoint_path, CompactionConfig, Table, TableEpochCheckpoint #[derive(Debug, Clone)] pub struct ExpiringTimeKeyTable { table_name: String, - task_info: TaskInfoRef, + task_info: Arc, schema: SchemaWithHashAndOperation, retention: Duration, storage_provider: StorageProviderRef, @@ -258,7 +258,7 @@ impl Table for ExpiringTimeKeyTable { fn from_config( config: Self::ConfigMessage, - task_info: arroyo_types::TaskInfoRef, + task_info: Arc, storage_provider: arroyo_storage::StorageProviderRef, checkpoint_message: Option, ) -> anyhow::Result { @@ -353,7 +353,7 @@ impl Table for ExpiringTimeKeyTable { TableEnum::ExpiringKeyedTimeTable } - fn task_info(&self) -> TaskInfoRef { + fn task_info(&self) -> Arc { self.task_info.clone() } diff --git a/crates/arroyo-state/src/tables/global_keyed_map.rs b/crates/arroyo-state/src/tables/global_keyed_map.rs index 422964fe8..b09638817 100644 --- a/crates/arroyo-state/src/tables/global_keyed_map.rs +++ b/crates/arroyo-state/src/tables/global_keyed_map.rs @@ -7,7 +7,7 @@ use arroyo_rpc::grpc::rpc::{ OperatorMetadata, TableEnum, }; use arroyo_storage::StorageProviderRef; -use arroyo_types::{to_micros, Data, Key, TaskInfoRef}; +use arroyo_types::{to_micros, Data, Key, TaskInfo}; use bincode::config; use once_cell::sync::Lazy; @@ -41,7 +41,7 @@ static GLOBAL_KEY_VALUE_SCHEMA: Lazy> = Lazy::new(|| { #[derive(Debug, Clone)] pub struct GlobalKeyedTable { table_name: String, - pub task_info: TaskInfoRef, + pub task_info: Arc, storage_provider: StorageProviderRef, pub files: Vec, } @@ -125,7 +125,7 @@ impl Table for GlobalKeyedTable { fn from_config( config: Self::ConfigMessage, - task_info: TaskInfoRef, + task_info: Arc, storage_provider: StorageProviderRef, checkpoint_message: Option, ) -> anyhow::Result { @@ -184,7 +184,7 @@ impl Table for GlobalKeyedTable { TableEnum::GlobalKeyValue } - fn task_info(&self) -> TaskInfoRef { + fn task_info(&self) -> Arc { self.task_info.clone() } @@ -227,7 +227,7 @@ impl Table for GlobalKeyedTable { pub struct GlobalKeyedCheckpointer { table_name: String, epoch: u32, - task_info: TaskInfoRef, + task_info: Arc, storage_provider: StorageProviderRef, latest_values: BTreeMap, Vec>, commit_data: Option>, diff --git a/crates/arroyo-state/src/tables/mod.rs b/crates/arroyo-state/src/tables/mod.rs index 6777b6f89..4a5c11891 100644 --- a/crates/arroyo-state/src/tables/mod.rs +++ b/crates/arroyo-state/src/tables/mod.rs @@ -5,10 +5,11 @@ use arroyo_rpc::grpc::rpc::{ TableSubtaskCheckpointMetadata, }; use arroyo_storage::StorageProviderRef; -use arroyo_types::TaskInfoRef; +use arroyo_types::TaskInfo; use prost::Message; use std::any::Any; use std::collections::{HashMap, HashSet}; +use std::sync::Arc; use std::time::SystemTime; use tracing::debug; @@ -79,7 +80,7 @@ pub(crate) trait Table: Send + Sync + 'static + Clone { // * checkpoint_message: If restoring from a checkpoint, the checkpoint data for that checkpoint's epoch. fn from_config( config: Self::ConfigMessage, - task_info: TaskInfoRef, + task_info: Arc, storage_provider: StorageProviderRef, checkpoint_message: Option, ) -> anyhow::Result @@ -117,7 +118,7 @@ pub(crate) trait Table: Send + Sync + 'static + Clone { fn table_type() -> TableEnum; - fn task_info(&self) -> TaskInfoRef; + fn task_info(&self) -> Arc; fn files_to_keep( config: Self::ConfigMessage, @@ -155,7 +156,7 @@ pub trait ErasedTable: Send + Sync + 'static { // * checkpoint_message: If restoring from a checkpoint, the checkpoint data for that checkpoint's epoch. fn from_config( config: TableConfig, - task_info: TaskInfoRef, + task_info: Arc, storage_provider: StorageProviderRef, checkpoint_message: Option, ) -> anyhow::Result @@ -241,7 +242,7 @@ pub trait ErasedTable: Send + Sync + 'static { impl ErasedTable for T { fn from_config( config: TableConfig, - task_info: TaskInfoRef, + task_info: Arc, storage_provider: StorageProviderRef, checkpoint_message: Option, ) -> anyhow::Result diff --git a/crates/arroyo-state/src/tables/table_manager.rs b/crates/arroyo-state/src/tables/table_manager.rs index eb8e1007e..975945eb0 100644 --- a/crates/arroyo-state/src/tables/table_manager.rs +++ b/crates/arroyo-state/src/tables/table_manager.rs @@ -11,7 +11,7 @@ use arroyo_rpc::{ CheckpointCompleted, ControlResp, }; use arroyo_storage::StorageProviderRef; -use arroyo_types::{from_micros, to_micros, CheckpointBarrier, Data, Key, TaskInfo, TaskInfoRef}; +use arroyo_types::{from_micros, to_micros, CheckpointBarrier, Data, Key, TaskInfo}; use tokio::sync::{ mpsc::{self, Receiver, Sender}, oneshot, @@ -38,7 +38,7 @@ pub struct TableManager { // ordered by table, then epoch. tables: HashMap>, writer: BackendWriter, - task_info: TaskInfoRef, + task_info: Arc, storage: StorageProviderRef, caches: HashMap>, } @@ -55,7 +55,7 @@ pub struct BackendFlusher { storage: StorageProviderRef, control_tx: Sender, finish_tx: Option>, - task_info: TaskInfoRef, + task_info: Arc, tables: HashMap>, table_configs: HashMap, table_checkpointers: HashMap>, @@ -195,7 +195,7 @@ impl BackendFlusher { impl BackendWriter { fn new( - task_info: TaskInfoRef, + task_info: Arc, control_tx: Sender, table_configs: HashMap, tables: HashMap>, diff --git a/crates/arroyo-types/src/lib.rs b/crates/arroyo-types/src/lib.rs index b59530afa..f7030ebb4 100644 --- a/crates/arroyo-types/src/lib.rs +++ b/crates/arroyo-types/src/lib.rs @@ -372,9 +372,6 @@ pub trait RecordBatchBuilder: Default + Debug + Sync + Send + 'static { fn schema(&self) -> SchemaRef; } -/// A reference-counted reference to a [TaskInfo]. -pub type TaskInfoRef = Arc; - #[derive(Eq, PartialEq, Hash, Debug, Clone, Encode, Decode)] pub struct TaskInfo { pub job_id: String, diff --git a/crates/arroyo-worker/src/arrow/async_udf.rs b/crates/arroyo-worker/src/arrow/async_udf.rs index 87966d9dc..7acdc7ba0 100644 --- a/crates/arroyo-worker/src/arrow/async_udf.rs +++ b/crates/arroyo-worker/src/arrow/async_udf.rs @@ -5,7 +5,10 @@ use arrow_schema::{Field, Schema}; use arroyo_datastream::logical::DylibUdfConfig; use arroyo_df::ASYNC_RESULT_FIELD; use arroyo_operator::context::{Collector, OperatorContext}; -use arroyo_operator::operator::{ArrowOperator, AsDisplayable, ConstructedOperator, DisplayableOperator, OperatorConstructor, Registry}; +use arroyo_operator::operator::{ + ArrowOperator, AsDisplayable, ConstructedOperator, DisplayableOperator, OperatorConstructor, + Registry, +}; use arroyo_rpc::grpc::api; use arroyo_rpc::grpc::rpc::TableConfig; use arroyo_state::global_table_config; @@ -81,24 +84,26 @@ impl OperatorConstructor for AsyncUdfConstructor { ) })?; - Ok(ConstructedOperator::from_operator(Box::new(AsyncUdfOperator { - name: config.name.clone(), - udf: (&*udf).try_into()?, - ordered, - allowed_in_flight: config.max_concurrency, - timeout: Duration::from_micros(config.timeout_micros), - config, - registry, - input_exprs: vec![], - final_exprs: vec![], - next_id: 0, - inputs: BTreeMap::new(), - outputs: BTreeMap::new(), - watermarks: VecDeque::new(), - input_row_converter: RowConverter::new(vec![]).unwrap(), - output_row_converter: RowConverter::new(vec![]).unwrap(), - input_schema: None, - }))) + Ok(ConstructedOperator::from_operator(Box::new( + AsyncUdfOperator { + name: config.name.clone(), + udf: (&*udf).try_into()?, + ordered, + allowed_in_flight: config.max_concurrency, + timeout: Duration::from_micros(config.timeout_micros), + config, + registry, + input_exprs: vec![], + final_exprs: vec![], + next_id: 0, + inputs: BTreeMap::new(), + outputs: BTreeMap::new(), + watermarks: VecDeque::new(), + input_row_converter: RowConverter::new(vec![]).unwrap(), + output_row_converter: RowConverter::new(vec![]).unwrap(), + input_schema: None, + }, + ))) } } @@ -232,7 +237,8 @@ impl ArrowOperator for AsyncUdfOperator { gs.get_all() .iter() .filter(|(task_index, _)| { - **task_index % ctx.task_info.parallelism as usize == ctx.task_info.task_index as usize + **task_index % ctx.task_info.parallelism as usize + == ctx.task_info.task_index as usize }) .for_each(|(_, state)| { for (k, v) in &state.inputs { @@ -288,7 +294,12 @@ impl ArrowOperator for AsyncUdfOperator { Some(Duration::from_millis(50)) } - async fn process_batch(&mut self, batch: RecordBatch, _: &mut OperatorContext, _: &mut dyn Collector) { + async fn process_batch( + &mut self, + batch: RecordBatch, + _: &mut OperatorContext, + _: &mut dyn Collector, + ) { let arg_batch: Vec<_> = self .input_exprs .iter() @@ -323,7 +334,12 @@ impl ArrowOperator for AsyncUdfOperator { } } - async fn handle_tick(&mut self, _: u64, ctx: &mut OperatorContext, collector: &mut dyn Collector) { + async fn handle_tick( + &mut self, + _: u64, + ctx: &mut OperatorContext, + collector: &mut dyn Collector, + ) { let Some((ids, results)) = self .udf .drain_results() @@ -379,7 +395,12 @@ impl ArrowOperator for AsyncUdfOperator { None } - async fn handle_checkpoint(&mut self, b: CheckpointBarrier, ctx: &mut OperatorContext, collector: &mut dyn Collector) { + async fn handle_checkpoint( + &mut self, + b: CheckpointBarrier, + ctx: &mut OperatorContext, + collector: &mut dyn Collector, + ) { let gs = ctx.table_manager.get_global_keyed_state("a").await.unwrap(); let state = AsyncUdfState { @@ -399,7 +420,12 @@ impl ArrowOperator for AsyncUdfOperator { gs.insert(ctx.task_info.task_index, state).await; } - async fn on_close(&mut self, final_message: &Option, ctx: &mut OperatorContext, collector: &mut dyn Collector) { + async fn on_close( + &mut self, + final_message: &Option, + ctx: &mut OperatorContext, + collector: &mut dyn Collector, + ) { if let Some(SignalMessage::EndOfData) = final_message { while !self.inputs.is_empty() && !self.outputs.is_empty() { self.handle_tick(0, ctx, collector).await; @@ -460,7 +486,8 @@ impl AsyncUdfOperator { if watermark_id <= oldest_unprocessed { // we've processed everything before this watermark, we can emit and drop it - collector.broadcast(ArrowMessage::Signal(SignalMessage::Watermark(watermark))) + collector + .broadcast(SignalMessage::Watermark(watermark)) .await; } else { // we still have messages preceding this watermark to work on diff --git a/crates/arroyo-worker/src/arrow/instant_join.rs b/crates/arroyo-worker/src/arrow/instant_join.rs index 244683266..c996e1987 100644 --- a/crates/arroyo-worker/src/arrow/instant_join.rs +++ b/crates/arroyo-worker/src/arrow/instant_join.rs @@ -5,7 +5,7 @@ use arrow_array::{RecordBatch, TimestampNanosecondArray}; use arroyo_df::physical::{ArroyoPhysicalExtensionCodec, DecodingContext}; use arroyo_operator::context::{Collector, OperatorContext}; use arroyo_operator::operator::{ - ArrowOperator, DisplayableOperator, OperatorConstructor, ConstructedOperator, Registry, + ArrowOperator, ConstructedOperator, DisplayableOperator, OperatorConstructor, Registry, }; use arroyo_rpc::{ df::{ArroyoSchema, ArroyoSchemaRef}, @@ -232,10 +232,15 @@ impl ArrowOperator for InstantJoin { } } - async fn process_batch(&mut self, _: RecordBatch, _: &mut OperatorContext, _: &mut dyn Collector) { + async fn process_batch( + &mut self, + _: RecordBatch, + _: &mut OperatorContext, + _: &mut dyn Collector, + ) { unreachable!(); } - + async fn process_batch_index( &mut self, index: usize, @@ -293,7 +298,12 @@ impl ArrowOperator for InstantJoin { Some(Watermark::EventTime(watermark)) } - async fn handle_checkpoint(&mut self, b: CheckpointBarrier, ctx: &mut OperatorContext, collector: &mut dyn Collector) { + async fn handle_checkpoint( + &mut self, + b: CheckpointBarrier, + ctx: &mut OperatorContext, + collector: &mut dyn Collector, + ) { let watermark = ctx.last_present_watermark(); ctx.table_manager .get_expiring_time_key_table("left", watermark) @@ -349,7 +359,12 @@ impl ArrowOperator for InstantJoin { })) } - async fn handle_future_result(&mut self, result: Box, _: &mut OperatorContext, collector: &mut dyn Collector) { + async fn handle_future_result( + &mut self, + result: Box, + _: &mut OperatorContext, + collector: &mut dyn Collector, + ) { let data: Box> = result.downcast().expect("invalid data in future"); if let Some((bin, batch_option)) = *data { match batch_option { @@ -359,7 +374,8 @@ impl ArrowOperator for InstantJoin { Some((batch, future)) => match self.execs.get_mut(&bin) { Some(exec) => { exec.active_exec = future.clone(); - collector.collect(batch.expect("should compute batch in future")) + collector + .collect(batch.expect("should compute batch in future")) .await; self.futures.lock().await.push(future); } diff --git a/crates/arroyo-worker/src/arrow/join_with_expiration.rs b/crates/arroyo-worker/src/arrow/join_with_expiration.rs index a03694ea9..0edabbf29 100644 --- a/crates/arroyo-worker/src/arrow/join_with_expiration.rs +++ b/crates/arroyo-worker/src/arrow/join_with_expiration.rs @@ -4,7 +4,8 @@ use arrow_array::RecordBatch; use arroyo_df::physical::{ArroyoPhysicalExtensionCodec, DecodingContext}; use arroyo_operator::context::{Collector, OperatorContext}; use arroyo_operator::operator::{ - ArrowOperator, AsDisplayable, DisplayableOperator, OperatorConstructor, ConstructedOperator, Registry, + ArrowOperator, AsDisplayable, ConstructedOperator, DisplayableOperator, OperatorConstructor, + Registry, }; use arroyo_rpc::{ df::ArroyoSchema, @@ -14,6 +15,7 @@ use arroyo_state::timestamp_table_config; use datafusion::execution::context::SessionContext; use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion::physical_plan::ExecutionPlan; +use datafusion::prelude::col; use datafusion_proto::{physical_plan::AsExecutionPlan, protobuf::PhysicalPlanNode}; use futures::StreamExt; use prost::Message; @@ -23,7 +25,6 @@ use std::{ sync::{Arc, RwLock}, time::Duration, }; -use datafusion::prelude::col; use tracing::warn; pub struct JoinWithExpiration { @@ -121,7 +122,7 @@ impl JoinWithExpiration { &mut self, left: RecordBatch, right: RecordBatch, - collector: &mut dyn Collector + collector: &mut dyn Collector, ) { { self.right_passer.write().unwrap().replace(right); @@ -165,7 +166,12 @@ impl ArrowOperator for JoinWithExpiration { } } - async fn process_batch(&mut self, _record_batch: RecordBatch, _ctx: &mut OperatorContext, _: &mut dyn Collector) { + async fn process_batch( + &mut self, + _record_batch: RecordBatch, + _ctx: &mut OperatorContext, + _: &mut dyn Collector, + ) { unreachable!(); } async fn process_batch_index( @@ -174,7 +180,7 @@ impl ArrowOperator for JoinWithExpiration { total_inputs: usize, record_batch: RecordBatch, ctx: &mut OperatorContext, - collector: &mut dyn Collector + collector: &mut dyn Collector, ) { match index / (total_inputs / 2) { 0 => self @@ -255,16 +261,18 @@ impl OperatorConstructor for JoinWithExpirationConstructor { ttl = Duration::from_secs(24 * 60 * 60); } - Ok(ConstructedOperator::from_operator(Box::new(JoinWithExpiration { - left_expiration: ttl, - right_expiration: ttl, - left_input_schema, - right_input_schema, - left_schema, - right_schema, - left_passer, - right_passer, - join_execution_plan, - }))) + Ok(ConstructedOperator::from_operator(Box::new( + JoinWithExpiration { + left_expiration: ttl, + right_expiration: ttl, + left_input_schema, + right_input_schema, + left_schema, + right_schema, + left_passer, + right_passer, + join_execution_plan, + }, + ))) } } diff --git a/crates/arroyo-worker/src/arrow/mod.rs b/crates/arroyo-worker/src/arrow/mod.rs index ae8767933..66ad38d65 100644 --- a/crates/arroyo-worker/src/arrow/mod.rs +++ b/crates/arroyo-worker/src/arrow/mod.rs @@ -4,7 +4,8 @@ use arroyo_df::physical::ArroyoPhysicalExtensionCodec; use arroyo_df::physical::DecodingContext; use arroyo_operator::context::{Collector, OperatorContext}; use arroyo_operator::operator::{ - ArrowOperator, AsDisplayable, DisplayableOperator, OperatorConstructor, ConstructedOperator, Registry, + ArrowOperator, AsDisplayable, ConstructedOperator, DisplayableOperator, OperatorConstructor, + Registry, }; use arroyo_rpc::grpc::api; use datafusion::common::DataFusionError; @@ -71,7 +72,12 @@ impl ArrowOperator for ValueExecutionOperator { } } - async fn process_batch(&mut self, record_batch: RecordBatch, ctx: &mut OperatorContext, collector: &mut dyn Collector) { + async fn process_batch( + &mut self, + record_batch: RecordBatch, + ctx: &mut OperatorContext, + collector: &mut dyn Collector, + ) { let mut records = self.executor.process_batch(record_batch).await; while let Some(batch) = records.next().await { let batch = batch.expect("should be able to compute batch"); @@ -196,7 +202,12 @@ impl ArrowOperator for KeyExecutionOperator { } } - async fn process_batch(&mut self, batch: RecordBatch, _: &mut OperatorContext, collector: &mut dyn Collector) { + async fn process_batch( + &mut self, + batch: RecordBatch, + _: &mut OperatorContext, + collector: &mut dyn Collector, + ) { let mut records = self.executor.process_batch(batch).await; while let Some(batch) = records.next().await { let batch = batch.expect("should be able to compute batch"); diff --git a/crates/arroyo-worker/src/arrow/session_aggregating_window.rs b/crates/arroyo-worker/src/arrow/session_aggregating_window.rs index 0317f7966..25d00d1e0 100644 --- a/crates/arroyo-worker/src/arrow/session_aggregating_window.rs +++ b/crates/arroyo-worker/src/arrow/session_aggregating_window.rs @@ -20,7 +20,7 @@ use arrow_schema::{DataType, Field, FieldRef}; use arroyo_df::schemas::window_arrow_struct; use arroyo_operator::{ context::OperatorContext, - operator::{ArrowOperator, OperatorConstructor, ConstructedOperator}, + operator::{ArrowOperator, ConstructedOperator, OperatorConstructor}, }; use arroyo_rpc::{ grpc::{api, rpc::TableConfig}, @@ -33,6 +33,7 @@ use arroyo_types::{from_nanos, print_time, to_nanos, CheckpointBarrier, Watermar use datafusion::{execution::context::SessionContext, physical_plan::ExecutionPlan}; use arroyo_df::physical::{ArroyoPhysicalExtensionCodec, DecodingContext}; +use arroyo_operator::context::Collector; use arroyo_operator::operator::Registry; use arroyo_rpc::df::{ArroyoSchema, ArroyoSchemaRef}; use datafusion::execution::{ @@ -45,7 +46,6 @@ use std::time::Duration; use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}; use tokio_stream::StreamExt; use tracing::{debug, warn}; -use arroyo_operator::context::Collector; // TODO: advance futures outside of method calls. pub struct SessionAggregatingWindowFunc { @@ -71,7 +71,11 @@ impl SessionAggregatingWindowFunc { result } - async fn advance(&mut self, ctx: &mut OperatorContext, collector: &mut dyn Collector) -> Result<()> { + async fn advance( + &mut self, + ctx: &mut OperatorContext, + collector: &mut dyn Collector, + ) -> Result<()> { let Some(watermark) = ctx.last_present_watermark() else { debug!("no watermark, not advancing"); return Ok(()); @@ -809,7 +813,12 @@ impl ArrowOperator for SessionAggregatingWindowFunc { } // TODO: filter out late data - async fn process_batch(&mut self, batch: RecordBatch, ctx: &mut OperatorContext, _: &mut dyn Collector) { + async fn process_batch( + &mut self, + batch: RecordBatch, + ctx: &mut OperatorContext, + _: &mut dyn Collector, + ) { debug!("received batch {:?}", batch); let current_watermark = ctx.last_present_watermark(); let batch = if let Some(watermark) = current_watermark { @@ -864,7 +873,12 @@ impl ArrowOperator for SessionAggregatingWindowFunc { Some(watermark) } - async fn handle_checkpoint(&mut self, b: CheckpointBarrier, ctx: &mut OperatorContext, collector: &mut dyn Collector) { + async fn handle_checkpoint( + &mut self, + b: CheckpointBarrier, + ctx: &mut OperatorContext, + collector: &mut dyn Collector, + ) { let watermark = ctx.last_present_watermark(); let table = ctx .table_manager diff --git a/crates/arroyo-worker/src/arrow/sliding_aggregating_window.rs b/crates/arroyo-worker/src/arrow/sliding_aggregating_window.rs index 744857e17..6d5897db2 100644 --- a/crates/arroyo-worker/src/arrow/sliding_aggregating_window.rs +++ b/crates/arroyo-worker/src/arrow/sliding_aggregating_window.rs @@ -4,7 +4,7 @@ use arrow_array::{types::TimestampNanosecondType, Array, PrimitiveArray, RecordB use arrow_schema::SchemaRef; use arroyo_operator::{ context::OperatorContext, - operator::{ArrowOperator, OperatorConstructor, ConstructedOperator}, + operator::{ArrowOperator, ConstructedOperator, OperatorConstructor}, }; use arroyo_rpc::grpc::{api, rpc::TableConfig}; use arroyo_state::timestamp_table_config; @@ -21,7 +21,9 @@ use std::{ use futures::stream::FuturesUnordered; +use super::sync::streams::KeyedCloneableStreamFuture; use arroyo_df::physical::{ArroyoPhysicalExtensionCodec, DecodingContext}; +use arroyo_operator::context::Collector; use arroyo_operator::operator::{AsDisplayable, DisplayableOperator, Registry}; use arroyo_rpc::df::ArroyoSchema; use datafusion::execution::{ @@ -39,8 +41,6 @@ use std::time::Duration; use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}; use tokio_stream::StreamExt; use tracing::info; -use arroyo_operator::context::Collector; -use super::sync::streams::KeyedCloneableStreamFuture; pub struct SlidingAggregatingWindowFunc { slide: Duration, @@ -113,7 +113,11 @@ impl SlidingAggregatingWindowFunc { } } - async fn advance(&mut self, ctx: &mut OperatorContext, collector: &mut dyn Collector) -> Result<()> { + async fn advance( + &mut self, + ctx: &mut OperatorContext, + collector: &mut dyn Collector, + ) -> Result<()> { let bin_start = match self.state { SlidingWindowState::NoData => unreachable!(), SlidingWindowState::OnlyBufferedData { earliest_bin_time } => earliest_bin_time, @@ -596,7 +600,12 @@ impl ArrowOperator for SlidingAggregatingWindowFunc { } // TODO: filter out late data - async fn process_batch(&mut self, batch: RecordBatch, ctx: &mut OperatorContext, _: &mut dyn Collector) { + async fn process_batch( + &mut self, + batch: RecordBatch, + ctx: &mut OperatorContext, + _: &mut dyn Collector, + ) { let bin = self .binning_function .evaluate(&batch) @@ -683,7 +692,12 @@ impl ArrowOperator for SlidingAggregatingWindowFunc { Some(watermark) } - async fn handle_checkpoint(&mut self, b: CheckpointBarrier, ctx: &mut OperatorContext, collector: &mut dyn Collector) { + async fn handle_checkpoint( + &mut self, + b: CheckpointBarrier, + ctx: &mut OperatorContext, + collector: &mut dyn Collector, + ) { let watermark = ctx .watermark() .and_then(|watermark: Watermark| match watermark { diff --git a/crates/arroyo-worker/src/arrow/tumbling_aggregating_window.rs b/crates/arroyo-worker/src/arrow/tumbling_aggregating_window.rs index c55182535..98a573e06 100644 --- a/crates/arroyo-worker/src/arrow/tumbling_aggregating_window.rs +++ b/crates/arroyo-worker/src/arrow/tumbling_aggregating_window.rs @@ -5,7 +5,8 @@ use arrow_schema::SchemaRef; use arroyo_df::schemas::add_timestamp_field_arrow; use arroyo_operator::context::{Collector, OperatorContext}; use arroyo_operator::operator::{ - ArrowOperator, AsDisplayable, DisplayableOperator, OperatorConstructor, ConstructedOperator, Registry, + ArrowOperator, AsDisplayable, ConstructedOperator, DisplayableOperator, OperatorConstructor, + Registry, }; use arroyo_rpc::grpc::{api, rpc::TableConfig}; use arroyo_state::timestamp_table_config; @@ -31,6 +32,7 @@ use datafusion::execution::{ SendableRecordBatchStream, }; use datafusion::physical_expr::PhysicalExpr; +use datafusion::prelude::col; use datafusion_proto::physical_plan::DefaultPhysicalExtensionCodec; use datafusion_proto::{ physical_plan::{from_proto::parse_physical_expr, AsExecutionPlan}, @@ -38,7 +40,6 @@ use datafusion_proto::{ }; use prost::Message; use std::time::Duration; -use datafusion::prelude::col; use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}; use tokio::sync::Mutex; use tracing::{debug, warn}; @@ -247,7 +248,12 @@ impl ArrowOperator for TumblingAggregatingWindowFunc { } } - async fn process_batch(&mut self, batch: RecordBatch, ctx: &mut OperatorContext, _: &mut dyn Collector) { + async fn process_batch( + &mut self, + batch: RecordBatch, + ctx: &mut OperatorContext, + _: &mut dyn Collector, + ) { let bin = self .binning_function .evaluate(&batch) @@ -401,7 +407,12 @@ impl ArrowOperator for TumblingAggregatingWindowFunc { })) } - async fn handle_future_result(&mut self, result: Box, ctx: &mut OperatorContext, collector: &mut dyn Collector) { + async fn handle_future_result( + &mut self, + result: Box, + ctx: &mut OperatorContext, + collector: &mut dyn Collector, + ) { let data: Box> = result.downcast().expect("invalid data in future"); if let Some((bin, batch_option)) = *data { match batch_option { @@ -422,7 +433,12 @@ impl ArrowOperator for TumblingAggregatingWindowFunc { } } - async fn handle_checkpoint(&mut self, b: CheckpointBarrier, ctx: &mut OperatorContext, collector: &mut dyn Collector) { + async fn handle_checkpoint( + &mut self, + b: CheckpointBarrier, + ctx: &mut OperatorContext, + collector: &mut dyn Collector, + ) { let watermark = ctx .watermark() .and_then(|watermark: Watermark| match watermark { diff --git a/crates/arroyo-worker/src/arrow/updating_aggregator.rs b/crates/arroyo-worker/src/arrow/updating_aggregator.rs index ae1586554..ef252df1b 100644 --- a/crates/arroyo-worker/src/arrow/updating_aggregator.rs +++ b/crates/arroyo-worker/src/arrow/updating_aggregator.rs @@ -13,11 +13,12 @@ use arrow_array::{Array, BooleanArray, RecordBatch, StructArray}; use arrow_array::cast::AsArray; use arrow_schema::SchemaRef; use arroyo_df::physical::{ArroyoPhysicalExtensionCodec, DecodingContext}; +use arroyo_operator::context::Collector; use arroyo_operator::{ context::OperatorContext, operator::{ - ArrowOperator, AsDisplayable, DisplayableOperator, OperatorConstructor, ConstructedOperator, - Registry, + ArrowOperator, AsDisplayable, ConstructedOperator, DisplayableOperator, + OperatorConstructor, Registry, }, }; use arroyo_rpc::df::ArroyoSchemaRef; @@ -25,6 +26,7 @@ use arroyo_rpc::grpc::{api::UpdatingAggregateOperator, rpc::TableConfig}; use arroyo_rpc::{updating_meta_fields, UPDATING_META_FIELD}; use arroyo_state::timestamp_table_config; use arroyo_types::{CheckpointBarrier, SignalMessage, Watermark}; +use datafusion::common::utils::coerced_fixed_size_list_to_list; use datafusion::execution::{ runtime_env::{RuntimeConfig, RuntimeEnv}, SendableRecordBatchStream, @@ -35,11 +37,9 @@ use futures::{lock::Mutex, Future}; use itertools::Itertools; use prost::Message; use std::time::Duration; -use datafusion::common::utils::coerced_fixed_size_list_to_list; use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}; use tokio_stream::StreamExt; use tracing::log::warn; -use arroyo_operator::context::Collector; pub struct UpdatingAggregatingFunc { partial_aggregation_plan: Arc, @@ -59,7 +59,11 @@ pub struct UpdatingAggregatingFunc { } impl UpdatingAggregatingFunc { - async fn flush(&mut self, ctx: &mut OperatorContext, collector: &mut dyn Collector) -> Result<()> { + async fn flush( + &mut self, + ctx: &mut OperatorContext, + collector: &mut dyn Collector, + ) -> Result<()> { if self.sender.is_none() { return Ok(()); } @@ -160,11 +164,12 @@ impl UpdatingAggregatingFunc { } if !batches_to_write.is_empty() { - collector.collect(concat_batches( - &batches_to_write[0].schema(), - batches_to_write.iter(), - )?) - .await; + collector + .collect(concat_batches( + &batches_to_write[0].schema(), + batches_to_write.iter(), + )?) + .await; } Ok(()) @@ -244,14 +249,24 @@ impl ArrowOperator for UpdatingAggregatingFunc { } } - async fn process_batch(&mut self, batch: RecordBatch, _ctx: &mut OperatorContext, _: &mut dyn Collector) { + async fn process_batch( + &mut self, + batch: RecordBatch, + _ctx: &mut OperatorContext, + _: &mut dyn Collector, + ) { if self.sender.is_none() { self.init_exec(); } self.sender.as_ref().unwrap().send(batch).unwrap(); } - async fn handle_checkpoint(&mut self, b: CheckpointBarrier, ctx: &mut OperatorContext, collector: &mut dyn Collector) { + async fn handle_checkpoint( + &mut self, + b: CheckpointBarrier, + ctx: &mut OperatorContext, + collector: &mut dyn Collector, + ) { self.flush(ctx, collector).await.unwrap(); } @@ -285,7 +300,12 @@ impl ArrowOperator for UpdatingAggregatingFunc { Some(self.flush_interval) } - async fn handle_tick(&mut self, _tick: u64, ctx: &mut OperatorContext, collector: &mut dyn Collector) { + async fn handle_tick( + &mut self, + _tick: u64, + ctx: &mut OperatorContext, + collector: &mut dyn Collector, + ) { self.flush(ctx, collector).await.unwrap(); } @@ -334,7 +354,12 @@ impl ArrowOperator for UpdatingAggregatingFunc { })) } - async fn on_close(&mut self, final_message: &Option, ctx: &mut OperatorContext, collector: &mut dyn Collector) { + async fn on_close( + &mut self, + final_message: &Option, + ctx: &mut OperatorContext, + collector: &mut dyn Collector, + ) { if let Some(SignalMessage::EndOfData) = final_message { self.flush(ctx, collector).await.unwrap(); } diff --git a/crates/arroyo-worker/src/arrow/watermark_generator.rs b/crates/arroyo-worker/src/arrow/watermark_generator.rs index 669469fb0..e0d4ad268 100644 --- a/crates/arroyo-worker/src/arrow/watermark_generator.rs +++ b/crates/arroyo-worker/src/arrow/watermark_generator.rs @@ -3,7 +3,8 @@ use arrow_array::RecordBatch; use arroyo_operator::context::{Collector, OperatorContext}; use arroyo_operator::get_timestamp_col; use arroyo_operator::operator::{ - ArrowOperator, AsDisplayable, DisplayableOperator, OperatorConstructor, ConstructedOperator, Registry, + ArrowOperator, AsDisplayable, ConstructedOperator, DisplayableOperator, OperatorConstructor, + Registry, }; use arroyo_rpc::df::ArroyoSchema; use arroyo_rpc::grpc::api::ExpressionWatermarkConfig; @@ -131,20 +132,30 @@ impl ArrowOperator for WatermarkGenerator { self.state_cache = state; } - async fn on_close(&mut self, final_message: &Option, ctx: &mut OperatorContext, collector: &mut dyn Collector) { + async fn on_close( + &mut self, + final_message: &Option, + ctx: &mut OperatorContext, + collector: &mut dyn Collector, + ) { if let Some(SignalMessage::EndOfData) = final_message { // send final watermark on close collector - .broadcast(ArrowMessage::Signal(SignalMessage::Watermark( + .broadcast(SignalMessage::Watermark( // this is in the year 2554, far enough out be close to inifinity, // but can still be formatted. Watermark::EventTime(from_nanos(u64::MAX as u128)), - ))) + )) .await; } } - async fn process_batch(&mut self, record: RecordBatch, ctx: &mut OperatorContext, collector: &mut dyn Collector) { + async fn process_batch( + &mut self, + record: RecordBatch, + ctx: &mut OperatorContext, + collector: &mut dyn Collector, + ) { collector.collect(record.clone()).await; self.last_event = SystemTime::now(); @@ -182,16 +193,19 @@ impl ArrowOperator for WatermarkGenerator { to_millis(watermark) ); collector - .broadcast(ArrowMessage::Signal(SignalMessage::Watermark( - Watermark::EventTime(watermark), - ))) + .broadcast(SignalMessage::Watermark(Watermark::EventTime(watermark))) .await; self.state_cache.last_watermark_emitted_at = max_timestamp; self.idle = false; } } - async fn handle_checkpoint(&mut self, b: CheckpointBarrier, ctx: &mut OperatorContext, _: &mut dyn Collector) { + async fn handle_checkpoint( + &mut self, + b: CheckpointBarrier, + ctx: &mut OperatorContext, + _: &mut dyn Collector, + ) { let gs = ctx .table_manager .get_global_keyed_state("s") @@ -201,17 +215,21 @@ impl ArrowOperator for WatermarkGenerator { gs.insert(ctx.task_info.task_index, self.state_cache).await; } - async fn handle_tick(&mut self, _: u64, ctx: &mut OperatorContext, collector: &mut dyn Collector) { + async fn handle_tick( + &mut self, + t: u64, + ctx: &mut OperatorContext, + collector: &mut dyn Collector, + ) { if let Some(idle_time) = self.idle_time { if self.last_event.elapsed().unwrap_or(Duration::ZERO) > idle_time && !self.idle { info!( "Setting partition {} to idle after {:?}", ctx.task_info.task_index, idle_time ); - collector.broadcast(ArrowMessage::Signal(SignalMessage::Watermark( - Watermark::Idle, - ))) - .await; + collector + .broadcast(SignalMessage::Watermark(Watermark::Idle)) + .await; self.idle = true; } } diff --git a/crates/arroyo-worker/src/arrow/window_fn.rs b/crates/arroyo-worker/src/arrow/window_fn.rs index 29decfbb1..2ba03ac50 100644 --- a/crates/arroyo-worker/src/arrow/window_fn.rs +++ b/crates/arroyo-worker/src/arrow/window_fn.rs @@ -8,7 +8,9 @@ use arrow::compute::{max, min}; use arrow_array::RecordBatch; use arroyo_df::physical::{ArroyoPhysicalExtensionCodec, DecodingContext}; use arroyo_operator::context::{Collector, OperatorContext}; -use arroyo_operator::operator::{ArrowOperator, OperatorConstructor, ConstructedOperator, Registry}; +use arroyo_operator::operator::{ + ArrowOperator, ConstructedOperator, OperatorConstructor, Registry, +}; use arroyo_rpc::df::ArroyoSchema; use arroyo_rpc::grpc::rpc::TableConfig; use arroyo_rpc::{df::ArroyoSchemaRef, grpc::api}; @@ -137,7 +139,12 @@ impl ArrowOperator for WindowFunctionOperator { } } } - async fn process_batch(&mut self, batch: RecordBatch, ctx: &mut OperatorContext, collector: &mut dyn Collector) { + async fn process_batch( + &mut self, + batch: RecordBatch, + ctx: &mut OperatorContext, + collector: &mut dyn Collector, + ) { let current_watermark = ctx.last_present_watermark(); let table = ctx .table_manager @@ -187,7 +194,12 @@ impl ArrowOperator for WindowFunctionOperator { Some(Watermark::EventTime(watermark)) } - async fn handle_checkpoint(&mut self, b: CheckpointBarrier, ctx: &mut OperatorContext, collector: &mut dyn Collector) { + async fn handle_checkpoint( + &mut self, + b: CheckpointBarrier, + ctx: &mut OperatorContext, + collector: &mut dyn Collector, + ) { let watermark = ctx.last_present_watermark(); ctx.table_manager .get_expiring_time_key_table("input", watermark) diff --git a/crates/arroyo-worker/src/engine.rs b/crates/arroyo-worker/src/engine.rs index 2f436a115..79663575c 100644 --- a/crates/arroyo-worker/src/engine.rs +++ b/crates/arroyo-worker/src/engine.rs @@ -710,7 +710,16 @@ impl Engine { let join_task = { let control_tx = control_tx.clone(); tokio::spawn(async move { - operator.start(control_tx.clone(), control_rx, in_qs, out_qs, node.out_schema, ready).await; + operator + .start( + control_tx.clone(), + control_rx, + in_qs, + out_qs, + node.out_schema, + ready, + ) + .await; }) }; @@ -783,7 +792,7 @@ pub async fn construct_node( }) } else { let mut head = None; - let mut cur: &mut Option = &mut None; + let mut cur: Option<&mut ChainedOperator> = None; let mut input_partitions = input_partitions as usize; for (node, edge) in chain.iter() { let ConstructedOperator::Operator(op) = @@ -805,7 +814,7 @@ pub async fn construct_node( restore_from, control_tx.clone(), input_partitions, - if let Some(cur) = cur { + if let Some(cur) = &mut cur { vec![cur.context.out_schema.clone().unwrap()] } else { in_schemas.clone() @@ -817,10 +826,11 @@ pub async fn construct_node( if cur.is_none() { head = Some(ChainedOperator::new(op, ctx)); - cur = &mut head; + cur = head.as_mut(); input_partitions = 1; } else { cur.as_mut().unwrap().next = Some(Box::new(ChainedOperator::new(op, ctx))); + cur = Some(cur.unwrap().next.as_mut().unwrap().as_mut()); } } diff --git a/crates/arroyo-worker/src/lib.rs b/crates/arroyo-worker/src/lib.rs index 74b3f95be..9ea192132 100644 --- a/crates/arroyo-worker/src/lib.rs +++ b/crates/arroyo-worker/src/lib.rs @@ -111,12 +111,10 @@ impl LocalRunner { let name = format!("{}-0", self.program.name); let total_nodes = self.program.total_nodes(); let engine = Engine::for_local(self.program, name); - + let (control_tx, mut control_rx) = channel(128); - - let _running_engine = engine - .start(control_tx) - .await; + + let _running_engine = engine.start(control_tx).await; let mut finished_nodes = HashSet::new(); @@ -447,7 +445,7 @@ impl WorkerGrpc for WorkerServer { } let (control_tx, control_rx) = channel(128); - + let engine = { let network = { self.network.lock().unwrap().take().unwrap() }; @@ -464,9 +462,16 @@ impl WorkerGrpc for WorkerServer { None }; - let program = - Program::from_logical(self.name.to_string(), &self.job_id, &logical.graph, &req.tasks, registry, checkpoint_metadata.as_ref(), control_tx.clone()) - .await; + let program = Program::from_logical( + self.name.to_string(), + &self.job_id, + &logical.graph, + &req.tasks, + registry, + checkpoint_metadata.as_ref(), + control_tx.clone(), + ) + .await; let engine = Engine::new( program, @@ -476,9 +481,7 @@ impl WorkerGrpc for WorkerServer { network, req.tasks, ); - engine - .start(control_tx) - .await + engine.start(control_tx).await }; self.shutdown_guard diff --git a/crates/arroyo-worker/src/utils.rs b/crates/arroyo-worker/src/utils.rs index 36e9ab3a1..238260608 100644 --- a/crates/arroyo-worker/src/utils.rs +++ b/crates/arroyo-worker/src/utils.rs @@ -1,7 +1,9 @@ use crate::engine::construct_operator; use arrow_schema::Schema; -use arroyo_datastream::logical::LogicalProgram; +use arroyo_datastream::logical::{ChainedLogicalOperator, LogicalEdgeType, LogicalProgram}; use arroyo_df::physical::new_registry; +use arroyo_operator::operator::Registry; +use arroyo_rpc::grpc::api::EdgeType; use std::fmt::Write; use std::sync::Arc; @@ -13,85 +15,120 @@ fn format_arrow_schema_fields(schema: &Schema) -> Vec<(String, String)> { .collect() } +fn write_op(d2: &mut String, registry: &Arc, idx: usize, el: &ChainedLogicalOperator) { + let operator = construct_operator(el.operator_name, &el.operator_config, registry.clone()); + let display = operator.display(); + + let mut label = format!("### {} ({})", operator.name(), &display.name); + for (field, value) in display.fields { + label.push_str(&format!("\n## {}\n\n{}", field, value)); + } + + writeln!( + d2, + "{}: {{ + label: |markdown +{} + | + shape: rectangle +}}", + idx, label + ) + .unwrap(); +} + +fn write_edge( + d2: &mut String, + from: usize, + to: usize, + edge_idx: usize, + edge_type: &LogicalEdgeType, + schema: &Schema, +) { + let edge_label = format!("{}", edge_type); + + let schema_node_name = format!("schema_{}", edge_idx); + let schema_fields = format_arrow_schema_fields(&schema); + + writeln!(d2, "{}: {{", schema_node_name).unwrap(); + writeln!(d2, " shape: sql_table").unwrap(); + + for (field_name, field_type) in schema_fields { + writeln!( + d2, + " \"{}\": \"{}\"", + field_name.replace("\"", "\\\""), + field_type.replace("\"", "\\\"") + ) + .unwrap(); + } + + writeln!(d2, "}}").unwrap(); + + writeln!( + d2, + "{} -> {}: \"{}\"", + from, + schema_node_name, + edge_label.replace("\"", "\\\"") + ) + .unwrap(); + + writeln!(d2, "{} -> {}", schema_node_name, to).unwrap(); +} + pub async fn to_d2(logical: &LogicalProgram) -> anyhow::Result { - todo!() -// let mut registry = new_registry(); -// -// for (name, udf) in &logical.program_config.udf_dylibs { -// registry.load_dylib(name, udf).await?; -// } -// -// for udf in logical.program_config.python_udfs.values() { -// registry.add_python_udf(udf).await?; -// } -// -// let registry = Arc::new(registry); -// -// let mut d2 = String::new(); -// -// for idx in logical.graph.node_indices() { -// let node = logical.graph.node_weight(idx).unwrap(); -// let operator = construct_operator( -// node.operator_name, -// node.operator_config.clone(), -// registry.clone(), -// ); -// let display = operator.display(); -// -// let mut label = format!("### {} ({})", operator.name(), &display.name); -// for (field, value) in display.fields { -// label.push_str(&format!("\n## {}\n\n{}", field, value)); -// } -// -// writeln!( -// &mut d2, -// "{}: {{ -// label: |markdown -// {} -// | -// shape: rectangle -// }}", -// idx.index(), -// label -// ) -// .unwrap(); -// } -// -// for idx in logical.graph.edge_indices() { -// let edge = logical.graph.edge_weight(idx).unwrap(); -// let (from, to) = logical.graph.edge_endpoints(idx).unwrap(); -// -// let edge_label = format!("{}", edge.edge_type); -// -// let schema_node_name = format!("schema_{}", idx.index()); -// let schema_fields = format_arrow_schema_fields(&edge.schema.schema); -// -// writeln!(&mut d2, "{}: {{", schema_node_name).unwrap(); -// writeln!(&mut d2, " shape: sql_table").unwrap(); -// -// for (field_name, field_type) in schema_fields { -// writeln!( -// &mut d2, -// " \"{}\": \"{}\"", -// field_name.replace("\"", "\\\""), -// field_type.replace("\"", "\\\"") -// ) -// .unwrap(); -// } -// -// writeln!(&mut d2, "}}").unwrap(); -// -// writeln!( -// &mut d2, -// "{} -> {}: \"{}\"", -// from.index(), -// schema_node_name, -// edge_label.replace("\"", "\\\"") -// ) -// .unwrap(); -// -// writeln!(&mut d2, "{} -> {}", schema_node_name, to.index()).unwrap(); -// } -// -// Ok(d2) + let mut registry = new_registry(); + + for (name, udf) in &logical.program_config.udf_dylibs { + registry.load_dylib(name, udf).await?; + } + + for udf in logical.program_config.python_udfs.values() { + registry.add_python_udf(udf).await?; + } + + let registry = Arc::new(registry); + + let mut d2 = String::new(); + + for idx in logical.graph.node_indices() { + let node = logical.graph.node_weight(idx).unwrap(); + + if node.operator_chain.len() == 1 { + let el = node.operator_chain.first(); + write_op(&mut d2, ®istry, idx.index(), el); + } else { + writeln!(d2, "{}: {{", idx.index()).unwrap(); + for (i, (el, edge)) in node.operator_chain.iter().enumerate() { + write_op(&mut d2, ®istry, i, el); + if let Some(edge) = edge { + write_edge( + &mut d2, + i, + i + 1, + i, + &LogicalEdgeType::Forward, + &edge.schema, + ); + } + } + writeln!(d2, "}}").unwrap(); + } + } + + for idx in logical.graph.edge_indices() { + let edge = logical.graph.edge_weight(idx).unwrap(); + let (from, to) = logical.graph.edge_endpoints(idx).unwrap(); + write_edge( + &mut d2, + from.index(), + to.index(), + idx.index(), + &edge.edge_type, + &edge.schema.schema, + ); + } + + Ok(d2) }