From cd51d55c9dd804f8767c9e492341c87f2c961b64 Mon Sep 17 00:00:00 2001 From: Hugo Larcher Date: Thu, 17 Oct 2024 17:29:50 +0200 Subject: [PATCH] fix: customize model name (for TRTLLM). Allow for empty content chunks (TRTLLM). --- build.rs | 40 +++++++++++++++++++++++++--------------- src/lib.rs | 2 +- src/main.rs | 15 +++++++++++---- src/requests.rs | 9 +++++++-- 4 files changed, 44 insertions(+), 22 deletions(-) diff --git a/build.rs b/build.rs index 781b4c2..9701ee3 100644 --- a/build.rs +++ b/build.rs @@ -1,20 +1,30 @@ use std::error::Error; -// use vergen_gitcl::{Emitter, GitclBuilder}; +use vergen_gitcl::{Emitter, GitclBuilder}; fn main() -> Result<(), Box> { - // // Try to get the git sha from the local git repository - // let gitcl = GitclBuilder::all_git()?; - // if Emitter::default() - // .fail_on_error() - // .add_instructions(&gitcl)? - // .emit() - // .is_err() - // { - // // Unable to get the git sha - // if let Ok(sha) = std::env::var("GIT_SHA") { - // // Set it from an env var - // println!("cargo:rustc-env=VERGEN_GIT_SHA={sha}"); - // } - // } + // Try to get the git sha from the local git repository + let gitcl = match GitclBuilder::all_git() { + Ok(gitcl) => gitcl, + Err(_) => { + fallback_git_sha(); + return Ok(()); + } + }; + if Emitter::default() + .fail_on_error() + .add_instructions(&gitcl)? + .emit() + .is_err() + { + fallback_git_sha(); + } Ok(()) } + +fn fallback_git_sha() { + // Unable to get the git sha + if let Ok(sha) = std::env::var("GIT_SHA") { + // Set it from an env var + println!("cargo:rustc-env=VERGEN_GIT_SHA={sha}"); + } +} diff --git a/src/lib.rs b/src/lib.rs index 8f459cf..173d91c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -45,7 +45,7 @@ pub struct RunConfiguration { pub dataset_file: String, pub hf_token: Option, pub extra_metadata: Option>, - pub model_name: String + pub model_name: String, } pub async fn run(run_config: RunConfiguration, stop_sender: Sender<()>) -> anyhow::Result<()> { diff --git a/src/main.rs b/src/main.rs index dd2ba2d..6fbdfd1 100644 --- a/src/main.rs +++ b/src/main.rs @@ -133,8 +133,8 @@ fn parse_tokenizer_options(s: &str) -> Result { } if tokenizer_options.num_tokens.is_some() && (tokenizer_options.num_tokens.unwrap() == 0 - || tokenizer_options.min_tokens == 0 - || tokenizer_options.max_tokens == 0) + || tokenizer_options.min_tokens == 0 + || tokenizer_options.max_tokens == 0) { return Err(Error::new(InvalidValue)); } @@ -148,7 +148,11 @@ fn parse_tokenizer_options(s: &str) -> Result { async fn main() { let args = Args::parse(); let git_sha = option_env!("VERGEN_GIT_SHA").unwrap_or("unknown"); - println!("Text Generation Inference Benchmark {} ({})", env!("CARGO_PKG_VERSION"), git_sha); + println!( + "Text Generation Inference Benchmark {} ({})", + env!("CARGO_PKG_VERSION"), + git_sha + ); let (stop_sender, _) = broadcast::channel(1); // handle ctrl-c @@ -171,7 +175,10 @@ async fn main() { Some(token) => Some(token), None => cache.token(), }; - let model_name = args.model_name.clone().unwrap_or(args.tokenizer_name.clone()); + let model_name = args + .model_name + .clone() + .unwrap_or(args.tokenizer_name.clone()); let run_config = RunConfiguration { url: args.url.clone(), tokenizer_name: args.tokenizer_name.clone(), diff --git a/src/requests.rs b/src/requests.rs index 2b37df1..ffef36f 100644 --- a/src/requests.rs +++ b/src/requests.rs @@ -71,7 +71,7 @@ pub struct OpenAITextGenerationMessage { #[derive(Deserialize, Serialize, Clone, Debug)] pub struct OpenAITextGenerationDelta { - pub content: String, + pub content: Option, } #[derive(Deserialize, Serialize, Clone, Debug)] @@ -187,7 +187,12 @@ impl TextGenerationBackend for OpenAITextGenerationBackend { } }; let choices = oai_response.choices; - let content = choices[0].clone().delta.unwrap().content; + let content = choices[0] + .clone() + .delta + .unwrap() + .content + .unwrap_or("".to_string()); if content.is_empty() { // skip empty responses continue;