Skip to content

Commit

Permalink
chore: Cleanup.
Browse files Browse the repository at this point in the history
  • Loading branch information
Hugoch committed Oct 10, 2024
1 parent a17e35d commit ecaa400
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 31 deletions.
5 changes: 5 additions & 0 deletions .github/workflows/rust.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
on:
workflow_dispatch:
push:
branches:
- 'main'
tags:
- 'v*'
pull_request:
paths:
- 'src/**'
Expand Down
8 changes: 3 additions & 5 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ mod flux;
mod requests;
mod results;
mod scheduler;
mod table;
mod tokens;
mod writers;
mod table;

pub struct RunConfiguration {
pub url: String,
Expand Down Expand Up @@ -147,7 +147,7 @@ pub async fn run(run_config: RunConfiguration, stop_sender: Sender<()>) -> anyho
run_config.dataset_file,
run_config.hf_token.clone(),
)
.expect("Can't download dataset");
.expect("Can't download dataset");
let requests = requests::ConversationTextRequestGenerator::load(
filepath,
run_config.tokenizer_name.clone(),
Expand All @@ -167,7 +167,7 @@ pub async fn run(run_config: RunConfiguration, stop_sender: Sender<()>) -> anyho
tokio::select! {
report = benchmark.run() => {
match report {
Ok(results) => {
Ok(_) => {
let report = benchmark.get_report();
let path = format!("results/{}_{}.json",run_config.tokenizer_name.replace("/","_").replace(".","_"), chrono::Utc::now().format("%Y-%m-%d-%H-%M-%S"));
let path=Path::new(&path);
Expand Down Expand Up @@ -204,5 +204,3 @@ pub async fn run(run_config: RunConfiguration, stop_sender: Sender<()>) -> anyho

Ok(())
}


18 changes: 9 additions & 9 deletions src/requests.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use std::fmt::Display;
use async_trait::async_trait;
use futures_util::StreamExt;
use hf_hub::api::sync::ApiBuilder;
Expand All @@ -9,6 +8,7 @@ use rayon::iter::split;
use rayon::prelude::*;
use reqwest_eventsource::{Error, Event, EventSource};
use serde::{Deserialize, Serialize};
use std::fmt::Display;
use std::path::PathBuf;
use std::sync::atomic::AtomicI64;
use std::sync::{Arc, Mutex};
Expand Down Expand Up @@ -354,7 +354,7 @@ impl ConversationTextRequestGenerator {
ProgressStyle::with_template(
"Tokenizing prompts [{elapsed_precise}] {bar:40.cyan/blue} {pos:>7}/{len:7} {msg}",
)
.unwrap(),
.unwrap(),
);
split(data, entry_splitter).for_each(|subrange| {
for entry in subrange {
Expand Down Expand Up @@ -687,7 +687,7 @@ mod tests {
"gpt2".to_string(),
tokenizer,
)
.unwrap();
.unwrap();
let request = TextGenerationRequest {
prompt: "Hello, world!".to_string(),
num_prompt_tokens: 2,
Expand Down Expand Up @@ -745,7 +745,7 @@ mod tests {
"gpt2".to_string(),
tokenizer,
)
.unwrap();
.unwrap();
let request = TextGenerationRequest {
prompt: "Hello, world!".to_string(),
num_prompt_tokens: 2,
Expand Down Expand Up @@ -797,7 +797,7 @@ mod tests {
assert!(
inter_token_latency_avg > expected_inter_token_latency_avg
&& inter_token_latency_avg
< expected_inter_token_latency_avg + inter_token_latency_overhead,
< expected_inter_token_latency_avg + inter_token_latency_overhead,
"inter_token_latency_avg: {:?} < {:?} < {:?}",
expected_inter_token_latency_avg,
inter_token_latency_avg,
Expand Down Expand Up @@ -829,7 +829,7 @@ mod tests {
"gpt2".to_string(),
tokenizer,
)
.unwrap();
.unwrap();
let request = TextGenerationRequest {
prompt: "Hello, world!".to_string(),
num_prompt_tokens: 2,
Expand Down Expand Up @@ -874,7 +874,7 @@ mod tests {
"gpt2".to_string(),
tokenizer,
)
.unwrap();
.unwrap();
let request = TextGenerationRequest {
prompt: "Hello, world!".to_string(),
num_prompt_tokens: 2,
Expand Down Expand Up @@ -919,7 +919,7 @@ mod tests {
"gpt2".to_string(),
tokenizer,
)
.unwrap();
.unwrap();
let request = TextGenerationRequest {
prompt: "Hello, world!".to_string(),
num_prompt_tokens: 2,
Expand Down Expand Up @@ -967,7 +967,7 @@ mod tests {
"gpt2".to_string(),
tokenizer,
)
.unwrap();
.unwrap();
let request = TextGenerationRequest {
prompt: "Hello, world!".to_string(),
num_prompt_tokens: 2,
Expand Down
70 changes: 54 additions & 16 deletions src/table.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,35 @@
use tabled::builder::Builder;
use crate::BenchmarkConfig;
use crate::results::BenchmarkReport;
use crate::BenchmarkConfig;
use tabled::builder::Builder;

pub fn parameters_table(benchmark: BenchmarkConfig) -> tabled::Table {
let mut builder = Builder::default();
let rates = benchmark.rates.map_or("N/A".to_string(), |e| format!("{:?}", e));
let prompt_options = benchmark.prompt_options.map_or("N/A".to_string(), |e| format!("{}", e));
let decode_options = benchmark.decode_options.map_or("N/A".to_string(), |e| format!("{}", e));
let extra_metadata = benchmark.extra_metadata.map_or("N/A".to_string(), |e| format!("{:?}", e));
let rates = benchmark
.rates
.map_or("N/A".to_string(), |e| format!("{:?}", e));
let prompt_options = benchmark
.prompt_options
.map_or("N/A".to_string(), |e| format!("{}", e));
let decode_options = benchmark
.decode_options
.map_or("N/A".to_string(), |e| format!("{}", e));
let extra_metadata = benchmark
.extra_metadata
.map_or("N/A".to_string(), |e| format!("{:?}", e));
builder.set_header(vec!["Parameter", "Value"]);
builder.push_record(vec!["Max VUs", benchmark.max_vus.to_string().as_str()]);
builder.push_record(vec!["Duration", benchmark.duration.as_secs().to_string().as_str()]);
builder.push_record(vec!["Warmup Duration", benchmark.warmup_duration.as_secs().to_string().as_str()]);
builder.push_record(vec!["Benchmark Kind", benchmark.benchmark_kind.to_string().as_str()]);
builder.push_record(vec![
"Duration",
benchmark.duration.as_secs().to_string().as_str(),
]);
builder.push_record(vec![
"Warmup Duration",
benchmark.warmup_duration.as_secs().to_string().as_str(),
]);
builder.push_record(vec![
"Benchmark Kind",
benchmark.benchmark_kind.to_string().as_str(),
]);
builder.push_record(vec!["Rates", rates.as_str()]);
builder.push_record(vec!["Num Rates", benchmark.num_rates.to_string().as_str()]);
builder.push_record(vec!["Prompt Options", prompt_options.as_str()]);
Expand All @@ -26,20 +43,41 @@ pub fn parameters_table(benchmark: BenchmarkConfig) -> tabled::Table {

pub fn results_table(benchmark: BenchmarkReport) -> tabled::Table {
let mut builder = Builder::default();
builder.set_header(vec!["Benchmark", "QPS", "E2E Latency", "TTFT", "ITL", "Throughput", "Error Rate"]);
builder.set_header(vec![
"Benchmark",
"QPS",
"E2E Latency",
"TTFT",
"ITL",
"Throughput",
"Error Rate",
]);
let results = benchmark.get_results();
for result in results {
let qps = format!("{:.2} req/s", result.successful_request_rate().unwrap());
let e2e = format!("{:.2} sec", result.e2e_latency_avg().unwrap().as_secs_f64());
let ttft = format!("{:.2} ms", result.time_to_first_token_avg().unwrap().as_micros() as f64 / 1000.0);
let itl = format!("{:.2} ms", result.inter_token_latency_avg().unwrap().as_micros() as f64 / 1000.0);
let ttft = format!(
"{:.2} ms",
result.time_to_first_token_avg().unwrap().as_micros() as f64 / 1000.0
);
let itl = format!(
"{:.2} ms",
result.inter_token_latency_avg().unwrap().as_micros() as f64 / 1000.0
);
let throughput = format!("{:.2} tokens/sec", result.token_throughput_secs().unwrap());
let error_rate = result.failed_requests() / result.total_requests();
let error_rate = format!("{:.2}%", error_rate as f64 * 100.0);
builder.push_record(vec![result.id.as_str(), qps.as_str(), e2e.as_str(), ttft.as_str(), itl.as_str(), throughput.as_str(), error_rate.as_str()]);
builder.push_record(vec![
result.id.as_str(),
qps.as_str(),
e2e.as_str(),
ttft.as_str(),
itl.as_str(),
throughput.as_str(),
error_rate.as_str(),
]);
}
let mut table = builder.build();
table.with(tabled::settings::Style::sharp());
table
.with(tabled::settings::Style::sharp());
table
}
}
2 changes: 1 addition & 1 deletion src/writers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ impl BenchmarkReportWriter {
Ok(())
}

pub async fn stdout(&self){
pub async fn stdout(&self) {
let param_table = table::parameters_table(self.config.clone());
println!("\n{param_table}\n");
let results_table = table::results_table(self.report.clone());
Expand Down

0 comments on commit ecaa400

Please sign in to comment.