diff --git a/ballista-cli/src/command.rs b/ballista-cli/src/command.rs index e5d88729b..3860e46db 100644 --- a/ballista-cli/src/command.rs +++ b/ballista-cli/src/command.rs @@ -21,11 +21,11 @@ use std::str::FromStr; use std::sync::Arc; use std::time::Instant; -use ballista::prelude::{BallistaError, Result}; - use datafusion::arrow::array::{ArrayRef, StringArray}; use datafusion::arrow::datatypes::{DataType, Field, Schema}; use datafusion::arrow::record_batch::RecordBatch; +use datafusion::common::Result; +use datafusion::error::DataFusionError; use datafusion::prelude::SessionContext; use crate::functions::{display_all_functions, Function}; @@ -60,25 +60,23 @@ impl Command { Self::Help => // TODO need to provide valid schema { - print_options - .print_batches(Arc::new(Schema::empty()), &[all_commands_info()], now) - .map_err(BallistaError::DataFusionError) + print_options.print_batches( + Arc::new(Schema::empty()), + &[all_commands_info()], + now, + ) } Self::ListTables => { let df = ctx.sql("SHOW TABLES").await?; let schema = Arc::new(df.schema().as_arrow().clone()); let batches = df.collect().await?; - print_options - .print_batches(schema, &batches, now) - .map_err(BallistaError::DataFusionError) + print_options.print_batches(schema, &batches, now) } Self::DescribeTable(name) => { let df = ctx.sql(&format!("SHOW COLUMNS FROM {name}")).await?; let schema = Arc::new(df.schema().as_arrow().clone()); let batches = df.collect().await?; - print_options - .print_batches(schema, &batches, now) - .map_err(BallistaError::DataFusionError) + print_options.print_batches(schema, &batches, now) } Self::QuietMode(quiet) => { if let Some(quiet) = quiet { @@ -95,12 +93,10 @@ impl Command { } Ok(()) } - Self::Quit => Err(BallistaError::Internal( + Self::Quit => Err(DataFusionError::Internal( "Unexpected quit, this should be handled outside".to_string(), )), - Self::ListFunctions => { - display_all_functions().map_err(BallistaError::DataFusionError) - } + Self::ListFunctions => display_all_functions(), Self::SearchFunctions(function) => { if let Ok(func) = function.parse::() { let details = func.function_details()?; @@ -108,10 +104,10 @@ impl Command { Ok(()) } else { let msg = format!("{function} is not a supported function"); - Err(BallistaError::NotImplemented(msg)) + Err(DataFusionError::NotImplemented(msg)) } } - Self::OutputFormat(_) => Err(BallistaError::Internal( + Self::OutputFormat(_) => Err(DataFusionError::Internal( "Unexpected change output format, this should be handled outside" .to_string(), )), @@ -221,7 +217,7 @@ impl OutputFormat { println!("Output format is {:?}.", print_options.format); Ok(()) } else { - Err(BallistaError::General(format!( + Err(DataFusionError::Execution(format!( "{:?} is not a valid format type [possible values: {:?}]", format, "TO BE FIXED", //PrintFormat::value_variants() diff --git a/ballista-cli/src/exec.rs b/ballista-cli/src/exec.rs index 8795b4e60..e9dd870d6 100644 --- a/ballista-cli/src/exec.rs +++ b/ballista-cli/src/exec.rs @@ -23,7 +23,7 @@ use std::io::BufReader; use std::sync::Arc; use std::time::Instant; -use ballista::prelude::Result; +use datafusion::common::Result; use datafusion::prelude::SessionContext; use rustyline::error::ReadlineError; use rustyline::Editor; diff --git a/ballista-cli/src/main.rs b/ballista-cli/src/main.rs index 3c9902a7f..caec5ed5f 100644 --- a/ballista-cli/src/main.rs +++ b/ballista-cli/src/main.rs @@ -18,15 +18,13 @@ use std::env; use std::path::Path; -use ballista::{ - extension::SessionConfigExt, - prelude::{Result, SessionContextExt}, -}; +use ballista::{extension::SessionConfigExt, prelude::SessionContextExt}; use ballista_cli::{ exec, print_format::PrintFormat, print_options::PrintOptions, BALLISTA_CLI_VERSION, }; use clap::Parser; use datafusion::{ + common::Result, execution::SessionStateBuilder, prelude::{SessionConfig, SessionContext}, }; diff --git a/ballista/client/Cargo.toml b/ballista/client/Cargo.toml index 219974d0d..a5d930307 100644 --- a/ballista/client/Cargo.toml +++ b/ballista/client/Cargo.toml @@ -50,6 +50,6 @@ object_store = { workspace = true, features = ["aws"] } testcontainers-modules = { version = "0.11", features = ["minio"] } [features] -default = [] +default = ["standalone"] standalone = ["ballista-executor", "ballista-scheduler"] testcontainers = [] diff --git a/ballista/client/README.md b/ballista/client/README.md index 503d6e08c..4f87b064b 100644 --- a/ballista/client/README.md +++ b/ballista/client/README.md @@ -92,16 +92,16 @@ data set. Download the file and add it to the `testdata` folder before running t ```rust,no_run use ballista::prelude::*; -use datafusion::prelude::{col, ParquetReadOptions}; +use datafusion::common::Result; +use datafusion::prelude::{col, SessionContext, ParquetReadOptions}; use datafusion::functions_aggregate::{min_max::min, min_max::max, sum::sum, average::avg}; #[tokio::main] async fn main() -> Result<()> { - // create configuration - let config = BallistaConfig::default(); + // connect to Ballista scheduler - let ctx = BallistaContext::remote("localhost", 50050, &config).await?; + let ctx = SessionContext::remote("df://localhost:50050").await?; let filename = "testdata/yellow_tripdata_2022-01.parquet"; @@ -154,4 +154,4 @@ The output should look similar to the following table. +-----------------+--------------------------+--------------------------+--------------------------+--------------------------+ ``` -More [examples](https://github.com/apache/arrow-ballista/tree/main/examples) can be found in the arrow-ballista repository. +More [examples](../../examples/examples/) can be found in the arrow-ballista repository. diff --git a/ballista/client/src/context.rs b/ballista/client/src/context.rs deleted file mode 100644 index 481285372..000000000 --- a/ballista/client/src/context.rs +++ /dev/null @@ -1,1044 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Distributed execution context. -#![allow(deprecated)] // TO BE REMOVED - -use datafusion::arrow::datatypes::SchemaRef; -use datafusion::execution::context::DataFilePaths; -use datafusion::sql::sqlparser::ast::Statement; -use datafusion::sql::TableReference; -use log::info; -use parking_lot::Mutex; -use std::collections::HashMap; -use std::sync::Arc; - -use ballista_core::config::BallistaConfig; -use ballista_core::serde::protobuf::scheduler_grpc_client::SchedulerGrpcClient; -use ballista_core::serde::protobuf::{CreateSessionParams, KeyValuePair}; -use ballista_core::utils::{ - create_df_ctx_with_ballista_query_planner, create_grpc_client_connection, - SessionConfigExt, -}; -use datafusion_proto::protobuf::LogicalPlanNode; - -use datafusion::dataframe::DataFrame; -use datafusion::datasource::{source_as_provider, TableProvider}; -use datafusion::error::{DataFusionError, Result}; -use datafusion::logical_expr::{ - CreateExternalTable, DdlStatement, LogicalPlan, TableScan, -}; -use datafusion::prelude::{ - AvroReadOptions, CsvReadOptions, NdJsonReadOptions, ParquetReadOptions, - SessionConfig, SessionContext, -}; -use datafusion::sql::parser::{DFParser, Statement as DFStatement}; - -struct BallistaContextState { - /// Ballista configuration - config: BallistaConfig, - /// Scheduler host - scheduler_host: String, - /// Scheduler port - scheduler_port: u16, - /// Tables that have been registered with this context - tables: HashMap>, -} - -impl BallistaContextState { - pub fn new( - scheduler_host: String, - scheduler_port: u16, - config: &BallistaConfig, - ) -> Self { - Self { - config: config.clone(), - scheduler_host, - scheduler_port, - tables: HashMap::new(), - } - } - - pub fn config(&self) -> &BallistaConfig { - &self.config - } -} - -#[deprecated] -pub struct BallistaContext { - state: Arc>, - context: Arc, -} - -impl BallistaContext { - /// Create a context for executing queries against a remote Ballista scheduler instance - pub async fn remote( - host: &str, - port: u16, - config: &BallistaConfig, - ) -> ballista_core::error::Result { - let state = BallistaContextState::new(host.to_owned(), port, config); - - let scheduler_url = - format!("http://{}:{}", &state.scheduler_host, state.scheduler_port); - info!( - "Connecting to Ballista scheduler at {}", - scheduler_url.clone() - ); - let connection = create_grpc_client_connection(scheduler_url.clone()) - .await - .map_err(|e| DataFusionError::Execution(format!("{e:?}")))?; - let limit = config.default_grpc_client_max_message_size(); - let mut scheduler = SchedulerGrpcClient::new(connection) - .max_encoding_message_size(limit) - .max_decoding_message_size(limit); - - let remote_session_id = scheduler - .create_session(CreateSessionParams { - settings: config - .settings() - .iter() - .map(|(k, v)| KeyValuePair { - key: k.to_owned(), - value: v.to_owned(), - }) - .collect::>(), - }) - .await - .map_err(|e| DataFusionError::Execution(format!("{e:?}")))? - .into_inner() - .session_id; - - info!( - "Server side SessionContext created with session id: {}", - remote_session_id - ); - - let ctx = { - create_df_ctx_with_ballista_query_planner::( - scheduler_url, - remote_session_id, - state.config(), - ) - }; - - Ok(Self { - state: Arc::new(Mutex::new(state)), - context: Arc::new(ctx), - }) - } - - #[cfg(feature = "standalone")] - pub async fn standalone( - config: &BallistaConfig, - concurrent_tasks: usize, - ) -> ballista_core::error::Result { - use ballista_core::serde::BallistaCodec; - use datafusion_proto::protobuf::PhysicalPlanNode; - - log::info!("Running in local mode. Scheduler will be run in-proc"); - - let addr = ballista_scheduler::standalone::new_standalone_scheduler().await?; - let scheduler_url = format!("http://localhost:{}", addr.port()); - let mut scheduler = loop { - match SchedulerGrpcClient::connect(scheduler_url.clone()).await { - Err(_) => { - tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; - log::info!("Attempting to connect to in-proc scheduler..."); - } - Ok(scheduler) => break scheduler, - } - }; - - let remote_session_id = scheduler - .create_session(CreateSessionParams { - settings: config - .settings() - .iter() - .map(|(k, v)| KeyValuePair { - key: k.to_owned(), - value: v.to_owned(), - }) - .collect::>(), - }) - .await - .map_err(|e| DataFusionError::Execution(format!("{e:?}")))? - .into_inner() - .session_id; - - info!( - "Server side SessionContext created with session id: {}", - remote_session_id - ); - - let ctx = { - create_df_ctx_with_ballista_query_planner::( - scheduler_url, - remote_session_id, - config, - ) - }; - - let default_codec: BallistaCodec = - BallistaCodec::default(); - - ballista_executor::new_standalone_executor( - scheduler, - concurrent_tasks, - default_codec, - ) - .await?; - - let state = - BallistaContextState::new("localhost".to_string(), addr.port(), config); - - Ok(Self { - state: Arc::new(Mutex::new(state)), - context: Arc::new(ctx), - }) - } - - /// Create a DataFrame representing an Json table scan - pub async fn read_json( - &self, - paths: P, - options: NdJsonReadOptions<'_>, - ) -> Result { - let df = self.context.read_json(paths, options).await?; - Ok(df) - } - - /// Create a DataFrame representing an Avro table scan - pub async fn read_avro( - &self, - paths: P, - options: AvroReadOptions<'_>, - ) -> Result { - let df = self.context.read_avro(paths, options).await?; - Ok(df) - } - - /// Create a DataFrame representing a Parquet table scan - pub async fn read_parquet( - &self, - paths: P, - options: ParquetReadOptions<'_>, - ) -> Result { - let df = self.context.read_parquet(paths, options).await?; - Ok(df) - } - - /// Create a DataFrame representing a CSV table scan - pub async fn read_csv( - &self, - paths: P, - options: CsvReadOptions<'_>, - ) -> Result { - let df = self.context.read_csv(paths, options).await?; - Ok(df) - } - - /// Register a DataFrame as a table that can be referenced from a SQL query - pub fn register_table( - &self, - name: &str, - table: Arc, - ) -> Result<()> { - let mut state = self.state.lock(); - state.tables.insert(name.to_owned(), table); - Ok(()) - } - - pub async fn register_csv( - &self, - name: &str, - path: &str, - options: CsvReadOptions<'_>, - ) -> Result<()> { - let plan = self - .read_csv(path, options) - .await - .map_err(|e| { - DataFusionError::Context(format!("Can't read CSV: {path}"), Box::new(e)) - })? - .into_optimized_plan()?; - match plan { - LogicalPlan::TableScan(TableScan { source, .. }) => { - self.register_table(name, source_as_provider(&source)?) - } - _ => Err(DataFusionError::Internal("Expected tables scan".to_owned())), - } - } - - pub async fn register_parquet( - &self, - name: &str, - path: &str, - options: ParquetReadOptions<'_>, - ) -> Result<()> { - match self - .read_parquet(path, options) - .await? - .into_optimized_plan()? - { - LogicalPlan::TableScan(TableScan { source, .. }) => { - self.register_table(name, source_as_provider(&source)?) - } - _ => Err(DataFusionError::Internal("Expected tables scan".to_owned())), - } - } - - pub async fn register_avro( - &self, - name: &str, - path: &str, - options: AvroReadOptions<'_>, - ) -> Result<()> { - match self.read_avro(path, options).await?.into_optimized_plan()? { - LogicalPlan::TableScan(TableScan { source, .. }) => { - self.register_table(name, source_as_provider(&source)?) - } - _ => Err(DataFusionError::Internal("Expected tables scan".to_owned())), - } - } - - /// is a 'show *' sql - pub async fn is_show_statement(&self, sql: &str) -> Result { - let mut is_show_variable: bool = false; - let statements = DFParser::parse_sql(sql)?; - - if statements.len() != 1 { - return Err(DataFusionError::NotImplemented( - "The context currently only supports a single SQL statement".to_string(), - )); - } - - if let DFStatement::Statement(st) = &statements[0] { - match **st { - Statement::ShowVariable { .. } => { - is_show_variable = true; - } - Statement::ShowColumns { .. } => { - is_show_variable = true; - } - Statement::ShowTables { .. } => { - is_show_variable = true; - } - _ => { - is_show_variable = false; - } - } - }; - - Ok(is_show_variable) - } - - pub fn context(&self) -> &SessionContext { - &self.context - } - - /// Create a DataFrame from a SQL statement. - /// - /// This method is `async` because queries of type `CREATE EXTERNAL TABLE` - /// might require the schema to be inferred. - pub async fn sql(&self, sql: &str) -> Result { - let mut ctx = self.context.clone(); - - let is_show = self.is_show_statement(sql).await?; - // the show tables、 show columns sql can not run at scheduler because the tables is store at client - if is_show { - ctx = Arc::new(SessionContext::new_with_config( - SessionConfig::new_with_ballista(), - )); - } - - // register tables with DataFusion context - { - let state = self.state.lock(); - for (name, prov) in &state.tables { - // ctx is shared between queries, check table exists or not before register - let table_ref = TableReference::Bare { - table: name.as_str().into(), - }; - if !ctx.table_exist(table_ref)? { - ctx.register_table( - TableReference::Bare { - table: name.as_str().into(), - }, - Arc::clone(prov), - )?; - } - } - } - - let plan = ctx.state().create_logical_plan(sql).await?; - - match &plan { - LogicalPlan::Ddl(DdlStatement::CreateExternalTable( - CreateExternalTable { - ref schema, - ref name, - ref location, - ref file_type, - ref table_partition_cols, - ref if_not_exists, - options, - .. - }, - )) => { - let table_exists = ctx.table_exist(name.to_owned())?; - let schema: SchemaRef = Arc::new(schema.as_ref().to_owned().into()); - let table_partition_cols = table_partition_cols - .iter() - .map(|col| { - schema - .field_with_name(col) - .map(|f| (f.name().to_owned(), f.data_type().to_owned())) - .map_err(|e| DataFusionError::ArrowError(e, None)) - }) - .collect::>>()?; - - match (if_not_exists, table_exists) { - (_, false) => match file_type.to_lowercase().as_str() { - "csv" => { - let has_header = match options.get("format.has_header") { - Some(str) => str.parse::().unwrap(), - None => false, - }; - let delimiter = match options.get("format.delimiter") { - Some(str) => str.chars().next().unwrap(), - None => ',', - }; - let mut options = CsvReadOptions::new() - .has_header(has_header) - .delimiter(delimiter as u8) - .table_partition_cols(table_partition_cols.to_vec()); - if !schema.fields().is_empty() { - options = options.schema(&schema); - } - self.register_csv(name.table(), location, options).await?; - Ok(DataFrame::new(ctx.state(), plan)) - } - "parquet" => { - self.register_parquet( - name.table(), - location, - ParquetReadOptions::default() - .table_partition_cols(table_partition_cols), - ) - .await?; - Ok(DataFrame::new(ctx.state(), plan)) - } - "avro" => { - self.register_avro( - name.table(), - location, - AvroReadOptions::default() - .table_partition_cols(table_partition_cols), - ) - .await?; - Ok(DataFrame::new(ctx.state(), plan)) - } - _ => Err(DataFusionError::NotImplemented(format!( - "Unsupported file type {file_type:?}." - ))), - }, - (true, true) => Ok(DataFrame::new(ctx.state(), plan)), - (false, true) => Err(DataFusionError::Execution(format!( - "Table '{name:?}' already exists" - ))), - } - } - _ => ctx.execute_logical_plan(plan).await, - } - } - - /// Execute the [`LogicalPlan`], return a [`DataFrame`]. This API - /// is not featured limited (so all SQL such as `CREATE TABLE` and - /// `COPY` will be run). - /// - /// If you wish to limit the type of plan that can be run from - /// SQL, see [`Self::sql_with_options`] and - /// [`SQLOptions::verify_plan`]. - pub async fn execute_logical_plan(&self, plan: LogicalPlan) -> Result { - let ctx = self.context.clone(); - ctx.execute_logical_plan(plan).await - } -} - -#[cfg(test)] -#[cfg(feature = "standalone")] -mod standalone_tests { - use ballista_core::config::BallistaConfig; - use datafusion::arrow; - use datafusion::arrow::util::pretty::pretty_format_batches; - - use crate::context::BallistaContext; - use ballista_core::error::Result; - use datafusion::config::TableParquetOptions; - use datafusion::dataframe::DataFrameWriteOptions; - use datafusion::datasource::listing::ListingTableUrl; - use datafusion::prelude::ParquetReadOptions; - use tempfile::TempDir; - - #[tokio::test] - async fn test_standalone_mode() { - use super::*; - let context = BallistaContext::standalone(&BallistaConfig::default(), 1) - .await - .unwrap(); - let df = context.sql("SELECT 1;").await.unwrap(); - df.collect().await.unwrap(); - } - - #[tokio::test] - async fn test_write_parquet() -> Result<()> { - use super::*; - let context = BallistaContext::standalone(&BallistaConfig::default(), 1).await?; - let df = context.sql("SELECT 1;").await?; - let tmp_dir = TempDir::new().unwrap(); - let file_path = format!( - "{}", - tmp_dir.path().join("test_write_parquet.parquet").display() - ); - df.write_parquet( - &file_path, - DataFrameWriteOptions::default(), - Some(TableParquetOptions::default()), - ) - .await?; - Ok(()) - } - - #[tokio::test] - async fn test_write_csv() -> Result<()> { - use super::*; - let context = BallistaContext::standalone(&BallistaConfig::default(), 1).await?; - let df = context.sql("SELECT 1;").await?; - let tmp_dir = TempDir::new().unwrap(); - let file_path = - format!("{}", tmp_dir.path().join("test_write_csv.csv").display()); - df.write_csv(&file_path, DataFrameWriteOptions::default(), None) - .await?; - Ok(()) - } - - #[tokio::test] - async fn test_ballista_show_tables() { - use super::*; - use std::fs::File; - use std::io::Write; - use tempfile::TempDir; - let context = BallistaContext::standalone(&BallistaConfig::default(), 1) - .await - .unwrap(); - - let data = "Jorge,2018-12-13T12:12:10.011Z\n\ - Andrew,2018-11-13T17:11:10.011Z"; - - let tmp_dir = TempDir::new().unwrap(); - let file_path = tmp_dir.path().join("timestamps.csv"); - - // scope to ensure the file is closed and written - { - File::create(&file_path) - .expect("creating temp file") - .write_all(data.as_bytes()) - .expect("writing data"); - } - - let sql = format!( - "CREATE EXTERNAL TABLE csv_with_timestamps ( - name VARCHAR, - ts TIMESTAMP - ) - STORED AS CSV - LOCATION '{}' - OPTIONS ('has_header' 'false', 'delimiter' ',') - ", - file_path.to_str().expect("path is utf8") - ); - - context.sql(sql.as_str()).await.unwrap(); - - let df = context.sql("show columns from csv_with_timestamps;").await; - - assert!(df.is_err()); - } - - #[tokio::test] - #[ignore = "this one fails after config change (will be removed)"] - async fn test_show_tables_not_with_information_schema() { - use super::*; - - use std::fs::File; - use std::io::Write; - use tempfile::TempDir; - let config = BallistaConfig::default(); - let context = BallistaContext::standalone(&config, 1).await.unwrap(); - - let data = "Jorge,2018-12-13T12:12:10.011Z\n\ - Andrew,2018-11-13T17:11:10.011Z"; - - let tmp_dir = TempDir::new().unwrap(); - let file_path = tmp_dir.path().join("timestamps.csv"); - - // scope to ensure the file is closed and written - { - File::create(&file_path) - .expect("creating temp file") - .write_all(data.as_bytes()) - .expect("writing data"); - } - - let sql = format!( - "CREATE EXTERNAL TABLE csv_with_timestamps ( - name VARCHAR, - ts TIMESTAMP - ) - STORED AS CSV - LOCATION '{}' - ", - file_path.to_str().expect("path is utf8") - ); - - context.sql(sql.as_str()).await.unwrap(); - let df = context.sql("show tables;").await; - assert!(df.is_ok()); - } - - #[tokio::test] - #[ignore] - // Tracking: https://github.com/apache/arrow-datafusion/issues/1840 - async fn test_task_stuck_when_referenced_task_failed() { - use super::*; - use datafusion::arrow::datatypes::Schema; - use datafusion::arrow::util::pretty; - use datafusion::datasource::file_format::csv::CsvFormat; - use datafusion::datasource::listing::{ - ListingOptions, ListingTable, ListingTableConfig, - }; - - let config = BallistaConfig::default(); - let context = BallistaContext::standalone(&config, 1).await.unwrap(); - - context - .register_parquet( - "single_nan", - "testdata/single_nan.parquet", - ParquetReadOptions::default(), - ) - .await - .unwrap(); - - { - let mut guard = context.state.lock(); - let csv_table = guard.tables.get("single_nan"); - - if let Some(table_provide) = csv_table { - if let Some(listing_table) = table_provide - .clone() - .as_any() - .downcast_ref::() - { - let x = listing_table.options(); - let error_options = ListingOptions { - file_extension: x.file_extension.clone(), - format: Arc::new(CsvFormat::default()), - table_partition_cols: x.table_partition_cols.clone(), - collect_stat: x.collect_stat, - target_partitions: x.target_partitions, - file_sort_order: vec![], - }; - - let table_paths = listing_table - .table_paths() - .iter() - .map(|t| ListingTableUrl::parse(t).unwrap()) - .collect(); - let config = ListingTableConfig::new_with_multi_paths(table_paths) - .with_schema(Arc::new(Schema::empty())) - .with_listing_options(error_options); - - let error_table = ListingTable::try_new(config).unwrap(); - - // change the table to an error table - guard - .tables - .insert("single_nan".to_string(), Arc::new(error_table)); - } - } - } - - let df = context - .sql("select count(1) from single_nan;") - .await - .unwrap(); - let results = df.collect().await.unwrap(); - pretty::print_batches(&results).unwrap(); - } - - #[tokio::test] - async fn test_empty_exec_with_one_row() { - use crate::context::BallistaContext; - - let config = BallistaConfig::default(); - let context = BallistaContext::standalone(&config, 1).await.unwrap(); - - let sql = "select EXTRACT(year FROM to_timestamp('2020-09-08T12:13:14+00:00'));"; - - let df = context.sql(sql).await.unwrap(); - assert!(!df.collect().await.unwrap().is_empty()); - } - - #[tokio::test] - async fn test_union_and_union_all() { - use super::*; - use datafusion::arrow::util::pretty::pretty_format_batches; - let config = BallistaConfig::default(); - let context = BallistaContext::standalone(&config, 1).await.unwrap(); - - let df = context - .sql("SELECT 1 as NUMBER union SELECT 1 as NUMBER;") - .await - .unwrap(); - let res1 = df.collect().await.unwrap(); - let expected1 = vec![ - "+--------+", - "| number |", - "+--------+", - "| 1 |", - "+--------+", - ]; - assert_eq!( - expected1, - pretty_format_batches(&res1) - .unwrap() - .to_string() - .trim() - .lines() - .collect::>() - ); - let expected2 = vec![ - "+--------+", - "| number |", - "+--------+", - "| 1 |", - "| 1 |", - "+--------+", - ]; - let df = context - .sql("SELECT 1 as NUMBER union all SELECT 1 as NUMBER;") - .await - .unwrap(); - let res2 = df.collect().await.unwrap(); - assert_eq!( - expected2, - pretty_format_batches(&res2) - .unwrap() - .to_string() - .trim() - .lines() - .collect::>() - ); - } - - #[tokio::test] - async fn test_aggregate_min_max() { - let context = create_test_context().await; - - let df = context.sql("select min(\"id\") from test").await.unwrap(); - let res = df.collect().await.unwrap(); - let expected = vec![ - "+--------------+", - "| min(test.id) |", - "+--------------+", - "| 0 |", - "+--------------+", - ]; - assert_result_eq(expected, &res); - - let df = context.sql("select max(\"id\") from test").await.unwrap(); - let res = df.collect().await.unwrap(); - let expected = vec![ - "+--------------+", - "| max(test.id) |", - "+--------------+", - "| 7 |", - "+--------------+", - ]; - assert_result_eq(expected, &res); - } - - #[tokio::test] - async fn test_aggregate_sum() { - let context = create_test_context().await; - - let df = context.sql("select SUM(\"id\") from test").await.unwrap(); - let res = df.collect().await.unwrap(); - let expected = vec![ - "+--------------+", - "| sum(test.id) |", - "+--------------+", - "| 28 |", - "+--------------+", - ]; - assert_result_eq(expected, &res); - } - #[tokio::test] - async fn test_aggregate_avg() { - let context = create_test_context().await; - - let df = context.sql("select AVG(\"id\") from test").await.unwrap(); - let res = df.collect().await.unwrap(); - let expected = vec![ - "+--------------+", - "| avg(test.id) |", - "+--------------+", - "| 3.5 |", - "+--------------+", - ]; - assert_result_eq(expected, &res); - } - - #[tokio::test] - async fn test_aggregate_count() { - let context = create_test_context().await; - - let df = context.sql("select COUNT(\"id\") from test").await.unwrap(); - let res = df.collect().await.unwrap(); - let expected = vec![ - "+----------------+", - "| count(test.id) |", - "+----------------+", - "| 8 |", - "+----------------+", - ]; - assert_result_eq(expected, &res); - } - #[tokio::test] - async fn test_aggregate_approx_distinct() { - let context = create_test_context().await; - - let df = context - .sql("select approx_distinct(\"id\") from test") - .await - .unwrap(); - let res = df.collect().await.unwrap(); - let expected = vec![ - "+--------------------------+", - "| approx_distinct(test.id) |", - "+--------------------------+", - "| 8 |", - "+--------------------------+", - ]; - assert_result_eq(expected, &res); - } - #[tokio::test] - async fn test_aggregate_array_agg() { - let context = create_test_context().await; - - let df = context - .sql("select ARRAY_AGG(\"id\") from test") - .await - .unwrap(); - let res = df.collect().await.unwrap(); - let expected = vec![ - "+--------------------------+", - "| array_agg(test.id) |", - "+--------------------------+", - "| [4, 5, 6, 7, 2, 3, 0, 1] |", - "+--------------------------+", - ]; - assert_result_eq(expected, &res); - } - #[tokio::test] - async fn test_aggregate_var() { - let context = create_test_context().await; - - let df = context.sql("select VAR(\"id\") from test").await.unwrap(); - let res = df.collect().await.unwrap(); - let expected = vec![ - "+-------------------+", - "| var(test.id) |", - "+-------------------+", - "| 6.000000000000001 |", - "+-------------------+", - ]; - assert_result_eq(expected, &res); - - let df = context - .sql("select VAR_POP(\"id\") from test") - .await - .unwrap(); - let res = df.collect().await.unwrap(); - let expected = vec![ - "+-------------------+", - "| var_pop(test.id) |", - "+-------------------+", - "| 5.250000000000001 |", - "+-------------------+", - ]; - assert_result_eq(expected, &res); - - let df = context - .sql("select VAR_SAMP(\"id\") from test") - .await - .unwrap(); - let res = df.collect().await.unwrap(); - let expected = vec![ - "+-------------------+", - "| var(test.id) |", - "+-------------------+", - "| 6.000000000000001 |", - "+-------------------+", - ]; - assert_result_eq(expected, &res); - } - #[tokio::test] - async fn test_aggregate_stddev() { - let context = create_test_context().await; - - let df = context - .sql("select STDDEV(\"id\") from test") - .await - .unwrap(); - let res = df.collect().await.unwrap(); - let expected = vec![ - "+--------------------+", - "| stddev(test.id) |", - "+--------------------+", - "| 2.4494897427831783 |", - "+--------------------+", - ]; - assert_result_eq(expected, &res); - - let df = context - .sql("select STDDEV_SAMP(\"id\") from test") - .await - .unwrap(); - let res = df.collect().await.unwrap(); - let expected = vec![ - "+--------------------+", - "| stddev(test.id) |", - "+--------------------+", - "| 2.4494897427831783 |", - "+--------------------+", - ]; - assert_result_eq(expected, &res); - } - #[tokio::test] - async fn test_aggregate_covar() { - let context = create_test_context().await; - - let df = context - .sql("select COVAR(id, tinyint_col) from test") - .await - .unwrap(); - let res = df.collect().await.unwrap(); - let expected = vec![ - "+--------------------------------------+", - "| covar_samp(test.id,test.tinyint_col) |", - "+--------------------------------------+", - "| 0.28571428571428586 |", - "+--------------------------------------+", - ]; - assert_result_eq(expected, &res); - } - #[tokio::test] - async fn test_aggregate_correlation() { - let context = create_test_context().await; - - let df = context - .sql("select CORR(id, tinyint_col) from test") - .await - .unwrap(); - let res = df.collect().await.unwrap(); - let expected = vec![ - "+--------------------------------+", - "| corr(test.id,test.tinyint_col) |", - "+--------------------------------+", - "| 0.21821789023599245 |", - "+--------------------------------+", - ]; - assert_result_eq(expected, &res); - } - // enable when upgrading Datafusion to > 42 - #[ignore] - #[tokio::test] - async fn test_aggregate_approx_percentile() { - let context = create_test_context().await; - - let df = context - .sql("select approx_percentile_cont_with_weight(id, 2, 0.5) from test") - .await - .unwrap(); - let res = df.collect().await.unwrap(); - let expected = vec![ - "+-------------------------------------------------------------------+", - "| approx_percentile_cont_with_weight(test.id,Int64(2),Float64(0.5)) |", - "+-------------------------------------------------------------------+", - "| 1 |", - "+-------------------------------------------------------------------+", - ]; - assert_result_eq(expected, &res); - - let df = context - .sql("select approx_percentile_cont(\"double_col\", 0.5) from test") - .await - .unwrap(); - let res = df.collect().await.unwrap(); - let expected = vec![ - "+------------------------------------------------------+", - "| approx_percentile_cont(test.double_col,Float64(0.5)) |", - "+------------------------------------------------------+", - "| 7.574999999999999 |", - "+------------------------------------------------------+", - ]; - - assert_result_eq(expected, &res); - } - - fn assert_result_eq( - expected: Vec<&str>, - results: &[arrow::record_batch::RecordBatch], - ) { - assert_eq!( - expected, - pretty_format_batches(results) - .unwrap() - .to_string() - .trim() - .lines() - .collect::>() - ); - } - async fn create_test_context() -> BallistaContext { - let config = BallistaConfig::default(); - let context = BallistaContext::standalone(&config, 4).await.unwrap(); - - context - .register_parquet( - "test", - "testdata/alltypes_plain.parquet", - ParquetReadOptions::default(), - ) - .await - .unwrap(); - context - } -} diff --git a/ballista/client/src/lib.rs b/ballista/client/src/lib.rs index 76bd0c940..1605d405b 100644 --- a/ballista/client/src/lib.rs +++ b/ballista/client/src/lib.rs @@ -17,6 +17,5 @@ #![doc = include_str!("../README.md")] -pub mod context; pub mod extension; pub mod prelude; diff --git a/ballista/client/src/prelude.rs b/ballista/client/src/prelude.rs index 1410476b2..ca6860d9e 100644 --- a/ballista/client/src/prelude.rs +++ b/ballista/client/src/prelude.rs @@ -17,12 +17,10 @@ //! Ballista Prelude (common imports) -pub use ballista_core::{ - config::BallistaConfig, - error::{BallistaError, Result}, -}; +// pub use ballista_core::{ +// config::BallistaConfig, +// error::{BallistaError, Result}, +// }; -#[allow(deprecated)] // TO BE REMOVED -pub use crate::context::BallistaContext; pub use crate::extension::{SessionConfigExt, SessionContextExt}; -pub use futures::StreamExt; +//pub use futures::StreamExt; diff --git a/ballista/client/tests/context_standalone.rs b/ballista/client/tests/context_standalone.rs index bd83d5276..c17b53e59 100644 --- a/ballista/client/tests/context_standalone.rs +++ b/ballista/client/tests/context_standalone.rs @@ -24,10 +24,10 @@ mod common; #[cfg(test)] #[cfg(feature = "standalone")] mod standalone_tests { - use ballista::extension::SessionContextExt; - use ballista_core::error::Result; + use ballista::prelude::SessionContextExt; use datafusion::arrow; use datafusion::arrow::util::pretty::pretty_format_batches; + use datafusion::common::Result; use datafusion::config::TableParquetOptions; use datafusion::dataframe::DataFrameWriteOptions; use datafusion::prelude::ParquetReadOptions; diff --git a/ballista/client/tests/remote.rs b/ballista/client/tests/remote.rs index b0184b265..c03db8524 100644 --- a/ballista/client/tests/remote.rs +++ b/ballista/client/tests/remote.rs @@ -19,7 +19,7 @@ mod common; #[cfg(test)] mod remote { - use ballista::extension::SessionContextExt; + use ballista::prelude::SessionContextExt; use datafusion::{assert_batches_eq, prelude::SessionContext}; #[tokio::test] diff --git a/ballista/client/tests/setup.rs b/ballista/client/tests/setup.rs index d1e487b5f..4dc6050df 100644 --- a/ballista/client/tests/setup.rs +++ b/ballista/client/tests/setup.rs @@ -19,7 +19,7 @@ mod common; #[cfg(test)] mod remote { - use ballista::extension::{SessionConfigExt, SessionContextExt}; + use ballista::prelude::{SessionConfigExt, SessionContextExt}; use datafusion::{ assert_batches_eq, execution::SessionStateBuilder, diff --git a/ballista/client/tests/standalone.rs b/ballista/client/tests/standalone.rs index c5f519b29..fe9f3df1a 100644 --- a/ballista/client/tests/standalone.rs +++ b/ballista/client/tests/standalone.rs @@ -20,7 +20,8 @@ mod common; #[cfg(test)] #[cfg(feature = "standalone")] mod standalone { - use ballista::{extension::SessionContextExt, prelude::*}; + use ballista::prelude::SessionContextExt; + use ballista_core::config::BallistaConfig; use datafusion::prelude::*; use datafusion::{assert_batches_eq, prelude::SessionContext}; diff --git a/examples/examples/remote-dataframe.rs b/examples/examples/remote-dataframe.rs index 74ae5f097..53fb4adf2 100644 --- a/examples/examples/remote-dataframe.rs +++ b/examples/examples/remote-dataframe.rs @@ -18,6 +18,7 @@ use ballista::prelude::*; use ballista_examples::test_util; use datafusion::{ + common::Result, execution::SessionStateBuilder, prelude::{col, lit, ParquetReadOptions, SessionConfig, SessionContext}, }; diff --git a/examples/examples/remote-sql.rs b/examples/examples/remote-sql.rs index 673b2dd62..566abb5d2 100644 --- a/examples/examples/remote-sql.rs +++ b/examples/examples/remote-sql.rs @@ -18,6 +18,7 @@ use ballista::prelude::*; use ballista_examples::test_util; use datafusion::{ + common::Result, execution::SessionStateBuilder, prelude::{CsvReadOptions, SessionConfig, SessionContext}, }; diff --git a/examples/examples/standalone-sql.rs b/examples/examples/standalone-sql.rs index 9a03bd9b4..da711d308 100644 --- a/examples/examples/standalone-sql.rs +++ b/examples/examples/standalone-sql.rs @@ -15,9 +15,10 @@ // specific language governing permissions and limitations // under the License. -use ballista::prelude::{Result, SessionConfigExt, SessionContextExt}; +use ballista::prelude::{SessionConfigExt, SessionContextExt}; use ballista_examples::test_util; use datafusion::{ + common::Result, execution::{options::ParquetReadOptions, SessionStateBuilder}, prelude::{SessionConfig, SessionContext}, };