Skip to content

Commit

Permalink
feat: support multivector type
Browse files Browse the repository at this point in the history
Signed-off-by: BubbleCal <[email protected]>
  • Loading branch information
BubbleCal committed Dec 2, 2024
1 parent dc9afbb commit 55b0e08
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 33 deletions.
25 changes: 21 additions & 4 deletions rust/lance/src/index/vector/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ use super::v2::IVFIndex;
pub struct IvfIndexBuilder<S: IvfSubIndex, Q: Quantization + Clone> {
dataset: Dataset,
column: String,
// used for multivector type
modal_index: Option<usize>,
index_dir: Path,
distance_type: DistanceType,
shuffler: Arc<dyn Shuffler>,
Expand Down Expand Up @@ -105,6 +107,7 @@ impl<S: IvfSubIndex + 'static, Q: Quantization + Clone + 'static> IvfIndexBuilde
Ok(Self {
dataset,
column,
modal_index: None,
index_dir,
distance_type,
shuffler: shuffler.into(),
Expand Down Expand Up @@ -142,6 +145,11 @@ impl<S: IvfSubIndex + 'static, Q: Quantization + Clone + 'static> IvfIndexBuilde
)
}

pub fn with_modal_index(&mut self, modal_index: usize) -> &mut Self {
self.modal_index = Some(modal_index);
self
}

// build the index with the all data in the dataset,
pub async fn build(&mut self) -> Result<()> {
// step 1. train IVF & quantizer
Expand Down Expand Up @@ -190,6 +198,7 @@ impl<S: IvfSubIndex + 'static, Q: Quantization + Clone + 'static> IvfIndexBuilde
super::build_ivf_model(
&self.dataset,
&self.column,
self.modal_index,
dim,
self.distance_type,
ivf_params,
Expand All @@ -211,9 +220,13 @@ impl<S: IvfSubIndex + 'static, Q: Quantization + Clone + 'static> IvfIndexBuilde
"loading training data for quantizer. sample size: {}",
sample_size_hint
);
let training_data =
utils::maybe_sample_training_data(&self.dataset, &self.column, sample_size_hint)
.await?;
let training_data = utils::maybe_sample_training_data(
&self.dataset,
&self.column,
self.modal_index,
sample_size_hint,
)
.await?;
info!(
"Finished loading training data in {:02} seconds",
start.elapsed().as_secs_f32()
Expand Down Expand Up @@ -252,11 +265,15 @@ impl<S: IvfSubIndex + 'static, Q: Quantization + Clone + 'static> IvfIndexBuilde
}

async fn shuffle_dataset(&mut self) -> Result<()> {
let column = match self.modal_index {
Some(index) => format!("{}[{}]", self.column, index),
None => self.column.clone(),
};
let stream = self
.dataset
.scan()
.batch_readahead(get_num_compute_intensive_cpus())
.project(&[self.column.as_str()])?
.project(&[column])?
.with_row_id()
.try_into_stream()
.await?;
Expand Down
23 changes: 17 additions & 6 deletions rust/lance/src/index/vector/ivf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1131,6 +1131,7 @@ fn sanity_check_params(ivf: &IvfBuildParams, pq: &PQBuildParams) -> Result<()> {
pub async fn build_ivf_model(
dataset: &Dataset,
column: &str,
modal_index: Option<usize>,
dim: usize,
metric_type: MetricType,
params: &IvfBuildParams,
Expand All @@ -1156,7 +1157,8 @@ pub async fn build_ivf_model(
"Loading training data for IVF. Sample size: {}",
sample_size_hint
);
let training_data = maybe_sample_training_data(dataset, column, sample_size_hint).await?;
let training_data =
maybe_sample_training_data(dataset, column, modal_index, sample_size_hint).await?;
info!(
"Finished loading training data in {:02} seconds",
start.elapsed().as_secs_f32()
Expand All @@ -1181,6 +1183,8 @@ pub async fn build_ivf_model(
Ok(ivf)
}

// Build PQ model from the dataset.
// this is used for only v1 index
async fn build_ivf_model_and_pq(
dataset: &Dataset,
column: &str,
Expand Down Expand Up @@ -1208,7 +1212,7 @@ async fn build_ivf_model_and_pq(
});
};

let ivf_model = build_ivf_model(dataset, column, dim, metric_type, ivf_params).await?;
let ivf_model = build_ivf_model(dataset, column, None, dim, metric_type, ivf_params).await?;

let ivf_residual = if matches!(metric_type, MetricType::Cosine | MetricType::L2) {
Some(&ivf_model)
Expand Down Expand Up @@ -2277,7 +2281,7 @@ mod tests {
let (dataset, _) = generate_test_dataset(test_uri, 1000.0..1100.0).await;

let ivf_params = IvfBuildParams::new(2);
let ivf_model = build_ivf_model(&dataset, "vector", DIM, MetricType::L2, &ivf_params)
let ivf_model = build_ivf_model(&dataset, "vector", None, DIM, MetricType::L2, &ivf_params)
.await
.unwrap();
assert_eq!(2, ivf_model.centroids.as_ref().unwrap().len());
Expand Down Expand Up @@ -2305,9 +2309,16 @@ mod tests {
let (dataset, _) = generate_test_dataset(test_uri, 1000.0..1100.0).await;

let ivf_params = IvfBuildParams::new(2);
let ivf_model = build_ivf_model(&dataset, "vector", DIM, MetricType::Cosine, &ivf_params)
.await
.unwrap();
let ivf_model = build_ivf_model(
&dataset,
"vector",
None,
DIM,
MetricType::Cosine,
&ivf_params,
)
.await
.unwrap();
assert_eq!(2, ivf_model.centroids.as_ref().unwrap().len());
assert_eq!(32, ivf_model.centroids.as_ref().unwrap().value_length());
assert_eq!(2, ivf_model.num_partitions());
Expand Down
15 changes: 11 additions & 4 deletions rust/lance/src/index/vector/pq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ pub async fn build_pq_model(
);
let start = std::time::Instant::now();
let mut training_data =
maybe_sample_training_data(dataset, column, expected_sample_size).await?;
maybe_sample_training_data(dataset, column, None, expected_sample_size).await?;
info!(
"Finished loading training data in {:02} seconds",
start.elapsed().as_secs_f32()
Expand Down Expand Up @@ -572,9 +572,16 @@ mod tests {
let (dataset, vectors) = generate_dataset(test_uri, 100.0..120.0).await;

let ivf_params = IvfBuildParams::new(4);
let ivf = build_ivf_model(&dataset, "vector", DIM, MetricType::Cosine, &ivf_params)
.await
.unwrap();
let ivf = build_ivf_model(
&dataset,
"vector",
None,
DIM,
MetricType::Cosine,
&ivf_params,
)
.await
.unwrap();
let params = PQBuildParams::new(16, 8);
let pq = build_pq_model(
&dataset,
Expand Down
37 changes: 18 additions & 19 deletions rust/lance/src/index/vector/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,17 @@ pub fn get_vector_dim(dataset: &Dataset, column: &str) -> Result<usize> {
message: format!("Column {} does not exist in schema {}", column, schema),
location: location!(),
})?;
let data_type = field.data_type();
if let arrow_schema::DataType::FixedSizeList(_, dim) = data_type {
Ok(dim as usize)
} else {
Err(Error::Index {
message: format!(
"Column {} is not a FixedSizeListArray, but {:?}",
column, data_type
),
infer_vector_dim(&field.data_type())
}

fn infer_vector_dim(data_type: &arrow_schema::DataType) -> Result<usize> {
match data_type {
arrow_schema::DataType::FixedSizeList(_, dim) => Ok(*dim as usize),
arrow_schema::DataType::List(field) => infer_vector_dim(field.data_type()),
_ => Err(Error::Index {
message: format!("Column is not a FixedSizeListArray, but {:?}", data_type),
location: location!(),
})
}),
}
}

Expand All @@ -40,15 +40,20 @@ pub fn get_vector_dim(dataset: &Dataset, column: &str) -> Result<usize> {
pub async fn maybe_sample_training_data(
dataset: &Dataset,
column: &str,
modal_index: Option<usize>,
sample_size_hint: usize,
) -> Result<FixedSizeListArray> {
let num_rows = dataset.count_rows(None).await?;
let projection = dataset.schema().project(&[column])?;
let vector_column = match modal_index {
Some(index) => format!("{}[{}]", column, index),
None => column.to_string(),
};
let projection = dataset.schema().project(&[&vector_column])?;
let batch = if num_rows > sample_size_hint {
dataset.sample(sample_size_hint, &projection).await?
} else {
let mut scanner = dataset.scan();
scanner.project(&[column])?;
scanner.project(&[vector_column])?;
let batches = scanner
.try_into_stream()
.await?
Expand All @@ -57,13 +62,7 @@ pub async fn maybe_sample_training_data(
concat_batches(&Arc::new(ArrowSchema::from(&projection)), &batches)?
};

let array = batch.column_by_name(column).ok_or(Error::Index {
message: format!(
"Sample training data: column {} does not exist in return",
column
),
location: location!(),
})?;
let array = batch.column(0);
Ok(array.as_fixed_size_list().clone())
}

Expand Down

0 comments on commit 55b0e08

Please sign in to comment.