Skip to content

Commit

Permalink
fix: customize model name (for TRTLLM). Allow for empty content chunk…
Browse files Browse the repository at this point in the history
…s (TRTLLM).
  • Loading branch information
Hugoch committed Oct 17, 2024
1 parent 8f56497 commit cd51d55
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 22 deletions.
40 changes: 25 additions & 15 deletions build.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,30 @@
use std::error::Error;
// use vergen_gitcl::{Emitter, GitclBuilder};
use vergen_gitcl::{Emitter, GitclBuilder};

fn main() -> Result<(), Box<dyn Error>> {
// // Try to get the git sha from the local git repository
// let gitcl = GitclBuilder::all_git()?;
// if Emitter::default()
// .fail_on_error()
// .add_instructions(&gitcl)?
// .emit()
// .is_err()
// {
// // Unable to get the git sha
// if let Ok(sha) = std::env::var("GIT_SHA") {
// // Set it from an env var
// println!("cargo:rustc-env=VERGEN_GIT_SHA={sha}");
// }
// }
// Try to get the git sha from the local git repository
let gitcl = match GitclBuilder::all_git() {
Ok(gitcl) => gitcl,
Err(_) => {
fallback_git_sha();
return Ok(());
}
};
if Emitter::default()
.fail_on_error()
.add_instructions(&gitcl)?
.emit()
.is_err()
{
fallback_git_sha();
}
Ok(())
}

fn fallback_git_sha() {
// Unable to get the git sha
if let Ok(sha) = std::env::var("GIT_SHA") {
// Set it from an env var
println!("cargo:rustc-env=VERGEN_GIT_SHA={sha}");
}
}
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ pub struct RunConfiguration {
pub dataset_file: String,
pub hf_token: Option<String>,
pub extra_metadata: Option<HashMap<String, String>>,
pub model_name: String
pub model_name: String,
}

pub async fn run(run_config: RunConfiguration, stop_sender: Sender<()>) -> anyhow::Result<()> {
Expand Down
15 changes: 11 additions & 4 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,8 @@ fn parse_tokenizer_options(s: &str) -> Result<TokenizeOptions, Error> {
}
if tokenizer_options.num_tokens.is_some()
&& (tokenizer_options.num_tokens.unwrap() == 0
|| tokenizer_options.min_tokens == 0
|| tokenizer_options.max_tokens == 0)
|| tokenizer_options.min_tokens == 0
|| tokenizer_options.max_tokens == 0)
{
return Err(Error::new(InvalidValue));
}
Expand All @@ -148,7 +148,11 @@ fn parse_tokenizer_options(s: &str) -> Result<TokenizeOptions, Error> {
async fn main() {
let args = Args::parse();
let git_sha = option_env!("VERGEN_GIT_SHA").unwrap_or("unknown");
println!("Text Generation Inference Benchmark {} ({})", env!("CARGO_PKG_VERSION"), git_sha);
println!(
"Text Generation Inference Benchmark {} ({})",
env!("CARGO_PKG_VERSION"),
git_sha
);

let (stop_sender, _) = broadcast::channel(1);
// handle ctrl-c
Expand All @@ -171,7 +175,10 @@ async fn main() {
Some(token) => Some(token),
None => cache.token(),
};
let model_name = args.model_name.clone().unwrap_or(args.tokenizer_name.clone());
let model_name = args
.model_name
.clone()
.unwrap_or(args.tokenizer_name.clone());
let run_config = RunConfiguration {
url: args.url.clone(),
tokenizer_name: args.tokenizer_name.clone(),
Expand Down
9 changes: 7 additions & 2 deletions src/requests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ pub struct OpenAITextGenerationMessage {

#[derive(Deserialize, Serialize, Clone, Debug)]
pub struct OpenAITextGenerationDelta {
pub content: String,
pub content: Option<String>,
}

#[derive(Deserialize, Serialize, Clone, Debug)]
Expand Down Expand Up @@ -187,7 +187,12 @@ impl TextGenerationBackend for OpenAITextGenerationBackend {
}
};
let choices = oai_response.choices;
let content = choices[0].clone().delta.unwrap().content;
let content = choices[0]
.clone()
.delta
.unwrap()
.content
.unwrap_or("".to_string());
if content.is_empty() {
// skip empty responses
continue;
Expand Down

0 comments on commit cd51d55

Please sign in to comment.