-
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' of https://github.com/monatis/letsearch
- Loading branch information
Showing
11 changed files
with
163 additions
and
38 deletions.
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
[package] | ||
name = "letsearch" | ||
version = "0.1.4" | ||
version = "0.1.9" | ||
edition = "2021" | ||
|
||
[dependencies] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
use crate::collection::collection_utils::home_dir; | ||
use hf_hub::api::sync::ApiBuilder; | ||
use std::fs; | ||
|
||
pub fn download_model( | ||
model_path: String, | ||
variant: String, | ||
token: Option<String>, | ||
) -> anyhow::Result<(String, String)> { | ||
// Build the Hugging Face API instance | ||
let cache_dir = home_dir().join("models").to_path_buf(); | ||
let api = ApiBuilder::new() | ||
.with_token(token) | ||
.with_cache_dir(cache_dir) | ||
.build()?; | ||
let model_path = model_path.replace("hf://", ""); | ||
let repo = api.model(model_path); | ||
let config_path = repo.get("metadata.json")?; | ||
|
||
// Read the metadata.json file | ||
let config_content = fs::read_to_string(config_path)?; | ||
let config: serde_json::Value = serde_json::from_str(&config_content)?; | ||
|
||
// Parse the "letsearch_version" and "variants" | ||
let version = config["letsearch_version"].as_i64().ok_or_else(|| { | ||
anyhow::anyhow!("This is probably not a letsearch-compatible model. Check it out") | ||
})?; | ||
assert_eq!(version, 1); | ||
|
||
let variants = config["variants"] | ||
.as_array() | ||
.ok_or_else(|| anyhow::anyhow!("This is probably not a letsearch model. check it out"))?; | ||
|
||
// Check if the requested variant exists | ||
let variant_info = variants | ||
.iter() | ||
.find(|v| v["variant"] == variant) | ||
.ok_or_else(|| anyhow::anyhow!("Variant not found in config"))?; | ||
|
||
// Download the ONNX model for the specified variant | ||
let model_file = match variant_info["path"].as_str() { | ||
Some(model_path) => repo.get(model_path)?, | ||
_ => unreachable!("unreachable"), | ||
}; | ||
|
||
if let Some(required_files) = config["required_files"].as_array() { | ||
for file in required_files { | ||
repo.get(file.as_str().unwrap())?; | ||
} | ||
} | ||
|
||
let model_dir = model_file.parent().unwrap().to_str().unwrap().to_string(); | ||
let model_file = model_file | ||
.file_name() | ||
.unwrap() | ||
.to_str() | ||
.unwrap() | ||
.to_string(); | ||
|
||
Ok((model_dir, model_file)) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,7 +12,7 @@ use std::io::Write; | |
#[derive(Parser, Debug)] | ||
#[command( | ||
name = "letsearch", | ||
version = "0.1.2", | ||
version = "0.1.9", | ||
author = "[email protected]", | ||
about = "Index and search your documents, and serve it if you wish", | ||
subcommand_required = true, | ||
|
@@ -38,10 +38,19 @@ pub enum Commands { | |
#[arg(short, long, required = true)] | ||
collection_name: String, | ||
|
||
/// Model to create embeddings | ||
/// Model to create embeddings. | ||
/// You can also give a hf:// path and it will be automatically downloaded. | ||
#[arg(short, long, required = true)] | ||
model: String, | ||
|
||
/// model variant. f32, f16 and i8 are supported for now. | ||
#[arg(short, long, default_value = "f32")] | ||
variant: String, | ||
|
||
/// HuggingFace token. Only needed when you want to access private repos | ||
#[arg(long)] | ||
hf_token: Option<String>, | ||
|
||
/// batch size when embedding texts | ||
#[arg(short, long, default_value = "32")] | ||
batch_size: u64, | ||
|
@@ -70,6 +79,10 @@ pub enum Commands { | |
/// port to listen to | ||
#[arg(short, long, default_value = "7898")] | ||
port: i32, | ||
|
||
/// HuggingFace token. Only needed when you want to access private repos | ||
#[arg(long)] | ||
hf_token: Option<String>, | ||
}, | ||
} | ||
|
||
|
@@ -96,6 +109,8 @@ async fn main() -> anyhow::Result<()> { | |
files, | ||
collection_name, | ||
model, | ||
variant, | ||
hf_token, | ||
batch_size, | ||
index_columns, | ||
overwrite, | ||
|
@@ -104,7 +119,8 @@ async fn main() -> anyhow::Result<()> { | |
config.name = collection_name.to_string(); | ||
config.index_columns = index_columns.to_vec(); | ||
config.model_name = model.to_string(); | ||
let collection_manager = CollectionManager::new(); | ||
config.model_variant = variant.to_string(); | ||
let collection_manager = CollectionManager::new(hf_token.to_owned()); | ||
collection_manager | ||
.create_collection(config, overwrite.to_owned()) | ||
.await?; | ||
|
@@ -133,11 +149,13 @@ async fn main() -> anyhow::Result<()> { | |
collection_name, | ||
host, | ||
port, | ||
hf_token, | ||
} => { | ||
run_server( | ||
host.to_string(), | ||
port.to_owned(), | ||
collection_name.to_string(), | ||
hf_token.to_owned(), | ||
) | ||
.await?; | ||
} | ||
|
@@ -147,5 +165,6 @@ async fn main() -> anyhow::Result<()> { | |
} | ||
|
||
mod collection; | ||
mod hf_ops; | ||
mod model; | ||
mod serve; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.