From 805438f85a09f0428ff4afeaac564fff32ba331a Mon Sep 17 00:00:00 2001 From: Weston Pace Date: Fri, 20 Dec 2024 13:10:41 -0800 Subject: [PATCH] fix: when taking struct fields they should be merged into the output in the correct order (#3277) In various situations we need to fetch some fields from a struct and then later add more fields for the struct. For example, maybe we have a `struct`. We might query with a filter on `filter` and then use late materialization to take `big_string`. When we do this we were previously creating `struct` which would cause issues since that isn't the correct data type. In creating this fix I added a new `Projection` concept and I would like to slowly replace a lot of the places where we use schemas as projections to use `Projection` instead. Not necessarily for performance but more for convenience. --- python/python/tests/test_filter.py | 16 + rust/lance-arrow/src/lib.rs | 159 +++++++++- rust/lance-core/src/datatypes.rs | 4 +- rust/lance-core/src/datatypes/field.rs | 160 +++++++++- rust/lance-core/src/datatypes/schema.rs | 279 ++++++++++++++++- rust/lance/src/datafusion/logical_plan.rs | 7 +- rust/lance/src/dataset.rs | 18 +- rust/lance/src/dataset/fragment.rs | 10 +- rust/lance/src/dataset/scanner.rs | 168 ++++++----- rust/lance/src/dataset/updater.rs | 12 +- rust/lance/src/dataset/write.rs | 16 +- rust/lance/src/dataset/write/merge_insert.rs | 33 +- rust/lance/src/io/exec/optimizer.rs | 80 ++++- rust/lance/src/io/exec/take.rs | 302 ++++++++++++++----- 14 files changed, 1045 insertions(+), 219 deletions(-) diff --git a/python/python/tests/test_filter.py b/python/python/tests/test_filter.py index 5ca6e645e4..2fad73a7b8 100644 --- a/python/python/tests/test_filter.py +++ b/python/python/tests/test_filter.py @@ -257,3 +257,19 @@ def test_duckdb(tmp_path): expected = duckdb.query("SELECT id, meta, price FROM ds").to_df() expected = expected[expected.meta == "aa"].reset_index(drop=True) tm.assert_frame_equal(actual, expected) + + +def test_struct_field_order(tmp_path): + """ + This test regresses some old behavior where the order of struct fields would get + messed up due to late materialization and we would get {y,x} instead of {x,y} + """ + data = pa.table({"struct": [{"x": i, "y": i} for i in range(10)]}) + dataset = lance.write_dataset(data, tmp_path) + + for late_materialization in [True, False]: + result = dataset.to_table( + filter="struct.y > 5", late_materialization=late_materialization + ) + expected = pa.table({"struct": [{"x": i, "y": i} for i in range(6, 10)]}) + assert result == expected diff --git a/rust/lance-arrow/src/lib.rs b/rust/lance-arrow/src/lib.rs index 9a806b0492..78c2b224e9 100644 --- a/rust/lance-arrow/src/lib.rs +++ b/rust/lance-arrow/src/lib.rs @@ -349,6 +349,17 @@ pub trait RecordBatchExt { /// Merge with another [`RecordBatch`] and returns a new one. /// + /// Fields are merged based on name. First we iterate the left columns. If a matching + /// name is found in the right then we merge the two columns. If there is no match then + /// we add the left column to the output. + /// + /// To merge two columns we consider the type. If both arrays are struct arrays we recurse. + /// Otherwise we use the left array. + /// + /// Afterwards we add all non-matching right columns to the output. + /// + /// Note: This method likely does not handle nested fields correctly and you may want to consider + /// using [`merge_with_schema`] instead. /// ``` /// use std::sync::Arc; /// use arrow_array::*; @@ -382,6 +393,17 @@ pub trait RecordBatchExt { /// TODO: add merge nested fields support. fn merge(&self, other: &RecordBatch) -> Result; + /// Create a batch by merging columns between two batches with a given schema. + /// + /// A reference schema is used to determine the proper ordering of nested fields. + /// + /// For each field in the reference schema we look for corresponding fields in + /// the left and right batches. If a field is found in both batches we recursively merge + /// it. + /// + /// If a field is only in the left or right batch we take it as it is. + fn merge_with_schema(&self, other: &RecordBatch, schema: &Schema) -> Result; + /// Drop one column specified with the name and return the new [`RecordBatch`]. /// /// If the named column does not exist, it returns a copy of this [`RecordBatch`]. @@ -450,6 +472,23 @@ impl RecordBatchExt for RecordBatch { self.try_new_from_struct_array(merge(&left_struct_array, &right_struct_array)) } + fn merge_with_schema(&self, other: &RecordBatch, schema: &Schema) -> Result { + if self.num_rows() != other.num_rows() { + return Err(ArrowError::InvalidArgumentError(format!( + "Attempt to merge two RecordBatch with different sizes: {} != {}", + self.num_rows(), + other.num_rows() + ))); + } + let left_struct_array: StructArray = self.clone().into(); + let right_struct_array: StructArray = other.clone().into(); + self.try_new_from_struct_array(merge_with_schema( + &left_struct_array, + &right_struct_array, + schema.fields(), + )) + } + fn drop_column(&self, name: &str) -> Result { let mut fields = vec![]; let mut columns = vec![]; @@ -542,7 +581,6 @@ fn project(struct_array: &StructArray, fields: &Fields) -> Result { StructArray::try_new(fields.clone(), columns, None) } -/// Merge the fields and columns of two RecordBatch's recursively fn merge(left_struct_array: &StructArray, right_struct_array: &StructArray) -> StructArray { let mut fields: Vec = vec![]; let mut columns: Vec = vec![]; @@ -616,6 +654,77 @@ fn merge(left_struct_array: &StructArray, right_struct_array: &StructArray) -> S StructArray::from(zipped) } +fn merge_with_schema( + left_struct_array: &StructArray, + right_struct_array: &StructArray, + fields: &Fields, +) -> StructArray { + // Helper function that returns true if both types are struct or both are non-struct + fn same_type_kind(left: &DataType, right: &DataType) -> bool { + match (left, right) { + (DataType::Struct(_), DataType::Struct(_)) => true, + (DataType::Struct(_), _) => false, + (_, DataType::Struct(_)) => false, + _ => true, + } + } + + let mut output_fields: Vec = Vec::with_capacity(fields.len()); + let mut columns: Vec = Vec::with_capacity(fields.len()); + + let left_fields = left_struct_array.fields(); + let left_columns = left_struct_array.columns(); + let right_fields = right_struct_array.fields(); + let right_columns = right_struct_array.columns(); + + for field in fields { + let left_match_idx = left_fields.iter().position(|f| { + f.name() == field.name() && same_type_kind(f.data_type(), field.data_type()) + }); + let right_match_idx = right_fields.iter().position(|f| { + f.name() == field.name() && same_type_kind(f.data_type(), field.data_type()) + }); + + match (left_match_idx, right_match_idx) { + (None, Some(right_idx)) => { + output_fields.push(right_fields[right_idx].as_ref().clone()); + columns.push(right_columns[right_idx].clone()); + } + (Some(left_idx), None) => { + output_fields.push(left_fields[left_idx].as_ref().clone()); + columns.push(left_columns[left_idx].clone()); + } + (Some(left_idx), Some(right_idx)) => { + if let DataType::Struct(child_fields) = field.data_type() { + let left_sub_array = left_columns[left_idx].as_struct(); + let right_sub_array = right_columns[right_idx].as_struct(); + let merged_sub_array = + merge_with_schema(left_sub_array, right_sub_array, child_fields); + output_fields.push(Field::new( + field.name(), + merged_sub_array.data_type().clone(), + field.is_nullable(), + )); + columns.push(Arc::new(merged_sub_array) as ArrayRef); + } else { + output_fields.push(left_fields[left_idx].as_ref().clone()); + columns.push(left_columns[left_idx].clone()); + } + } + (None, None) => { + // The field will not be included in the output + } + } + } + + let zipped: Vec<(FieldRef, ArrayRef)> = output_fields + .into_iter() + .map(Arc::new) + .zip(columns) + .collect::>(); + StructArray::from(zipped) +} + fn get_sub_array<'a>(array: &'a ArrayRef, components: &[&str]) -> Option<&'a ArrayRef> { if components.is_empty() { return Some(array); @@ -721,7 +830,7 @@ impl BufferExt for arrow_buffer::Buffer { #[cfg(test)] mod tests { use super::*; - use arrow_array::{Int32Array, StringArray}; + use arrow_array::{new_empty_array, Int32Array, StringArray}; #[test] fn test_merge_recursive() { @@ -808,6 +917,52 @@ mod tests { assert_eq!(result, merged_batch); } + #[test] + fn test_merge_with_schema() { + fn test_batch(names: &[&str], types: &[DataType]) -> (Schema, RecordBatch) { + let fields: Fields = names + .iter() + .zip(types) + .map(|(name, ty)| Field::new(name.to_string(), ty.clone(), false)) + .collect(); + let schema = Schema::new(vec![Field::new( + "struct", + DataType::Struct(fields.clone()), + false, + )]); + let children = types + .iter() + .map(|ty| new_empty_array(ty)) + .collect::>(); + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![Arc::new(StructArray::new(fields, children, None)) as ArrayRef], + ); + (schema, batch.unwrap()) + } + + let (_, left_batch) = test_batch(&["a", "b"], &[DataType::Int32, DataType::Int64]); + let (_, right_batch) = test_batch(&["c", "b"], &[DataType::Int32, DataType::Int64]); + let (output_schema, _) = test_batch( + &["b", "a", "c"], + &[DataType::Int64, DataType::Int32, DataType::Int32], + ); + + // If we use merge_with_schema the schema is respected + let merged = left_batch + .merge_with_schema(&right_batch, &output_schema) + .unwrap(); + assert_eq!(merged.schema().as_ref(), &output_schema); + + // If we use merge we get first-come first-serve based on the left batch + let (naive_schema, _) = test_batch( + &["a", "b", "c"], + &[DataType::Int32, DataType::Int64, DataType::Int32], + ); + let merged = left_batch.merge(&right_batch).unwrap(); + assert_eq!(merged.schema().as_ref(), &naive_schema); + } + #[test] fn test_take_record_batch() { let schema = Arc::new(Schema::new(vec![ diff --git a/rust/lance-core/src/datatypes.rs b/rust/lance-core/src/datatypes.rs index e7d3f28a97..920e4cf38e 100644 --- a/rust/lance-core/src/datatypes.rs +++ b/rust/lance-core/src/datatypes.rs @@ -19,10 +19,10 @@ mod schema; use crate::{Error, Result}; pub use field::{ - Encoding, Field, NullabilityComparison, SchemaCompareOptions, StorageClass, + Encoding, Field, NullabilityComparison, OnTypeMismatch, SchemaCompareOptions, StorageClass, LANCE_STORAGE_CLASS_SCHEMA_META_KEY, }; -pub use schema::Schema; +pub use schema::{OnMissing, Projectable, Projection, Schema}; pub const COMPRESSION_META_KEY: &str = "lance-encoding:compression"; pub const COMPRESSION_LEVEL_META_KEY: &str = "lance-encoding:compression-level"; diff --git a/rust/lance-core/src/datatypes/field.rs b/rust/lance-core/src/datatypes/field.rs index 45351ebb86..d94926c31d 100644 --- a/rust/lance-core/src/datatypes/field.rs +++ b/rust/lance-core/src/datatypes/field.rs @@ -4,8 +4,8 @@ //! Lance Schema Field use std::{ - cmp::max, - collections::HashMap, + cmp::{max, Ordering}, + collections::{HashMap, VecDeque}, fmt::{self, Display}, str::FromStr, sync::Arc, @@ -25,7 +25,7 @@ use snafu::{location, Location}; use super::{ schema::{compare_fields, explain_fields_difference}, - Dictionary, LogicalType, + Dictionary, LogicalType, Projection, }; use crate::{Error, Result}; @@ -108,6 +108,13 @@ impl FromStr for StorageClass { } } +/// What to do on a merge operation if the types of the fields don't match +#[derive(Debug, Clone, Copy, PartialEq, Eq, DeepSizeOf)] +pub enum OnTypeMismatch { + TakeSelf, + Error, +} + /// Lance Schema Field /// #[derive(Debug, Clone, PartialEq, DeepSizeOf)] @@ -162,6 +169,106 @@ impl Field { self.storage_class } + /// Merge a field with another field using a reference field to ensure + /// the correct order of fields + /// + /// For each child in the reference field we look for a matching child + /// in self and other. + /// + /// If we find a match in both we recursively merge the children. + /// If we find a match in one but not the other we take the matching child. + /// + /// Primitive fields we simply clone self and return. + /// + /// Matches are determined using field names and so ids are not required. + pub fn merge_with_reference(&self, other: &Self, reference: &Self) -> Self { + let mut new_children = Vec::with_capacity(reference.children.len()); + let mut self_children_itr = self.children.iter().peekable(); + let mut other_children_itr = other.children.iter().peekable(); + for ref_child in &reference.children { + match (self_children_itr.peek(), other_children_itr.peek()) { + (Some(&only_child), None) => { + // other is exhausted so just check if self matches + if only_child.name == ref_child.name { + new_children.push(only_child.clone()); + self_children_itr.next(); + } + } + (None, Some(&only_child)) => { + // Self is exhausted so just check if other matches + if only_child.name == ref_child.name { + new_children.push(only_child.clone()); + other_children_itr.next(); + } + } + (Some(&self_child), Some(&other_child)) => { + // Both iterators have potential, see if any match + match ( + ref_child.name.cmp(&self_child.name), + ref_child.name.cmp(&other_child.name), + ) { + (Ordering::Equal, Ordering::Equal) => { + // Both match, recursively merge + new_children + .push(self_child.merge_with_reference(other_child, ref_child)); + self_children_itr.next(); + other_children_itr.next(); + } + (Ordering::Equal, _) => { + // Self matches, other doesn't, use self as-is + new_children.push(self_child.clone()); + self_children_itr.next(); + } + (_, Ordering::Equal) => { + // Other matches, self doesn't, use other as-is + new_children.push(other_child.clone()); + other_children_itr.next(); + } + _ => { + // Neither match, field is projected out + } + } + } + (None, None) => { + // Both iterators are exhausted, we can quit, all remaining fields projected out + break; + } + } + } + Self { + children: new_children, + ..self.clone() + } + } + + pub fn apply_projection(&self, projection: &Projection) -> Option { + let children = self + .children + .iter() + .filter_map(|c| c.apply_projection(projection)) + .collect::>(); + + // The following case is invalid: + // - This is a nested field (has children) + // - All children were projected away + // - Caller is asking for the parent field + assert!( + // One of the following must be true + !children.is_empty() // Some children were projected + || !projection.contains_field_id(self.id) // Caller is not asking for this field + || self.children.is_empty() // This isn't a nested field + ); + + if children.is_empty() && !projection.contains_field_id(self.id) { + None + } else { + Some(Self { + children, + ..self.clone() + }) + } + } + pub(crate) fn explain_differences( &self, expected: &Self, @@ -456,7 +563,7 @@ impl Field { /// Project by a field. /// - pub fn project_by_field(&self, other: &Self) -> Result { + pub fn project_by_field(&self, other: &Self, on_type_mismatch: OnTypeMismatch) -> Result { if self.name != other.name { return Err(Error::Schema { message: format!( @@ -496,7 +603,7 @@ impl Field { location: location!(), }); }; - fields.push(child.project_by_field(other_field)?); + fields.push(child.project_by_field(other_field, on_type_mismatch)?); } let mut cloned = self.clone(); cloned.children = fields; @@ -504,7 +611,8 @@ impl Field { } (DataType::List(_), DataType::List(_)) | (DataType::LargeList(_), DataType::LargeList(_)) => { - let projected = self.children[0].project_by_field(&other.children[0])?; + let projected = + self.children[0].project_by_field(&other.children[0], on_type_mismatch)?; let mut cloned = self.clone(); cloned.children = vec![projected]; Ok(cloned) @@ -524,13 +632,33 @@ impl Field { { Ok(self.clone()) } - _ => Err(Error::Schema { - message: format!( - "Attempt to project incompatible fields: {} and {}", - self, other - ), - location: location!(), - }), + _ => match on_type_mismatch { + OnTypeMismatch::Error => Err(Error::Schema { + message: format!( + "Attempt to project incompatible fields: {} and {}", + self, other + ), + location: location!(), + }), + OnTypeMismatch::TakeSelf => Ok(self.clone()), + }, + } + } + + pub(crate) fn resolve<'a>( + &'a self, + split: &mut VecDeque<&str>, + fields: &mut Vec<&'a Self>, + ) -> bool { + fields.push(self); + if split.is_empty() { + return true; + } + let first = split.pop_front().unwrap(); + if let Some(child) = self.children.iter().find(|c| c.name == first) { + child.resolve(split, fields) + } else { + false } } @@ -970,19 +1098,19 @@ mod tests { let f2: Field = ArrowField::new("a", DataType::Null, true) .try_into() .unwrap(); - let p1 = f1.project_by_field(&f2).unwrap(); + let p1 = f1.project_by_field(&f2, OnTypeMismatch::Error).unwrap(); assert_eq!(p1, f1); let f3: Field = ArrowField::new("b", DataType::Null, true) .try_into() .unwrap(); - assert!(f1.project_by_field(&f3).is_err()); + assert!(f1.project_by_field(&f3, OnTypeMismatch::Error).is_err()); let f4: Field = ArrowField::new("a", DataType::Int32, true) .try_into() .unwrap(); - assert!(f1.project_by_field(&f4).is_err()); + assert!(f1.project_by_field(&f4, OnTypeMismatch::Error).is_err()); } #[test] diff --git a/rust/lance-core/src/datatypes/schema.rs b/rust/lance-core/src/datatypes/schema.rs index ed17394824..dde75bbfcf 100644 --- a/rust/lance-core/src/datatypes/schema.rs +++ b/rust/lance-core/src/datatypes/schema.rs @@ -4,8 +4,9 @@ //! Schema use std::{ - collections::{HashMap, HashSet}, + collections::{HashMap, HashSet, VecDeque}, fmt::{self, Debug, Formatter}, + sync::Arc, }; use arrow_array::RecordBatch; @@ -14,8 +15,8 @@ use deepsize::DeepSizeOf; use lance_arrow::*; use snafu::{location, Location}; -use super::field::{Field, SchemaCompareOptions, StorageClass}; -use crate::{Error, Result}; +use super::field::{Field, OnTypeMismatch, SchemaCompareOptions, StorageClass}; +use crate::{Error, Result, ROW_ADDR, ROW_ID}; /// Lance Schema. #[derive(Default, Debug, Clone, DeepSizeOf)] @@ -152,6 +153,26 @@ impl Schema { ArrowSchema::from(self).to_compact_string(indent) } + /// Given a string column reference, resolve the path of fields + /// + /// For example, given a.b.c we will return the fields [a, b, c] + /// + /// Returns None if we can't find a segment at any point + pub fn resolve(&self, column: impl AsRef) -> Option> { + let mut split = column.as_ref().split('.').collect::>(); + let mut fields = Vec::with_capacity(split.len()); + let first = split.pop_front().unwrap(); + if let Some(field) = self.field(first) { + if field.resolve(&mut split, &mut fields) { + Some(fields) + } else { + None + } + } else { + None + } + } + fn do_project>(&self, columns: &[T], err_on_missing: bool) -> Result { let mut candidates: Vec = vec![]; for col in columns { @@ -304,13 +325,15 @@ impl Schema { pub fn project_by_schema>( &self, projection: S, + on_missing: OnMissing, + on_type_mismatch: OnTypeMismatch, ) -> Result { let projection = projection.try_into()?; let mut new_fields = vec![]; for field in projection.fields.iter() { if let Some(self_field) = self.field(&field.name) { - new_fields.push(self_field.project_by_field(field)?); - } else { + new_fields.push(self_field.project_by_field(field, on_type_mismatch)?); + } else if matches!(on_missing, OnMissing::Error) { return Err(Error::Schema { message: format!("Field {} not found", field.name), location: location!(), @@ -750,6 +773,248 @@ fn explain_metadata_difference( } } +/// What to do when a column is missing in the schema +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum OnMissing { + Error, + Ignore, +} + +/// A trait for something that we can project fields from. +pub trait Projectable: Debug + Send + Sync { + fn schema(&self) -> &Schema; +} + +impl Projectable for Schema { + fn schema(&self) -> &Schema { + self + } +} + +/// A projection is a selection of fields in a schema +/// +/// In addition we record whether the row_id or row_addr are +/// selected (these fields have no field id) +#[derive(Clone)] +pub struct Projection { + base: Arc, + pub field_ids: HashSet, + pub with_row_id: bool, + pub with_row_addr: bool, +} + +impl Debug for Projection { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.debug_struct("Projection") + .field("schema", &self.to_schema()) + .field("with_row_id", &self.with_row_id) + .field("with_row_addr", &self.with_row_addr) + .finish() + } +} + +impl Projection { + /// Create a new empty projection + pub fn empty(base: Arc) -> Self { + Self { + base, + field_ids: HashSet::new(), + with_row_id: false, + with_row_addr: false, + } + } + + /// Add a column (and any of its parents) to the projection from a string reference + pub fn union_column(mut self, column: impl AsRef, on_missing: OnMissing) -> Result { + let column = column.as_ref(); + if column == ROW_ID { + self.with_row_id = true; + return Ok(self); + } else if column == ROW_ADDR { + self.with_row_addr = true; + return Ok(self); + } + + if let Some(fields) = self.base.schema().resolve(column) { + self.field_ids.extend(fields.iter().map(|f| f.id)); + } else if matches!(on_missing, OnMissing::Error) { + return Err(Error::InvalidInput { + source: format!("Column {} does not exist", column).into(), + location: location!(), + }); + } + Ok(self) + } + + /// True if the projection selects the given field id + pub fn contains_field_id(&self, id: i32) -> bool { + self.field_ids.contains(&id) + } + + /// Add multiple columns (and their parents) to the projection + pub fn union_columns( + mut self, + columns: impl IntoIterator>, + on_missing: OnMissing, + ) -> Result { + for column in columns { + self = self.union_column(column, on_missing)?; + } + Ok(self) + } + + /// Adds all fields from the base schema satisfying a predicate + pub fn union_predicate(mut self, predicate: impl Fn(&Field) -> bool) -> Self { + for field in self.base.schema().fields_pre_order() { + if predicate(field) { + self.field_ids.insert(field.id); + } + } + self + } + + /// Removes all fields in the base schema satisfying a predicate + pub fn subtract_predicate(mut self, predicate: impl Fn(&Field) -> bool) -> Self { + for field in self.base.schema().fields_pre_order() { + if predicate(field) { + self.field_ids.remove(&field.id); + } + } + self + } + + /// Creates a new projection that is the intersection of this projection and another + pub fn intersect(mut self, other: &Self) -> Self { + self.field_ids = HashSet::from_iter(self.field_ids.intersection(&other.field_ids).copied()); + self.with_row_id = self.with_row_id && other.with_row_id; + self.with_row_addr = self.with_row_addr && other.with_row_addr; + self + } + + /// Adds all fields from the provided schema to the projection + /// + /// Fields are only added if they exist in the base schema, otherwise they + /// are ignored. + /// + /// Will panic if a field in the given schema has a non-negative id and is not in the base schema. + pub fn union_schema(mut self, other: &Schema) -> Self { + for field in other.fields_pre_order() { + if field.id >= 0 { + self.field_ids.insert(field.id); + } else if field.name == ROW_ID { + self.with_row_id = true; + } else if field.name == ROW_ADDR { + self.with_row_addr = true; + } else { + // If a field is not in our schema then it should probably have an id of -1. If it isn't -1 + // that probably implies some kind of weird schema mixing is going on and we should panic. + debug_assert_eq!(field.id, -1); + } + } + self + } + + /// Creates a new projection that is the union of this projection and another + pub fn union_projection(mut self, other: &Self) -> Self { + self.field_ids.extend(&other.field_ids); + self.with_row_id = self.with_row_id || other.with_row_id; + self.with_row_addr = self.with_row_addr || other.with_row_addr; + self + } + + /// Adds all fields from the given schema to the projection + /// + /// on_missing controls what happen to fields that are not in the base schema + /// + /// Name based matching is used to determine if a field is in the base schema. + pub fn union_arrow_schema( + mut self, + other: &ArrowSchema, + on_missing: OnMissing, + ) -> Result { + self.with_row_id |= other.fields().iter().any(|f| f.name() == ROW_ID); + self.with_row_addr |= other.fields().iter().any(|f| f.name() == ROW_ADDR); + let other = + self.base + .schema() + .project_by_schema(other, on_missing, OnTypeMismatch::TakeSelf)?; + Ok(self.union_schema(&other)) + } + + /// Removes all fields from the projection that are in the given schema + /// + /// on_missing controls what happen to fields that are not in the base schema + /// + /// Name based matching is used to determine if a field is in the base schema. + pub fn subtract_arrow_schema( + mut self, + other: &ArrowSchema, + on_missing: OnMissing, + ) -> Result { + self.with_row_id &= !other.fields().iter().any(|f| f.name() == ROW_ID); + self.with_row_addr &= !other.fields().iter().any(|f| f.name() == ROW_ADDR); + let other = + self.base + .schema() + .project_by_schema(other, on_missing, OnTypeMismatch::TakeSelf)?; + Ok(self.subtract_schema(&other)) + } + + /// Removes all fields from this projection that are present in the given projection + pub fn subtract_projection(mut self, other: &Self) -> Self { + self.field_ids = self + .field_ids + .difference(&other.field_ids) + .copied() + .collect(); + self.with_row_addr = self.with_row_addr && !other.with_row_addr; + self.with_row_id = self.with_row_id && !other.with_row_id; + self + } + + /// Removes all fields from the projection that are in the given schema + /// + /// Fields are only removed if they exist in the base schema, otherwise they + /// are ignored. + /// + /// Will panic if a field in the given schema has a non-negative id and is not in the base schema. + pub fn subtract_schema(mut self, other: &Schema) -> Self { + for field in other.fields_pre_order() { + if field.id >= 0 { + self.field_ids.remove(&field.id); + } else if field.name == ROW_ID { + self.with_row_id = false; + } else if field.name == ROW_ADDR { + self.with_row_addr = false; + } else { + debug_assert_eq!(field.id, -1); + } + } + self + } + + /// True if the projection does not select any fields + pub fn is_empty(&self) -> bool { + self.field_ids.is_empty() + } + + /// Convert the projection to a schema + pub fn to_schema(&self) -> Schema { + let field_ids = self.field_ids.iter().copied().collect::>(); + self.base.schema().project_by_ids(&field_ids, false) + } + + /// Convert the projection to a schema + pub fn into_schema(self) -> Schema { + self.to_schema() + } + + /// Convert the projection to a schema reference + pub fn into_schema_ref(self) -> Arc { + Arc::new(self.into_schema()) + } +} + #[cfg(test)] mod tests { use std::sync::Arc; @@ -921,7 +1186,9 @@ mod tests { false, ), ]); - let projected = schema.project_by_schema(&projection).unwrap(); + let projected = schema + .project_by_schema(&projection, OnMissing::Error, OnTypeMismatch::TakeSelf) + .unwrap(); assert_eq!(ArrowSchema::from(&projected), projection); } diff --git a/rust/lance/src/datafusion/logical_plan.rs b/rust/lance/src/datafusion/logical_plan.rs index 6afb94e332..9c20a3d43c 100644 --- a/rust/lance/src/datafusion/logical_plan.rs +++ b/rust/lance/src/datafusion/logical_plan.rs @@ -13,6 +13,7 @@ use datafusion::{ physical_plan::ExecutionPlan, prelude::Expr, }; +use lance_core::datatypes::{OnMissing, OnTypeMismatch}; use crate::Dataset; @@ -52,7 +53,11 @@ impl TableProvider for Dataset { if projection.len() != schema_ref.fields.len() { let arrow_schema: ArrowSchema = schema_ref.into(); let arrow_schema = arrow_schema.project(projection)?; - schema_ref.project_by_schema(&arrow_schema)? + schema_ref.project_by_schema( + &arrow_schema, + OnMissing::Error, + OnTypeMismatch::Error, + )? } else { schema_ref.clone() } diff --git a/rust/lance/src/dataset.rs b/rust/lance/src/dataset.rs index fcd5959d71..84ba4bf528 100644 --- a/rust/lance/src/dataset.rs +++ b/rust/lance/src/dataset.rs @@ -12,6 +12,7 @@ use futures::future::BoxFuture; use futures::stream::{self, StreamExt, TryStreamExt}; use futures::{FutureExt, Stream}; use itertools::Itertools; +use lance_core::datatypes::{OnMissing, OnTypeMismatch, Projectable, Projection}; use lance_core::traits::DatasetTakeRows; use lance_core::utils::address::RowAddress; use lance_core::utils::tokio::get_num_compute_intensive_cpus; @@ -270,7 +271,11 @@ impl ProjectionRequest { pub fn into_projection_plan(self, dataset_schema: &Schema) -> Result { match self { Self::Schema(schema) => Ok(ProjectionPlan::new_empty( - Arc::new(dataset_schema.project_by_schema(schema.as_ref())?), + Arc::new(dataset_schema.project_by_schema( + schema.as_ref(), + OnMissing::Error, + OnTypeMismatch::Error, + )?), /*load_blobs=*/ false, )), Self::Sql(columns) => { @@ -1036,6 +1041,11 @@ impl Dataset { &self.manifest.local_schema } + /// Creates a new empty projection into the dataset schema + pub fn empty_projection(self: &Arc) -> Projection { + Projection::empty(self.clone()) + } + /// Get fragments. /// /// If `filter` is provided, only fragments with the given name will be returned. @@ -1651,6 +1661,12 @@ fn write_manifest_file_to_path<'a>( }) } +impl Projectable for Dataset { + fn schema(&self) -> &Schema { + self.schema() + } +} + #[cfg(test)] mod tests { use std::vec; diff --git a/rust/lance/src/dataset/fragment.rs b/rust/lance/src/dataset/fragment.rs index 938ff646ab..e739dbc47f 100644 --- a/rust/lance/src/dataset/fragment.rs +++ b/rust/lance/src/dataset/fragment.rs @@ -18,7 +18,7 @@ use datafusion::logical_expr::Expr; use datafusion::scalar::ScalarValue; use futures::future::try_join_all; use futures::{join, stream, FutureExt, StreamExt, TryFutureExt, TryStreamExt}; -use lance_core::datatypes::SchemaCompareOptions; +use lance_core::datatypes::{OnMissing, OnTypeMismatch, SchemaCompareOptions}; use lance_core::utils::deletion::DeletionVector; use lance_core::utils::tokio::get_num_compute_intensive_cpus; use lance_core::{datatypes::Schema, Error, Result}; @@ -754,9 +754,11 @@ impl FileFragment { Some(&self.dataset.session.file_metadata_cache), ) .await?; - let initialized_schema = reader - .schema() - .project_by_schema(schema_per_file.as_ref())?; + let initialized_schema = reader.schema().project_by_schema( + schema_per_file.as_ref(), + OnMissing::Error, + OnTypeMismatch::Error, + )?; let reader = V1Reader::new(reader, Arc::new(initialized_schema)); Ok(Some(Box::new(reader))) } else { diff --git a/rust/lance/src/dataset/scanner.rs b/rust/lance/src/dataset/scanner.rs index 3f3500cde1..4537b75961 100644 --- a/rust/lance/src/dataset/scanner.rs +++ b/rust/lance/src/dataset/scanner.rs @@ -13,6 +13,7 @@ use arrow_array::{ use arrow_schema::{DataType, Field as ArrowField, Schema as ArrowSchema, SchemaRef, SortOptions}; use arrow_select::concat::concat_batches; use async_recursion::async_recursion; +use datafusion::common::SchemaExt; use datafusion::functions_aggregate; use datafusion::functions_aggregate::count::count_udaf; use datafusion::logical_expr::Expr; @@ -39,7 +40,7 @@ use futures::stream::{Stream, StreamExt}; use futures::TryStreamExt; use lance_arrow::floats::{coerce_float_vector, FloatType}; use lance_arrow::DataTypeExt; -use lance_core::datatypes::Field; +use lance_core::datatypes::{Field, OnMissing, Projection}; use lance_core::utils::tokio::get_num_compute_intensive_cpus; use lance_core::{ROW_ADDR, ROW_ADDR_FIELD, ROW_ID, ROW_ID_FIELD}; use lance_datafusion::exec::{execute_plan, LanceExecutionOptions}; @@ -945,6 +946,7 @@ impl Scanner { #[instrument(skip_all)] pub async fn try_into_stream(&self) -> Result { let plan = self.create_plan().await?; + Ok(DatasetRecordBatchStream::new(execute_plan( plan, LanceExecutionOptions::default(), @@ -1064,37 +1066,23 @@ impl Scanner { fn calc_eager_columns(&self, filter_plan: &FilterPlan) -> Result> { let columns = filter_plan.refine_columns(); - // If the column didn't exist in the scan output schema then we wouldn't make - // it to this point. However, there may be columns (like _rowid, _distance, etc.) - // which do not exist in the dataset schema but are added by the scan. We can ignore - // those as eager columns. - let filter_schema = self.dataset.schema().project_or_drop(&columns)?; - if filter_schema.fields.iter().any(|f| !f.is_default_storage()) { + let early_schema = self + .dataset + .empty_projection() + // We need the filter columns + .union_columns(columns, OnMissing::Error)? + // And also any columns that are eager + .union_predicate(|f| self.is_early_field(f)) + .into_schema_ref(); + + if early_schema.fields.iter().any(|f| !f.is_default_storage()) { return Err(Error::NotSupported { source: "non-default storage columns cannot be used as filters".into(), location: location!(), }); } - let physical_schema = self.projection_plan.physical_schema.clone(); - let remaining_schema = physical_schema.exclude(&filter_schema)?; - - let narrow_fields = remaining_schema - .fields - .iter() - .filter(|f| self.is_early_field(f)) - .cloned() - .collect::>(); - if narrow_fields.is_empty() { - Ok(Arc::new(filter_schema)) - } else { - let mut new_fields = filter_schema.fields; - new_fields.extend(narrow_fields); - Ok(Arc::new(Schema { - fields: new_fields, - metadata: HashMap::new(), - })) - } + Ok(early_schema) } /// Create [`ExecutionPlan`] for Scan. @@ -1331,34 +1319,33 @@ impl Scanner { }; // Stage 1.5 load columns needed for stages 2 & 3 - let mut additional_schema = None; + // Calculate the schema needed for the filter and ordering. + let mut pre_filter_projection = self + .dataset + .empty_projection() + .union_schema(&self.projection_plan.physical_schema) + .subtract_predicate(|field| !self.is_early_field(field)); + // We may need to take filter columns if we are going to refine - // an indexed scan. Otherwise, the filter was applied during the scan - // and this should be false + // an indexed scan. if filter_plan.has_refine() { - let eager_schema = self.calc_eager_columns(&filter_plan)?; - let base_schema = Schema::try_from(plan.schema().as_ref())?; - let still_to_load = eager_schema.exclude(base_schema)?; - if still_to_load.fields.is_empty() { - additional_schema = None; - } else { - additional_schema = Some(still_to_load); - } + // It's ok for some filter columns to be missing (e.g. _rowid) + pre_filter_projection = pre_filter_projection + .union_columns(filter_plan.refine_columns(), OnMissing::Ignore)?; } + + // TODO: Does it always make sense to take the ordering columns here? If there is a filter then + // maybe we wait until after the filter to take the ordering columns? Maybe it would be better to + // grab the ordering column in the initial scan (if it is eager) and if it isn't then we should + // take it after the filtering phase, if any (we already have a take there). if let Some(ordering) = &self.ordering { - additional_schema = self.calc_new_fields( - &additional_schema - .map(Ok::) - .unwrap_or_else(|| Schema::try_from(plan.schema().as_ref()))?, - &ordering - .iter() - .map(|col| &col.column_name) - .collect::>(), + pre_filter_projection = pre_filter_projection.union_columns( + ordering.iter().map(|col| &col.column_name), + OnMissing::Error, )?; } - if let Some(additional_schema) = additional_schema { - plan = self.take(plan, &additional_schema, self.batch_readahead)?; - } + + plan = self.take(plan, pre_filter_projection, self.batch_readahead)?; // Stage 2: filter if let Some(refine_expr) = filter_plan.refine_expr { @@ -1372,19 +1359,13 @@ impl Scanner { // Stage 3: sort if let Some(ordering) = &self.ordering { - let order_by_schema = Arc::new( - self.dataset.schema().project( - &ordering - .iter() - .map(|col| &col.column_name) - .collect::>(), - )?, - ); - let remaining_schema = order_by_schema.exclude(plan.schema().as_ref())?; - if !remaining_schema.fields.is_empty() { - // We haven't loaded the sort column yet so take it now - plan = self.take(plan, &remaining_schema, self.batch_readahead)?; - } + let ordering_columns = ordering.iter().map(|col| &col.column_name); + let projection_with_ordering = self + .dataset + .empty_projection() + .union_columns(ordering_columns, OnMissing::Error)?; + // We haven't loaded the sort column yet so take it now + plan = self.take(plan, projection_with_ordering, self.batch_readahead)?; let col_exprs = ordering .iter() .map(|col| { @@ -1408,12 +1389,14 @@ impl Scanner { // Stage 5: take remaining columns required for projection let physical_schema = self.scan_output_schema(&self.projection_plan.physical_schema, false)?; - let remaining_schema = physical_schema.exclude(plan.schema().as_ref())?; - if !remaining_schema.fields.is_empty() { - plan = self.take(plan, &remaining_schema, self.batch_readahead)?; - } + let physical_projection = self + .dataset + .empty_projection() + .union_schema(&physical_schema); + plan = self.take(plan, physical_projection, self.batch_readahead)?; // Stage 6: physical projection -- reorder physical columns needed before final projection let output_arrow_schema = physical_schema.as_ref().into(); + if plan.schema().as_ref() != &output_arrow_schema { plan = Arc::new(project(plan, &physical_schema.as_ref().into())?); } @@ -1635,9 +1618,13 @@ impl Scanner { let ann_node = self.ann(q, &deltas, filter_plan).await?; // _distance, _rowid let mut knn_node = if q.refine_factor.is_some() { - let with_vector = self.dataset.schema().project(&[&q.column])?; + let vector_projection = self + .dataset + .empty_projection() + .union_column(&q.column, OnMissing::Error) + .unwrap(); let knn_node_with_vector = - self.take(ann_node, &with_vector, self.batch_readahead)?; + self.take(ann_node, vector_projection, self.batch_readahead)?; // TODO: now we just open an index to get its metric type. let idx = self .dataset @@ -1701,8 +1688,12 @@ impl Scanner { // If the vector column is not present, we need to take the vector column, so // that the distance value is comparable with the flat search ones. if knn_node.schema().column_with_name(&q.column).is_none() { - let with_vector = self.dataset.schema().project(&[&q.column])?; - knn_node = self.take(knn_node, &with_vector, self.batch_readahead)?; + let vector_projection = self + .dataset + .empty_projection() + .union_column(&q.column, OnMissing::Error) + .unwrap(); + knn_node = self.take(knn_node, vector_projection, self.batch_readahead)?; } let mut columns = vec![q.column.clone()]; @@ -1740,7 +1731,9 @@ impl Scanner { // knn_node: _distance, _rowid, vector // topk_appended: vector, , _rowid, _distance let topk_appended = project(topk_appended, knn_node.schema().as_ref())?; - assert_eq!(topk_appended.schema(), knn_node.schema()); + assert!(topk_appended + .schema() + .equivalent_names_and_types(&knn_node.schema())); // union let unioned = UnionExec::new(vec![Arc::new(topk_appended), knn_node]); // Enforce only 1 partition. @@ -1822,7 +1815,8 @@ impl Scanner { _ => true, }; if needs_take { - plan = self.take(plan, projection, self.batch_readahead)?; + let take_projection = self.dataset.empty_projection().union_schema(projection); + plan = self.take(plan, take_projection, self.batch_readahead)?; } if self.with_row_address { @@ -2095,16 +2089,24 @@ impl Scanner { fn take( &self, input: Arc, - projection: &Schema, + output_projection: Projection, batch_readahead: usize, ) -> Result> { - let coalesced = Arc::new(CoalesceBatchesExec::new(input, self.get_batch_size())); - Ok(Arc::new(TakeExec::try_new( + let coalesced = Arc::new(CoalesceBatchesExec::new( + input.clone(), + self.get_batch_size(), + )); + if let Some(take_plan) = TakeExec::try_new( self.dataset.clone(), coalesced, - Arc::new(projection.clone()), + output_projection, batch_readahead, - )?)) + )? { + Ok(Arc::new(take_plan)) + } else { + // No new columns needed + Ok(input) + } } /// Global offset-limit of the result of the input plan @@ -4609,11 +4611,11 @@ mod test { assert_plan_equals( &dataset.dataset, |scan| scan.use_stats(false).filter("s IS NOT NULL"), - "ProjectionExec: expr=[i@1 as i, s@0 as s, vec@3 as vec] - Take: columns=\"s, i, _rowid, (vec)\" + "ProjectionExec: expr=[i@0 as i, s@1 as s, vec@3 as vec] + Take: columns=\"i, s, _rowid, (vec)\" CoalesceBatchesExec: target_batch_size=8192 - FilterExec: s@0 IS NOT NULL - LanceScan: uri..., projection=[s, i], row_id=true, row_addr=false, ordered=true", + FilterExec: s@1 IS NOT NULL + LanceScan: uri..., projection=[i, s], row_id=true, row_addr=false, ordered=true", ) .await?; @@ -4625,9 +4627,9 @@ mod test { .materialization_style(MaterializationStyle::AllEarly) .filter("s IS NOT NULL") }, - "ProjectionExec: expr=[i@1 as i, s@0 as s, vec@2 as vec] - FilterExec: s@0 IS NOT NULL - LanceScan: uri..., projection=[s, i, vec], row_id=true, row_addr=false, ordered=true", + "ProjectionExec: expr=[i@0 as i, s@1 as s, vec@2 as vec] + FilterExec: s@1 IS NOT NULL + LanceScan: uri..., projection=[i, s, vec], row_id=true, row_addr=false, ordered=true", ) .await?; diff --git a/rust/lance/src/dataset/updater.rs b/rust/lance/src/dataset/updater.rs index f12b201de8..10a3023b9b 100644 --- a/rust/lance/src/dataset/updater.rs +++ b/rust/lance/src/dataset/updater.rs @@ -3,6 +3,7 @@ use arrow_array::{RecordBatch, UInt32Array}; use futures::StreamExt; +use lance_core::datatypes::{OnMissing, OnTypeMismatch}; use lance_core::utils::deletion::DeletionVector; use lance_core::{datatypes::Schema, Error, Result}; use lance_table::format::Fragment; @@ -182,12 +183,11 @@ impl Updater { final_schema.set_field_id(Some(self.fragment.dataset().manifest.max_field_id())); self.final_schema = Some(final_schema); self.final_schema.as_ref().unwrap().validate()?; - self.write_schema = Some( - self.final_schema - .as_ref() - .unwrap() - .project_by_schema(output_schema.as_ref())?, - ); + self.write_schema = Some(self.final_schema.as_ref().unwrap().project_by_schema( + output_schema.as_ref(), + OnMissing::Error, + OnTypeMismatch::Error, + )?); } self.writer = Some( diff --git a/rust/lance/src/dataset/write.rs b/rust/lance/src/dataset/write.rs index 600176fae6..44385bd66f 100644 --- a/rust/lance/src/dataset/write.rs +++ b/rust/lance/src/dataset/write.rs @@ -6,7 +6,9 @@ use std::sync::Arc; use arrow_array::RecordBatch; use datafusion::physical_plan::SendableRecordBatchStream; use futures::{StreamExt, TryStreamExt}; -use lance_core::datatypes::{NullabilityComparison, SchemaCompareOptions, StorageClass}; +use lance_core::datatypes::{ + NullabilityComparison, OnMissing, OnTypeMismatch, SchemaCompareOptions, StorageClass, +}; use lance_core::{datatypes::Schema, Error, Result}; use lance_datafusion::chunker::{break_stream, chunk_stream}; use lance_datafusion::utils::StreamingWriteSource; @@ -335,7 +337,11 @@ pub async fn write_fragments_internal( }, )?; // Project from the dataset schema, because it has the correct field ids. - let write_schema = dataset.schema().project_by_schema(&schema)?; + let write_schema = dataset.schema().project_by_schema( + &schema, + OnMissing::Error, + OnTypeMismatch::Error, + )?; // Use the storage version from the dataset, ignoring any version from the user. let data_storage_version = dataset .manifest() @@ -362,7 +368,11 @@ pub async fn write_fragments_internal( (schema, params.storage_version_or_default()) }; - let data_schema = schema.project_by_schema(data.schema().as_ref())?; + let data_schema = schema.project_by_schema( + data.schema().as_ref(), + OnMissing::Error, + OnTypeMismatch::Error, + )?; let (data, blob_data) = data.extract_blob_stream(&data_schema); diff --git a/rust/lance/src/dataset/write/merge_insert.rs b/rust/lance/src/dataset/write/merge_insert.rs index fa4d05682a..1d603dec40 100644 --- a/rust/lance/src/dataset/write/merge_insert.rs +++ b/rust/lance/src/dataset/write/merge_insert.rs @@ -54,7 +54,7 @@ use futures::{ Stream, StreamExt, TryStreamExt, }; use lance_core::{ - datatypes::SchemaCompareOptions, + datatypes::{OnMissing, OnTypeMismatch, SchemaCompareOptions}, error::{box_error, InvalidInputSnafu}, utils::{ futures::Capacity, @@ -454,12 +454,19 @@ impl MergeInsertJob { } // 4 - Take the mapped row ids - let mut target = Arc::new(TakeExec::try_new( - self.dataset.clone(), - index_mapper, - Arc::new(self.dataset.schema().project_by_schema(schema.as_ref())?), - get_num_compute_intensive_cpus(), - )?) as Arc; + let projection = self + .dataset + .empty_projection() + .union_arrow_schema(schema.as_ref(), OnMissing::Error)?; + let mut target = Arc::new( + TakeExec::try_new( + self.dataset.clone(), + index_mapper, + projection, + get_num_compute_intensive_cpus(), + )? + .unwrap(), + ) as Arc; // 5 - Take puts the row id and row addr at the beginning. A full scan (used when there is // no scalar index) puts the row id and addr at the end. We need to match these up so @@ -632,7 +639,11 @@ impl MergeInsertJob { ) -> Result { // batches still have _rowaddr let write_schema = batches[0].schema().as_ref().without_column(ROW_ADDR); - let write_schema = dataset.local_schema().project_by_schema(&write_schema)?; + let write_schema = dataset.local_schema().project_by_schema( + &write_schema, + OnMissing::Error, + OnTypeMismatch::Error, + )?; let updated_rows: usize = batches.iter().map(|batch| batch.num_rows()).sum(); if Some(updated_rows) == metadata.physical_rows { @@ -804,7 +815,11 @@ impl MergeInsertJob { let reader = RecordBatchIterator::new(batches, write_schema.clone()); let stream = reader_to_stream(Box::new(reader)); - let write_schema = dataset.schema().project_by_schema(write_schema.as_ref())?; + let write_schema = dataset.schema().project_by_schema( + write_schema.as_ref(), + OnMissing::Error, + OnTypeMismatch::Error, + )?; let fragments = write_fragments_internal( Some(dataset.as_ref()), diff --git a/rust/lance/src/io/exec/optimizer.rs b/rust/lance/src/io/exec/optimizer.rs index b05e5f5feb..d84ddf33f6 100644 --- a/rust/lance/src/io/exec/optimizer.rs +++ b/rust/lance/src/io/exec/optimizer.rs @@ -6,18 +6,68 @@ use std::sync::Arc; use super::TakeExec; +use arrow_schema::Schema as ArrowSchema; use datafusion::{ common::tree_node::{Transformed, TreeNode}, config::ConfigOptions, error::Result as DFResult, physical_optimizer::{optimizer::PhysicalOptimizer, PhysicalOptimizerRule}, - physical_plan::{projection::ProjectionExec as DFProjectionExec, ExecutionPlan}, + physical_plan::{ + coalesce_batches::CoalesceBatchesExec, projection::ProjectionExec, ExecutionPlan, + }, }; -use datafusion_physical_expr::expressions::Column; +use datafusion_physical_expr::{expressions::Column, PhysicalExpr}; /// Rule that eliminates [TakeExec] nodes that are immediately followed by another [TakeExec]. pub struct CoalesceTake; +impl CoalesceTake { + fn field_order_differs(old_schema: &ArrowSchema, new_schema: &ArrowSchema) -> bool { + old_schema + .fields + .iter() + .zip(&new_schema.fields) + .any(|(old, new)| old.name() != new.name()) + } + + fn remap_collapsed_output( + old_schema: &ArrowSchema, + new_schema: &ArrowSchema, + plan: Arc, + ) -> Arc { + let mut project_exprs = Vec::with_capacity(old_schema.fields.len()); + for field in &old_schema.fields { + project_exprs.push(( + Arc::new(Column::new_with_schema(field.name(), new_schema).unwrap()) + as Arc, + field.name().clone(), + )); + } + Arc::new(ProjectionExec::try_new(project_exprs, plan).unwrap()) + } + + fn collapse_takes( + inner_take: &TakeExec, + outer_take: &TakeExec, + outer_exec: Arc, + ) -> Arc { + let inner_take_input = inner_take.children()[0].clone(); + let old_output_schema = outer_take.schema(); + let collapsed = outer_exec + .with_new_children(vec![inner_take_input]) + .unwrap(); + let new_output_schema = collapsed.schema(); + + // It's possible that collapsing the take can change the field order. This disturbs DF's planner and + // so we must restore it. + if Self::field_order_differs(&old_output_schema, &new_output_schema) { + Self::remap_collapsed_output(&old_output_schema, &new_output_schema, collapsed) + } else { + collapsed + } + } +} + impl PhysicalOptimizerRule for CoalesceTake { fn optimize( &self, @@ -26,11 +76,27 @@ impl PhysicalOptimizerRule for CoalesceTake { ) -> DFResult> { Ok(plan .transform_down(|plan| { - if let Some(take) = plan.as_any().downcast_ref::() { - let child = take.children()[0]; - if let Some(exec_child) = child.as_any().downcast_ref::() { + if let Some(outer_take) = plan.as_any().downcast_ref::() { + let child = outer_take.children()[0]; + // Case 1: TakeExec -> TakeExec + if let Some(inner_take) = child.as_any().downcast_ref::() { + return Ok(Transformed::yes(Self::collapse_takes( + inner_take, + outer_take, + plan.clone(), + ))); + // Case 2: TakeExec -> CoalesceBatchesExec -> TakeExec + } else if let Some(exec_child) = + child.as_any().downcast_ref::() + { let inner_child = exec_child.children()[0].clone(); - return Ok(Transformed::yes(plan.with_new_children(vec![inner_child])?)); + if let Some(inner_take) = inner_child.as_any().downcast_ref::() { + return Ok(Transformed::yes(Self::collapse_takes( + inner_take, + outer_take, + plan.clone(), + ))); + } } } Ok(Transformed::no(plan)) @@ -59,7 +125,7 @@ impl PhysicalOptimizerRule for SimplifyProjection { ) -> DFResult> { Ok(plan .transform_down(|plan| { - if let Some(proj) = plan.as_any().downcast_ref::() { + if let Some(proj) = plan.as_any().downcast_ref::() { let children = proj.children(); if children.len() != 1 { return Ok(Transformed::no(plan)); diff --git a/rust/lance/src/io/exec/take.rs b/rust/lance/src/io/exec/take.rs index bf7b808844..3bdb2e6cad 100644 --- a/rust/lance/src/io/exec/take.rs +++ b/rust/lance/src/io/exec/take.rs @@ -17,6 +17,7 @@ use datafusion::physical_plan::{ use datafusion_physical_expr::EquivalenceProperties; use futures::stream::{self, Stream, StreamExt, TryStreamExt}; use futures::{Future, FutureExt}; +use lance_core::datatypes::{Field, OnMissing, Projection}; use tokio::sync::mpsc::{self, Receiver}; use tokio::task::JoinHandle; use tracing::{instrument, Instrument}; @@ -33,7 +34,6 @@ use crate::{arrow::*, Error}; pub struct Take { rx: Receiver>, bg_thread: Option>, - output_schema: SchemaRef, } @@ -55,15 +55,18 @@ impl Take { ) -> Self { let (tx, rx) = mpsc::channel(4); + let output_schema_copy = output_schema.clone(); let bg_thread = tokio::spawn( async move { if let Err(e) = child .zip(stream::repeat_with(|| { (dataset.clone(), projection.clone()) })) - .map(|(batch, (dataset, extra))| async move { - Self::take_batch(batch?, dataset, extra).await - }) + .map(|(batch, (dataset, extra))| { + let output_schema_copy = output_schema_copy.clone(); + async move { + Self::take_batch(batch?, dataset, extra, output_schema_copy).await + }}) .buffered(batch_readahead) .map(|r| r.map_err(|e| DataFusionError::Execution(e.to_string()))) .try_for_each(|b| async { @@ -110,6 +113,7 @@ impl Take { batch: RecordBatch, dataset: Arc, extra: Arc, + output_schema: SchemaRef, ) -> impl Future> + Send { async move { let row_id_arr = batch.column_by_name(ROW_ID).unwrap(); @@ -121,7 +125,7 @@ impl Take { .take_rows(row_ids.values(), ProjectionRequest::Schema(extra)) .await?; debug_assert_eq!(batch.num_rows(), new_columns.num_rows()); - batch.merge(&new_columns)? + batch.merge_with_schema(&new_columns, &output_schema)? }; Ok::(rows) } @@ -173,14 +177,20 @@ impl RecordBatchStream for Take { #[derive(Debug)] pub struct TakeExec { /// Dataset to read from. - dataset: Arc, + pub(crate) dataset: Arc, - pub(crate) extra_schema: Arc, + /// The original projection is kept to recalculate `with_new_children`. + pub(crate) original_projection: Arc, + + /// The schema to pass to dataset.take, this should be the original projection + /// minus any fields in the input schema. + schema_to_take: Arc, input: Arc, - /// Output schema is the merged schema between input schema and extra schema. - output_schema: Schema, + /// Output schema is the merged schema between input schema and extra schema and + /// tells us how to merge the input and extra columns. + output_schema: Arc, batch_readahead: usize, @@ -190,7 +200,7 @@ pub struct TakeExec { impl DisplayAs for TakeExec { fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result { let extra_fields = self - .extra_schema + .schema_to_take .fields .iter() .map(|f| f.name.clone()) @@ -221,38 +231,109 @@ impl TakeExec { /// /// - dataset: the dataset to read from /// - input: the upstream [`ExecutionPlan`] to feed data in. - /// - extra_schema: the extra schema to take / read from the dataset. + /// - projection: the desired output projection, can overlap with the input schema if desired + /// + /// Returns None if no extra columns are required (everything in the projection exists in the input schema). pub fn try_new( dataset: Arc, input: Arc, - extra_schema: Arc, + projection: Projection, batch_readahead: usize, - ) -> Result { + ) -> Result> { + let original_projection = projection.clone().into_schema_ref(); + let projection = + projection.subtract_arrow_schema(input.schema().as_ref(), OnMissing::Ignore)?; + if projection.is_empty() { + return Ok(None); + } + + // We actually need a take so lets make sure we have a row id if input.schema().column_with_name(ROW_ID).is_none() { return Err(DataFusionError::Plan( "TakeExec requires the input plan to have a column named '_rowid'".to_string(), )); } - let input_schema = Schema::try_from(input.schema().as_ref())?; - let output_schema = input_schema.merge(extra_schema.as_ref())?; - - let remaining_schema = extra_schema.exclude(&input_schema)?; + // Can't use take if we don't want any fields and we can't use take to add row_id or row_addr + assert!( + !projection.with_row_id && !projection.with_row_addr, + "Take cannot insert row_id / row_addr: {:#?}", + projection + ); - let output_arrow = Arc::new(ArrowSchema::from(&output_schema)); + let output_schema = Arc::new(Self::calculate_output_schema( + dataset.schema(), + &input.schema(), + &projection, + )); + let output_arrow = Arc::new(ArrowSchema::from(output_schema.as_ref())); let properties = input .properties() .clone() .with_eq_properties(EquivalenceProperties::new(output_arrow)); - Ok(Self { + Ok(Some(Self { dataset, - extra_schema: Arc::new(remaining_schema), + original_projection, + schema_to_take: projection.into_schema_ref(), input, output_schema, batch_readahead, properties, - }) + })) + } + + /// The output of a take operation will be all columns from the input schema followed + /// by any new columns from the dataset. + /// + /// The output fields will always be added in dataset schema order + /// + /// Nested columns in the input schema may have new fields inserted into them. + /// + /// If this happens the order of the new nested fields will match the order defined in + /// the dataset schema. + fn calculate_output_schema( + dataset_schema: &Schema, + input_schema: &ArrowSchema, + projection: &Projection, + ) -> Schema { + // TakeExec doesn't reorder top-level fields and so the first thing we need to do is determine the + // top-level field order. + let mut top_level_fields_added = HashSet::with_capacity(input_schema.fields.len()); + let projected_schema = projection.to_schema(); + + let mut output_fields = + Vec::with_capacity(input_schema.fields.len() + projected_schema.fields.len()); + // TakeExec always moves the _rowid to the start of the output schema + output_fields.extend(input_schema.fields.iter().map(|f| { + let f = Field::try_from(f.as_ref()).unwrap(); + if let Some(ds_field) = dataset_schema.field(&f.name) { + top_level_fields_added.insert(ds_field.id); + // Field is in the dataset, it might have new fields added to it + if let Some(projected_field) = ds_field.apply_projection(projection) { + f.merge_with_reference(&projected_field, ds_field) + } else { + // No new fields added, keep as-is + f + } + } else { + // Field not in dataset, not possible to add extra fields, use as-is + f + } + })); + + // Now we add to the end any brand new top-level fields. These will be added + // dataset schema order. + output_fields.extend( + projected_schema + .fields + .into_iter() + .filter(|f| !top_level_fields_added.contains(&f.id)), + ); + Schema { + fields: output_fields, + metadata: dataset_schema.metadata.clone(), + } } /// Get the dataset. @@ -273,7 +354,7 @@ impl ExecutionPlan for TakeExec { } fn schema(&self) -> SchemaRef { - ArrowSchema::from(&self.output_schema).into() + Arc::new(self.output_schema.as_ref().into()) } fn children(&self) -> Vec<&Arc> { @@ -291,18 +372,24 @@ impl ExecutionPlan for TakeExec { )); } - let child = &children[0]; - - let extra_schema = self.output_schema.exclude(child.schema().as_ref())?; + let projection = self + .dataset + .empty_projection() + .union_schema(&self.original_projection); let plan = Self::try_new( self.dataset.clone(), children[0].clone(), - Arc::new(extra_schema), + projection, self.batch_readahead, )?; - Ok(Arc::new(plan)) + if let Some(plan) = plan { + Ok(Arc::new(plan)) + } else { + // Is this legal or do we need to insert a no-op node? + Ok(children[0].clone()) + } } fn execute( @@ -311,10 +398,11 @@ impl ExecutionPlan for TakeExec { context: Arc, ) -> Result { let input_stream = self.input.execute(partition, context)?; + let output_schema_arrow = Arc::new(ArrowSchema::from(self.output_schema.as_ref())); Ok(Box::pin(Take::new( self.dataset.clone(), - self.extra_schema.clone(), - self.schema(), + self.schema_to_take.clone(), + output_schema_arrow, input_stream, self.batch_readahead, ))) @@ -336,20 +424,35 @@ impl ExecutionPlan for TakeExec { mod tests { use super::*; - use arrow_array::{ArrayRef, Float32Array, Int32Array, RecordBatchIterator, StringArray}; - use arrow_schema::{DataType, Field}; - use tempfile::tempdir; + use arrow_array::{ + ArrayRef, Float32Array, Int32Array, RecordBatchIterator, StringArray, StructArray, + }; + use arrow_schema::{DataType, Field, Fields}; + use datafusion::execution::TaskContext; + use lance_core::datatypes::OnMissing; + use tempfile::{tempdir, TempDir}; use crate::{ dataset::WriteParams, io::exec::{LanceScanConfig, LanceScanExec}, }; - async fn create_dataset() -> Arc { + struct TestFixture { + dataset: Arc, + _tmp_dir_guard: TempDir, + } + + async fn test_fixture() -> TestFixture { + let struct_fields = Fields::from(vec![ + Arc::new(Field::new("x", DataType::Int32, false)), + Arc::new(Field::new("y", DataType::Int32, false)), + ]); + let schema = Arc::new(ArrowSchema::new(vec![ Field::new("i", DataType::Int32, false), Field::new("f", DataType::Float32, false), Field::new("s", DataType::Utf8, false), + Field::new("struct", DataType::Struct(struct_fields.clone()), false), ])); // Write 3 batches. @@ -362,7 +465,15 @@ mod tests { value_range.clone().map(|v| v as f32), )), Arc::new(StringArray::from_iter_values( - value_range.map(|v| format!("str-{v}")), + value_range.clone().map(|v| format!("str-{v}")), + )), + Arc::new(StructArray::new( + struct_fields.clone(), + vec![ + Arc::new(Int32Array::from_iter(value_range.clone())), + Arc::new(Int32Array::from_iter(value_range)), + ], + None, )), ]; RecordBatch::try_new(schema.clone(), columns).unwrap() @@ -381,19 +492,19 @@ mod tests { .await .unwrap(); - Arc::new(Dataset::open(test_uri).await.unwrap()) + TestFixture { + dataset: Arc::new(Dataset::open(test_uri).await.unwrap()), + _tmp_dir_guard: test_dir, + } } #[tokio::test] async fn test_take_schema() { - let dataset = create_dataset().await; + let TestFixture { dataset, .. } = test_fixture().await; let scan_arrow_schema = ArrowSchema::new(vec![Field::new("i", DataType::Int32, false)]); let scan_schema = Arc::new(Schema::try_from(&scan_arrow_schema).unwrap()); - let extra_arrow_schema = ArrowSchema::new(vec![Field::new("s", DataType::Int32, false)]); - let extra_schema = Arc::new(Schema::try_from(&extra_arrow_schema).unwrap()); - // With row id let config = LanceScanConfig { with_row_id: true, @@ -406,7 +517,14 @@ mod tests { scan_schema, config, )); - let take_exec = TakeExec::try_new(dataset, input, extra_schema, 10).unwrap(); + + let projection = dataset + .empty_projection() + .union_column("s", OnMissing::Error) + .unwrap(); + let take_exec = TakeExec::try_new(dataset, input, projection, 10) + .unwrap() + .unwrap(); let schema = take_exec.schema(); assert_eq!( schema.fields.iter().map(|f| f.name()).collect::>(), @@ -415,18 +533,15 @@ mod tests { } #[tokio::test] - async fn test_take_no_extra_columns() { - let dataset = create_dataset().await; - - let scan_arrow_schema = ArrowSchema::new(vec![ - Field::new("i", DataType::Int32, false), - Field::new("s", DataType::Int32, false), - ]); - let scan_schema = Arc::new(Schema::try_from(&scan_arrow_schema).unwrap()); + async fn test_take_struct() { + // When taking fields into an existing struct, the field order should be maintained + // according the the schema of the struct. + let TestFixture { + dataset, + _tmp_dir_guard, + } = test_fixture().await; - // Extra column is already read. - let extra_arrow_schema = ArrowSchema::new(vec![Field::new("s", DataType::Int32, false)]); - let extra_schema = Arc::new(Schema::try_from(&extra_arrow_schema).unwrap()); + let scan_schema = Arc::new(dataset.schema().project(&["struct.y"]).unwrap()); let config = LanceScanConfig { with_row_id: true, @@ -439,26 +554,50 @@ mod tests { scan_schema, config, )); - let take_exec = TakeExec::try_new(dataset, input, extra_schema, 10).unwrap(); + + let projection = dataset + .empty_projection() + .union_column("struct.x", OnMissing::Error) + .unwrap(); + + let take_exec = TakeExec::try_new(dataset, input, projection, 10) + .unwrap() + .unwrap(); + + let expected_schema = ArrowSchema::new(vec![ + Field::new( + "struct", + DataType::Struct(Fields::from(vec![ + Arc::new(Field::new("x", DataType::Int32, false)), + Arc::new(Field::new("y", DataType::Int32, false)), + ])), + false, + ), + Field::new(ROW_ID, DataType::UInt64, true), + ]); let schema = take_exec.schema(); - assert_eq!( - schema.fields.iter().map(|f| f.name()).collect::>(), - vec!["i", "s", ROW_ID] - ); + assert_eq!(schema.as_ref(), &expected_schema); + + let mut stream = take_exec + .execute(0, Arc::new(TaskContext::default())) + .unwrap(); + + while let Some(batch) = stream.try_next().await.unwrap() { + assert_eq!(batch.schema().as_ref(), &expected_schema); + } } #[tokio::test] async fn test_take_no_row_id() { - let dataset = create_dataset().await; + let TestFixture { dataset, .. } = test_fixture().await; - let scan_arrow_schema = ArrowSchema::new(vec![ - Field::new("i", DataType::Int32, false), - Field::new("s", DataType::Int32, false), - ]); + let scan_arrow_schema = ArrowSchema::new(vec![Field::new("i", DataType::Int32, false)]); let scan_schema = Arc::new(Schema::try_from(&scan_arrow_schema).unwrap()); - let extra_arrow_schema = ArrowSchema::new(vec![Field::new("s", DataType::Int32, false)]); - let extra_schema = Arc::new(Schema::try_from(&extra_arrow_schema).unwrap()); + let projection = dataset + .empty_projection() + .union_column("s", OnMissing::Error) + .unwrap(); // No row ID let input = Arc::new(LanceScanExec::new( @@ -468,38 +607,43 @@ mod tests { scan_schema, LanceScanConfig::default(), )); - assert!(TakeExec::try_new(dataset, input, extra_schema, 10).is_err()); + assert!(TakeExec::try_new(dataset, input, projection, 10).is_err()); } #[tokio::test] async fn test_with_new_children() -> Result<()> { - let dataset = create_dataset().await; + let TestFixture { dataset, .. } = test_fixture().await; let config = LanceScanConfig { with_row_id: true, ..Default::default() }; + + let input_schema = Arc::new(dataset.schema().project(&["i"])?); + let projection = dataset + .empty_projection() + .union_column("s", OnMissing::Error) + .unwrap(); + let input = Arc::new(LanceScanExec::new( dataset.clone(), dataset.fragments().clone(), None, - Arc::new(dataset.schema().project(&["i"])?), + input_schema, config, )); + assert_eq!(input.schema().field_names(), vec!["i", ROW_ID],); - let take_exec = TakeExec::try_new( - dataset.clone(), - input.clone(), - Arc::new(dataset.schema().project(&["s"])?), - 10, - )?; + let take_exec = TakeExec::try_new(dataset.clone(), input.clone(), projection, 10)?.unwrap(); assert_eq!(take_exec.schema().field_names(), vec!["i", ROW_ID, "s"],); - let outer_take = Arc::new(TakeExec::try_new( - dataset.clone(), - Arc::new(take_exec), - Arc::new(dataset.schema().project(&["f"])?), - 10, - )?); + + let projection = dataset + .empty_projection() + .union_columns(["s", "f"], OnMissing::Error) + .unwrap(); + + let outer_take = + Arc::new(TakeExec::try_new(dataset, Arc::new(take_exec), projection, 10)?.unwrap()); assert_eq!( outer_take.schema().field_names(), vec!["i", ROW_ID, "s", "f"], @@ -507,7 +651,7 @@ mod tests { // with_new_children should preserve the output schema. let edited = outer_take.with_new_children(vec![input])?; - assert_eq!(edited.schema().field_names(), vec!["i", ROW_ID, "s", "f"],); + assert_eq!(edited.schema().field_names(), vec!["i", ROW_ID, "f", "s"],); Ok(()) } }