Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/monatis/letsearch
Browse files Browse the repository at this point in the history
  • Loading branch information
monatis committed Dec 3, 2024
2 parents 2e17bc6 + 72733e0 commit 914891e
Show file tree
Hide file tree
Showing 11 changed files with 163 additions and 38 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "letsearch"
version = "0.1.4"
version = "0.1.9"
edition = "2021"

[dependencies]
Expand Down
54 changes: 36 additions & 18 deletions src/collection/collection_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,17 @@ use super::collection_utils::{CollectionConfig, SearchResult};
pub struct CollectionManager {
collections: RwLock<HashMap<String, Arc<RwLock<Collection>>>>,
model_manager: Arc<RwLock<ModelManager>>,
model_lookup: RwLock<HashMap<String, u32>>,
model_lookup: RwLock<HashMap<(String, String), u32>>,
token: Option<String>,
}

impl CollectionManager {
pub fn new() -> Self {
pub fn new(token: Option<String>) -> Self {
CollectionManager {
collections: RwLock::new(HashMap::new()),
model_manager: Arc::new(RwLock::new(ModelManager::new())),
model_lookup: RwLock::new(HashMap::new()),
token: token,
}
}

Expand All @@ -30,12 +32,18 @@ impl CollectionManager {
let manager_guard = self.model_manager.write().await;
for requested_model in requested_models {
let mut lookup_guard = self.model_lookup.write().await;
if !lookup_guard.contains_key(requested_model.as_str()) {
if !lookup_guard.contains_key(&requested_model) {
let (model_path, model_variant) = requested_model.clone();
let model_id = manager_guard
.load_model(requested_model.clone(), Backend::ONNX)
.load_model(
model_path.clone(),
model_variant.clone(),
Backend::ONNX,
self.token.clone(),
)
.await
.unwrap();
lookup_guard.insert(requested_model.clone(), model_id);
lookup_guard.insert(requested_model, model_id);
}
}
}
Expand All @@ -59,12 +67,18 @@ impl CollectionManager {
let manager_guard = self.model_manager.write().await;
for requested_model in requested_models {
let mut lookup_guard = self.model_lookup.write().await;
if !lookup_guard.contains_key(requested_model.as_str()) {
if !lookup_guard.contains_key(&requested_model) {
let (model_path, model_variant) = requested_model.clone();
let model_id = manager_guard
.load_model(requested_model.clone(), Backend::ONNX)
.load_model(
model_path.clone(),
model_variant.clone(),
Backend::ONNX,
self.token.clone(),
)
.await
.unwrap();
lookup_guard.insert(requested_model.clone(), model_id);
lookup_guard.insert(requested_model, model_id);
}
}
}
Expand Down Expand Up @@ -175,14 +189,16 @@ impl CollectionManager {
};

// Fetch model ID
let model_name = collection.read().await.config().model_name;
let model_id = {
let lookup_guard = self.model_lookup.read().await;
lookup_guard
.get(&model_name)
.copied()
.ok_or_else(|| anyhow::anyhow!("Model '{}' is not loaded", model_name))?
};
let config = collection.read().await.config();
let model = (config.model_name, config.model_variant);

let model_id = self
.model_lookup
.read()
.await
.get(&model)
.copied()
.ok_or_else(|| anyhow::anyhow!("Model '{:?}' is not loaded", model))?;

// Perform embedding
let mut collection_guard = collection.write().await;
Expand Down Expand Up @@ -212,12 +228,14 @@ impl CollectionManager {
.ok_or_else(|| {
return anyhow::anyhow!("Collection '{}' does not exist", collection_name);
})?;
let model_name = collection.read().await.config().model_name;
let config = collection.read().await.config();
let model = (config.model_name, config.model_variant);

let model_id = self
.model_lookup
.read()
.await
.get(model_name.as_str())
.get(&model)
.copied()
.ok_or_else(|| {
return anyhow::anyhow!(
Expand Down
7 changes: 5 additions & 2 deletions src/collection/collection_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -340,8 +340,11 @@ impl Collection {
Ok(())
}

pub async fn requested_models(&self) -> Vec<String> {
vec![self.config.model_name.clone()]
pub async fn requested_models(&self) -> Vec<(String, String)> {
vec![(
self.config.model_name.clone(),
self.config.model_variant.clone(),
)]
}

pub async fn search(
Expand Down
9 changes: 8 additions & 1 deletion src/collection/collection_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ pub struct CollectionConfig {
pub index_columns: Vec<String>,
#[serde(default = "default_model_name")]
pub model_name: String,
#[serde(default = "default_model_variant")]
pub model_variant: String,
#[serde(default = "default_db_path")]
pub db_path: String,
#[serde(default = "default_index_dir")]
Expand All @@ -34,7 +36,11 @@ fn default_index_columns() -> Vec<String> {
}

fn default_model_name() -> String {
String::from("minilm")
String::from("mys/minilm")
}

fn default_model_variant() -> String {
String::from("f32")
}

fn default_db_path() -> String {
Expand All @@ -55,6 +61,7 @@ impl CollectionConfig {
name: default_collection_name(),
index_columns: default_index_columns(),
model_name: default_model_name(),
model_variant: default_model_variant(),
db_path: default_db_path(),
index_dir: default_index_dir(),
serialization_version: default_serialization_version(),
Expand Down
61 changes: 61 additions & 0 deletions src/hf_ops.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
use crate::collection::collection_utils::home_dir;
use hf_hub::api::sync::ApiBuilder;
use std::fs;

pub fn download_model(
model_path: String,
variant: String,
token: Option<String>,
) -> anyhow::Result<(String, String)> {
// Build the Hugging Face API instance
let cache_dir = home_dir().join("models").to_path_buf();
let api = ApiBuilder::new()
.with_token(token)
.with_cache_dir(cache_dir)
.build()?;
let model_path = model_path.replace("hf://", "");
let repo = api.model(model_path);
let config_path = repo.get("metadata.json")?;

// Read the metadata.json file
let config_content = fs::read_to_string(config_path)?;
let config: serde_json::Value = serde_json::from_str(&config_content)?;

// Parse the "letsearch_version" and "variants"
let version = config["letsearch_version"].as_i64().ok_or_else(|| {
anyhow::anyhow!("This is probably not a letsearch-compatible model. Check it out")
})?;
assert_eq!(version, 1);

let variants = config["variants"]
.as_array()
.ok_or_else(|| anyhow::anyhow!("This is probably not a letsearch model. check it out"))?;

// Check if the requested variant exists
let variant_info = variants
.iter()
.find(|v| v["variant"] == variant)
.ok_or_else(|| anyhow::anyhow!("Variant not found in config"))?;

// Download the ONNX model for the specified variant
let model_file = match variant_info["path"].as_str() {
Some(model_path) => repo.get(model_path)?,
_ => unreachable!("unreachable"),
};

if let Some(required_files) = config["required_files"].as_array() {
for file in required_files {
repo.get(file.as_str().unwrap())?;
}
}

let model_dir = model_file.parent().unwrap().to_str().unwrap().to_string();
let model_file = model_file
.file_name()
.unwrap()
.to_str()
.unwrap()
.to_string();

Ok((model_dir, model_file))
}
25 changes: 22 additions & 3 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use std::io::Write;
#[derive(Parser, Debug)]
#[command(
name = "letsearch",
version = "0.1.2",
version = "0.1.9",
author = "[email protected]",
about = "Index and search your documents, and serve it if you wish",
subcommand_required = true,
Expand All @@ -38,10 +38,19 @@ pub enum Commands {
#[arg(short, long, required = true)]
collection_name: String,

/// Model to create embeddings
/// Model to create embeddings.
/// You can also give a hf:// path and it will be automatically downloaded.
#[arg(short, long, required = true)]
model: String,

/// model variant. f32, f16 and i8 are supported for now.
#[arg(short, long, default_value = "f32")]
variant: String,

/// HuggingFace token. Only needed when you want to access private repos
#[arg(long)]
hf_token: Option<String>,

/// batch size when embedding texts
#[arg(short, long, default_value = "32")]
batch_size: u64,
Expand Down Expand Up @@ -70,6 +79,10 @@ pub enum Commands {
/// port to listen to
#[arg(short, long, default_value = "7898")]
port: i32,

/// HuggingFace token. Only needed when you want to access private repos
#[arg(long)]
hf_token: Option<String>,
},
}

Expand All @@ -96,6 +109,8 @@ async fn main() -> anyhow::Result<()> {
files,
collection_name,
model,
variant,
hf_token,
batch_size,
index_columns,
overwrite,
Expand All @@ -104,7 +119,8 @@ async fn main() -> anyhow::Result<()> {
config.name = collection_name.to_string();
config.index_columns = index_columns.to_vec();
config.model_name = model.to_string();
let collection_manager = CollectionManager::new();
config.model_variant = variant.to_string();
let collection_manager = CollectionManager::new(hf_token.to_owned());
collection_manager
.create_collection(config, overwrite.to_owned())
.await?;
Expand Down Expand Up @@ -133,11 +149,13 @@ async fn main() -> anyhow::Result<()> {
collection_name,
host,
port,
hf_token,
} => {
run_server(
host.to_string(),
port.to_owned(),
collection_name.to_string(),
hf_token.to_owned(),
)
.await?;
}
Expand All @@ -147,5 +165,6 @@ async fn main() -> anyhow::Result<()> {
}

mod collection;
mod hf_ops;
mod model;
mod serve;
11 changes: 4 additions & 7 deletions src/model/backends/onnx/bert_onnx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ impl BertONNX {

#[async_trait]
impl ModelTrait for BertONNX {
async fn load_model(&mut self, model_path: &str) -> anyhow::Result<()> {
let model_source_path = Path::new(model_path);
async fn load_model(&mut self, model_dir: &str, model_file: &str) -> anyhow::Result<()> {
let model_source_path = Path::new(model_dir);
ort::init()
.with_name("onnx_model")
.with_execution_providers([CPUExecutionProvider::default().build()])
Expand All @@ -47,11 +47,10 @@ impl ModelTrait for BertONNX {
.unwrap()
.with_intra_threads(available_parallelism()?.get())
.unwrap()
.commit_from_file(Path::join(model_source_path, "model.onnx"))
.commit_from_file(model_source_path.join(model_file))
.unwrap();

let mut tokenizer =
Tokenizer::from_file(Path::join(model_source_path, "tokenizer.json")).unwrap();
let mut tokenizer = Tokenizer::from_file(model_source_path.join("tokenizer.json")).unwrap();

tokenizer.with_padding(Some(PaddingParams {
strategy: tokenizers::PaddingStrategy::BatchLongest,
Expand All @@ -62,8 +61,6 @@ impl ModelTrait for BertONNX {
pad_token: "<pad>".into(),
}));

info!("Model loaded from {}", model_path);

// TODO: instead of using a hardcoded index,
// use .filter to get the output tensor by name

Expand Down
19 changes: 17 additions & 2 deletions src/model/model_manager.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use super::model_utils::{Backend, Embeddings, ModelOutputDType, ONNXModel};
use crate::hf_ops::download_model;
use crate::model::backends::onnx::bert_onnx::BertONNX;
use anyhow::Error;
use half::f16;
use log::info;
use ndarray::Array2;
use std::collections::HashMap;
use std::sync::Arc;
Expand All @@ -20,16 +22,28 @@ impl ModelManager {
}
}

pub async fn load_model(&self, model_path: String, model_type: Backend) -> anyhow::Result<u32> {
pub async fn load_model(
&self,
model_path: String,
model_variant: String,
model_type: Backend,
token: Option<String>,
) -> anyhow::Result<u32> {
let model: Arc<RwLock<dyn ONNXModel>> = match model_type {
Backend::ONNX => Arc::new(RwLock::new(BertONNX::new())),
// _ => unreachable!("not implemented"),
};

let (model_dir, model_file) = if model_path.starts_with("hf://") {
download_model(model_path.clone(), model_variant.clone(), token)?
} else {
(model_path.clone(), model_variant.clone())
};

{
let mut model_guard = model.write().await;
model_guard
.load_model(&model_path)
.load_model(model_dir.as_str(), model_file.as_str())
.await
.map_err(|e| Error::msg(e.to_string()))?;
}
Expand All @@ -40,6 +54,7 @@ impl ModelManager {

let mut models = self.models.write().await;
models.insert(model_id, model);
info!("Model loaded from {}", model_path.as_str());

Ok(model_id)
}
Expand Down
Loading

0 comments on commit 914891e

Please sign in to comment.