Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Whisper] Pipeline: handle long form generation #35750

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 25 additions & 2 deletions src/transformers/models/whisper/tokenization_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -908,7 +908,7 @@ def _convert_to_list(token_ids):
return token_ids


def _decode_asr(tokenizer, model_outputs, *, return_timestamps, return_language, time_precision):
def _decode_asr(tokenizer, model_outputs, *, return_timestamps, return_language, time_precision, segment_size=1500):
"""
Internal method meant to only be used by asr pipeline. Handles all the little quirks specific to whisper to handle
the various options not allowed in other seq2seq models
Expand Down Expand Up @@ -960,6 +960,12 @@ def new_chunk():
last_timestamp = None
first_timestamp = timestamp_begin

# long form generation: we need to handle the case where the call to generate returns concatenated segments,
# with underlying multiple calls to generate
cur_max_timestamp = 0.0
prev_segments_len = 0.0
penultimate_timestamp = 0.0

if "stride" in output:
chunk_len, stride_left, stride_right = output["stride"]
# Offset the timings to account for the other `model_outputs`.
Expand Down Expand Up @@ -1022,7 +1028,24 @@ def new_chunk():
pass
elif token >= timestamp_begin:
# 3/ Timestamp token
time = (token - timestamp_begin) * time_precision + time_offset

timestamp = float((token - timestamp_begin) * time_precision)
if timestamp < cur_max_timestamp:
# next segment has started
last_was_single_ending = i >= 2 and not (
token_ids[i - 1] >= timestamp_begin and token_ids[i - 2] >= timestamp_begin
)
if last_was_single_ending:
prev_segments_len += time_precision * segment_size
else:
cur_max_timestamp = penultimate_timestamp
prev_segments_len += penultimate_timestamp

penultimate_timestamp = cur_max_timestamp
cur_max_timestamp = timestamp

time = (token - timestamp_begin) * time_precision + time_offset + prev_segments_len

time = round(time, 2)
if last_timestamp and token >= last_timestamp:
# Whisper outputted a timestamp token, but it falls within
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/pipelines/automatic_speech_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,11 @@ def _sanitize_parameters(
" https://github.com/huggingface/transformers/pull/20104. Ignore this warning with pipeline(...,"
" ignore_warning=True)"
)
elif self.type == "seq2seq_whisper" and not ignore_warning:
logger.warning(
"Using `chunk_length_s` with Whisper models is not recommended and will result in unreliable results, as it uses it's own chunking mechanism "
"(cf. Whisper original paper, section 3.8. Long-form Transcription)."
Comment on lines +306 to +309
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I mentioned offline would be a pity to not use that batch algo in some cases! But up to debate!

)
preprocess_params["chunk_length_s"] = chunk_length_s
if stride_length_s is not None:
preprocess_params["stride_length_s"] = stride_length_s
Expand Down