From 2f223db21557c15080bf865ac692d276b8f0b770 Mon Sep 17 00:00:00 2001 From: Baris Palaska Date: Fri, 27 Sep 2024 19:16:15 +0100 Subject: [PATCH] Upgrade to Datafusion 41 (#1062) * upgrade dependencies, some fixes, still wip compiles add license Update datafusion protobuf definitions (#1057) * update datafusion proto defs * allow optionals in proto3 update docker environment for higher protoc version * runs * test e2e, fix python * rm unnecessary dependency * create BallistaLogicalExtensionCodec that can decode/encode file formats, fix some tests * fix tests * clippy, tomlfmt * fix grpc connect info extract * extract into method, remove unnecessary log * datafusion to 41, adjust other deps --- Cargo.toml | 26 +- ballista-cli/Cargo.toml | 4 +- ballista-cli/src/main.rs | 28 +- ballista/client/README.md | 3 +- ballista/client/src/context.rs | 32 +- ballista/core/Cargo.toml | 6 +- ballista/core/src/config.rs | 14 +- .../src/execution_plans/distributed_query.rs | 4 + .../src/execution_plans/shuffle_reader.rs | 4 + .../src/execution_plans/shuffle_writer.rs | 4 + .../src/execution_plans/unresolved_shuffle.rs | 4 + ballista/core/src/serde/mod.rs | 104 +++++- ballista/core/src/serde/scheduler/mod.rs | 5 + ballista/core/src/utils.rs | 27 +- ballista/executor/Cargo.toml | 1 - ballista/executor/src/collect.rs | 4 + ballista/executor/src/executor.rs | 22 +- ballista/scheduler/Cargo.toml | 17 +- ballista/scheduler/src/api/handlers.rs | 297 ++++++++++-------- ballista/scheduler/src/api/mod.rs | 153 ++------- ballista/scheduler/src/bin/main.rs | 13 +- ballista/scheduler/src/cluster/mod.rs | 26 +- ballista/scheduler/src/config.rs | 6 +- ballista/scheduler/src/planner.rs | 2 + ballista/scheduler/src/scheduler_process.rs | 85 ++--- .../scheduler/src/scheduler_server/grpc.rs | 15 +- .../scheduler/src/scheduler_server/mod.rs | 6 +- .../scheduler/src/state/execution_graph.rs | 140 +++------ .../src/state/execution_graph_dot.rs | 17 +- ballista/scheduler/src/test_utils.rs | 73 ++++- examples/Cargo.toml | 4 +- python/.cargo/config.toml | 11 + python/Cargo.toml | 11 +- python/src/context.rs | 25 +- python/src/lib.rs | 3 +- 35 files changed, 653 insertions(+), 543 deletions(-) create mode 100644 python/.cargo/config.toml diff --git a/Cargo.toml b/Cargo.toml index 3c451885f..ea1c8321a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,19 +21,23 @@ members = ["ballista-cli", "ballista/cache", "ballista/client", "ballista/core", resolver = "2" [workspace.dependencies] -arrow = { version = "52.0.0", features = ["ipc_compression"] } -arrow-flight = { version = "52.0.0", features = ["flight-sql-experimental"] } -arrow-schema = { version = "52.0.0", default-features = false } +arrow = { version = "52.2.0", features = ["ipc_compression"] } +arrow-flight = { version = "52.2.0", features = ["flight-sql-experimental"] } +arrow-schema = { version = "52.2.0", default-features = false } +clap = { version = "3", features = ["derive", "cargo"] } configure_me = { version = "0.4.0" } configure_me_codegen = { version = "0.4.4" } -datafusion = "39.0.0" -datafusion-cli = "39.0.0" -datafusion-proto = "39.0.0" -datafusion-proto-common = "39.0.0" -object_store = "0.10.1" -sqlparser = "0.47.0" -tonic = { version = "0.11" } -tonic-build = { version = "0.11", default-features = false, features = [ +# bump directly to datafusion v43 to avoid the serde bug on v42 (https://github.com/apache/datafusion/pull/12626) +datafusion = "41.0.0" +datafusion-cli = "41.0.0" +datafusion-proto = "41.0.0" +datafusion-proto-common = "41.0.0" +object_store = "0.10.2" +prost = "0.12.0" +prost-types = "0.12.0" +sqlparser = "0.49.0" +tonic = { version = "0.11.0" } +tonic-build = { version = "0.11.0", default-features = false, features = [ "transport", "prost" ] } diff --git a/ballista-cli/Cargo.toml b/ballista-cli/Cargo.toml index e07ad2797..dc8ff7cbd 100644 --- a/ballista-cli/Cargo.toml +++ b/ballista-cli/Cargo.toml @@ -30,14 +30,14 @@ readme = "README.md" [dependencies] ballista = { path = "../ballista/client", version = "0.12.0", features = ["standalone"] } -clap = { version = "3", features = ["derive", "cargo"] } +clap = { workspace = true } datafusion = { workspace = true } datafusion-cli = { workspace = true } dirs = "5.0.1" env_logger = "0.10" mimalloc = { version = "0.1", default-features = false } num_cpus = "1.13.0" -rustyline = "11.0" +rustyline = "11.0.0" tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync", "parking_lot"] } [features] diff --git a/ballista-cli/src/main.rs b/ballista-cli/src/main.rs index a8055c877..6aeecd6c9 100644 --- a/ballista-cli/src/main.rs +++ b/ballista-cli/src/main.rs @@ -36,7 +36,7 @@ struct Args { short = 'p', long, help = "Path to your data, default to current directory", - validator(is_valid_data_dir) + value_parser(parse_valid_data_dir) )] data_path: Option, @@ -44,14 +44,14 @@ struct Args { short = 'c', long, help = "The batch size of each query, or use Ballista default", - validator(is_valid_batch_size) + value_parser(parse_batch_size) )] batch_size: Option, #[clap( long, help = "The max concurrent tasks, only for Ballista local mode. Default: all available cores", - validator(is_valid_concurrent_tasks_size) + value_parser(parse_valid_concurrent_tasks_size) )] concurrent_tasks: Option, @@ -60,7 +60,7 @@ struct Args { long, multiple_values = true, help = "Execute commands from file(s), then exit", - validator(is_valid_file) + value_parser(parse_valid_file) )] file: Vec, @@ -69,12 +69,12 @@ struct Args { long, multiple_values = true, help = "Run the provided files on startup instead of ~/.ballistarc", - validator(is_valid_file), + value_parser(parse_valid_file), conflicts_with = "file" )] rc: Option>, - #[clap(long, arg_enum, default_value_t = PrintFormat::Table)] + #[clap(long, value_enum, default_value_t = PrintFormat::Table)] format: PrintFormat, #[clap(long, help = "Ballista scheduler host")] @@ -168,32 +168,32 @@ pub async fn main() -> Result<()> { Ok(()) } -fn is_valid_file(dir: &str) -> std::result::Result<(), String> { +fn parse_valid_file(dir: &str) -> std::result::Result { if Path::new(dir).is_file() { - Ok(()) + Ok(dir.to_string()) } else { Err(format!("Invalid file '{dir}'")) } } -fn is_valid_data_dir(dir: &str) -> std::result::Result<(), String> { +fn parse_valid_data_dir(dir: &str) -> std::result::Result { if Path::new(dir).is_dir() { - Ok(()) + Ok(dir.to_string()) } else { Err(format!("Invalid data directory '{dir}'")) } } -fn is_valid_batch_size(size: &str) -> std::result::Result<(), String> { +fn parse_batch_size(size: &str) -> std::result::Result { match size.parse::() { - Ok(size) if size > 0 => Ok(()), + Ok(size) if size > 0 => Ok(size), _ => Err(format!("Invalid batch size '{size}'")), } } -fn is_valid_concurrent_tasks_size(size: &str) -> std::result::Result<(), String> { +fn parse_valid_concurrent_tasks_size(size: &str) -> std::result::Result { match size.parse::() { - Ok(size) if size > 0 => Ok(()), + Ok(size) if size > 0 => Ok(size), _ => Err(format!("Invalid concurrent_tasks size '{size}'")), } } diff --git a/ballista/client/README.md b/ballista/client/README.md index 19dc14390..ac65bc985 100644 --- a/ballista/client/README.md +++ b/ballista/client/README.md @@ -92,7 +92,8 @@ data set. Download the file and add it to the `testdata` folder before running t ```rust,no_run use ballista::prelude::*; -use datafusion::prelude::{col, min, max, avg, sum, ParquetReadOptions}; +use datafusion::prelude::{col, ParquetReadOptions}; +use datafusion::functions_aggregate::{min_max::min, min_max::max, sum::sum, average::avg}; #[tokio::main] async fn main() -> Result<()> { diff --git a/ballista/client/src/context.rs b/ballista/client/src/context.rs index de22b777e..269afc64d 100644 --- a/ballista/client/src/context.rs +++ b/ballista/client/src/context.rs @@ -19,6 +19,7 @@ use datafusion::arrow::datatypes::SchemaRef; use datafusion::execution::context::DataFilePaths; +use datafusion::sql::TableReference; use log::info; use parking_lot::Mutex; use sqlparser::ast::Statement; @@ -33,7 +34,6 @@ use ballista_core::utils::{ }; use datafusion_proto::protobuf::LogicalPlanNode; -use datafusion::catalog::TableReference; use datafusion::dataframe::DataFrame; use datafusion::datasource::{source_as_provider, TableProvider}; use datafusion::error::{DataFusionError, Result}; @@ -791,7 +791,7 @@ mod standalone_tests { let res = df.collect().await.unwrap(); let expected = vec![ "+--------------+", - "| MIN(test.id) |", + "| min(test.id) |", "+--------------+", "| 0 |", "+--------------+", @@ -802,7 +802,7 @@ mod standalone_tests { let res = df.collect().await.unwrap(); let expected = vec![ "+--------------+", - "| MAX(test.id) |", + "| max(test.id) |", "+--------------+", "| 7 |", "+--------------+", @@ -818,7 +818,7 @@ mod standalone_tests { let res = df.collect().await.unwrap(); let expected = vec![ "+--------------+", - "| SUM(test.id) |", + "| sum(test.id) |", "+--------------+", "| 28 |", "+--------------+", @@ -833,7 +833,7 @@ mod standalone_tests { let res = df.collect().await.unwrap(); let expected = vec![ "+--------------+", - "| AVG(test.id) |", + "| avg(test.id) |", "+--------------+", "| 3.5 |", "+--------------+", @@ -849,7 +849,7 @@ mod standalone_tests { let res = df.collect().await.unwrap(); let expected = vec![ "+----------------+", - "| COUNT(test.id) |", + "| count(test.id) |", "+----------------+", "| 8 |", "+----------------+", @@ -867,7 +867,7 @@ mod standalone_tests { let res = df.collect().await.unwrap(); let expected = vec![ "+--------------------------+", - "| APPROX_DISTINCT(test.id) |", + "| approx_distinct(test.id) |", "+--------------------------+", "| 8 |", "+--------------------------+", @@ -885,7 +885,7 @@ mod standalone_tests { let res = df.collect().await.unwrap(); let expected = vec![ "+--------------------------+", - "| ARRAY_AGG(test.id) |", + "| array_agg(test.id) |", "+--------------------------+", "| [4, 5, 6, 7, 2, 3, 0, 1] |", "+--------------------------+", @@ -914,7 +914,7 @@ mod standalone_tests { let res = df.collect().await.unwrap(); let expected = vec![ "+-------------------+", - "| VAR_POP(test.id) |", + "| var_pop(test.id) |", "+-------------------+", "| 5.250000000000001 |", "+-------------------+", @@ -946,7 +946,7 @@ mod standalone_tests { let res = df.collect().await.unwrap(); let expected = vec![ "+--------------------+", - "| STDDEV(test.id) |", + "| stddev(test.id) |", "+--------------------+", "| 2.4494897427831783 |", "+--------------------+", @@ -960,7 +960,7 @@ mod standalone_tests { let res = df.collect().await.unwrap(); let expected = vec![ "+--------------------+", - "| STDDEV(test.id) |", + "| stddev(test.id) |", "+--------------------+", "| 2.4494897427831783 |", "+--------------------+", @@ -996,25 +996,27 @@ mod standalone_tests { let res = df.collect().await.unwrap(); let expected = vec![ "+--------------------------------+", - "| CORR(test.id,test.tinyint_col) |", + "| corr(test.id,test.tinyint_col) |", "+--------------------------------+", "| 0.21821789023599245 |", "+--------------------------------+", ]; assert_result_eq(expected, &res); } + // enable when upgrading Datafusion to > 42 + #[ignore] #[tokio::test] async fn test_aggregate_approx_percentile() { let context = create_test_context().await; let df = context - .sql("select approx_percentile_cont_with_weight(\"id\", 2, 0.5) from test") + .sql("select approx_percentile_cont_with_weight(id, 2, 0.5) from test") .await .unwrap(); let res = df.collect().await.unwrap(); let expected = vec![ "+-------------------------------------------------------------------+", - "| APPROX_PERCENTILE_CONT_WITH_WEIGHT(test.id,Int64(2),Float64(0.5)) |", + "| approx_percentile_cont_with_weight(test.id,Int64(2),Float64(0.5)) |", "+-------------------------------------------------------------------+", "| 1 |", "+-------------------------------------------------------------------+", @@ -1028,7 +1030,7 @@ mod standalone_tests { let res = df.collect().await.unwrap(); let expected = vec![ "+------------------------------------------------------+", - "| APPROX_PERCENTILE_CONT(test.double_col,Float64(0.5)) |", + "| approx_percentile_cont(test.double_col,Float64(0.5)) |", "+------------------------------------------------------+", "| 7.574999999999999 |", "+------------------------------------------------------+", diff --git a/ballista/core/Cargo.toml b/ballista/core/Cargo.toml index fccdd0ecc..8a01f56fb 100644 --- a/ballista/core/Cargo.toml +++ b/ballista/core/Cargo.toml @@ -51,7 +51,7 @@ async-trait = "0.1.41" ballista-cache = { path = "../cache", version = "0.12.0" } bytes = "1.0" chrono = { version = "0.4", default-features = false } -clap = { version = "3", features = ["derive", "cargo"] } +clap = { workspace = true } datafusion = { workspace = true } datafusion-objectstore-hdfs = { version = "0.1.4", default-features = false, optional = true } datafusion-proto = { workspace = true } @@ -68,8 +68,8 @@ once_cell = "1.9.0" parking_lot = "0.12" parse_arg = "0.1.3" -prost = "0.12" -prost-types = "0.12" +prost = { workspace = true } +prost-types = { workspace = true } rand = "0.8" serde = { version = "1", features = ["derive"] } sqlparser = { workspace = true } diff --git a/ballista/core/src/config.rs b/ballista/core/src/config.rs index 03c8f6b9c..46424ecf4 100644 --- a/ballista/core/src/config.rs +++ b/ballista/core/src/config.rs @@ -18,7 +18,7 @@ //! Ballista configuration -use clap::ArgEnum; +use clap::ValueEnum; use core::fmt; use std::collections::HashMap; use std::result; @@ -307,7 +307,7 @@ impl BallistaConfig { // an enum used to configure the scheduler policy // needs to be visible to code generated by configure_me -#[derive(Clone, ArgEnum, Copy, Debug, serde::Deserialize)] +#[derive(Clone, ValueEnum, Copy, Debug, serde::Deserialize)] pub enum TaskSchedulingPolicy { PullStaged, PushStaged, @@ -317,7 +317,7 @@ impl std::str::FromStr for TaskSchedulingPolicy { type Err = String; fn from_str(s: &str) -> std::result::Result { - ArgEnum::from_str(s, true) + ValueEnum::from_str(s, true) } } @@ -329,7 +329,7 @@ impl parse_arg::ParseArgFromStr for TaskSchedulingPolicy { // an enum used to configure the log rolling policy // needs to be visible to code generated by configure_me -#[derive(Clone, ArgEnum, Copy, Debug, serde::Deserialize)] +#[derive(Clone, ValueEnum, Copy, Debug, serde::Deserialize)] pub enum LogRotationPolicy { Minutely, Hourly, @@ -341,7 +341,7 @@ impl std::str::FromStr for LogRotationPolicy { type Err = String; fn from_str(s: &str) -> std::result::Result { - ArgEnum::from_str(s, true) + ValueEnum::from_str(s, true) } } @@ -353,7 +353,7 @@ impl parse_arg::ParseArgFromStr for LogRotationPolicy { // an enum used to configure the source data cache policy // needs to be visible to code generated by configure_me -#[derive(Clone, ArgEnum, Copy, Debug, serde::Deserialize)] +#[derive(Clone, ValueEnum, Copy, Debug, serde::Deserialize)] pub enum DataCachePolicy { LocalDiskFile, } @@ -362,7 +362,7 @@ impl std::str::FromStr for DataCachePolicy { type Err = String; fn from_str(s: &str) -> std::result::Result { - ArgEnum::from_str(s, true) + ValueEnum::from_str(s, true) } } diff --git a/ballista/core/src/execution_plans/distributed_query.rs b/ballista/core/src/execution_plans/distributed_query.rs index b96367bbc..050ba877a 100644 --- a/ballista/core/src/execution_plans/distributed_query.rs +++ b/ballista/core/src/execution_plans/distributed_query.rs @@ -154,6 +154,10 @@ impl DisplayAs for DistributedQueryExec { } impl ExecutionPlan for DistributedQueryExec { + fn name(&self) -> &str { + "DistributedQueryExec" + } + fn as_any(&self) -> &dyn Any { self } diff --git a/ballista/core/src/execution_plans/shuffle_reader.rs b/ballista/core/src/execution_plans/shuffle_reader.rs index 79dfe296d..2f856b394 100644 --- a/ballista/core/src/execution_plans/shuffle_reader.rs +++ b/ballista/core/src/execution_plans/shuffle_reader.rs @@ -107,6 +107,10 @@ impl DisplayAs for ShuffleReaderExec { } impl ExecutionPlan for ShuffleReaderExec { + fn name(&self) -> &str { + "ShuffleReaderExec" + } + fn as_any(&self) -> &dyn Any { self } diff --git a/ballista/core/src/execution_plans/shuffle_writer.rs b/ballista/core/src/execution_plans/shuffle_writer.rs index 7f21b18b4..87e4feeac 100644 --- a/ballista/core/src/execution_plans/shuffle_writer.rs +++ b/ballista/core/src/execution_plans/shuffle_writer.rs @@ -355,6 +355,10 @@ impl DisplayAs for ShuffleWriterExec { } impl ExecutionPlan for ShuffleWriterExec { + fn name(&self) -> &str { + "ShuffleWriterExec" + } + fn as_any(&self) -> &dyn Any { self } diff --git a/ballista/core/src/execution_plans/unresolved_shuffle.rs b/ballista/core/src/execution_plans/unresolved_shuffle.rs index b3c30c0d5..e227e2ac3 100644 --- a/ballista/core/src/execution_plans/unresolved_shuffle.rs +++ b/ballista/core/src/execution_plans/unresolved_shuffle.rs @@ -82,6 +82,10 @@ impl DisplayAs for UnresolvedShuffleExec { } impl ExecutionPlan for UnresolvedShuffleExec { + fn name(&self) -> &str { + "UnresolvedShuffleExec" + } + fn as_any(&self) -> &dyn Any { self } diff --git a/ballista/core/src/serde/mod.rs b/ballista/core/src/serde/mod.rs index 08208eed7..2bb555d1a 100644 --- a/ballista/core/src/serde/mod.rs +++ b/ballista/core/src/serde/mod.rs @@ -21,9 +21,13 @@ use crate::{error::BallistaError, serde::scheduler::Action as BallistaAction}; use arrow_flight::sql::ProstMessageExt; -use datafusion::common::DataFusionError; +use datafusion::common::{DataFusionError, Result}; use datafusion::execution::FunctionRegistry; use datafusion::physical_plan::{ExecutionPlan, Partitioning}; +use datafusion_proto::logical_plan::file_formats::{ + ArrowLogicalExtensionCodec, AvroLogicalExtensionCodec, CsvLogicalExtensionCodec, + JsonLogicalExtensionCodec, ParquetLogicalExtensionCodec, +}; use datafusion_proto::physical_plan::from_proto::parse_protobuf_hash_partitioning; use datafusion_proto::protobuf::proto_error; use datafusion_proto::protobuf::{LogicalPlanNode, PhysicalPlanNode}; @@ -84,7 +88,7 @@ pub struct BallistaCodec< impl Default for BallistaCodec { fn default() -> Self { Self { - logical_extension_codec: Arc::new(DefaultLogicalExtensionCodec {}), + logical_extension_codec: Arc::new(BallistaLogicalExtensionCodec::default()), physical_extension_codec: Arc::new(BallistaPhysicalExtensionCodec {}), logical_plan_repr: PhantomData, physical_plan_repr: PhantomData, @@ -114,6 +118,102 @@ impl BallistaCodec, + file_format_codecs: Vec>, +} + +impl BallistaLogicalExtensionCodec { + fn try_any( + &self, + mut f: impl FnMut(&dyn LogicalExtensionCodec) -> Result, + ) -> Result { + let mut last_err = None; + for codec in &self.file_format_codecs { + match f(codec.as_ref()) { + Ok(node) => return Ok(node), + Err(err) => last_err = Some(err), + } + } + + Err(last_err.unwrap_or_else(|| { + DataFusionError::NotImplemented("Empty list of composed codecs".to_owned()) + })) + } +} + +impl Default for BallistaLogicalExtensionCodec { + fn default() -> Self { + Self { + default_codec: Arc::new(DefaultLogicalExtensionCodec {}), + file_format_codecs: vec![ + Arc::new(CsvLogicalExtensionCodec {}), + Arc::new(JsonLogicalExtensionCodec {}), + Arc::new(ParquetLogicalExtensionCodec {}), + Arc::new(ArrowLogicalExtensionCodec {}), + Arc::new(AvroLogicalExtensionCodec {}), + ], + } + } +} + +impl LogicalExtensionCodec for BallistaLogicalExtensionCodec { + fn try_decode( + &self, + buf: &[u8], + inputs: &[datafusion::logical_expr::LogicalPlan], + ctx: &datafusion::prelude::SessionContext, + ) -> Result { + self.default_codec.try_decode(buf, inputs, ctx) + } + + fn try_encode( + &self, + node: &datafusion::logical_expr::Extension, + buf: &mut Vec, + ) -> Result<()> { + self.default_codec.try_encode(node, buf) + } + + fn try_decode_table_provider( + &self, + buf: &[u8], + table_ref: &datafusion::sql::TableReference, + schema: datafusion::arrow::datatypes::SchemaRef, + ctx: &datafusion::prelude::SessionContext, + ) -> Result> { + self.default_codec + .try_decode_table_provider(buf, table_ref, schema, ctx) + } + + fn try_encode_table_provider( + &self, + table_ref: &datafusion::sql::TableReference, + node: Arc, + buf: &mut Vec, + ) -> Result<()> { + self.default_codec + .try_encode_table_provider(table_ref, node, buf) + } + + fn try_decode_file_format( + &self, + buf: &[u8], + ctx: &datafusion::prelude::SessionContext, + ) -> Result> { + self.try_any(|codec| codec.try_decode_file_format(buf, ctx)) + } + + fn try_encode_file_format( + &self, + buf: &mut Vec, + node: Arc, + ) -> Result<()> { + self.try_any(|codec| codec.try_encode_file_format(buf, node.clone())) + } +} + #[derive(Debug)] pub struct BallistaPhysicalExtensionCodec {} diff --git a/ballista/core/src/serde/scheduler/mod.rs b/ballista/core/src/serde/scheduler/mod.rs index 0ced200e5..23c9c4256 100644 --- a/ballista/core/src/serde/scheduler/mod.rs +++ b/ballista/core/src/serde/scheduler/mod.rs @@ -25,6 +25,7 @@ use datafusion::arrow::array::{ use datafusion::arrow::datatypes::{DataType, Field}; use datafusion::common::DataFusionError; use datafusion::execution::FunctionRegistry; +use datafusion::logical_expr::planner::ExprPlanner; use datafusion::logical_expr::{AggregateUDF, ScalarUDF, WindowUDF}; use datafusion::physical_plan::ExecutionPlan; use datafusion::physical_plan::Partitioning; @@ -299,6 +300,10 @@ pub struct SimpleFunctionRegistry { } impl FunctionRegistry for SimpleFunctionRegistry { + fn expr_planners(&self) -> Vec> { + vec![] + } + fn udfs(&self) -> HashSet { self.scalar_functions.keys().cloned().collect() } diff --git a/ballista/core/src/utils.rs b/ballista/core/src/utils.rs index 45a4f53fd..7e88ffaf3 100644 --- a/ballista/core/src/utils.rs +++ b/ballista/core/src/utils.rs @@ -35,6 +35,7 @@ use datafusion::execution::context::{ QueryPlanner, SessionConfig, SessionContext, SessionState, }; use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; +use datafusion::execution::session_state::SessionStateBuilder; use datafusion::logical_expr::{DdlStatement, LogicalPlan}; use datafusion::physical_plan::aggregates::AggregateExec; use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec; @@ -62,13 +63,14 @@ use tonic::transport::{Channel, Error, Server}; /// Default session builder using the provided configuration pub fn default_session_builder(config: SessionConfig) -> SessionState { - SessionState::new_with_config_rt( - config, - Arc::new( + SessionStateBuilder::new() + .with_default_features() + .with_config(config) + .with_runtime_env(Arc::new( RuntimeEnv::new(with_object_store_registry(RuntimeConfig::default())) .unwrap(), - ), - ) + )) + .build() } /// Stream data to disk in Arrow IPC format @@ -252,15 +254,16 @@ pub fn create_df_ctx_with_ballista_query_planner( let session_config = SessionConfig::new() .with_target_partitions(config.default_shuffle_partitions()) .with_information_schema(true); - let mut session_state = SessionState::new_with_config_rt( - session_config, - Arc::new( + let session_state = SessionStateBuilder::new() + .with_default_features() + .with_config(session_config) + .with_runtime_env(Arc::new( RuntimeEnv::new(with_object_store_registry(RuntimeConfig::default())) .unwrap(), - ), - ) - .with_query_planner(planner); - session_state = session_state.with_session_id(session_id); + )) + .with_query_planner(planner) + .with_session_id(session_id) + .build(); // the SessionContext created here is the client side context, but the session_id is from server side. SessionContext::new_with_state(session_state) } diff --git a/ballista/executor/Cargo.toml b/ballista/executor/Cargo.toml index 6bebaa877..e0ca6efb6 100644 --- a/ballista/executor/Cargo.toml +++ b/ballista/executor/Cargo.toml @@ -48,7 +48,6 @@ dashmap = "5.4.0" datafusion = { workspace = true } datafusion-proto = { workspace = true } futures = "0.3" -hyper = "0.14.4" log = "0.4" mimalloc = { version = "0.1", default-features = false, optional = true } num_cpus = "1.13.0" diff --git a/ballista/executor/src/collect.rs b/ballista/executor/src/collect.rs index eb96e314f..1d77e7198 100644 --- a/ballista/executor/src/collect.rs +++ b/ballista/executor/src/collect.rs @@ -67,6 +67,10 @@ impl DisplayAs for CollectExec { } impl ExecutionPlan for CollectExec { + fn name(&self) -> &str { + "CollectExec" + } + fn as_any(&self) -> &dyn Any { self } diff --git a/ballista/executor/src/executor.rs b/ballista/executor/src/executor.rs index ccc7f2739..4e83b1251 100644 --- a/ballista/executor/src/executor.rs +++ b/ballista/executor/src/executor.rs @@ -28,6 +28,8 @@ use ballista_core::serde::scheduler::PartitionId; use dashmap::DashMap; use datafusion::execution::context::TaskContext; use datafusion::execution::runtime_env::RuntimeEnv; +use datafusion::functions::all_default_functions; +use datafusion::functions_aggregate::all_default_aggregate_functions; use datafusion::logical_expr::{AggregateUDF, ScalarUDF, WindowUDF}; use futures::future::AbortHandle; use std::collections::HashMap; @@ -103,12 +105,22 @@ impl Executor { concurrent_tasks: usize, execution_engine: Option>, ) -> Self { + let scalar_functions = all_default_functions() + .into_iter() + .map(|f| (f.name().to_string(), f)) + .collect(); + + let aggregate_functions = all_default_aggregate_functions() + .into_iter() + .map(|f| (f.name().to_string(), f)) + .collect(); + Self { metadata, work_dir: work_dir.to_owned(), - // TODO add logic to dynamically load UDF/UDAFs libs from files - scalar_functions: HashMap::new(), - aggregate_functions: HashMap::new(), + scalar_functions, + aggregate_functions, + // TODO: set to default window functions when they are moved to udwf window_functions: HashMap::new(), runtime, runtime_with_data_cache, @@ -277,6 +289,10 @@ mod test { } impl ExecutionPlan for NeverendingOperator { + fn name(&self) -> &str { + "NeverendingOperator" + } + fn as_any(&self) -> &dyn Any { self } diff --git a/ballista/scheduler/Cargo.toml b/ballista/scheduler/Cargo.toml index dd878b9a6..596db6d11 100644 --- a/ballista/scheduler/Cargo.toml +++ b/ballista/scheduler/Cargo.toml @@ -46,20 +46,19 @@ anyhow = "1" arrow-flight = { workspace = true } async-recursion = "1.0.0" async-trait = "0.1.41" +axum = "0.6.20" ballista-core = { path = "../core", version = "0.12.0", features = ["s3"] } base64 = { version = "0.21" } -clap = { version = "3", features = ["derive", "cargo"] } +clap = { workspace = true } configure_me = { workspace = true } dashmap = "5.4.0" datafusion = { workspace = true } datafusion-proto = { workspace = true } -etcd-client = { version = "0.12", optional = true } +etcd-client = { version = "0.14", optional = true } flatbuffers = { version = "23.5.26" } futures = "0.3" graphviz-rust = "0.8.0" -http = "0.2" -http-body = "0.4" -hyper = "0.14.4" +http = "0.2.9" itertools = "0.12.0" log = "0.4" object_store = { workspace = true } @@ -67,20 +66,20 @@ once_cell = { version = "1.16.0", optional = true } parking_lot = "0.12" parse_arg = "0.1.3" prometheus = { version = "0.13", features = ["process"], optional = true } -prost = "0.12" -prost-types = { version = "0.12.0" } +prost = { workspace = true } +prost-types = { workspace = true } rand = "0.8" serde = { version = "1", features = ["derive"] } sled_package = { package = "sled", version = "0.34", optional = true } tokio = { version = "1.0", features = ["full"] } tokio-stream = { version = "0.1", features = ["net"], optional = true } tonic = { workspace = true } -tower = { version = "0.4" } +# tonic 0.12.2 depends on tower 0.4.7 +tower = { version = "0.4.7", default-features = false, features = ["make", "util"] } tracing = { workspace = true } tracing-appender = { workspace = true } tracing-subscriber = { workspace = true } uuid = { version = "1.0", features = ["v4"] } -warp = "0.3" [dev-dependencies] ballista-core = { path = "../core", version = "0.12.0" } diff --git a/ballista/scheduler/src/api/handlers.rs b/ballista/scheduler/src/api/handlers.rs index 463ca2175..4d0366ff8 100644 --- a/ballista/scheduler/src/api/handlers.rs +++ b/ballista/scheduler/src/api/handlers.rs @@ -14,6 +14,11 @@ use crate::scheduler_server::event::QueryStageSchedulerEvent; use crate::scheduler_server::SchedulerServer; use crate::state::execution_graph::ExecutionStage; use crate::state::execution_graph_dot::ExecutionGraphDot; +use axum::{ + extract::{Path, State}, + response::{IntoResponse, Response}, + Json, +}; use ballista_core::serde::protobuf::job_status::Status; use ballista_core::BALLISTA_VERSION; use datafusion::physical_plan::metrics::{MetricValue, MetricsSet, Time}; @@ -22,10 +27,9 @@ use datafusion_proto::physical_plan::AsExecutionPlan; use graphviz_rust::cmd::{CommandArg, Format}; use graphviz_rust::exec; use graphviz_rust::printer::PrinterContext; -use http::header::CONTENT_TYPE; - +use http::{header::CONTENT_TYPE, StatusCode}; +use std::sync::Arc; use std::time::Duration; -use warp::Rejection; #[derive(Debug, serde::Serialize)] struct SchedulerStateResponse { @@ -64,22 +68,26 @@ pub struct QueryStageSummary { pub elapsed_compute: String, } -/// Return current scheduler state -pub(crate) async fn get_scheduler_state( - data_server: SchedulerServer, -) -> Result { +pub async fn get_scheduler_state< + T: AsLogicalPlan + Clone + Send + Sync + 'static, + U: AsExecutionPlan + Send + Sync + 'static, +>( + State(data_server): State>>, +) -> impl IntoResponse { let response = SchedulerStateResponse { started: data_server.start_time, version: BALLISTA_VERSION, }; - Ok(warp::reply::json(&response)) + Json(response) } -/// Return list of executors -pub(crate) async fn get_executors( - data_server: SchedulerServer, -) -> Result { - let state = data_server.state; +pub async fn get_executors< + T: AsLogicalPlan + Clone + Send + Sync + 'static, + U: AsExecutionPlan + Send + Sync + 'static, +>( + State(data_server): State>>, +) -> impl IntoResponse { + let state = &data_server.state; let executors: Vec = state .executor_manager .get_executor_state() @@ -94,21 +102,23 @@ pub(crate) async fn get_executors( }) .collect(); - Ok(warp::reply::json(&executors)) + Json(executors) } -/// Return list of jobs -pub(crate) async fn get_jobs( - data_server: SchedulerServer, -) -> Result { +pub async fn get_jobs< + T: AsLogicalPlan + Clone + Send + Sync + 'static, + U: AsExecutionPlan + Send + Sync + 'static, +>( + State(data_server): State>>, +) -> Result { // TODO: Display last seen information in UI - let state = data_server.state; + let state = &data_server.state; let jobs = state .task_manager .get_jobs() .await - .map_err(|_| warp::reject())?; + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; let jobs: Vec = jobs .iter() @@ -157,31 +167,34 @@ pub(crate) async fn get_jobs( }) .collect(); - Ok(warp::reply::json(&jobs)) + Ok(Json(jobs)) } -pub(crate) async fn cancel_job( - data_server: SchedulerServer, - job_id: String, -) -> Result { +pub async fn cancel_job< + T: AsLogicalPlan + Clone + Send + Sync + 'static, + U: AsExecutionPlan + Send + Sync + 'static, +>( + State(data_server): State>>, + Path(job_id): Path, +) -> Result { // 404 if job doesn't exist data_server .state .task_manager .get_job_status(&job_id) .await - .map_err(|_| warp::reject())? - .ok_or_else(warp::reject)?; + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? + .ok_or(StatusCode::NOT_FOUND)?; data_server .query_stage_event_loop .get_sender() - .map_err(|_| warp::reject())? + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? .post_event(QueryStageSchedulerEvent::JobCancel(job_id)) .await - .map_err(|_| warp::reject())?; + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; - Ok(warp::reply::json(&CancelJobResponse { cancelled: true })) + Ok(Json(CancelJobResponse { cancelled: true })) } #[derive(Debug, serde::Serialize)] @@ -189,69 +202,71 @@ pub struct QueryStagesResponse { pub stages: Vec, } -/// Get the execution graph for the specified job id -pub(crate) async fn get_query_stages( - data_server: SchedulerServer, - job_id: String, -) -> Result { +pub async fn get_query_stages< + T: AsLogicalPlan + Clone + Send + Sync + 'static, + U: AsExecutionPlan + Send + Sync + 'static, +>( + State(data_server): State>>, + Path(job_id): Path, +) -> Result { if let Some(graph) = data_server .state .task_manager .get_job_execution_graph(&job_id) .await - .map_err(|_| warp::reject())? + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? { - Ok(warp::reply::json(&QueryStagesResponse { - stages: graph - .as_ref() - .stages() - .iter() - .map(|(id, stage)| { - let mut summary = QueryStageSummary { - stage_id: id.to_string(), - stage_status: stage.variant_name().to_string(), - input_rows: 0, - output_rows: 0, - elapsed_compute: "".to_string(), - }; - match stage { - ExecutionStage::Running(running_stage) => { - summary.input_rows = running_stage - .stage_metrics - .as_ref() - .map(|m| get_combined_count(m.as_slice(), "input_rows")) - .unwrap_or(0); - summary.output_rows = running_stage - .stage_metrics - .as_ref() - .map(|m| get_combined_count(m.as_slice(), "output_rows")) - .unwrap_or(0); - summary.elapsed_compute = running_stage - .stage_metrics - .as_ref() - .map(|m| get_elapsed_compute_nanos(m.as_slice())) - .unwrap_or_default(); - } - ExecutionStage::Successful(completed_stage) => { - summary.input_rows = get_combined_count( - &completed_stage.stage_metrics, - "input_rows", - ); - summary.output_rows = get_combined_count( - &completed_stage.stage_metrics, - "output_rows", - ); - summary.elapsed_compute = - get_elapsed_compute_nanos(&completed_stage.stage_metrics); - } - _ => {} + let stages = graph + .as_ref() + .stages() + .iter() + .map(|(id, stage)| { + let mut summary = QueryStageSummary { + stage_id: id.to_string(), + stage_status: stage.variant_name().to_string(), + input_rows: 0, + output_rows: 0, + elapsed_compute: "".to_string(), + }; + match stage { + ExecutionStage::Running(running_stage) => { + summary.input_rows = running_stage + .stage_metrics + .as_ref() + .map(|m| get_combined_count(m.as_slice(), "input_rows")) + .unwrap_or(0); + summary.output_rows = running_stage + .stage_metrics + .as_ref() + .map(|m| get_combined_count(m.as_slice(), "output_rows")) + .unwrap_or(0); + summary.elapsed_compute = running_stage + .stage_metrics + .as_ref() + .map(|m| get_elapsed_compute_nanos(m.as_slice())) + .unwrap_or_default(); + } + ExecutionStage::Successful(completed_stage) => { + summary.input_rows = get_combined_count( + &completed_stage.stage_metrics, + "input_rows", + ); + summary.output_rows = get_combined_count( + &completed_stage.stage_metrics, + "output_rows", + ); + summary.elapsed_compute = + get_elapsed_compute_nanos(&completed_stage.stage_metrics); } - summary - }) - .collect(), - })) + _ => {} + } + summary + }) + .collect(); + + Ok(Json(QueryStagesResponse { stages })) } else { - Ok(warp::reply::json(&QueryStagesResponse { stages: vec![] })) + Ok(Json(QueryStagesResponse { stages: vec![] })) } } @@ -286,78 +301,96 @@ fn get_combined_count(metrics: &[MetricsSet], name: &str) -> usize { .sum() } -/// Generate a dot graph for the specified job id and return as plain text -pub(crate) async fn get_job_dot_graph( - data_server: SchedulerServer, - job_id: String, -) -> Result { +pub async fn get_job_dot_graph< + T: AsLogicalPlan + Clone + Send + Sync + 'static, + U: AsExecutionPlan + Send + Sync + 'static, +>( + State(data_server): State>>, + Path(job_id): Path, +) -> Result { if let Some(graph) = data_server .state .task_manager .get_job_execution_graph(&job_id) .await - .map_err(|_| warp::reject())? + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? { - ExecutionGraphDot::generate(graph.as_ref()).map_err(|_| warp::reject()) + ExecutionGraphDot::generate(graph.as_ref()) + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR) } else { Ok("Not Found".to_string()) } } -/// Generate a dot graph for the specified job id and query stage and return as plain text -pub(crate) async fn get_query_stage_dot_graph( - data_server: SchedulerServer, - job_id: String, - stage_id: usize, -) -> Result { +pub async fn get_query_stage_dot_graph< + T: AsLogicalPlan + Clone + Send + Sync + 'static, + U: AsExecutionPlan + Send + Sync + 'static, +>( + State(data_server): State>>, + Path((job_id, stage_id)): Path<(String, usize)>, +) -> Result { if let Some(graph) = data_server .state .task_manager .get_job_execution_graph(&job_id) .await - .map_err(|_| warp::reject())? + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? { ExecutionGraphDot::generate_for_query_stage(graph.as_ref(), stage_id) - .map_err(|_| warp::reject()) + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR) } else { Ok("Not Found".to_string()) } } -/// Generate an SVG graph for the specified job id and return it as plain text -pub(crate) async fn get_job_svg_graph( - data_server: SchedulerServer, - job_id: String, -) -> Result { - let dot = get_job_dot_graph(data_server, job_id).await; - match dot { - Ok(dot) => { - let graph = graphviz_rust::parse(&dot); - if let Ok(graph) = graph { - exec( - graph, - &mut PrinterContext::default(), - vec![CommandArg::Format(Format::Svg)], - ) - .map(|bytes| String::from_utf8_lossy(&bytes).to_string()) - .map_err(|_| warp::reject()) - } else { - Ok("Cannot parse graph".to_string()) - } +pub async fn get_job_svg_graph< + T: AsLogicalPlan + Clone + Send + Sync + 'static, + U: AsExecutionPlan + Send + Sync + 'static, +>( + State(data_server): State>>, + Path(job_id): Path, +) -> Result { + let dot = get_job_dot_graph(State(data_server.clone()), Path(job_id)).await?; + match graphviz_rust::parse(&dot) { + Ok(graph) => { + let result = exec( + graph, + &mut PrinterContext::default(), + vec![CommandArg::Format(Format::Svg)], + ) + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + + let svg = String::from_utf8_lossy(&result).to_string(); + Ok(Response::builder() + .header(CONTENT_TYPE, "image/svg+xml") + .body(svg) + .unwrap()) } - _ => Ok("Not Found".to_string()), + Err(_) => Ok(Response::builder() + .status(StatusCode::BAD_REQUEST) + .body("Cannot parse graph".to_string()) + .unwrap()), } } -pub(crate) async fn get_scheduler_metrics( - data_server: SchedulerServer, -) -> Result { - Ok(data_server - .metrics_collector() - .gather_metrics() - .map_err(|_| warp::reject())? - .map(|(data, content_type)| { - warp::reply::with_header(data, CONTENT_TYPE, content_type) - }) - .unwrap_or_else(|| warp::reply::with_header(vec![], CONTENT_TYPE, "text/html"))) +pub async fn get_scheduler_metrics< + T: AsLogicalPlan + Clone + Send + Sync + 'static, + U: AsExecutionPlan + Send + Sync + 'static, +>( + State(data_server): State>>, +) -> impl IntoResponse { + match data_server.metrics_collector().gather_metrics() { + Ok(Some((data, content_type))) => Response::builder() + .header(CONTENT_TYPE, content_type) + .body(axum::body::Body::from(data)) + .unwrap(), + Ok(None) => Response::builder() + .status(StatusCode::NO_CONTENT) + .body(axum::body::Body::empty()) + .unwrap(), + Err(_) => Response::builder() + .status(StatusCode::INTERNAL_SERVER_ERROR) + .body(axum::body::Body::empty()) + .unwrap(), + } } diff --git a/ballista/scheduler/src/api/mod.rs b/ballista/scheduler/src/api/mod.rs index 8f5555d06..c33d5157a 100644 --- a/ballista/scheduler/src/api/mod.rs +++ b/ballista/scheduler/src/api/mod.rs @@ -13,126 +13,39 @@ mod handlers; use crate::scheduler_server::SchedulerServer; -use anyhow::Result; +use axum::routing::patch; +use axum::{routing::get, Router}; use datafusion_proto::logical_plan::AsLogicalPlan; use datafusion_proto::physical_plan::AsExecutionPlan; -use std::{ - pin::Pin, - task::{Context as TaskContext, Poll}, -}; -use warp::filters::BoxedFilter; -use warp::{Buf, Filter, Reply}; - -pub enum EitherBody { - Left(A), - Right(B), -} - -pub type Error = Box; -pub type HttpBody = dyn http_body::Body + 'static; - -impl http_body::Body for EitherBody -where - A: http_body::Body + Send + Unpin, - B: http_body::Body + Send + Unpin, - A::Error: Into, - B::Error: Into, -{ - type Data = A::Data; - type Error = Error; - - fn poll_data( - self: Pin<&mut Self>, - cx: &mut TaskContext<'_>, - ) -> Poll>> { - match self.get_mut() { - EitherBody::Left(b) => Pin::new(b).poll_data(cx).map(map_option_err), - EitherBody::Right(b) => Pin::new(b).poll_data(cx).map(map_option_err), - } - } - - fn poll_trailers( - self: Pin<&mut Self>, - cx: &mut TaskContext<'_>, - ) -> Poll, Self::Error>> { - match self.get_mut() { - EitherBody::Left(b) => Pin::new(b).poll_trailers(cx).map_err(Into::into), - EitherBody::Right(b) => Pin::new(b).poll_trailers(cx).map_err(Into::into), - } - } - - fn is_end_stream(&self) -> bool { - match self { - EitherBody::Left(b) => b.is_end_stream(), - EitherBody::Right(b) => b.is_end_stream(), - } - } -} - -fn map_option_err>( - err: Option>, -) -> Option> { - err.map(|e| e.map_err(Into::into)) -} - -fn with_data_server( - db: SchedulerServer, -) -> impl Filter,), Error = std::convert::Infallible> + Clone -{ - warp::any().map(move || db.clone()) -} - -pub fn get_routes( - scheduler_server: SchedulerServer, -) -> BoxedFilter<(impl Reply,)> { - let route_scheduler_state = warp::path!("api" / "state") - .and(with_data_server(scheduler_server.clone())) - .and_then(handlers::get_scheduler_state); - - let route_executors = warp::path!("api" / "executors") - .and(with_data_server(scheduler_server.clone())) - .and_then(handlers::get_executors); - - let route_jobs = warp::path!("api" / "jobs") - .and(with_data_server(scheduler_server.clone())) - .and_then(|data_server| handlers::get_jobs(data_server)); - - let route_cancel_job = warp::path!("api" / "job" / String) - .and(warp::patch()) - .and(with_data_server(scheduler_server.clone())) - .and_then(|job_id, data_server| handlers::cancel_job(data_server, job_id)); - - let route_query_stages = warp::path!("api" / "job" / String / "stages") - .and(with_data_server(scheduler_server.clone())) - .and_then(|job_id, data_server| handlers::get_query_stages(data_server, job_id)); - - let route_job_dot = warp::path!("api" / "job" / String / "dot") - .and(with_data_server(scheduler_server.clone())) - .and_then(|job_id, data_server| handlers::get_job_dot_graph(data_server, job_id)); - - let route_query_stage_dot = - warp::path!("api" / "job" / String / "stage" / usize / "dot") - .and(with_data_server(scheduler_server.clone())) - .and_then(|job_id, stage_id, data_server| { - handlers::get_query_stage_dot_graph(data_server, job_id, stage_id) - }); - - let route_job_dot_svg = warp::path!("api" / "job" / String / "dot_svg") - .and(with_data_server(scheduler_server.clone())) - .and_then(|job_id, data_server| handlers::get_job_svg_graph(data_server, job_id)); - - let route_scheduler_metrics = warp::path!("api" / "metrics") - .and(with_data_server(scheduler_server)) - .and_then(|data_server| handlers::get_scheduler_metrics(data_server)); - - let routes = route_scheduler_state - .or(route_executors) - .or(route_jobs) - .or(route_cancel_job) - .or(route_query_stages) - .or(route_job_dot) - .or(route_query_stage_dot) - .or(route_job_dot_svg) - .or(route_scheduler_metrics); - routes.boxed() +use std::sync::Arc; + +pub fn get_routes< + T: AsLogicalPlan + Clone + Send + Sync + 'static, + U: AsExecutionPlan + Send + Sync + 'static, +>( + scheduler_server: Arc>, +) -> Router { + Router::new() + .route("/api/state", get(handlers::get_scheduler_state::)) + .route("/api/executors", get(handlers::get_executors::)) + .route("/api/jobs", get(handlers::get_jobs::)) + .route("/api/job/:job_id", patch(handlers::cancel_job::)) + .route( + "/api/job/:job_id/stages", + get(handlers::get_query_stages::), + ) + .route( + "/api/job/:job_id/dot", + get(handlers::get_job_dot_graph::), + ) + .route( + "/api/job/:job_id/stage/:stage_id/dot", + get(handlers::get_query_stage_dot_graph::), + ) + .route( + "/api/job/:job_id/dot_svg", + get(handlers::get_job_svg_graph::), + ) + .route("/api/metrics", get(handlers::get_scheduler_metrics::)) + .with_state(scheduler_server) } diff --git a/ballista/scheduler/src/bin/main.rs b/ballista/scheduler/src/bin/main.rs index ee9364c78..d2e2c9ceb 100644 --- a/ballista/scheduler/src/bin/main.rs +++ b/ballista/scheduler/src/bin/main.rs @@ -47,8 +47,17 @@ mod config { )); } -#[tokio::main] -async fn main() -> Result<()> { +fn main() -> Result<()> { + let runtime = tokio::runtime::Builder::new_multi_thread() + .enable_io() + .enable_time() + .thread_stack_size(32 * 1024 * 1024) // 32MB + .build() + .unwrap(); + + runtime.block_on(inner()) +} +async fn inner() -> Result<()> { // parse options let (opt, _remaining_args) = Config::including_optional_config_files(&["/etc/ballista/scheduler.toml"]) diff --git a/ballista/scheduler/src/cluster/mod.rs b/ballista/scheduler/src/cluster/mod.rs index 8313f0330..b7489a25d 100644 --- a/ballista/scheduler/src/cluster/mod.rs +++ b/ballista/scheduler/src/cluster/mod.rs @@ -20,7 +20,7 @@ use std::fmt; use std::pin::Pin; use std::sync::Arc; -use clap::ArgEnum; +use clap::ValueEnum; use datafusion::common::tree_node::TreeNode; use datafusion::common::tree_node::TreeNodeRecursion; use datafusion::datasource::listing::PartitionedFile; @@ -65,7 +65,7 @@ pub mod test_util; // an enum used to configure the backend // needs to be visible to code generated by configure_me -#[derive(Debug, Clone, ArgEnum, serde::Deserialize, PartialEq, Eq)] +#[derive(Debug, Clone, ValueEnum, serde::Deserialize, PartialEq, Eq)] pub enum ClusterStorage { Etcd, Memory, @@ -76,7 +76,7 @@ impl std::str::FromStr for ClusterStorage { type Err = String; fn from_str(s: &str) -> std::result::Result { - ArgEnum::from_str(s, true) + ValueEnum::from_str(s, true) } } @@ -764,7 +764,10 @@ mod test { }; use crate::state::execution_graph::ExecutionGraph; use crate::state::task_manager::JobInfoCache; - use crate::test_utils::{mock_completed_task, test_aggregation_plan_with_job_id}; + use crate::test_utils::{ + mock_completed_task, revive_graph_and_complete_next_stage, + test_aggregation_plan_with_job_id, + }; #[tokio::test] async fn test_bind_task_bias() -> Result<()> { @@ -1008,10 +1011,11 @@ mod test { async fn mock_graph( job_id: &str, - num_partition: usize, + num_target_partitions: usize, num_pending_task: usize, ) -> Result { - let mut graph = test_aggregation_plan_with_job_id(num_partition, job_id).await; + let mut graph = + test_aggregation_plan_with_job_id(num_target_partitions, job_id).await; let executor = ExecutorMetadata { id: "executor_0".to_string(), host: "localhost".to_string(), @@ -1020,14 +1024,10 @@ mod test { specification: ExecutorSpecification { task_slots: 32 }, }; - if let Some(task) = graph.pop_next_task(&executor.id)? { - let task_status = mock_completed_task(task, &executor.id); - graph.update_task_status(&executor, vec![task_status], 1, 1)?; - } - - graph.revive(); + // complete first stage + revive_graph_and_complete_next_stage(&mut graph)?; - for _i in 0..num_partition - num_pending_task { + for _ in 0..num_target_partitions - num_pending_task { if let Some(task) = graph.pop_next_task(&executor.id)? { let task_status = mock_completed_task(task, &executor.id); graph.update_task_status(&executor, vec![task_status], 1, 1)?; diff --git a/ballista/scheduler/src/config.rs b/ballista/scheduler/src/config.rs index d15e928ce..822809110 100644 --- a/ballista/scheduler/src/config.rs +++ b/ballista/scheduler/src/config.rs @@ -19,7 +19,7 @@ //! Ballista scheduler specific configuration use ballista_core::config::TaskSchedulingPolicy; -use clap::ArgEnum; +use clap::ValueEnum; use std::fmt; /// Configurations for the ballista scheduler of scheduling jobs and tasks @@ -189,7 +189,7 @@ pub enum ClusterStorageConfig { /// Policy of distributing tasks to available executor slots /// /// It needs to be visible to code generated by configure_me -#[derive(Clone, ArgEnum, Copy, Debug, serde::Deserialize)] +#[derive(Clone, ValueEnum, Copy, Debug, serde::Deserialize)] pub enum TaskDistribution { /// Eagerly assign tasks to executor slots. This will assign as many task slots per executor /// as are currently available @@ -208,7 +208,7 @@ impl std::str::FromStr for TaskDistribution { type Err = String; fn from_str(s: &str) -> std::result::Result { - ArgEnum::from_str(s, true) + ValueEnum::from_str(s, true) } } diff --git a/ballista/scheduler/src/planner.rs b/ballista/scheduler/src/planner.rs index 3da9f339f..0e18a062c 100644 --- a/ballista/scheduler/src/planner.rs +++ b/ballista/scheduler/src/planner.rs @@ -592,6 +592,8 @@ order by Ok(()) } + #[ignore] + // enable when upgrading Datafusion, a bug is fixed with https://github.com/apache/datafusion/pull/11926/ #[tokio::test] async fn roundtrip_serde_aggregate() -> Result<(), BallistaError> { let ctx = datafusion_test_context("testdata").await?; diff --git a/ballista/scheduler/src/scheduler_process.rs b/ballista/scheduler/src/scheduler_process.rs index 6bcaaec5c..1f7f7ac35 100644 --- a/ballista/scheduler/src/scheduler_process.rs +++ b/ballista/scheduler/src/scheduler_process.rs @@ -15,26 +15,18 @@ // specific language governing permissions and limitations // under the License. -use anyhow::{Context, Result}; +use anyhow::{Error, Result}; #[cfg(feature = "flight-sql")] use arrow_flight::flight_service_server::FlightServiceServer; -use futures::future::{self, Either, TryFutureExt}; -use hyper::{server::conn::AddrStream, service::make_service_fn, Server}; -use log::info; -use std::convert::Infallible; -use std::net::SocketAddr; -use std::sync::Arc; -use tonic::transport::server::Connected; -use tower::Service; - -use datafusion_proto::protobuf::{LogicalPlanNode, PhysicalPlanNode}; - use ballista_core::serde::protobuf::scheduler_grpc_server::SchedulerGrpcServer; use ballista_core::serde::BallistaCodec; use ballista_core::utils::create_grpc_server; use ballista_core::BALLISTA_VERSION; +use datafusion_proto::protobuf::{LogicalPlanNode, PhysicalPlanNode}; +use log::info; +use std::{net::SocketAddr, sync::Arc}; -use crate::api::{get_routes, EitherBody, Error}; +use crate::api::get_routes; use crate::cluster::BallistaCluster; use crate::config::SchedulerConfig; use crate::flight_sql::FlightSqlServiceImpl; @@ -70,58 +62,31 @@ pub async fn start_server( scheduler_server.init().await?; - Server::bind(&addr) - .serve(make_service_fn(move |request: &AddrStream| { - let config = &scheduler_server.state.config; - let scheduler_grpc_server = - SchedulerGrpcServer::new(scheduler_server.clone()) - .max_encoding_message_size( - config.grpc_server_max_encoding_message_size as usize, - ) - .max_decoding_message_size( - config.grpc_server_max_decoding_message_size as usize, - ); - - let keda_scaler = ExternalScalerServer::new(scheduler_server.clone()); - - let tonic_builder = create_grpc_server() - .add_service(scheduler_grpc_server) - .add_service(keda_scaler); + let config = &scheduler_server.state.config; + let scheduler_grpc_server = SchedulerGrpcServer::new(scheduler_server.clone()) + .max_encoding_message_size(config.grpc_server_max_encoding_message_size as usize) + .max_decoding_message_size(config.grpc_server_max_decoding_message_size as usize); - #[cfg(feature = "flight-sql")] - let tonic_builder = tonic_builder.add_service(FlightServiceServer::new( - FlightSqlServiceImpl::new(scheduler_server.clone()), - )); + let keda_scaler = ExternalScalerServer::new(scheduler_server.clone()); - let mut tonic = tonic_builder.into_service(); + let tonic_builder = create_grpc_server() + .add_service(scheduler_grpc_server) + .add_service(keda_scaler); - let mut warp = warp::service(get_routes(scheduler_server.clone())); + #[cfg(feature = "flight-sql")] + let tonic_builder = tonic_builder.add_service(FlightServiceServer::new( + FlightSqlServiceImpl::new(scheduler_server.clone()), + )); - let connect_info = request.connect_info(); - future::ok::<_, Infallible>(tower::service_fn( - move |req: hyper::Request| { - // Set the connect info from hyper to tonic - let (mut parts, body) = req.into_parts(); - parts.extensions.insert(connect_info.clone()); - let req = http::Request::from_parts(parts, body); + let tonic = tonic_builder.into_service().into_router(); - if req.uri().path().starts_with("/api") { - return Either::Left( - warp.call(req) - .map_ok(|res| res.map(EitherBody::Left)) - .map_err(Error::from), - ); - } + let axum = get_routes(Arc::new(scheduler_server)); + let merged = axum + .merge(tonic) + .into_make_service_with_connect_info::(); - Either::Right( - tonic - .call(req) - .map_ok(|res| res.map(EitherBody::Right)) - .map_err(Error::from), - ) - }, - )) - })) + axum::Server::bind(&addr) + .serve(merged) .await - .context("Could not start grpc server") + .map_err(Error::from) } diff --git a/ballista/scheduler/src/scheduler_server/grpc.rs b/ballista/scheduler/src/scheduler_server/grpc.rs index 2d759fb7b..6992bf756 100644 --- a/ballista/scheduler/src/scheduler_server/grpc.rs +++ b/ballista/scheduler/src/scheduler_server/grpc.rs @@ -15,10 +15,12 @@ // specific language governing permissions and limitations // under the License. +use axum::extract::ConnectInfo; use ballista_core::config::{BallistaConfig, BALLISTA_JOB_NAME}; use ballista_core::serde::protobuf::execute_query_params::{OptionalSessionId, Query}; use std::collections::HashMap; use std::convert::TryInto; +use std::net::SocketAddr; use ballista_core::serde::protobuf::executor_registration::OptionalHost; use ballista_core::serde::protobuf::scheduler_grpc_server::SchedulerGrpc; @@ -70,7 +72,7 @@ impl SchedulerGrpc "Bad request because poll work is not supported for push-based task scheduling", )); } - let remote_addr = request.remote_addr(); + let remote_addr = extract_connect_info(&request); if let PollWorkParams { metadata: Some(metadata), num_free_slots, @@ -155,7 +157,7 @@ impl SchedulerGrpc &self, request: Request, ) -> Result, Status> { - let remote_addr = request.remote_addr(); + let remote_addr = extract_connect_info(&request); if let RegisterExecutorParams { metadata: Some(metadata), } = request.into_inner() @@ -191,7 +193,7 @@ impl SchedulerGrpc &self, request: Request, ) -> Result, Status> { - let remote_addr = request.remote_addr(); + let remote_addr = extract_connect_info(&request); let HeartBeatParams { executor_id, metrics, @@ -634,6 +636,13 @@ impl SchedulerGrpc } } +fn extract_connect_info(request: &Request) -> Option> { + request + .extensions() + .get::>() + .cloned() +} + #[cfg(all(test, feature = "sled"))] mod test { use std::sync::Arc; diff --git a/ballista/scheduler/src/scheduler_server/mod.rs b/ballista/scheduler/src/scheduler_server/mod.rs index e6525f188..c2bf657b8 100644 --- a/ballista/scheduler/src/scheduler_server/mod.rs +++ b/ballista/scheduler/src/scheduler_server/mod.rs @@ -345,7 +345,7 @@ mod test { use datafusion::functions_aggregate::sum::sum; use datafusion::logical_expr::{col, LogicalPlan}; - use datafusion::test_util::scan_empty; + use datafusion::test_util::scan_empty_with_partitions; use datafusion_proto::protobuf::LogicalPlanNode; use datafusion_proto::protobuf::PhysicalPlanNode; @@ -700,7 +700,9 @@ mod test { Field::new("gmv", DataType::UInt64, false), ]); - scan_empty(None, &schema, Some(vec![0, 1])) + // partitions need to be > 1 for the datafusion's optimizer to insert a repartition node + // behavior changed with: https://github.com/apache/datafusion/pull/11875 + scan_empty_with_partitions(None, &schema, Some(vec![0, 1]), 2) .unwrap() .aggregate(vec![col("id")], vec![sum(col("gmv"))]) .unwrap() diff --git a/ballista/scheduler/src/state/execution_graph.rs b/ballista/scheduler/src/state/execution_graph.rs index 9ee95d671..333545d35 100644 --- a/ballista/scheduler/src/state/execution_graph.rs +++ b/ballista/scheduler/src/state/execution_graph.rs @@ -1694,7 +1694,9 @@ mod test { use crate::state::execution_graph::ExecutionGraph; use crate::test_utils::{ - mock_completed_task, mock_executor, mock_failed_task, test_aggregation_plan, + mock_completed_task, mock_executor, mock_failed_task, + revive_graph_and_complete_next_stage, + revive_graph_and_complete_next_stage_with_executor, test_aggregation_plan, test_coalesce_plan, test_join_plan, test_two_aggregations_plan, test_union_all_plan, test_union_plan, }; @@ -1793,19 +1795,13 @@ mod test { join_graph.revive(); assert_eq!(join_graph.stage_count(), 4); - assert_eq!(join_graph.available_tasks(), 2); + assert_eq!(join_graph.available_tasks(), 4); // Complete the first stage - if let Some(task) = join_graph.pop_next_task(&executor1.id)? { - let task_status = mock_completed_task(task, &executor1.id); - join_graph.update_task_status(&executor1, vec![task_status], 1, 1)?; - } + revive_graph_and_complete_next_stage_with_executor(&mut join_graph, &executor1)?; // Complete the second stage - if let Some(task) = join_graph.pop_next_task(&executor2.id)? { - let task_status = mock_completed_task(task, &executor2.id); - join_graph.update_task_status(&executor2, vec![task_status], 1, 1)?; - } + revive_graph_and_complete_next_stage_with_executor(&mut join_graph, &executor2)?; join_graph.revive(); // There are 4 tasks pending schedule for the 3rd stage @@ -1823,7 +1819,7 @@ mod test { // Two stages were reset, 1 Running stage rollback to Unresolved and 1 Completed stage move to Running assert_eq!(reset.0.len(), 2); - assert_eq!(join_graph.available_tasks(), 1); + assert_eq!(join_graph.available_tasks(), 2); drain_tasks(&mut join_graph)?; assert!(join_graph.is_successful(), "Failed to complete join plan"); @@ -1844,19 +1840,19 @@ mod test { join_graph.revive(); assert_eq!(join_graph.stage_count(), 4); - assert_eq!(join_graph.available_tasks(), 2); + assert_eq!(join_graph.available_tasks(), 4); // Complete the first stage - if let Some(task) = join_graph.pop_next_task(&executor1.id)? { - let task_status = mock_completed_task(task, &executor1.id); - join_graph.update_task_status(&executor1, vec![task_status], 1, 1)?; - } + assert_eq!(revive_graph_and_complete_next_stage(&mut join_graph)?, 2); // Complete the second stage - if let Some(task) = join_graph.pop_next_task(&executor2.id)? { - let task_status = mock_completed_task(task, &executor2.id); - join_graph.update_task_status(&executor2, vec![task_status], 1, 1)?; - } + assert_eq!( + revive_graph_and_complete_next_stage_with_executor( + &mut join_graph, + &executor2 + )?, + 2 + ); // There are 0 tasks pending schedule now assert_eq!(join_graph.available_tasks(), 0); @@ -1865,7 +1861,7 @@ mod test { // Two stages were reset, 1 Resolved stage rollback to Unresolved and 1 Completed stage move to Running assert_eq!(reset.0.len(), 2); - assert_eq!(join_graph.available_tasks(), 1); + assert_eq!(join_graph.available_tasks(), 2); drain_tasks(&mut join_graph)?; assert!(join_graph.is_successful(), "Failed to complete join plan"); @@ -1886,13 +1882,10 @@ mod test { agg_graph.revive(); assert_eq!(agg_graph.stage_count(), 2); - assert_eq!(agg_graph.available_tasks(), 1); + assert_eq!(agg_graph.available_tasks(), 2); // Complete the first stage - if let Some(task) = agg_graph.pop_next_task(&executor1.id)? { - let task_status = mock_completed_task(task, &executor1.id); - agg_graph.update_task_status(&executor1, vec![task_status], 1, 1)?; - } + revive_graph_and_complete_next_stage_with_executor(&mut agg_graph, &executor1)?; // 1st task in the second stage if let Some(task) = agg_graph.pop_next_task(&executor2.id)? { @@ -1920,12 +1913,12 @@ mod test { // Two stages were reset, 1 Running stage rollback to Unresolved and 1 Completed stage move to Running assert_eq!(reset.0.len(), 2); - assert_eq!(agg_graph.available_tasks(), 1); + assert_eq!(agg_graph.available_tasks(), 2); // Call the reset again let reset = agg_graph.reset_stages_on_lost_executor(&executor1.id)?; assert_eq!(reset.0.len(), 0); - assert_eq!(agg_graph.available_tasks(), 1); + assert_eq!(agg_graph.available_tasks(), 2); drain_tasks(&mut agg_graph)?; assert!(agg_graph.is_successful(), "Failed to complete agg plan"); @@ -1935,24 +1928,20 @@ mod test { #[tokio::test] async fn test_do_not_retry_killed_task() -> Result<()> { - let executor1 = mock_executor("executor-id1".to_string()); - let executor2 = mock_executor("executor-id2".to_string()); + let executor = mock_executor("executor-id-123".to_string()); let mut agg_graph = test_aggregation_plan(4).await; // Call revive to move the leaf Resolved stages to Running agg_graph.revive(); // Complete the first stage - if let Some(task) = agg_graph.pop_next_task(&executor1.id)? { - let task_status = mock_completed_task(task, &executor1.id); - agg_graph.update_task_status(&executor1, vec![task_status], 4, 4)?; - } + revive_graph_and_complete_next_stage(&mut agg_graph)?; // 1st task in the second stage - let task1 = agg_graph.pop_next_task(&executor2.id)?.unwrap(); - let task_status1 = mock_completed_task(task1, &executor2.id); + let task1 = agg_graph.pop_next_task(&executor.id)?.unwrap(); + let task_status1 = mock_completed_task(task1, &executor.id); // 2rd task in the second stage - let task2 = agg_graph.pop_next_task(&executor2.id)?.unwrap(); + let task2 = agg_graph.pop_next_task(&executor.id)?.unwrap(); let task_status2 = mock_failed_task( task2, FailedTask { @@ -1964,7 +1953,7 @@ mod test { ); agg_graph.update_task_status( - &executor2, + &executor, vec![task_status1, task_status2], 4, 4, @@ -1983,24 +1972,20 @@ mod test { #[tokio::test] async fn test_max_task_failed_count() -> Result<()> { - let executor1 = mock_executor("executor-id1".to_string()); - let executor2 = mock_executor("executor-id2".to_string()); + let executor = mock_executor("executor-id2".to_string()); let mut agg_graph = test_aggregation_plan(2).await; // Call revive to move the leaf Resolved stages to Running agg_graph.revive(); // Complete the first stage - if let Some(task) = agg_graph.pop_next_task(&executor1.id)? { - let task_status = mock_completed_task(task, &executor1.id); - agg_graph.update_task_status(&executor1, vec![task_status], 4, 4)?; - } + revive_graph_and_complete_next_stage(&mut agg_graph)?; // 1st task in the second stage - let task1 = agg_graph.pop_next_task(&executor2.id)?.unwrap(); - let task_status1 = mock_completed_task(task1, &executor2.id); + let task1 = agg_graph.pop_next_task(&executor.id)?.unwrap(); + let task_status1 = mock_completed_task(task1, &executor.id); // 2rd task in the second stage, failed due to IOError - let task2 = agg_graph.pop_next_task(&executor2.id)?.unwrap(); + let task2 = agg_graph.pop_next_task(&executor.id)?.unwrap(); let task_status2 = mock_failed_task( task2.clone(), FailedTask { @@ -2012,7 +1997,7 @@ mod test { ); agg_graph.update_task_status( - &executor2, + &executor, vec![task_status1, task_status2], 4, 4, @@ -2023,7 +2008,7 @@ mod test { let mut last_attempt = 0; // 2rd task's attempts for attempt in 1..5 { - if let Some(task2_attempt) = agg_graph.pop_next_task(&executor2.id)? { + if let Some(task2_attempt) = agg_graph.pop_next_task(&executor.id)? { assert_eq!( task2_attempt.partition.partition_id, task2.partition.partition_id @@ -2041,7 +2026,7 @@ mod test { )), }, ); - agg_graph.update_task_status(&executor2, vec![task_status], 4, 4)?; + agg_graph.update_task_status(&executor, vec![task_status], 4, 4)?; } } @@ -2075,10 +2060,7 @@ mod test { agg_graph.revive(); // Complete the Stage 1 - if let Some(task) = agg_graph.pop_next_task(&executor1.id)? { - let task_status = mock_completed_task(task, &executor1.id); - agg_graph.update_task_status(&executor1, vec![task_status], 1, 1)?; - } + revive_graph_and_complete_next_stage_with_executor(&mut agg_graph, &executor1)?; // 1st task in the Stage 2 if let Some(task) = agg_graph.pop_next_task(&executor2.id)? { @@ -2103,13 +2085,10 @@ mod test { // Two stages were reset, Stage 2 rollback to Unresolved and Stage 1 move to Running assert_eq!(reset.0.len(), 2); - assert_eq!(agg_graph.available_tasks(), 1); + assert_eq!(agg_graph.available_tasks(), 2); // Complete the Stage 1 again - if let Some(task) = agg_graph.pop_next_task(&executor1.id)? { - let task_status = mock_completed_task(task, &executor1.id); - agg_graph.update_task_status(&executor1, vec![task_status], 1, 1)?; - } + revive_graph_and_complete_next_stage_with_executor(&mut agg_graph, &executor1)?; // Stage 2 move to Running agg_graph.revive(); @@ -2148,10 +2127,7 @@ mod test { agg_graph.revive(); // Complete the Stage 1 - if let Some(task) = agg_graph.pop_next_task(&executor1.id)? { - let task_status = mock_completed_task(task, &executor1.id); - agg_graph.update_task_status(&executor1, vec![task_status], 4, 4)?; - } + revive_graph_and_complete_next_stage(&mut agg_graph)?; // 1st task in the Stage 2 let task1 = agg_graph.pop_next_task(&executor2.id)?.unwrap(); @@ -2198,7 +2174,7 @@ mod test { let running_stage = agg_graph.running_stages(); assert_eq!(running_stage.len(), 1); assert_eq!(running_stage[0], 1); - assert_eq!(agg_graph.available_tasks(), 1); + assert_eq!(agg_graph.available_tasks(), 2); drain_tasks(&mut agg_graph)?; assert!(agg_graph.is_successful(), "Failed to complete agg plan"); @@ -2216,10 +2192,7 @@ mod test { assert_eq!(agg_graph.stage_count(), 3); // Complete the Stage 1 - if let Some(task) = agg_graph.pop_next_task(&executor1.id)? { - let task_status = mock_completed_task(task, &executor1.id); - agg_graph.update_task_status(&executor1, vec![task_status], 4, 4)?; - } + revive_graph_and_complete_next_stage(&mut agg_graph)?; // Complete the Stage 2, 5 tasks run on executor_2 and 3 tasks run on executor_1 for _i in 0..5 { @@ -2283,10 +2256,7 @@ mod test { agg_graph.revive(); for attempt in 0..6 { - if let Some(task) = agg_graph.pop_next_task(&executor1.id)? { - let task_status = mock_completed_task(task, &executor1.id); - agg_graph.update_task_status(&executor1, vec![task_status], 4, 4)?; - } + revive_graph_and_complete_next_stage(&mut agg_graph)?; // 1rd task in the Stage 2, failed due to FetchPartitionError if let Some(task1) = agg_graph.pop_next_task(&executor2.id)? { @@ -2318,7 +2288,7 @@ mod test { let running_stage = agg_graph.running_stages(); assert_eq!(running_stage.len(), 1); assert_eq!(running_stage[0], 1); - assert_eq!(agg_graph.available_tasks(), 1); + assert_eq!(agg_graph.available_tasks(), 2); } else { // Job is failed after exceeds the max_stage_failures assert_eq!(stage_events.len(), 1); @@ -2355,10 +2325,7 @@ mod test { assert_eq!(agg_graph.stage_count(), 3); // Complete the Stage 1 - if let Some(task) = agg_graph.pop_next_task(&executor1.id)? { - let task_status = mock_completed_task(task, &executor1.id); - agg_graph.update_task_status(&executor1, vec![task_status], 4, 4)?; - } + revive_graph_and_complete_next_stage(&mut agg_graph)?; // Complete the Stage 2, 5 tasks run on executor_2, 2 tasks run on executor_1, 1 task runs on executor_3 for _i in 0..5 { @@ -2559,10 +2526,7 @@ mod test { assert_eq!(agg_graph.stage_count(), 3); // Complete the Stage 1 - if let Some(task) = agg_graph.pop_next_task(&executor1.id)? { - let task_status = mock_completed_task(task, &executor1.id); - agg_graph.update_task_status(&executor1, vec![task_status], 4, 4)?; - } + revive_graph_and_complete_next_stage(&mut agg_graph)?; // Complete the Stage 2, 5 tasks run on executor_2, 3 tasks run on executor_1 for _i in 0..5 { @@ -2662,10 +2626,7 @@ mod test { assert_eq!(agg_graph.stage_count(), 3); // Complete the Stage 1 - if let Some(task) = agg_graph.pop_next_task(&executor1.id)? { - let task_status = mock_completed_task(task, &executor1.id); - agg_graph.update_task_status(&executor1, vec![task_status], 4, 4)?; - } + revive_graph_and_complete_next_stage(&mut agg_graph)?; // Complete the Stage 2, 5 tasks run on executor_2, 3 tasks run on executor_1 for _i in 0..5 { @@ -2735,7 +2696,7 @@ mod test { let running_stage = agg_graph.running_stages(); assert_eq!(running_stage.len(), 1); assert_eq!(running_stage[0], 1); - assert_eq!(agg_graph.available_tasks(), 1); + assert_eq!(agg_graph.available_tasks(), 2); // There are two failed stage attempts: Stage 2 and Stage 3 assert_eq!(agg_graph.failed_stage_attempts.len(), 2); @@ -2759,14 +2720,9 @@ mod test { let executor1 = mock_executor("executor-id1".to_string()); let executor2 = mock_executor("executor-id2".to_string()); let mut agg_graph = test_aggregation_plan(4).await; - // Call revive to move the leaf Resolved stages to Running - agg_graph.revive(); // Complete the Stage 1 - if let Some(task) = agg_graph.pop_next_task(&executor1.id)? { - let task_status = mock_completed_task(task, &executor1.id); - agg_graph.update_task_status(&executor1, vec![task_status], 4, 4)?; - } + revive_graph_and_complete_next_stage(&mut agg_graph)?; // 1st task in the Stage 2 let task1 = agg_graph.pop_next_task(&executor2.id)?.unwrap(); diff --git a/ballista/scheduler/src/state/execution_graph_dot.rs b/ballista/scheduler/src/state/execution_graph_dot.rs index 3a6dce7af..d5d7e7aec 100644 --- a/ballista/scheduler/src/state/execution_graph_dot.rs +++ b/ballista/scheduler/src/state/execution_graph_dot.rs @@ -432,13 +432,13 @@ mod tests { let expected = r#"digraph G { subgraph cluster0 { label = "Stage 1 [Resolved]"; - stage_1_0 [shape=box, label="ShuffleWriter [0 partitions]"] + stage_1_0 [shape=box, label="ShuffleWriter [2 partitions]"] stage_1_0_0 [shape=box, label="MemoryExec"] stage_1_0_0 -> stage_1_0 } subgraph cluster1 { label = "Stage 2 [Resolved]"; - stage_2_0 [shape=box, label="ShuffleWriter [0 partitions]"] + stage_2_0 [shape=box, label="ShuffleWriter [2 partitions]"] stage_2_0_0 [shape=box, label="MemoryExec"] stage_2_0_0 -> stage_2_0 } @@ -462,7 +462,7 @@ filter_expr="] } subgraph cluster3 { label = "Stage 4 [Resolved]"; - stage_4_0 [shape=box, label="ShuffleWriter [0 partitions]"] + stage_4_0 [shape=box, label="ShuffleWriter [2 partitions]"] stage_4_0_0 [shape=box, label="MemoryExec"] stage_4_0_0 -> stage_4_0 } @@ -531,19 +531,19 @@ filter_expr="] let expected = r#"digraph G { subgraph cluster0 { label = "Stage 1 [Resolved]"; - stage_1_0 [shape=box, label="ShuffleWriter [0 partitions]"] + stage_1_0 [shape=box, label="ShuffleWriter [2 partitions]"] stage_1_0_0 [shape=box, label="MemoryExec"] stage_1_0_0 -> stage_1_0 } subgraph cluster1 { label = "Stage 2 [Resolved]"; - stage_2_0 [shape=box, label="ShuffleWriter [0 partitions]"] + stage_2_0 [shape=box, label="ShuffleWriter [2 partitions]"] stage_2_0_0 [shape=box, label="MemoryExec"] stage_2_0_0 -> stage_2_0 } subgraph cluster2 { label = "Stage 3 [Resolved]"; - stage_3_0 [shape=box, label="ShuffleWriter [0 partitions]"] + stage_3_0 [shape=box, label="ShuffleWriter [2 partitions]"] stage_3_0_0 [shape=box, label="MemoryExec"] stage_3_0_0 -> stage_3_0 } @@ -635,7 +635,7 @@ filter_expr="] Field::new("a", DataType::UInt32, false), Field::new("b", DataType::UInt32, false), ])); - let table = Arc::new(MemTable::try_new(schema.clone(), vec![])?); + let table = Arc::new(MemTable::try_new(schema.clone(), vec![vec![], vec![]])?); ctx.register_table("foo", table.clone())?; ctx.register_table("bar", table.clone())?; ctx.register_table("baz", table)?; @@ -660,7 +660,8 @@ filter_expr="] let ctx = SessionContext::new_with_config(config); let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::UInt32, false)])); - let table = Arc::new(MemTable::try_new(schema.clone(), vec![])?); + // we specify the input partitions to be > 1 because of https://github.com/apache/datafusion/issues/12611 + let table = Arc::new(MemTable::try_new(schema.clone(), vec![vec![], vec![]])?); ctx.register_table("foo", table.clone())?; ctx.register_table("bar", table.clone())?; ctx.register_table("baz", table)?; diff --git a/ballista/scheduler/src/test_utils.rs b/ballista/scheduler/src/test_utils.rs index 59e6a875a..5e5dee124 100644 --- a/ballista/scheduler/src/test_utils.rs +++ b/ballista/scheduler/src/test_utils.rs @@ -16,6 +16,7 @@ // under the License. use ballista_core::error::{BallistaError, Result}; +use datafusion::catalog::Session; use std::any::Any; use std::collections::HashMap; use std::future::Future; @@ -44,19 +45,18 @@ use ballista_core::serde::{protobuf, BallistaCodec}; use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion::common::DataFusionError; use datafusion::datasource::{TableProvider, TableType}; -use datafusion::execution::context::{SessionConfig, SessionContext, SessionState}; -use datafusion::functions_aggregate::sum::sum; -use datafusion::logical_expr::expr::Sort; -use datafusion::logical_expr::{Expr, LogicalPlan}; +use datafusion::execution::context::{SessionConfig, SessionContext}; +use datafusion::functions_aggregate::{count::count, sum::sum}; +use datafusion::logical_expr::{Expr, LogicalPlan, SortExpr}; use datafusion::physical_plan::display::DisplayableExecutionPlan; use datafusion::physical_plan::ExecutionPlan; -use datafusion::prelude::{col, count, CsvReadOptions, JoinType}; -use datafusion::test_util::scan_empty; +use datafusion::prelude::{col, CsvReadOptions, JoinType}; +use datafusion::test_util::scan_empty_with_partitions; use crate::cluster::BallistaCluster; use crate::scheduler_server::event::QueryStageSchedulerEvent; -use crate::state::execution_graph::{ExecutionGraph, TaskDescription}; +use crate::state::execution_graph::{ExecutionGraph, ExecutionStage, TaskDescription}; use ballista_core::utils::default_session_builder; use datafusion_proto::protobuf::{LogicalPlanNode, PhysicalPlanNode}; use parking_lot::Mutex; @@ -89,7 +89,7 @@ impl TableProvider for ExplodingTableProvider { async fn scan( &self, - _ctx: &SessionState, + _ctx: &dyn Session, _projection: Option<&Vec>, _filters: &[Expr], _limit: Option, @@ -783,6 +783,47 @@ pub fn assert_failed_event(job_id: &str, collector: &TestMetricsCollector) { assert!(found, "{}", "Expected failed event for job {job_id}"); } +pub fn revive_graph_and_complete_next_stage(graph: &mut ExecutionGraph) -> Result { + let executor = mock_executor("executor-id1".to_string()); + revive_graph_and_complete_next_stage_with_executor(graph, &executor) +} + +pub fn revive_graph_and_complete_next_stage_with_executor( + graph: &mut ExecutionGraph, + executor: &ExecutorMetadata, +) -> Result { + graph.revive(); + + // find the num_available_tasks of the next running stage + let num_available_tasks = graph + .stages() + .iter() + .map(|(_stage_id, stage)| { + if let ExecutionStage::Running(stage) = stage { + stage + .task_infos + .iter() + .filter(|info| info.is_none()) + .count() + } else { + 0 + } + }) + .find(|num_available_tasks| num_available_tasks > &0) + .unwrap(); + + if num_available_tasks > 0 { + for _ in 0..num_available_tasks { + if let Some(task) = graph.pop_next_task(&executor.id).unwrap() { + let task_status = mock_completed_task(task, &executor.id); + graph.update_task_status(executor, vec![task_status], 1, 1)?; + } + } + } + + Ok(num_available_tasks) +} + pub async fn test_aggregation_plan(partition: usize) -> ExecutionGraph { test_aggregation_plan_with_job_id(partition, "job").await } @@ -800,7 +841,8 @@ pub async fn test_aggregation_plan_with_job_id( Field::new("gmv", DataType::UInt64, false), ]); - let logical_plan = scan_empty(None, &schema, Some(vec![0, 1])) + // we specify the input partitions to be > 1 because of https://github.com/apache/datafusion/issues/12611 + let logical_plan = scan_empty_with_partitions(None, &schema, Some(vec![0, 1]), 2) .unwrap() .aggregate(vec![col("id")], vec![sum(col("gmv"))]) .unwrap() @@ -833,7 +875,8 @@ pub async fn test_two_aggregations_plan(partition: usize) -> ExecutionGraph { Field::new("gmv", DataType::UInt64, false), ]); - let logical_plan = scan_empty(None, &schema, Some(vec![0, 1, 2])) + // we specify the input partitions to be > 1 because of https://github.com/apache/datafusion/issues/12611 + let logical_plan = scan_empty_with_partitions(None, &schema, Some(vec![0, 1, 2]), 2) .unwrap() .aggregate(vec![col("id"), col("name")], vec![sum(col("gmv"))]) .unwrap() @@ -867,7 +910,8 @@ pub async fn test_coalesce_plan(partition: usize) -> ExecutionGraph { Field::new("gmv", DataType::UInt64, false), ]); - let logical_plan = scan_empty(None, &schema, Some(vec![0, 1])) + // we specify the input partitions to be > 1 because of https://github.com/apache/datafusion/issues/12611 + let logical_plan = scan_empty_with_partitions(None, &schema, Some(vec![0, 1]), 2) .unwrap() .limit(0, Some(1)) .unwrap() @@ -898,14 +942,15 @@ pub async fn test_join_plan(partition: usize) -> ExecutionGraph { Field::new("gmv", DataType::UInt64, false), ]); - let left_plan = scan_empty(Some("left"), &schema, None).unwrap(); + // we specify the input partitions to be > 1 because of https://github.com/apache/datafusion/issues/12611 + let left_plan = scan_empty_with_partitions(Some("left"), &schema, None, 2).unwrap(); - let right_plan = scan_empty(Some("right"), &schema, None) + let right_plan = scan_empty_with_partitions(Some("right"), &schema, None, 2) .unwrap() .build() .unwrap(); - let sort_expr = Expr::Sort(Sort::new(Box::new(col("id")), false, false)); + let sort_expr = Expr::Sort(SortExpr::new(Box::new(col("id")), false, false)); let logical_plan = left_plan .join(right_plan, JoinType::Inner, (vec!["id"], vec!["id"]), None) diff --git a/examples/Cargo.toml b/examples/Cargo.toml index e41e60518..3fd07740c 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -38,7 +38,7 @@ ballista = { path = "../ballista/client", version = "0.12.0" } datafusion = { workspace = true } futures = "0.3" num_cpus = "1.13.0" -prost = "0.12" +prost = { workspace = true } tokio = { version = "1.0", features = [ "macros", "rt", @@ -46,4 +46,4 @@ tokio = { version = "1.0", features = [ "sync", "parking_lot" ] } -tonic = "0.10" +tonic = { workspace = true } diff --git a/python/.cargo/config.toml b/python/.cargo/config.toml new file mode 100644 index 000000000..d47f983e4 --- /dev/null +++ b/python/.cargo/config.toml @@ -0,0 +1,11 @@ +[target.x86_64-apple-darwin] +rustflags = [ + "-C", "link-arg=-undefined", + "-C", "link-arg=dynamic_lookup", +] + +[target.aarch64-apple-darwin] +rustflags = [ + "-C", "link-arg=-undefined", + "-C", "link-arg=dynamic_lookup", +] diff --git a/python/Cargo.toml b/python/Cargo.toml index 2b2dff410..eb662cb1e 100644 --- a/python/Cargo.toml +++ b/python/Cargo.toml @@ -33,13 +33,12 @@ publish = false async-trait = "0.1.77" ballista = { path = "../ballista/client", version = "0.12.0" } ballista-core = { path = "../ballista/core", version = "0.12.0" } -datafusion = "35.0.0" -datafusion-proto = "35.0.0" +datafusion = "41.0.0" +datafusion-proto = "41.0.0" +datafusion-python = "41.0.0" -# we need to use a recent build of ADP that has a public PyDataFrame -datafusion-python = { git = "https://github.com/apache/arrow-datafusion-python", rev = "5296c0cfcf8e6fcb654d5935252469bf04f929e9" } - -pyo3 = { version = "0.20", features = ["extension-module", "abi3", "abi3-py38"] } +pyo3 = { version = "0.21", features = ["extension-module", "abi3", "abi3-py38"] } +pyo3-log = "0.11.0" tokio = { version = "1.35", features = ["macros", "rt", "rt-multi-thread", "sync"] } [lib] diff --git a/python/src/context.rs b/python/src/context.rs index 0d0231c67..be7dd6109 100644 --- a/python/src/context.rs +++ b/python/src/context.rs @@ -16,6 +16,7 @@ // under the License. use crate::utils::to_pyerr; +use datafusion::logical_expr::SortExpr; use pyo3::exceptions::PyValueError; use pyo3::prelude::*; use std::path::PathBuf; @@ -30,7 +31,7 @@ use datafusion_python::context::{ }; use datafusion_python::dataframe::PyDataFrame; use datafusion_python::errors::DataFusionError; -use datafusion_python::expr::PyExpr; +use datafusion_python::expr::sort_expr::PySortExpr; use datafusion_python::sql::logical::PyLogicalPlan; use datafusion_python::utils::wait_for_future; @@ -187,7 +188,7 @@ impl PySessionContext { file_extension: &str, skip_metadata: bool, schema: Option>, - file_sort_order: Option>>, + file_sort_order: Option>>, py: Python, ) -> PyResult { let mut options = ParquetReadOptions::default() @@ -199,7 +200,14 @@ impl PySessionContext { options.file_sort_order = file_sort_order .unwrap_or_default() .into_iter() - .map(|e| e.into_iter().map(|f| f.into()).collect()) + .map(|e| { + e.into_iter() + .map(|f| { + let sort_expr: SortExpr = f.into(); + *sort_expr.expr + }) + .collect() + }) .collect(); let result = self.ctx.read_parquet(path, options); @@ -299,7 +307,7 @@ impl PySessionContext { file_extension: &str, skip_metadata: bool, schema: Option>, - file_sort_order: Option>>, + file_sort_order: Option>>, py: Python, ) -> PyResult<()> { let mut options = ParquetReadOptions::default() @@ -311,7 +319,14 @@ impl PySessionContext { options.file_sort_order = file_sort_order .unwrap_or_default() .into_iter() - .map(|e| e.into_iter().map(|f| f.into()).collect()) + .map(|e| { + e.into_iter() + .map(|f| { + let sort_expr: SortExpr = f.into(); + *sort_expr.expr + }) + .collect() + }) .collect(); let result = self.ctx.register_parquet(name, path, options); diff --git a/python/src/lib.rs b/python/src/lib.rs index 04cf232a2..5fbd2491b 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -22,7 +22,8 @@ mod utils; pub use crate::context::PySessionContext; #[pymodule] -fn pyballista_internal(_py: Python, m: &PyModule) -> PyResult<()> { +fn pyballista_internal(_py: Python, m: Bound<'_, PyModule>) -> PyResult<()> { + pyo3_log::init(); // Ballista structs m.add_class::()?; // DataFusion structs