Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Whisper Timestamps and Temperature #19509

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
7400dcf
Add temperature to model and clean up scripts
kunal-vaishnavi Jan 20, 2024
10cf102
Add task token ids to WhisperBeamSearch
kunal-vaishnavi Jan 26, 2024
9fbca34
Fix parity check
kunal-vaishnavi Jan 26, 2024
ffd190a
Merge branch 'main' into kvaishnavi/whisper-temperature
kunal-vaishnavi Feb 1, 2024
2722582
Add packages to requirements file
kunal-vaishnavi Feb 1, 2024
f44e427
Fix token ids in timestamps processor
kunal-vaishnavi Feb 12, 2024
6c99173
Convert other token ids to attrs
kunal-vaishnavi Feb 12, 2024
3df0429
Fix timestamps test case bugs
kunal-vaishnavi Feb 13, 2024
29ced94
Cleaning up comments
kunal-vaishnavi Feb 13, 2024
e723bb0
Add changes suggested by linter
kunal-vaishnavi Feb 13, 2024
298fc3c
Merge branch 'main' into kvaishnavi/whisper-temperature
kunal-vaishnavi Feb 13, 2024
781cab7
Fix CodeQL warnings
kunal-vaishnavi Feb 13, 2024
6df8430
Add copyright to Whisper scripts
kunal-vaishnavi Feb 13, 2024
75a575b
Add updated contrib ops doc
kunal-vaishnavi Feb 13, 2024
c009bbc
Make token id attributes optional
kunal-vaishnavi Feb 14, 2024
3c63a51
Add updated contrib ops doc
kunal-vaishnavi Feb 14, 2024
b89c13c
Address PR feedback
kunal-vaishnavi Feb 15, 2024
5578745
Add updated docs and fix linter warnings
kunal-vaishnavi Feb 15, 2024
9f9d883
Remove noqa to pass CI lintrunner and ignore local lintrunner failure
kunal-vaishnavi Feb 15, 2024
8bcb559
Update README and fix export
kunal-vaishnavi Feb 15, 2024
af38716
Add another expected transcription
kunal-vaishnavi Feb 16, 2024
a79814d
Merge branch 'main' into kvaishnavi/whisper-temperature
kunal-vaishnavi Feb 16, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 21 additions & 11 deletions docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@ This version of the operator has been available since version 1 of the 'com.micr
<dt><tt>repetition_penalty</tt> (optional) : T</dt>
<dd>The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)</dd>
<dt><tt>vocab_mask</tt> (optional) : M</dt>
<dd>Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)</dd>
<dd>Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vocab_size)</dd>
<dt><tt>prefix_vocab_mask</tt> (optional) : M</dt>
<dd>Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)</dd>
<dt><tt>attention_mask</tt> (optional) : I</dt>
Expand Down Expand Up @@ -2252,7 +2252,7 @@ This version of the operator has been available since version 1 of the 'com.micr
<dt><tt>repetition_penalty</tt> (optional) : T</dt>
<dd>The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)</dd>
<dt><tt>vocab_mask</tt> (optional) : I</dt>
<dd>Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)</dd>
<dd>Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vocab_size)</dd>
<dt><tt>prefix_vocab_mask</tt> (optional) : I</dt>
<dd>Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)</dd>
<dt><tt>attention_mask</tt> (optional) : I</dt>
Expand Down Expand Up @@ -5154,7 +5154,7 @@ This version of the operator has been available since version 1 of the 'com.micr
<dt><tt>repetition_penalty</tt> (optional) : T</dt>
<dd>The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)</dd>
<dt><tt>vocab_mask</tt> (optional) : I</dt>
<dd>Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)</dd>
<dd>Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vocab_size)</dd>
<dt><tt>prefix_vocab_mask</tt> (optional) : I</dt>
<dd>Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)</dd>
<dt><tt>attention_mask</tt> (optional) : I</dt>
Expand Down Expand Up @@ -5743,12 +5743,14 @@ This version of the operator has been available since version 1 of the 'com.micr
#### Attributes

<dl>
<dt><tt>beginning_timestamp_token_id</tt> : int</dt>
<dd>The id of the first timestamp</dd>
<dt><tt>decoder</tt> : graph (required)</dt>
<dd>Decoder subgraph to execute in a loop.</dd>
<dt><tt>decoder_output_cross_qk</tt> : int</dt>
<dd>If nozero, decoder subgraph contains output Q*K from cross attentions. Default 0.</dd>
<dt><tt>decoder_start_token_id</tt> : int</dt>
<dd>The id of the token that indicates decoding starts.</dd>
<dd>The id of the token that indicates decoding starts (i.e. the start of transcription token id)</dd>
<dt><tt>early_stopping</tt> : int</dt>
<dd>early stop or not</dd>
<dt><tt>encoder</tt> : graph</dt>
Expand All @@ -5761,10 +5763,18 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>Must be 2 for whisper</dd>
<dt><tt>no_repeat_ngram_size</tt> : int</dt>
<dd>no repeat ngrams size</dd>
<dt><tt>no_speech_token</tt> : int</dt>
<dt><tt>no_speech_token_id</tt> : int</dt>
<dd>The token in whisper model that marks all sequence empty. With this model, whisper could output no_speech_prob after. Default -1.</dd>
<dt><tt>no_timestamps_token_id</tt> : int</dt>
<dd>The id of the token that indicates no timestamps</dd>
<dt><tt>pad_token_id</tt> : int (required)</dt>
<dd>The id of the padding token</dd>
<dt><tt>start_of_lm_token_id</tt> : int</dt>
<dd>The id of the token that indicates LM starts</dd>
<dt><tt>transcribe_token_id</tt> : int</dt>
<dd>The id of the transcribe task</dd>
<dt><tt>translate_token_id</tt> : int</dt>
<dd>The id of the translate task</dd>
<dt><tt>vocab_size</tt> : int</dt>
<dd>Size of the vocabulary. If not provided, it will be inferred from the decoder subgraph's output shape</dd>
</dl>
Expand All @@ -5783,11 +5793,11 @@ This version of the operator has been available since version 1 of the 'com.micr
<dt><tt>num_return_sequences</tt> : I</dt>
<dd>The number of returned sequences in the batch. Shape is (1)</dd>
<dt><tt>length_penalty</tt> (optional) : T</dt>
<dd>Exponential penalty to the length. Default value 1.0 means no penalty.Value > 1.0 encourages longer sequences, while values < 1.0 produces shorter sequences.Shape is (1,)</dd>
<dd>Exponential penalty to the length. Default value 1.0 means no penalty. Value > 1.0 encourages longer sequences, while values < 1.0 produces shorter sequences. Shape is (1,)</dd>
<dt><tt>repetition_penalty</tt> (optional) : T</dt>
<dd>The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)</dd>
<dt><tt>vocab_mask</tt> (optional) : M</dt>
<dd>Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)</dd>
<dd>Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vocab_size)</dd>
<dt><tt>prefix_vocab_mask</tt> (optional) : M</dt>
<dd>Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)</dd>
<dt><tt>attention_mask</tt> (optional) : I</dt>
Expand All @@ -5797,7 +5807,7 @@ This version of the operator has been available since version 1 of the 'com.micr
<dt><tt>logits_processor</tt> (optional) : I</dt>
<dd>Specific logits processor for different types of beamsearch models. Default value 0 means no specific logit processor. Accepts value >= 0. Shape is (1)</dd>
<dt><tt>cross_qk_layer_head</tt> (optional) : I</dt>
<dd>Only keep this list of (layer, head) of QK in the final cross_qk output when use_cross_qk is set. Default collect allits shape is (number of (layer, head) to keep, 2), i.e., [[layer_id1, head_id1], [layer_id2, head_id2]......]</dd>
<dd>Only keep this list of (layer, head) of QK in the final cross_qk output when use_cross_qk is set. Default collect all its shape is (number of (layer, head) to keep, 2), i.e., [[layer_id1, head_id1], [layer_id2, head_id2]......]</dd>
<dt><tt>extra_decoding_ids</tt> (optional) : I</dt>
<dd>Part of the decoder_input_ids that we need cross qk for it. it is of shape (batch_size, extra_decoding_ids_len).In such case, we should remove this from the tail of the decoder_input_ids, and put it here. ids < 0 in it (for multiple batch) are treated as stop of the extra_decoding_ids for corresponding batch.</dd>
<dt><tt>temperature</tt> (optional) : T</dt>
Expand All @@ -5812,11 +5822,11 @@ This version of the operator has been available since version 1 of the 'com.micr
<dt><tt>sequences_scores</tt> (optional) : T</dt>
<dd>Final beam score of the generated sequences. Shape is (batch_size, num_return_sequences)</dd>
<dt><tt>scores</tt> (optional) : T</dt>
<dd>Processed beam scores for each vocabulary token at each generation step.Beam scores consisting of log softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam.Shape is (max_length - sequence_length, batch_size, num_beams, vocab_size)</dd>
<dd>Processed beam scores for each vocabulary token at each generation step. Beam scores consisting of log softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam. Shape is (max_length - sequence_length, batch_size, num_beams, vocab_size)</dd>
<dt><tt>cross_qk</tt> (optional) : V</dt>
<dd>Output the accumulated stacked Q*K in cross attentions. Let H = number of Head of cross attention, F = the frames or kv-seq-len of the cross attention input, T = real decoded token length, L = number of layers,B = batch size, R = num_return_sequences. It then should return tensor of shape [B, R, L*H, T, F].If cross_qk_layer_head is given, shape is [B, R, cross_qk_layer_head.shape[0], T, F]</dd>
<dd>Output the accumulated stacked Q*K in cross attentions. Let H = number of Head of cross attention, F = the frames or kv-seq-len of the cross attention input, T = real decoded token length, L = number of layers, B = batch size, R = num_return_sequences. It then should return tensor of shape [B, R, L*H, T, F]. If cross_qk_layer_head is given, shape is [B, R, cross_qk_layer_head.shape[0], T, F]</dd>
<dt><tt>non_speech_probs</tt> (optional) : T</dt>
<dd>For whisper model, output the probabilities from logits after encoder and context decoding for the no_speech_token.Currently we treat the last token's logits is what we need, in future extra graph logic may be add to the encoder/context-decoder subgraph.The prob is save before logits may be updated by extra-decoding-ids. The shape of non_speech_probs is [B]</dd>
<dd>For whisper model, output the probabilities from logits after encoder and context decoding for the no_speech_token_id. The shape of non_speech_probs is [B]</dd>
</dl>

#### Type Constraints
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,8 @@ Status BeamSearchWhisper<T>::Execute(const FeedsFetchesManager& encoder_feeds_fe
TensorShape no_speech_probs_shape{parameters->batch_size};
Tensor* no_speech_probs = this->context_.Output(parameters->no_speech_probs_output_id, no_speech_probs_shape);
if (no_speech_probs && no_speech_probs->MutableData<T>()) {
ORT_ENFORCE(parameters->no_speech_token >= 0 && parameters->no_speech_token < parameters->vocab_size,
"no_speech_token id out of range, it is ", parameters->no_speech_token,
ORT_ENFORCE(parameters->no_speech_token_id >= 0 && parameters->no_speech_token_id < parameters->vocab_size,
"no_speech_token_id is out of range, it is ", parameters->no_speech_token_id,
", vocab_size is ", parameters->vocab_size);
this->parameters_->no_speech_probs = (void*)no_speech_probs->MutableData<T>();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,13 @@ void WhisperBeamSearchParameters::ParseFromAttributes(const OpKernelInfo& info)
model_type = static_cast<int>(info.GetAttrOrDefault<int64_t>("model_type", IGenerationParameters::kModelTypeWhisper));
ORT_ENFORCE(model_type == IGenerationParameters::kModelTypeWhisper);

no_speech_token = static_cast<int>(info.GetAttrOrDefault<int64_t>("no_speech_token", -1LL));
// Token ids are defined below in the order that they appear in the tokenizer
translate_token_id = static_cast<int>(info.GetAttrOrDefault<int64_t>("translate_token_id", -1LL));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To support old model, need add some fallback logic when the attribute is not available.

transcribe_token_id = static_cast<int>(info.GetAttrOrDefault<int64_t>("transcribe_token_id", -1LL));
start_of_lm_token_id = static_cast<int>(info.GetAttrOrDefault<int64_t>("start_of_lm_token_id", -1LL));
no_speech_token_id = static_cast<int>(info.GetAttrOrDefault<int64_t>("no_speech_token_id", -1LL));
no_timestamps_token_id = static_cast<int>(info.GetAttrOrDefault<int64_t>("no_timestamps_token_id", -1LL));
beginning_timestamp_token_id = static_cast<int>(info.GetAttrOrDefault<int64_t>("beginning_timestamp_token_id", -1LL));
cross_qk_layer_head_input_id = 12;
extra_decoding_ids_input_id = 13;
cross_qk_output_id = 3;
Expand Down
9 changes: 8 additions & 1 deletion onnxruntime/contrib_ops/cpu/transformers/generation_shared.h
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,14 @@ struct IGenerationParameters {
// Parameters for whisper model
bool decoder_output_cross_qk = false;
gsl::span<const int32_t> extra_decoding_ids;
int32_t no_speech_token = -1;

// Token ids are defined below in the order that they appear in the tokenizer
int32_t translate_token_id = -1;
int32_t transcribe_token_id = -1;
int32_t start_of_lm_token_id = -1;
int32_t no_speech_token_id = -1;
int32_t no_timestamps_token_id = -1;
int32_t beginning_timestamp_token_id = -1;
void* no_speech_probs = nullptr;

int cross_qk_layer_head_input_id = -1;
Expand Down
Loading
Loading