From 169cadd2bbd37c2483da9a8a1fe6ac9453d6f0a4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marko=20Milenkovi=C4=87?= Date: Fri, 22 Nov 2024 19:13:28 +0000 Subject: [PATCH] Make easier to create custom schedulers and executors (#1118) --- Cargo.toml | 1 + ballista/executor/Cargo.toml | 2 +- ballista/executor/src/bin/main.rs | 90 ++--- ballista/executor/src/config.rs | 71 ++++ ballista/executor/src/executor_process.rs | 72 +--- ballista/executor/src/lib.rs | 1 + ballista/scheduler/Cargo.toml | 2 +- ballista/scheduler/scheduler_config_spec.toml | 4 +- ballista/scheduler/src/bin/main.rs | 111 ++---- ballista/scheduler/src/cluster/memory.rs | 2 +- ballista/scheduler/src/cluster/mod.rs | 14 +- ballista/scheduler/src/config.rs | 129 ++++++- ballista/scheduler/src/scheduler_process.rs | 18 +- docs/source/index.rst | 1 + .../source/user-guide/extending-components.md | 250 ++++++++++++++ examples/Cargo.toml | 10 +- examples/examples/custom-client.rs | 123 +++++++ examples/examples/custom-executor.rs | 64 ++++ examples/examples/custom-scheduler.rs | 68 ++++ examples/src/lib.rs | 3 + examples/src/object_store.rs | 323 ++++++++++++++++++ 21 files changed, 1150 insertions(+), 209 deletions(-) create mode 100644 ballista/executor/src/config.rs create mode 100644 docs/source/user-guide/extending-components.md create mode 100644 examples/examples/custom-client.rs create mode 100644 examples/examples/custom-executor.rs create mode 100644 examples/examples/custom-scheduler.rs create mode 100644 examples/src/object_store.rs diff --git a/Cargo.toml b/Cargo.toml index 4e88716dc..0467d8ab6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,6 +21,7 @@ members = ["ballista-cli", "ballista/client", "ballista/core", "ballista/executo resolver = "2" [workspace.dependencies] +anyhow = "1" arrow = { version = "53", features = ["ipc_compression"] } arrow-flight = { version = "53", features = ["flight-sql-experimental"] } clap = { version = "3", features = ["derive", "cargo"] } diff --git a/ballista/executor/Cargo.toml b/ballista/executor/Cargo.toml index e1822e9c1..a7c5c65cc 100644 --- a/ballista/executor/Cargo.toml +++ b/ballista/executor/Cargo.toml @@ -37,7 +37,7 @@ path = "src/bin/main.rs" default = ["mimalloc"] [dependencies] -anyhow = "1" +anyhow = { workspace = true } arrow = { workspace = true } arrow-flight = { workspace = true } async-trait = { workspace = true } diff --git a/ballista/executor/src/bin/main.rs b/ballista/executor/src/bin/main.rs index 9f5ed12f1..5ef88e8bf 100644 --- a/ballista/executor/src/bin/main.rs +++ b/ballista/executor/src/bin/main.rs @@ -18,24 +18,15 @@ //! Ballista Rust executor binary. use anyhow::Result; -use std::sync::Arc; - +use ballista_core::config::LogRotationPolicy; use ballista_core::print_version; +use ballista_executor::config::prelude::*; use ballista_executor::executor_process::{ start_executor_process, ExecutorProcessConfig, }; -use config::prelude::*; - -#[allow(unused_imports)] -#[macro_use] -extern crate configure_me; - -#[allow(clippy::all, warnings)] -mod config { - // Ideally we would use the include_config macro from configure_me, but then we cannot use - // #[allow(clippy::all)] to silence clippy warnings from the generated code - include!(concat!(env!("OUT_DIR"), "/executor_configure_me_config.rs")); -} +use std::env; +use std::sync::Arc; +use tracing_subscriber::EnvFilter; #[cfg(feature = "mimalloc")] #[global_allocator] @@ -53,46 +44,39 @@ async fn main() -> Result<()> { std::process::exit(0); } - let log_file_name_prefix = format!( - "executor_{}_{}", - opt.external_host - .clone() - .unwrap_or_else(|| "localhost".to_string()), - opt.bind_port - ); + let config: ExecutorProcessConfig = opt.try_into()?; + + let rust_log = env::var(EnvFilter::DEFAULT_ENV); + let log_filter = + EnvFilter::new(rust_log.unwrap_or(config.special_mod_log_level.clone())); + + let tracing = tracing_subscriber::fmt() + .with_ansi(false) + .with_thread_names(config.print_thread_info) + .with_thread_ids(config.print_thread_info) + .with_env_filter(log_filter); - let config = ExecutorProcessConfig { - special_mod_log_level: opt.log_level_setting, - external_host: opt.external_host, - bind_host: opt.bind_host, - port: opt.bind_port, - grpc_port: opt.bind_grpc_port, - scheduler_host: opt.scheduler_host, - scheduler_port: opt.scheduler_port, - scheduler_connect_timeout_seconds: opt.scheduler_connect_timeout_seconds, - concurrent_tasks: opt.concurrent_tasks, - task_scheduling_policy: opt.task_scheduling_policy, - work_dir: opt.work_dir, - log_dir: opt.log_dir, - log_file_name_prefix, - log_rotation_policy: opt.log_rotation_policy, - print_thread_info: opt.print_thread_info, - job_data_ttl_seconds: opt.job_data_ttl_seconds, - job_data_clean_up_interval_seconds: opt.job_data_clean_up_interval_seconds, - grpc_max_decoding_message_size: opt.grpc_server_max_decoding_message_size, - grpc_max_encoding_message_size: opt.grpc_server_max_encoding_message_size, - executor_heartbeat_interval_seconds: opt.executor_heartbeat_interval_seconds, - data_cache_policy: opt.data_cache_policy, - cache_dir: opt.cache_dir, - cache_capacity: opt.cache_capacity, - cache_io_concurrency: opt.cache_io_concurrency, - execution_engine: None, - function_registry: None, - config_producer: None, - runtime_producer: None, - logical_codec: None, - physical_codec: None, - }; + // File layer + if let Some(log_dir) = &config.log_dir { + let log_file = match config.log_rotation_policy { + LogRotationPolicy::Minutely => { + tracing_appender::rolling::minutely(log_dir, &config.log_file_name_prefix) + } + LogRotationPolicy::Hourly => { + tracing_appender::rolling::hourly(log_dir, &config.log_file_name_prefix) + } + LogRotationPolicy::Daily => { + tracing_appender::rolling::daily(log_dir, &config.log_file_name_prefix) + } + LogRotationPolicy::Never => { + tracing_appender::rolling::never(log_dir, &config.log_file_name_prefix) + } + }; + + tracing.with_writer(log_file).init(); + } else { + tracing.init(); + } start_executor_process(Arc::new(config)).await } diff --git a/ballista/executor/src/config.rs b/ballista/executor/src/config.rs new file mode 100644 index 000000000..78db477f9 --- /dev/null +++ b/ballista/executor/src/config.rs @@ -0,0 +1,71 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use ballista_core::error::BallistaError; + +use crate::executor_process::ExecutorProcessConfig; + +// Ideally we would use the include_config macro from configure_me, but then we cannot use +// #[allow(clippy::all)] to silence clippy warnings from the generated code +include!(concat!(env!("OUT_DIR"), "/executor_configure_me_config.rs")); + +impl TryFrom for ExecutorProcessConfig { + type Error = BallistaError; + + fn try_from(opt: Config) -> Result { + let log_file_name_prefix = format!( + "executor_{}_{}", + opt.external_host + .clone() + .unwrap_or_else(|| "localhost".to_string()), + opt.bind_port + ); + + Ok(ExecutorProcessConfig { + special_mod_log_level: opt.log_level_setting, + external_host: opt.external_host, + bind_host: opt.bind_host, + port: opt.bind_port, + grpc_port: opt.bind_grpc_port, + scheduler_host: opt.scheduler_host, + scheduler_port: opt.scheduler_port, + scheduler_connect_timeout_seconds: opt.scheduler_connect_timeout_seconds, + concurrent_tasks: opt.concurrent_tasks, + task_scheduling_policy: opt.task_scheduling_policy, + work_dir: opt.work_dir, + log_dir: opt.log_dir, + log_file_name_prefix, + log_rotation_policy: opt.log_rotation_policy, + print_thread_info: opt.print_thread_info, + job_data_ttl_seconds: opt.job_data_ttl_seconds, + job_data_clean_up_interval_seconds: opt.job_data_clean_up_interval_seconds, + grpc_max_decoding_message_size: opt.grpc_server_max_decoding_message_size, + grpc_max_encoding_message_size: opt.grpc_server_max_encoding_message_size, + executor_heartbeat_interval_seconds: opt.executor_heartbeat_interval_seconds, + data_cache_policy: opt.data_cache_policy, + cache_dir: opt.cache_dir, + cache_capacity: opt.cache_capacity, + cache_io_concurrency: opt.cache_io_concurrency, + override_execution_engine: None, + override_function_registry: None, + override_config_producer: None, + override_runtime_producer: None, + override_logical_codec: None, + override_physical_codec: None, + }) + } +} diff --git a/ballista/executor/src/executor_process.rs b/ballista/executor/src/executor_process.rs index 9a6187bda..db276e108 100644 --- a/ballista/executor/src/executor_process.rs +++ b/ballista/executor/src/executor_process.rs @@ -21,7 +21,6 @@ use std::net::SocketAddr; use std::sync::atomic::Ordering; use std::sync::Arc; use std::time::{Duration, Instant, UNIX_EPOCH}; -use std::{env, io}; use anyhow::{Context, Result}; use arrow_flight::flight_service_server::FlightServiceServer; @@ -37,7 +36,6 @@ use tokio::signal; use tokio::sync::mpsc; use tokio::task::JoinHandle; use tokio::{fs, time}; -use tracing_subscriber::EnvFilter; use uuid::Uuid; use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; @@ -98,57 +96,20 @@ pub struct ExecutorProcessConfig { pub executor_heartbeat_interval_seconds: u64, /// Optional execution engine to use to execute physical plans, will default to /// DataFusion if none is provided. - pub execution_engine: Option>, + pub override_execution_engine: Option>, /// Overrides default function registry - pub function_registry: Option>, + pub override_function_registry: Option>, /// [RuntimeProducer] override option - pub runtime_producer: Option, + pub override_runtime_producer: Option, /// [ConfigProducer] override option - pub config_producer: Option, + pub override_config_producer: Option, /// [PhysicalExtensionCodec] override option - pub logical_codec: Option>, + pub override_logical_codec: Option>, /// [PhysicalExtensionCodec] override option - pub physical_codec: Option>, + pub override_physical_codec: Option>, } pub async fn start_executor_process(opt: Arc) -> Result<()> { - let rust_log = env::var(EnvFilter::DEFAULT_ENV); - let log_filter = - EnvFilter::new(rust_log.unwrap_or(opt.special_mod_log_level.clone())); - // File layer - if let Some(log_dir) = opt.log_dir.clone() { - let log_file = match opt.log_rotation_policy { - LogRotationPolicy::Minutely => { - tracing_appender::rolling::minutely(log_dir, &opt.log_file_name_prefix) - } - LogRotationPolicy::Hourly => { - tracing_appender::rolling::hourly(log_dir, &opt.log_file_name_prefix) - } - LogRotationPolicy::Daily => { - tracing_appender::rolling::daily(log_dir, &opt.log_file_name_prefix) - } - LogRotationPolicy::Never => { - tracing_appender::rolling::never(log_dir, &opt.log_file_name_prefix) - } - }; - tracing_subscriber::fmt() - .with_ansi(false) - .with_thread_names(opt.print_thread_info) - .with_thread_ids(opt.print_thread_info) - .with_writer(log_file) - .with_env_filter(log_filter) - .init(); - } else { - // Console layer - tracing_subscriber::fmt() - .with_ansi(false) - .with_thread_names(opt.print_thread_info) - .with_thread_ids(opt.print_thread_info) - .with_writer(io::stdout) - .with_env_filter(log_filter) - .init(); - } - let addr = format!("{}:{}", opt.bind_host, opt.port); let addr = addr .parse() @@ -194,23 +155,26 @@ pub async fn start_executor_process(opt: Arc) -> Result<( // put them to session config let metrics_collector = Arc::new(LoggingMetricsCollector::default()); let config_producer = opt - .config_producer + .override_config_producer .clone() .unwrap_or_else(|| Arc::new(default_config_producer)); let wd = work_dir.clone(); - let runtime_producer: RuntimeProducer = Arc::new(move |_| { - let config = RuntimeConfig::new().with_temp_file_path(wd.clone()); - Ok(Arc::new(RuntimeEnv::new(config)?)) - }); + let runtime_producer: RuntimeProducer = + opt.override_runtime_producer.clone().unwrap_or_else(|| { + Arc::new(move |_| { + let config = RuntimeConfig::new().with_temp_file_path(wd.clone()); + Ok(Arc::new(RuntimeEnv::new(config)?)) + }) + }); let logical = opt - .logical_codec + .override_logical_codec .clone() .unwrap_or_else(|| Arc::new(BallistaLogicalExtensionCodec::default())); let physical = opt - .physical_codec + .override_physical_codec .clone() .unwrap_or_else(|| Arc::new(BallistaPhysicalExtensionCodec::default())); @@ -224,10 +188,10 @@ pub async fn start_executor_process(opt: Arc) -> Result<( &work_dir, runtime_producer, config_producer, - opt.function_registry.clone().unwrap_or_default(), + opt.override_function_registry.clone().unwrap_or_default(), metrics_collector, concurrent_tasks, - opt.execution_engine.clone(), + opt.override_execution_engine.clone(), )); let connect_timeout = opt.scheduler_connect_timeout_seconds as u64; diff --git a/ballista/executor/src/lib.rs b/ballista/executor/src/lib.rs index bc9d23e87..f0284cbdb 100644 --- a/ballista/executor/src/lib.rs +++ b/ballista/executor/src/lib.rs @@ -18,6 +18,7 @@ #![doc = include_str!("../README.md")] pub mod collect; +pub mod config; pub mod execution_engine; pub mod execution_loop; pub mod executor; diff --git a/ballista/scheduler/Cargo.toml b/ballista/scheduler/Cargo.toml index 642e63d48..ad3e09636 100644 --- a/ballista/scheduler/Cargo.toml +++ b/ballista/scheduler/Cargo.toml @@ -41,7 +41,7 @@ prometheus-metrics = ["prometheus", "once_cell"] rest-api = [] [dependencies] -anyhow = "1" +anyhow = { workspace = true } arrow-flight = { workspace = true } async-trait = { workspace = true } axum = "0.7.7" diff --git a/ballista/scheduler/scheduler_config_spec.toml b/ballista/scheduler/scheduler_config_spec.toml index 804987d9a..20bceb5f2 100644 --- a/ballista/scheduler/scheduler_config_spec.toml +++ b/ballista/scheduler/scheduler_config_spec.toml @@ -82,9 +82,9 @@ doc = "Delayed interval for cleaning up finished job state. Default: 3600" [[param]] name = "task_distribution" -type = "ballista_scheduler::config::TaskDistribution" +type = "crate::config::TaskDistribution" doc = "The policy of distributing tasks to available executor slots, possible values: bias, round-robin, consistent-hash. Default: bias" -default = "ballista_scheduler::config::TaskDistribution::Bias" +default = "crate::config::TaskDistribution::Bias" [[param]] name = "consistent_hash_num_replicas" diff --git a/ballista/scheduler/src/bin/main.rs b/ballista/scheduler/src/bin/main.rs index 7d8b4b1b0..f6a063284 100644 --- a/ballista/scheduler/src/bin/main.rs +++ b/ballista/scheduler/src/bin/main.rs @@ -17,35 +17,16 @@ //! Ballista Rust scheduler binary. -use std::sync::Arc; -use std::{env, io}; - use anyhow::Result; - -use crate::config::{Config, ResultExt}; use ballista_core::config::LogRotationPolicy; use ballista_core::print_version; use ballista_scheduler::cluster::BallistaCluster; -use ballista_scheduler::config::{ - ClusterStorageConfig, SchedulerConfig, TaskDistribution, TaskDistributionPolicy, -}; +use ballista_scheduler::config::{Config, ResultExt}; use ballista_scheduler::scheduler_process::start_server; +use std::sync::Arc; +use std::{env, io}; use tracing_subscriber::EnvFilter; -#[allow(unused_imports)] -#[macro_use] -extern crate configure_me; - -#[allow(clippy::all, warnings)] -mod config { - // Ideally we would use the include_config macro from configure_me, but then we cannot use - // #[allow(clippy::all)] to silence clippy warnings from the generated code - include!(concat!( - env!("OUT_DIR"), - "/scheduler_configure_me_config.rs" - )); -} - fn main() -> Result<()> { let runtime = tokio::runtime::Builder::new_multi_thread() .enable_io() @@ -67,19 +48,23 @@ async fn inner() -> Result<()> { std::process::exit(0); } - let special_mod_log_level = opt.log_level_setting; - let log_dir = opt.log_dir; - let print_thread_info = opt.print_thread_info; + let rust_log = env::var(EnvFilter::DEFAULT_ENV); + let log_filter = EnvFilter::new(rust_log.unwrap_or(opt.log_level_setting.clone())); - let log_file_name_prefix = format!( - "scheduler_{}_{}_{}", - opt.namespace, opt.external_host, opt.bind_port - ); + let tracing = tracing_subscriber::fmt() + .with_ansi(false) + .with_thread_names(opt.print_thread_info) + .with_thread_ids(opt.print_thread_info) + .with_writer(io::stdout) + .with_env_filter(log_filter); - let rust_log = env::var(EnvFilter::DEFAULT_ENV); - let log_filter = EnvFilter::new(rust_log.unwrap_or(special_mod_log_level)); // File layer - if let Some(log_dir) = log_dir { + if let Some(log_dir) = &opt.log_dir { + let log_file_name_prefix = format!( + "scheduler_{}_{}_{}", + opt.namespace, opt.external_host, opt.bind_port + ); + let log_file = match opt.log_rotation_policy { LogRotationPolicy::Minutely => { tracing_appender::rolling::minutely(log_dir, &log_file_name_prefix) @@ -94,68 +79,16 @@ async fn inner() -> Result<()> { tracing_appender::rolling::never(log_dir, &log_file_name_prefix) } }; - tracing_subscriber::fmt() - .with_ansi(false) - .with_thread_names(print_thread_info) - .with_thread_ids(print_thread_info) - .with_writer(log_file) - .with_env_filter(log_filter) - .init(); + + tracing.with_writer(log_file).init(); } else { - // Console layer - tracing_subscriber::fmt() - .with_ansi(false) - .with_thread_names(print_thread_info) - .with_thread_ids(print_thread_info) - .with_writer(io::stdout) - .with_env_filter(log_filter) - .init(); + tracing.init(); } - let addr = format!("{}:{}", opt.bind_host, opt.bind_port); let addr = addr.parse()?; - - let cluster_storage_config = ClusterStorageConfig::Memory; - - let task_distribution = match opt.task_distribution { - TaskDistribution::Bias => TaskDistributionPolicy::Bias, - TaskDistribution::RoundRobin => TaskDistributionPolicy::RoundRobin, - TaskDistribution::ConsistentHash => { - let num_replicas = opt.consistent_hash_num_replicas as usize; - let tolerance = opt.consistent_hash_tolerance as usize; - TaskDistributionPolicy::ConsistentHash { - num_replicas, - tolerance, - } - } - }; - - let config = SchedulerConfig { - namespace: opt.namespace, - external_host: opt.external_host, - bind_port: opt.bind_port, - scheduling_policy: opt.scheduler_policy, - event_loop_buffer_size: opt.event_loop_buffer_size, - task_distribution, - finished_job_data_clean_up_interval_seconds: opt - .finished_job_data_clean_up_interval_seconds, - finished_job_state_clean_up_interval_seconds: opt - .finished_job_state_clean_up_interval_seconds, - advertise_flight_sql_endpoint: opt.advertise_flight_sql_endpoint, - cluster_storage: cluster_storage_config, - job_resubmit_interval_ms: (opt.job_resubmit_interval_ms > 0) - .then_some(opt.job_resubmit_interval_ms), - executor_termination_grace_period: opt.executor_termination_grace_period, - scheduler_event_expected_processing_duration: opt - .scheduler_event_expected_processing_duration, - grpc_server_max_decoding_message_size: opt.grpc_server_max_decoding_message_size, - grpc_server_max_encoding_message_size: opt.grpc_server_max_encoding_message_size, - executor_timeout_seconds: opt.executor_timeout_seconds, - expire_dead_executor_interval_seconds: opt.expire_dead_executor_interval_seconds, - }; - + let config = opt.try_into()?; let cluster = BallistaCluster::new_from_config(&config).await?; - start_server(cluster, addr, Arc::new(config)).await?; + Ok(()) } diff --git a/ballista/scheduler/src/cluster/memory.rs b/ballista/scheduler/src/cluster/memory.rs index 6e32510a0..6df044035 100644 --- a/ballista/scheduler/src/cluster/memory.rs +++ b/ballista/scheduler/src/cluster/memory.rs @@ -290,7 +290,7 @@ pub struct InMemoryJobState { session_builder: SessionBuilder, /// Sender of job events job_event_sender: ClusterEventSender, - + /// Config producer config_producer: ConfigProducer, } diff --git a/ballista/scheduler/src/cluster/mod.rs b/ballista/scheduler/src/cluster/mod.rs index 2869c8876..94f86969e 100644 --- a/ballista/scheduler/src/cluster/mod.rs +++ b/ballista/scheduler/src/cluster/mod.rs @@ -111,11 +111,21 @@ impl BallistaCluster { pub async fn new_from_config(config: &SchedulerConfig) -> Result { let scheduler = config.scheduler_name(); + let session_builder = config + .override_session_builder + .clone() + .unwrap_or_else(|| Arc::new(default_session_builder)); + + let config_producer = config + .override_config_producer + .clone() + .unwrap_or_else(|| Arc::new(default_config_producer)); + match &config.cluster_storage { ClusterStorageConfig::Memory => Ok(BallistaCluster::new_memory( scheduler, - Arc::new(default_session_builder), - Arc::new(default_config_producer), + session_builder, + config_producer, )), } } diff --git a/ballista/scheduler/src/config.rs b/ballista/scheduler/src/config.rs index ce542e519..7bb85bd48 100644 --- a/ballista/scheduler/src/config.rs +++ b/ballista/scheduler/src/config.rs @@ -18,12 +18,20 @@ //! Ballista scheduler specific configuration -use ballista_core::config::TaskSchedulingPolicy; +use crate::SessionBuilder; +use ballista_core::{config::TaskSchedulingPolicy, error::BallistaError, ConfigProducer}; use clap::ValueEnum; -use std::fmt; +use datafusion_proto::logical_plan::LogicalExtensionCodec; +use datafusion_proto::physical_plan::PhysicalExtensionCodec; +use std::{fmt, sync::Arc}; + +include!(concat!( + env!("OUT_DIR"), + "/scheduler_configure_me_config.rs" +)); /// Configurations for the ballista scheduler of scheduling jobs and tasks -#[derive(Debug, Clone)] +#[derive(Clone)] pub struct SchedulerConfig { /// Namespace of this scheduler. Schedulers using the same cluster storage and namespace /// will share global cluster state. @@ -62,6 +70,65 @@ pub struct SchedulerConfig { pub executor_timeout_seconds: u64, /// The interval to check expired or dead executors pub expire_dead_executor_interval_seconds: u64, + + /// [ConfigProducer] override option + pub override_config_producer: Option, + /// [SessionBuilder] override option + pub override_session_builder: Option, + /// [PhysicalExtensionCodec] override option + pub override_logical_codec: Option>, + /// [PhysicalExtensionCodec] override option + pub override_physical_codec: Option>, +} + +impl std::fmt::Debug for SchedulerConfig { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("SchedulerConfig") + .field("namespace", &self.namespace) + .field("external_host", &self.external_host) + .field("bind_port", &self.bind_port) + .field("scheduling_policy", &self.scheduling_policy) + .field("event_loop_buffer_size", &self.event_loop_buffer_size) + .field("task_distribution", &self.task_distribution) + .field( + "finished_job_data_clean_up_interval_seconds", + &self.finished_job_data_clean_up_interval_seconds, + ) + .field( + "finished_job_state_clean_up_interval_seconds", + &self.finished_job_state_clean_up_interval_seconds, + ) + .field( + "advertise_flight_sql_endpoint", + &self.advertise_flight_sql_endpoint, + ) + .field("job_resubmit_interval_ms", &self.job_resubmit_interval_ms) + .field("cluster_storage", &self.cluster_storage) + .field( + "executor_termination_grace_period", + &self.executor_termination_grace_period, + ) + .field( + "scheduler_event_expected_processing_duration", + &self.scheduler_event_expected_processing_duration, + ) + .field( + "grpc_server_max_decoding_message_size", + &self.grpc_server_max_decoding_message_size, + ) + .field( + "grpc_server_max_encoding_message_size", + &self.grpc_server_max_encoding_message_size, + ) + .field("executor_timeout_seconds", &self.executor_timeout_seconds) + .field( + "expire_dead_executor_interval_seconds", + &self.expire_dead_executor_interval_seconds, + ) + .field("override_logical_codec", &self.override_logical_codec) + .field("override_physical_codec", &self.override_physical_codec) + .finish() + } } impl Default for SchedulerConfig { @@ -84,6 +151,10 @@ impl Default for SchedulerConfig { grpc_server_max_encoding_message_size: 16777216, executor_timeout_seconds: 180, expire_dead_executor_interval_seconds: 15, + override_config_producer: None, + override_session_builder: None, + override_logical_codec: None, + override_physical_codec: None, } } } @@ -231,3 +302,55 @@ pub enum TaskDistributionPolicy { tolerance: usize, }, } + +impl TryFrom for SchedulerConfig { + type Error = BallistaError; + + fn try_from(opt: Config) -> Result { + let task_distribution = match opt.task_distribution { + TaskDistribution::Bias => TaskDistributionPolicy::Bias, + TaskDistribution::RoundRobin => TaskDistributionPolicy::RoundRobin, + TaskDistribution::ConsistentHash => { + let num_replicas = opt.consistent_hash_num_replicas as usize; + let tolerance = opt.consistent_hash_tolerance as usize; + TaskDistributionPolicy::ConsistentHash { + num_replicas, + tolerance, + } + } + }; + + let config = SchedulerConfig { + namespace: opt.namespace, + external_host: opt.external_host, + bind_port: opt.bind_port, + scheduling_policy: opt.scheduler_policy, + event_loop_buffer_size: opt.event_loop_buffer_size, + task_distribution, + finished_job_data_clean_up_interval_seconds: opt + .finished_job_data_clean_up_interval_seconds, + finished_job_state_clean_up_interval_seconds: opt + .finished_job_state_clean_up_interval_seconds, + advertise_flight_sql_endpoint: opt.advertise_flight_sql_endpoint, + cluster_storage: ClusterStorageConfig::Memory, + job_resubmit_interval_ms: (opt.job_resubmit_interval_ms > 0) + .then_some(opt.job_resubmit_interval_ms), + executor_termination_grace_period: opt.executor_termination_grace_period, + scheduler_event_expected_processing_duration: opt + .scheduler_event_expected_processing_duration, + grpc_server_max_decoding_message_size: opt + .grpc_server_max_decoding_message_size, + grpc_server_max_encoding_message_size: opt + .grpc_server_max_encoding_message_size, + executor_timeout_seconds: opt.executor_timeout_seconds, + expire_dead_executor_interval_seconds: opt + .expire_dead_executor_interval_seconds, + override_config_producer: None, + override_logical_codec: None, + override_physical_codec: None, + override_session_builder: None, + }; + + Ok(config) + } +} diff --git a/ballista/scheduler/src/scheduler_process.rs b/ballista/scheduler/src/scheduler_process.rs index 4b9706079..393b03b62 100644 --- a/ballista/scheduler/src/scheduler_process.rs +++ b/ballista/scheduler/src/scheduler_process.rs @@ -19,7 +19,9 @@ use anyhow::{Error, Result}; #[cfg(feature = "flight-sql")] use arrow_flight::flight_service_server::FlightServiceServer; use ballista_core::serde::protobuf::scheduler_grpc_server::SchedulerGrpcServer; -use ballista_core::serde::BallistaCodec; +use ballista_core::serde::{ + BallistaCodec, BallistaLogicalExtensionCodec, BallistaPhysicalExtensionCodec, +}; use ballista_core::utils::create_grpc_server; use ballista_core::BALLISTA_VERSION; use datafusion_proto::protobuf::{LogicalPlanNode, PhysicalPlanNode}; @@ -54,11 +56,23 @@ pub async fn start_server( let metrics_collector = default_metrics_collector()?; + let codec_logical = config + .override_logical_codec + .clone() + .unwrap_or_else(|| Arc::new(BallistaLogicalExtensionCodec::default())); + + let codec_physical = config + .override_physical_codec + .clone() + .unwrap_or_else(|| Arc::new(BallistaPhysicalExtensionCodec::default())); + + let codec = BallistaCodec::new(codec_logical, codec_physical); + let mut scheduler_server: SchedulerServer = SchedulerServer::new( config.scheduler_name(), cluster, - BallistaCodec::default(), + codec, config, metrics_collector, ); diff --git a/docs/source/index.rst b/docs/source/index.rst index 959d5844b..9289eab75 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -55,6 +55,7 @@ Table of content user-guide/tuning-guide user-guide/metrics user-guide/faq + user-guide/extending-components .. _toc.contributors: diff --git a/docs/source/user-guide/extending-components.md b/docs/source/user-guide/extending-components.md new file mode 100644 index 000000000..556c0a36b --- /dev/null +++ b/docs/source/user-guide/extending-components.md @@ -0,0 +1,250 @@ + + +# Extending Ballista Scheduler And Executors + +Ballista scheduler and executor provide a set of configuration options +which can be used to extend their basic functionality. They allow registering +new configuration extensions, object stores, logical and physical codecs ... + +- `function registry` - provides possibility to override set of build in functions. +- `config producer` - function which creates new `SessionConfig`, which can hold extended configuration options +- `runtime producer` - function which creates new `RuntimeEnv` based on provided `SessionConfig`. +- `session builder` - function which creates new `SessionState` for each user session +- `logical codec` - overrides `LogicalCodec` +- `physical codec` - overrides `PhysicalCodec` + +Ballista executor can be configured using `ExecutorProcessConfig` which supports overriding `function registry`,`runtime producer`, `config producer`, `logical codec`, `physical codec`. + +Ballista scheduler can be tunned using `SchedulerConfig` which supports overriding `config producer`, `session builder`, `logical codec`, `physical codec` + +## Example: Custom Object Store Integration + +Extending basic building blocks will be demonstrated by integrating S3 object store. For this, new `ObjectStoreRegistry` and `S3Options` will be provided. `ObjectStoreRegistry` creates new `ObjectStore` instances configured using `S3Options`. + +For this specific task `config producer`, `runtime producer` and `session builder` have to be provided, and client, scheduler and executor need to be configured. + +```rust +/// Custom [SessionConfig] constructor method +/// +/// This method registers config extension [S3Options] +/// which is used to configure [ObjectStore] with ACCESS and +/// SECRET key +pub fn custom_session_config_with_s3_options() -> SessionConfig { + SessionConfig::new_with_ballista() + .with_information_schema(true) + .with_option_extension(S3Options::default()) +} +``` + +```rust +/// Custom [RuntimeEnv] constructor method +/// +/// It will register [CustomObjectStoreRegistry] which will +/// use configuration extension [S3Options] to configure +/// and created [ObjectStore]s +pub fn custom_runtime_env_with_s3_support( + session_config: &SessionConfig, +) -> Result> { + let s3options = session_config + .options() + .extensions + .get::() + .ok_or(DataFusionError::Configuration( + "S3 Options not set".to_string(), + ))?; + + let config = RuntimeConfig::new().with_object_store_registry(Arc::new( + CustomObjectStoreRegistry::new(s3options.clone()), + )); + + Ok(Arc::new(RuntimeEnv::new(config)?)) +} +``` + +```rust +/// Custom [SessionState] constructor method +/// +/// It will configure [SessionState] with provided [SessionConfig], +/// and [RuntimeEnv]. +pub fn custom_session_state_with_s3_support( + session_config: SessionConfig, +) -> SessionState { + let runtime_env = custom_runtime_env_with_s3_support(&session_config).unwrap(); + + SessionStateBuilder::new() + .with_runtime_env(runtime_env) + .with_config(session_config) + .build() +} +``` + +`S3Options` & `CustomObjectStoreRegistry` implementation can be found in examples sub-project. + +### Configuring Scheduler + +```rust +#[tokio::main] +async fn main() -> Result<()> { + // parse CLI options (default options which Ballista scheduler exposes) + let (opt, _remaining_args) = + Config::including_optional_config_files(&["/etc/ballista/scheduler.toml"]) + .unwrap_or_exit(); + + let addr = format!("{}:{}", opt.bind_host, opt.bind_port); + let addr = addr.parse()?; + + // converting CLI options to SchedulerConfig + let mut config: SchedulerConfig = opt.try_into()?; + + // overriding default runtime producer with custom producer + // which knows how to create S3 connections + config.override_config_producer = + Some(Arc::new(custom_session_config_with_s3_options)); + + // overriding default session builder, which has custom session configuration + // runtime environment and session state. + config.override_session_builder = Some(Arc::new(|session_config: SessionConfig| { + custom_session_state_with_s3_support(session_config) + })); + let cluster = BallistaCluster::new_from_config(&config).await?; + start_server(cluster, addr, Arc::new(config)).await?; + Ok(()) +} +``` + +### Configuring Executor + +```rust +#[tokio::main] +async fn main() -> Result<()> { + // parse CLI options (default options which Ballista executor exposes) + let (opt, _remaining_args) = + Config::including_optional_config_files(&["/etc/ballista/executor.toml"]) + .unwrap_or_exit(); + + // Converting CLI options to executor configuration + let mut config: ExecutorProcessConfig = opt.try_into().unwrap(); + + // overriding default config producer with custom producer + // which has required S3 configuration options + config.override_config_producer = + Some(Arc::new(custom_session_config_with_s3_options)); + + // overriding default runtime producer with custom producer + // which knows how to create S3 connections + config.override_runtime_producer = + Some(Arc::new(|session_config: &SessionConfig| { + custom_runtime_env_with_s3_support(session_config) + })); + + start_executor_process(Arc::new(config)).await + Ok(()) +} + +``` + +### Configuring Client + +```rust +let test_data = ballista_examples::test_util::examples_test_data(); + +// new sessions state with required custom session configuration and runtime environment +let state = + custom_session_state_with_s3_support(custom_session_config_with_s3_options()); + +let ctx: SessionContext = + SessionContext::remote_with_state("df://localhost:50050", state).await?; + +// once we have it all setup we can configure object store +// +// as session config has relevant S3 options registered and exposed, +// S3 configuration options can be changed using SQL `SET` statement. + +ctx.sql("SET s3.allow_http = true").await?.show().await?; + +ctx.sql(&format!("SET s3.access_key_id = '{}'", S3_ACCESS_KEY_ID)) + .await? + .show() + .await?; + +ctx.sql(&format!("SET s3.secret_access_key = '{}'", S3_SECRET_KEY)) + .await? + .show() + .await?; + +ctx.sql("SET s3.endpoint = 'http://localhost:9000'") + .await? + .show() + .await?; +ctx.sql("SET s3.allow_http = true").await?.show().await?; + +ctx.register_parquet( + "test", + &format!("{test_data}/alltypes_plain.parquet"), + Default::default(), +) +.await?; + +let write_dir_path = &format!("s3://{}/write_test.parquet", S3_BUCKET); + +ctx.sql("select * from test") + .await? + .write_parquet(write_dir_path, Default::default(), Default::default()) + .await?; + +ctx.register_parquet("written_table", write_dir_path, Default::default()) + .await?; + +let result = ctx + .sql("select id, string_col, timestamp_col from written_table where id > 4") + .await? + .collect() + .await?; + +let expected = [ + "+----+------------+---------------------+", + "| id | string_col | timestamp_col |", + "+----+------------+---------------------+", + "| 5 | 31 | 2009-03-01T00:01:00 |", + "| 6 | 30 | 2009-04-01T00:00:00 |", + "| 7 | 31 | 2009-04-01T00:01:00 |", + "+----+------------+---------------------+", +]; + +assert_batches_eq!(expected, &result); +``` + +## Example: Client Side Logical/Physical Codec + +Default physical and logical codecs can be replaced if needed. For scheduler and executor procedure is similar to previous example. At the client side procedure is slightly different, `ballista::prelude::SessionConfigExt` provides methods to be used to override physical and logical codecs on client side. + +```rust +let session_config = SessionConfig::new_with_ballista() + .with_information_schema(true) + .with_ballista_physical_extension_codec(Arc::new(BetterPhysicalCodec::default())) + .with_ballista_logical_extension_codec(Arc::new(BetterLogicalCodec::default())); + +let state = SessionStateBuilder::new() + .with_default_features() + .with_config(session_config) + .build(); + +let ctx: SessionContext = SessionContext::standalone_with_state(state).await?; +``` diff --git a/examples/Cargo.toml b/examples/Cargo.toml index c87c039cf..97b9f441b 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -34,8 +34,16 @@ path = "examples/standalone-sql.rs" required-features = ["ballista/standalone"] [dependencies] +anyhow = { workspace = true } ballista = { path = "../ballista/client", version = "0.12.0" } +ballista-core = { path = "../ballista/core", version = "0.12.0" } +ballista-executor = { path = "../ballista/executor", version = "0.12.0" } +ballista-scheduler = { path = "../ballista/scheduler", version = "0.12.0" } datafusion = { workspace = true } +env_logger = { workspace = true } +log = { workspace = true } +object_store = { workspace = true, features = ["aws"] } +parking_lot = { workspace = true } tokio = { workspace = true, features = [ "macros", "rt", @@ -43,4 +51,4 @@ tokio = { workspace = true, features = [ "sync", "parking_lot" ] } - +url = { workspace = true } diff --git a/examples/examples/custom-client.rs b/examples/examples/custom-client.rs new file mode 100644 index 000000000..3577621e4 --- /dev/null +++ b/examples/examples/custom-client.rs @@ -0,0 +1,123 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use ballista::extension::SessionContextExt; +use ballista_examples::object_store::{ + custom_session_config_with_s3_options, custom_session_state_with_s3_support, +}; +use datafusion::error::Result; +use datafusion::{assert_batches_eq, prelude::SessionContext}; + +/// bucket name to be used for this example +const S3_BUCKET: &str = "ballista"; +/// S3 access key +const S3_ACCESS_KEY_ID: &str = "MINIO"; +/// S3 secret key +const S3_SECRET_KEY: &str = "MINIOSECRET"; +/// +/// # Extending Ballista +/// +/// This example demonstrates how to extend ballista scheduler and executor registering new object store registry. +/// It uses local [minio](https://min.io) to act as S3 object store. +/// +/// Ballista will be extended providing custom session configuration, runtime environment and session state. +/// +/// Minio can be started: +/// +/// ```bash +/// docker run --rm -p 9000:9000 -p 9001:9001 -e "MINIO_ACCESS_KEY=MINIO" -e "MINIO_SECRET_KEY=MINIOSECRET" quay.io/minio/minio server /data --console-address ":9001" +/// ``` +/// After minio, we need to start `custom-scheduler` +/// +/// ```bash +/// cargo run --example custom-scheduler +/// ``` +/// +/// and `custom-executor` +/// +/// ```bash +/// cargo run --example custom-executor +/// ``` +/// +/// ```bash +/// cargo run --example custom-client +/// ``` +#[tokio::main] +async fn main() -> Result<()> { + let test_data = ballista_examples::test_util::examples_test_data(); + + // new sessions state with required custom session configuration and runtime environment + let state = + custom_session_state_with_s3_support(custom_session_config_with_s3_options()); + + let ctx: SessionContext = + SessionContext::remote_with_state("df://localhost:50050", state).await?; + + // session config has relevant S3 options registered and exposed. + // S3 configuration options can be changed using `SET` statement + ctx.sql("SET s3.allow_http = true").await?.show().await?; + + ctx.sql(&format!("SET s3.access_key_id = '{}'", S3_ACCESS_KEY_ID)) + .await? + .show() + .await?; + ctx.sql(&format!("SET s3.secret_access_key = '{}'", S3_SECRET_KEY)) + .await? + .show() + .await?; + ctx.sql("SET s3.endpoint = 'http://localhost:9000'") + .await? + .show() + .await?; + ctx.sql("SET s3.allow_http = true").await?.show().await?; + + ctx.register_parquet( + "test", + &format!("{test_data}/alltypes_plain.parquet"), + Default::default(), + ) + .await?; + + let write_dir_path = &format!("s3://{}/write_test.parquet", S3_BUCKET); + + ctx.sql("select * from test") + .await? + .write_parquet(write_dir_path, Default::default(), Default::default()) + .await?; + + ctx.register_parquet("written_table", write_dir_path, Default::default()) + .await?; + + let result = ctx + .sql("select id, string_col, timestamp_col from written_table where id > 4") + .await? + .collect() + .await?; + + let expected = [ + "+----+------------+---------------------+", + "| id | string_col | timestamp_col |", + "+----+------------+---------------------+", + "| 5 | 31 | 2009-03-01T00:01:00 |", + "| 6 | 30 | 2009-04-01T00:00:00 |", + "| 7 | 31 | 2009-04-01T00:01:00 |", + "+----+------------+---------------------+", + ]; + + assert_batches_eq!(expected, &result); + Ok(()) +} diff --git a/examples/examples/custom-executor.rs b/examples/examples/custom-executor.rs new file mode 100644 index 000000000..df3f7c241 --- /dev/null +++ b/examples/examples/custom-executor.rs @@ -0,0 +1,64 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use anyhow::Result; +use ballista_examples::object_store::{ + custom_runtime_env_with_s3_support, custom_session_config_with_s3_options, +}; +use ballista_executor::config::prelude::*; +use ballista_executor::executor_process::{ + start_executor_process, ExecutorProcessConfig, +}; +use datafusion::prelude::SessionConfig; +use std::sync::Arc; +/// +/// # Custom Ballista Executor +/// +/// This example demonstrates how to crate custom ballista executors. +/// +#[tokio::main] +async fn main() -> Result<()> { + let _ = env_logger::builder() + .filter_level(log::LevelFilter::Info) + .is_test(true) + .try_init(); + + let (opt, _remaining_args) = + Config::including_optional_config_files(&["/etc/ballista/executor.toml"]) + .unwrap_or_exit(); + + if opt.version { + ballista_core::print_version(); + std::process::exit(0); + } + + let mut config: ExecutorProcessConfig = opt.try_into().unwrap(); + + // overriding default config producer with custom producer + // which has required S3 configuration options + config.override_config_producer = + Some(Arc::new(custom_session_config_with_s3_options)); + + // overriding default runtime producer with custom producer + // which knows how to create S3 connections + config.override_runtime_producer = + Some(Arc::new(|session_config: &SessionConfig| { + custom_runtime_env_with_s3_support(session_config) + })); + + start_executor_process(Arc::new(config)).await +} diff --git a/examples/examples/custom-scheduler.rs b/examples/examples/custom-scheduler.rs new file mode 100644 index 000000000..30aeb3e3f --- /dev/null +++ b/examples/examples/custom-scheduler.rs @@ -0,0 +1,68 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use anyhow::Result; +use ballista_core::print_version; +use ballista_examples::object_store::{ + custom_session_config_with_s3_options, custom_session_state_with_s3_support, +}; +use ballista_scheduler::cluster::BallistaCluster; +use ballista_scheduler::config::{Config, ResultExt, SchedulerConfig}; +use ballista_scheduler::scheduler_process::start_server; +use datafusion::prelude::SessionConfig; +use std::sync::Arc; + +/// +/// # Custom Ballista Scheduler +/// +/// This example demonstrates how to crate custom made ballista schedulers. +/// +#[tokio::main] +async fn main() -> Result<()> { + let _ = env_logger::builder() + .filter_level(log::LevelFilter::Info) + .is_test(true) + .try_init(); + + // parse options + let (opt, _remaining_args) = + Config::including_optional_config_files(&["/etc/ballista/scheduler.toml"]) + .unwrap_or_exit(); + + if opt.version { + print_version(); + std::process::exit(0); + } + + let addr = format!("{}:{}", opt.bind_host, opt.bind_port); + let addr = addr.parse()?; + let mut config: SchedulerConfig = opt.try_into()?; + + // overriding default runtime producer with custom producer + // which knows how to create S3 connections + config.override_config_producer = + Some(Arc::new(custom_session_config_with_s3_options)); + // overriding default session builder, which has custom session configuration + // runtime environment and session state. + config.override_session_builder = Some(Arc::new(|session_config: SessionConfig| { + custom_session_state_with_s3_support(session_config) + })); + let cluster = BallistaCluster::new_from_config(&config).await?; + start_server(cluster, addr, Arc::new(config)).await?; + + Ok(()) +} diff --git a/examples/src/lib.rs b/examples/src/lib.rs index 6dc48f6b9..f8d7cc59b 100644 --- a/examples/src/lib.rs +++ b/examples/src/lib.rs @@ -15,4 +15,7 @@ // specific language governing permissions and limitations // under the License. +/// Provides required structures and methods to +/// integrate with S3 object store +pub mod object_store; pub mod test_util; diff --git a/examples/src/object_store.rs b/examples/src/object_store.rs new file mode 100644 index 000000000..130d47059 --- /dev/null +++ b/examples/src/object_store.rs @@ -0,0 +1,323 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! # Extending Ballista +//! +//! This example demonstrates extending standard ballista behavior, +//! integrating external [ObjectStoreRegistry]. +//! +//! [ObjectStore] is provided by [ObjectStoreRegistry], and configured +//! using [ExtensionOptions], which can be configured using SQL `SET` command. + +use ballista::prelude::SessionConfigExt; +use datafusion::common::{config_err, exec_err}; +use datafusion::config::{ + ConfigEntry, ConfigExtension, ConfigField, ExtensionOptions, Visit, +}; +use datafusion::error::Result; +use datafusion::execution::object_store::ObjectStoreRegistry; +use datafusion::execution::SessionState; +use datafusion::prelude::SessionConfig; +use datafusion::{ + error::DataFusionError, + execution::{ + runtime_env::{RuntimeConfig, RuntimeEnv}, + SessionStateBuilder, + }, +}; +use object_store::aws::AmazonS3Builder; +use object_store::local::LocalFileSystem; +use object_store::ObjectStore; +use parking_lot::RwLock; +use std::any::Any; +use std::fmt::Display; +use std::sync::Arc; +use url::Url; + +/// Custom [SessionConfig] constructor method +/// +/// This method registers config extension [S3Options] +/// which is used to configure [ObjectStore] with ACCESS and +/// SECRET key +pub fn custom_session_config_with_s3_options() -> SessionConfig { + SessionConfig::new_with_ballista() + .with_information_schema(true) + .with_option_extension(S3Options::default()) +} + +/// Custom [RuntimeEnv] constructor method +/// +/// It will register [CustomObjectStoreRegistry] which will +/// use configuration extension [S3Options] to configure +/// and created [ObjectStore]s +pub fn custom_runtime_env_with_s3_support( + session_config: &SessionConfig, +) -> Result> { + let s3options = session_config + .options() + .extensions + .get::() + .ok_or(DataFusionError::Configuration( + "S3 Options not set".to_string(), + ))?; + + let config = RuntimeConfig::new().with_object_store_registry(Arc::new( + CustomObjectStoreRegistry::new(s3options.clone()), + )); + + Ok(Arc::new(RuntimeEnv::new(config)?)) +} + +/// Custom [SessionState] constructor method +/// +/// It will configure [SessionState] with provided [SessionConfig], +/// and [RuntimeEnv]. +pub fn custom_session_state_with_s3_support( + session_config: SessionConfig, +) -> SessionState { + let runtime_env = custom_runtime_env_with_s3_support(&session_config).unwrap(); + + SessionStateBuilder::new() + .with_runtime_env(runtime_env) + .with_config(session_config) + .build() +} + +/// Custom [ObjectStoreRegistry] which will create +/// and configure [ObjectStore] using provided [S3Options] +#[derive(Debug)] +pub struct CustomObjectStoreRegistry { + local: Arc, + s3options: S3Options, +} + +impl CustomObjectStoreRegistry { + pub fn new(s3options: S3Options) -> Self { + Self { + s3options, + local: Arc::new(LocalFileSystem::new()), + } + } +} + +impl ObjectStoreRegistry for CustomObjectStoreRegistry { + fn register_store( + &self, + _url: &Url, + _store: Arc, + ) -> Option> { + unimplemented!("register_store not supported") + } + + fn get_store(&self, url: &Url) -> Result> { + let scheme = url.scheme(); + log::info!("get_store: {:?}", &self.s3options.config.read()); + match scheme { + "" | "file" => Ok(self.local.clone()), + "s3" => { + let s3store = + Self::s3_object_store_builder(url, &self.s3options.config.read())? + .build()?; + + Ok(Arc::new(s3store)) + } + + _ => exec_err!("get_store - store not supported, url {}", url), + } + } +} + +impl CustomObjectStoreRegistry { + fn s3_object_store_builder( + url: &Url, + aws_options: &S3RegistryConfiguration, + ) -> Result { + let S3RegistryConfiguration { + access_key_id, + secret_access_key, + session_token, + region, + endpoint, + allow_http, + } = aws_options; + + let bucket_name = Self::get_bucket_name(url)?; + let mut builder = AmazonS3Builder::from_env().with_bucket_name(bucket_name); + + if let (Some(access_key_id), Some(secret_access_key)) = + (access_key_id, secret_access_key) + { + builder = builder + .with_access_key_id(access_key_id) + .with_secret_access_key(secret_access_key); + + if let Some(session_token) = session_token { + builder = builder.with_token(session_token); + } + } else { + return config_err!( + "'s3.access_key_id' & 's3.secret_access_key' must be configured" + ); + } + + if let Some(region) = region { + builder = builder.with_region(region); + } + + if let Some(endpoint) = endpoint { + if let Ok(endpoint_url) = Url::try_from(endpoint.as_str()) { + if !matches!(allow_http, Some(true)) && endpoint_url.scheme() == "http" { + return config_err!("Invalid endpoint: {endpoint}. HTTP is not allowed for S3 endpoints. To allow HTTP, set 's3.allow_http' to true"); + } + } + + builder = builder.with_endpoint(endpoint); + } + + if let Some(allow_http) = allow_http { + builder = builder.with_allow_http(*allow_http); + } + + Ok(builder) + } + + fn get_bucket_name(url: &Url) -> Result<&str> { + url.host_str().ok_or_else(|| { + DataFusionError::Execution(format!( + "Not able to parse bucket name from url: {}", + url.as_str() + )) + }) + } +} + +/// Custom [SessionConfig] extension which allows +/// users to configure [ObjectStore] access using SQL +/// interface +#[derive(Debug, Clone, Default)] +pub struct S3Options { + config: Arc>, +} + +impl ExtensionOptions for S3Options { + fn as_any(&self) -> &dyn Any { + self + } + + fn as_any_mut(&mut self) -> &mut dyn Any { + self + } + + fn cloned(&self) -> Box { + Box::new(self.clone()) + } + + fn set(&mut self, key: &str, value: &str) -> Result<()> { + log::debug!("set config, key:{}, value:{}", key, value); + match key { + "access_key_id" => { + let mut c = self.config.write(); + c.access_key_id.set(key, value)?; + } + "secret_access_key" => { + let mut c = self.config.write(); + c.secret_access_key.set(key, value)?; + } + "session_token" => { + let mut c = self.config.write(); + c.session_token.set(key, value)?; + } + "region" => { + let mut c = self.config.write(); + c.region.set(key, value)?; + } + "endpoint" => { + let mut c = self.config.write(); + c.endpoint.set(key, value)?; + } + "allow_http" => { + let mut c = self.config.write(); + c.allow_http.set(key, value)?; + } + _ => { + log::warn!("Config value {} cant be set to {}", key, value); + return config_err!("Config value \"{}\" not found in S3Options", key); + } + } + Ok(()) + } + + fn entries(&self) -> Vec { + struct Visitor(Vec); + + impl Visit for Visitor { + fn some( + &mut self, + key: &str, + value: V, + description: &'static str, + ) { + self.0.push(ConfigEntry { + key: format!("{}.{}", S3Options::PREFIX, key), + value: Some(value.to_string()), + description, + }) + } + + fn none(&mut self, key: &str, description: &'static str) { + self.0.push(ConfigEntry { + key: format!("{}.{}", S3Options::PREFIX, key), + value: None, + description, + }) + } + } + let c = self.config.read(); + + let mut v = Visitor(vec![]); + c.access_key_id + .visit(&mut v, "access_key_id", "S3 Access Key"); + c.secret_access_key + .visit(&mut v, "secret_access_key", "S3 Secret Key"); + c.session_token + .visit(&mut v, "session_token", "S3 Session token"); + c.region.visit(&mut v, "region", "S3 region"); + c.endpoint.visit(&mut v, "endpoint", "S3 Endpoint"); + c.allow_http.visit(&mut v, "allow_http", "S3 Allow Http"); + + v.0 + } +} + +impl ConfigExtension for S3Options { + const PREFIX: &'static str = "s3"; +} +#[derive(Default, Debug, Clone)] +struct S3RegistryConfiguration { + /// Access Key ID + pub access_key_id: Option, + /// Secret Access Key + pub secret_access_key: Option, + /// Session token + pub session_token: Option, + /// AWS Region + pub region: Option, + /// OSS or COS Endpoint + pub endpoint: Option, + /// Allow HTTP (otherwise will always use https) + pub allow_http: Option, +}