From 0593f4e203d573c94d0a2e4ea9709e7894f89d3c Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Wed, 25 Dec 2024 12:21:04 +0800 Subject: [PATCH] fix Signed-off-by: BubbleCal --- rust/lance/src/dataset/scanner.rs | 33 ++++++++++++++++++++------- rust/lance/src/index/vector/ivf/v2.rs | 18 +++++++-------- 2 files changed, 34 insertions(+), 17 deletions(-) diff --git a/rust/lance/src/dataset/scanner.rs b/rust/lance/src/dataset/scanner.rs index 02ac14216d..f1fa23caaa 100644 --- a/rust/lance/src/dataset/scanner.rs +++ b/rust/lance/src/dataset/scanner.rs @@ -634,7 +634,8 @@ impl Scanner { } /// Find k-nearest neighbor within the vector column. - /// the query can be a Float16Array, Float32Array, Float64Array, UInt8Array, or a ListArray of the above types. + /// the query can be a Float16Array, Float32Array, Float64Array, UInt8Array, + /// or a ListArray/FixedSizeListArray of the above types. pub fn nearest(&mut self, column: &str, q: &dyn Array, k: usize) -> Result<&mut Self> { if !self.prefilter { // We can allow fragment scan if the input to nearest is a prefilter. @@ -659,7 +660,7 @@ impl Scanner { let dim = get_vector_dim(self.dataset.schema(), column)?; let q = match q.data_type() { - DataType::List(_) => { + DataType::List(_) | DataType::FixedSizeList(_, _) => { if !matches!(vector_type, DataType::List(_)) { return Err(Error::invalid_input( format!( @@ -669,22 +670,38 @@ impl Scanner { location!(), )); } - let list_array = q.as_list::(); - for i in 0..list_array.len() { - let vec = list_array.value(i); - if vec.len() != dim { + + if let Some(list_array) = q.as_list_opt::() { + for i in 0..list_array.len() { + let vec = list_array.value(i); + if vec.len() != dim { + return Err(Error::invalid_input( + format!( + "query dim({}) doesn't match the column {} vector dim({})", + vec.len(), + column, + dim, + ), + location!(), + )); + } + } + list_array.values().clone() + } else { + let fsl = q.as_fixed_size_list(); + if fsl.value_length() as usize != dim { return Err(Error::invalid_input( format!( "query dim({}) doesn't match the column {} vector dim({})", - vec.len(), + fsl.value_length(), column, dim, ), location!(), )); } + fsl.values().clone() } - list_array.values().clone() } _ => { if q.len() != dim { diff --git a/rust/lance/src/index/vector/ivf/v2.rs b/rust/lance/src/index/vector/ivf/v2.rs index cbf4ae9c95..4382d09685 100644 --- a/rust/lance/src/index/vector/ivf/v2.rs +++ b/rust/lance/src/index/vector/ivf/v2.rs @@ -672,7 +672,12 @@ mod tests { k: usize, distance_type: DistanceType, ) -> Vec<(f32, u64)> { - multivec_distance(query, vectors, distance_type) + let query = if let Some(list_array) = query.as_list_opt::() { + list_array.values().clone() + } else { + query.as_fixed_size_list().values().clone() + }; + multivec_distance(&query, vectors, distance_type) .unwrap() .into_iter() .enumerate() @@ -1017,17 +1022,12 @@ mod tests { .await .unwrap(); - let multivector = vectors.value(0); - let query = multivector - .as_fixed_size_list() - .values() - .as_primitive::(); - let query = &query; + let query = vectors.value(0); let k = 100; let result = dataset .scan() - .nearest("vector", query, k) + .nearest("vector", &query, k) .unwrap() .nprobs(nlist) .with_row_id() @@ -1048,7 +1048,7 @@ mod tests { .collect::>(); let row_ids = row_ids.into_iter().collect::>(); - let gt = multivec_ground_truth(&vectors, query, k, params.metric_type); + let gt = multivec_ground_truth(&vectors, &query, k, params.metric_type); let gt_set = gt.iter().map(|r| r.1).collect::>(); let recall = row_ids.intersection(>_set).count() as f32 / 10.0;