Skip to content

Commit

Permalink
fix: Correctly build token count
Browse files Browse the repository at this point in the history
  • Loading branch information
Hugoch committed Oct 14, 2024
1 parent 9f91e3a commit a406689
Show file tree
Hide file tree
Showing 3 changed files with 156 additions and 74 deletions.
15 changes: 10 additions & 5 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,20 +117,25 @@ fn parse_tokenizer_options(s: &str) -> Result<TokenizeOptions, Error> {
return Err(Error::new(InvalidValue));
}
match key_value[0] {
"num_tokens" => tokenizer_options.num_tokens = key_value[1].parse::<u64>().unwrap(),
"num_tokens" => {
tokenizer_options.num_tokens = Some(key_value[1].parse::<u64>().unwrap())
}
"min_tokens" => tokenizer_options.min_tokens = key_value[1].parse::<u64>().unwrap(),
"max_tokens" => tokenizer_options.max_tokens = key_value[1].parse::<u64>().unwrap(),
"variance" => tokenizer_options.variance = key_value[1].parse::<u64>().unwrap(),
_ => return Err(Error::new(InvalidValue)),
}
}
if tokenizer_options.num_tokens == 0
|| tokenizer_options.min_tokens == 0
|| tokenizer_options.max_tokens == 0
|| tokenizer_options.min_tokens > tokenizer_options.max_tokens
if tokenizer_options.num_tokens.is_some()
&& (tokenizer_options.num_tokens.unwrap() == 0
|| tokenizer_options.min_tokens == 0
|| tokenizer_options.max_tokens == 0)
{
return Err(Error::new(InvalidValue));
}
if tokenizer_options.min_tokens > tokenizer_options.max_tokens {
return Err(Error::new(InvalidValue));
}
Ok(tokenizer_options)
}

Expand Down
208 changes: 139 additions & 69 deletions src/requests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ impl TextGenerationBackend for OpenAITextGenerationBackend {
max_tokens: request.num_decode_tokens,
stream: true,
stop: None,
temperature: 0.0
temperature: 0.0,
};
let req = self
.client
Expand Down Expand Up @@ -293,7 +293,7 @@ pub struct ConversationEntry {

#[derive(Clone, Serialize, Debug)]
pub struct TokenizeOptions {
pub num_tokens: u64,
pub num_tokens: Option<u64>,
pub min_tokens: u64,
pub max_tokens: u64,
pub variance: u64,
Expand All @@ -302,9 +302,9 @@ pub struct TokenizeOptions {
impl TokenizeOptions {
pub fn new() -> Self {
Self {
num_tokens: 0,
num_tokens: None,
min_tokens: 0,
max_tokens: 0,
max_tokens: u64::MAX,
variance: 0,
}
}
Expand All @@ -320,7 +320,7 @@ impl Display for TokenizeOptions {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"num_tokens={num_tokens},min_tokens={min_tokens},max_tokens={max_tokens},variance={variance}",
"num_tokens={num_tokens:?},min_tokens={min_tokens},max_tokens={max_tokens},variance={variance}",
num_tokens = self.num_tokens,
min_tokens = self.min_tokens,
max_tokens = self.max_tokens,
Expand Down Expand Up @@ -358,11 +358,9 @@ impl ConversationTextRequestGenerator {
filepath = filepath.display().to_string()
);
let bar = ProgressBar::new(data.len() as u64);
bar.set_style(
ProgressStyle::with_template(
"Tokenizing prompts [{elapsed_precise}] {bar:40.cyan/blue} {pos:>7}/{len:7} {msg}",
)?,
);
bar.set_style(ProgressStyle::with_template(
"Tokenizing prompts [{elapsed_precise}] {bar:40.cyan/blue} {pos:>7}/{len:7} {msg}",
)?);
split(data, entry_splitter).for_each(|subrange| {
for entry in subrange {
bar.inc(1);
Expand All @@ -376,14 +374,17 @@ impl ConversationTextRequestGenerator {
.map(|c| c.content.clone());
let system_prompt_tokens = match system_prompt {
Some(ref prompt) => {
let (_, num_tokens) =
match tokenize_prompt(prompt.clone(), tokenizer.clone(), None) {
Ok((prompt, num_tokens)) => (prompt, num_tokens),
Err(e) => {
debug!("Error tokenizing system prompt: {e}");
return;
}
};
let (_, num_tokens) = match tokenize_prompt(
prompt.clone(),
tokenizer.clone(),
&TokenizeOptions::default(),
) {
Ok((prompt, num_tokens)) => (prompt, num_tokens),
Err(e) => {
debug!("Error tokenizing system prompt: {e}");
return;
}
};
num_tokens
}
None => 0,
Expand All @@ -397,20 +398,22 @@ impl ConversationTextRequestGenerator {
let num_decode_tokens = decode_tokenize_opts.clone().map_or_else(
|| None,
|opts| {
Some(sample_num_tokens(
opts.num_tokens,
opts.min_tokens,
opts.max_tokens,
opts.variance,
))
opts.num_tokens.map(|num_tokens| {
sample_num_tokens(
num_tokens,
opts.min_tokens,
opts.max_tokens,
opts.variance,
)
})
},
);
match &prompt_tokenize_opts {
None => {
let (_, num_tokens) = match tokenize_prompt(
prompt.clone(),
tokenizer.clone(),
None,
&TokenizeOptions::default(),
) {
Ok((prompt, num_tokens)) => (prompt, num_tokens),
Err(e) => {
Expand All @@ -426,17 +429,11 @@ impl ConversationTextRequestGenerator {
});
}
Some(options) => {
let num_tokens = options.num_tokens;
let min_tokens = options.min_tokens;
let max_tokens = options.max_tokens;
let variance = options.variance;
// compute number of tokens to generate using a Gaussian distribution
let num_tokens =
sample_num_tokens(num_tokens, min_tokens, max_tokens, variance);
let sampled_prompt = match tokenize_prompt(
let (sampled_prompt, prompt_tokens) = match tokenize_prompt(
prompt.clone(),
tokenizer.clone(),
Some(num_tokens),
options,
) {
Ok(prompt) => prompt,
Err(e) => {
Expand All @@ -445,8 +442,8 @@ impl ConversationTextRequestGenerator {
}
};
requests.lock().unwrap().push(TextGenerationRequest {
prompt: sampled_prompt.0,
num_prompt_tokens: num_tokens + system_prompt_tokens,
prompt: sampled_prompt,
num_prompt_tokens: prompt_tokens + system_prompt_tokens,
num_decode_tokens,
system_prompt: system_prompt.clone(),
});
Expand Down Expand Up @@ -520,13 +517,26 @@ impl TextRequestGenerator for ConversationTextRequestGenerator {
fn tokenize_prompt(
prompt: String,
tokenizer: Arc<Tokenizer>,
num_tokens: Option<u64>,
options: &TokenizeOptions,
) -> anyhow::Result<(String, u64)> {
let prompt_tokens = tokenizer
.encode(prompt.clone(), false)
.map_err(|_| anyhow::anyhow!("Error tokenizing prompt"))?;
match num_tokens {
None => Ok((prompt, prompt_tokens.len() as u64)),
match options.num_tokens {
None => {
// check if we have a min/max number of tokens, skip prompts that are too short or too long
if prompt_tokens.len() > options.max_tokens as usize
|| prompt_tokens.len() < options.min_tokens as usize
{
return Err(anyhow::anyhow!(format!(
"Prompt is too short or too long, skipping: {}<{}<{}",
options.min_tokens,
prompt_tokens.len(),
options.max_tokens
)));
}
Ok((prompt, prompt_tokens.len() as u64))
}
Some(num_tokens) => {
if prompt_tokens.len() < num_tokens as usize {
return Err(anyhow::anyhow!(format!(
Expand All @@ -535,36 +545,14 @@ fn tokenize_prompt(
num_tokens
)));
}
// let's do a binary search to find the right number of tokens
let mut low = 1;
let mut high = prompt.len() as u64;
let mut prompt_sub = String::new();
while low < high {
let mid = (low + high) / 2;
prompt_sub = prompt
.chars()
.skip((low - 1) as usize)
.take(high as usize)
.collect::<String>();
let tokenized_len = match tokenizer.encode(prompt_sub.clone(), false) {
Ok(tokens) => tokens.len(),
Err(_) => {
return Err(anyhow::anyhow!("Error tokenizing prompt"));
}
};
match tokenized_len.cmp(&(num_tokens as usize)) {
std::cmp::Ordering::Equal => {
return Ok((prompt_sub.to_string(), num_tokens));
}
std::cmp::Ordering::Greater => {
high = mid;
}
std::cmp::Ordering::Less => {
low = mid + 1;
}
}
}
Ok((prompt_sub.to_string(), prompt_tokens.len() as u64))
let tokens = prompt_tokens
.get_ids()
.iter()
.take(num_tokens as usize)
.copied()
.collect::<Vec<u32>>();
let prompt = tokenizer.decode(&tokens, true).unwrap();
Ok((prompt, num_tokens))
}
}
}
Expand Down Expand Up @@ -1029,7 +1017,7 @@ mod tests {
tokenizer,
time::Duration::from_secs(1),
)
.unwrap();
.unwrap();
let request = TextGenerationRequest {
prompt: "Hello, world!".to_string(),
num_prompt_tokens: 2,
Expand All @@ -1053,4 +1041,86 @@ mod tests {
assert_eq!(responses.len(), 1);
assert_eq!(responses[0].failed, true);
}

/// Test that conversations are correctly loaded
#[tokio::test]
async fn test_load_conversations_from_file() {
let filepath = PathBuf::from("test_data/conversations.json");
let tokenizer = "gpt2".to_string();
let prompt_tokenize_opts = TokenizeOptions::default();
let decode_tokenize_opts = TokenizeOptions::default();
let hf_token = None;
let generator = ConversationTextRequestGenerator::load(
filepath,
tokenizer,
Some(prompt_tokenize_opts),
Some(decode_tokenize_opts),
hf_token,
)
.unwrap();
assert_eq!(generator.requests.len(), 17005);
}

/// Test that conversations are bounded by the min/max number of tokens
#[tokio::test]
async fn test_load_conversations_bounded() {
let filepath = PathBuf::from("test_data/conversations.json");
let tokenizer = "gpt2".to_string();
let prompt_tokenize_opts = TokenizeOptions {
num_tokens: None,
min_tokens: 4,
max_tokens: 1024,
variance: 0,
};
let decode_tokenize_opts = TokenizeOptions::default();
let hf_token = None;
let generator = ConversationTextRequestGenerator::load(
filepath,
tokenizer,
Some(prompt_tokenize_opts),
Some(decode_tokenize_opts),
hf_token,
)
.unwrap();
let min_tokens = generator
.requests
.iter()
.map(|r| r.num_prompt_tokens)
.min()
.unwrap();
let max_tokens = generator
.requests
.iter()
.map(|r| r.num_prompt_tokens)
.max()
.unwrap();
assert!(min_tokens >= 4, "Min tokens: {}", min_tokens);
assert!(max_tokens <= 1024, "Max tokens: {}", max_tokens);
}

/// Test that conversations prompts have the correct number of tokens
#[tokio::test]
async fn test_load_conversations_fixed_tokens() {
let filepath = PathBuf::from("test_data/conversations.json");
let tokenizer = "gpt2".to_string();
let prompt_tokenize_opts = TokenizeOptions {
num_tokens: Some(200),
min_tokens: 200,
max_tokens: 200,
variance: 0,
};
let decode_tokenize_opts = TokenizeOptions::default();
let hf_token = None;
let generator = ConversationTextRequestGenerator::load(
filepath,
tokenizer,
Some(prompt_tokenize_opts),
Some(decode_tokenize_opts),
hf_token,
)
.unwrap();
for r in generator.requests.iter() {
assert_eq!(r.num_prompt_tokens, 200);
}
}
}
7 changes: 7 additions & 0 deletions src/table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ pub fn results_table(benchmark: BenchmarkReport) -> anyhow::Result<tabled::Table
"ITL (avg)",
"Throughput",
"Error Rate",
"Sucessful Requests",
]);
let results = benchmark.get_results();
for result in results {
Expand All @@ -75,6 +76,12 @@ pub fn results_table(benchmark: BenchmarkReport) -> anyhow::Result<tabled::Table
itl.as_str(),
throughput.as_str(),
error_rate.as_str(),
format!(
"{}/{}",
result.successful_requests(),
result.total_requests()
)
.as_str(),
]);
}
let mut table = builder.build();
Expand Down

0 comments on commit a406689

Please sign in to comment.