From 9ab3f9729fc1444687578a9dc913760b0d8d9963 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Thu, 16 Nov 2023 22:10:31 +0000 Subject: [PATCH] Use the whisper-v3 tokenizer now that it has been added. (#1337) * Use the whisper-v3 tokenizer now that it has been added. * Use the appropriate nospeech token. --- candle-examples/examples/whisper/main.rs | 14 ++++++++------ candle-transformers/src/models/whisper/mod.rs | 2 +- candle-wasm-examples/whisper/src/worker.rs | 8 +++++++- 3 files changed, 16 insertions(+), 8 deletions(-) diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs index d2caebcdca..5be81f2dd2 100644 --- a/candle-examples/examples/whisper/main.rs +++ b/candle-examples/examples/whisper/main.rs @@ -128,7 +128,13 @@ impl Decoder { let transcribe_token = token_id(&tokenizer, m::TRANSCRIBE_TOKEN)?; let translate_token = token_id(&tokenizer, m::TRANSLATE_TOKEN)?; let eot_token = token_id(&tokenizer, m::EOT_TOKEN)?; - let no_speech_token = token_id(&tokenizer, m::NO_SPEECH_TOKEN)?; + let no_speech_token = m::NO_SPEECH_TOKENS + .iter() + .find_map(|token| token_id(&tokenizer, token).ok()); + let no_speech_token = match no_speech_token { + None => anyhow::bail!("unable to find any non-speech token"), + Some(n) => n, + }; Ok(Self { model, rng: rand::rngs::StdRng::seed_from_u64(seed), @@ -512,11 +518,7 @@ fn main() -> Result<()> { ) } else { let config = repo.get("config.json")?; - let tokenizer = if args.model == WhichModel::LargeV3 { - panic!("openai/whisper-large-v3 does not provide a compatible tokenizer.json config at the moment") - } else { - repo.get("tokenizer.json")? - }; + let tokenizer = repo.get("tokenizer.json")?; let model = repo.get("model.safetensors")?; (config, tokenizer, model) }; diff --git a/candle-transformers/src/models/whisper/mod.rs b/candle-transformers/src/models/whisper/mod.rs index bf24045abc..8028cf2c66 100644 --- a/candle-transformers/src/models/whisper/mod.rs +++ b/candle-transformers/src/models/whisper/mod.rs @@ -43,4 +43,4 @@ pub const TRANSCRIBE_TOKEN: &str = "<|transcribe|>"; pub const TRANSLATE_TOKEN: &str = "<|translate|>"; pub const NO_TIMESTAMPS_TOKEN: &str = "<|notimestamps|>"; pub const EOT_TOKEN: &str = "<|endoftext|>"; -pub const NO_SPEECH_TOKEN: &str = "<|nocaptions|>"; +pub const NO_SPEECH_TOKENS: [&str; 2] = ["<|nocaptions|>", "<|nospeech|>"]; diff --git a/candle-wasm-examples/whisper/src/worker.rs b/candle-wasm-examples/whisper/src/worker.rs index 09d4f58077..db5e8bb14c 100644 --- a/candle-wasm-examples/whisper/src/worker.rs +++ b/candle-wasm-examples/whisper/src/worker.rs @@ -129,7 +129,13 @@ impl Decoder { let transcribe_token = token_id(&tokenizer, m::TRANSCRIBE_TOKEN)?; let translate_token = token_id(&tokenizer, m::TRANSLATE_TOKEN)?; let eot_token = token_id(&tokenizer, m::EOT_TOKEN)?; - let no_speech_token = token_id(&tokenizer, m::NO_SPEECH_TOKEN)?; + let no_speech_token = m::NO_SPEECH_TOKENS + .iter() + .find_map(|token| token_id(&tokenizer, token).ok()); + let no_speech_token = match no_speech_token { + None => anyhow::bail!("unable to find any non-speech token"), + Some(n) => n, + }; let seed = 299792458; Ok(Self { model,