Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
Signed-off-by: BubbleCal <[email protected]>
  • Loading branch information
BubbleCal committed Dec 25, 2024
1 parent 5cedac5 commit 0593f4e
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 17 deletions.
33 changes: 25 additions & 8 deletions rust/lance/src/dataset/scanner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -634,7 +634,8 @@ impl Scanner {
}

/// Find k-nearest neighbor within the vector column.
/// the query can be a Float16Array, Float32Array, Float64Array, UInt8Array, or a ListArray of the above types.
/// the query can be a Float16Array, Float32Array, Float64Array, UInt8Array,
/// or a ListArray/FixedSizeListArray of the above types.
pub fn nearest(&mut self, column: &str, q: &dyn Array, k: usize) -> Result<&mut Self> {
if !self.prefilter {
// We can allow fragment scan if the input to nearest is a prefilter.
Expand All @@ -659,7 +660,7 @@ impl Scanner {
let dim = get_vector_dim(self.dataset.schema(), column)?;

let q = match q.data_type() {
DataType::List(_) => {
DataType::List(_) | DataType::FixedSizeList(_, _) => {
if !matches!(vector_type, DataType::List(_)) {
return Err(Error::invalid_input(
format!(
Expand All @@ -669,22 +670,38 @@ impl Scanner {
location!(),
));
}
let list_array = q.as_list::<i32>();
for i in 0..list_array.len() {
let vec = list_array.value(i);
if vec.len() != dim {

if let Some(list_array) = q.as_list_opt::<i32>() {
for i in 0..list_array.len() {
let vec = list_array.value(i);
if vec.len() != dim {
return Err(Error::invalid_input(
format!(
"query dim({}) doesn't match the column {} vector dim({})",
vec.len(),
column,
dim,
),
location!(),
));
}
}
list_array.values().clone()
} else {
let fsl = q.as_fixed_size_list();
if fsl.value_length() as usize != dim {
return Err(Error::invalid_input(
format!(
"query dim({}) doesn't match the column {} vector dim({})",
vec.len(),
fsl.value_length(),
column,
dim,
),
location!(),
));
}
fsl.values().clone()
}
list_array.values().clone()
}
_ => {
if q.len() != dim {
Expand Down
18 changes: 9 additions & 9 deletions rust/lance/src/index/vector/ivf/v2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -672,7 +672,12 @@ mod tests {
k: usize,
distance_type: DistanceType,
) -> Vec<(f32, u64)> {
multivec_distance(query, vectors, distance_type)
let query = if let Some(list_array) = query.as_list_opt::<i32>() {
list_array.values().clone()
} else {
query.as_fixed_size_list().values().clone()
};
multivec_distance(&query, vectors, distance_type)
.unwrap()
.into_iter()
.enumerate()
Expand Down Expand Up @@ -1017,17 +1022,12 @@ mod tests {
.await
.unwrap();

let multivector = vectors.value(0);
let query = multivector
.as_fixed_size_list()
.values()
.as_primitive::<T>();
let query = &query;
let query = vectors.value(0);
let k = 100;

let result = dataset
.scan()
.nearest("vector", query, k)
.nearest("vector", &query, k)
.unwrap()
.nprobs(nlist)
.with_row_id()
Expand All @@ -1048,7 +1048,7 @@ mod tests {
.collect::<Vec<_>>();
let row_ids = row_ids.into_iter().collect::<HashSet<_>>();

let gt = multivec_ground_truth(&vectors, query, k, params.metric_type);
let gt = multivec_ground_truth(&vectors, &query, k, params.metric_type);
let gt_set = gt.iter().map(|r| r.1).collect::<HashSet<_>>();

let recall = row_ids.intersection(&gt_set).count() as f32 / 10.0;
Expand Down

0 comments on commit 0593f4e

Please sign in to comment.