diff --git a/include/onnxruntime/core/providers/cuda/cuda_provider_options.h b/include/onnxruntime/core/providers/cuda/cuda_provider_options.h index eaf0e5337b8b6..5f266dd14d36d 100644 --- a/include/onnxruntime/core/providers/cuda/cuda_provider_options.h +++ b/include/onnxruntime/core/providers/cuda/cuda_provider_options.h @@ -32,4 +32,5 @@ struct OrtCUDAProviderOptionsV2 { int tunable_op_max_tuning_duration_ms = 0; // Max tuning duration time limit for TunableOp. int enable_skip_layer_norm_strict_mode = 0; // flag specifying if SkipLayerNorm is in strict mode. If true, use LayerNormalization kernel. // The strict mode has better accuracy but lower performance. + int use_ep_level_unified_stream = 0; // flag specifying if ep level stream is used or not }; diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h index 1dc85e6d345d7..73b83057bdbe9 100644 --- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h @@ -103,7 +103,8 @@ Status CheckInputs(const T* query, } if (past_key_dims[2] != past_value_dims[2]) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_key' and 'past_value' shall have same dim 2 (past_sequence_length)"); + "Input 'past_key' and 'past_value' shall have same dim 2 (past_sequence_length). ", + past_key_dims[2], " vs ", past_value_dims[2]); } if (past_key_dims[3] != head_size) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc index c391f47e1927b..ff6ba86cb25ef 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc @@ -319,7 +319,12 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const { update_decoder_feeds_func_ ? update_decoder_feeds_func_ : GenerationCpuDeviceHelper::UpdateDecoderFeeds, expand_buffer_float_func_ ? expand_buffer_float_func_ : GenerationCpuDeviceHelper::ExpandBuffer, expand_buffer_float16_func_ ? expand_buffer_float16_func_ : GenerationCpuDeviceHelper::ExpandBuffer, - create_beam_scorer_func_}; + create_beam_scorer_func_, + update_decoder_cross_qk_func_ ? update_decoder_cross_qk_func_ : GenerationCpuDeviceHelper::UpdateDecoderCrossQK, + finalize_decoder_cross_qk_func_ ? finalize_decoder_cross_qk_func_ : GenerationCpuDeviceHelper::FinalizeDecoderCrossQK, + cuda_device_prop_, + cuda_device_arch_}; + #ifdef USE_CUDA ORT_RETURN_IF_ERROR(impl.InitializeCuda(reorder_past_state_func_, init_cache_indir_func_, cuda_device_prop_, cuda_device_arch_)); #endif @@ -340,7 +345,12 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const { update_decoder_feeds_fp16_func_ ? update_decoder_feeds_fp16_func_ : GenerationCpuDeviceHelper::UpdateDecoderFeeds, expand_buffer_float_func_, expand_buffer_float16_func_, - create_beam_scorer_func_}; + create_beam_scorer_func_, + update_decoder_cross_qk_func_ ? update_decoder_cross_qk_func_ : GenerationCpuDeviceHelper::UpdateDecoderCrossQK, + finalize_decoder_cross_qk_func_ ? finalize_decoder_cross_qk_func_ : GenerationCpuDeviceHelper::FinalizeDecoderCrossQK, + cuda_device_prop_, + cuda_device_arch_}; + #ifdef USE_CUDA ORT_RETURN_IF_ERROR(impl.InitializeCuda(reorder_past_state_func_, init_cache_indir_func_, cuda_device_prop_, cuda_device_arch_)); #endif diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search.h index 93b7e08fabf94..f3c5b50dfb84d 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search.h @@ -88,12 +88,16 @@ class BeamSearch : public IControlFlowKernel { const GenerationDeviceHelper::UpdateDecoderFeedsFunc& update_decoder_feeds_fp16_func, const GenerationDeviceHelper::ExpandBufferFunc& expand_buffer_int32_func, const GenerationDeviceHelper::ExpandBufferFunc& expand_buffer_float_func, - const GenerationDeviceHelper::ExpandBufferFunc& expand_buffer_float16_func) { + const GenerationDeviceHelper::ExpandBufferFunc& expand_buffer_float16_func, + const GenerationDeviceHelper::UpdateDecoderCrossQKFunc& update_decoder_cross_qk_func, + const GenerationDeviceHelper::FinalizeDecoderCrossQKFunc& finalize_decoder_cross_qk_func) { update_decoder_feeds_func_ = update_decoder_feeds_func; update_decoder_feeds_fp16_func_ = update_decoder_feeds_fp16_func; expand_buffer_int32_func_ = expand_buffer_int32_func; expand_buffer_float_func_ = expand_buffer_float_func; expand_buffer_float16_func_ = expand_buffer_float16_func; + update_decoder_cross_qk_func_ = update_decoder_cross_qk_func; + finalize_decoder_cross_qk_func_ = finalize_decoder_cross_qk_func; } #ifdef USE_CUDA @@ -175,6 +179,10 @@ class BeamSearch : public IControlFlowKernel { BeamSearchParameters parameters_; bool has_init_decoder_ = false; + + GenerationDeviceHelper::UpdateDecoderCrossQKFunc update_decoder_cross_qk_func_; + + GenerationDeviceHelper::FinalizeDecoderCrossQKFunc finalize_decoder_cross_qk_func_; }; } // namespace transformers diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h index 8832b4314bad3..29b38fc234de5 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h @@ -17,34 +17,35 @@ struct BeamSearchState : IBeamSearchState { BeamSearchState(const IGenerationParameters& parameters, AllocatorPtr allocator, int has_decoder_masked_attention, - bool use_position) { + bool use_position, + Stream* stream) { size_t batch_beam_size = SafeInt(parameters.batch_size) * parameters.num_beams; size_t next_token_size = SafeInt(batch_beam_size) * parameters.vocab_size; - this->next_token_logits = AllocateBuffer(allocator, next_token_logits_buffer_, next_token_size); - this->next_token_scores = AllocateBuffer(allocator, next_token_scores_buffer_, next_token_size); - this->next_tokens = AllocateBuffer(allocator, next_tokens_buffer_, SafeInt(2) * batch_beam_size); - this->next_indices = AllocateBuffer(allocator, next_indices_buffer_, SafeInt(2) * batch_beam_size); - this->next_scores = AllocateBuffer(allocator, next_scores_buffer_, SafeInt(2) * batch_beam_size); + this->next_token_logits = AllocateBuffer(allocator, next_token_logits_buffer_, next_token_size, stream); + this->next_token_scores = AllocateBuffer(allocator, next_token_scores_buffer_, next_token_size, stream); + this->next_tokens = AllocateBuffer(allocator, next_tokens_buffer_, SafeInt(2) * batch_beam_size, stream); + this->next_indices = AllocateBuffer(allocator, next_indices_buffer_, SafeInt(2) * batch_beam_size, stream); + this->next_scores = AllocateBuffer(allocator, next_scores_buffer_, SafeInt(2) * batch_beam_size, stream); constexpr size_t max_parts_of_vocab = 128; size_t topk_buffer_size = SafeInt(batch_beam_size) * (max_parts_of_vocab + 1) * parameters.num_beams * 2 * 2; - this->topk_buffer = AllocateBuffer(allocator, topk_temp_buffer_, topk_buffer_size); + this->topk_buffer = AllocateBuffer(allocator, topk_temp_buffer_, topk_buffer_size, stream); if (allocator->Info().device.Type() == OrtDevice::GPU) { size_t sequences_elements = SafeInt(2) * batch_beam_size * parameters.max_length; - this->sequences_device = AllocateBuffer(allocator, sequences_device_buffer_, sequences_elements); + this->sequences_device = AllocateBuffer(allocator, sequences_device_buffer_, sequences_elements, stream); } if (use_position) { - this->next_positions = AllocateBuffer(allocator, next_positions_buffer_, batch_beam_size); + this->next_positions = AllocateBuffer(allocator, next_positions_buffer_, batch_beam_size, stream); } - this->beam_scores = AllocateBuffer(allocator, beam_scores_buffer_, batch_beam_size); + this->beam_scores = AllocateBuffer(allocator, beam_scores_buffer_, batch_beam_size, stream); if (parameters.output_scores) { size_t elements = SafeInt(parameters.max_length - parameters.sequence_length) * parameters.batch_size * parameters.num_beams * parameters.vocab_size; - this->scores = AllocateBuffer(allocator, scores_buffer_, elements); + this->scores = AllocateBuffer(allocator, scores_buffer_, elements, stream); this->remaining_scores = this->scores; } @@ -68,35 +69,38 @@ struct BeamSearchState : IBeamSearchState { } private: - BufferUniquePtr next_token_logits_buffer_; - BufferUniquePtr next_token_scores_buffer_; - BufferUniquePtr next_tokens_buffer_; - BufferUniquePtr next_indices_buffer_; - BufferUniquePtr next_scores_buffer_; - BufferUniquePtr next_positions_buffer_; - BufferUniquePtr beam_scores_buffer_; - BufferUniquePtr scores_buffer_; - BufferUniquePtr topk_temp_buffer_; - BufferUniquePtr sequences_device_buffer_; + IAllocatorUniquePtr next_token_logits_buffer_; + IAllocatorUniquePtr next_token_scores_buffer_; + IAllocatorUniquePtr next_tokens_buffer_; + IAllocatorUniquePtr next_indices_buffer_; + IAllocatorUniquePtr next_scores_buffer_; + IAllocatorUniquePtr next_positions_buffer_; + IAllocatorUniquePtr beam_scores_buffer_; + IAllocatorUniquePtr scores_buffer_; + IAllocatorUniquePtr topk_temp_buffer_; + IAllocatorUniquePtr sequences_device_buffer_; }; struct BeamSearchCpuState : IBeamSearchCpuState { Sequences sequences; - BeamSearchCpuState(const IGenerationParameters& parameters, AllocatorPtr allocator, bool is_cuda) + BeamSearchCpuState(const IGenerationParameters& parameters, AllocatorPtr allocator, bool is_cuda, Stream* stream) : parameters_{parameters} { - sequence_lengths = AllocateBuffer(allocator, sequence_lengths_buffer_, batch_beam_size_); + sequence_lengths = AllocateBuffer(allocator, sequence_lengths_buffer_, batch_beam_size_, stream); size_t sequences_bytes = SafeInt(2) * batch_beam_size_ * parameters.max_length; - sequences_space = AllocateBuffer(allocator, sequences_space_buffer_, sequences_bytes, true /* fill */); + sequences_space = AllocateBuffer(allocator, sequences_space_buffer_, sequences_bytes, stream, true /* fill */); sequences.Init(sequences_space, batch_beam_size_, parameters.sequence_length, parameters.max_length); if (is_cuda) { // buffers used by CUDA operator but not by CPU operator. - topk_scores = AllocateBuffer(allocator, topk_scores_buffer_, 2 * static_cast(batch_beam_size_)); - topk_tokens = AllocateBuffer(allocator, topk_tokens_buffer_, 2 * static_cast(batch_beam_size_)); - topk_indices = AllocateBuffer(allocator, topk_indices_buffer_, 2 * static_cast(batch_beam_size_)); - final_beam_scores = AllocateBuffer(allocator, final_beam_scores_buffer_, batch_beam_size_); + topk_scores = AllocateBuffer(allocator, topk_scores_buffer_, 2 * static_cast(batch_beam_size_), stream); + topk_tokens = AllocateBuffer(allocator, topk_tokens_buffer_, 2 * static_cast(batch_beam_size_), stream); + topk_indices = AllocateBuffer(allocator, topk_indices_buffer_, 2 * static_cast(batch_beam_size_), stream); + final_beam_scores = AllocateBuffer(allocator, final_beam_scores_buffer_, batch_beam_size_, stream); + + size_t next_token_size = SafeInt(batch_beam_size_) * parameters.vocab_size; + next_token_scores = AllocateBuffer(allocator, next_token_scores_buffer_, next_token_size, stream); } } @@ -124,12 +128,13 @@ struct BeamSearchCpuState : IBeamSearchCpuState { const IGenerationParameters& parameters_; const int batch_beam_size_{parameters_.batch_size * parameters_.num_beams}; - BufferUniquePtr final_beam_scores_buffer_; - BufferUniquePtr sequence_lengths_buffer_; - BufferUniquePtr topk_scores_buffer_; - BufferUniquePtr topk_tokens_buffer_; - BufferUniquePtr topk_indices_buffer_; - BufferUniquePtr sequences_space_buffer_; + IAllocatorUniquePtr final_beam_scores_buffer_; + IAllocatorUniquePtr sequence_lengths_buffer_; + IAllocatorUniquePtr topk_scores_buffer_; + IAllocatorUniquePtr topk_tokens_buffer_; + IAllocatorUniquePtr topk_indices_buffer_; + IAllocatorUniquePtr sequences_space_buffer_; + IAllocatorUniquePtr next_token_scores_buffer_; }; // Base class of beam search implementation that is common for GPT-2, T5, and Whisper. diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h index 205d94fae9fab..56d950ca2f41e 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h @@ -215,7 +215,8 @@ Status BeamSearchGpt::Execute(const FeedsFetchesManager* init_run_feeds_fetch BeamSearchCpuState cpu_state{*parameters, this->cpu_allocator_, - this->IsCuda()}; + this->IsCuda(), + this->ort_stream_}; // buffer in GPU for input_ids, position_ids and attention_mask IAllocatorUniquePtr buffer; @@ -240,7 +241,8 @@ Status BeamSearchGpt::Execute(const FeedsFetchesManager* init_run_feeds_fetch BeamSearchState beam_state{*parameters, this->temp_space_allocator_, gpt_subgraph_.has_decoder_masked_attention_, - true /* use_position */}; + true /* use_position */, + this->ort_stream_}; init_beam_state_func_(&beam_state, cpu_state.sequence_lengths, diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h index 14a0db57c45de..94547887d3a90 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h @@ -144,7 +144,8 @@ Status BeamSearchT5::Execute(const FeedsFetchesManager& encoder_feeds_fetches BeamSearchCpuState cpu_state{*parameters, this->cpu_allocator_, - this->IsCuda()}; + this->IsCuda(), + this->ort_stream_}; IAllocatorUniquePtr buffer; @@ -195,7 +196,8 @@ Status BeamSearchT5::Execute(const FeedsFetchesManager& encoder_feeds_fetches BeamSearchState beam_state{*parameters, this->temp_space_allocator_, decoder_subgraph_.has_decoder_masked_attention_, - false /* use_position */}; + false /* use_position */, + this->ort_stream_}; init_beam_state_func_(&beam_state, cpu_state.sequence_lengths, diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h index 198dec011c56f..d97f2e5f1a19e 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h @@ -36,7 +36,11 @@ class BeamSearchWhisper : public BeamSearchBase { const GenerationDeviceHelper::UpdateDecoderFeedsFunc& update_decoder_feeds_func, const GenerationDeviceHelper::ExpandBufferFunc& expand_buffer_float_func, const GenerationDeviceHelper::ExpandBufferFunc& expand_buffer_float16_func, - const GenerationDeviceHelper::CreateBeamScorer& create_beam_scorer_func) + const GenerationDeviceHelper::CreateBeamScorer& create_beam_scorer_func, + const GenerationDeviceHelper::UpdateDecoderCrossQKFunc& update_decoder_cross_qk_func, + const GenerationDeviceHelper::FinalizeDecoderCrossQKFunc& finalize_decoder_cross_qk_func, + const void* cuda_device_prop, + int cuda_device_arch) : BeamSearchBase(context, decoder_session_state, thread_pool, ort_stream, cuda_dumper, params, topk_func, process_logits_func, device_copy_func, device_copy_int32_func), @@ -49,7 +53,11 @@ class BeamSearchWhisper : public BeamSearchBase { update_decoder_feeds_func_(update_decoder_feeds_func), expand_buffer_float_func_(expand_buffer_float_func), expand_buffer_float16_func_(expand_buffer_float16_func), - create_beam_scorer_func_(create_beam_scorer_func) {} + create_beam_scorer_func_(create_beam_scorer_func), + update_decoder_cross_qk_func_(update_decoder_cross_qk_func), + finalize_decoder_cross_qk_func_(finalize_decoder_cross_qk_func), + cuda_device_prop_(cuda_device_prop), + cuda_device_arch_(cuda_device_arch) {} #ifdef USE_CUDA Status InitializeCuda( @@ -95,6 +103,8 @@ class BeamSearchWhisper : public BeamSearchBase { GenerationDeviceHelper::ExpandBufferFunc expand_buffer_float16_func_; GenerationDeviceHelper::CreateBeamScorer create_beam_scorer_func_; + const GenerationDeviceHelper::UpdateDecoderCrossQKFunc update_decoder_cross_qk_func_; + const GenerationDeviceHelper::FinalizeDecoderCrossQKFunc finalize_decoder_cross_qk_func_; const void* cuda_device_prop_ = nullptr; int cuda_device_arch_ = 0; }; @@ -122,6 +132,15 @@ Status BeamSearchWhisper::Execute(const FeedsFetchesManager& encoder_feeds_fe TensorShape scores_shape(&scores_dims[0], sizeof(scores_dims) / sizeof(scores_dims[0])); Tensor* output_scores = this->context_.Output(2, scores_shape); + TensorShape no_speech_probs_shape{parameters->batch_size}; + Tensor* no_speech_probs = this->context_.Output(4, no_speech_probs_shape); + if (no_speech_probs && no_speech_probs->MutableData()) { + ORT_ENFORCE(parameters->no_speech_token >= 0 && parameters->no_speech_token < parameters->vocab_size, + "no_speech_token id out of range, it is ", parameters->no_speech_token, + ", vocab_size is ", parameters->vocab_size); + this->parameters_->no_speech_probs = (void*)no_speech_probs->MutableData(); + } + // Update the flag to indicate whether scores exists in output this->parameters_->output_scores = (output_scores != nullptr); @@ -136,7 +155,8 @@ Status BeamSearchWhisper::Execute(const FeedsFetchesManager& encoder_feeds_fe BeamSearchCpuState cpu_state{*parameters, this->cpu_allocator_, - this->IsCuda()}; + this->IsCuda(), + this->ort_stream_}; IAllocatorUniquePtr buffer; @@ -188,7 +208,8 @@ Status BeamSearchWhisper::Execute(const FeedsFetchesManager& encoder_feeds_fe BeamSearchState beam_state{*parameters, this->temp_space_allocator_, decoder_subgraph_.has_decoder_masked_attention_, - false /* use_position */}; + false /* use_position */, + this->ort_stream_}; init_beam_state_func_(&beam_state, cpu_state.sequence_lengths, @@ -222,6 +243,16 @@ Status BeamSearchWhisper::Execute(const FeedsFetchesManager& encoder_feeds_fe std::vector decoder_feeds; int current_length = parameters->sequence_length; + // for decoder subgraph output cross qk + int64_t frames_of_k = 0LL; + Tensor* cross_qk_output = nullptr; // output tensor + int64_t cross_qk_layer_head_pair_count = 0LL; + OrtValue cross_qk_buffer_value; + float* cross_qk_buffer_data = nullptr; + std::vector cross_qk_all_layer_heads; + const int32_t* cross_qk_layer_head_pairs = nullptr; + IAllocatorUniquePtr qk_layer_pointers; // if needed, device array hold the cross qk data pointers, shape of [num_layers] + std::vector decoder_fetches; if (current_length + 1 < parameters->max_length) { @@ -265,6 +296,41 @@ Status BeamSearchWhisper::Execute(const FeedsFetchesManager& encoder_feeds_fe } } + if (decoder_subgraph_.output_cross_qk_) { + ORT_ENFORCE(decoder_subgraph_.has_decoder_masked_attention_, "decoder subgraph: output_cross_qk could only work with has_decoder_masked_attention"); + ORT_ENFORCE(decoder_subgraph_.past_present_share_buffer_, "decoder subgraph: output_cross_qk could only work with past_present_share_buffer"); + + cross_qk_layer_head_pair_count = parameters->num_layers * parameters->num_heads; + const auto* input_tensor_cross_qk_layer_head = this->context_.template Input(12); + ORT_ENFORCE(input_tensor_cross_qk_layer_head != nullptr, "Must specify input cross_qk_layer_head"); + cross_qk_layer_head_pair_count = input_tensor_cross_qk_layer_head->Shape()[0]; + cross_qk_layer_head_pairs = input_tensor_cross_qk_layer_head->template Data(); // it is on GPU + + size_t decoder_input_first_cross_key = static_cast(decoder_subgraph_.GetFirstPastInputIndex()) + (2 * decoder_subgraph_.num_layers); + auto first_cross_attention_key = decoder_feeds[decoder_input_first_cross_key].GetMutable(); + frames_of_k = first_cross_attention_key->Shape()[2]; + + TensorShape layer_cross_qk_shape{ + static_cast(parameters->BatchBeamSize()), + static_cast(parameters->num_heads), + 1LL, + static_cast(frames_of_k)}; + for (int layer = 0; layer < decoder_subgraph_.num_layers; layer++) { + OrtValue cross_qk_value; + Tensor::InitOrtValue(DataTypeImpl::GetType(), layer_cross_qk_shape, this->temp_space_allocator_, cross_qk_value); + decoder_fetches.emplace_back(cross_qk_value); + } + + TensorShape cross_qk_shape{ + static_cast(parameters->batch_size), + static_cast(parameters->num_beams), + cross_qk_layer_head_pair_count, + static_cast(parameters->max_length), + frames_of_k}; + Tensor::InitOrtValue(DataTypeImpl::GetType(), cross_qk_shape, this->temp_space_allocator_, cross_qk_buffer_value); + cross_qk_buffer_data = cross_qk_buffer_value.GetMutable()->MutableData(); + } + if (decoder_subgraph_.has_decoder_masked_attention_) { size_t offset = static_cast(decoder_subgraph_.GetFirstPastInputIndex()); // Need to check cross attention's past key tensor size, suppose all layers cross attention key size are same @@ -316,6 +382,21 @@ Status BeamSearchWhisper::Execute(const FeedsFetchesManager& encoder_feeds_fe ORT_RETURN_IF_ERROR(status); + if (decoder_subgraph_.output_cross_qk_) { + int decoder_output_first_cross_qk = decoder_subgraph_.GetFirstPresentOutputIndex() + (2 * decoder_subgraph_.num_layers); + ORT_RETURN_IF_ERROR(this->update_decoder_cross_qk_func_( + iteration_counter, + this->ort_stream_, + &decoder_fetches[decoder_output_first_cross_qk], + qk_layer_pointers, + parameters->num_layers, + static_cast(cross_qk_layer_head_pair_count), + cross_qk_layer_head_pairs, + cross_qk_buffer_data, + parameters->max_length, + this->temp_space_allocator_)); + } + #ifdef DEBUG_GENERATION for (int i = 0; i <= decoder_subgraph_.GetFirstPresentOutputIndex(); i++) { dumper->Print("decoder_fetches", i, true); @@ -383,6 +464,35 @@ Status BeamSearchWhisper::Execute(const FeedsFetchesManager& encoder_feeds_fe } } + if (decoder_subgraph_.output_cross_qk_) { + TensorShape cross_qk_shape{ + static_cast(parameters->batch_size), + static_cast(parameters->num_return_sequences), + cross_qk_layer_head_pair_count, + static_cast(iteration_counter - 1), + frames_of_k}; + cross_qk_output = this->context_.Output(3, cross_qk_shape); + + size_t cache_indir_input_offset = static_cast(decoder_subgraph_.GetFirstPastInputIndex()) + 4 * static_cast(decoder_subgraph_.num_layers) + 2; + const int* cache_indir_data = decoder_feeds[cache_indir_input_offset].GetMutable()->Data(); + auto beam_indices = this->beam_scorer_->GetNextIndicesGPU(); // currently only support on GPU + ORT_RETURN_IF_ERROR(this->finalize_decoder_cross_qk_func_( + this->ort_stream_, + iteration_counter, + parameters->sequence_length, + parameters->batch_size, + parameters->num_beams, + parameters->max_length, + static_cast(cross_qk_layer_head_pair_count), + cross_qk_layer_head_pairs, + static_cast(frames_of_k), + cross_qk_buffer_data, + cross_qk_output->MutableData(), + parameters->num_return_sequences, + cache_indir_data, + beam_indices)); + } + gsl::span final_beam_scores = beam_state.beam_scores; this->beam_scorer_->Finalize(cpu_state.sequences, final_beam_scores, diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc index 76011a5c89b66..fa8c7d85ff3d1 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc @@ -25,6 +25,7 @@ void BeamSearchParameters::ParseFromAttributes(const OpKernelInfo& info) { decoder_start_token_id = static_cast(info.GetAttrOrDefault("decoder_start_token_id", -1)); no_repeat_ngram_size = static_cast(info.GetAttrOrDefault("no_repeat_ngram_size", 0)); vocab_size = static_cast(info.GetAttrOrDefault("vocab_size", -1)); + no_speech_token = static_cast(info.GetAttrOrDefault("no_speech_token", -1LL)); } void BeamSearchParameters::ParseFromInputs(OpKernelContext* context) { @@ -47,6 +48,23 @@ void BeamSearchParameters::ParseFromInputs(OpKernelContext* context) { } batch_size = static_cast(dims[0]); + extra_decoding_ids = gsl::span(); + if (this->model_type == IGenerationParameters::kModelTypeWhisper) { + const Tensor* extra_decoder_tensor = context->Input(13); + if (extra_decoder_tensor != nullptr) { + const auto& extra_decoder_tensor_dims = extra_decoder_tensor->Shape().GetDims(); + ORT_ENFORCE(extra_decoder_tensor_dims.size() == 2, + "extra_decoder_tensor shall have 2 dimensions. Got ", + extra_decoder_tensor_dims.size()); + ORT_ENFORCE(extra_decoder_tensor_dims[0] == batch_size, + "extra_decoder_tensor first dim not same as batch_size. Got ", + extra_decoder_tensor_dims[0], ", expecting ", batch_size); + if (extra_decoder_tensor->Shape().Size() > 0) { + extra_decoding_ids = gsl::span(extra_decoder_tensor->Data(), extra_decoder_tensor->Shape().Size()); + } + } + } + if (this->model_type == IGenerationParameters::kModelTypeGpt) { sequence_length = static_cast(dims[1]); } else if (this->model_type == IGenerationParameters::kModelTypeWhisper) { diff --git a/onnxruntime/contrib_ops/cpu/transformers/generate_impl_base.h b/onnxruntime/contrib_ops/cpu/transformers/generate_impl_base.h index e889281abb023..680cb23fd887a 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generate_impl_base.h +++ b/onnxruntime/contrib_ops/cpu/transformers/generate_impl_base.h @@ -33,24 +33,43 @@ gsl::span AllocateBuffer(AllocatorPtr allocator, return span; } +template +gsl::span AllocateBuffer(AllocatorPtr allocator, + IAllocatorUniquePtr& buffer, + size_t elements, + Stream* stream, + bool fill = false, + T fill_value = T{}) { + size_t bytes = SafeInt(sizeof(T)) * elements; + buffer = IAllocator::MakeUniquePtr(allocator, bytes, false, stream); + T* first = reinterpret_cast(buffer.get()); + auto span = gsl::make_span(first, elements); + + if (fill) { + std::fill_n(first, elements, fill_value); + } + + return span; +} + template inline void AllocateTempBufferForGetGreedySearchTopOne( int32_t batch_size, AllocatorPtr allocator, - BufferUniquePtr& buffer, + IAllocatorUniquePtr& buffer, gsl::span& stage_1_scores, // shape (batch_size, parts_of_vocab) gsl::span& stage_1_tokens, // shape (batch_size, parts_of_vocab) gsl::span& output_scores, // shape (batch_size) - gsl::span& output_tokens // shape (batch_size) -) { + gsl::span& output_tokens, // shape (batch_size) + Stream* stream) { constexpr size_t kMaxPartsPerVocab = 128; const size_t stage_1_element_size = kMaxPartsPerVocab * batch_size; const size_t output_element_size = batch_size; // Note: use float to allocate buffer for temporary value buffer to avoid unalignment - void* topk_data = allocator->Alloc((stage_1_element_size + output_element_size) * (sizeof(float) + sizeof(int32_t))); - BufferUniquePtr temp_buffer(topk_data, BufferDeleter(allocator)); - buffer = std::move(temp_buffer); + size_t bytes = (stage_1_element_size + output_element_size) * (sizeof(float) + sizeof(int32_t)); + buffer = IAllocator::MakeUniquePtr(allocator, bytes, false, stream); + void* topk_data = buffer.get(); ElementType* stage_1_scores_data = reinterpret_cast(topk_data); stage_1_scores = gsl::make_span(stage_1_scores_data, stage_1_element_size); diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc index 88348ad88dc27..ea80e01a7adda 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc @@ -1084,6 +1084,40 @@ template Status CreateWhisperEncoderInputs( OrtValue& encoder_input_features, OrtValue& decoder_input_ids); +Status UpdateDecoderCrossQK( + [[maybe_unused]] int iteration_number, + [[maybe_unused]] Stream* tream, + [[maybe_unused]] OrtValue* cross_qks, + [[maybe_unused]] IAllocatorUniquePtr& qk_layer_pointers, + [[maybe_unused]] int num_layers, + [[maybe_unused]] int cross_qk_layer_head_pair_count, + [[maybe_unused]] const int* cross_qk_layer_head_pairs, + [[maybe_unused]] float* cross_qk_buffer_data, + [[maybe_unused]] int max_length, + [[maybe_unused]] AllocatorPtr allocator) { + throw std::runtime_error("CPU beam search current not support output cross QK."); + return Status::OK(); +} + +Status FinalizeDecoderCrossQK( + [[maybe_unused]] Stream* stream, + [[maybe_unused]] int iteration_number, + [[maybe_unused]] int context_decoding_len, + [[maybe_unused]] int batch_size, + [[maybe_unused]] int num_beams, + [[maybe_unused]] int max_length, + [[maybe_unused]] int cross_qk_layer_head_pair_count, + [[maybe_unused]] const int* cross_qk_layer_head_pairs, + [[maybe_unused]] int frames_of_k, + [[maybe_unused]] const float* cross_qk_buffer_data, + [[maybe_unused]] float* cross_qk_output, + [[maybe_unused]] int num_return_sequences, + [[maybe_unused]] const int* cache_indir_data, + [[maybe_unused]] gsl::span beam_indices) { + throw std::runtime_error("CPU beam search current not support output cross QK."); + return Status::OK(); +} + } // namespace GenerationCpuDeviceHelper } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h index ba1b0b662f1a5..6dfdc6b027671 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h @@ -204,6 +204,35 @@ using ExpandBufferFunc = std::function; + +using UpdateDecoderCrossQKFunc = std::function& qk_layer_pointers, + int num_layers, + int cross_qk_layer_head_pair_count, + const int* cross_qk_layer_head_pairs, + float* cross_qk_buffer_data, + int max_length, + AllocatorPtr allocator)>; + +using FinalizeDecoderCrossQKFunc = std::function beam_indices)>; + } // namespace GenerationDeviceHelper // These are CPU specific device helper implementations @@ -368,6 +397,34 @@ Status ExpandBuffer( bool only_copy_shape, int max_sequence_length); +Status UpdateDecoderCrossQK( + int iteration_number, + Stream* stream, + OrtValue* cross_qks, + IAllocatorUniquePtr& qk_layer_pointers, + int num_layers, + int cross_qk_layer_head_pair_count, + const int* cross_qk_layer_head_pairs, + float* cross_qk_buffer_data, + int max_length, + AllocatorPtr allocator); + +Status FinalizeDecoderCrossQK( + Stream* stream, + int iteration_number, + int context_decoding_len, + int batch_size, + int num_beams, + int max_length, + int cross_qk_layer_head_pair_count, + const int* cross_qk_layer_head_pairs, + int frames_of_k, + const float* cross_qk_buffer_data, + float* cross_qk_output, + int num_return_sequences, + const int* cache_indir_data, + gsl::span beam_indices); + } // namespace GenerationCpuDeviceHelper } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h index 719dd302d274d..2846c1db642d1 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h @@ -53,6 +53,7 @@ struct IBeamSearchCpuState { gsl::span topk_tokens; // shape (batch_size, 2*num_beams), tokens of topk candidates. gsl::span topk_indices; // shape (batch_size, 2*num_beams), beam indices of topk candidates. gsl::span final_beam_scores; // shape (batch_size, num_beams) + gsl::span next_token_scores; // shape (batch_size, num_beams * vocab_size) }; template @@ -175,6 +176,11 @@ struct IGenerationParameters { int seed = 0; int min_tokens_to_keep = 1; bool custom_sampling = false; + + bool decoder_output_cross_qk = false; + gsl::span extra_decoding_ids; + int32_t no_speech_token = -1; + void* no_speech_probs = nullptr; }; } // namespace transformers diff --git a/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_base.h b/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_base.h index be974ed2159d9..9f372e5b3a673 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_base.h +++ b/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_base.h @@ -20,26 +20,27 @@ struct SamplingState : public ISamplingState { int vocab_size, int max_iter, int seed, - bool is_cuda) { + bool is_cuda, + Stream* stream) { int total_count = batch_size * vocab_size; - this->h_softmaxed_score = AllocateBuffer(cpu_allocator, h_softmaxed_score_buffer_, SafeInt(total_count)); + this->h_softmaxed_score = AllocateBuffer(cpu_allocator, h_softmaxed_score_buffer_, SafeInt(total_count), stream); this->generator = std::default_random_engine{gsl::narrow_cast(seed)}; if (is_cuda) { - this->d_index_in = AllocateBuffer(allocator, d_index_in_buffer_, SafeInt(total_count)); - this->d_index_out = AllocateBuffer(allocator, d_index_out_buffer_, SafeInt(total_count)); - this->d_offset = AllocateBuffer(allocator, d_offset_buffer_, SafeInt(batch_size + 1)); - this->d_sorted_score = AllocateBuffer(allocator, d_sorted_score_buffer_, SafeInt(total_count)); - this->d_sorted_softmaxed_score = AllocateBuffer(allocator, d_sorted_softmaxed_score_buffer_, SafeInt(total_count)); - this->d_softmaxed_score = AllocateBuffer(allocator, d_softmaxed_score_buffer_, SafeInt(total_count)); - this->d_sampled = AllocateBuffer(allocator, d_sampled_buffer_, SafeInt(batch_size)); - this->h_sampled_all = AllocateBuffer(cpu_allocator, h_sampled_all_buffer_, SafeInt(batch_size * max_iter)); - this->d_indices = AllocateBuffer(allocator, d_indices_buffer_, SafeInt(batch_size)); + this->d_index_in = AllocateBuffer(allocator, d_index_in_buffer_, SafeInt(total_count), stream); + this->d_index_out = AllocateBuffer(allocator, d_index_out_buffer_, SafeInt(total_count), stream); + this->d_offset = AllocateBuffer(allocator, d_offset_buffer_, SafeInt(batch_size + 1), stream); + this->d_sorted_score = AllocateBuffer(allocator, d_sorted_score_buffer_, SafeInt(total_count), stream); + this->d_sorted_softmaxed_score = AllocateBuffer(allocator, d_sorted_softmaxed_score_buffer_, SafeInt(total_count), stream); + this->d_softmaxed_score = AllocateBuffer(allocator, d_softmaxed_score_buffer_, SafeInt(total_count), stream); + this->d_sampled = AllocateBuffer(allocator, d_sampled_buffer_, SafeInt(batch_size), stream); + this->h_sampled_all = AllocateBuffer(cpu_allocator, h_sampled_all_buffer_, SafeInt(batch_size * max_iter), stream); + this->d_indices = AllocateBuffer(allocator, d_indices_buffer_, SafeInt(batch_size), stream); this->temp_storage_bytes = 0; // TODO: Do not allocate this buffer if there's no presence_mask - this->d_presence_mask = AllocateBuffer(allocator, d_presence_mask_buffer_, SafeInt(total_count)); + this->d_presence_mask = AllocateBuffer(allocator, d_presence_mask_buffer_, SafeInt(total_count), stream); std::uniform_real_distribution distribution(0.0, 1.0); static_cast(distribution(this->generator)); @@ -48,25 +49,25 @@ struct SamplingState : public ISamplingState { } } else { // TODO: Some buffer can be reused for CPU - this->sorted_scores = AllocateBuffer(cpu_allocator, sorted_scores_buffer_, SafeInt(total_count)); - this->cumulative_probs = AllocateBuffer(cpu_allocator, cumulative_probs_buffer_, SafeInt(total_count)); + this->sorted_scores = AllocateBuffer(cpu_allocator, sorted_scores_buffer_, SafeInt(total_count), stream); + this->cumulative_probs = AllocateBuffer(cpu_allocator, cumulative_probs_buffer_, SafeInt(total_count), stream); } } private: - BufferUniquePtr d_index_in_buffer_; - BufferUniquePtr d_index_out_buffer_; - BufferUniquePtr d_offset_buffer_; - BufferUniquePtr d_sorted_score_buffer_; - BufferUniquePtr d_sorted_softmaxed_score_buffer_; - BufferUniquePtr d_softmaxed_score_buffer_; - BufferUniquePtr h_softmaxed_score_buffer_; - BufferUniquePtr d_sampled_buffer_; - BufferUniquePtr h_sampled_all_buffer_; - BufferUniquePtr d_indices_buffer_; - BufferUniquePtr d_presence_mask_buffer_; - BufferUniquePtr sorted_scores_buffer_; - BufferUniquePtr cumulative_probs_buffer_; + IAllocatorUniquePtr d_index_in_buffer_; + IAllocatorUniquePtr d_index_out_buffer_; + IAllocatorUniquePtr d_offset_buffer_; + IAllocatorUniquePtr d_sorted_score_buffer_; + IAllocatorUniquePtr d_sorted_softmaxed_score_buffer_; + IAllocatorUniquePtr d_softmaxed_score_buffer_; + IAllocatorUniquePtr h_softmaxed_score_buffer_; + IAllocatorUniquePtr d_sampled_buffer_; + IAllocatorUniquePtr h_sampled_all_buffer_; + IAllocatorUniquePtr d_indices_buffer_; + IAllocatorUniquePtr d_presence_mask_buffer_; + IAllocatorUniquePtr sorted_scores_buffer_; + IAllocatorUniquePtr cumulative_probs_buffer_; }; template @@ -82,24 +83,25 @@ struct GreedySearchState : public IGreedySearchState { int num_heads, int head_size, bool has_decoder_masked_self_attention, - bool is_cuda) { + bool is_cuda, + Stream* stream) { // below buffers are on cpu this->sequences_space = AllocateBuffer(cpu_allocator, sequences_space_buffer_, - SafeInt(2) * batch_size * max_length); + SafeInt(2) * batch_size * max_length, stream); memset(this->sequences_space.data(), 0, this->sequences_space.size_bytes()); this->sequences.Init(this->sequences_space, static_cast(batch_size), sequence_length, max_length); - this->sequence_lengths = AllocateBuffer(cpu_allocator, sequence_lengths_buffer_, batch_size); - this->eos_meet = AllocateBuffer(cpu_allocator, eos_meet_buffer_, batch_size); + this->sequence_lengths = AllocateBuffer(cpu_allocator, sequence_lengths_buffer_, batch_size, stream); + this->eos_meet = AllocateBuffer(cpu_allocator, eos_meet_buffer_, batch_size, stream); memset(this->eos_meet.data(), 0, this->eos_meet.size_bytes()); - this->next_tokens = AllocateBuffer(cpu_allocator, next_tokens_buffer_, SafeInt(batch_size)); + this->next_tokens = AllocateBuffer(cpu_allocator, next_tokens_buffer_, SafeInt(batch_size), stream); // below buffers are on cpu or cuda size_t next_token_size = SafeInt(batch_size) * vocab_size; - this->next_token_scores = AllocateBuffer(allocator, next_token_scores_buffer_, next_token_size); - this->next_positions = AllocateBuffer(allocator, next_positions_buffer_, batch_size); + this->next_token_scores = AllocateBuffer(allocator, next_token_scores_buffer_, next_token_size, stream); + this->next_positions = AllocateBuffer(allocator, next_positions_buffer_, batch_size, stream); if (is_cuda) { AllocateTempBufferForGetGreedySearchTopOne( @@ -109,7 +111,8 @@ struct GreedySearchState : public IGreedySearchState { this->temp_topk_scores_buffer, this->temp_topk_tokens_buffer, this->topk_scores_buffer, - this->topk_tokens_buffer); + this->topk_tokens_buffer, + stream); // If at all we need to, we only need to re-order past state for CUDA as //`DecoderMaskedSelfAttention` is only supported on CUDA @@ -137,14 +140,14 @@ struct GreedySearchState : public IGreedySearchState { } private: - BufferUniquePtr sequences_space_buffer_; - BufferUniquePtr sequence_lengths_buffer_; - BufferUniquePtr next_token_scores_buffer_; - BufferUniquePtr next_tokens_buffer_; - BufferUniquePtr next_positions_buffer_; - BufferUniquePtr eos_meet_buffer_; - BufferUniquePtr temp_topk_buffer_; - BufferUniquePtr staging_for_past_state_reorder_buffer_; + IAllocatorUniquePtr sequences_space_buffer_; + IAllocatorUniquePtr sequence_lengths_buffer_; + IAllocatorUniquePtr next_token_scores_buffer_; + IAllocatorUniquePtr next_tokens_buffer_; + IAllocatorUniquePtr next_positions_buffer_; + IAllocatorUniquePtr eos_meet_buffer_; + IAllocatorUniquePtr temp_topk_buffer_; + IAllocatorUniquePtr staging_for_past_state_reorder_buffer_; }; // Base class of gready search implementation that is common for both GPT-2 and Bart/T5. diff --git a/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_gpt.h b/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_gpt.h index 4504b099e32bd..69d25eaabbe02 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_gpt.h +++ b/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_gpt.h @@ -211,7 +211,8 @@ Status GreedySearchGpt::Execute(const FeedsFetchesManager* init_ static_cast(parameters->num_heads), static_cast(parameters->head_size), gpt_subgraph_.has_decoder_masked_attention_, - this->IsCuda()); + this->IsCuda(), + this->ort_stream_); SamplingState sampling_state; if (std::is_same::value) { @@ -221,7 +222,8 @@ Status GreedySearchGpt::Execute(const FeedsFetchesManager* init_ static_cast(parameters->vocab_size), static_cast(parameters->max_length - parameters->sequence_length), parameters->seed, - this->IsCuda()); + this->IsCuda(), + this->ort_stream_); } IAllocatorUniquePtr buffer; diff --git a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc index 9f77c32f0c7cc..f39f090c78b0c 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc @@ -17,20 +17,6 @@ namespace onnxruntime { namespace contrib { namespace transformers { -template -gsl::span NextTokenScores::GetScores(int batch_beam_index) { - assert(batch_beam_index >= 0 && batch_beam_index < batch_beam_size); - return scores.subspan(static_cast(batch_beam_index) * vocab_size, vocab_size); -} - -template -void NextTokenScores::SetScore(int token_id, T score) { - assert(token_id >= 0 && token_id < vocab_size); - for (int i = 0; i < batch_beam_size; i++) { - scores[static_cast(i) * vocab_size + token_id] = score; - } -} - #ifdef DEBUG_GENERATION template void DumpScores(const char* name, const NextTokenScores& next_token_scores) { @@ -238,128 +224,6 @@ void PresencePenaltyLogitsProcessor::Process(const ISequences*, #endif } -template -TimestampLogitsProcessor::TimestampLogitsProcessor(int eos_token_id, int max_initial_timestamp_index) - : eos_token_id_(eos_token_id), max_initial_timestamp_index_(max_initial_timestamp_index) {} - -template -void TimestampLogitsProcessor::Process(const ISequences* sequences, - NextTokenScores& next_token_scores) { - const int beg_token_id_ = eos_token_id_ + 107; - const int not_token_id_ = eos_token_id_ + 106; - const int solm_token_id_ = eos_token_id_ + 105; - const int sot_token_id_ = eos_token_id_ + 1; - constexpr int translate_token_id_ = 50358; - constexpr int transcribe_token_id_ = 50359; - - const int batch_beam_size = next_token_scores.batch_beam_size; - const int vocab_size = next_token_scores.vocab_size; - for (int i = 0; i < batch_beam_size; i++) { - gsl::span beam_token_scores = next_token_scores.GetScores(i); - gsl::span sequence = sequences->GetSequence(i); - const size_t seq_length = sequence.size(); - - // Find first timestamp - size_t sample_begin = 0; - for (size_t j = 0; j < seq_length; j++) { - sample_begin++; - if (sequence[j] >= beg_token_id_) { - break; - } - } - - // Suppress tokens - for (int j = 0; j < vocab_size; j++) { - // Suppress notimestamps and solm tokens - if (j == not_token_id_ || j == solm_token_id_) { - beam_token_scores[j] = std::numeric_limits::lowest(); - } - - // Suppress sot, translate and transcribe tokens - if (seq_length > sample_begin) { - if (j == sot_token_id_ || j == translate_token_id_ || j == transcribe_token_id_) { - beam_token_scores[j] = std::numeric_limits::lowest(); - } - } - } - - // Timestamps should be in pair except the first one - const bool last_was_timestamp = seq_length > 0 && sequence.back() >= beg_token_id_; - const bool penultimate_was_timestamp = seq_length <= sample_begin || sequence[seq_length - 2] >= beg_token_id_; - if (last_was_timestamp) { - if (penultimate_was_timestamp) { - // If timestamps show up in pair, or it's the first timestamp, no more timestamp is generated - for (int j = beg_token_id_; j < vocab_size; j++) { - beam_token_scores[j] = std::numeric_limits::lowest(); - } - } else { - // If timestamp doesn't show up in pair, generate timestamp - for (int j = 0; j < eos_token_id_; j++) { - beam_token_scores[j] = std::numeric_limits::lowest(); - } - } - } - - // Find timestamp tokens - std::vector timestamps; - for (const auto& word_id : sequence) { - if (word_id >= beg_token_id_) { - timestamps.push_back(word_id); - } - } - - // Timestamps will not decrease - const size_t timestamps_len = timestamps.size(); - if (timestamps_len > 0) { - int timestamp_last = 0; - if (last_was_timestamp && !penultimate_was_timestamp) { - // For single timestamp at the end, next timestamp must not be smaller - timestamp_last = timestamps.back(); - } else { - // For paired timestamp at the end, next timestamp must be greater - timestamp_last = timestamps.back() + 1; - } - - for (int j = beg_token_id_; j < timestamp_last; j++) { - beam_token_scores[j] = std::numeric_limits::lowest(); - } - } - - if (seq_length == sample_begin) { - const int last_allowed = beg_token_id_ + max_initial_timestamp_index_; - for (int j = last_allowed + 1; j < vocab_size; j++) { - beam_token_scores[j] = std::numeric_limits::lowest(); - } - } - - // Caculate logsumexp on timestamps - float timestamp_logprob = std::numeric_limits::lowest(); - { - float logsumexp = 0.0f; - const float logprob_max = *std::max_element(beam_token_scores.begin() + beg_token_id_, beam_token_scores.end()); - for (int j = beg_token_id_; j < vocab_size; ++j) { - if (beam_token_scores[j] > std::numeric_limits::lowest()) { - logsumexp += expf(beam_token_scores[j] - logprob_max); - } - } - if (logsumexp > 0.0f) { - timestamp_logprob = logf(logsumexp) + logprob_max; - } - } - - const float max_text_token_logprob = *std::max_element(beam_token_scores.begin(), beam_token_scores.begin() + beg_token_id_); - if (timestamp_logprob > max_text_token_logprob) { - for (int j = 0; j < beg_token_id_; ++j) { - beam_token_scores[j] = std::numeric_limits::lowest(); - } - } - } - -#ifdef DEBUG_GENERATION - DumpScores("TimestampLogitsProcessor", next_token_scores); -#endif -} - void LogitsProcessorList::Init(const BeamSearchParameters& parameters) { LogitsProcessorInitImpl(parameters); } diff --git a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h index 664c497a106d4..c0ef83675399e 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h +++ b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h @@ -6,6 +6,7 @@ #include "core/common/inlined_containers.h" #include "contrib_ops/cpu/transformers/sequences.h" #include "contrib_ops/cpu/transformers/beam_search_parameters.h" +#include "contrib_ops/cpu/transformers/dump_tensor.h" #include "contrib_ops/cpu/transformers/greedy_search_parameters.h" #include "contrib_ops/cpu/transformers/sampling_parameters.h" #include "contrib_ops/cpu/transformers/generation_shared.h" @@ -20,9 +21,17 @@ struct NextTokenScores { int batch_beam_size; int vocab_size; - gsl::span GetScores(int batch_beam_index); + gsl::span GetScores(int batch_beam_index) { + assert(batch_beam_index >= 0 && batch_beam_index < batch_beam_size); + return scores.subspan(static_cast(batch_beam_index) * vocab_size, vocab_size); + } - void SetScore(int token_id, T score); + void SetScore(int token_id, T score) { + assert(token_id >= 0 && token_id < vocab_size); + for (int i = 0; i < batch_beam_size; i++) { + scores[static_cast(i) * vocab_size + token_id] = score; + } + } }; // Interface for all scorers for beam search or beam sample. @@ -141,10 +150,125 @@ class PresencePenaltyLogitsProcessor : public ILogitsProcessor { template class TimestampLogitsProcessor : public ILogitsProcessor { public: - TimestampLogitsProcessor(int eos_token_id, int max_initial_timestamp_index); + TimestampLogitsProcessor(int eos_token_id, int max_initial_timestamp_index) + : eos_token_id_(eos_token_id), max_initial_timestamp_index_(max_initial_timestamp_index) {} void Process(const ISequences* sequences, - NextTokenScores& next_token_scores) override; + NextTokenScores& next_token_scores) override { + const int beg_token_id_ = eos_token_id_ + 107; + const int not_token_id_ = eos_token_id_ + 106; + const int solm_token_id_ = eos_token_id_ + 105; + const int sot_token_id_ = eos_token_id_ + 1; + constexpr int translate_token_id_ = 50358; + constexpr int transcribe_token_id_ = 50359; + + const int batch_beam_size = next_token_scores.batch_beam_size; + const int vocab_size = next_token_scores.vocab_size; + for (int i = 0; i < batch_beam_size; i++) { + gsl::span beam_token_scores = next_token_scores.GetScores(i); + gsl::span sequence = sequences->GetSequence(i); + const size_t seq_length = sequence.size(); + + // Find first timestamp + size_t sample_begin = 0; + for (size_t j = 0; j < seq_length; j++) { + sample_begin++; + if (sequence[j] >= beg_token_id_) { + break; + } + } + + // Suppress tokens + for (int j = 0; j < vocab_size; j++) { + // Suppress notimestamps and solm tokens + if (j == not_token_id_ || j == solm_token_id_) { + beam_token_scores[j] = std::numeric_limits::lowest(); + } + + // Suppress sot, translate and transcribe tokens + if (seq_length > sample_begin) { + if (j == sot_token_id_ || j == translate_token_id_ || j == transcribe_token_id_) { + beam_token_scores[j] = std::numeric_limits::lowest(); + } + } + } + + // Timestamps should be in pair except the first one + const bool last_was_timestamp = seq_length > 0 && sequence.back() >= beg_token_id_; + const bool penultimate_was_timestamp = seq_length <= sample_begin || sequence[seq_length - 2] >= beg_token_id_; + if (last_was_timestamp) { + if (penultimate_was_timestamp) { + // If timestamps show up in pair, or it's the first timestamp, no more timestamp is generated + for (int j = beg_token_id_; j < vocab_size; j++) { + beam_token_scores[j] = std::numeric_limits::lowest(); + } + } else { + // If timestamp doesn't show up in pair, generate timestamp + for (int j = 0; j < eos_token_id_; j++) { + beam_token_scores[j] = std::numeric_limits::lowest(); + } + } + } + + // Find timestamp tokens + std::vector timestamps; + for (const auto& word_id : sequence) { + if (word_id >= beg_token_id_) { + timestamps.push_back(word_id); + } + } + + // Timestamps will not decrease + const size_t timestamps_len = timestamps.size(); + if (timestamps_len > 0) { + int timestamp_last = 0; + if (last_was_timestamp && !penultimate_was_timestamp) { + // For single timestamp at the end, next timestamp must not be smaller + timestamp_last = timestamps.back(); + } else { + // For paired timestamp at the end, next timestamp must be greater + timestamp_last = timestamps.back() + 1; + } + + for (int j = beg_token_id_; j < timestamp_last; j++) { + beam_token_scores[j] = std::numeric_limits::lowest(); + } + } + + if (seq_length == sample_begin) { + const int last_allowed = beg_token_id_ + max_initial_timestamp_index_; + for (int j = last_allowed + 1; j < vocab_size; j++) { + beam_token_scores[j] = std::numeric_limits::lowest(); + } + } + + // Caculate logsumexp on timestamps + float timestamp_logprob = std::numeric_limits::lowest(); + { + float logsumexp = 0.0f; + const float logprob_max = *std::max_element(beam_token_scores.begin() + beg_token_id_, beam_token_scores.end()); + for (int j = beg_token_id_; j < vocab_size; ++j) { + if (beam_token_scores[j] > std::numeric_limits::lowest()) { + logsumexp += expf(beam_token_scores[j] - logprob_max); + } + } + if (logsumexp > 0.0f) { + timestamp_logprob = logf(logsumexp) + logprob_max; + } + } + + const float max_text_token_logprob = *std::max_element(beam_token_scores.begin(), beam_token_scores.begin() + beg_token_id_); + if (timestamp_logprob > max_text_token_logprob) { + for (int j = 0; j < beg_token_id_; ++j) { + beam_token_scores[j] = std::numeric_limits::lowest(); + } + } + } + +#ifdef DEBUG_GENERATION + DumpScores("TimestampLogitsProcessor", next_token_scores); +#endif + } private: int eos_token_id_; diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.h b/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.h index 3c11d2d324a85..487a35c55a85f 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.h +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.h @@ -45,6 +45,7 @@ class Subgraph { int num_layers; bool past_present_share_buffer_; bool has_decoder_masked_attention_; + bool output_cross_qk_ = false; // Setup execution Status Setup(const SessionState& session_state, diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc index 28acd81ae95fd..4d61ce71c69be 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc @@ -172,7 +172,7 @@ Status T5DecoderSubgraph::CreateInitialFeeds( int32_t* input_ids_data = input_ids.GetMutable()->MutableData(); AllocatorPtr buffer_allocator = std::make_shared(); size_t total_size = static_cast(static_cast(cur_len) * batch_beam_size * sizeof(int)); - auto seq_copy = IAllocator::MakeUniquePtr(buffer_allocator, total_size); + auto seq_copy = IAllocator::MakeUniquePtr(buffer_allocator, total_size, false, stream); int* seq_copy_ptr = seq_copy.get(); if (!use_sequence_as_input_ids_) { diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.h b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.h index 085d8f3903976..83dae49c7dcbd 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.h +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.h @@ -5,6 +5,7 @@ #include "contrib_ops/cpu/transformers/subgraph_base.h" #include "contrib_ops/cpu/transformers/sequences.h" +#include "core/framework/op_kernel.h" namespace onnxruntime { namespace contrib { @@ -20,6 +21,13 @@ class T5DecoderSubgraph : public Subgraph { has_hidden_state_(false), use_sequence_as_input_ids_(true) { first_present_output_index_ = 1; + + // Currently just using parent node's attribute. Maybe better to find it purely in subgraph. + const auto& attributes = node_in.GetAttributes(); + if (attributes.find("decoder_output_cross_qk") != attributes.end()) { + auto& attr = attributes.at("decoder_output_cross_qk"); + output_cross_qk_ = (attr.i() != 0LL); + } } // Create inputs for first inference of decoder subgraph. @@ -62,7 +70,7 @@ class T5DecoderSubgraph : public Subgraph { return first_present_output_index_; } - bool UseSequenceAsInputIds() const { + inline bool UseSequenceAsInputIds() const { return use_sequence_as_input_ids_; } diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_decoder.cc b/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_decoder.cc index 887a6a8984b83..7d0c62b618ee2 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_decoder.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_decoder.cc @@ -70,8 +70,15 @@ Status WhisperDecoderSubgraph::Validate(const std::vector& subgr "number of inputs expected to be kFirstPastInputIndex + 4 * layers + 1, got:", num_subgraph_inputs); } - ORT_RETURN_IF(num_subgraph_outputs < 3 || (num_subgraph_outputs - first_present_output_index_) % 2 != 0, - "number of outputs expected to be 1 + 2 * layers, got:", num_subgraph_outputs); + if (!output_cross_qk_) { + ORT_RETURN_IF(num_subgraph_outputs < 3 || (num_subgraph_outputs - first_present_output_index_) % 2 != 0, + "number of outputs expected to be first_present_output_index_", + first_present_output_index_, " + 2 * layers, got:", num_subgraph_outputs); + } else { + ORT_RETURN_IF(num_subgraph_outputs < 4 || (num_subgraph_outputs - first_present_output_index_) % 3 != 0, + "When outputing cross qk, number of outputs expected to be first_present_output_index_", + first_present_output_index_, " + 3 * layers, got:", num_subgraph_outputs); + } ORT_RETURN_IF(subgraph_inputs[0]->Name() != "input_ids", "decoder subgraph input 0 shall be named as input_ids, got: ", subgraph_inputs[0]->Name()); @@ -90,7 +97,8 @@ Status WhisperDecoderSubgraph::Validate(const std::vector& subgr // Save parameters related to the subgraph. ORT_RETURN_IF_ERROR(GetParameters(past_shape, logits_shape, false)); - num_layers = (static_cast(subgraph_outputs.size()) - first_present_output_index_) / 2; + + num_layers = (static_cast(subgraph_outputs.size()) - first_present_output_index_) / (output_cross_qk_ ? 3 : 2); // If input_ids's shape is ['batch_size', 1] then use next token as input_ids. // Otherwise in the case of shape ['batch_size', 'sequence'], use sequence as input_ids. @@ -112,12 +120,7 @@ Status WhisperDecoderSubgraph::Validate(const std::vector& subgr for (int i = first_past_input_index_; i < first_past_input_index_ + 4 * num_layers; i++) { ORT_RETURN_IF(subgraph_inputs[i]->TypeAsProto()->tensor_type().elem_type() != float_type, - "decoder subgraph past inputs shall have same data type as that of encoder_hidden_states"); - } - - for (int i = 0; i < num_subgraph_outputs; i++) { - ORT_RETURN_IF(subgraph_outputs[i]->TypeAsProto()->tensor_type().elem_type() != float_type, - "decoder subgraph output shall have same data type as that of encoder_hidden_states"); + "decoder subgraph past inputs shall have same data type as that of encoder_hidden_states."); } is_output_float16_ = (subgraph_outputs[0]->TypeAsProto()->tensor_type().elem_type() == float16_type); @@ -166,7 +169,7 @@ Status WhisperDecoderSubgraph::CreateInitialFeeds( AllocatorPtr buffer_allocator = std::make_shared(); size_t total_size = static_cast(static_cast(cur_len) * batch_beam_size * sizeof(int)); - auto seq_copy = IAllocator::MakeUniquePtr(buffer_allocator, total_size); + auto seq_copy = IAllocator::MakeUniquePtr(buffer_allocator, total_size, false, stream); int* seq_copy_ptr = seq_copy.get(); if (!use_sequence_as_input_ids_) { diff --git a/onnxruntime/contrib_ops/cuda/bert/attention.cc b/onnxruntime/contrib_ops/cuda/bert/attention.cc index f0385ea5abdfb..5e33e1307fcc4 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/attention.cc @@ -155,8 +155,10 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { if (use_causal_fused_runner) { // Here we assume that num_heads, head_size and is_unidirectional does not change for an Attention node. if (nullptr == fused_fp16_runner_.get()) { - fused_fp16_runner_ = FusedMHARunnerFP16v2::Create(num_heads_, parameters.head_size, sm, is_unidirectional_, - enable_trt_flash_attention_, parameters.scale); + std::call_once(fused_fp16_runner_created_, [&]() { + fused_fp16_runner_ = FusedMHARunnerFP16v2::Create(num_heads_, parameters.head_size, sm, is_unidirectional_, + enable_trt_flash_attention_, parameters.scale); + }); } // Here we assume all causal kernels can be loaded into shared memory. TODO: add a function to check. @@ -175,8 +177,10 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { if (use_fused_runner) { // Here we assume that num_heads, head_size and is_unidirectional does not change for an Attention node. if (nullptr == fused_fp16_runner_.get()) { - fused_fp16_runner_ = FusedMHARunnerFP16v2::Create(num_heads_, parameters.head_size, sm, is_unidirectional_, - enable_trt_flash_attention_, parameters.scale); + std::call_once(fused_fp16_runner_created_, [&]() { + fused_fp16_runner_ = FusedMHARunnerFP16v2::Create(num_heads_, parameters.head_size, sm, is_unidirectional_, + enable_trt_flash_attention_, parameters.scale); + }); } // In case some kernel not loaded due to shared memory limit, we need to double check here. @@ -213,11 +217,12 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { typedef typename ToCudaType::MappedType CudaT; - IAllocatorUniquePtr gemm_buffer; + AllocatorPtr allocator; + ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); int m = batch_size * sequence_length; int n = (parameters.hidden_size + parameters.hidden_size + parameters.v_hidden_size); int k = parameters.input_hidden_size; - gemm_buffer = GetScratchBuffer(static_cast(m) * n, context->GetComputeStream()); + IAllocatorUniquePtr gemm_buffer = IAllocator::MakeUniquePtr(allocator, static_cast(m * n) * sizeof(T), false, context->GetComputeStream()); CudaT one = ToCudaType::FromFloat(1.0f); CudaT zero = ToCudaType::FromFloat(0.0f); @@ -244,7 +249,8 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { use_flash_attention, use_fused_cross_attention, use_memory_efficient_attention); - auto work_space = GetScratchBuffer(workSpaceSize, context->GetComputeStream()); + IAllocatorUniquePtr work_space = IAllocator::MakeUniquePtr(allocator, workSpaceSize, false, context->GetComputeStream()); + ; typedef typename ToCudaType::MappedType CudaT; AttentionData data; diff --git a/onnxruntime/contrib_ops/cuda/bert/attention.h b/onnxruntime/contrib_ops/cuda/bert/attention.h index 455e55ba05a66..acafb379d713f 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention.h @@ -4,6 +4,7 @@ #pragma once #include +#include #include "core/providers/cuda/cuda_kernel.h" #include "contrib_ops/cpu/bert/attention_base.h" #include "contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.h" @@ -28,6 +29,7 @@ class Attention final : public CudaKernel, public AttentionBase { bool disable_memory_efficient_attention_; int min_seq_len_for_flash_attention_packed_qkv_; mutable std::unique_ptr fused_fp16_runner_; + mutable std::once_flag fused_fp16_runner_created_; }; } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.cc index 4bdc6db30b036..54aad9cbaf387 100644 --- a/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.cc @@ -22,6 +22,7 @@ static constexpr int kBeamWidthInputIndex = 8; static constexpr int kCacheIndirectionInputIndex = 9; static constexpr int kPastInputIndex = 5; static constexpr int kPresentOutputIndex = 1; +static constexpr int kQKOutputIndex = 3; static constexpr int kBiasIndex = 10; #define REGISTER_KERNEL_TYPED(T1, T2) \ @@ -50,6 +51,7 @@ DecoderMaskedMultiHeadAttention::DecoderMaskedMultiHeadAttention(const O mask_filter_value_ = info.GetAttrOrDefault("mask_filter_value", -10000.0f); scale_ = info.GetAttrOrDefault("scale", 0.0f); past_present_share_buffer_ = info.GetAttrOrDefault("past_present_share_buffer", 0LL); + output_qk_ = info.GetAttrOrDefault("output_qk", 0LL); } template @@ -98,7 +100,7 @@ Status DecoderMaskedMultiHeadAttention::ComputeInternal(OpKernelContext* // This kernel is for decoding only (i.e.) sequence length has to be 1 if (sequence_length != 1) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input sequence length should be 1 to use DecoderMaskedMultiHeadAttention"); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input sequence length should be 1 to use DecoderMaskedMultiHeadAttention. Actual length is ", sequence_length); } if (parameters.head_size != parameters.v_head_size) { @@ -125,6 +127,7 @@ Status DecoderMaskedMultiHeadAttention::ComputeInternal(OpKernelContext* TensorShape present_shape(present_dims); Tensor* present_key = context->Output(kPresentOutputIndex, present_shape); Tensor* present_value = context->Output(kPresentOutputIndex + 1, present_shape); + Tensor* cross_qk = nullptr; auto cuda_stream = Stream(context); @@ -191,6 +194,13 @@ Status DecoderMaskedMultiHeadAttention::ComputeInternal(OpKernelContext* parameters.v_cache = present_value_data; } + if (output_qk_) { + int64_t qk_dims[] = {parameters.batch_size, parameters.num_heads, 1, parameters.total_sequence_length}; + TensorShape qk_shape(&qk_dims[0], sizeof(qk_dims) / sizeof(qk_dims[0])); + cross_qk = context->Output(kQKOutputIndex, qk_shape); + parameters.out_qk = cross_qk->MutableData(); + } + parameters.out = output->MutableDataRaw(); // Scale diff --git a/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.h b/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.h index 8200a66db383f..b5476e6b54c44 100644 --- a/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.h @@ -22,6 +22,7 @@ class DecoderMaskedMultiHeadAttention final : public CudaKernel { float mask_filter_value_; float scale_; bool past_present_share_buffer_; + bool output_qk_; }; } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu index 5827bdfee1ab5..c9a06e2c6798b 100644 --- a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu @@ -430,6 +430,15 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentio // Compute the logits and start the sum. float sum = 0.f; int sum_tlength = params.is_cross_attention ? tlength - 1 : tlength; + + if (params.out_qk != nullptr) { + // store cross qk before softmax, out_qk has shape [B(batchxbeam), #Head, 1, total_sequence_length] + float* target = ((float*)params.out_qk) + ((int64_t)bhi * tlength); + for (int ti = tidx; ti <= sum_tlength; ti += THREADS_PER_BLOCK) { + target[ti] = (float)(qk_smem[ti]); + } + } + for (int ti = tidx; ti <= sum_tlength; ti += THREADS_PER_BLOCK) { // This is a deviation from FasterTransformer kernel implementation // but this aligns with ORT's other Attention kernels which strives to diff --git a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.h index 6d7f368db4dd4..4b408dafa2d81 100644 --- a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.h @@ -37,6 +37,7 @@ struct DecoderMaskedMultiHeadAttentionParams : AttentionParameters { void* v_cache = nullptr; void* out = nullptr; + void* out_qk = nullptr; const int32_t* cache_indir = nullptr; const int32_t* mask = nullptr; // [B, total_sequence_length] diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc index 25f3f59165e43..e3f53ca6a63cb 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc @@ -194,8 +194,10 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { // Here we assume that num_heads and head_size does not change for a MultiHeadAttention node. if (nullptr == fused_fp16_runner_.get()) { constexpr bool is_unidirectional = false; - fused_fp16_runner_ = FusedMHARunnerFP16v2::Create( - num_heads_, parameters.head_size, sm, is_unidirectional, enable_trt_flash_attention_, parameters.scale); + std::call_once(fused_fp16_runner_created_, [&]() { + fused_fp16_runner_ = FusedMHARunnerFP16v2::Create(num_heads_, parameters.head_size, sm, is_unidirectional, + enable_trt_flash_attention_, parameters.scale); + }); } // In case some kernel not loaded due to shared memory limit, we need to double check here. diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h index 33fa3d50e4564..c162f7133cc1c 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h @@ -32,6 +32,7 @@ class MultiHeadAttention final : public CudaKernel { bool disable_memory_efficient_attention_; int min_seq_len_for_flash_attention_packed_qkv_; mutable std::unique_ptr fused_fp16_runner_; + mutable std::once_flag fused_fp16_runner_created_; mutable const FusedMultiHeadCrossAttentionKernel* fused_fp16_cross_attention_kernel_; mutable CumulatedSequenceLengthCache cumulated_sequence_length_q_cache_; mutable CumulatedSequenceLengthCache cumulated_sequence_length_kv_cache_; diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_attention.h b/onnxruntime/contrib_ops/cuda/bert/packed_attention.h index 0cdd8021de4a1..f00c112fc73d2 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/packed_attention.h @@ -24,10 +24,11 @@ class TrtFusedAttention { protected: MHARunner* GetFusedRunner(const cudaDeviceProp& device_prop, const PackedAttentionParameters& parameters) const; - private: + protected: bool disable_fused_runner_; bool enable_trt_flash_attention_; mutable std::unique_ptr fused_fp16_runner_; + mutable std::once_flag fused_fp16_runner_created_; }; template diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index 86c1cb93e8b6f..58102d5eb496c 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -113,6 +113,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16_float_float, SimplifiedLayerNormalization); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Inverse); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Trilu); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, UnfoldTensor); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, DynamicTimeWarping); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int8_t_MLFloat16, QuantizeLinear); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, uint8_t_MLFloat16, QuantizeLinear); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int8_t_MLFloat16, DequantizeLinear); @@ -270,6 +272,8 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, // TransposedMatMul is still here for backward compatibility diff --git a/onnxruntime/contrib_ops/cuda/tensor/dynamic_time_warping.cc b/onnxruntime/contrib_ops/cuda/tensor/dynamic_time_warping.cc new file mode 100644 index 0000000000000..381316f605fc9 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/tensor/dynamic_time_warping.cc @@ -0,0 +1,56 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/cuda/tensor/dynamic_time_warping.h" +#include "contrib_ops/cuda/tensor/dynamic_time_warping_impl.h" +#include "core/providers/cpu/tensor/utils.h" + +#include +#include + +using namespace onnxruntime::common; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +ONNX_OPERATOR_KERNEL_EX( + DynamicTimeWarping, + kMSDomain, + 1, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("F", DataTypeImpl::GetTensorType()) + .TypeConstraint("I", DataTypeImpl::GetTensorType()), + DynamicTimeWarping); + +Status DynamicTimeWarping::ComputeInternal(OpKernelContext* ctx) const { + const Tensor& input_tensor = *ctx->Input(0); + const auto& input_dims = input_tensor.Shape().GetDims(); + int rank = SafeInt(input_dims.size()); + ORT_ENFORCE(rank == 2 || (rank == 3 && input_dims[0] == 1), "Currently input rank must be 2, or (3 with first dim equal to 1), but got:", rank); + + const size_t rows = SafeInt(input_dims[rank == 3 ? 1 : 0]); + const size_t cols = SafeInt(input_dims[rank == 3 ? 2 : 1]); + size_t max_index_len = 0; + + size_t buffer_size_in_bytes = GetDynamicTimeWarpingBufferSize(1, rows, cols, max_index_len); + IAllocatorUniquePtr buffer = GetScratchBuffer(buffer_size_in_bytes, ctx->GetComputeStream()); + + size_t result_len = 0; + ORT_RETURN_IF_ERROR(LaunchDynamicTimeWarping( + this->Stream(ctx), this->GetDeviceProp(), 1, rows, cols, + input_tensor.Data(), buffer.get(), result_len)); + + Tensor* output_tensor = ctx->Output(0, TensorShape{2LL, SafeInt(result_len)}); + + return CUDA_CALL(cudaMemcpy2DAsync( + output_tensor->MutableData(), result_len * sizeof(int32_t), + buffer.get() + ((max_index_len - result_len) * sizeof(int32_t)), max_index_len * sizeof(int32_t), + result_len * sizeof(int32_t), 2, + cudaMemcpyDeviceToDevice, this->Stream(ctx))); +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/tensor/dynamic_time_warping.h b/onnxruntime/contrib_ops/cuda/tensor/dynamic_time_warping.h new file mode 100644 index 0000000000000..3083e19aff6f2 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/tensor/dynamic_time_warping.h @@ -0,0 +1,26 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/providers/cuda/cuda_kernel.h" +#include + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +using onnxruntime::OpKernelContext; +using onnxruntime::OpKernelInfo; +using onnxruntime::cuda::CudaKernel; +class DynamicTimeWarping final : public CudaKernel { + public: + DynamicTimeWarping(const OpKernelInfo& info) : CudaKernel(info) {} + + ~DynamicTimeWarping() = default; + + Status ComputeInternal(OpKernelContext* context) const override; +}; + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/tensor/dynamic_time_warping_impl.cu b/onnxruntime/contrib_ops/cuda/tensor/dynamic_time_warping_impl.cu new file mode 100644 index 0000000000000..2072f001bc3a1 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/tensor/dynamic_time_warping_impl.cu @@ -0,0 +1,141 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/cuda/tensor/dynamic_time_warping_impl.h" +#include "core/providers/cuda/cu_inc/common.cuh" +#include "core/common/common.h" +#include +#include + +using namespace onnxruntime::cuda; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +__global__ void DynamicTimeWarpingInitCost(float* cost_buffer, int8_t* trace_buffer, size_t cols_plus_1) { + int r = blockIdx.x; + cost_buffer += cols_plus_1 * r; + for (size_t i = threadIdx.x; i < cols_plus_1; i += blockDim.x) { + cost_buffer[i] = FLT_MAX; + } + if (r == 0) { + for (size_t i = threadIdx.x; i < cols_plus_1; i += blockDim.x) { + trace_buffer[i] = 2; + } + } + if (threadIdx.x == 0) trace_buffer[cols_plus_1 * r] = 1; + if (threadIdx.x == 0 && r == 0) *cost_buffer = 0.0f; +} + +__global__ void DynamicTimeWarpingKernel( + size_t rows, + size_t cols, + size_t max_index_len, + const float* input, + float* cost_buffer, + int8_t* trace_buffer, + int32_t* result_buffer, + size_t* result_len_device +) { + const int diag_max = static_cast(rows + cols); + for (int d = 1; d <= diag_max; d++) { + for (int c = threadIdx.x + 1; c <= cols; c += blockDim.x) { + int r = d - c; + if (r >= 1 && r <= rows) { + int cost_idx = ((r - 1) * (cols + 1) + (c - 1)); //[r - 1, c - 1] + const float c0 = cost_buffer[cost_idx]; + const float c1 = cost_buffer[cost_idx + 1]; // [r - 1, c] + const float c2 = cost_buffer[cost_idx + cols + 1]; // [r, c - 1] + + float cost; + int8_t t; + if (c0 < c1 && c0 < c2) { + cost = c0; + t = 0; + } else if (c1 < c0 && c1 < c2) { + cost = c1; + t = 1; + } else { + cost = c2; + t = 2; + } + cost_idx += ((cols + 1) + 1); + cost_buffer[cost_idx] = cost + input[(r - 1) * cols + (c - 1)]; + trace_buffer[cost_idx] = t; + } + } + __syncthreads(); + } + + //back tracing, reverse append to result buffer + if (threadIdx.x == 0) { + int r = rows - 1; + int c = cols - 1; + int pos = static_cast(max_index_len); // reverse put + while (r >= 0 && c >= 0) { + --pos; + result_buffer[pos] = r; + result_buffer[max_index_len + pos] = c; + const int trace_index = (r + 1) * (cols + 1) + (c + 1); + int8_t t = trace_buffer[trace_index]; + switch (t) { + case 0: r -= 1; c -= 1; break; + case 1: r -= 1; break; + default: c -= 1; break; + } + } + *result_len_device = max_index_len - static_cast(pos); + } +} + +size_t GetDynamicTimeWarpingBufferSize(size_t batch, size_t rows, size_t cols, size_t& max_index_len) { + max_index_len = rows + cols + 1; + size_t cost_buffer_size = ((rows + 1) * (cols + 1)); + return batch * max_index_len * 2 * sizeof(int32_t) + // two index arrays + sizeof(int64_t) + // final index array length + batch* cost_buffer_size * sizeof(float) + // cost buffer + batch* cost_buffer_size * sizeof(int8_t); // trace buffer +} + +Status LaunchDynamicTimeWarping( + cudaStream_t stream, + const cudaDeviceProp& device_prop, + size_t batch, + size_t rows, + size_t cols, + const float* input, + void* buffer, + size_t& result_len +) { + ORT_ENFORCE(batch == 1); + size_t max_index_len = rows + cols + 1; + int32_t* result_buffer = (int32_t*)buffer; + size_t* result_len_device_buf = (size_t*)(result_buffer + (batch * max_index_len * 2)); + float* cost_buffer = (float*)(result_len_device_buf + 1); + int8_t* trace_buffer = (int8_t*)(cost_buffer + ((rows + 1) * (cols + 1))); + + dim3 block(device_prop.maxThreadsPerBlock); + dim3 grid_init((unsigned)SafeInt(rows + 1), (unsigned)SafeInt(batch)); + DynamicTimeWarpingInitCost<<>>(cost_buffer, trace_buffer, cols+1); + ORT_RETURN_IF_ERROR(CUDA_CALL(cudaGetLastError())); + + dim3 grid(1, (unsigned)SafeInt(batch)); + DynamicTimeWarpingKernel<<>>( + rows, + cols, + max_index_len, + input, + cost_buffer, + trace_buffer, + result_buffer, + result_len_device_buf); + ORT_RETURN_IF_ERROR(CUDA_CALL(cudaGetLastError())); + + ORT_RETURN_IF_ERROR(CUDA_CALL(cudaMemcpyAsync(&result_len, result_len_device_buf, sizeof(size_t), cudaMemcpyDeviceToHost, stream))); + return CUDA_CALL(cudaGetLastError()); +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/tensor/dynamic_time_warping_impl.h b/onnxruntime/contrib_ops/cuda/tensor/dynamic_time_warping_impl.h new file mode 100644 index 0000000000000..cb4a0dfb16807 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/tensor/dynamic_time_warping_impl.h @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/providers/cuda/shared_inc/cuda_utils.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +size_t GetDynamicTimeWarpingBufferSize(size_t batch, size_t rows, size_t cols, size_t& max_index_len); + +Status LaunchDynamicTimeWarping( + cudaStream_t stream, + const cudaDeviceProp& device_prop, + size_t batch, + size_t rows, + size_t cols, + const float* input, + void* buffer, + size_t& result_len); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/tensor/unfold.cc b/onnxruntime/contrib_ops/cuda/tensor/unfold.cc new file mode 100644 index 0000000000000..c38c8c5317f0a --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/tensor/unfold.cc @@ -0,0 +1,55 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/cuda/tensor/unfold.h" +#include "contrib_ops/cuda/tensor/unfold_impl.h" +#include "core/providers/cpu/tensor/utils.h" + +#include +#include + +using namespace onnxruntime::common; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +ONNX_OPERATOR_KERNEL_EX( + UnfoldTensor, + kMSDomain, + 1, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::AllTensorTypes()), + UnfoldTensor); + +Status UnfoldTensor::ComputeInternal(OpKernelContext* ctx) const { + const Tensor& input_tensor = *ctx->Input(0); + const auto& input_dims = input_tensor.Shape().GetDims(); + int rank = SafeInt(input_dims.size()); + + int dim = SafeInt(HandleNegativeAxis(dim_, rank)); + ORT_ENFORCE(dim < rank, "input rank:", rank, " is not bigger than attribut specified dim: ", dim); + ORT_ENFORCE(input_dims[dim] >= size_, "dimsize:", input_dims[dim], " is less than unfold size:", size_); + + int64_t leading_dims = std::accumulate(input_dims.begin(), input_dims.begin() + dim, 1LL, std::multiplies()); + int64_t tailing_dims = std::accumulate(input_dims.begin() + (dim + 1), input_dims.end(), 1LL, std::multiplies()); + + std::vector output_dims(rank + 1, 0); + std::copy(input_dims.begin(), input_dims.end(), output_dims.begin()); + output_dims[dim] = (input_dims[dim] - size_) / step_ + 1; + output_dims.back() = size_; + TensorShape output_shape(output_dims); + Tensor* output_tensor = ctx->Output(0, output_shape); + + cudaStream_t stream = this->Stream(ctx); + const cudaDeviceProp& device_prop = this->GetDeviceProp(); + size_t element_size = input_tensor.DataType()->Size(); + return LaunchUnfoldTensor( + stream, device_prop, element_size, input_tensor.DataRaw(), output_tensor->MutableDataRaw(), + leading_dims, input_dims[dim], tailing_dims, size_, step_); +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/tensor/unfold.h b/onnxruntime/contrib_ops/cuda/tensor/unfold.h new file mode 100644 index 0000000000000..1717687593470 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/tensor/unfold.h @@ -0,0 +1,39 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/providers/cuda/cuda_kernel.h" +#include + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +using onnxruntime::OpKernelContext; +using onnxruntime::OpKernelInfo; +using onnxruntime::cuda::CudaKernel; +class UnfoldTensor final : public CudaKernel { + public: + UnfoldTensor(const OpKernelInfo& info) : CudaKernel(info) { + dim_ = SafeInt(info.GetAttrOrDefault("dim", -1LL)); + step_ = SafeInt(info.GetAttrOrDefault("step", 1LL)); + ORT_ENFORCE(step_ > 0, "step must greater than zero!"); + + int64_t temp_size; + ORT_ENFORCE(info.GetAttr("size", &temp_size).IsOK()); + size_ = SafeInt(temp_size); + } + + ~UnfoldTensor() = default; + + Status ComputeInternal(OpKernelContext* context) const override; + + private: + int dim_; + int size_; + int step_; +}; + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/tensor/unfold_impl.cu b/onnxruntime/contrib_ops/cuda/tensor/unfold_impl.cu new file mode 100644 index 0000000000000..996f340b483a3 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/tensor/unfold_impl.cu @@ -0,0 +1,107 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/cuda/tensor/unfold_impl.h" +#include "core/providers/cuda/cu_inc/common.cuh" +#include "core/common/common.h" +#include + +using namespace onnxruntime::cuda; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +template +__global__ void UnfoldTensorKenel( + const T* input, + T* output, + int64_t N, + int64_t unfold_size, // stride_tailing_dim_dst + int64_t tailing_dims_size, // stride_fold_dim_dst = tailing_dims_size * unfold_size, stride_append_dim_src = tailing_dims_size + int64_t stride_leading_dst, + int64_t stride_fold_dim_src, + int64_t stride_leading_src +) { + int64_t idx = (int64_t)blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= N) return; + + const int64_t idx_leading = idx / stride_leading_dst; + int64_t n = idx % stride_leading_dst; + const int64_t stride_fold_dim_dst = tailing_dims_size * unfold_size; + const int64_t idx_fold = n / stride_fold_dim_dst; + n %= stride_fold_dim_dst; + const int64_t idx_tailing = n / unfold_size; + const int64_t idx_append = n % unfold_size; + + // stride_tailing_dim_src = 1 + int64_t idx_src = idx_leading * stride_leading_src + idx_fold * stride_fold_dim_src + idx_tailing + idx_append * tailing_dims_size; + output[idx] = input[idx_src]; +} + + +Status LaunchUnfoldTensor( + cudaStream_t stream, + const cudaDeviceProp& device_prop, + size_t element_size, + const void* input, + void* output, + int64_t leading_dims_size, + int64_t unfold_dim_size, + int64_t tailing_dims_size, + int64_t unfold_size, + int64_t step_size +) { + int64_t TPB = device_prop.maxThreadsPerBlock; + int64_t unfold_dim_size_dst = (unfold_dim_size - unfold_size) / step_size + 1; + int64_t N = leading_dims_size * unfold_dim_size_dst * tailing_dims_size * unfold_size; + int64_t num_blocks = (N + TPB - 1) / TPB; + + // int64_t stride_append_dim_dst = 1; + // int64_t stride_tailing_dim_dst = unfold_size; + // int64_t stride_fold_dim_dst = unfold_size * tailing_dims_size; + int64_t stride_leading_dst = unfold_size * tailing_dims_size * unfold_dim_size_dst; + + // int64_t stride_append_dim_src = tailing_dims_size; + // int64_t stride_tailing_dim_src = 1; + int64_t stride_fold_dim_src = tailing_dims_size * step_size; + int64_t stride_leading_src = tailing_dims_size * unfold_dim_size; + + dim3 block((unsigned)SafeInt(TPB)); + dim3 grid((unsigned)SafeInt(num_blocks)); + switch (element_size) { + case 1: + UnfoldTensorKenel<<>>( + (const int8_t*)input, (int8_t*)output, N, unfold_size, + tailing_dims_size, stride_leading_dst, stride_fold_dim_src, stride_leading_src); + break; + case 2: + UnfoldTensorKenel<<>>( + (const int16_t*)input, (int16_t*)output, N, unfold_size, + tailing_dims_size, stride_leading_dst, stride_fold_dim_src, stride_leading_src); + break; + case 4: + UnfoldTensorKenel<<>>( + (const int32_t*)input, (int32_t*)output, N, unfold_size, + tailing_dims_size, stride_leading_dst, stride_fold_dim_src, stride_leading_src); + break; + case 8: + UnfoldTensorKenel<<>>( + (const int64_t*)input, (int64_t*)output, N, unfold_size, + tailing_dims_size, stride_leading_dst, stride_fold_dim_src, stride_leading_src); + break; + case 16: + UnfoldTensorKenel<<>>( + (const float4*)input, (float4*)output, N, unfold_size, + tailing_dims_size, stride_leading_dst, stride_fold_dim_src, stride_leading_src); + break; + default: + return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Unsupported element_size"); + } + + return CUDA_CALL(cudaGetLastError()); +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/tensor/unfold_impl.h b/onnxruntime/contrib_ops/cuda/tensor/unfold_impl.h new file mode 100644 index 0000000000000..9e82dccdec23c --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/tensor/unfold_impl.h @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/providers/cuda/shared_inc/cuda_utils.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +Status LaunchUnfoldTensor( + cudaStream_t stream, + const cudaDeviceProp& device_prop, + size_t element_size, + const void* input, + void* output, + int64_t leading_dims_size, + int64_t tailing_dims_size, + int64_t dim_size, + int64_t unfold_size, + int64_t step_size); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/transformers/beam_search.cc b/onnxruntime/contrib_ops/cuda/transformers/beam_search.cc index d18460e016444..5e5211b2f88bd 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/beam_search.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/beam_search.cc @@ -58,7 +58,9 @@ BeamSearch::BeamSearch(const OpKernelInfo& info) GenerationCudaDeviceHelper::UpdateDecoderFeeds, GenerationCudaDeviceHelper::ExpandBuffer, GenerationCudaDeviceHelper::ExpandBuffer, - GenerationCudaDeviceHelper::ExpandBuffer); + GenerationCudaDeviceHelper::ExpandBuffer, + GenerationCudaDeviceHelper::UpdateDecoderCrossQK, + GenerationCudaDeviceHelper::FinalizeDecoderCrossQK); SetConsoleDumper(&g_cuda_dumper); diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu index 07a8896210d2c..d3b783df2083b 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu @@ -1315,6 +1315,235 @@ template void BufferExpansionKernelLauncher(const int32_t* input, int chunk_size, cudaStream_t stream); +template +__global__ void CopyCrossQKSingleDecodeStepKernel( + T* target, // shape [batchxbeam, layer_head_pair_count, max_length, frame] + T** qk_layer_pointers, + int token_index, + int num_layers, + int num_heads, + const int* cross_qk_layer_head_pairs, + int frames, + int max_length +) { + const int pair = blockIdx.x; + const int layer_head_pair_count = gridDim.x; + const int bbm = blockIdx.y; + cross_qk_layer_head_pairs += (pair * 2); + const int layer = *cross_qk_layer_head_pairs; + const int head = *(cross_qk_layer_head_pairs + 1); + + target += ((int64_t)bbm * layer_head_pair_count + pair) * max_length * frames + ((int64_t)token_index * frames); + T* src = qk_layer_pointers[layer] + ((int64_t)bbm * num_heads + head) * frames; + + for (int tid = threadIdx.x; tid < frames; tid += blockDim.x) { + target[tid] = src[tid]; // use vectorized read write in future if needed + } +} + +void LaunchCopyCrossQKSingleDecodeStep( + cudaStream_t stream, + float* cross_qk_buffer_data, + float** qk_layer_pointers, + int token_index, + int batchxbeam, + int num_layers, + int num_heads, + int cross_qk_layer_head_pair_count, + const int* cross_qk_layer_head_pairs, + int frames, + int max_length +) { + dim3 block(512); + dim3 grid(cross_qk_layer_head_pair_count, batchxbeam); + typedef typename ToCudaType::MappedType CudaT; + + CopyCrossQKSingleDecodeStepKernel<<>>( + (CudaT*)cross_qk_buffer_data, + (CudaT**)qk_layer_pointers, + token_index, + num_layers, + num_heads, + cross_qk_layer_head_pairs, + frames, + max_length + ); +} + + +template +__global__ void CopyDecoderCrossQKAllStepsKernel( + int context_decoding_len, + int num_beams, + int num_return_sequences, + int max_length, + int frames_of_k, + const T* cross_qk_buffer_data, // [batch, num_beams, layer_head_pair_count, max_length, frames] + T* cross_qk_output, // [batch, num_return_sequences, layer_head_pair_count, total_decoding_length, frames] + const int* cache_indir_data, // [batch, num_beams, max_length] + const int32_t* beam_indices +) { + const int pair = blockIdx.y; + const int layer_head_pair_count = gridDim.y; + const int total_decoding_length = gridDim.x; + const int token_decoding_index = blockIdx.x; + const int br = blockIdx.z; + const int batch = br / num_return_sequences; + const int ret_seq_id = br % num_return_sequences; + + // get the real beam index, as the cache_indir_data did not updated in last token + const int src_beam = beam_indices[batch * num_beams + ret_seq_id] % num_beams; + + const int64_t offset_in_cache = ((int64_t)batch * num_beams + src_beam) * max_length + token_decoding_index + context_decoding_len; + int bm_mapped = ((num_beams <= 1) ? 0: ((token_decoding_index == total_decoding_length - 1) ? ret_seq_id : cache_indir_data[offset_in_cache])); + int bi_src = batch * num_beams + bm_mapped; + + T* target = cross_qk_output + + (((int64_t)br * layer_head_pair_count + (int64_t)pair) * total_decoding_length + token_decoding_index) * frames_of_k; + const T* src = cross_qk_buffer_data + + ((int64_t)bi_src * layer_head_pair_count * max_length + (int64_t)pair * max_length + token_decoding_index) * frames_of_k; + for (int tid = threadIdx.x; tid < frames_of_k; tid += blockDim.x) { + target[tid] = src[tid]; // use vectorized read write in future if needed + } +} + +void LaunchFinalizeCrossQK( + cudaStream_t stream, + int iteration_number, + int context_decoding_len, + int batch_size, + int num_beams, + int max_length, + int cross_qk_layer_head_pair_count, + [[maybe_unused]] const int* cross_qk_layer_head_pairs, + int frames_of_k, + const float* cross_qk_buffer_data, + float* cross_qk_output, + int num_return_sequences, + const int* cache_indir_data, + const int32_t* beam_indices +) { + int64_t br = (int64_t)batch_size * num_return_sequences; + ORT_ENFORCE(br < 65536L && cross_qk_layer_head_pair_count < 65536); + const int total_decoding_length = iteration_number - 1; + dim3 block(512); + dim3 grid(total_decoding_length, cross_qk_layer_head_pair_count, (unsigned)br); + typedef typename ToCudaType::MappedType CudaT; + + CopyDecoderCrossQKAllStepsKernel<<>>( + context_decoding_len, + num_beams, + num_return_sequences, + max_length, + frames_of_k, + (const CudaT*)cross_qk_buffer_data, + (CudaT*)cross_qk_output, + cache_indir_data, + beam_indices); +} + +template +__global__ void ForceDecodingIdsKernel( + float* beam_scores, + const int vocab_size, + const int32_t* force_ids, + int id_len, + int step +) { + const int num_beams = gridDim.y; + const int beam = blockIdx.y; + const int batch = blockIdx.z; + beam_scores += (((int64_t)batch * num_beams + beam)* vocab_size); // move to (batch, beam) + const int32_t id_wanted = force_ids[((int64_t)batch * id_len) + step]; + if (id_wanted < 0 || id_wanted >= vocab_size) return; + + const int32_t elements_per_block = (int32_t)blockDim.x * ElementsPerThreads; + const int32_t block_start_id = blockIdx.x * elements_per_block; + + int32_t token_id = block_start_id + (int)threadIdx.x; + #pragma unroll + for (int elem = 0; elem < ElementsPerThreads; elem++) { + if (token_id < vocab_size) { + beam_scores[token_id] = ((token_id == id_wanted) ? 0.0f : cub::FpLimits::Lowest()); + } + token_id += (int)blockDim.x; + } +} + + +void LaunchForceDecodingIds( + float* beam_scores, + const int batch_size, + const int num_beams, + const int vocab_size, + const int32_t* force_ids, + int id_len, + int step, + cudaStream_t stream +) { + dim3 blocks(512); + constexpr int ElementsPerThreads = 4; + unsigned gridx = static_cast((vocab_size + 512 * ElementsPerThreads - 1) / (512 * ElementsPerThreads)); + dim3 grids(gridx, num_beams, batch_size); + ForceDecodingIdsKernel<<>>( + beam_scores, vocab_size, force_ids, id_len, step + ); +} + +template +__global__ void SaveNoSpeechProbsKernel( + T* result_no_speech_probs, + const float* probs, + const int batch_size, + const int num_beams, + const int vocab_size, + const int no_speech_token_id +) { + int b = blockIdx.x * blockDim.x + threadIdx.x; + if (b < batch_size) { + int64_t src_offset = b * num_beams * vocab_size + no_speech_token_id; + result_no_speech_probs[b] = (T)(probs[src_offset]); + } +} + +template +void LaunchSaveNoSpeechProbs( + T* result_no_speech_probs, /* [batch]*/ + const float* probs, /* [batch, num_beams, vocab_size]*/ + const int batch_size, + const int num_beams, + const int vocab_size, + const int no_speech_token_id, + cudaStream_t stream +) { + int tpb = 256; + int bpg = (batch_size + 255) / 256; + + typedef typename ToCudaType::MappedType CudaT; + SaveNoSpeechProbsKernel<<>>( + (CudaT*)result_no_speech_probs, probs, batch_size, num_beams, vocab_size, no_speech_token_id); +} + +template void LaunchSaveNoSpeechProbs( + float* result_no_speech_probs, + const float* probs, + const int batch_size, + const int num_beams, + const int vocab_size, + const int no_speech_token_id, + cudaStream_t stream +); + +template void LaunchSaveNoSpeechProbs( + MLFloat16* result_no_speech_probs, + const float* probs, + const int batch_size, + const int num_beams, + const int vocab_size, + const int no_speech_token_id, + cudaStream_t stream +); + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.h b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.h index 8c52f6fd52385..07ccbeced456e 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.h +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.h @@ -213,6 +213,55 @@ void BufferExpansionKernelLauncher(const T* input, int chunk_size, cudaStream_t stream); +void LaunchCopyCrossQKSingleDecodeStep( + cudaStream_t stream, + float* cross_qk_buffer_data, + float** qk_layer_pointers, + int token_index, + int batchxbeam, + int num_layers, + int num_heads, + int cross_qk_layer_head_pair_count, + const int* cross_qk_layer_head_pairs, + int frames, + int max_length); + +void LaunchFinalizeCrossQK( + cudaStream_t stream, + int iteration_number, + int context_decoding_len, + int batch_size, + int num_beams, + int max_length, + int cross_qk_layer_head_pair_count, + const int* cross_qk_layer_head_pairs, + int frames_of_k, + const float* cross_qk_buffer_data, + float* cross_qk_output, + int num_return_sequences, + const int* cache_indir_data, + const int32_t* beam_indices); + +void LaunchForceDecodingIds( + float* beam_scores, + const int batch_size, + const int num_beams, + const int vocab_size, + const int32_t* force_ids, + int id_len, + int step, + cudaStream_t stream); + +template +void LaunchSaveNoSpeechProbs( + T* result_no_speech_probs, /* [batch]*/ + const float* probs, /* [batch, num_beams, vocab_size]*/ + const int batch_size, + const int num_beams, + const int vocab_size, + const int no_speech_token_id, + cudaStream_t stream); + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc index e4de33499c6ca..32cf8655f3b04 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc @@ -13,6 +13,8 @@ #include #include "contrib_ops/cuda/transformers/generation_cuda_impl.h" #include "contrib_ops/cuda/transformers/dump_cuda_tensor.h" +#include "contrib_ops/cpu/transformers/logits_processor.h" +#include "contrib_ops/cpu/transformers/generation_shared.h" #include "contrib_ops/cpu/transformers/subgraph_t5_decoder.h" #include "contrib_ops/cpu/transformers/subgraph_gpt.h" #include "contrib_ops/cuda/transformers/beam_search_topk.h" @@ -210,7 +212,7 @@ Status AddToFeeds(Stream* ort_stream, ORT_ENFORCE(total_bytes > 0); cudaStream_t stream = ort_stream ? static_cast(ort_stream->GetHandle()) : nullptr; - auto pinned_buffer = IAllocator::MakeUniquePtr(host_allocator, total_bytes); + auto pinned_buffer = IAllocator::MakeUniquePtr(host_allocator, total_bytes, false, ort_stream); char* pinned_data = static_cast(pinned_buffer.get()); // Copy tensors to one pinned memory buffer (so that we only need copy to GPU once) char* destination = pinned_data; @@ -426,11 +428,21 @@ Status ProcessLogits(const OrtValue& logits, // dumper->Print("next_token_scores after softmax", next_token_scores.data(), batch_size, num_beams, vocab_size); #endif + const bool is_whisper_model = (parameters->model_type == onnxruntime::contrib::transformers::IGenerationParameters::kModelTypeWhisper); + if (step == 1 && is_whisper_model && parameters->no_speech_probs) { + cuda::LaunchSaveNoSpeechProbs( + (T*)parameters->no_speech_probs, Y_data, batch_size, num_beams, vocab_size, parameters->no_speech_token, cuda_stream); + } + + // NOTE: currently we treat extra decoding ids are same + int extra_decoding_len = static_cast(parameters->extra_decoding_ids.size() / parameters->batch_size); + const bool need_handle_extra_decoding_ids = is_whisper_model && (!parameters->extra_decoding_ids.empty()) && (extra_decoding_len >= step); + cuda::LaunchLogitsProcessKernel( next_token_scores.data(), parameters->vocab_mask.data(), - step > 1 ? nullptr : parameters->prefix_vocab_mask.data(), // prefix vocab mask is applied to first step only. - nullptr, // parameters->presence_mask.data(), + (step > extra_decoding_len + 1) ? nullptr : parameters->prefix_vocab_mask.data(), // prefix vocab mask is applied to first step only. + nullptr, // parameters->presence_mask.data(), parameters->presence_penalty, parameters->temperature, parameters->batch_size, @@ -445,6 +457,50 @@ Status ProcessLogits(const OrtValue& logits, // parameters->no_repeat_ngram_size, cuda_stream); + // Whisper time stamp generation. + // TODO: implement it on GPU + bool gen_timestamp = is_whisper_model && + (parameters->logits_processor == onnxruntime::contrib::transformers::IGenerationParameters::kLogitsProcessorTypeWhisper); + if (gen_timestamp) { + // Copy next token scores to cpu memory, copy Sequences to cpu + std::vector cpu_next_token_scores(next_token_scores.size()); + gsl::span cpu_next_token_scores_span(cpu_next_token_scores.data(), cpu_next_token_scores.size()); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(cpu_next_token_scores.data(), + next_token_scores.data(), + next_token_scores.size_bytes(), + cudaMemcpyDeviceToHost, + cuda_stream)); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(const_cast(sequences->GetSequence(0).data()), + sequences->GetCurrentDeviceSequences().data(), + sequences->GetSequence(0).size_bytes() * batch_beam_size, + cudaMemcpyDeviceToHost, + cuda_stream)); + constexpr int max_initial_timestamp_index = 50; + onnxruntime::contrib::transformers::TimestampLogitsProcessor time_logit_processor(parameters->eos_token_id, max_initial_timestamp_index); + onnxruntime::contrib::transformers::NextTokenScores next_token_scores_timestamp({cpu_next_token_scores_span, batch_beam_size, vocab_size}); + + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(cuda_stream)); + time_logit_processor.Process(sequences, next_token_scores_timestamp); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(next_token_scores.data(), + cpu_next_token_scores.data(), + next_token_scores.size_bytes(), + cudaMemcpyHostToDevice, + cuda_stream)); + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(cuda_stream)); + } + + if (need_handle_extra_decoding_ids && !parameters->extra_decoding_ids.empty()) { + cuda::LaunchForceDecodingIds( + next_token_scores.data(), + parameters->batch_size, + parameters->num_beams, + parameters->vocab_size, + parameters->extra_decoding_ids.data(), + parameters->extra_decoding_ids.size() / parameters->batch_size, + step - 1, + cuda_stream); + } + #ifdef DEBUG_GENERATION dumper->Print("next_token_scores after logits process", next_token_scores.data(), batch_size, num_beams, vocab_size); #endif @@ -807,13 +863,11 @@ Status GreedySearchProcessLogits( // Sequences generated by beam scorer is currently stored in CPU. // Copy sequences to device only when repetition penalty or no repeat ngram is used in kernel - BufferUniquePtr sequences_buffer; + IAllocatorUniquePtr sequences_buffer; int current_sequence_length = sequences->GetSequenceLength(); if (parameters->repetition_penalty != 1.0f) { size_t bytes = SafeInt(sizeof(int32_t)) * batch_beam_size * parameters->max_length; - void* data = allocator->Alloc(bytes); - BufferUniquePtr temp_buffer(data, BufferDeleter(allocator)); - sequences_buffer = std::move(temp_buffer); + sequences_buffer = IAllocator::MakeUniquePtr(allocator, bytes, false, stream); CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(sequences_buffer.get(), sequences->GetSequence(0).data(), bytes, cudaMemcpyHostToDevice, cuda_stream)); } @@ -1196,14 +1250,14 @@ Status UpdateDecoderFeeds( if (past_present_share_buffer) { // Update past sequence length input - const ptrdiff_t past_sequence_length_idx = 2 * (static_cast(last_outputs.size()) - t5_decoder_first_present_output_idx) + t5_decoder_first_past_input_idx; + const ptrdiff_t past_sequence_length_idx = 2 * num_present_tensors + t5_decoder_first_past_input_idx; *(next_inputs[past_sequence_length_idx].GetMutable()->MutableData()) = current_length - 1; // Update beam search specific input for DecoderMaskedSelfAttention (cache indirection) if present // If the last input is not `past_sequence_length`, then the beam search specific inputs // for `DecoderMaskedSelfAttention` is present - if (need_cache_indir) { + if (need_cache_indir && num_beams > 1) { ORT_ENFORCE(!beam_indices_gpu.empty(), "Beam indices must be present on CUDA while using DecoderMaskedMultiHeadAttention with BeamSearch"); // The cache indirection feed comes 2 feeds after the `past_sequence_length` feed @@ -1528,6 +1582,93 @@ template Status ExpandBuffer( OrtValue& expanded, bool only_copy_shape, int max_sequence_length); + +Status UpdateDecoderCrossQK( + int iteration_number, + Stream* stream, + OrtValue* cross_qks, + IAllocatorUniquePtr& qk_layer_pointers, + int num_layers, + int cross_qk_layer_head_pair_count, + const int* cross_qk_layer_head_pairs, + float* cross_qk_buffer_data, + int max_length, + AllocatorPtr allocator) { + cudaStream_t cuda_stream = stream ? static_cast(stream->GetHandle()) : nullptr; + + if (qk_layer_pointers.get() == nullptr) { + // Put all the qk pointers into gpu, as they did not change in following decoding steps + // also this help to use single kernel to process each step + qk_layer_pointers = IAllocator::MakeUniquePtr(allocator, static_cast(num_layers), false, stream); + std::vector qk_layer_data(num_layers, nullptr); + for (int layer = 0; layer < num_layers; layer++) { + qk_layer_data[layer] = cross_qks[layer].GetMutable()->MutableData(); + } + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync((void*)qk_layer_pointers.get(), qk_layer_data.data(), sizeof(qk_layer_data[0]) * num_layers, + cudaMemcpyHostToDevice, cuda_stream)); + } + + auto cross_qk_layer_shape = cross_qks[0].GetMutable()->Shape(); + int64_t batchxbeam = cross_qk_layer_shape[0]; + int64_t num_heads = cross_qk_layer_shape[1]; + int64_t frames = cross_qk_layer_shape[3]; + + cuda::LaunchCopyCrossQKSingleDecodeStep( + cuda_stream, + cross_qk_buffer_data, + qk_layer_pointers.get(), + iteration_number - 2, + batchxbeam, + num_layers, + num_heads, + cross_qk_layer_head_pair_count, + cross_qk_layer_head_pairs, + frames, + max_length); + + CUDA_RETURN_IF_ERROR(cudaGetLastError()); + + return Status::OK(); +} + +Status FinalizeDecoderCrossQK( + Stream* stream, + int iteration_number, + int context_decoding_len, + int batch_size, + int num_beams, + int max_length, + int cross_qk_layer_head_pair_count, + const int* cross_qk_layer_head_pairs, + int frames_of_k, + const float* cross_qk_buffer_data, + float* cross_qk_output, + int num_return_sequences, + const int* cache_indir_data, + gsl::span beam_indices_gpu) { + cudaStream_t cuda_stream = stream ? static_cast(stream->GetHandle()) : nullptr; + + cuda::LaunchFinalizeCrossQK( + cuda_stream, + iteration_number, + context_decoding_len, + batch_size, + num_beams, + max_length, + cross_qk_layer_head_pair_count, + cross_qk_layer_head_pairs, + frames_of_k, + cross_qk_buffer_data, + cross_qk_output, + num_return_sequences, + cache_indir_data, + beam_indices_gpu.data()); + + CUDA_RETURN_IF_ERROR(cudaGetLastError()); + + return Status::OK(); +} + } // namespace GenerationCudaDeviceHelper } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.h b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.h index f5f062d7a101b..7a718eb9f66c1 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.h +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.h @@ -150,6 +150,34 @@ Status ExpandBuffer( bool only_copy_shape, int max_sequence_length = 0); +Status UpdateDecoderCrossQK( + int iteration_number, + Stream* stream, + OrtValue* cross_qks, + IAllocatorUniquePtr& qk_layer_pointers, + int num_layers, + int cross_qk_layer_head_pair_count, + const int* cross_qk_layer_head_pairs, + float* cross_qk_buffer_data, + int max_length, + AllocatorPtr allocator); + +Status FinalizeDecoderCrossQK( + Stream* stream, + int iteration_number, + int context_decoding_len, + int batch_size, + int num_beams, + int max_length, + int cross_qk_layer_head_pair_count, + const int* cross_qk_layer_head_pairs, + int frames_of_k, + const float* cross_qk_buffer_data, + float* cross_qk_output, + int num_return_sequences, + const int* cache_indir_data, + gsl::span beam_indices); + } // namespace GenerationCudaDeviceHelper } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index e5956a575d73d..880b2df3543e6 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -746,6 +746,10 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "The value to be filled in the attention mask. Default value is -10000.0f", AttributeProto::FLOAT, OPTIONAL_VALUE) + .Attr("output_qk", + "Need output the cross attention MatMul(Q, K)", + AttributeProto::INT, + OPTIONAL_VALUE) .Input(0, "query", "Query with shape (batch_size, 1, hidden_size) or packed QKV with shape " @@ -837,6 +841,12 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "while effective_seq_length = (past_sequence_length + kv_sequence_length).", "T", OpSchema::Optional) + .Output(3, + "qk", + "normalized Q * K, of shape (batch_size, num_heads, 1, head_size). ", + "V", + OpSchema::Optional) + .TypeConstraint("V", {"tensor(float)"}, "Constrain qk output types to float32 tensors.") .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float tensors.") diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index a79203a94a3a7..00740ada2677d 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -407,6 +407,9 @@ void BeamSearchShapeInference(ONNX_NAMESPACE::InferenceContext& ctx) { ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 5, 1); if (ctx.getNumOutputs() > 2) { ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 5, 2); + if (ctx.getNumOutputs() > 3) { + ONNX_NAMESPACE::updateOutputElemType(ctx, 3, ONNX_NAMESPACE::TensorProto::FLOAT); + } } } @@ -415,6 +418,7 @@ void BeamSearchShapeInference(ONNX_NAMESPACE::InferenceContext& ctx) { // output 0 (sequences) shape: (batch_size, num_return_sequences, max_length) // output 1 (sequences_scores) shape: (batch_size, num_return_sequences) // output 2 (scores) shape: (max_length - sequence_length, batch_size, num_beams, vocab_size) + // output 3 (cross_attention): shape: (batch_size, num_return_sequences, Layers, Heads, max_length, Frames) if (!hasInputShape(ctx, 0)) { return; } @@ -490,6 +494,22 @@ void BeamSearchShapeInference(ONNX_NAMESPACE::InferenceContext& ctx) { } updateOutputShape(ctx, 2, scores_shape); } + + if (ctx.getNumOutputs() > 3) { + ONNX_NAMESPACE::TensorShapeProto cross_attn_shape; + cross_attn_shape.add_dim()->set_dim_value(batch_size); + cross_attn_shape.add_dim()->set_dim_value(num_return_sequences_value); + cross_attn_shape.add_dim(); // num of layer is unknown, no need to calc it from subgraph here + cross_attn_shape.add_dim(); // num of head is unknown, no need to calc it from subgraph here + cross_attn_shape.add_dim()->set_dim_value(max_length_value); + cross_attn_shape.add_dim()->set_dim_value(sequence_length); + updateOutputShape(ctx, 3, cross_attn_shape); + } + if (ctx.getNumOutputs() > 4) { + ONNX_NAMESPACE::TensorShapeProto non_speech_probs_shape; + non_speech_probs_shape.add_dim()->set_dim_value(batch_size); + updateOutputShape(ctx, 4, non_speech_probs_shape); + } } } @@ -1060,6 +1080,78 @@ ONNX_MS_OPERATOR_SET_SCHEMA(GridSample, 1, updateOutputShape(ctx, 0, {N, C, H_out, W_out}); })); +ONNX_MS_OPERATOR_SET_SCHEMA( + UnfoldTensor, 1, + OpSchema() + .SetDoc("Returns a tensor which contains all slices of size size from input tensor in the dimension dim. " + "Step between two slices is given by step. " + "If sizedim is the size of dimension dim for input tensor, the size of dimension dim in " + "the returned tensor will be (sizedim - size) / step + 1. " + "An additional dimension of size size is appended in the returned tensor.") + .Attr("dim", "specify the dimension to unfold", AttributeProto::INT, static_cast(-1)) + .Attr("size", "specify the size", AttributeProto::INT) + .Attr("step", "specify the step.", AttributeProto::INT, static_cast(1)) + .Input(0, "input", "input tensor", "T") + .Output(0, "output", "Output tensor.", "T") + .TypeConstraint("T", OpSchema::all_tensor_types_ir4(), "Allow inputs and outputs to be any kind of tensor.") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + propagateElemTypeFromInputToOutput(ctx, 0, 0); + + if (!hasInputShape(ctx, 0)) return; + auto& input_shape = getInputShape(ctx, 0); + const int rank = input_shape.dim_size(); + int64_t dim = getAttribute(ctx, "dim", -1); + dim = HandleNegativeAxis(dim, rank); + if (!input_shape.dim(dim).has_dim_value()) { + return; + } + int64_t dim_size = input_shape.dim(dim).dim_value(); + + const int64_t step = getAttribute(ctx, "step", -1); + if (step <= 0) { + fail_shape_inference("size attribute in UnfoldTensor must greater than 0.") + } + int64_t size = -1; + auto size_proto = ctx.getAttribute("size"); + if (!(size_proto)) { + fail_shape_inference("size attribute in UnfoldTensor not specified!") + } + size = size_proto->i(); + if (size > dim_size || size <= 0) { + fail_shape_inference("size attribute in UnfoldTensor not positive and less than the dim size!") + } + + ONNX_NAMESPACE::TensorShapeProto output_shape; + for (int d = 0; d < rank; d++) { + if (d == dim) { + output_shape.add_dim()->set_dim_value((dim_size - size) / step + 1); + } else { + *output_shape.add_dim() = input_shape.dim(d); + } + } + output_shape.add_dim()->set_dim_value(size); + updateOutputShape(ctx, 0, output_shape); + })); + +ONNX_MS_OPERATOR_SET_SCHEMA( + DynamicTimeWarping, 1, + OpSchema() + .SetDoc("Input is cost matrix where each value in input[r][c] is the cost for pass the point (r, c). From current point" + "(r, c), points (r+1, c), (r+1, c+1) or (r, c+1) could be arrived in next move. Given such cost matrix, return " + "dynamic time wrapping of shape [2, x], where the path made by all points (output[0][t], output[1][t])" + "have the lowest cost among all paths from (0, 0) to (M-1, N-1).") + .Input(0, "input", "Input cost tensor, it must be 2D tensor of shape M x N, or 1 x M x N", "F") + .Output(0, "output", "Output tensor. shape is [2, x], where max(M, N) <= x < M + N", "I") + .TypeConstraint("F", {"tensor(float)"}, "Constrain to float tensors.") + .TypeConstraint("I", {"tensor(int32)"}, "Constrain to integer types.") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + updateOutputElemType(ctx, 0, ONNX_NAMESPACE::TensorProto::INT32); + ONNX_NAMESPACE::TensorShapeProto resultShape; + resultShape.add_dim()->set_dim_value(2); + resultShape.add_dim(); + updateOutputShape(ctx, 0, resultShape); + })); + ONNX_MS_OPERATOR_SET_SCHEMA(BeamSearch, 1, OpSchema() .SetDoc("Beam Search for text generation. Supports GPT-2 decoder.") @@ -1068,7 +1160,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(BeamSearch, 1, .Attr("decoder_start_token_id", "The id of the token that indicates decoding starts.", AttributeProto::INT, static_cast(-1)) .Attr("no_repeat_ngram_size", "no repeat ngrams size", AttributeProto::INT, static_cast(0)) .Attr("early_stopping", "early stop or not", AttributeProto::INT, static_cast(0)) - .Attr("model_type", "model type: 0 for GPT-2; 1 for encoder decoder like T5", AttributeProto::INT, static_cast(0)) + .Attr("model_type", "model type: 0 for GPT-2; 1 for encoder decoder like T5; 2 for whisper", AttributeProto::INT, static_cast(0)) .Attr("encoder", "The subgraph for initialization of encoder and decoder. It will be called once before decoder subgraph.", AttributeProto::GRAPH, OPTIONAL_VALUE) .Attr("init_decoder", "The subgraph for the first decoding run. It will be called once before `decoder` subgraph. " @@ -1079,6 +1171,10 @@ ONNX_MS_OPERATOR_SET_SCHEMA(BeamSearch, 1, "Size of the vocabulary. " "If not provided, it will be inferred from the decoder subgraph's output shape", AttributeProto::INT, static_cast(-1)) + .Attr("decoder_output_cross_qk", "If nozero, decoder subgraph contains output Q*K from cross attentions. Default 0.", AttributeProto::INT, OPTIONAL_VALUE) + .Attr("no_speech_token", + "The token in whisper model that mark all sequence empty. With this model, whisper could output no_speech_prob after Default -1.", + AttributeProto::INT, OPTIONAL_VALUE) .Input(0, "input_ids", "The sequence used as a prompt for the generation in the encoder subgraph. Shape is (batch_size, sequence_length)", "F") .Input(1, "max_length", "The maximum length of the sequence to be generated. Shape is (1)", "I") .Input(2, "min_length", "The minimum length below which the score of eos_token_id is set to -Inf. Shape is (1)", "I", OpSchema::Optional) @@ -1095,6 +1191,15 @@ ONNX_MS_OPERATOR_SET_SCHEMA(BeamSearch, 1, .Input(9, "attention_mask", "Custom attention mask. Shape is (batch_size, sequence_length)", "I", OpSchema::Optional) .Input(10, "decoder_input_ids", "The forced input id sequence for the decoder subgraph. Shape is (batch_size, initial_sequence_length)", "I", OpSchema::Optional) .Input(11, "logits_processor", "Specific logits processor for different types of beamsearch models. Default value 0 means no specific logit processor. Accepts value >= 0. Shape is (1)", "I", OpSchema::Optional) + .Input(12, "cross_qk_layer_head", + "Only keep this list of (layer, head) of QK in the final cross_qk output when use_cross_qk is set. Default collect all" + "its shape is (number of (layer, head) to keep, 2), i.e., [[layer_id1, head_id1], [layer_id2, head_id2]......]", + "I", OpSchema::Optional) + .Input(13, "extra_decoding_ids", + "Part of the decoder_input_ids that we need cross qk for it. it is of shape (batch_size, extra_decoding_ids_len)." + "In such case, we should remove this from the tail of the decoder_input_ids, and put it here. ids < 0 in it (for multiple batch) " + "are treated as stop of the extra_decoding_ids for corresponding batch.", + "I", OpSchema::Optional) .Output(0, "sequences", "Word IDs of generated sequences. Shape is (batch_size, num_return_sequences, max_sequence_length)", "I") .Output(1, "sequences_scores", "Final beam score of the generated sequences. Shape is (batch_size, num_return_sequences)", "T", OpSchema::Optional) .Output(2, "scores", @@ -1102,6 +1207,18 @@ ONNX_MS_OPERATOR_SET_SCHEMA(BeamSearch, 1, "Beam scores consisting of log softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam." "Shape is (max_length - sequence_length, batch_size, num_beams, vocab_size)", "T", OpSchema::Optional) + .Output(3, "cross_qk", + "Output the accumulated stacked Q*K in cross attentions. Let H = number of Head of cross attention, " + "F = the frames or kv-seq-len of the cross attention input, T = real decoded token length, L = number of layers," + "B = batch size, R = num_return_sequences. It then should return tensor of shape [B, R, L*H, T, F]." + "If cross_qk_layer_head is given, shape is [B, R, cross_qk_layer_head.shape[0], T, F]", + "V", OpSchema::Optional) + .Output(4, "non_speech_probs", + "For whisper model, output the probabilities from logits after encoder and context decoding for the no_speech_token." + "Currently we treat the last token's logits is what we need, in future extra graph logic may be add to the encoder/context-decoder subgraph." + "The prob is save before logits may be updated by extra-decoding-ids. The shape of non_speech_probs is [B]", + "T", OpSchema::Optional) + .TypeConstraint("V", {"tensor(float)"}, "Constrain to float32 tensors.") .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain to float tensors.") .TypeConstraint("F", {"tensor(float)", "tensor(int32)", "tensor(float16)"}, "Constrain input type to float or int tensors.") .TypeConstraint("I", {"tensor(int32)"}, "Constrain to integer types") diff --git a/onnxruntime/core/graph/contrib_ops/ms_opset.h b/onnxruntime/core/graph/contrib_ops/ms_opset.h index 3c31997286254..2a241dba922cb 100644 --- a/onnxruntime/core/graph/contrib_ops/ms_opset.h +++ b/onnxruntime/core/graph/contrib_ops/ms_opset.h @@ -102,6 +102,8 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Tokenizer); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, TorchEmbedding); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, TransposeMatMul); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Trilu); +class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, UnfoldTensor); +class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, DynamicTimeWarping); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Unique); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, WordConvEmbedding); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, GemmFastGelu); @@ -203,6 +205,8 @@ class OpSet_Microsoft_ver1 { fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); + fn(GetOpSchema()); + fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index ad892eab3b843..ef9c81440927f 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -250,7 +250,7 @@ CUDAExecutionProvider::CUDAExecutionProvider(const CUDAExecutionProviderInfo& in if (info.external_allocator_info.UseExternalAllocator()) { use_ep_level_unified_stream_ = true; stream_ = nullptr; - } else if (info.enable_cuda_graph) { + } else if (info.enable_cuda_graph || info.use_ep_level_unified_stream) { // current cuda graph implementation only works with single stream // use EP level unified stream for all the reqeust CUDA_CALL_THROW(cudaStreamCreateWithFlags(&stream_, cudaStreamNonBlocking)); diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider_info.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider_info.cc index ca88b3474b758..966448051264d 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider_info.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider_info.cc @@ -29,6 +29,7 @@ constexpr const char* kTunableOpEnable = "tunable_op_enable"; constexpr const char* kTunableOpTuningEnable = "tunable_op_tuning_enable"; constexpr const char* kTunableOpMaxTuningDurationMs = "tunable_op_max_tuning_duration_ms"; constexpr const char* kEnableSkipLayerNormStrictMode = "enable_skip_layer_norm_strict_mode"; +constexpr const char* KUseEPLevelUnifiedStream = "use_ep_level_unified_stream"; } // namespace provider_option_names } // namespace cuda @@ -99,6 +100,7 @@ CUDAExecutionProviderInfo CUDAExecutionProviderInfo::FromProviderOptions(const P .AddAssignmentToReference(cuda::provider_option_names::kEnableCudaGraph, info.enable_cuda_graph) .AddAssignmentToReference(cuda::provider_option_names::kCudnnConv1dPadToNc1d, info.cudnn_conv1d_pad_to_nc1d) .AddAssignmentToReference(cuda::provider_option_names::kEnableSkipLayerNormStrictMode, info.enable_skip_layer_norm_strict_mode) + .AddAssignmentToReference(cuda::provider_option_names::KUseEPLevelUnifiedStream, info.use_ep_level_unified_stream) .AddValueParser( cuda::provider_option_names::kTunableOpEnable, [&info](const std::string& value_str) -> Status { @@ -144,6 +146,7 @@ ProviderOptions CUDAExecutionProviderInfo::ToProviderOptions(const CUDAExecution {cuda::provider_option_names::kTunableOpTuningEnable, MakeStringWithClassicLocale(info.tunable_op.tuning_enable)}, {cuda::provider_option_names::kTunableOpMaxTuningDurationMs, MakeStringWithClassicLocale(info.tunable_op.max_tuning_duration_ms)}, {cuda::provider_option_names::kEnableSkipLayerNormStrictMode, MakeStringWithClassicLocale(info.enable_skip_layer_norm_strict_mode)}, + {cuda::provider_option_names::KUseEPLevelUnifiedStream, MakeStringWithClassicLocale(info.use_ep_level_unified_stream)}, }; return options; @@ -162,6 +165,7 @@ ProviderOptions CUDAExecutionProviderInfo::ToProviderOptions(const OrtCUDAProvid {cuda::provider_option_names::kTunableOpEnable, MakeStringWithClassicLocale(info.tunable_op_enable)}, {cuda::provider_option_names::kTunableOpTuningEnable, MakeStringWithClassicLocale(info.tunable_op_tuning_enable)}, {cuda::provider_option_names::kTunableOpMaxTuningDurationMs, MakeStringWithClassicLocale(info.tunable_op_max_tuning_duration_ms)}, + {cuda::provider_option_names::KUseEPLevelUnifiedStream, MakeStringWithClassicLocale(info.use_ep_level_unified_stream)}, }; return options; diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider_info.h b/onnxruntime/core/providers/cuda/cuda_execution_provider_info.h index 789b02b0e1d8c..89b266f362e8d 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider_info.h +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider_info.h @@ -72,6 +72,8 @@ struct CUDAExecutionProviderInfo { bool enable_skip_layer_norm_strict_mode{false}; + bool use_ep_level_unified_stream{false}; + static CUDAExecutionProviderInfo FromProviderOptions(const ProviderOptions& options); static ProviderOptions ToProviderOptions(const CUDAExecutionProviderInfo& info); static ProviderOptions ToProviderOptions(const OrtCUDAProviderOptionsV2& info); diff --git a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc index 5a11f2529f38e..53746ff1e1fe6 100644 --- a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc +++ b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc @@ -222,6 +222,7 @@ struct CUDA_Provider : Provider { info.tunable_op.tuning_enable = params->tunable_op_tuning_enable; info.tunable_op.max_tuning_duration_ms = params->tunable_op_max_tuning_duration_ms; info.enable_skip_layer_norm_strict_mode = params->enable_skip_layer_norm_strict_mode != 0; + info.use_ep_level_unified_stream = params->use_ep_level_unified_stream != 0; return std::make_shared(info); } @@ -253,6 +254,7 @@ struct CUDA_Provider : Provider { cuda_options.enable_cuda_graph = internal_options.enable_cuda_graph; cuda_options.cudnn_conv1d_pad_to_nc1d = internal_options.cudnn_conv1d_pad_to_nc1d; cuda_options.enable_skip_layer_norm_strict_mode = internal_options.enable_skip_layer_norm_strict_mode; + cuda_options.use_ep_level_unified_stream = internal_options.use_ep_level_unified_stream; } ProviderOptions GetProviderOptions(const void* provider_options) override { diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index bf7a3bbd9d380..29ca5bd7b0055 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -1332,6 +1332,7 @@ OrtCUDAProviderOptionsV2 OrtCUDAProviderOptionsToOrtCUDAProviderOptionsV2(const cuda_options_converted.enable_cuda_graph = 0; cuda_options_converted.cudnn_conv1d_pad_to_nc1d = 0; cuda_options_converted.enable_skip_layer_norm_strict_mode = 0; + cuda_options_converted.use_ep_level_unified_stream = 0; return cuda_options_converted; } diff --git a/onnxruntime/python/tools/transformers/convert_generation.py b/onnxruntime/python/tools/transformers/convert_generation.py index c0cabbb5e9759..c1c709d6d759b 100644 --- a/onnxruntime/python/tools/transformers/convert_generation.py +++ b/onnxruntime/python/tools/transformers/convert_generation.py @@ -1272,7 +1272,46 @@ def find_past_seq_len_usage(subg: GraphProto): return tensor_names_to_rename, nodes_to_remove -def update_decoder_subgraph_share_buffer_and_use_decoder_masked_mha(subg: GraphProto): +def update_decoder_subgraph_output_cross_attention(subg: GraphProto): + input_self_past_0 = 1 + # w/wo attention mask, w/wo hidden_state + graph_input_names = [gi.name for gi in subg.input] + while input_self_past_0 < 3 and not graph_input_names[input_self_past_0].startswith("past"): + input_self_past_0 += 1 + output_self_present_0 = 1 + + num_layers = (len(subg.output) - output_self_present_0) // 2 + input_cross_past_0 = 2 * num_layers + input_self_past_0 + past_key_cross_inputs = {subg.input[layer * 2 + input_cross_past_0].name: layer for layer in range(num_layers)} + print(f" --past_key_cross_inputs={past_key_cross_inputs}") + + input_past_key_cross_0_shape = shape_of(subg.input[input_cross_past_0]) + print(f"past_key_cross_0_shape is {input_past_key_cross_0_shape}") + batch_size_dim = input_past_key_cross_0_shape[0] + num_heads_dim = input_past_key_cross_0_shape[1] + cross_seq_len_dim = input_past_key_cross_0_shape[2] + + num_layer_output_qk = 0 + for node in subg.node: + if (node.op_type == "DecoderMaskedMultiHeadAttention") and (node.input[1] in past_key_cross_inputs): + print(f" -- add cross QK output from: node: {node.name} with output: {node.output}") + num_layer_output_qk += 1 + layer = past_key_cross_inputs[node.input[1]] + cross_attention_out_name = f"output_cross_qk_{layer}" + appended_names = [""] * (3 - len(node.output)) + appended_names.append(cross_attention_out_name) + node.output.extend(appended_names) + node.attribute.extend([onnx.helper.make_attribute("output_qk", 1)]) + + cross_attention = onnx.helper.make_tensor_value_info( + cross_attention_out_name, TensorProto.FLOAT, [batch_size_dim, num_heads_dim, 1, cross_seq_len_dim] + ) + subg.output.extend([cross_attention]) + if num_layer_output_qk != num_layers: + raise ValueError(f"Did not add cross QK for all layers{num_layers} vs {num_layer_output_qk}") + + +def update_decoder_subgraph_share_buffer_and_use_decoder_masked_mha(subg: ModelProto): input_self_past_0 = 1 # w/wo attention mask, w/wo hidden_state graph_input_names = [gi.name for gi in subg.input] diff --git a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py index 3562df1660ea9..c5f43ce782331 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py @@ -169,6 +169,69 @@ def parse_arguments(argv=None): ) parser.set_defaults(chain_model=True) + parser.add_argument( + "--extra_decoding_ids", + required=False, + action="store_true", + help="Need extra starting decoding ids for some feature like cross qk. Default if false.", + ) + parser.set_defaults(extra_decoding_ids=False) + + parser.add_argument( + "--collect_cross_qk", + required=False, + action="store_true", + help="Beam search model collect stacked cross QK.", + ) + parser.set_defaults(collect_cross_qk=False) + + parser.add_argument( + "--output_cross_qk", + required=False, + action="store_true", + help="Beam search model output collected qk as output. Also hint collect_cross_qk", + ) + parser.set_defaults(output_cross_qk=False) + + parser.add_argument( + "--no_speech_token_id", + default=50362, + type=int, + help="specify no_speech_token_id. Default is 1000. if >= 0, will be add into beam search attr", + ) + + parser.add_argument( + "--output_no_speech_probs", + required=False, + action="store_true", + help="Beam search model output no speech probs which is computed from the encoder/context-decoder graph.", + ) + parser.set_defaults(output_no_speech_probs=False) + + parser.add_argument( + "--output_scores", + required=False, + action="store_true", + help="Beam search model output scores over vocab per generated token.", + ) + parser.set_defaults(output_scores=False) + + parser.add_argument( + "--output_sequence_scores", + required=False, + action="store_true", + help="Beam search model output scores for each generated sequence.", + ) + parser.set_defaults(output_sequence_scores=False) + + parser.add_argument( + "--cross_qk_onnx_model", + required=False, + type=str, + default=None, + help="the model which consume cross_qk.", + ) + parser.add_argument( "--beam_output_model", type=str, @@ -220,6 +283,7 @@ def parse_arguments(argv=None): ) args = parser.parse_args(argv) + args.collect_cross_qk = args.collect_cross_qk or args.output_cross_qk return args diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py index 3b1e656136547..a1ed0c7ed5ca2 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py @@ -1,12 +1,19 @@ import logging import os +import sys import onnx -from benchmark_helper import Precision -from convert_generation import get_shared_initializers, update_decoder_subgraph_share_buffer_and_use_decoder_masked_mha from onnx import TensorProto, helper from transformers import WhisperConfig +sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..")) +from benchmark_helper import Precision # noqa: E402 +from convert_generation import ( # noqa: E402 + get_shared_initializers, + update_decoder_subgraph_output_cross_attention, + update_decoder_subgraph_share_buffer_and_use_decoder_masked_mha, +) + logger = logging.getLogger(__name__) @@ -42,8 +49,24 @@ def chain_model(args): "", # attention mask "decoder_input_ids" if args.use_forced_decoder_ids else "", "logits_processor" if args.use_logits_processor else "", + "cross_qk_layer_head" if args.collect_cross_qk else "", + "extra_decoding_ids" if args.extra_decoding_ids else "", ] + beam_outputs = ["sequences"] + if args.output_sequence_scores: + beam_outputs.append("sequence_scores") + if args.output_scores: + beam_outputs.append("scores") + + if args.collect_cross_qk: + while len(beam_outputs) < 3: + beam_outputs.extend([""]) + beam_outputs.extend(["cross_qk"]) + if args.output_no_speech_probs: + while len(beam_outputs) < 4: + beam_outputs.extend([""]) + beam_outputs.extend(["no_speech_probs_beam"]) input_features_cast_node, len_pen_cast_node, rep_pen_cast_node = None, None, None if args.precision == Precision.FLOAT16: @@ -81,6 +104,10 @@ def chain_model(args): helper.make_attribute("model_type", 2), ] ) + if args.collect_cross_qk: + node.attribute.extend([helper.make_attribute("decoder_output_cross_qk", 1)]) + if args.no_speech_token_id >= 0: + node.attribute.extend([helper.make_attribute("no_speech_token", args.no_speech_token_id)]) input_features = helper.make_tensor_value_info( "input_features", TensorProto.FLOAT, ["batch_size", "feature_size", "sequence_length"] @@ -121,17 +148,50 @@ def chain_model(args): logits_processor = helper.make_tensor_value_info("logits_processor", TensorProto.INT32, [1]) graph_inputs.append(logits_processor) + if args.collect_cross_qk: + cross_qk_layer_head = helper.make_tensor_value_info( + "cross_qk_layer_head", TensorProto.INT32, ["num_layer_head", 2] + ) + graph_inputs.append(cross_qk_layer_head) + + if args.extra_decoding_ids: + extra_decoding_ids = helper.make_tensor_value_info( + "extra_decoding_ids", TensorProto.INT32, ["batch_size", "extra_decoding_ids_len"] + ) + graph_inputs.append(extra_decoding_ids) + # graph outputs sequences = helper.make_tensor_value_info( "sequences", TensorProto.INT32, ["batch_size", "num_return_sequences", "max_length"] ) graph_outputs = [sequences] + if args.output_cross_qk or (not args.cross_qk_onnx_model and args.collect_cross_qk): + cross_qk = helper.make_tensor_value_info( + "cross_qk", + TensorProto.FLOAT, + ["batch_size", "num_return_sequences", "num_layer_head_cross_qk", "max_length", "frames"], + ) + graph_outputs.extend([cross_qk]) + + if args.output_no_speech_probs: + no_speech_probs = helper.make_tensor_value_info("no_speech_probs", TensorProto.FLOAT, ["batch_size"]) + graph_outputs.extend([no_speech_probs]) + + if args.output_sequence_scores: + sequence_scores = helper.make_tensor_value_info("sequence_scores", TensorProto.FLOAT, ["batch_size"]) + graph_outputs.extend([sequence_scores]) + + if args.output_scores: + scores = helper.make_tensor_value_info("scores", TensorProto.FLOAT, ["batch_size"]) + graph_outputs.extend([scores]) if hasattr(args, "use_gpu") and args.use_gpu: if update_decoder_subgraph_share_buffer_and_use_decoder_masked_mha(decoder_model.graph): logger.info("Updated whisper decoder subgraph to use DecoderMaskedMultiHeadAttention successfully!") else: logger.warning("DecoderMaskedMultiHeadAttention could not be applied to whisper decoder subgraph") + if hasattr(args, "collect_cross_qk") and args.collect_cross_qk: + update_decoder_subgraph_output_cross_attention(decoder_model.graph) # Initializers/opsets # Delete shared data between decoder/encoder and move to larger graph initializers @@ -150,7 +210,34 @@ def chain_model(args): if args.precision == Precision.FLOAT16 else [node] ) + if args.output_no_speech_probs: + prob_cast_node = helper.make_node( + "Cast", + inputs=["no_speech_probs_beam"], + outputs=["no_speech_probs"], + name="no_speech_probs_cast_to_fp32", + to=TensorProto.FLOAT, + ) + graph_nodes.extend([prob_cast_node]) + beam_graph = helper.make_graph(graph_nodes, "beam-search-test", graph_inputs, graph_outputs, initializers) + beam_graph_input_names = [gi.name for gi in graph_inputs] + beam_graph_output_names = [go.name for go in graph_outputs] + + if args.cross_qk_onnx_model: + post_qk_model = onnx.load_model(args.cross_qk_onnx_model, load_external_data=True) + post_qk_graph = post_qk_model.graph + beam_graph.initializer.extend(post_qk_graph.initializer) + beam_graph.node.extend(post_qk_graph.node) + # TODO: Treat same name same input, user need check their shapes, etc + for pgi in post_qk_graph.input: + if ( + (pgi.name not in beam_graph_input_names) + and (pgi.name not in beam_graph_output_names) + and (pgi.name != "cross_qk") + ): + beam_graph.input.extend([pgi]) + beam_graph.output.extend(post_qk_graph.output) # Verify graph's inputs match beam search's inputs verify_inputs(beam_inputs, graph_inputs) diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py index 3a81700a7fd04..8c22cd5e745b3 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py @@ -312,6 +312,7 @@ def verify_onnx( "tensor(uint8)": np.uint8, } + use_extra_decoding_ids = "extra_decoding_ids" in ort_names for name, dtype in zip(ort_names, ort_dtypes): if name == "input_features": inputs[name] = inputs[name].detach().cpu().numpy() @@ -320,9 +321,18 @@ def verify_onnx( elif name == "prefix_vocab_mask": inputs[name] = np.ones((batch_size, config.vocab_size), dtype=ort_to_np[dtype]) elif name == "decoder_input_ids": - inputs[name] = np.array([[config.decoder_start_token_id, 50259, 50359, 50363]], dtype=ort_to_np[dtype]) + raw_input_ids = ( + [[config.decoder_start_token_id]] + if use_extra_decoding_ids + else [[config.decoder_start_token_id, 50259, 50359, 50363]] + ) + inputs[name] = np.array(raw_input_ids, dtype=ort_to_np[dtype]) elif name == "logits_processor": inputs[name] = np.array([1], dtype=ort_to_np[dtype]) + elif name == "cross_qk_layer_head": + inputs[name] = np.array([[0, 0]], dtype=ort_to_np[dtype]) + elif name == "extra_decoding_ids": + inputs[name] = np.repeat(np.array([[50259, 50359, 50363]], dtype=ort_to_np[dtype]), batch_size, 0) else: inputs[name] = np.array([inputs[name]], dtype=ort_to_np[dtype]) ort_outputs = ort_session.run(None, inputs)[0][0] diff --git a/onnxruntime/test/contrib_ops/dynamic_time_warping_op_test.cc b/onnxruntime/test/contrib_ops/dynamic_time_warping_op_test.cc new file mode 100644 index 0000000000000..28f15b58eb467 --- /dev/null +++ b/onnxruntime/test/contrib_ops/dynamic_time_warping_op_test.cc @@ -0,0 +1,119 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "gtest/gtest.h" +#include "test/providers/provider_test_utils.h" +#include "test/util/include/default_providers.h" +#include "test/common/cuda_op_test_utils.h" + +using namespace ONNX_NAMESPACE; + +namespace onnxruntime { +namespace test { + +TEST(DynamicTimeWarp, simple) { + if (NeedSkipIfCudaArchLowerThan(530)) { + return; + } + + std::vector X = { + 3.0f, + 8.0f, + 5.0f, + 1.0f, + 9.0f, + 8.0f, + 5.0f, + 7.0f, + 4.0f, + 4.0f, + 9.0f, + 6.0f, + 2.0f, + 9.0f, + 7.0f, + 2.0f, + 5.0f, + 6.0f, + 1.0f, + 8.0f, + 4.0f, + 6.0f, + 5.0f, + 8.0f, + 4.0f, + 8.0f, + 3.0f, + 6.0f, + 3.0f, + 9.0f, + 1.0f, + 1.0f, + 6.0f, + 8.0f, + 3.0f, + 5.0f, + 5.0f, + 3.0f, + 3.0f, + 8.0f, + 8.0f, + 7.0f, + 1.0f, + 2.0f, + 2.0f, + 1.0f, + 5.0f, + 4.0f, + 5.0f, + 0.0f, + 3.0f, + 6.0f, + 3.0f, + 7.0f, + 4.0f, + 5.0f, + 4.0f, + 5.0f, + 4.0f, + 0.0f, + }; + + std::vector Y = { + 0, + 1, + 2, + 3, + 4, + 4, + 4, + 4, + 5, + 5, + 5, + 5, + 0, + 1, + 1, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + }; + + OpTester tester("DynamicTimeWarping", 1, onnxruntime::kMSDomain); + tester.AddInput("input", {6, 10}, X); + tester.AddOutput("output", {2, 12}, Y); + + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/contrib_ops/tensor_op_test.cc b/onnxruntime/test/contrib_ops/tensor_op_test.cc index 44cb49580ce8b..323a8b2cb00ef 100644 --- a/onnxruntime/test/contrib_ops/tensor_op_test.cc +++ b/onnxruntime/test/contrib_ops/tensor_op_test.cc @@ -6,6 +6,7 @@ #include "test/common/tensor_op_test_utils.h" #include "test/providers/provider_test_utils.h" #include "test/util/include/default_providers.h" +#include "test/common/cuda_op_test_utils.h" using namespace ONNX_NAMESPACE; using namespace onnxruntime::test; @@ -14,6 +15,76 @@ namespace test { using ExpectResult = OpTester::ExpectResult; +TEST(UnfoldTensorOpTest, LastDim) { + if (NeedSkipIfCudaArchLowerThan(530)) { + return; + } + + std::vector X = { + 1.0f, 2.0f, 3.0f, 4.0f, + 5.0f, 6.0f, 7.0f, 8.0f, + 6.0f, 7.0f, 8.0f, 9.0f}; + + std::vector output = { + 1.0f, 2.0f, 3.0f, 2.0f, 3.0f, 4.0f, + 5.0f, 6.0f, 7.0f, 6.0f, 7.0f, 8.0f, + 6.0f, 7.0f, 8.0f, 7.0f, 8.0f, 9.0f}; + + OpTester tester("UnfoldTensor", 1, onnxruntime::kMSDomain); + + tester.AddAttribute("size", 3LL); + tester.AddInput("input", {3, 4}, X); + tester.AddOutput("output", {3, 2, 3}, output); + + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +TEST(UnfoldTensorOpTest, NormalDim) { + if (NeedSkipIfCudaArchLowerThan(530)) { + return; + } + + std::vector X = { + 1, 2, 3, 4, 2, 2, 3, 4, 3, 2, 3, 4, + 4, 6, 7, 8, 5, 6, 7, 8, 6, 6, 7, 8, + 6, 7, 8, 9, 7, 7, 8, 9, 8, 7, 8, 9, + 9, 7, 8, 9, 10, 7, 8, 9, 11, 7, 8, 9}; + + std::vector output = { + 1, 2, 3, + 2, 2, 2, + 3, 3, 3, + 4, 4, 4, + + 3, 4, 5, + 2, 6, 6, + 3, 7, 7, + 4, 8, 8, + + 6, 7, 8, + 7, 7, 7, + 8, 8, 8, + 9, 9, 9, + + 8, 9, 10, + 7, 7, 7, + 8, 8, 8, + 9, 9, 9}; + + OpTester tester("UnfoldTensor", 1, onnxruntime::kMSDomain); + tester.AddAttribute("dim", 1LL); + tester.AddAttribute("size", 3LL); + tester.AddAttribute("step", 2LL); + tester.AddInput("input", {2, 6, 4}, X); + tester.AddOutput("output", {2, 2, 4, 3}, output); + + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + TEST(CropContribOpTest, CropBorderOnly) { constexpr int N = 2, C = 1, H = 3, W = 4; std::vector X = {1.0f, 2.0f, 3.0f, 4.0f, diff --git a/onnxruntime/test/python/transformers/test_whisper_timestamp_processor.py b/onnxruntime/test/python/transformers/test_whisper_timestamp_processor.py index 66200af06f511..77ce09d7e793b 100644 --- a/onnxruntime/test/python/transformers/test_whisper_timestamp_processor.py +++ b/onnxruntime/test/python/transformers/test_whisper_timestamp_processor.py @@ -10,7 +10,7 @@ import pytest import torch -from onnxruntime import InferenceSession, SessionOptions +from onnxruntime import InferenceSession, SessionOptions, get_available_providers class TestTimestampProcessor(unittest.TestCase): @@ -52,12 +52,13 @@ def run_timestamp(self, provider: str): ort_transcription = processor.batch_decode( ort_out_tensor[0][0].view(1, -1), skip_special_tokens=True, output_offsets=True ) + print(ort_transcription) expected_transcription = [ { - "text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.", + "text": "<|0.00|> Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.<|5.44|>", "offsets": [ { - "text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.", + "text": "<|0.00|> Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.<|5.44|>", "timestamp": (0.0, 5.44), } ], @@ -70,6 +71,12 @@ def test_timestamp_cpu(self): provider = "CPUExecutionProvider" self.run_timestamp(provider) + @pytest.mark.slow + def test_timestamp_cuda(self): + cuda_provider = "CUDAExecutionProvider" + if cuda_provider in get_available_providers(): + self.run_timestamp(cuda_provider) + if __name__ == "__main__": unittest.main()