Skip to content

Commit

Permalink
Whisper Timestamps and Temperature (#19509)
Browse files Browse the repository at this point in the history
This PR updates exporting and running the Whisper model with beam search
by adding the following.

- Adds temperature as a graph input to the exported model
- Fixes the token ids by adding them as attributes to
`WhisperBeamSearch`
- Fixes the timestamps test cases so they pass now
- Fixes a bug with invoking `torch.onnx.export`
- Cleans up the Whisper scripts and groups the arguments in
`convert_to_onnx.py`
- Adds a `requirements.txt` file to specify package dependencies
- Adds `whisper-large-v3` to list of pretrained models
- Fixes a bug with missing cross-attention KV cache inputs in the
decoder subgraph

- This is a follow-up to [this
PR](#19188).
- The incorrect token ids in the timestamps processor were first noticed
during [this PR
review](#17500 (comment)).
When they were originally added in [this
PR](#15853), the offsets
were previously constant across the Whisper model sizes. When comparing
the new `whisper-large-v3` variant, the English-only variants (e.g.
`whisper-tiny.en`), and the original variants (e.g. `whisper-tiny`),
both the values and the offsets differ. Therefore, it is easier to set
the token ids as attributes to `WhisperBeamSearch` when exporting to
ensure the right values are used in the timestamps processor.
- The Hugging Face API for returning timestamps and the expected outputs
from the PyTorch model have both changed.
- The fix for `torch.onnx.export` is a follow-up to [this PR
review](#17179 (comment)).
- The argument grouping is a follow-up to [this PR
review](#17500 (comment)).
- Specific package versions are needed to run the Whisper scripts and
the `requirements.txt` file ensures that these versions are installed.
- The `whisper-large-v3` variant is released and should be in the list
of official pretrained models.
- After the changes from [this
PR](#17316), the exported
model is not loading in an ORT inference session because the
cross-attention KV cache inputs are missing in the decoder subgraph.
  • Loading branch information
kunal-vaishnavi authored and rachguo committed Feb 17, 2024
1 parent ad86d13 commit 485e17e
Show file tree
Hide file tree
Showing 21 changed files with 578 additions and 370 deletions.
32 changes: 21 additions & 11 deletions docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@ This version of the operator has been available since version 1 of the 'com.micr
<dt><tt>repetition_penalty</tt> (optional) : T</dt>
<dd>The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)</dd>
<dt><tt>vocab_mask</tt> (optional) : M</dt>
<dd>Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)</dd>
<dd>Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vocab_size)</dd>
<dt><tt>prefix_vocab_mask</tt> (optional) : M</dt>
<dd>Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)</dd>
<dt><tt>attention_mask</tt> (optional) : I</dt>
Expand Down Expand Up @@ -2252,7 +2252,7 @@ This version of the operator has been available since version 1 of the 'com.micr
<dt><tt>repetition_penalty</tt> (optional) : T</dt>
<dd>The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)</dd>
<dt><tt>vocab_mask</tt> (optional) : I</dt>
<dd>Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)</dd>
<dd>Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vocab_size)</dd>
<dt><tt>prefix_vocab_mask</tt> (optional) : I</dt>
<dd>Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)</dd>
<dt><tt>attention_mask</tt> (optional) : I</dt>
Expand Down Expand Up @@ -5154,7 +5154,7 @@ This version of the operator has been available since version 1 of the 'com.micr
<dt><tt>repetition_penalty</tt> (optional) : T</dt>
<dd>The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)</dd>
<dt><tt>vocab_mask</tt> (optional) : I</dt>
<dd>Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)</dd>
<dd>Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vocab_size)</dd>
<dt><tt>prefix_vocab_mask</tt> (optional) : I</dt>
<dd>Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)</dd>
<dt><tt>attention_mask</tt> (optional) : I</dt>
Expand Down Expand Up @@ -5743,12 +5743,14 @@ This version of the operator has been available since version 1 of the 'com.micr
#### Attributes

<dl>
<dt><tt>beginning_timestamp_token_id</tt> : int</dt>
<dd>The id of the first timestamp</dd>
<dt><tt>decoder</tt> : graph (required)</dt>
<dd>Decoder subgraph to execute in a loop.</dd>
<dt><tt>decoder_output_cross_qk</tt> : int</dt>
<dd>If nozero, decoder subgraph contains output Q*K from cross attentions. Default 0.</dd>
<dt><tt>decoder_start_token_id</tt> : int</dt>
<dd>The id of the token that indicates decoding starts.</dd>
<dd>The id of the token that indicates decoding starts (i.e. the start of transcription token id)</dd>
<dt><tt>early_stopping</tt> : int</dt>
<dd>early stop or not</dd>
<dt><tt>encoder</tt> : graph</dt>
Expand All @@ -5761,10 +5763,18 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>Must be 2 for whisper</dd>
<dt><tt>no_repeat_ngram_size</tt> : int</dt>
<dd>no repeat ngrams size</dd>
<dt><tt>no_speech_token</tt> : int</dt>
<dt><tt>no_speech_token_id</tt> : int</dt>
<dd>The token in whisper model that marks all sequence empty. With this model, whisper could output no_speech_prob after. Default -1.</dd>
<dt><tt>no_timestamps_token_id</tt> : int</dt>
<dd>The id of the token that indicates no timestamps</dd>
<dt><tt>pad_token_id</tt> : int (required)</dt>
<dd>The id of the padding token</dd>
<dt><tt>start_of_lm_token_id</tt> : int</dt>
<dd>The id of the token that indicates LM starts</dd>
<dt><tt>transcribe_token_id</tt> : int</dt>
<dd>The id of the transcribe task</dd>
<dt><tt>translate_token_id</tt> : int</dt>
<dd>The id of the translate task</dd>
<dt><tt>vocab_size</tt> : int</dt>
<dd>Size of the vocabulary. If not provided, it will be inferred from the decoder subgraph's output shape</dd>
</dl>
Expand All @@ -5783,11 +5793,11 @@ This version of the operator has been available since version 1 of the 'com.micr
<dt><tt>num_return_sequences</tt> : I</dt>
<dd>The number of returned sequences in the batch. Shape is (1)</dd>
<dt><tt>length_penalty</tt> (optional) : T</dt>
<dd>Exponential penalty to the length. Default value 1.0 means no penalty.Value > 1.0 encourages longer sequences, while values < 1.0 produces shorter sequences.Shape is (1,)</dd>
<dd>Exponential penalty to the length. Default value 1.0 means no penalty. Value > 1.0 encourages longer sequences, while values < 1.0 produces shorter sequences. Shape is (1,)</dd>
<dt><tt>repetition_penalty</tt> (optional) : T</dt>
<dd>The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)</dd>
<dt><tt>vocab_mask</tt> (optional) : M</dt>
<dd>Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)</dd>
<dd>Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vocab_size)</dd>
<dt><tt>prefix_vocab_mask</tt> (optional) : M</dt>
<dd>Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)</dd>
<dt><tt>attention_mask</tt> (optional) : I</dt>
Expand All @@ -5797,7 +5807,7 @@ This version of the operator has been available since version 1 of the 'com.micr
<dt><tt>logits_processor</tt> (optional) : I</dt>
<dd>Specific logits processor for different types of beamsearch models. Default value 0 means no specific logit processor. Accepts value >= 0. Shape is (1)</dd>
<dt><tt>cross_qk_layer_head</tt> (optional) : I</dt>
<dd>Only keep this list of (layer, head) of QK in the final cross_qk output when use_cross_qk is set. Default collect allits shape is (number of (layer, head) to keep, 2), i.e., [[layer_id1, head_id1], [layer_id2, head_id2]......]</dd>
<dd>Only keep this list of (layer, head) of QK in the final cross_qk output when use_cross_qk is set. Default collect all its shape is (number of (layer, head) to keep, 2), i.e., [[layer_id1, head_id1], [layer_id2, head_id2]......]</dd>
<dt><tt>extra_decoding_ids</tt> (optional) : I</dt>
<dd>Part of the decoder_input_ids that we need cross qk for it. it is of shape (batch_size, extra_decoding_ids_len).In such case, we should remove this from the tail of the decoder_input_ids, and put it here. ids < 0 in it (for multiple batch) are treated as stop of the extra_decoding_ids for corresponding batch.</dd>
</dl>
Expand All @@ -5810,11 +5820,11 @@ This version of the operator has been available since version 1 of the 'com.micr
<dt><tt>sequences_scores</tt> (optional) : T</dt>
<dd>Final beam score of the generated sequences. Shape is (batch_size, num_return_sequences)</dd>
<dt><tt>scores</tt> (optional) : T</dt>
<dd>Processed beam scores for each vocabulary token at each generation step.Beam scores consisting of log softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam.Shape is (max_length - sequence_length, batch_size, num_beams, vocab_size)</dd>
<dd>Processed beam scores for each vocabulary token at each generation step. Beam scores consisting of log softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam. Shape is (max_length - sequence_length, batch_size, num_beams, vocab_size)</dd>
<dt><tt>cross_qk</tt> (optional) : V</dt>
<dd>Output the accumulated stacked Q*K in cross attentions. Let H = number of Head of cross attention, F = the frames or kv-seq-len of the cross attention input, T = real decoded token length, L = number of layers,B = batch size, R = num_return_sequences. It then should return tensor of shape [B, R, L*H, T, F].If cross_qk_layer_head is given, shape is [B, R, cross_qk_layer_head.shape[0], T, F]</dd>
<dd>Output the accumulated stacked Q*K in cross attentions. Let H = number of Head of cross attention, F = the frames or kv-seq-len of the cross attention input, T = real decoded token length, L = number of layers, B = batch size, R = num_return_sequences. It then should return tensor of shape [B, R, L*H, T, F]. If cross_qk_layer_head is given, shape is [B, R, cross_qk_layer_head.shape[0], T, F]</dd>
<dt><tt>non_speech_probs</tt> (optional) : T</dt>
<dd>For whisper model, output the probabilities from logits after encoder and context decoding for the no_speech_token.Currently we treat the last token's logits is what we need, in future extra graph logic may be add to the encoder/context-decoder subgraph.The prob is save before logits may be updated by extra-decoding-ids. The shape of non_speech_probs is [B]</dd>
<dd>For whisper model, output the probabilities from logits after encoder and context decoding for the no_speech_token_id. The shape of non_speech_probs is [B]</dd>
</dl>

#### Type Constraints
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,8 @@ Status BeamSearchWhisper<T>::Execute(const FeedsFetchesManager& encoder_feeds_fe
TensorShape no_speech_probs_shape{parameters->batch_size};
Tensor* no_speech_probs = this->context_.Output(parameters->no_speech_probs_output_id, no_speech_probs_shape);
if (no_speech_probs && no_speech_probs->MutableData<T>()) {
ORT_ENFORCE(parameters->no_speech_token >= 0 && parameters->no_speech_token < parameters->vocab_size,
"no_speech_token id out of range, it is ", parameters->no_speech_token,
ORT_ENFORCE(parameters->no_speech_token_id >= 0 && parameters->no_speech_token_id < parameters->vocab_size,
"no_speech_token_id is out of range, it is ", parameters->no_speech_token_id,
", vocab_size is ", parameters->vocab_size);
this->parameters_->no_speech_probs = (void*)no_speech_probs->MutableData<T>();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,13 @@ void WhisperBeamSearchParameters::ParseFromAttributes(const OpKernelInfo& info)
model_type = static_cast<int>(info.GetAttrOrDefault<int64_t>("model_type", IGenerationParameters::kModelTypeWhisper));
ORT_ENFORCE(model_type == IGenerationParameters::kModelTypeWhisper);

no_speech_token = static_cast<int>(info.GetAttrOrDefault<int64_t>("no_speech_token", -1LL));
// Token ids are defined below in the order that they appear in the tokenizer
translate_token_id = static_cast<int>(info.GetAttrOrDefault<int64_t>("translate_token_id", -1LL));
transcribe_token_id = static_cast<int>(info.GetAttrOrDefault<int64_t>("transcribe_token_id", -1LL));
start_of_lm_token_id = static_cast<int>(info.GetAttrOrDefault<int64_t>("start_of_lm_token_id", -1LL));
no_speech_token_id = static_cast<int>(info.GetAttrOrDefault<int64_t>("no_speech_token_id", -1LL));
no_timestamps_token_id = static_cast<int>(info.GetAttrOrDefault<int64_t>("no_timestamps_token_id", -1LL));
beginning_timestamp_token_id = static_cast<int>(info.GetAttrOrDefault<int64_t>("beginning_timestamp_token_id", -1LL));
cross_qk_layer_head_input_id = 12;
extra_decoding_ids_input_id = 13;
cross_qk_output_id = 3;
Expand Down
9 changes: 8 additions & 1 deletion onnxruntime/contrib_ops/cpu/transformers/generation_shared.h
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,14 @@ struct IGenerationParameters {
// Parameters for whisper model
bool decoder_output_cross_qk = false;
gsl::span<const int32_t> extra_decoding_ids;
int32_t no_speech_token = -1;

// Token ids are defined below in the order that they appear in the tokenizer
int32_t translate_token_id = -1;
int32_t transcribe_token_id = -1;
int32_t start_of_lm_token_id = -1;
int32_t no_speech_token_id = -1;
int32_t no_timestamps_token_id = -1;
int32_t beginning_timestamp_token_id = -1;
void* no_speech_probs = nullptr;

int cross_qk_layer_head_input_id = -1;
Expand Down
Loading

0 comments on commit 485e17e

Please sign in to comment.