diff --git a/Cargo.lock b/Cargo.lock index eba012b..200c841 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1689,7 +1689,7 @@ checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" [[package]] name = "letsearch" -version = "0.1.4" +version = "0.1.9" dependencies = [ "actix-web", "anyhow", diff --git a/Cargo.toml b/Cargo.toml index 19e427c..2c07117 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "letsearch" -version = "0.1.4" +version = "0.1.9" edition = "2021" [dependencies] diff --git a/src/collection/collection_manager.rs b/src/collection/collection_manager.rs index 46ec617..4f36b48 100644 --- a/src/collection/collection_manager.rs +++ b/src/collection/collection_manager.rs @@ -10,15 +10,17 @@ use super::collection_utils::{CollectionConfig, SearchResult}; pub struct CollectionManager { collections: RwLock>>>, model_manager: Arc>, - model_lookup: RwLock>, + model_lookup: RwLock>, + token: Option, } impl CollectionManager { - pub fn new() -> Self { + pub fn new(token: Option) -> Self { CollectionManager { collections: RwLock::new(HashMap::new()), model_manager: Arc::new(RwLock::new(ModelManager::new())), model_lookup: RwLock::new(HashMap::new()), + token: token, } } @@ -30,12 +32,18 @@ impl CollectionManager { let manager_guard = self.model_manager.write().await; for requested_model in requested_models { let mut lookup_guard = self.model_lookup.write().await; - if !lookup_guard.contains_key(requested_model.as_str()) { + if !lookup_guard.contains_key(&requested_model) { + let (model_path, model_variant) = requested_model.clone(); let model_id = manager_guard - .load_model(requested_model.clone(), Backend::ONNX) + .load_model( + model_path.clone(), + model_variant.clone(), + Backend::ONNX, + self.token.clone(), + ) .await .unwrap(); - lookup_guard.insert(requested_model.clone(), model_id); + lookup_guard.insert(requested_model, model_id); } } } @@ -59,12 +67,18 @@ impl CollectionManager { let manager_guard = self.model_manager.write().await; for requested_model in requested_models { let mut lookup_guard = self.model_lookup.write().await; - if !lookup_guard.contains_key(requested_model.as_str()) { + if !lookup_guard.contains_key(&requested_model) { + let (model_path, model_variant) = requested_model.clone(); let model_id = manager_guard - .load_model(requested_model.clone(), Backend::ONNX) + .load_model( + model_path.clone(), + model_variant.clone(), + Backend::ONNX, + self.token.clone(), + ) .await .unwrap(); - lookup_guard.insert(requested_model.clone(), model_id); + lookup_guard.insert(requested_model, model_id); } } } @@ -175,14 +189,16 @@ impl CollectionManager { }; // Fetch model ID - let model_name = collection.read().await.config().model_name; - let model_id = { - let lookup_guard = self.model_lookup.read().await; - lookup_guard - .get(&model_name) - .copied() - .ok_or_else(|| anyhow::anyhow!("Model '{}' is not loaded", model_name))? - }; + let config = collection.read().await.config(); + let model = (config.model_name, config.model_variant); + + let model_id = self + .model_lookup + .read() + .await + .get(&model) + .copied() + .ok_or_else(|| anyhow::anyhow!("Model '{:?}' is not loaded", model))?; // Perform embedding let mut collection_guard = collection.write().await; @@ -212,12 +228,14 @@ impl CollectionManager { .ok_or_else(|| { return anyhow::anyhow!("Collection '{}' does not exist", collection_name); })?; - let model_name = collection.read().await.config().model_name; + let config = collection.read().await.config(); + let model = (config.model_name, config.model_variant); + let model_id = self .model_lookup .read() .await - .get(model_name.as_str()) + .get(&model) .copied() .ok_or_else(|| { return anyhow::anyhow!( diff --git a/src/collection/collection_type.rs b/src/collection/collection_type.rs index e1df691..3017c65 100644 --- a/src/collection/collection_type.rs +++ b/src/collection/collection_type.rs @@ -340,8 +340,11 @@ impl Collection { Ok(()) } - pub async fn requested_models(&self) -> Vec { - vec![self.config.model_name.clone()] + pub async fn requested_models(&self) -> Vec<(String, String)> { + vec![( + self.config.model_name.clone(), + self.config.model_variant.clone(), + )] } pub async fn search( diff --git a/src/collection/collection_utils.rs b/src/collection/collection_utils.rs index 6991d90..a0a30b3 100644 --- a/src/collection/collection_utils.rs +++ b/src/collection/collection_utils.rs @@ -17,6 +17,8 @@ pub struct CollectionConfig { pub index_columns: Vec, #[serde(default = "default_model_name")] pub model_name: String, + #[serde(default = "default_model_variant")] + pub model_variant: String, #[serde(default = "default_db_path")] pub db_path: String, #[serde(default = "default_index_dir")] @@ -34,7 +36,11 @@ fn default_index_columns() -> Vec { } fn default_model_name() -> String { - String::from("minilm") + String::from("mys/minilm") +} + +fn default_model_variant() -> String { + String::from("f32") } fn default_db_path() -> String { @@ -55,6 +61,7 @@ impl CollectionConfig { name: default_collection_name(), index_columns: default_index_columns(), model_name: default_model_name(), + model_variant: default_model_variant(), db_path: default_db_path(), index_dir: default_index_dir(), serialization_version: default_serialization_version(), diff --git a/src/hf_ops.rs b/src/hf_ops.rs new file mode 100644 index 0000000..e128fa4 --- /dev/null +++ b/src/hf_ops.rs @@ -0,0 +1,61 @@ +use crate::collection::collection_utils::home_dir; +use hf_hub::api::sync::ApiBuilder; +use std::fs; + +pub fn download_model( + model_path: String, + variant: String, + token: Option, +) -> anyhow::Result<(String, String)> { + // Build the Hugging Face API instance + let cache_dir = home_dir().join("models").to_path_buf(); + let api = ApiBuilder::new() + .with_token(token) + .with_cache_dir(cache_dir) + .build()?; + let model_path = model_path.replace("hf://", ""); + let repo = api.model(model_path); + let config_path = repo.get("metadata.json")?; + + // Read the metadata.json file + let config_content = fs::read_to_string(config_path)?; + let config: serde_json::Value = serde_json::from_str(&config_content)?; + + // Parse the "letsearch_version" and "variants" + let version = config["letsearch_version"].as_i64().ok_or_else(|| { + anyhow::anyhow!("This is probably not a letsearch-compatible model. Check it out") + })?; + assert_eq!(version, 1); + + let variants = config["variants"] + .as_array() + .ok_or_else(|| anyhow::anyhow!("This is probably not a letsearch model. check it out"))?; + + // Check if the requested variant exists + let variant_info = variants + .iter() + .find(|v| v["variant"] == variant) + .ok_or_else(|| anyhow::anyhow!("Variant not found in config"))?; + + // Download the ONNX model for the specified variant + let model_file = match variant_info["path"].as_str() { + Some(model_path) => repo.get(model_path)?, + _ => unreachable!("unreachable"), + }; + + if let Some(required_files) = config["required_files"].as_array() { + for file in required_files { + repo.get(file.as_str().unwrap())?; + } + } + + let model_dir = model_file.parent().unwrap().to_str().unwrap().to_string(); + let model_file = model_file + .file_name() + .unwrap() + .to_str() + .unwrap() + .to_string(); + + Ok((model_dir, model_file)) +} diff --git a/src/main.rs b/src/main.rs index 5d67fdf..3f4ff8f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -12,7 +12,7 @@ use std::io::Write; #[derive(Parser, Debug)] #[command( name = "letsearch", - version = "0.1.2", + version = "0.1.9", author = "yusufsarigoz@gmail.com", about = "Index and search your documents, and serve it if you wish", subcommand_required = true, @@ -38,10 +38,19 @@ pub enum Commands { #[arg(short, long, required = true)] collection_name: String, - /// Model to create embeddings + /// Model to create embeddings. + /// You can also give a hf:// path and it will be automatically downloaded. #[arg(short, long, required = true)] model: String, + /// model variant. f32, f16 and i8 are supported for now. + #[arg(short, long, default_value = "f32")] + variant: String, + + /// HuggingFace token. Only needed when you want to access private repos + #[arg(long)] + hf_token: Option, + /// batch size when embedding texts #[arg(short, long, default_value = "32")] batch_size: u64, @@ -70,6 +79,10 @@ pub enum Commands { /// port to listen to #[arg(short, long, default_value = "7898")] port: i32, + + /// HuggingFace token. Only needed when you want to access private repos + #[arg(long)] + hf_token: Option, }, } @@ -96,6 +109,8 @@ async fn main() -> anyhow::Result<()> { files, collection_name, model, + variant, + hf_token, batch_size, index_columns, overwrite, @@ -104,7 +119,8 @@ async fn main() -> anyhow::Result<()> { config.name = collection_name.to_string(); config.index_columns = index_columns.to_vec(); config.model_name = model.to_string(); - let collection_manager = CollectionManager::new(); + config.model_variant = variant.to_string(); + let collection_manager = CollectionManager::new(hf_token.to_owned()); collection_manager .create_collection(config, overwrite.to_owned()) .await?; @@ -133,11 +149,13 @@ async fn main() -> anyhow::Result<()> { collection_name, host, port, + hf_token, } => { run_server( host.to_string(), port.to_owned(), collection_name.to_string(), + hf_token.to_owned(), ) .await?; } @@ -147,5 +165,6 @@ async fn main() -> anyhow::Result<()> { } mod collection; +mod hf_ops; mod model; mod serve; diff --git a/src/model/backends/onnx/bert_onnx.rs b/src/model/backends/onnx/bert_onnx.rs index 1a60373..bc0a371 100644 --- a/src/model/backends/onnx/bert_onnx.rs +++ b/src/model/backends/onnx/bert_onnx.rs @@ -33,8 +33,8 @@ impl BertONNX { #[async_trait] impl ModelTrait for BertONNX { - async fn load_model(&mut self, model_path: &str) -> anyhow::Result<()> { - let model_source_path = Path::new(model_path); + async fn load_model(&mut self, model_dir: &str, model_file: &str) -> anyhow::Result<()> { + let model_source_path = Path::new(model_dir); ort::init() .with_name("onnx_model") .with_execution_providers([CPUExecutionProvider::default().build()]) @@ -47,11 +47,10 @@ impl ModelTrait for BertONNX { .unwrap() .with_intra_threads(available_parallelism()?.get()) .unwrap() - .commit_from_file(Path::join(model_source_path, "model.onnx")) + .commit_from_file(model_source_path.join(model_file)) .unwrap(); - let mut tokenizer = - Tokenizer::from_file(Path::join(model_source_path, "tokenizer.json")).unwrap(); + let mut tokenizer = Tokenizer::from_file(model_source_path.join("tokenizer.json")).unwrap(); tokenizer.with_padding(Some(PaddingParams { strategy: tokenizers::PaddingStrategy::BatchLongest, @@ -62,8 +61,6 @@ impl ModelTrait for BertONNX { pad_token: "".into(), })); - info!("Model loaded from {}", model_path); - // TODO: instead of using a hardcoded index, // use .filter to get the output tensor by name diff --git a/src/model/model_manager.rs b/src/model/model_manager.rs index 58ff98c..b528ece 100644 --- a/src/model/model_manager.rs +++ b/src/model/model_manager.rs @@ -1,7 +1,9 @@ use super::model_utils::{Backend, Embeddings, ModelOutputDType, ONNXModel}; +use crate::hf_ops::download_model; use crate::model::backends::onnx::bert_onnx::BertONNX; use anyhow::Error; use half::f16; +use log::info; use ndarray::Array2; use std::collections::HashMap; use std::sync::Arc; @@ -20,16 +22,28 @@ impl ModelManager { } } - pub async fn load_model(&self, model_path: String, model_type: Backend) -> anyhow::Result { + pub async fn load_model( + &self, + model_path: String, + model_variant: String, + model_type: Backend, + token: Option, + ) -> anyhow::Result { let model: Arc> = match model_type { Backend::ONNX => Arc::new(RwLock::new(BertONNX::new())), // _ => unreachable!("not implemented"), }; + let (model_dir, model_file) = if model_path.starts_with("hf://") { + download_model(model_path.clone(), model_variant.clone(), token)? + } else { + (model_path.clone(), model_variant.clone()) + }; + { let mut model_guard = model.write().await; model_guard - .load_model(&model_path) + .load_model(model_dir.as_str(), model_file.as_str()) .await .map_err(|e| Error::msg(e.to_string()))?; } @@ -40,6 +54,7 @@ impl ModelManager { let mut models = self.models.write().await; models.insert(model_id, model); + info!("Model loaded from {}", model_path.as_str()); Ok(model_id) } diff --git a/src/model/model_utils.rs b/src/model/model_utils.rs index 645c6ce..33965e4 100644 --- a/src/model/model_utils.rs +++ b/src/model/model_utils.rs @@ -23,7 +23,7 @@ pub enum Embeddings { #[async_trait] pub trait ModelTrait { - async fn load_model(&mut self, model_path: &str) -> anyhow::Result<()>; + async fn load_model(&mut self, model_dir: &str, model_file: &str) -> anyhow::Result<()>; #[allow(dead_code)] async fn unload_model(&self) -> anyhow::Result<()>; } diff --git a/src/serve.rs b/src/serve.rs index 9faee1b..cb4ec18 100644 --- a/src/serve.rs +++ b/src/serve.rs @@ -156,8 +156,13 @@ async fn search( response } -pub async fn run_server(host: String, port: i32, collection_name: String) -> std::io::Result<()> { - let collection_manager = CollectionManager::new(); +pub async fn run_server( + host: String, + port: i32, + collection_name: String, + token: Option, +) -> std::io::Result<()> { + let collection_manager = CollectionManager::new(token); let _ = collection_manager .load_collection(collection_name) .await