Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Temperature to WhisperBeamSearch input #19188

Merged
merged 11 commits into from
Jan 23, 2024
4 changes: 3 additions & 1 deletion docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -5761,7 +5761,7 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>Size of the vocabulary. If not provided, it will be inferred from the decoder subgraph's output shape</dd>
</dl>

#### Inputs (5 - 14)
#### Inputs (5 - 15)

<dl>
<dt><tt>input_ids</tt> : F</dt>
Expand Down Expand Up @@ -5792,6 +5792,8 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>Only keep this list of (layer, head) of QK in the final cross_qk output when use_cross_qk is set. Default collect allits shape is (number of (layer, head) to keep, 2), i.e., [[layer_id1, head_id1], [layer_id2, head_id2]......]</dd>
<dt><tt>extra_decoding_ids</tt> (optional) : I</dt>
<dd>Part of the decoder_input_ids that we need cross qk for it. it is of shape (batch_size, extra_decoding_ids_len).In such case, we should remove this from the tail of the decoder_input_ids, and put it here. ids < 0 in it (for multiple batch) are treated as stop of the extra_decoding_ids for corresponding batch.</dd>
<dt><tt>temperature</tt> (optional) : T</dt>
<dd>Temperature value to apply to logits processing during this execution's decoding. Shape is (1)</dd>
</dl>

#### Outputs (1 - 5)
Expand Down
4 changes: 2 additions & 2 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,7 @@ Do not modify directly.*
|TransposeMatMul|*in* A:**T**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
|Trilu|*in* X:**T**<br> *in* k:**tensor(int64)**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(int64)|
|Unique|*in* x:**T**<br> *out* y:**T**<br> *out* idx:**tensor(int64)**<br> *out* counts:**tensor(int64)**|1+|**T** = tensor(float)|
|WhisperBeamSearch|*in* input_ids:**F**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* num_beams:**I**<br> *in* num_return_sequences:**I**<br> *in* length_penalty:**T**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**M**<br> *in* prefix_vocab_mask:**M**<br> *in* attention_mask:**I**<br> *in* decoder_input_ids:**I**<br> *in* logits_processor:**I**<br> *in* cross_qk_layer_head:**I**<br> *in* extra_decoding_ids:**I**<br> *out* sequences:**I**<br> *out* sequences_scores:**T**<br> *out* scores:**T**<br> *out* cross_qk:**V**<br> *out* non_speech_probs:**T**|1+|**T** = tensor(float)|
|WhisperBeamSearch|*in* input_ids:**F**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* num_beams:**I**<br> *in* num_return_sequences:**I**<br> *in* length_penalty:**T**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**M**<br> *in* prefix_vocab_mask:**M**<br> *in* attention_mask:**I**<br> *in* decoder_input_ids:**I**<br> *in* logits_processor:**I**<br> *in* cross_qk_layer_head:**I**<br> *in* extra_decoding_ids:**I**<br> *in* temperature:**T**<br> *out* sequences:**I**<br> *out* sequences_scores:**T**<br> *out* scores:**T**<br> *out* cross_qk:**V**<br> *out* non_speech_probs:**T**|1+|**T** = tensor(float)|
|WordConvEmbedding|*in* Sequence:**T**<br> *in* W:**T1**<br> *in* B:**T1**<br> *in* C:**T1**<br> *out* Y:**T1**|1+|**T** = tensor(int32)<br/> **T1** = tensor(float)|
| |
| |
Expand Down Expand Up @@ -876,7 +876,7 @@ Do not modify directly.*
|TransposeMatMul|*in* A:**T**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
|Trilu|*in* X:**T**<br> *in* k:**tensor(int64)**<br> *out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|UnfoldTensor|*in* input:**T**<br> *out* output:**T**|1+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|WhisperBeamSearch|*in* input_ids:**F**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* num_beams:**I**<br> *in* num_return_sequences:**I**<br> *in* length_penalty:**T**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**M**<br> *in* prefix_vocab_mask:**M**<br> *in* attention_mask:**I**<br> *in* decoder_input_ids:**I**<br> *in* logits_processor:**I**<br> *in* cross_qk_layer_head:**I**<br> *in* extra_decoding_ids:**I**<br> *out* sequences:**I**<br> *out* sequences_scores:**T**<br> *out* scores:**T**<br> *out* cross_qk:**V**<br> *out* non_speech_probs:**T**|1+|**T** = tensor(float), tensor(float16)|
|WhisperBeamSearch|*in* input_ids:**F**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* num_beams:**I**<br> *in* num_return_sequences:**I**<br> *in* length_penalty:**T**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**M**<br> *in* prefix_vocab_mask:**M**<br> *in* attention_mask:**I**<br> *in* decoder_input_ids:**I**<br> *in* logits_processor:**I**<br> *in* cross_qk_layer_head:**I**<br> *in* extra_decoding_ids:**I**<br> *in* temperature:**T**<br> *out* sequences:**I**<br> *out* sequences_scores:**T**<br> *out* scores:**T**<br> *out* cross_qk:**V**<br> *out* non_speech_probs:**T**|1+|**T** = tensor(float), tensor(float16)|
| |
| |

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,20 @@ void BeamSearchParameters::ParseFromInputs(OpKernelContext* context) {
logits_processor = logits_processor_tensor ? static_cast<int>(*logits_processor_tensor->Data<int32_t>()) : 0;
ORT_ENFORCE(logits_processor >= 0,
"logits_processor shall be a non-negative integer, got ", logits_processor);
}

if (this->model_type == IGenerationParameters::kModelTypeWhisper) {
auto* temperature_tensor = context->Input<Tensor>(14);
if (temperature_tensor) {
if (temperature_tensor->IsDataType<float>()) {
temperature = *temperature_tensor->Data<float>();
} else {
temperature = static_cast<float>(*temperature_tensor->Data<MLFloat16>());
}
} else {
temperature = 1.0f;
}
}
}
void BeamSearchParameters::SetSubgraphParameters(int vocabulary_size, int heads, int hidden_size_per_head, int layers) {
// Override vocab_size using the inferred shape from the decoder subgraph ONLY IF
// the vocab_size hasn't been explicitly specified by the user (as an attribute of BeamSearch)
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/contrib_ops/cuda/transformers/beam_search.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ ONNX_OPERATOR_KERNEL_EX(
.InputMemoryType(OrtMemTypeCPUInput, 9) // 'attention_mask' needs to be on CPU
.InputMemoryType(OrtMemTypeCPUInput, 10) // 'decoder_input_ids' needs to be on CPU
.InputMemoryType(OrtMemTypeCPUInput, 11) // 'logits_processor' needs to be on CPU
.InputMemoryType(OrtMemTypeCPUInput, 14) // 'temperature' needs to be on CPU
.OutputMemoryType(OrtMemTypeCPUOutput, 0) // 'sequences' output on CPU
.OutputMemoryType(OrtMemTypeCPUOutput, 1) // 'sequences_scores' output on CPU
.TypeConstraint("T", {DataTypeImpl::GetTensorType<float>(),
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/graph/contrib_ops/contrib_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1231,6 +1231,7 @@
"In such case, we should remove this from the tail of the decoder_input_ids, and put it here. ids < 0 in it (for multiple batch) "
"are treated as stop of the extra_decoding_ids for corresponding batch.",
"I", OpSchema::Optional)
.Input(14, "temperature", "Temperature value to apply to logits processing during this execution's decoding. Shape is (1)", "T", OpSchema::Optional)

Check warning on line 1234 in onnxruntime/core/graph/contrib_ops/contrib_defs.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/graph/contrib_ops/contrib_defs.cc#L1234

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/graph/contrib_ops/contrib_defs.cc:1234:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
.Output(0, "sequences", "Word IDs of generated sequences. Shape is (batch_size, num_return_sequences, max_sequence_length)", "I")
.Output(1, "sequences_scores", "Final beam score of the generated sequences. Shape is (batch_size, num_return_sequences)", "T", OpSchema::Optional)
.Output(2, "scores",
Expand Down
Loading